Skip to content

Commit 2ea31ab

Browse files
committed
fixing tests
1 parent 7413907 commit 2ea31ab

File tree

2 files changed

+19
-162
lines changed

2 files changed

+19
-162
lines changed

tailnet/resume.go

Lines changed: 9 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@ package tailnet
33
import (
44
"context"
55
"crypto/rand"
6-
"database/sql"
7-
"encoding/hex"
86
"time"
97

10-
"github.com/go-jose/go-jose/v4"
118
"github.com/go-jose/go-jose/v4/jwt"
129
"github.com/google/uuid"
1310
"golang.org/x/xerrors"
@@ -53,47 +50,6 @@ func GenerateResumeTokenSigningKey() (ResumeTokenSigningKey, error) {
5350
return key, nil
5451
}
5552

56-
type ResumeTokenSigningKeyDatabaseStore interface {
57-
GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error)
58-
UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, key string) error
59-
}
60-
61-
// ResumeTokenSigningKeyFromDatabase retrieves the coordinator resume token
62-
// signing key from the database. If the key is not found, a new key is
63-
// generated and inserted into the database.
64-
func ResumeTokenSigningKeyFromDatabase(ctx context.Context, db ResumeTokenSigningKeyDatabaseStore) (ResumeTokenSigningKey, error) {
65-
var resumeTokenKey ResumeTokenSigningKey
66-
resumeTokenKeyStr, err := db.GetCoordinatorResumeTokenSigningKey(ctx)
67-
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
68-
return resumeTokenKey, xerrors.Errorf("get coordinator resume token key: %w", err)
69-
}
70-
if decoded, err := hex.DecodeString(resumeTokenKeyStr); err != nil || len(decoded) != len(resumeTokenKey) {
71-
newKey, err := GenerateResumeTokenSigningKey()
72-
if err != nil {
73-
return resumeTokenKey, xerrors.Errorf("generate fresh coordinator resume token key: %w", err)
74-
}
75-
76-
resumeTokenKeyStr = hex.EncodeToString(newKey[:])
77-
err = db.UpsertCoordinatorResumeTokenSigningKey(ctx, resumeTokenKeyStr)
78-
if err != nil {
79-
return resumeTokenKey, xerrors.Errorf("insert freshly generated coordinator resume token key to database: %w", err)
80-
}
81-
}
82-
83-
resumeTokenKeyBytes, err := hex.DecodeString(resumeTokenKeyStr)
84-
if err != nil {
85-
return resumeTokenKey, xerrors.Errorf("decode coordinator resume token key from database: %w", err)
86-
}
87-
if len(resumeTokenKeyBytes) != len(resumeTokenKey) {
88-
return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is not the correct length, expect %d got %d", len(resumeTokenKey), len(resumeTokenKeyBytes))
89-
}
90-
copy(resumeTokenKey[:], resumeTokenKeyBytes)
91-
if resumeTokenKey == [64]byte{} {
92-
return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is empty")
93-
}
94-
return resumeTokenKey, nil
95-
}
96-
9753
type ResumeTokenKeyProvider struct {
9854
key jwtutils.SigningKeyManager
9955
clock quartz.Clock
@@ -111,19 +67,11 @@ func NewResumeTokenKeyProvider(key jwtutils.SigningKeyManager, clock quartz.Cloc
11167
}
11268
}
11369

114-
type resumeTokenPayload struct {
115-
jwt.Claims
116-
PeerID uuid.UUID `json:"sub"`
117-
Expiry int64 `json:"exp"`
118-
}
119-
12070
func (p ResumeTokenKeyProvider) GenerateResumeToken(ctx context.Context, peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) {
12171
exp := p.clock.Now().Add(p.expiry)
122-
payload := resumeTokenPayload{
123-
PeerID: peerID,
124-
Claims: jwt.Claims{
125-
Expiry: jwt.NewNumericDate(exp),
126-
},
72+
payload := jwt.Claims{
73+
Subject: peerID.String(),
74+
Expiry: jwt.NewNumericDate(exp),
12775
}
12876

12977
token, err := jwtutils.Sign(ctx, p.key, payload)
@@ -142,12 +90,16 @@ func (p ResumeTokenKeyProvider) GenerateResumeToken(ctx context.Context, peerID
14290
// returns the payload. If the token is invalid or expired, an error is
14391
// returned.
14492
func (p ResumeTokenKeyProvider) VerifyResumeToken(ctx context.Context, str string) (uuid.UUID, error) {
145-
var tok resumeTokenPayload
93+
var tok jwt.Claims
14694
err := jwtutils.Verify(ctx, p.key, str, &tok, jwtutils.WithVerifyExpected(jwt.Expected{
14795
Time: p.clock.Now(),
14896
}))
14997
if err != nil {
15098
return uuid.Nil, xerrors.Errorf("verify payload: %w", err)
15199
}
152-
return tok.PeerID, nil
100+
parsed, err := uuid.Parse(tok.Subject)
101+
if err != nil {
102+
return uuid.Nil, xerrors.Errorf("parse peerID from token: %w", err)
103+
}
104+
return parsed, nil
153105
}

tailnet/resume_test.go

Lines changed: 10 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,20 @@
11
package tailnet_test
22

33
import (
4-
"context"
5-
"encoding/hex"
64
"testing"
75
"time"
86

7+
"github.com/go-jose/go-jose/v4"
8+
"github.com/go-jose/go-jose/v4/jwt"
99
"github.com/google/uuid"
10-
"github.com/stretchr/testify/assert"
1110
"github.com/stretchr/testify/require"
12-
"go.uber.org/mock/gomock"
1311

14-
"github.com/coder/coder/v2/coderd/database/dbmock"
15-
"github.com/coder/coder/v2/coderd/database/dbtestutil"
1612
"github.com/coder/coder/v2/coderd/jwtutils"
1713
"github.com/coder/coder/v2/tailnet"
1814
"github.com/coder/coder/v2/testutil"
1915
"github.com/coder/quartz"
2016
)
2117

22-
func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {
23-
t.Parallel()
24-
25-
assertRandomKey := func(t *testing.T, key tailnet.ResumeTokenSigningKey) {
26-
t.Helper()
27-
assert.NotEqual(t, tailnet.ResumeTokenSigningKey{}, key, "key should not be empty")
28-
assert.NotEqualValues(t, [64]byte{1}, key, "key should not be all 1s")
29-
}
30-
31-
t.Run("GenerateRetrieve", func(t *testing.T) {
32-
t.Parallel()
33-
34-
db, _ := dbtestutil.NewDB(t)
35-
ctx := testutil.Context(t, testutil.WaitShort)
36-
key1, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
37-
require.NoError(t, err)
38-
assertRandomKey(t, key1)
39-
40-
key2, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
41-
require.NoError(t, err)
42-
require.Equal(t, key1, key2, "keys should not be different")
43-
})
44-
45-
t.Run("GetError", func(t *testing.T) {
46-
t.Parallel()
47-
48-
db := dbmock.NewMockStore(gomock.NewController(t))
49-
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("", assert.AnError)
50-
51-
ctx := testutil.Context(t, testutil.WaitShort)
52-
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
53-
require.ErrorIs(t, err, assert.AnError)
54-
})
55-
56-
t.Run("UpsertError", func(t *testing.T) {
57-
t.Parallel()
58-
59-
db := dbmock.NewMockStore(gomock.NewController(t))
60-
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("", nil)
61-
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(assert.AnError)
62-
63-
ctx := testutil.Context(t, testutil.WaitShort)
64-
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
65-
require.ErrorIs(t, err, assert.AnError)
66-
})
67-
68-
t.Run("DecodeErrorShouldRegenerate", func(t *testing.T) {
69-
t.Parallel()
70-
71-
db := dbmock.NewMockStore(gomock.NewController(t))
72-
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("invalid", nil)
73-
74-
var storedKey tailnet.ResumeTokenSigningKey
75-
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Do(func(_ context.Context, value string) error {
76-
keyBytes, err := hex.DecodeString(value)
77-
require.NoError(t, err)
78-
require.Len(t, keyBytes, len(storedKey))
79-
copy(storedKey[:], keyBytes)
80-
return nil
81-
})
82-
83-
ctx := testutil.Context(t, testutil.WaitShort)
84-
key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
85-
require.NoError(t, err)
86-
assertRandomKey(t, key)
87-
require.Equal(t, storedKey, key, "key should match stored value")
88-
})
89-
90-
t.Run("LengthErrorShouldRegenerate", func(t *testing.T) {
91-
t.Parallel()
92-
93-
db := dbmock.NewMockStore(gomock.NewController(t))
94-
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("deadbeef", nil)
95-
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(nil)
96-
97-
ctx := testutil.Context(t, testutil.WaitShort)
98-
key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
99-
require.NoError(t, err)
100-
assertRandomKey(t, key)
101-
})
102-
103-
t.Run("EmptyError", func(t *testing.T) {
104-
t.Parallel()
105-
106-
db := dbmock.NewMockStore(gomock.NewController(t))
107-
emptyKey := hex.EncodeToString(make([]byte, 64))
108-
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return(emptyKey, nil)
109-
110-
ctx := testutil.Context(t, testutil.WaitShort)
111-
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
112-
require.ErrorContains(t, err, "is empty")
113-
})
114-
}
115-
11618
func TestResumeTokenKeyProvider(t *testing.T) {
11719
t.Parallel()
11820

@@ -156,7 +58,7 @@ func TestResumeTokenKeyProvider(t *testing.T) {
15658
_ = clock.Advance(tailnet.DefaultResumeTokenExpiry + time.Second)
15759

15860
_, err = provider.VerifyResumeToken(ctx, token.Token)
159-
require.ErrorContains(t, err, "expired")
61+
require.ErrorIs(t, err, jwt.ErrExpired)
16062
})
16163

16264
t.Run("InvalidToken", func(t *testing.T) {
@@ -175,17 +77,20 @@ func TestResumeTokenKeyProvider(t *testing.T) {
17577
// Generate a resume token with a different key
17678
otherKey, err := tailnet.GenerateResumeTokenSigningKey()
17779
require.NoError(t, err)
178-
otherProvider := tailnet.NewResumeTokenKeyProvider(newKeySigner(otherKey), quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
80+
otherSigner := newKeySigner(otherKey)
81+
otherProvider := tailnet.NewResumeTokenKeyProvider(otherSigner, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
17982
token, err := otherProvider.GenerateResumeToken(ctx, uuid.New())
18083
require.NoError(t, err)
18184

182-
provider := tailnet.NewResumeTokenKeyProvider(newKeySigner(key), quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
85+
signer := newKeySigner(key)
86+
signer.ID = otherSigner.ID
87+
provider := tailnet.NewResumeTokenKeyProvider(signer, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
18388
_, err = provider.VerifyResumeToken(ctx, token.Token)
184-
require.ErrorContains(t, err, "verify JWS")
89+
require.ErrorIs(t, err, jose.ErrCryptoFailure)
18590
})
18691
}
18792

188-
func newKeySigner(key tailnet.ResumeTokenSigningKey) jwtutils.SigningKeyManager {
93+
func newKeySigner(key tailnet.ResumeTokenSigningKey) jwtutils.StaticKeyManager {
18994
return jwtutils.StaticKeyManager{
19095
ID: uuid.New().String(),
19196
Key: key[:],

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy