diff --git a/cli/server.go b/cli/server.go index 5adb44c3c0a7d..2c1f8fab10c1d 100644 --- a/cli/server.go +++ b/cli/server.go @@ -10,7 +10,6 @@ import ( "crypto/tls" "crypto/x509" "database/sql" - "encoding/hex" "errors" "flag" "fmt" @@ -62,6 +61,7 @@ import ( "github.com/coder/serpent" "github.com/coder/wgtunnel/tunnelsdk" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/notifications/reports" "github.com/coder/coder/v2/coderd/runtimeconfig" @@ -97,7 +97,6 @@ import ( "github.com/coder/coder/v2/coderd/updatecheck" "github.com/coder/coder/v2/coderd/util/slice" stringutil "github.com/coder/coder/v2/coderd/util/strings" - "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/coderd/workspaceapps/appurl" "github.com/coder/coder/v2/coderd/workspacestats" "github.com/coder/coder/v2/codersdk" @@ -741,90 +740,31 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. return xerrors.Errorf("set deployment id: %w", err) } } - - // Read the app signing key from the DB. We store it hex encoded - // since the config table uses strings for the value and we - // don't want to deal with automatic encoding issues. - appSecurityKeyStr, err := tx.GetAppSecurityKey(ctx) - if err != nil && !xerrors.Is(err, sql.ErrNoRows) { - return xerrors.Errorf("get app signing key: %w", err) - } - // If the string in the DB is an invalid hex string or the - // length is not equal to the current key length, generate a new - // one. - // - // If the key is regenerated, old signed tokens and encrypted - // strings will become invalid. New signed app tokens will be - // generated automatically on failure. Any workspace app token - // smuggling operations in progress may fail, although with a - // helpful error. - if decoded, err := hex.DecodeString(appSecurityKeyStr); err != nil || len(decoded) != len(workspaceapps.SecurityKey{}) { - b := make([]byte, len(workspaceapps.SecurityKey{})) - _, err := rand.Read(b) - if err != nil { - return xerrors.Errorf("generate fresh app signing key: %w", err) - } - - appSecurityKeyStr = hex.EncodeToString(b) - err = tx.UpsertAppSecurityKey(ctx, appSecurityKeyStr) - if err != nil { - return xerrors.Errorf("insert freshly generated app signing key to database: %w", err) - } - } - - appSecurityKey, err := workspaceapps.KeyFromString(appSecurityKeyStr) - if err != nil { - return xerrors.Errorf("decode app signing key from database: %w", err) - } - - options.AppSecurityKey = appSecurityKey - - // Read the oauth signing key from the database. Like the app security, generate a new one - // if it is invalid for any reason. - oauthSigningKeyStr, err := tx.GetOAuthSigningKey(ctx) - if err != nil && !xerrors.Is(err, sql.ErrNoRows) { - return xerrors.Errorf("get app oauth signing key: %w", err) - } - if decoded, err := hex.DecodeString(oauthSigningKeyStr); err != nil || len(decoded) != len(options.OAuthSigningKey) { - b := make([]byte, len(options.OAuthSigningKey)) - _, err := rand.Read(b) - if err != nil { - return xerrors.Errorf("generate fresh oauth signing key: %w", err) - } - - oauthSigningKeyStr = hex.EncodeToString(b) - err = tx.UpsertOAuthSigningKey(ctx, oauthSigningKeyStr) - if err != nil { - return xerrors.Errorf("insert freshly generated oauth signing key to database: %w", err) - } - } - - oauthKeyBytes, err := hex.DecodeString(oauthSigningKeyStr) - if err != nil { - return xerrors.Errorf("decode oauth signing key from database: %w", err) - } - if len(oauthKeyBytes) != len(options.OAuthSigningKey) { - return xerrors.Errorf("oauth signing key in database is not the correct length, expect %d got %d", len(options.OAuthSigningKey), len(oauthKeyBytes)) - } - copy(options.OAuthSigningKey[:], oauthKeyBytes) - if options.OAuthSigningKey == [32]byte{} { - return xerrors.Errorf("oauth signing key in database is empty") - } - - // Read the coordinator resume token signing key from the - // database. - resumeTokenKey, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, tx) - if err != nil { - return xerrors.Errorf("get coordinator resume token key from database: %w", err) - } - options.CoordinatorResumeTokenProvider = tailnet.NewResumeTokenKeyProvider(resumeTokenKey, quartz.NewReal(), tailnet.DefaultResumeTokenExpiry) - return nil }, nil) if err != nil { - return err + return xerrors.Errorf("set deployment id: %w", err) + } + + fetcher := &cryptokeys.DBFetcher{ + DB: options.Database, + } + + resumeKeycache, err := cryptokeys.NewSigningCache(ctx, + logger, + fetcher, + codersdk.CryptoKeyFeatureTailnetResume, + ) + if err != nil { + logger.Critical(ctx, "failed to properly instantiate tailnet resume signing cache", slog.Error(err)) } + options.CoordinatorResumeTokenProvider = tailnet.NewResumeTokenKeyProvider( + resumeKeycache, + quartz.NewReal(), + tailnet.DefaultResumeTokenExpiry, + ) + options.RuntimeConfig = runtimeconfig.NewManager() // This should be output before the logs start streaming. diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 76084b1ff54dd..09f070046066a 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -7646,6 +7646,15 @@ const docTemplate = `{ ], "summary": "Get workspace proxy crypto keys", "operationId": "get-workspace-proxy-crypto-keys", + "parameters": [ + { + "type": "string", + "description": "Feature key", + "name": "feature", + "in": "query", + "required": true + } + ], "responses": { "200": { "description": "OK", @@ -10011,12 +10020,14 @@ const docTemplate = `{ "codersdk.CryptoKeyFeature": { "type": "string", "enum": [ - "workspace_apps", + "workspace_apps_api_key", + "workspace_apps_token", "oidc_convert", "tailnet_resume" ], "x-enum-varnames": [ - "CryptoKeyFeatureWorkspaceApp", + "CryptoKeyFeatureWorkspaceAppsAPIKey", + "CryptoKeyFeatureWorkspaceAppsToken", "CryptoKeyFeatureOIDCConvert", "CryptoKeyFeatureTailnetResume" ] @@ -16244,9 +16255,6 @@ const docTemplate = `{ "wsproxysdk.RegisterWorkspaceProxyResponse": { "type": "object", "properties": { - "app_security_key": { - "type": "string" - }, "derp_force_websockets": { "type": "boolean" }, diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index beff69ca22373..42b34d576509a 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -6758,6 +6758,15 @@ "tags": ["Enterprise"], "summary": "Get workspace proxy crypto keys", "operationId": "get-workspace-proxy-crypto-keys", + "parameters": [ + { + "type": "string", + "description": "Feature key", + "name": "feature", + "in": "query", + "required": true + } + ], "responses": { "200": { "description": "OK", @@ -8914,9 +8923,15 @@ }, "codersdk.CryptoKeyFeature": { "type": "string", - "enum": ["workspace_apps", "oidc_convert", "tailnet_resume"], + "enum": [ + "workspace_apps_api_key", + "workspace_apps_token", + "oidc_convert", + "tailnet_resume" + ], "x-enum-varnames": [ - "CryptoKeyFeatureWorkspaceApp", + "CryptoKeyFeatureWorkspaceAppsAPIKey", + "CryptoKeyFeatureWorkspaceAppsToken", "CryptoKeyFeatureOIDCConvert", "CryptoKeyFeatureTailnetResume" ] @@ -14853,9 +14868,6 @@ "wsproxysdk.RegisterWorkspaceProxyResponse": { "type": "object", "properties": { - "app_security_key": { - "type": "string" - }, "derp_force_websockets": { "type": "boolean" }, diff --git a/coderd/coderd.go b/coderd/coderd.go index cb0884808ef27..3011c2d58d39c 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -40,6 +40,7 @@ import ( "github.com/coder/quartz" "github.com/coder/serpent" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/idpsync" "github.com/coder/coder/v2/coderd/runtimeconfig" @@ -185,9 +186,6 @@ type Options struct { TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore] AccessControlStore *atomic.Pointer[dbauthz.AccessControlStore] - // AppSecurityKey is the crypto key used to sign and encrypt tokens related to - // workspace applications. It consists of both a signing and encryption key. - AppSecurityKey workspaceapps.SecurityKey // CoordinatorResumeTokenProvider is used to provide and validate resume // tokens issued by and passed to the coordinator DRPC API. CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider @@ -251,6 +249,12 @@ type Options struct { // OneTimePasscodeValidityPeriod specifies how long a one time passcode should be valid for. OneTimePasscodeValidityPeriod time.Duration + + // Keycaches + AppSigningKeyCache cryptokeys.SigningKeycache + AppEncryptionKeyCache cryptokeys.EncryptionKeycache + OIDCConvertKeyCache cryptokeys.SigningKeycache + Clock quartz.Clock } // @title Coder API @@ -352,6 +356,9 @@ func New(options *Options) *API { if options.PrometheusRegistry == nil { options.PrometheusRegistry = prometheus.NewRegistry() } + if options.Clock == nil { + options.Clock = quartz.NewReal() + } if options.DERPServer == nil && options.DeploymentValues.DERP.Server.Enable { options.DERPServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp"))) } @@ -444,6 +451,49 @@ func New(options *Options) *API { if err != nil { panic(xerrors.Errorf("get deployment ID: %w", err)) } + + fetcher := &cryptokeys.DBFetcher{ + DB: options.Database, + } + + if options.OIDCConvertKeyCache == nil { + options.OIDCConvertKeyCache, err = cryptokeys.NewSigningCache(ctx, + options.Logger.Named("oidc_convert_keycache"), + fetcher, + codersdk.CryptoKeyFeatureOIDCConvert, + ) + if err != nil { + options.Logger.Critical(ctx, "failed to properly instantiate oidc convert signing cache", slog.Error(err)) + } + } + + if options.AppSigningKeyCache == nil { + options.AppSigningKeyCache, err = cryptokeys.NewSigningCache(ctx, + options.Logger.Named("app_signing_keycache"), + fetcher, + codersdk.CryptoKeyFeatureWorkspaceAppsToken, + ) + if err != nil { + options.Logger.Critical(ctx, "failed to properly instantiate app signing key cache", slog.Error(err)) + } + } + + if options.AppEncryptionKeyCache == nil { + options.AppEncryptionKeyCache, err = cryptokeys.NewEncryptionCache(ctx, + options.Logger, + fetcher, + codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey, + ) + if err != nil { + options.Logger.Critical(ctx, "failed to properly instantiate app encryption key cache", slog.Error(err)) + } + } + + // Start a background process that rotates keys. We intentionally start this after the caches + // are created to force initial requests for a key to populate the caches. This helps catch + // bugs that may only occur when a key isn't precached in tests and the latency cost is minimal. + cryptokeys.StartRotator(ctx, options.Logger, options.Database) + api := &API{ ctx: ctx, cancel: cancel, @@ -464,7 +514,7 @@ func New(options *Options) *API { options.DeploymentValues, oauthConfigs, options.AgentInactiveDisconnectTimeout, - options.AppSecurityKey, + options.AppSigningKeyCache, ), metricsCache: metricsCache, Auditor: atomic.Pointer[audit.Auditor]{}, @@ -606,7 +656,7 @@ func New(options *Options) *API { ResumeTokenProvider: api.Options.CoordinatorResumeTokenProvider, }) if err != nil { - api.Logger.Fatal(api.ctx, "failed to initialize tailnet client service", slog.Error(err)) + api.Logger.Fatal(context.Background(), "failed to initialize tailnet client service", slog.Error(err)) } api.statsReporter = workspacestats.NewReporter(workspacestats.ReporterOptions{ @@ -628,9 +678,6 @@ func New(options *Options) *API { options.WorkspaceAppsStatsCollectorOptions.Reporter = api.statsReporter } - if options.AppSecurityKey.IsZero() { - api.Logger.Fatal(api.ctx, "app security key cannot be zero") - } api.workspaceAppServer = &workspaceapps.Server{ Logger: workspaceAppsLogger, @@ -642,11 +689,11 @@ func New(options *Options) *API { SignedTokenProvider: api.WorkspaceAppsProvider, AgentProvider: api.agentProvider, - AppSecurityKey: options.AppSecurityKey, StatsCollector: workspaceapps.NewStatsCollector(options.WorkspaceAppsStatsCollectorOptions), - DisablePathApps: options.DeploymentValues.DisablePathApps.Value(), - SecureAuthCookie: options.DeploymentValues.SecureAuthCookie.Value(), + DisablePathApps: options.DeploymentValues.DisablePathApps.Value(), + SecureAuthCookie: options.DeploymentValues.SecureAuthCookie.Value(), + APIKeyEncryptionKeycache: options.AppEncryptionKeyCache, } apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ @@ -1434,6 +1481,9 @@ func (api *API) Close() error { _ = api.agentProvider.Close() _ = api.statsReporter.Close() _ = api.NetworkTelemetryBatcher.Close() + _ = api.OIDCConvertKeyCache.Close() + _ = api.AppSigningKeyCache.Close() + _ = api.AppEncryptionKeyCache.Close() return nil } diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 05c31f35bd20a..d94a6fbe93c4e 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -55,6 +55,7 @@ import ( "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/autobuild" "github.com/coder/coder/v2/coderd/awsidentity" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" @@ -88,12 +89,9 @@ import ( sdkproto "github.com/coder/coder/v2/provisionersdk/proto" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" ) -// AppSecurityKey is a 96-byte key used to sign JWTs and encrypt JWEs for -// workspace app tokens in tests. -var AppSecurityKey = must(workspaceapps.KeyFromString("6465616e207761732068657265206465616e207761732068657265206465616e207761732068657265206465616e207761732068657265206465616e207761732068657265206465616e207761732068657265206465616e2077617320686572")) - type Options struct { // AccessURL denotes a custom access URL. By default we use the httptest // server's URL. Setting this may result in unexpected behavior (especially @@ -161,8 +159,10 @@ type Options struct { DatabaseRolluper *dbrollup.Rolluper WorkspaceUsageTrackerFlush chan int WorkspaceUsageTrackerTick chan time.Time - - NotificationsEnqueuer notifications.Enqueuer + NotificationsEnqueuer notifications.Enqueuer + APIKeyEncryptionCache cryptokeys.EncryptionKeycache + OIDCConvertKeyCache cryptokeys.SigningKeycache + Clock quartz.Clock } // New constructs a codersdk client connected to an in-memory API instance. @@ -525,7 +525,6 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can DeploymentOptions: codersdk.DeploymentOptionsWithoutSecrets(options.DeploymentValues.Options()), UpdateCheckOptions: options.UpdateCheckOptions, SwaggerEndpoint: options.SwaggerEndpoint, - AppSecurityKey: AppSecurityKey, SSHConfig: options.ConfigSSH, HealthcheckFunc: options.HealthcheckFunc, HealthcheckTimeout: options.HealthcheckTimeout, @@ -538,6 +537,9 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can WorkspaceUsageTracker: wuTracker, NotificationsEnqueuer: options.NotificationsEnqueuer, OneTimePasscodeValidityPeriod: options.OneTimePasscodeValidityPeriod, + Clock: options.Clock, + AppEncryptionKeyCache: options.APIKeyEncryptionCache, + OIDCConvertKeyCache: options.OIDCConvertKeyCache, } } diff --git a/coderd/cryptokeys/cache.go b/coderd/cryptokeys/cache.go index 74fb025d416fd..7777d5f75b942 100644 --- a/coderd/cryptokeys/cache.go +++ b/coderd/cryptokeys/cache.go @@ -3,6 +3,7 @@ package cryptokeys import ( "context" "encoding/hex" + "fmt" "io" "strconv" "sync" @@ -12,7 +13,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/db2sdk" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/codersdk" "github.com/coder/quartz" ) @@ -25,7 +26,7 @@ var ( ) type Fetcher interface { - Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) + Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) } type EncryptionKeycache interface { @@ -62,27 +63,26 @@ const ( ) type DBFetcher struct { - DB database.Store - Feature database.CryptoKeyFeature + DB database.Store } -func (d *DBFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) { - keys, err := d.DB.GetCryptoKeysByFeature(ctx, d.Feature) +func (d *DBFetcher) Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) { + keys, err := d.DB.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeature(feature)) if err != nil { return nil, xerrors.Errorf("get crypto keys by feature: %w", err) } - return db2sdk.CryptoKeys(keys), nil + return toSDKKeys(keys), nil } // cache implements the caching functionality for both signing and encryption keys. type cache struct { - clock quartz.Clock - refreshCtx context.Context - refreshCancel context.CancelFunc - fetcher Fetcher - logger slog.Logger - feature codersdk.CryptoKeyFeature + ctx context.Context + cancel context.CancelFunc + clock quartz.Clock + fetcher Fetcher + logger slog.Logger + feature codersdk.CryptoKeyFeature mu sync.Mutex keys map[int32]codersdk.CryptoKey @@ -109,7 +109,8 @@ func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, if !isSigningKeyFeature(feature) { return nil, xerrors.Errorf("invalid feature: %s", feature) } - return newCache(ctx, logger, fetcher, feature, opts...) + logger = logger.Named(fmt.Sprintf("%s_signing_keycache", feature)) + return newCache(ctx, logger, fetcher, feature, opts...), nil } func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, @@ -118,10 +119,11 @@ func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher if !isEncryptionKeyFeature(feature) { return nil, xerrors.Errorf("invalid feature: %s", feature) } - return newCache(ctx, logger, fetcher, feature, opts...) + logger = logger.Named(fmt.Sprintf("%s_encryption_keycache", feature)) + return newCache(ctx, logger, fetcher, feature, opts...), nil } -func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (*cache, error) { +func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) *cache { cache := &cache{ clock: quartz.NewReal(), logger: logger, @@ -134,16 +136,16 @@ func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature } cache.cond = sync.NewCond(&cache.mu) - cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx) + //nolint:gocritic // We need to be able to read the keys in order to cache them. + cache.ctx, cache.cancel = context.WithCancel(dbauthz.AsKeyReader(ctx)) cache.refresher = cache.clock.AfterFunc(refreshInterval, cache.refresh) - keys, err := cache.cryptoKeys(ctx) + keys, err := cache.cryptoKeys(cache.ctx) if err != nil { - cache.refreshCancel() - return nil, xerrors.Errorf("initial fetch: %w", err) + cache.logger.Critical(cache.ctx, "failed initial fetch", slog.Error(err)) } cache.keys = keys - return cache, nil + return cache } func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error) { @@ -151,6 +153,8 @@ func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error) return "", nil, ErrInvalidFeature } + //nolint:gocritic // cache can only read crypto keys. + ctx = dbauthz.AsKeyReader(ctx) return c.cryptoKey(ctx, latestSequence) } @@ -164,6 +168,8 @@ func (c *cache) DecryptingKey(ctx context.Context, id string) (interface{}, erro return nil, xerrors.Errorf("parse id: %w", err) } + //nolint:gocritic // cache can only read crypto keys. + ctx = dbauthz.AsKeyReader(ctx) _, secret, err := c.cryptoKey(ctx, int32(seq)) if err != nil { return nil, xerrors.Errorf("crypto key: %w", err) @@ -176,6 +182,8 @@ func (c *cache) SigningKey(ctx context.Context) (string, interface{}, error) { return "", nil, ErrInvalidFeature } + //nolint:gocritic // cache can only read crypto keys. + ctx = dbauthz.AsKeyReader(ctx) return c.cryptoKey(ctx, latestSequence) } @@ -188,7 +196,8 @@ func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error if err != nil { return nil, xerrors.Errorf("parse id: %w", err) } - + //nolint:gocritic // cache can only read crypto keys. + ctx = dbauthz.AsKeyReader(ctx) _, secret, err := c.cryptoKey(ctx, int32(seq)) if err != nil { return nil, xerrors.Errorf("crypto key: %w", err) @@ -198,12 +207,12 @@ func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error } func isEncryptionKeyFeature(feature codersdk.CryptoKeyFeature) bool { - return feature == codersdk.CryptoKeyFeatureWorkspaceApp + return feature == codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey } func isSigningKeyFeature(feature codersdk.CryptoKeyFeature) bool { switch feature { - case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert: + case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert, codersdk.CryptoKeyFeatureWorkspaceAppsToken: return true default: return false @@ -292,14 +301,15 @@ func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (string, [] func (c *cache) refresh() { now := c.clock.Now("CryptoKeyCache", "refresh") c.mu.Lock() - defer c.mu.Unlock() if c.closed { + c.mu.Unlock() return } // If something's already fetching, we don't need to do anything. if c.fetching { + c.mu.Unlock() return } @@ -307,20 +317,21 @@ func (c *cache) refresh() { // is ongoing but prior to the timer getting reset. In this case we want to // avoid double fetching. if now.Sub(c.lastFetch) < refreshInterval { + c.mu.Unlock() return } c.fetching = true c.mu.Unlock() - keys, err := c.cryptoKeys(c.refreshCtx) + keys, err := c.cryptoKeys(c.ctx) if err != nil { - c.logger.Error(c.refreshCtx, "fetch crypto keys", slog.Error(err)) + c.logger.Error(c.ctx, "fetch crypto keys", slog.Error(err)) return } - // We don't defer an unlock here due to the deferred unlock at the top of the function. c.mu.Lock() + defer c.mu.Unlock() c.lastFetch = c.clock.Now() c.refresher.Reset(refreshInterval) @@ -332,9 +343,9 @@ func (c *cache) refresh() { // cryptoKeys queries the control plane for the crypto keys. // Outside of initialization, this should only be called by fetch. func (c *cache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) { - keys, err := c.fetcher.Fetch(ctx) + keys, err := c.fetcher.Fetch(ctx, c.feature) if err != nil { - return nil, xerrors.Errorf("crypto keys: %w", err) + return nil, xerrors.Errorf("fetch: %w", err) } cache := toKeyMap(keys, c.clock.Now()) return cache, nil @@ -361,9 +372,28 @@ func (c *cache) Close() error { } c.closed = true - c.refreshCancel() + c.cancel() c.refresher.Stop() c.cond.Broadcast() return nil } + +// We have to do this to avoid a circular dependency on db2sdk (cryptokeys -> db2sdk -> tailnet -> cryptokeys) +func toSDKKeys(keys []database.CryptoKey) []codersdk.CryptoKey { + into := make([]codersdk.CryptoKey, 0, len(keys)) + for _, key := range keys { + into = append(into, toSDK(key)) + } + return into +} + +func toSDK(key database.CryptoKey) codersdk.CryptoKey { + return codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeature(key.Feature), + Sequence: key.Sequence, + StartsAt: key.StartsAt, + DeletesAt: key.DeletesAt.Time, + Secret: key.Secret.String, + } +} diff --git a/coderd/cryptokeys/cache_test.go b/coderd/cryptokeys/cache_test.go index 92fc4527ae7b3..cda87315605a4 100644 --- a/coderd/cryptokeys/cache_test.go +++ b/coderd/cryptokeys/cache_test.go @@ -488,7 +488,7 @@ type fakeFetcher struct { called int } -func (f *fakeFetcher) Fetch(_ context.Context) ([]codersdk.CryptoKey, error) { +func (f *fakeFetcher) Fetch(_ context.Context, _ codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) { f.called++ return f.keys, nil } diff --git a/coderd/cryptokeys/rotate.go b/coderd/cryptokeys/rotate.go index 14a623e2156db..5d7d7b33b9dec 100644 --- a/coderd/cryptokeys/rotate.go +++ b/coderd/cryptokeys/rotate.go @@ -11,6 +11,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/quartz" ) @@ -53,10 +54,12 @@ func WithKeyDuration(keyDuration time.Duration) RotatorOption { // StartRotator starts a background process that rotates keys in the database. // It ensures there's at least one valid key per feature prior to returning. // Canceling the provided context will stop the background process. -func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, opts ...RotatorOption) error { +func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, opts ...RotatorOption) { + //nolint:gocritic // KeyRotator can only rotate crypto keys. + ctx = dbauthz.AsKeyRotator(ctx) kr := &rotator{ db: db, - logger: logger, + logger: logger.Named("keyrotator"), clock: quartz.NewReal(), keyDuration: DefaultKeyDuration, features: database.AllCryptoKeyFeatureValues(), @@ -68,12 +71,10 @@ func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, op err := kr.rotateKeys(ctx) if err != nil { - return xerrors.Errorf("rotate keys: %w", err) + kr.logger.Critical(ctx, "failed to rotate keys", slog.Error(err)) } go kr.start(ctx) - - return nil } // start begins the process of rotating keys. @@ -226,9 +227,11 @@ func (k *rotator) rotateKey(ctx context.Context, tx database.Store, key database func generateNewSecret(feature database.CryptoKeyFeature) (string, error) { switch feature { - case database.CryptoKeyFeatureWorkspaceApps: + case database.CryptoKeyFeatureWorkspaceAppsAPIKey: return generateKey(32) - case database.CryptoKeyFeatureOidcConvert: + case database.CryptoKeyFeatureWorkspaceAppsToken: + return generateKey(64) + case database.CryptoKeyFeatureOIDCConvert: return generateKey(64) case database.CryptoKeyFeatureTailnetResume: return generateKey(64) @@ -247,9 +250,11 @@ func generateKey(length int) (string, error) { func tokenDuration(feature database.CryptoKeyFeature) time.Duration { switch feature { - case database.CryptoKeyFeatureWorkspaceApps: + case database.CryptoKeyFeatureWorkspaceAppsAPIKey: + return WorkspaceAppsTokenDuration + case database.CryptoKeyFeatureWorkspaceAppsToken: return WorkspaceAppsTokenDuration - case database.CryptoKeyFeatureOidcConvert: + case database.CryptoKeyFeatureOIDCConvert: return OIDCConvertTokenDuration case database.CryptoKeyFeatureTailnetResume: return TailnetResumeTokenDuration diff --git a/coderd/cryptokeys/rotate_internal_test.go b/coderd/cryptokeys/rotate_internal_test.go index 43754c1d8750f..e427a3c6216ac 100644 --- a/coderd/cryptokeys/rotate_internal_test.go +++ b/coderd/cryptokeys/rotate_internal_test.go @@ -38,7 +38,7 @@ func Test_rotateKeys(t *testing.T) { clock: clock, logger: logger, features: []database.CryptoKeyFeature{ - database.CryptoKeyFeatureWorkspaceApps, + database.CryptoKeyFeatureWorkspaceAppsAPIKey, }, } @@ -46,7 +46,7 @@ func Test_rotateKeys(t *testing.T) { // Seed the database with an existing key. oldKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now, Sequence: 15, }) @@ -69,11 +69,11 @@ func Test_rotateKeys(t *testing.T) { // The new key should be created and have a starts_at of the old key's expires_at. newKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: oldKey.Sequence + 1, }) require.NoError(t, err) - requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, oldKey.ExpiresAt(keyDuration), nullTime, oldKey.Sequence+1) + requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceAppsAPIKey, oldKey.ExpiresAt(keyDuration), nullTime, oldKey.Sequence+1) // Advance the clock just before the keys delete time. clock.Advance(oldKey.DeletesAt.Time.UTC().Sub(now) - time.Second) @@ -123,7 +123,7 @@ func Test_rotateKeys(t *testing.T) { clock: clock, logger: logger, features: []database.CryptoKeyFeature{ - database.CryptoKeyFeatureWorkspaceApps, + database.CryptoKeyFeatureWorkspaceAppsAPIKey, }, } @@ -131,7 +131,7 @@ func Test_rotateKeys(t *testing.T) { // Seed the database with an existing key existingKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now, Sequence: 123, }) @@ -179,7 +179,7 @@ func Test_rotateKeys(t *testing.T) { clock: clock, logger: logger, features: []database.CryptoKeyFeature{ - database.CryptoKeyFeatureWorkspaceApps, + database.CryptoKeyFeatureWorkspaceAppsAPIKey, }, } @@ -187,7 +187,7 @@ func Test_rotateKeys(t *testing.T) { // Seed the database with an existing key deletingKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now.Add(-keyDuration), Sequence: 789, DeletesAt: sql.NullTime{ @@ -232,7 +232,7 @@ func Test_rotateKeys(t *testing.T) { clock: clock, logger: logger, features: []database.CryptoKeyFeature{ - database.CryptoKeyFeatureWorkspaceApps, + database.CryptoKeyFeatureWorkspaceAppsAPIKey, }, } @@ -240,7 +240,7 @@ func Test_rotateKeys(t *testing.T) { // Seed the database with an existing key deletingKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now, Sequence: 456, DeletesAt: sql.NullTime{ @@ -281,7 +281,7 @@ func Test_rotateKeys(t *testing.T) { clock: clock, logger: logger, features: []database.CryptoKeyFeature{ - database.CryptoKeyFeatureWorkspaceApps, + database.CryptoKeyFeatureWorkspaceAppsAPIKey, }, } @@ -291,7 +291,7 @@ func Test_rotateKeys(t *testing.T) { keys, err := db.GetCryptoKeys(ctx) require.NoError(t, err) require.Len(t, keys, 1) - requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceApps, clock.Now().UTC(), nullTime, 1) + requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceAppsAPIKey, clock.Now().UTC(), nullTime, 1) }) // Assert we insert a new key when the only key was manually deleted. @@ -312,14 +312,14 @@ func Test_rotateKeys(t *testing.T) { clock: clock, logger: logger, features: []database.CryptoKeyFeature{ - database.CryptoKeyFeatureWorkspaceApps, + database.CryptoKeyFeatureWorkspaceAppsAPIKey, }, } now := dbnow(clock) deletedkey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now, Sequence: 19, DeletesAt: sql.NullTime{ @@ -338,7 +338,7 @@ func Test_rotateKeys(t *testing.T) { keys, err := db.GetCryptoKeys(ctx) require.NoError(t, err) require.Len(t, keys, 1) - requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceApps, now, nullTime, deletedkey.Sequence+1) + requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceAppsAPIKey, now, nullTime, deletedkey.Sequence+1) }) // This tests ensures that rotation works with multiple @@ -365,9 +365,11 @@ func Test_rotateKeys(t *testing.T) { now := dbnow(clock) - // We'll test a scenario where one feature has no valid keys. - // Another has a key that should be rotate. And one that - // has a valid key that shouldn't trigger an action. + // We'll test a scenario where: + // - One feature has no valid keys. + // - One has a key that should be rotated. + // - One has a valid key that shouldn't trigger an action. + // - One has no keys at all. _ = dbgen.CryptoKey(t, db, database.CryptoKey{ Feature: database.CryptoKeyFeatureTailnetResume, StartsAt: now.Add(-keyDuration), @@ -377,6 +379,7 @@ func Test_rotateKeys(t *testing.T) { Valid: false, }, }) + // Generate another deleted key to ensure we insert after the latest sequence. deletedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ Feature: database.CryptoKeyFeatureTailnetResume, StartsAt: now.Add(-keyDuration), @@ -389,14 +392,14 @@ func Test_rotateKeys(t *testing.T) { // Insert a key that should be rotated. rotatedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now.Add(-keyDuration + time.Hour), Sequence: 42, }) // Insert a key that should not trigger an action. validKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureOidcConvert, + Feature: database.CryptoKeyFeatureOIDCConvert, StartsAt: now, Sequence: 17, }) @@ -406,26 +409,28 @@ func Test_rotateKeys(t *testing.T) { keys, err := db.GetCryptoKeys(ctx) require.NoError(t, err) - require.Len(t, keys, 4) + require.Len(t, keys, 5) kbf, err := keysByFeature(keys, database.AllCryptoKeyFeatureValues()) require.NoError(t, err) // No actions on OIDC convert. - require.Len(t, kbf[database.CryptoKeyFeatureOidcConvert], 1) + require.Len(t, kbf[database.CryptoKeyFeatureOIDCConvert], 1) // Workspace apps should have been rotated. - require.Len(t, kbf[database.CryptoKeyFeatureWorkspaceApps], 2) + require.Len(t, kbf[database.CryptoKeyFeatureWorkspaceAppsAPIKey], 2) // No existing key for tailnet resume should've // caused a key to be inserted. require.Len(t, kbf[database.CryptoKeyFeatureTailnetResume], 1) + require.Len(t, kbf[database.CryptoKeyFeatureWorkspaceAppsToken], 1) - oidcKey := kbf[database.CryptoKeyFeatureOidcConvert][0] + oidcKey := kbf[database.CryptoKeyFeatureOIDCConvert][0] tailnetKey := kbf[database.CryptoKeyFeatureTailnetResume][0] - requireKey(t, oidcKey, database.CryptoKeyFeatureOidcConvert, now, nullTime, validKey.Sequence) + appTokenKey := kbf[database.CryptoKeyFeatureWorkspaceAppsToken][0] + requireKey(t, oidcKey, database.CryptoKeyFeatureOIDCConvert, now, nullTime, validKey.Sequence) requireKey(t, tailnetKey, database.CryptoKeyFeatureTailnetResume, now, nullTime, deletedKey.Sequence+1) - - newKey := kbf[database.CryptoKeyFeatureWorkspaceApps][0] - oldKey := kbf[database.CryptoKeyFeatureWorkspaceApps][1] + requireKey(t, appTokenKey, database.CryptoKeyFeatureWorkspaceAppsToken, now, nullTime, 1) + newKey := kbf[database.CryptoKeyFeatureWorkspaceAppsAPIKey][0] + oldKey := kbf[database.CryptoKeyFeatureWorkspaceAppsAPIKey][1] if newKey.Sequence == rotatedKey.Sequence { oldKey, newKey = newKey, oldKey } @@ -433,8 +438,8 @@ func Test_rotateKeys(t *testing.T) { Time: rotatedKey.ExpiresAt(keyDuration).Add(WorkspaceAppsTokenDuration + time.Hour), Valid: true, } - requireKey(t, oldKey, database.CryptoKeyFeatureWorkspaceApps, rotatedKey.StartsAt.UTC(), deletesAt, rotatedKey.Sequence) - requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, rotatedKey.ExpiresAt(keyDuration), nullTime, rotatedKey.Sequence+1) + requireKey(t, oldKey, database.CryptoKeyFeatureWorkspaceAppsAPIKey, rotatedKey.StartsAt.UTC(), deletesAt, rotatedKey.Sequence) + requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceAppsAPIKey, rotatedKey.ExpiresAt(keyDuration), nullTime, rotatedKey.Sequence+1) }) t.Run("UnknownFeature", func(t *testing.T) { @@ -478,11 +483,11 @@ func Test_rotateKeys(t *testing.T) { keyDuration: keyDuration, clock: clock, logger: logger, - features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceApps}, + features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceAppsAPIKey}, } expiringKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now.Add(-keyDuration), Sequence: 345, }) @@ -522,19 +527,19 @@ func Test_rotateKeys(t *testing.T) { keyDuration: keyDuration, clock: clock, logger: logger, - features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceApps}, + features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceAppsAPIKey}, } now := dbnow(clock) expiredKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now.Add(-keyDuration - 2*time.Hour), Sequence: 19, }) deletedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now, Sequence: 20, Secret: sql.NullString{ @@ -587,9 +592,11 @@ func requireKey(t *testing.T, key database.CryptoKey, feature database.CryptoKey require.NoError(t, err) switch key.Feature { - case database.CryptoKeyFeatureOidcConvert: + case database.CryptoKeyFeatureOIDCConvert: + require.Len(t, secret, 64) + case database.CryptoKeyFeatureWorkspaceAppsToken: require.Len(t, secret, 64) - case database.CryptoKeyFeatureWorkspaceApps: + case database.CryptoKeyFeatureWorkspaceAppsAPIKey: require.Len(t, secret, 32) case database.CryptoKeyFeatureTailnetResume: require.Len(t, secret, 64) diff --git a/coderd/cryptokeys/rotate_test.go b/coderd/cryptokeys/rotate_test.go index 190ad213b1153..9e147c8f921f0 100644 --- a/coderd/cryptokeys/rotate_test.go +++ b/coderd/cryptokeys/rotate_test.go @@ -34,8 +34,7 @@ func TestRotator(t *testing.T) { require.NoError(t, err) require.Len(t, dbkeys, 0) - err = cryptokeys.StartRotator(ctx, logger, db, cryptokeys.WithClock(clock)) - require.NoError(t, err) + cryptokeys.StartRotator(ctx, logger, db, cryptokeys.WithClock(clock)) // Fetch the keys from the database and ensure they // are as expected. @@ -58,7 +57,7 @@ func TestRotator(t *testing.T) { now := clock.Now().UTC() rotatingKey := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now.Add(-cryptokeys.DefaultKeyDuration + time.Hour + time.Minute), Sequence: 12345, }) @@ -66,8 +65,7 @@ func TestRotator(t *testing.T) { trap := clock.Trap().TickerFunc() t.Cleanup(trap.Close) - err := cryptokeys.StartRotator(ctx, logger, db, cryptokeys.WithClock(clock)) - require.NoError(t, err) + cryptokeys.StartRotator(ctx, logger, db, cryptokeys.WithClock(clock)) initialKeyLen := len(database.AllCryptoKeyFeatureValues()) // Fetch the keys from the database and ensure they @@ -85,7 +83,7 @@ func TestRotator(t *testing.T) { require.NoError(t, err) require.Len(t, keys, initialKeyLen+1) - newKey, err := db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps) + newKey, err := db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceAppsAPIKey) require.NoError(t, err) require.Equal(t, rotatingKey.Sequence+1, newKey.Sequence) require.Equal(t, rotatingKey.ExpiresAt(cryptokeys.DefaultKeyDuration), newKey.StartsAt.UTC()) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 052f25450e6a5..35e4f09250ff8 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -228,6 +228,42 @@ var ( Scope: rbac.ScopeAll, }.WithCachedASTValue() + // See cryptokeys package. + subjectCryptoKeyRotator = rbac.Subject{ + FriendlyName: "Crypto Key Rotator", + ID: uuid.Nil.String(), + Roles: rbac.Roles([]rbac.Role{ + { + Identifier: rbac.RoleIdentifier{Name: "keyrotator"}, + DisplayName: "Key Rotator", + Site: rbac.Permissions(map[string][]policy.Action{ + rbac.ResourceCryptoKey.Type: {policy.WildcardSymbol}, + }), + Org: map[string][]rbac.Permission{}, + User: []rbac.Permission{}, + }, + }), + Scope: rbac.ScopeAll, + }.WithCachedASTValue() + + // See cryptokeys package. + subjectCryptoKeyReader = rbac.Subject{ + FriendlyName: "Crypto Key Reader", + ID: uuid.Nil.String(), + Roles: rbac.Roles([]rbac.Role{ + { + Identifier: rbac.RoleIdentifier{Name: "keyrotator"}, + DisplayName: "Key Rotator", + Site: rbac.Permissions(map[string][]policy.Action{ + rbac.ResourceCryptoKey.Type: {policy.WildcardSymbol}, + }), + Org: map[string][]rbac.Permission{}, + User: []rbac.Permission{}, + }, + }), + Scope: rbac.ScopeAll, + }.WithCachedASTValue() + subjectSystemRestricted = rbac.Subject{ FriendlyName: "System", ID: uuid.Nil.String(), @@ -281,6 +317,16 @@ func AsHangDetector(ctx context.Context) context.Context { return context.WithValue(ctx, authContextKey{}, subjectHangDetector) } +// AsKeyRotator returns a context with an actor that has permissions required for rotating crypto keys. +func AsKeyRotator(ctx context.Context) context.Context { + return context.WithValue(ctx, authContextKey{}, subjectCryptoKeyRotator) +} + +// AsKeyReader returns a context with an actor that has permissions required for reading crypto keys. +func AsKeyReader(ctx context.Context) context.Context { + return context.WithValue(ctx, authContextKey{}, subjectCryptoKeyReader) +} + // AsSystemRestricted returns a context with an actor that has permissions // required for various system operations (login, logout, metrics cache). func AsSystemRestricted(ctx context.Context) context.Context { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 6a34e88104ce1..439cf1bdaec19 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -2243,13 +2243,13 @@ func (s *MethodTestSuite) TestCryptoKeys() { })) s.Run("InsertCryptoKey", s.Subtest(func(db database.Store, check *expects) { check.Args(database.InsertCryptoKeyParams{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, }). Asserts(rbac.ResourceCryptoKey, policy.ActionCreate) })) s.Run("DeleteCryptoKey", s.Subtest(func(db database.Store, check *expects) { key := dbgen.CryptoKey(s.T(), db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: 4, }) check.Args(database.DeleteCryptoKeyParams{ @@ -2259,7 +2259,7 @@ func (s *MethodTestSuite) TestCryptoKeys() { })) s.Run("GetCryptoKeyByFeatureAndSequence", s.Subtest(func(db database.Store, check *expects) { key := dbgen.CryptoKey(s.T(), db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: 4, }) check.Args(database.GetCryptoKeyByFeatureAndSequenceParams{ @@ -2269,14 +2269,14 @@ func (s *MethodTestSuite) TestCryptoKeys() { })) s.Run("GetLatestCryptoKeyByFeature", s.Subtest(func(db database.Store, check *expects) { dbgen.CryptoKey(s.T(), db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: 4, }) - check.Args(database.CryptoKeyFeatureWorkspaceApps).Asserts(rbac.ResourceCryptoKey, policy.ActionRead) + check.Args(database.CryptoKeyFeatureWorkspaceAppsAPIKey).Asserts(rbac.ResourceCryptoKey, policy.ActionRead) })) s.Run("UpdateCryptoKeyDeletesAt", s.Subtest(func(db database.Store, check *expects) { key := dbgen.CryptoKey(s.T(), db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: 4, }) check.Args(database.UpdateCryptoKeyDeletesAtParams{ @@ -2286,7 +2286,7 @@ func (s *MethodTestSuite) TestCryptoKeys() { }).Asserts(rbac.ResourceCryptoKey, policy.ActionUpdate) })) s.Run("GetCryptoKeysByFeature", s.Subtest(func(db database.Store, check *expects) { - check.Args(database.CryptoKeyFeatureWorkspaceApps). + check.Args(database.CryptoKeyFeatureWorkspaceAppsAPIKey). Asserts(rbac.ResourceCryptoKey, policy.ActionRead) })) } diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 255c62f82aef2..69419b98c79b1 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -943,7 +943,7 @@ func CustomRole(t testing.TB, db database.Store, seed database.CustomRole) datab func CryptoKey(t testing.TB, db database.Store, seed database.CryptoKey) database.CryptoKey { t.Helper() - seed.Feature = takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceApps) + seed.Feature = takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceAppsAPIKey) // An empty string for the secret is interpreted as // a caller wanting a new secret to be generated. @@ -1048,9 +1048,11 @@ func takeFirst[Value comparable](values ...Value) Value { func newCryptoKeySecret(feature database.CryptoKeyFeature) (string, error) { switch feature { - case database.CryptoKeyFeatureWorkspaceApps: + case database.CryptoKeyFeatureWorkspaceAppsAPIKey: return generateCryptoKey(32) - case database.CryptoKeyFeatureOidcConvert: + case database.CryptoKeyFeatureWorkspaceAppsToken: + return generateCryptoKey(64) + case database.CryptoKeyFeatureOIDCConvert: return generateCryptoKey(64) case database.CryptoKeyFeatureTailnetResume: return generateCryptoKey(64) diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 3a9a5a7a2d8f6..fc7819e38f218 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -38,7 +38,8 @@ CREATE TYPE build_reason AS ENUM ( ); CREATE TYPE crypto_key_feature AS ENUM ( - 'workspace_apps', + 'workspace_apps_token', + 'workspace_apps_api_key', 'oidc_convert', 'tailnet_resume' ); diff --git a/coderd/database/migrations/000271_cryptokey_features.down.sql b/coderd/database/migrations/000271_cryptokey_features.down.sql new file mode 100644 index 0000000000000..7cdd00d222da8 --- /dev/null +++ b/coderd/database/migrations/000271_cryptokey_features.down.sql @@ -0,0 +1,18 @@ +-- Step 1: Remove the new entries from crypto_keys table +DELETE FROM crypto_keys +WHERE feature IN ('workspace_apps_token', 'workspace_apps_api_key'); + +CREATE TYPE old_crypto_key_feature AS ENUM ( + 'workspace_apps', + 'oidc_convert', + 'tailnet_resume' +); + +ALTER TABLE crypto_keys + ALTER COLUMN feature TYPE old_crypto_key_feature + USING (feature::text::old_crypto_key_feature); + +DROP TYPE crypto_key_feature; + +ALTER TYPE old_crypto_key_feature RENAME TO crypto_key_feature; + diff --git a/coderd/database/migrations/000271_cryptokey_features.up.sql b/coderd/database/migrations/000271_cryptokey_features.up.sql new file mode 100644 index 0000000000000..bca75d220d0c7 --- /dev/null +++ b/coderd/database/migrations/000271_cryptokey_features.up.sql @@ -0,0 +1,18 @@ +-- Create a new enum type with the desired values +CREATE TYPE new_crypto_key_feature AS ENUM ( + 'workspace_apps_token', + 'workspace_apps_api_key', + 'oidc_convert', + 'tailnet_resume' +); + +DELETE FROM crypto_keys WHERE feature = 'workspace_apps'; + +-- Drop the old type and rename the new one +ALTER TABLE crypto_keys + ALTER COLUMN feature TYPE new_crypto_key_feature + USING (feature::text::new_crypto_key_feature); + +DROP TYPE crypto_key_feature; + +ALTER TYPE new_crypto_key_feature RENAME TO crypto_key_feature; diff --git a/coderd/database/migrations/testdata/fixtures/000271_cryptokey_features.up.sql b/coderd/database/migrations/testdata/fixtures/000271_cryptokey_features.up.sql new file mode 100644 index 0000000000000..5cb2cd4c95509 --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000271_cryptokey_features.up.sql @@ -0,0 +1,40 @@ +INSERT INTO crypto_keys (feature, sequence, secret, secret_key_id, starts_at, deletes_at) +VALUES ( + 'workspace_apps_token', + 1, + 'abc', + NULL, + '1970-01-01 00:00:00 UTC'::timestamptz, + '2100-01-01 00:00:00 UTC'::timestamptz +); + +INSERT INTO crypto_keys (feature, sequence, secret, secret_key_id, starts_at, deletes_at) +VALUES ( + 'workspace_apps_api_key', + 1, + 'def', + NULL, + '1970-01-01 00:00:00 UTC'::timestamptz, + '2100-01-01 00:00:00 UTC'::timestamptz +); + +INSERT INTO crypto_keys (feature, sequence, secret, secret_key_id, starts_at, deletes_at) +VALUES ( + 'oidc_convert', + 2, + 'ghi', + NULL, + '1970-01-01 00:00:00 UTC'::timestamptz, + '2100-01-01 00:00:00 UTC'::timestamptz +); + +INSERT INTO crypto_keys (feature, sequence, secret, secret_key_id, starts_at, deletes_at) +VALUES ( + 'tailnet_resume', + 2, + 'jkl', + NULL, + '1970-01-01 00:00:00 UTC'::timestamptz, + '2100-01-01 00:00:00 UTC'::timestamptz +); + diff --git a/coderd/database/models.go b/coderd/database/models.go index 1207587d46529..e7d90acf5ea94 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -345,9 +345,10 @@ func AllBuildReasonValues() []BuildReason { type CryptoKeyFeature string const ( - CryptoKeyFeatureWorkspaceApps CryptoKeyFeature = "workspace_apps" - CryptoKeyFeatureOidcConvert CryptoKeyFeature = "oidc_convert" - CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume" + CryptoKeyFeatureWorkspaceAppsToken CryptoKeyFeature = "workspace_apps_token" + CryptoKeyFeatureWorkspaceAppsAPIKey CryptoKeyFeature = "workspace_apps_api_key" + CryptoKeyFeatureOIDCConvert CryptoKeyFeature = "oidc_convert" + CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume" ) func (e *CryptoKeyFeature) Scan(src interface{}) error { @@ -387,8 +388,9 @@ func (ns NullCryptoKeyFeature) Value() (driver.Value, error) { func (e CryptoKeyFeature) Valid() bool { switch e { - case CryptoKeyFeatureWorkspaceApps, - CryptoKeyFeatureOidcConvert, + case CryptoKeyFeatureWorkspaceAppsToken, + CryptoKeyFeatureWorkspaceAppsAPIKey, + CryptoKeyFeatureOIDCConvert, CryptoKeyFeatureTailnetResume: return true } @@ -397,8 +399,9 @@ func (e CryptoKeyFeature) Valid() bool { func AllCryptoKeyFeatureValues() []CryptoKeyFeature { return []CryptoKeyFeature{ - CryptoKeyFeatureWorkspaceApps, - CryptoKeyFeatureOidcConvert, + CryptoKeyFeatureWorkspaceAppsToken, + CryptoKeyFeatureWorkspaceAppsAPIKey, + CryptoKeyFeatureOIDCConvert, CryptoKeyFeatureTailnetResume, } } diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index a70e45a522989..257c95ddb2d7a 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -135,6 +135,8 @@ sql: api_key_id: APIKeyID callback_url: CallbackURL login_type_oauth2_provider_app: LoginTypeOAuth2ProviderApp + crypto_key_feature_workspace_apps_api_key: CryptoKeyFeatureWorkspaceAppsAPIKey + crypto_key_feature_oidc_convert: CryptoKeyFeatureOIDCConvert rules: - name: do-not-use-public-schema-in-queries message: "do not use public schema in queries" diff --git a/coderd/jwtutils/jwe.go b/coderd/jwtutils/jwe.go index d03816a477a26..bc9d0ddd2a9c8 100644 --- a/coderd/jwtutils/jwe.go +++ b/coderd/jwtutils/jwe.go @@ -65,6 +65,12 @@ func Encrypt(ctx context.Context, e EncryptKeyProvider, claims Claims) (string, return compact, nil } +func WithDecryptExpected(expected jwt.Expected) func(*DecryptOptions) { + return func(opts *DecryptOptions) { + opts.RegisteredClaims = expected + } +} + // DecryptOptions are options for decrypting a JWE. type DecryptOptions struct { RegisteredClaims jwt.Expected @@ -100,7 +106,7 @@ func Decrypt(ctx context.Context, d DecryptKeyProvider, token string, claims Cla kid := object.Header.KeyID if kid == "" { - return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey) + return ErrMissingKeyID } key, err := d.DecryptingKey(ctx, kid) diff --git a/coderd/jwtutils/jws.go b/coderd/jwtutils/jws.go index 73f35e672492d..0c8ca9aa30f39 100644 --- a/coderd/jwtutils/jws.go +++ b/coderd/jwtutils/jws.go @@ -10,10 +10,27 @@ import ( "golang.org/x/xerrors" ) +var ErrMissingKeyID = xerrors.New("missing key ID") + const ( keyIDHeaderKey = "kid" ) +// RegisteredClaims is a convenience type for embedding jwt.Claims. It should be +// preferred over embedding jwt.Claims directly since it will ensure that certain fields are set. +type RegisteredClaims jwt.Claims + +func (r RegisteredClaims) Validate(e jwt.Expected) error { + if r.Expiry == nil { + return xerrors.Errorf("expiry is required") + } + if e.Time.IsZero() { + return xerrors.Errorf("expected time is required") + } + + return (jwt.Claims(r)).Validate(e) +} + // Claims defines the payload for a JWT. Most callers // should embed jwt.Claims type Claims interface { @@ -24,6 +41,11 @@ const ( signingAlgo = jose.HS512 ) +type SigningKeyManager interface { + SigningKeyProvider + VerifyKeyProvider +} + type SigningKeyProvider interface { SigningKey(ctx context.Context) (id string, key interface{}, err error) } @@ -75,6 +97,12 @@ type VerifyOptions struct { SignatureAlgorithm jose.SignatureAlgorithm } +func WithVerifyExpected(expected jwt.Expected) func(*VerifyOptions) { + return func(opts *VerifyOptions) { + opts.RegisteredClaims = expected + } +} + // Verify verifies that a token was signed by the provided key. It unmarshals into the provided claims. func Verify(ctx context.Context, v VerifyKeyProvider, token string, claims Claims, opts ...func(*VerifyOptions)) error { options := VerifyOptions{ @@ -105,7 +133,7 @@ func Verify(ctx context.Context, v VerifyKeyProvider, token string, claims Claim kid := signature.Header.KeyID if kid == "" { - return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey) + return ErrMissingKeyID } key, err := v.VerifyingKey(ctx, kid) @@ -125,3 +153,35 @@ func Verify(ctx context.Context, v VerifyKeyProvider, token string, claims Claim return claims.Validate(options.RegisteredClaims) } + +// StaticKey fulfills the SigningKeycache and EncryptionKeycache interfaces. Useful for testing. +type StaticKey struct { + ID string + Key interface{} +} + +func (s StaticKey) SigningKey(_ context.Context) (string, interface{}, error) { + return s.ID, s.Key, nil +} + +func (s StaticKey) VerifyingKey(_ context.Context, id string) (interface{}, error) { + if id != s.ID { + return nil, xerrors.Errorf("invalid id %q", id) + } + return s.Key, nil +} + +func (s StaticKey) EncryptingKey(_ context.Context) (string, interface{}, error) { + return s.ID, s.Key, nil +} + +func (s StaticKey) DecryptingKey(_ context.Context, id string) (interface{}, error) { + if id != s.ID { + return nil, xerrors.Errorf("invalid id %q", id) + } + return s.Key, nil +} + +func (StaticKey) Close() error { + return nil +} diff --git a/coderd/jwtutils/jwt_test.go b/coderd/jwtutils/jwt_test.go index 697e5d210d858..5d1f4d48bdb4a 100644 --- a/coderd/jwtutils/jwt_test.go +++ b/coderd/jwtutils/jwt_test.go @@ -236,11 +236,11 @@ func TestJWS(t *testing.T) { ctx = testutil.Context(t, testutil.WaitShort) db, _ = dbtestutil.NewDB(t) _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureOidcConvert, + Feature: database.CryptoKeyFeatureOIDCConvert, StartsAt: time.Now(), }) log = slogtest.Make(t, nil) - fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert} + fetcher = &cryptokeys.DBFetcher{DB: db} ) cache, err := cryptokeys.NewSigningCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureOIDCConvert) @@ -326,15 +326,15 @@ func TestJWE(t *testing.T) { ctx = testutil.Context(t, testutil.WaitShort) db, _ = dbtestutil.NewDB(t) _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: time.Now(), }) log = slogtest.Make(t, nil) - fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureWorkspaceApps} + fetcher = &cryptokeys.DBFetcher{DB: db} ) - cache, err := cryptokeys.NewEncryptionCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureWorkspaceApp) + cache, err := cryptokeys.NewEncryptionCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey) require.NoError(t, err) claims := testClaims{ diff --git a/coderd/userauth.go b/coderd/userauth.go index 85ab0d77e6cc1..f1a19d77d23d0 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -15,7 +15,8 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/golang-jwt/jwt/v4" + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" "github.com/google/go-github/v43/github" "github.com/google/uuid" "github.com/moby/moby/pkg/namesgenerator" @@ -23,6 +24,9 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/cryptokeys" + "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/apikey" "github.com/coder/coder/v2/coderd/audit" @@ -32,7 +36,6 @@ import ( "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" - "github.com/coder/coder/v2/coderd/idpsync" "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/coderd/rbac" @@ -49,7 +52,7 @@ const ( ) type OAuthConvertStateClaims struct { - jwt.RegisteredClaims + jwtutils.RegisteredClaims UserID uuid.UUID `json:"user_id"` State string `json:"state"` @@ -57,6 +60,10 @@ type OAuthConvertStateClaims struct { ToLoginType codersdk.LoginType `json:"to_login_type"` } +func (o *OAuthConvertStateClaims) Validate(e jwt.Expected) error { + return o.RegisteredClaims.Validate(e) +} + // postConvertLoginType replies with an oauth state token capable of converting // the user to an oauth user. // @@ -149,11 +156,11 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) { // Eg: Developers with more than 1 deployment. now := time.Now() claims := &OAuthConvertStateClaims{ - RegisteredClaims: jwt.RegisteredClaims{ + RegisteredClaims: jwtutils.RegisteredClaims{ Issuer: api.DeploymentID, Subject: stateString, Audience: []string{user.ID.String()}, - ExpiresAt: jwt.NewNumericDate(now.Add(time.Minute * 5)), + Expiry: jwt.NewNumericDate(now.Add(time.Minute * 5)), NotBefore: jwt.NewNumericDate(now.Add(time.Second * -1)), IssuedAt: jwt.NewNumericDate(now), ID: uuid.NewString(), @@ -164,9 +171,7 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) { ToLoginType: req.ToType, } - token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) - // Key must be a byte slice, not an array. So make sure to include the [:] - tokenString, err := token.SignedString(api.OAuthSigningKey[:]) + token, err := jwtutils.Sign(ctx, api.OIDCConvertKeyCache, claims) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error signing state jwt.", @@ -176,8 +181,8 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) { } aReq.New = database.AuditOAuthConvertState{ - CreatedAt: claims.IssuedAt.Time, - ExpiresAt: claims.ExpiresAt.Time, + CreatedAt: claims.IssuedAt.Time(), + ExpiresAt: claims.Expiry.Time(), FromLoginType: database.LoginType(claims.FromLoginType), ToLoginType: database.LoginType(claims.ToLoginType), UserID: claims.UserID, @@ -186,8 +191,8 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) { http.SetCookie(rw, &http.Cookie{ Name: OAuthConvertCookieValue, Path: "/", - Value: tokenString, - Expires: claims.ExpiresAt.Time, + Value: token, + Expires: claims.Expiry.Time(), Secure: api.SecureAuthCookie, HttpOnly: true, // Must be SameSite to work on the redirected auth flow from the @@ -196,7 +201,7 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) { }) httpapi.Write(ctx, rw, http.StatusCreated, codersdk.OAuthConversionResponse{ StateString: stateString, - ExpiresAt: claims.ExpiresAt.Time, + ExpiresAt: claims.Expiry.Time(), ToType: claims.ToLoginType, UserID: claims.UserID, }) @@ -1677,10 +1682,9 @@ func (api *API) convertUserToOauth(ctx context.Context, r *http.Request, db data } } var claims OAuthConvertStateClaims - token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(_ *jwt.Token) (interface{}, error) { - return api.OAuthSigningKey[:], nil - }) - if xerrors.Is(err, jwt.ErrSignatureInvalid) || !token.Valid { + + err = jwtutils.Verify(ctx, api.OIDCConvertKeyCache, jwtCookie.Value, &claims) + if xerrors.Is(err, cryptokeys.ErrKeyNotFound) || xerrors.Is(err, cryptokeys.ErrKeyInvalid) || xerrors.Is(err, jose.ErrCryptoFailure) || xerrors.Is(err, jwtutils.ErrMissingKeyID) { // These errors are probably because the user is mixing 2 coder deployments. return database.User{}, idpsync.HTTPError{ Code: http.StatusBadRequest, @@ -1709,7 +1713,7 @@ func (api *API) convertUserToOauth(ctx context.Context, r *http.Request, db data oauthConvertAudit.UserID = claims.UserID oauthConvertAudit.Old = user - if claims.RegisteredClaims.Issuer != api.DeploymentID { + if claims.Issuer != api.DeploymentID { return database.User{}, idpsync.HTTPError{ Code: http.StatusForbidden, Msg: "Request to convert login type failed. Issuer mismatch. Found a cookie from another coder deployment, please try again.", diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 20dfe7f723899..6386be7eb8be4 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -3,6 +3,8 @@ package coderd_test import ( "context" "crypto" + "crypto/rand" + "encoding/json" "fmt" "io" "net/http" @@ -13,6 +15,7 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/go-jose/go-jose/v4" "github.com/golang-jwt/jwt/v4" "github.com/google/go-github/v43/github" "github.com/google/uuid" @@ -27,10 +30,12 @@ import ( "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/coderdtest/oidctest" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/codersdk" @@ -1316,6 +1321,7 @@ func TestUserOIDC(t *testing.T) { owner := coderdtest.CreateFirstUser(t, client) user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) + require.Equal(t, codersdk.LoginTypePassword, userData.LoginType) claims := jwt.MapClaims{ "email": userData.Email, @@ -1323,15 +1329,17 @@ func TestUserOIDC(t *testing.T) { var err error user.HTTPClient.Jar, err = cookiejar.New(nil) require.NoError(t, err) + user.HTTPClient.Transport = http.DefaultTransport.(*http.Transport).Clone() ctx := testutil.Context(t, testutil.WaitShort) + convertResponse, err := user.ConvertLoginType(ctx, codersdk.ConvertLoginRequest{ ToType: codersdk.LoginTypeOIDC, Password: "SomeSecurePassword!", }) require.NoError(t, err) - fake.LoginWithClient(t, user, claims, func(r *http.Request) { + _, _ = fake.LoginWithClient(t, user, claims, func(r *http.Request) { r.URL.RawQuery = url.Values{ "oidc_merge_state": {convertResponse.StateString}, }.Encode() @@ -1341,6 +1349,99 @@ func TestUserOIDC(t *testing.T) { r.AddCookie(cookie) } }) + + info, err := client.User(ctx, userData.ID.String()) + require.NoError(t, err) + require.Equal(t, codersdk.LoginTypeOIDC, info.LoginType) + }) + + t.Run("BadJWT", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitMedium) + logger = slogtest.Make(t, nil) + ) + + auditor := audit.NewMock() + fake := oidctest.NewFakeIDP(t, + oidctest.WithRefresh(func(_ string) error { + return xerrors.New("refreshing token should never occur") + }), + oidctest.WithServing(), + ) + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + }) + + db, ps := dbtestutil.NewDB(t) + fetcher := &cryptokeys.DBFetcher{ + DB: db, + } + + kc, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, codersdk.CryptoKeyFeatureOIDCConvert) + require.NoError(t, err) + + client := coderdtest.New(t, &coderdtest.Options{ + Auditor: auditor, + OIDCConfig: cfg, + Database: db, + Pubsub: ps, + OIDCConvertKeyCache: kc, + }) + + owner := coderdtest.CreateFirstUser(t, client) + user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) + + claims := jwt.MapClaims{ + "email": userData.Email, + } + user.HTTPClient.Jar, err = cookiejar.New(nil) + require.NoError(t, err) + user.HTTPClient.Transport = http.DefaultTransport.(*http.Transport).Clone() + + convertResponse, err := user.ConvertLoginType(ctx, codersdk.ConvertLoginRequest{ + ToType: codersdk.LoginTypeOIDC, + Password: "SomeSecurePassword!", + }) + require.NoError(t, err) + + // Update the cookie to use a bad signing key. We're asserting the behavior of the scenario + // where a JWT gets minted on an old version of Coder but gets verified on a new version. + _, resp := fake.AttemptLogin(t, user, claims, func(r *http.Request) { + r.URL.RawQuery = url.Values{ + "oidc_merge_state": {convertResponse.StateString}, + }.Encode() + r.Header.Set(codersdk.SessionTokenHeader, user.SessionToken()) + + cookies := user.HTTPClient.Jar.Cookies(user.URL) + for i, cookie := range cookies { + if cookie.Name != coderd.OAuthConvertCookieValue { + continue + } + + jwt := cookie.Value + var claims coderd.OAuthConvertStateClaims + err := jwtutils.Verify(ctx, kc, jwt, &claims) + require.NoError(t, err) + badJWT := generateBadJWT(t, claims) + cookie.Value = badJWT + cookies[i] = cookie + } + + user.HTTPClient.Jar.SetCookies(user.URL, cookies) + + for _, cookie := range cookies { + fmt.Printf("cookie: %+v\n", cookie) + r.AddCookie(cookie) + } + }) + defer resp.Body.Close() + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + var respErr codersdk.Response + err = json.NewDecoder(resp.Body).Decode(&respErr) + require.NoError(t, err) + require.Contains(t, respErr.Message, "Using an invalid jwt to authorize this action.") }) t.Run("AlternateUsername", func(t *testing.T) { @@ -2022,3 +2123,24 @@ func inflateClaims(t testing.TB, seed jwt.MapClaims, size int) jwt.MapClaims { seed["random_data"] = junk return seed } + +// generateBadJWT generates a JWT with a random key. It's intended to emulate the old-style JWT's we generated. +func generateBadJWT(t *testing.T, claims interface{}) string { + t.Helper() + + var buf [64]byte + _, err := rand.Read(buf[:]) + require.NoError(t, err) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.HS512, + Key: buf[:], + }, nil) + require.NoError(t, err) + payload, err := json.Marshal(claims) + require.NoError(t, err) + signed, err := signer.Sign(payload) + require.NoError(t, err) + compact, err := signed.CompactSerialize() + require.NoError(t, err) + return compact +} diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 6ea631f2e7d0c..a181697f27279 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -32,6 +32,7 @@ import ( "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" @@ -852,8 +853,12 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R ) if resumeToken != "" { var err error - peerID, err = api.Options.CoordinatorResumeTokenProvider.VerifyResumeToken(resumeToken) - if err != nil { + peerID, err = api.Options.CoordinatorResumeTokenProvider.VerifyResumeToken(ctx, resumeToken) + // If the token is missing the key ID, it's probably an old token in which + // case we just want to generate a new peer ID. + if xerrors.Is(err, jwtutils.ErrMissingKeyID) { + peerID = uuid.New() + } else if err != nil { httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{ Message: workspacesdk.CoordinateAPIInvalidResumeToken, Detail: err.Error(), @@ -862,9 +867,10 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R }, }) return + } else { + api.Logger.Debug(ctx, "accepted coordinate resume token for peer", + slog.F("peer_id", peerID.String())) } - api.Logger.Debug(ctx, "accepted coordinate resume token for peer", - slog.F("peer_id", peerID.String())) } api.WebsocketWaitMutex.Lock() diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 8c0801a914d61..ba677975471d6 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v4/jwt" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,6 +37,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/externalauth" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/workspacesdk" @@ -531,20 +533,20 @@ func newResumeTokenRecordingProvider(t testing.TB, underlying tailnet.ResumeToke } } -func (r *resumeTokenRecordingProvider) GenerateResumeToken(peerID uuid.UUID) (*tailnetproto.RefreshResumeTokenResponse, error) { +func (r *resumeTokenRecordingProvider) GenerateResumeToken(ctx context.Context, peerID uuid.UUID) (*tailnetproto.RefreshResumeTokenResponse, error) { select { case r.generateCalls <- peerID: - return r.ResumeTokenProvider.GenerateResumeToken(peerID) + return r.ResumeTokenProvider.GenerateResumeToken(ctx, peerID) default: r.t.Error("generateCalls full") return nil, xerrors.New("generateCalls full") } } -func (r *resumeTokenRecordingProvider) VerifyResumeToken(token string) (uuid.UUID, error) { +func (r *resumeTokenRecordingProvider) VerifyResumeToken(ctx context.Context, token string) (uuid.UUID, error) { select { case r.verifyCalls <- token: - return r.ResumeTokenProvider.VerifyResumeToken(token) + return r.ResumeTokenProvider.VerifyResumeToken(ctx, token) default: r.t.Error("verifyCalls full") return uuid.Nil, xerrors.New("verifyCalls full") @@ -554,69 +556,136 @@ func (r *resumeTokenRecordingProvider) VerifyResumeToken(token string) (uuid.UUI func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) { t.Parallel() - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - clock := quartz.NewMock(t) - resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey() - require.NoError(t, err) - resumeTokenProvider := newResumeTokenRecordingProvider( - t, - tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour), - ) - client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ - Coordinator: tailnet.NewCoordinator(logger), - CoordinatorResumeTokenProvider: resumeTokenProvider, + t.Run("OK", func(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + clock := quartz.NewMock(t) + resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey() + mgr := jwtutils.StaticKey{ + ID: uuid.New().String(), + Key: resumeTokenSigningKey[:], + } + require.NoError(t, err) + resumeTokenProvider := newResumeTokenRecordingProvider( + t, + tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour), + ) + client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ + Coordinator: tailnet.NewCoordinator(logger), + CoordinatorResumeTokenProvider: resumeTokenProvider, + }) + defer closer.Close() + user := coderdtest.CreateFirstUser(t, client) + + // Create a workspace with an agent. No need to connect it since clients can + // still connect to the coordinator while the agent isn't connected. + r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{ + OrganizationID: user.OrganizationID, + OwnerID: user.UserID, + }).WithAgent().Do() + agentTokenUUID, err := uuid.Parse(r.AgentToken) + require.NoError(t, err) + ctx := testutil.Context(t, testutil.WaitLong) + agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint + require.NoError(t, err) + + // Connect with no resume token, and ensure that the peer ID is set to a + // random value. + originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "") + require.NoError(t, err) + originalPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls) + require.NotEqual(t, originalPeerID, uuid.Nil) + + // Connect with a valid resume token, and ensure that the peer ID is set to + // the stored value. + clock.Advance(time.Second) + newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken) + require.NoError(t, err) + verifiedToken := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls) + require.Equal(t, originalResumeToken, verifiedToken) + newPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls) + require.Equal(t, originalPeerID, newPeerID) + require.NotEqual(t, originalResumeToken, newResumeToken) + + // Connect with an invalid resume token, and ensure that the request is + // rejected. + clock.Advance(time.Second) + _, err = connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "invalid") + require.Error(t, err) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode()) + require.Len(t, sdkErr.Validations, 1) + require.Equal(t, "resume_token", sdkErr.Validations[0].Field) + verifiedToken = testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls) + require.Equal(t, "invalid", verifiedToken) + + select { + case <-resumeTokenProvider.generateCalls: + t.Fatal("unexpected peer ID in channel") + default: + } }) - defer closer.Close() - user := coderdtest.CreateFirstUser(t, client) - // Create a workspace with an agent. No need to connect it since clients can - // still connect to the coordinator while the agent isn't connected. - r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{ - OrganizationID: user.OrganizationID, - OwnerID: user.UserID, - }).WithAgent().Do() - agentTokenUUID, err := uuid.Parse(r.AgentToken) - require.NoError(t, err) - ctx := testutil.Context(t, testutil.WaitLong) - agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint - require.NoError(t, err) + t.Run("BadJWT", func(t *testing.T) { + t.Parallel() - // Connect with no resume token, and ensure that the peer ID is set to a - // random value. - originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "") - require.NoError(t, err) - originalPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls) - require.NotEqual(t, originalPeerID, uuid.Nil) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + clock := quartz.NewMock(t) + resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey() + mgr := jwtutils.StaticKey{ + ID: uuid.New().String(), + Key: resumeTokenSigningKey[:], + } + require.NoError(t, err) + resumeTokenProvider := newResumeTokenRecordingProvider( + t, + tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour), + ) + client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ + Coordinator: tailnet.NewCoordinator(logger), + CoordinatorResumeTokenProvider: resumeTokenProvider, + }) + defer closer.Close() + user := coderdtest.CreateFirstUser(t, client) - // Connect with a valid resume token, and ensure that the peer ID is set to - // the stored value. - clock.Advance(time.Second) - newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken) - require.NoError(t, err) - verifiedToken := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls) - require.Equal(t, originalResumeToken, verifiedToken) - newPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls) - require.Equal(t, originalPeerID, newPeerID) - require.NotEqual(t, originalResumeToken, newResumeToken) - - // Connect with an invalid resume token, and ensure that the request is - // rejected. - clock.Advance(time.Second) - _, err = connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "invalid") - require.Error(t, err) - var sdkErr *codersdk.Error - require.ErrorAs(t, err, &sdkErr) - require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode()) - require.Len(t, sdkErr.Validations, 1) - require.Equal(t, "resume_token", sdkErr.Validations[0].Field) - verifiedToken = testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls) - require.Equal(t, "invalid", verifiedToken) + // Create a workspace with an agent. No need to connect it since clients can + // still connect to the coordinator while the agent isn't connected. + r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{ + OrganizationID: user.OrganizationID, + OwnerID: user.UserID, + }).WithAgent().Do() + agentTokenUUID, err := uuid.Parse(r.AgentToken) + require.NoError(t, err) + ctx := testutil.Context(t, testutil.WaitLong) + agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint + require.NoError(t, err) - select { - case <-resumeTokenProvider.generateCalls: - t.Fatal("unexpected peer ID in channel") - default: - } + // Connect with no resume token, and ensure that the peer ID is set to a + // random value. + originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "") + require.NoError(t, err) + originalPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls) + require.NotEqual(t, originalPeerID, uuid.Nil) + + // Connect with an outdated token, and ensure that the peer ID is set to a + // random value. We don't want to fail requests just because + // a user got unlucky during a deployment upgrade. + outdatedToken := generateBadJWT(t, jwtutils.RegisteredClaims{ + Subject: originalPeerID.String(), + Expiry: jwt.NewNumericDate(clock.Now().Add(time.Minute)), + }) + + clock.Advance(time.Second) + newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, outdatedToken) + require.NoError(t, err) + verifiedToken := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls) + require.Equal(t, outdatedToken, verifiedToken) + newPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls) + require.NotEqual(t, originalPeerID, newPeerID) + require.NotEqual(t, originalResumeToken, newResumeToken) + }) } // connectToCoordinatorAndFetchResumeToken connects to the tailnet coordinator diff --git a/coderd/workspaceapps.go b/coderd/workspaceapps.go index d2fa11b9ea2ea..e264dbd80b58d 100644 --- a/coderd/workspaceapps.go +++ b/coderd/workspaceapps.go @@ -16,6 +16,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/coderd/workspaceapps/appurl" @@ -122,10 +123,11 @@ func (api *API) workspaceApplicationAuth(rw http.ResponseWriter, r *http.Request return } - // Encrypt the API key. - encryptedAPIKey, err := api.AppSecurityKey.EncryptAPIKey(workspaceapps.EncryptedAPIKeyPayload{ + payload := workspaceapps.EncryptedAPIKeyPayload{ APIKey: cookie.Value, - }) + } + payload.Fill(api.Clock.Now()) + encryptedAPIKey, err := jwtutils.Encrypt(ctx, api.AppEncryptionKeyCache, payload) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to encrypt API key.", diff --git a/coderd/workspaceapps/apptest/apptest.go b/coderd/workspaceapps/apptest/apptest.go index 14adf2d61d362..c6e251806230d 100644 --- a/coderd/workspaceapps/apptest/apptest.go +++ b/coderd/workspaceapps/apptest/apptest.go @@ -3,6 +3,7 @@ package apptest import ( "bufio" "context" + "crypto/rand" "encoding/json" "fmt" "io" @@ -408,6 +409,67 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) { require.Equal(t, http.StatusInternalServerError, resp.StatusCode) assertWorkspaceLastUsedAtNotUpdated(t, appDetails) }) + + t.Run("BadJWT", func(t *testing.T) { + t.Parallel() + + appDetails := setupProxyTest(t, nil) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + u := appDetails.PathAppURL(appDetails.Apps.Owner) + resp, err := requestWithRetries(ctx, t, appDetails.AppClient(t), http.MethodGet, u.String(), nil) + require.NoError(t, err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, proxyTestAppBody, string(body)) + require.Equal(t, http.StatusOK, resp.StatusCode) + + appTokenCookie := findCookie(resp.Cookies(), codersdk.SignedAppTokenCookie) + require.NotNil(t, appTokenCookie, "no signed app token cookie in response") + require.Equal(t, appTokenCookie.Path, u.Path, "incorrect path on app token cookie") + + object, err := jose.ParseSigned(appTokenCookie.Value) + require.NoError(t, err) + require.Len(t, object.Signatures, 1) + + // Parse the payload. + var tok workspaceapps.SignedToken + //nolint:gosec + err = json.Unmarshal(object.UnsafePayloadWithoutVerification(), &tok) + require.NoError(t, err) + + appTokenClient := appDetails.AppClient(t) + apiKey := appTokenClient.SessionToken() + appTokenClient.SetSessionToken("") + appTokenClient.HTTPClient.Jar, err = cookiejar.New(nil) + require.NoError(t, err) + // Sign the token with an old-style key. + appTokenCookie.Value = generateBadJWT(t, tok) + appTokenClient.HTTPClient.Jar.SetCookies(u, + []*http.Cookie{ + appTokenCookie, + { + Name: codersdk.PathAppSessionTokenCookie, + Value: apiKey, + }, + }, + ) + + resp, err = requestWithRetries(ctx, t, appTokenClient, http.MethodGet, u.String(), nil) + require.NoError(t, err) + defer resp.Body.Close() + body, err = io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, proxyTestAppBody, string(body)) + require.Equal(t, http.StatusOK, resp.StatusCode) + assertWorkspaceLastUsedAtUpdated(t, appDetails) + + // Since the old token is invalid, the signed app token cookie should have a new value. + newTokenCookie := findCookie(resp.Cookies(), codersdk.SignedAppTokenCookie) + require.NotEqual(t, appTokenCookie.Value, newTokenCookie.Value) + }) }) t.Run("WorkspaceApplicationAuth", func(t *testing.T) { @@ -463,7 +525,7 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) { appClient.SetSessionToken("") // Try to load the application without authentication. - u := c.appURL + u := *c.appURL u.Path = path.Join(u.Path, "/test") req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) require.NoError(t, err) @@ -500,7 +562,7 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) { // Copy the query parameters and then check equality. u.RawQuery = gotLocation.RawQuery - require.Equal(t, u, gotLocation) + require.Equal(t, u, *gotLocation) // Verify the API key is set. encryptedAPIKey := gotLocation.Query().Get(workspaceapps.SubdomainProxyAPIKeyParam) @@ -580,6 +642,38 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) { resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) }) + + t.Run("BadJWE", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + currentKeyStr := appDetails.SDKClient.SessionToken() + appClient := appDetails.AppClient(t) + appClient.SetSessionToken("") + u := *c.appURL + u.Path = path.Join(u.Path, "/test") + badToken := generateBadJWE(t, workspaceapps.EncryptedAPIKeyPayload{ + APIKey: currentKeyStr, + }) + + u.RawQuery = (url.Values{ + workspaceapps.SubdomainProxyAPIKeyParam: {badToken}, + }).Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + require.NoError(t, err) + + var resp *http.Response + resp, err = doWithRetries(t, appClient, req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Contains(t, string(body), "Could not decrypt API key. Please remove the query parameter and try again.") + }) } }) }) @@ -1077,6 +1171,68 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) { assertWorkspaceLastUsedAtNotUpdated(t, appDetails) }) }) + + t.Run("BadJWT", func(t *testing.T) { + t.Parallel() + + appDetails := setupProxyTest(t, nil) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + u := appDetails.SubdomainAppURL(appDetails.Apps.Owner) + resp, err := requestWithRetries(ctx, t, appDetails.AppClient(t), http.MethodGet, u.String(), nil) + require.NoError(t, err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, proxyTestAppBody, string(body)) + require.Equal(t, http.StatusOK, resp.StatusCode) + + appTokenCookie := findCookie(resp.Cookies(), codersdk.SignedAppTokenCookie) + require.NotNil(t, appTokenCookie, "no signed token cookie in response") + require.Equal(t, appTokenCookie.Path, "/", "incorrect path on signed token cookie") + + object, err := jose.ParseSigned(appTokenCookie.Value) + require.NoError(t, err) + require.Len(t, object.Signatures, 1) + + // Parse the payload. + var tok workspaceapps.SignedToken + //nolint:gosec + err = json.Unmarshal(object.UnsafePayloadWithoutVerification(), &tok) + require.NoError(t, err) + + appTokenClient := appDetails.AppClient(t) + apiKey := appTokenClient.SessionToken() + appTokenClient.SetSessionToken("") + appTokenClient.HTTPClient.Jar, err = cookiejar.New(nil) + require.NoError(t, err) + // Sign the token with an old-style key. + appTokenCookie.Value = generateBadJWT(t, tok) + appTokenClient.HTTPClient.Jar.SetCookies(u, + []*http.Cookie{ + appTokenCookie, + { + Name: codersdk.SubdomainAppSessionTokenCookie, + Value: apiKey, + }, + }, + ) + + // We should still be able to successfully proxy. + resp, err = requestWithRetries(ctx, t, appTokenClient, http.MethodGet, u.String(), nil) + require.NoError(t, err) + defer resp.Body.Close() + body, err = io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, proxyTestAppBody, string(body)) + require.Equal(t, http.StatusOK, resp.StatusCode) + assertWorkspaceLastUsedAtUpdated(t, appDetails) + + // Since the old token is invalid, the signed app token cookie should have a new value. + newTokenCookie := findCookie(resp.Cookies(), codersdk.SignedAppTokenCookie) + require.NotEqual(t, appTokenCookie.Value, newTokenCookie.Value) + }) }) t.Run("PortSharing", func(t *testing.T) { @@ -1789,3 +1945,57 @@ func assertWorkspaceLastUsedAtNotUpdated(t testing.TB, details *Details) { require.NoError(t, err) require.Equal(t, before.LastUsedAt, after.LastUsedAt, "workspace LastUsedAt updated when it should not have been") } + +func generateBadJWE(t *testing.T, claims interface{}) string { + t.Helper() + var buf [32]byte + _, err := rand.Read(buf[:]) + require.NoError(t, err) + encrypt, err := jose.NewEncrypter( + jose.A256GCM, + jose.Recipient{ + Algorithm: jose.A256GCMKW, + Key: buf[:], + }, &jose.EncrypterOptions{ + Compression: jose.DEFLATE, + }, + ) + require.NoError(t, err) + payload, err := json.Marshal(claims) + require.NoError(t, err) + signed, err := encrypt.Encrypt(payload) + require.NoError(t, err) + compact, err := signed.CompactSerialize() + require.NoError(t, err) + return compact +} + +// generateBadJWT generates a JWT with a random key. It's intended to emulate the old-style JWT's we generated. +func generateBadJWT(t *testing.T, claims interface{}) string { + t.Helper() + + var buf [64]byte + _, err := rand.Read(buf[:]) + require.NoError(t, err) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.HS512, + Key: buf[:], + }, nil) + require.NoError(t, err) + payload, err := json.Marshal(claims) + require.NoError(t, err) + signed, err := signer.Sign(payload) + require.NoError(t, err) + compact, err := signed.CompactSerialize() + require.NoError(t, err) + return compact +} + +func findCookie(cookies []*http.Cookie, name string) *http.Cookie { + for _, cookie := range cookies { + if cookie.Name == name { + return cookie + } + } + return nil +} diff --git a/coderd/workspaceapps/db.go b/coderd/workspaceapps/db.go index 1b369cf6d6ef4..1aa4dfe91bdd0 100644 --- a/coderd/workspaceapps/db.go +++ b/coderd/workspaceapps/db.go @@ -13,11 +13,15 @@ import ( "golang.org/x/exp/slices" "golang.org/x/xerrors" + "github.com/go-jose/go-jose/v4/jwt" + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/codersdk" @@ -35,12 +39,20 @@ type DBTokenProvider struct { DeploymentValues *codersdk.DeploymentValues OAuth2Configs *httpmw.OAuth2Configs WorkspaceAgentInactiveTimeout time.Duration - SigningKey SecurityKey + Keycache cryptokeys.SigningKeycache } var _ SignedTokenProvider = &DBTokenProvider{} -func NewDBTokenProvider(log slog.Logger, accessURL *url.URL, authz rbac.Authorizer, db database.Store, cfg *codersdk.DeploymentValues, oauth2Cfgs *httpmw.OAuth2Configs, workspaceAgentInactiveTimeout time.Duration, signingKey SecurityKey) SignedTokenProvider { +func NewDBTokenProvider(log slog.Logger, + accessURL *url.URL, + authz rbac.Authorizer, + db database.Store, + cfg *codersdk.DeploymentValues, + oauth2Cfgs *httpmw.OAuth2Configs, + workspaceAgentInactiveTimeout time.Duration, + signer cryptokeys.SigningKeycache, +) SignedTokenProvider { if workspaceAgentInactiveTimeout == 0 { workspaceAgentInactiveTimeout = 1 * time.Minute } @@ -53,12 +65,12 @@ func NewDBTokenProvider(log slog.Logger, accessURL *url.URL, authz rbac.Authoriz DeploymentValues: cfg, OAuth2Configs: oauth2Cfgs, WorkspaceAgentInactiveTimeout: workspaceAgentInactiveTimeout, - SigningKey: signingKey, + Keycache: signer, } } func (p *DBTokenProvider) FromRequest(r *http.Request) (*SignedToken, bool) { - return FromRequest(r, p.SigningKey) + return FromRequest(r, p.Keycache) } func (p *DBTokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *http.Request, issueReq IssueTokenRequest) (*SignedToken, string, bool) { @@ -70,7 +82,7 @@ func (p *DBTokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r * dangerousSystemCtx := dbauthz.AsSystemRestricted(ctx) appReq := issueReq.AppRequest.Normalize() - err := appReq.Validate() + err := appReq.Check() if err != nil { WriteWorkspaceApp500(p.Logger, p.DashboardURL, rw, r, &appReq, err, "invalid app request") return nil, "", false @@ -210,9 +222,11 @@ func (p *DBTokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r * return nil, "", false } + token.RegisteredClaims = jwtutils.RegisteredClaims{ + Expiry: jwt.NewNumericDate(time.Now().Add(DefaultTokenExpiry)), + } // Sign the token. - token.Expiry = time.Now().Add(DefaultTokenExpiry) - tokenStr, err := p.SigningKey.SignToken(token) + tokenStr, err := jwtutils.Sign(ctx, p.Keycache, token) if err != nil { WriteWorkspaceApp500(p.Logger, p.DashboardURL, rw, r, &appReq, err, "generate token") return nil, "", false diff --git a/coderd/workspaceapps/db_test.go b/coderd/workspaceapps/db_test.go index 6c5a0212aff2b..bf364f1ce62b3 100644 --- a/coderd/workspaceapps/db_test.go +++ b/coderd/workspaceapps/db_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v4/jwt" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -20,6 +21,7 @@ import ( "github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/coderd/workspaceapps/appurl" "github.com/coder/coder/v2/codersdk" @@ -94,8 +96,7 @@ func Test_ResolveRequest(t *testing.T) { _ = closer.Close() }) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) - t.Cleanup(cancel) + ctx := testutil.Context(t, testutil.WaitMedium) firstUser := coderdtest.CreateFirstUser(t, client) me, err := client.User(ctx, codersdk.Me) @@ -276,15 +277,17 @@ func Test_ResolveRequest(t *testing.T) { _ = w.Body.Close() require.Equal(t, &workspaceapps.SignedToken{ + RegisteredClaims: jwtutils.RegisteredClaims{ + Expiry: jwt.NewNumericDate(token.Expiry.Time()), + }, Request: req, - Expiry: token.Expiry, // ignored to avoid flakiness UserID: me.ID, WorkspaceID: workspace.ID, AgentID: agentID, AppURL: appURL, }, token) require.NotZero(t, token.Expiry) - require.WithinDuration(t, time.Now().Add(workspaceapps.DefaultTokenExpiry), token.Expiry, time.Minute) + require.WithinDuration(t, time.Now().Add(workspaceapps.DefaultTokenExpiry), token.Expiry.Time(), time.Minute) // Check that the token was set in the response and is valid. require.Len(t, w.Cookies(), 1) @@ -292,10 +295,11 @@ func Test_ResolveRequest(t *testing.T) { require.Equal(t, codersdk.SignedAppTokenCookie, cookie.Name) require.Equal(t, req.BasePath, cookie.Path) - parsedToken, err := api.AppSecurityKey.VerifySignedToken(cookie.Value) + var parsedToken workspaceapps.SignedToken + err := jwtutils.Verify(ctx, api.AppSigningKeyCache, cookie.Value, &parsedToken) require.NoError(t, err) // normalize expiry - require.WithinDuration(t, token.Expiry, parsedToken.Expiry, 2*time.Second) + require.WithinDuration(t, token.Expiry.Time(), parsedToken.Expiry.Time(), 2*time.Second) parsedToken.Expiry = token.Expiry require.Equal(t, token, &parsedToken) @@ -314,7 +318,7 @@ func Test_ResolveRequest(t *testing.T) { }) require.True(t, ok) // normalize expiry - require.WithinDuration(t, token.Expiry, secondToken.Expiry, 2*time.Second) + require.WithinDuration(t, token.Expiry.Time(), secondToken.Expiry.Time(), 2*time.Second) secondToken.Expiry = token.Expiry require.Equal(t, token, secondToken) } @@ -540,13 +544,16 @@ func Test_ResolveRequest(t *testing.T) { // App name differs AppSlugOrPort: appNamePublic, }).Normalize(), - Expiry: time.Now().Add(time.Minute), + RegisteredClaims: jwtutils.RegisteredClaims{ + Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute)), + }, UserID: me.ID, WorkspaceID: workspace.ID, AgentID: agentID, AppURL: appURL, } - badTokenStr, err := api.AppSecurityKey.SignToken(badToken) + + badTokenStr, err := jwtutils.Sign(ctx, api.AppSigningKeyCache, badToken) require.NoError(t, err) req := (workspaceapps.Request{ @@ -589,7 +596,8 @@ func Test_ResolveRequest(t *testing.T) { require.Len(t, cookies, 1) require.Equal(t, cookies[0].Name, codersdk.SignedAppTokenCookie) require.NotEqual(t, cookies[0].Value, badTokenStr) - parsedToken, err := api.AppSecurityKey.VerifySignedToken(cookies[0].Value) + var parsedToken workspaceapps.SignedToken + err = jwtutils.Verify(ctx, api.AppSigningKeyCache, cookies[0].Value, &parsedToken) require.NoError(t, err) require.Equal(t, appNameOwner, parsedToken.AppSlugOrPort) }) diff --git a/coderd/workspaceapps/provider.go b/coderd/workspaceapps/provider.go index 8d4b7fd149800..1887036e35cbf 100644 --- a/coderd/workspaceapps/provider.go +++ b/coderd/workspaceapps/provider.go @@ -38,7 +38,7 @@ type ResolveRequestOptions struct { func ResolveRequest(rw http.ResponseWriter, r *http.Request, opts ResolveRequestOptions) (*SignedToken, bool) { appReq := opts.AppRequest.Normalize() - err := appReq.Validate() + err := appReq.Check() if err != nil { // This is a 500 since it's a coder server or proxy that's making this // request struct based on details from the request. The values should @@ -79,7 +79,7 @@ func ResolveRequest(rw http.ResponseWriter, r *http.Request, opts ResolveRequest Name: codersdk.SignedAppTokenCookie, Value: tokenStr, Path: appReq.BasePath, - Expires: token.Expiry, + Expires: token.Expiry.Time(), }) return token, true diff --git a/coderd/workspaceapps/proxy.go b/coderd/workspaceapps/proxy.go index 69f1aadca49b2..84cea4fa86678 100644 --- a/coderd/workspaceapps/proxy.go +++ b/coderd/workspaceapps/proxy.go @@ -11,17 +11,21 @@ import ( "strconv" "strings" "sync" + "time" "github.com/go-chi/chi/v5" + "github.com/go-jose/go-jose/v4/jwt" "github.com/google/uuid" "go.opentelemetry.io/otel/trace" "nhooyr.io/websocket" "cdr.dev/slog" "github.com/coder/coder/v2/agent/agentssh" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/coderd/workspaceapps/appurl" @@ -97,8 +101,8 @@ type Server struct { HostnameRegex *regexp.Regexp RealIPConfig *httpmw.RealIPConfig - SignedTokenProvider SignedTokenProvider - AppSecurityKey SecurityKey + SignedTokenProvider SignedTokenProvider + APIKeyEncryptionKeycache cryptokeys.EncryptionKeycache // DisablePathApps disables path-based apps. This is a security feature as path // based apps share the same cookie as the dashboard, and are susceptible to XSS @@ -176,7 +180,10 @@ func (s *Server) handleAPIKeySmuggling(rw http.ResponseWriter, r *http.Request, } // Exchange the encoded API key for a real one. - token, err := s.AppSecurityKey.DecryptAPIKey(encryptedAPIKey) + var payload EncryptedAPIKeyPayload + err := jwtutils.Decrypt(ctx, s.APIKeyEncryptionKeycache, encryptedAPIKey, &payload, jwtutils.WithDecryptExpected(jwt.Expected{ + Time: time.Now(), + })) if err != nil { s.Logger.Debug(ctx, "could not decrypt smuggled workspace app API key", slog.Error(err)) site.RenderStaticErrorPage(rw, r, site.ErrorPageData{ @@ -225,7 +232,7 @@ func (s *Server) handleAPIKeySmuggling(rw http.ResponseWriter, r *http.Request, // server using the wrong value. http.SetCookie(rw, &http.Cookie{ Name: AppConnectSessionTokenCookieName(accessMethod), - Value: token, + Value: payload.APIKey, Domain: domain, Path: "/", MaxAge: 0, diff --git a/coderd/workspaceapps/request.go b/coderd/workspaceapps/request.go index 4f6a6f3a64e65..0833ab731fe67 100644 --- a/coderd/workspaceapps/request.go +++ b/coderd/workspaceapps/request.go @@ -124,9 +124,9 @@ func (r Request) Normalize() Request { return req } -// Validate ensures the request is correct and contains the necessary +// Check ensures the request is correct and contains the necessary // parameters. -func (r Request) Validate() error { +func (r Request) Check() error { switch r.AccessMethod { case AccessMethodPath, AccessMethodSubdomain, AccessMethodTerminal: default: diff --git a/coderd/workspaceapps/request_test.go b/coderd/workspaceapps/request_test.go index b6e4bb7a2e65f..fbabc840745e9 100644 --- a/coderd/workspaceapps/request_test.go +++ b/coderd/workspaceapps/request_test.go @@ -279,7 +279,7 @@ func Test_RequestValidate(t *testing.T) { if !c.noNormalize { req = c.req.Normalize() } - err := req.Validate() + err := req.Check() if c.errContains == "" { require.NoError(t, err) } else { diff --git a/coderd/workspaceapps/token.go b/coderd/workspaceapps/token.go index 33428b0e25f13..dcd8c5a0e5c34 100644 --- a/coderd/workspaceapps/token.go +++ b/coderd/workspaceapps/token.go @@ -1,35 +1,27 @@ package workspaceapps import ( - "encoding/base64" - "encoding/hex" - "encoding/json" "net/http" "strings" "time" - "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v4/jwt" "github.com/google/uuid" "golang.org/x/xerrors" - "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/cryptokeys" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/codersdk" ) -const ( - tokenSigningAlgorithm = jose.HS512 - apiKeyEncryptionAlgorithm = jose.A256GCMKW -) - // SignedToken is the struct data contained inside a workspace app JWE. It // contains the details of the workspace app that the token is valid for to // avoid database queries. type SignedToken struct { + jwtutils.RegisteredClaims // Request details. Request `json:"request"` - // Trusted resolved details. - Expiry time.Time `json:"expiry"` // set by GenerateToken if unset UserID uuid.UUID `json:"user_id"` WorkspaceID uuid.UUID `json:"workspace_id"` AgentID uuid.UUID `json:"agent_id"` @@ -57,191 +49,32 @@ func (t SignedToken) MatchesRequest(req Request) bool { t.AppSlugOrPort == req.AppSlugOrPort } -// SecurityKey is used for signing and encrypting app tokens and API keys. -// -// The first 64 bytes of the key are used for signing tokens with HMAC-SHA256, -// and the last 32 bytes are used for encrypting API keys with AES-256-GCM. -// We use a single key for both operations to avoid having to store and manage -// two keys. -type SecurityKey [96]byte - -func (k SecurityKey) IsZero() bool { - return k == SecurityKey{} -} - -func (k SecurityKey) String() string { - return hex.EncodeToString(k[:]) -} - -func (k SecurityKey) signingKey() []byte { - return k[:64] -} - -func (k SecurityKey) encryptionKey() []byte { - return k[64:] -} - -func KeyFromString(str string) (SecurityKey, error) { - var key SecurityKey - decoded, err := hex.DecodeString(str) - if err != nil { - return key, xerrors.Errorf("decode key: %w", err) - } - if len(decoded) != len(key) { - return key, xerrors.Errorf("expected key to be %d bytes, got %d", len(key), len(decoded)) - } - copy(key[:], decoded) - - return key, nil -} - -// SignToken generates a signed workspace app token with the given payload. If -// the payload doesn't have an expiry, it will be set to the current time plus -// the default expiry. -func (k SecurityKey) SignToken(payload SignedToken) (string, error) { - if payload.Expiry.IsZero() { - payload.Expiry = time.Now().Add(DefaultTokenExpiry) - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - return "", xerrors.Errorf("marshal payload to JSON: %w", err) - } - - signer, err := jose.NewSigner(jose.SigningKey{ - Algorithm: tokenSigningAlgorithm, - Key: k.signingKey(), - }, nil) - if err != nil { - return "", xerrors.Errorf("create signer: %w", err) - } - - signedObject, err := signer.Sign(payloadBytes) - if err != nil { - return "", xerrors.Errorf("sign payload: %w", err) - } - - serialized, err := signedObject.CompactSerialize() - if err != nil { - return "", xerrors.Errorf("serialize JWS: %w", err) - } - - return serialized, nil -} - -// VerifySignedToken parses a signed workspace app token with the given key and -// returns the payload. If the token is invalid or expired, an error is -// returned. -func (k SecurityKey) VerifySignedToken(str string) (SignedToken, error) { - object, err := jose.ParseSigned(str) - if err != nil { - return SignedToken{}, xerrors.Errorf("parse JWS: %w", err) - } - if len(object.Signatures) != 1 { - return SignedToken{}, xerrors.New("expected 1 signature") - } - if object.Signatures[0].Header.Algorithm != string(tokenSigningAlgorithm) { - return SignedToken{}, xerrors.Errorf("expected token signing algorithm to be %q, got %q", tokenSigningAlgorithm, object.Signatures[0].Header.Algorithm) - } - - output, err := object.Verify(k.signingKey()) - if err != nil { - return SignedToken{}, xerrors.Errorf("verify JWS: %w", err) - } - - var tok SignedToken - err = json.Unmarshal(output, &tok) - if err != nil { - return SignedToken{}, xerrors.Errorf("unmarshal payload: %w", err) - } - if tok.Expiry.Before(time.Now()) { - return SignedToken{}, xerrors.New("signed app token expired") - } - - return tok, nil -} - type EncryptedAPIKeyPayload struct { - APIKey string `json:"api_key"` - ExpiresAt time.Time `json:"expires_at"` + jwtutils.RegisteredClaims + APIKey string `json:"api_key"` } -// EncryptAPIKey encrypts an API key for subdomain token smuggling. -func (k SecurityKey) EncryptAPIKey(payload EncryptedAPIKeyPayload) (string, error) { - if payload.APIKey == "" { - return "", xerrors.New("API key is empty") - } - if payload.ExpiresAt.IsZero() { - // Very short expiry as these keys are only used once as part of an - // automatic redirection flow. - payload.ExpiresAt = dbtime.Now().Add(time.Minute) - } - - payloadBytes, err := json.Marshal(payload) - if err != nil { - return "", xerrors.Errorf("marshal payload: %w", err) - } - - // JWEs seem to apply a nonce themselves. - encrypter, err := jose.NewEncrypter( - jose.A256GCM, - jose.Recipient{ - Algorithm: apiKeyEncryptionAlgorithm, - Key: k.encryptionKey(), - }, - &jose.EncrypterOptions{ - Compression: jose.DEFLATE, - }, - ) - if err != nil { - return "", xerrors.Errorf("initializer jose encrypter: %w", err) - } - encryptedObject, err := encrypter.Encrypt(payloadBytes) - if err != nil { - return "", xerrors.Errorf("encrypt jwe: %w", err) - } - - encrypted := encryptedObject.FullSerialize() - return base64.RawURLEncoding.EncodeToString([]byte(encrypted)), nil +func (e *EncryptedAPIKeyPayload) Fill(now time.Time) { + e.Issuer = "coderd" + e.Audience = jwt.Audience{"wsproxy"} + e.Expiry = jwt.NewNumericDate(now.Add(time.Minute)) + e.NotBefore = jwt.NewNumericDate(now.Add(-time.Minute)) } -// DecryptAPIKey undoes EncryptAPIKey and is used in the subdomain app handler. -func (k SecurityKey) DecryptAPIKey(encryptedAPIKey string) (string, error) { - encrypted, err := base64.RawURLEncoding.DecodeString(encryptedAPIKey) - if err != nil { - return "", xerrors.Errorf("base64 decode encrypted API key: %w", err) +func (e EncryptedAPIKeyPayload) Validate(ex jwt.Expected) error { + if e.NotBefore == nil { + return xerrors.Errorf("not before is required") } - object, err := jose.ParseEncrypted(string(encrypted)) - if err != nil { - return "", xerrors.Errorf("parse encrypted API key: %w", err) - } - if object.Header.Algorithm != string(apiKeyEncryptionAlgorithm) { - return "", xerrors.Errorf("expected API key encryption algorithm to be %q, got %q", apiKeyEncryptionAlgorithm, object.Header.Algorithm) - } - - // Decrypt using the hashed secret. - decrypted, err := object.Decrypt(k.encryptionKey()) - if err != nil { - return "", xerrors.Errorf("decrypt API key: %w", err) - } - - // Unmarshal the payload. - var payload EncryptedAPIKeyPayload - if err := json.Unmarshal(decrypted, &payload); err != nil { - return "", xerrors.Errorf("unmarshal decrypted payload: %w", err) - } - - // Validate expiry. - if payload.ExpiresAt.Before(dbtime.Now()) { - return "", xerrors.New("encrypted API key expired") - } + ex.Issuer = "coderd" + ex.AnyAudience = jwt.Audience{"wsproxy"} - return payload.APIKey, nil + return e.RegisteredClaims.Validate(ex) } // FromRequest returns the signed token from the request, if it exists and is // valid. The caller must check that the token matches the request. -func FromRequest(r *http.Request, key SecurityKey) (*SignedToken, bool) { +func FromRequest(r *http.Request, mgr cryptokeys.SigningKeycache) (*SignedToken, bool) { // Get all signed app tokens from the request. This includes the query // parameter and all matching cookies sent with the request. If there are // somehow multiple signed app token cookies, we want to try all of them @@ -270,8 +103,12 @@ func FromRequest(r *http.Request, key SecurityKey) (*SignedToken, bool) { tokens = tokens[:4] } + ctx := r.Context() for _, tokenStr := range tokens { - token, err := key.VerifySignedToken(tokenStr) + var token SignedToken + err := jwtutils.Verify(ctx, mgr, tokenStr, &token, jwtutils.WithVerifyExpected(jwt.Expected{ + Time: time.Now(), + })) if err == nil { req := token.Request.Normalize() if hasQueryParam && req.AccessMethod != AccessMethodTerminal { @@ -280,7 +117,7 @@ func FromRequest(r *http.Request, key SecurityKey) (*SignedToken, bool) { return nil, false } - err := req.Validate() + err := req.Check() if err == nil { // The request has a valid signed app token, which is a valid // token signed by us. The caller must check that it matches diff --git a/coderd/workspaceapps/token_test.go b/coderd/workspaceapps/token_test.go index c656ae2ab77b8..db070268fa196 100644 --- a/coderd/workspaceapps/token_test.go +++ b/coderd/workspaceapps/token_test.go @@ -1,22 +1,22 @@ package workspaceapps_test import ( - "fmt" + "crypto/rand" "net/http" "net/http/httptest" "testing" "time" + "github.com/go-jose/go-jose/v4/jwt" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" - "github.com/go-jose/go-jose/v3" "github.com/google/uuid" "github.com/stretchr/testify/require" - "github.com/coder/coder/v2/coderd/coderdtest" - "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/workspaceapps" - "github.com/coder/coder/v2/cryptorand" ) func Test_TokenMatchesRequest(t *testing.T) { @@ -283,129 +283,6 @@ func Test_TokenMatchesRequest(t *testing.T) { } } -func Test_GenerateToken(t *testing.T) { - t.Parallel() - - t.Run("SetExpiry", func(t *testing.T) { - t.Parallel() - - tokenStr, err := coderdtest.AppSecurityKey.SignToken(workspaceapps.SignedToken{ - Request: workspaceapps.Request{ - AccessMethod: workspaceapps.AccessMethodPath, - BasePath: "/app", - UsernameOrID: "foo", - WorkspaceNameOrID: "bar", - AgentNameOrID: "baz", - AppSlugOrPort: "qux", - }, - - Expiry: time.Time{}, - UserID: uuid.MustParse("b1530ba9-76f3-415e-b597-4ddd7cd466a4"), - WorkspaceID: uuid.MustParse("1e6802d3-963e-45ac-9d8c-bf997016ffed"), - AgentID: uuid.MustParse("9ec18681-d2c9-4c9e-9186-f136efb4edbe"), - AppURL: "http://127.0.0.1:8080", - }) - require.NoError(t, err) - - token, err := coderdtest.AppSecurityKey.VerifySignedToken(tokenStr) - require.NoError(t, err) - - require.WithinDuration(t, time.Now().Add(time.Minute), token.Expiry, 15*time.Second) - }) - - future := time.Now().Add(time.Hour) - cases := []struct { - name string - token workspaceapps.SignedToken - parseErrContains string - }{ - { - name: "OK1", - token: workspaceapps.SignedToken{ - Request: workspaceapps.Request{ - AccessMethod: workspaceapps.AccessMethodPath, - BasePath: "/app", - UsernameOrID: "foo", - WorkspaceNameOrID: "bar", - AgentNameOrID: "baz", - AppSlugOrPort: "qux", - }, - - Expiry: future, - UserID: uuid.MustParse("b1530ba9-76f3-415e-b597-4ddd7cd466a4"), - WorkspaceID: uuid.MustParse("1e6802d3-963e-45ac-9d8c-bf997016ffed"), - AgentID: uuid.MustParse("9ec18681-d2c9-4c9e-9186-f136efb4edbe"), - AppURL: "http://127.0.0.1:8080", - }, - }, - { - name: "OK2", - token: workspaceapps.SignedToken{ - Request: workspaceapps.Request{ - AccessMethod: workspaceapps.AccessMethodSubdomain, - BasePath: "/", - UsernameOrID: "oof", - WorkspaceNameOrID: "rab", - AgentNameOrID: "zab", - AppSlugOrPort: "xuq", - }, - - Expiry: future, - UserID: uuid.MustParse("6fa684a3-11aa-49fd-8512-ab527bd9b900"), - WorkspaceID: uuid.MustParse("b2d816cc-505c-441d-afdf-dae01781bc0b"), - AgentID: uuid.MustParse("6c4396e1-af88-4a8a-91a3-13ea54fc29fb"), - AppURL: "http://localhost:9090", - }, - }, - { - name: "Expired", - token: workspaceapps.SignedToken{ - Request: workspaceapps.Request{ - AccessMethod: workspaceapps.AccessMethodSubdomain, - BasePath: "/", - UsernameOrID: "foo", - WorkspaceNameOrID: "bar", - AgentNameOrID: "baz", - AppSlugOrPort: "qux", - }, - - Expiry: time.Now().Add(-time.Hour), - UserID: uuid.MustParse("b1530ba9-76f3-415e-b597-4ddd7cd466a4"), - WorkspaceID: uuid.MustParse("1e6802d3-963e-45ac-9d8c-bf997016ffed"), - AgentID: uuid.MustParse("9ec18681-d2c9-4c9e-9186-f136efb4edbe"), - AppURL: "http://127.0.0.1:8080", - }, - parseErrContains: "token expired", - }, - } - - for _, c := range cases { - c := c - - t.Run(c.name, func(t *testing.T) { - t.Parallel() - - str, err := coderdtest.AppSecurityKey.SignToken(c.token) - require.NoError(t, err) - - // Tokens aren't deterministic as they have a random nonce, so we - // can't compare them directly. - - token, err := coderdtest.AppSecurityKey.VerifySignedToken(str) - if c.parseErrContains != "" { - require.Error(t, err) - require.ErrorContains(t, err, c.parseErrContains) - } else { - require.NoError(t, err) - // normalize the expiry - require.WithinDuration(t, c.token.Expiry, token.Expiry, 10*time.Second) - c.token.Expiry = token.Expiry - require.Equal(t, c.token, token) - } - }) - } -} - func Test_FromRequest(t *testing.T) { t.Parallel() @@ -419,7 +296,13 @@ func Test_FromRequest(t *testing.T) { Value: "invalid", }) + ctx := testutil.Context(t, testutil.WaitShort) + signer := newSigner(t) + token := workspaceapps.SignedToken{ + RegisteredClaims: jwtutils.RegisteredClaims{ + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }, Request: workspaceapps.Request{ AccessMethod: workspaceapps.AccessMethodSubdomain, BasePath: "/", @@ -429,7 +312,6 @@ func Test_FromRequest(t *testing.T) { AgentNameOrID: "agent", AppSlugOrPort: "app", }, - Expiry: time.Now().Add(time.Hour), UserID: uuid.New(), WorkspaceID: uuid.New(), AgentID: uuid.New(), @@ -438,16 +320,15 @@ func Test_FromRequest(t *testing.T) { // Add an expired cookie expired := token - expired.Expiry = time.Now().Add(time.Hour * -1) - expiredStr, err := coderdtest.AppSecurityKey.SignToken(token) + expired.RegisteredClaims.Expiry = jwt.NewNumericDate(time.Now().Add(time.Hour * -1)) + expiredStr, err := jwtutils.Sign(ctx, signer, expired) require.NoError(t, err) r.AddCookie(&http.Cookie{ Name: codersdk.SignedAppTokenCookie, Value: expiredStr, }) - // Add a valid token - validStr, err := coderdtest.AppSecurityKey.SignToken(token) + validStr, err := jwtutils.Sign(ctx, signer, token) require.NoError(t, err) r.AddCookie(&http.Cookie{ @@ -455,147 +336,27 @@ func Test_FromRequest(t *testing.T) { Value: validStr, }) - signed, ok := workspaceapps.FromRequest(r, coderdtest.AppSecurityKey) + signed, ok := workspaceapps.FromRequest(r, signer) require.True(t, ok, "expected a token to be found") // Confirm it is the correct token. require.Equal(t, signed.UserID, token.UserID) }) } -// The ParseToken fn is tested quite thoroughly in the GenerateToken test as -// well. -func Test_ParseToken(t *testing.T) { - t.Parallel() - - t.Run("InvalidJWS", func(t *testing.T) { - t.Parallel() - - token, err := coderdtest.AppSecurityKey.VerifySignedToken("invalid") - require.Error(t, err) - require.ErrorContains(t, err, "parse JWS") - require.Equal(t, workspaceapps.SignedToken{}, token) - }) - - t.Run("VerifySignature", func(t *testing.T) { - t.Parallel() +func newSigner(t *testing.T) jwtutils.StaticKey { + t.Helper() - // Create a valid token using a different key. - var otherKey workspaceapps.SecurityKey - copy(otherKey[:], coderdtest.AppSecurityKey[:]) - for i := range otherKey { - otherKey[i] ^= 0xff - } - require.NotEqual(t, coderdtest.AppSecurityKey, otherKey) - - tokenStr, err := otherKey.SignToken(workspaceapps.SignedToken{ - Request: workspaceapps.Request{ - AccessMethod: workspaceapps.AccessMethodPath, - BasePath: "/app", - UsernameOrID: "foo", - WorkspaceNameOrID: "bar", - AgentNameOrID: "baz", - AppSlugOrPort: "qux", - }, - - Expiry: time.Now().Add(time.Hour), - UserID: uuid.MustParse("b1530ba9-76f3-415e-b597-4ddd7cd466a4"), - WorkspaceID: uuid.MustParse("1e6802d3-963e-45ac-9d8c-bf997016ffed"), - AgentID: uuid.MustParse("9ec18681-d2c9-4c9e-9186-f136efb4edbe"), - AppURL: "http://127.0.0.1:8080", - }) - require.NoError(t, err) - - // Verify the token is invalid. - token, err := coderdtest.AppSecurityKey.VerifySignedToken(tokenStr) - require.Error(t, err) - require.ErrorContains(t, err, "verify JWS") - require.Equal(t, workspaceapps.SignedToken{}, token) - }) - - t.Run("InvalidBody", func(t *testing.T) { - t.Parallel() - - // Create a signature for an invalid body. - signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS512, Key: coderdtest.AppSecurityKey[:64]}, nil) - require.NoError(t, err) - signedObject, err := signer.Sign([]byte("hi")) - require.NoError(t, err) - serialized, err := signedObject.CompactSerialize() - require.NoError(t, err) - - token, err := coderdtest.AppSecurityKey.VerifySignedToken(serialized) - require.Error(t, err) - require.ErrorContains(t, err, "unmarshal payload") - require.Equal(t, workspaceapps.SignedToken{}, token) - }) -} - -func TestAPIKeyEncryption(t *testing.T) { - t.Parallel() - - genAPIKey := func(t *testing.T) string { - id, _ := cryptorand.String(10) - secret, _ := cryptorand.String(22) - - return fmt.Sprintf("%s-%s", id, secret) + return jwtutils.StaticKey{ + ID: "test", + Key: generateSecret(t, 64), } +} - t.Run("OK", func(t *testing.T) { - t.Parallel() - - key := genAPIKey(t) - encrypted, err := coderdtest.AppSecurityKey.EncryptAPIKey(workspaceapps.EncryptedAPIKeyPayload{ - APIKey: key, - }) - require.NoError(t, err) - - decryptedKey, err := coderdtest.AppSecurityKey.DecryptAPIKey(encrypted) - require.NoError(t, err) - require.Equal(t, key, decryptedKey) - }) - - t.Run("Verifies", func(t *testing.T) { - t.Parallel() - - t.Run("Expiry", func(t *testing.T) { - t.Parallel() - - key := genAPIKey(t) - encrypted, err := coderdtest.AppSecurityKey.EncryptAPIKey(workspaceapps.EncryptedAPIKeyPayload{ - APIKey: key, - ExpiresAt: dbtime.Now().Add(-1 * time.Hour), - }) - require.NoError(t, err) - - decryptedKey, err := coderdtest.AppSecurityKey.DecryptAPIKey(encrypted) - require.Error(t, err) - require.ErrorContains(t, err, "expired") - require.Empty(t, decryptedKey) - }) - - t.Run("EncryptionKey", func(t *testing.T) { - t.Parallel() - - // Create a valid token using a different key. - var otherKey workspaceapps.SecurityKey - copy(otherKey[:], coderdtest.AppSecurityKey[:]) - for i := range otherKey { - otherKey[i] ^= 0xff - } - require.NotEqual(t, coderdtest.AppSecurityKey, otherKey) - - // Encrypt with the other key. - key := genAPIKey(t) - encrypted, err := otherKey.EncryptAPIKey(workspaceapps.EncryptedAPIKeyPayload{ - APIKey: key, - }) - require.NoError(t, err) +func generateSecret(t *testing.T, size int) []byte { + t.Helper() - // Decrypt with the original key. - decryptedKey, err := coderdtest.AppSecurityKey.DecryptAPIKey(encrypted) - require.Error(t, err) - require.ErrorContains(t, err, "decrypt API key") - require.Empty(t, decryptedKey) - }) - }) + secret := make([]byte, size) + _, err := rand.Read(secret) + require.NoError(t, err) + return secret } diff --git a/coderd/workspaceapps_test.go b/coderd/workspaceapps_test.go index 1d00b7daa7bd9..52b3e18b4e6ad 100644 --- a/coderd/workspaceapps_test.go +++ b/coderd/workspaceapps_test.go @@ -5,16 +5,23 @@ import ( "net/http" "net/url" "testing" + "time" + "github.com/go-jose/go-jose/v4/jwt" "github.com/stretchr/testify/require" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" ) func TestGetAppHost(t *testing.T) { @@ -181,16 +188,28 @@ func TestWorkspaceApplicationAuth(t *testing.T) { t.Run(c.name, func(t *testing.T) { t.Parallel() - db, pubsub := dbtestutil.NewDB(t) - + ctx := testutil.Context(t, testutil.WaitMedium) + logger := slogtest.Make(t, nil) accessURL, err := url.Parse(c.accessURL) require.NoError(t, err) + db, ps := dbtestutil.NewDB(t) + fetcher := &cryptokeys.DBFetcher{ + DB: db, + } + + kc, err := cryptokeys.NewEncryptionCache(ctx, logger, fetcher, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey) + require.NoError(t, err) + + clock := quartz.NewMock(t) + client := coderdtest.New(t, &coderdtest.Options{ - Database: db, - Pubsub: pubsub, - AccessURL: accessURL, - AppHostname: c.appHostname, + AccessURL: accessURL, + AppHostname: c.appHostname, + Database: db, + Pubsub: ps, + APIKeyEncryptionCache: kc, + Clock: clock, }) _ = coderdtest.CreateFirstUser(t, client) @@ -240,7 +259,15 @@ func TestWorkspaceApplicationAuth(t *testing.T) { loc.RawQuery = q.Encode() require.Equal(t, c.expectRedirect, loc.String()) - // The decrypted key is verified in the apptest test suite. + var token workspaceapps.EncryptedAPIKeyPayload + err = jwtutils.Decrypt(ctx, kc, encryptedAPIKey, &token, jwtutils.WithDecryptExpected(jwt.Expected{ + Time: clock.Now(), + AnyAudience: jwt.Audience{"wsproxy"}, + Issuer: "coderd", + })) + require.NoError(t, err) + require.Equal(t, jwt.NewNumericDate(clock.Now().Add(time.Minute)), token.Expiry) + require.Equal(t, jwt.NewNumericDate(clock.Now().Add(-time.Minute)), token.NotBefore) }) } } diff --git a/codersdk/deployment.go b/codersdk/deployment.go index d6840df504b85..391d0039f0369 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -3109,9 +3109,11 @@ func (c *Client) SSHConfiguration(ctx context.Context) (SSHConfigResponse, error type CryptoKeyFeature string const ( - CryptoKeyFeatureWorkspaceApp CryptoKeyFeature = "workspace_apps" - CryptoKeyFeatureOIDCConvert CryptoKeyFeature = "oidc_convert" - CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume" + CryptoKeyFeatureWorkspaceAppsAPIKey CryptoKeyFeature = "workspace_apps_api_key" + //nolint:gosec // This denotes a type of key, not a literal. + CryptoKeyFeatureWorkspaceAppsToken CryptoKeyFeature = "workspace_apps_token" + CryptoKeyFeatureOIDCConvert CryptoKeyFeature = "oidc_convert" + CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume" ) type CryptoKey struct { diff --git a/codersdk/workspacesdk/connector_internal_test.go b/codersdk/workspacesdk/connector_internal_test.go index 7a339a0079ba2..19f1930c89bc5 100644 --- a/codersdk/workspacesdk/connector_internal_test.go +++ b/codersdk/workspacesdk/connector_internal_test.go @@ -25,6 +25,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/apiversion" "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" @@ -61,7 +62,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) { CoordPtr: &coordPtr, DERPMapUpdateFrequency: time.Millisecond, DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, - NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {}, + NetworkTelemetryHandler: func([]*proto.TelemetryEvent) {}, ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(), }) require.NoError(t, err) @@ -165,13 +166,17 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) { clock := quartz.NewMock(t) resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey() require.NoError(t, err) - resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour) + mgr := jwtutils.StaticKey{ + ID: "123", + Key: resumeTokenSigningKey[:], + } + resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour) svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ Logger: logger, CoordPtr: &coordPtr, DERPMapUpdateFrequency: time.Millisecond, DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, - NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {}, + NetworkTelemetryHandler: func([]*proto.TelemetryEvent) {}, ResumeTokenProvider: resumeTokenProvider, }) require.NoError(t, err) @@ -190,7 +195,7 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) { t.Logf("received resume token: %s", resumeToken) assert.Equal(t, expectResumeToken, resumeToken) if resumeToken != "" { - peerID, err = resumeTokenProvider.VerifyResumeToken(resumeToken) + peerID, err = resumeTokenProvider.VerifyResumeToken(ctx, resumeToken) assert.NoError(t, err, "failed to parse resume token") if err != nil { httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{ @@ -280,13 +285,17 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) { clock := quartz.NewMock(t) resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey() require.NoError(t, err) - resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour) + mgr := jwtutils.StaticKey{ + ID: uuid.New().String(), + Key: resumeTokenSigningKey[:], + } + resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour) svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ Logger: logger, CoordPtr: &coordPtr, DERPMapUpdateFrequency: time.Millisecond, DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, - NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {}, + NetworkTelemetryHandler: func(_ []*proto.TelemetryEvent) {}, ResumeTokenProvider: resumeTokenProvider, }) require.NoError(t, err) diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index ed3800b3a27cd..f4e683305029b 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -1454,7 +1454,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o ```json { "deletes_at": "2019-08-24T14:15:22Z", - "feature": "workspace_apps", + "feature": "workspace_apps_api_key", "secret": "string", "sequence": 0, "starts_at": "2019-08-24T14:15:22Z" @@ -1474,18 +1474,19 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o ## codersdk.CryptoKeyFeature ```json -"workspace_apps" +"workspace_apps_api_key" ``` ### Properties #### Enumerated Values -| Value | -| ---------------- | -| `workspace_apps` | -| `oidc_convert` | -| `tailnet_resume` | +| Value | +| ------------------------ | +| `workspace_apps_api_key` | +| `workspace_apps_token` | +| `oidc_convert` | +| `tailnet_resume` | ## codersdk.CustomRoleRequest @@ -9893,7 +9894,7 @@ _None_ "crypto_keys": [ { "deletes_at": "2019-08-24T14:15:22Z", - "feature": "workspace_apps", + "feature": "workspace_apps_api_key", "secret": "string", "sequence": 0, "starts_at": "2019-08-24T14:15:22Z" @@ -9971,7 +9972,6 @@ _None_ ```json { - "app_security_key": "string", "derp_force_websockets": true, "derp_map": { "homeParams": { @@ -10052,7 +10052,6 @@ _None_ | Name | Type | Required | Restrictions | Description | | ----------------------- | --------------------------------------------- | -------- | ------------ | -------------------------------------------------------------------------------------- | -| `app_security_key` | string | false | | | | `derp_force_websockets` | boolean | false | | | | `derp_map` | [tailcfg.DERPMap](#tailcfgderpmap) | false | | | | `derp_mesh_key` | string | false | | | diff --git a/enterprise/coderd/coderdenttest/proxytest.go b/enterprise/coderd/coderdenttest/proxytest.go index 6e5a822bdf251..a6f2c7384b16f 100644 --- a/enterprise/coderd/coderdenttest/proxytest.go +++ b/enterprise/coderd/coderdenttest/proxytest.go @@ -65,6 +65,8 @@ type WorkspaceProxy struct { // owner client. If a token is provided, the proxy will become a replica of the // existing proxy region. func NewWorkspaceProxyReplica(t *testing.T, coderdAPI *coderd.API, owner *codersdk.Client, options *ProxyOptions) WorkspaceProxy { + t.Helper() + ctx, cancelFunc := context.WithCancel(context.Background()) t.Cleanup(cancelFunc) @@ -142,8 +144,10 @@ func NewWorkspaceProxyReplica(t *testing.T, coderdAPI *coderd.API, owner *coders statsCollectorOptions.Flush = options.FlushStats } + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug).With(slog.F("server_url", serverURL.String())) + wssrv, err := wsproxy.New(ctx, &wsproxy.Options{ - Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug).With(slog.F("server_url", serverURL.String())), + Logger: logger, Experiments: options.Experiments, DashboardURL: coderdAPI.AccessURL, AccessURL: accessURL, diff --git a/enterprise/coderd/workspaceproxy.go b/enterprise/coderd/workspaceproxy.go index 47bdf53493489..4008de69e4faa 100644 --- a/enterprise/coderd/workspaceproxy.go +++ b/enterprise/coderd/workspaceproxy.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "net/url" + "slices" "strings" "time" @@ -33,6 +34,13 @@ import ( "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" ) +// whitelistedCryptoKeyFeatures is a list of crypto key features that are +// allowed to be queried with workspace proxies. +var whitelistedCryptoKeyFeatures = []database.CryptoKeyFeature{ + database.CryptoKeyFeatureWorkspaceAppsToken, + database.CryptoKeyFeatureWorkspaceAppsAPIKey, +} + // forceWorkspaceProxyHealthUpdate forces an update of the proxy health. // This is useful when a proxy is created or deleted. Errors will be logged. func (api *API) forceWorkspaceProxyHealthUpdate(ctx context.Context) { @@ -700,7 +708,6 @@ func (api *API) workspaceProxyRegister(rw http.ResponseWriter, r *http.Request) } httpapi.Write(ctx, rw, http.StatusCreated, wsproxysdk.RegisterWorkspaceProxyResponse{ - AppSecurityKey: api.AppSecurityKey.String(), DERPMeshKey: api.DERPServer.MeshKey(), DERPRegionID: regionID, DERPMap: api.AGPL.DERPMap(), @@ -721,13 +728,29 @@ func (api *API) workspaceProxyRegister(rw http.ResponseWriter, r *http.Request) // @Security CoderSessionToken // @Produce json // @Tags Enterprise +// @Param feature query string true "Feature key" // @Success 200 {object} wsproxysdk.CryptoKeysResponse // @Router /workspaceproxies/me/crypto-keys [get] // @x-apidocgen {"skip": true} func (api *API) workspaceProxyCryptoKeys(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - keys, err := api.Database.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps) + feature := database.CryptoKeyFeature(r.URL.Query().Get("feature")) + if feature == "" { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Missing feature query parameter.", + }) + return + } + + if !slices.Contains(whitelistedCryptoKeyFeatures, feature) { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Invalid feature: %q", feature), + }) + return + } + + keys, err := api.Database.GetCryptoKeysByFeature(ctx, feature) if err != nil { httpapi.InternalServerError(rw, err) return diff --git a/enterprise/coderd/workspaceproxy_test.go b/enterprise/coderd/workspaceproxy_test.go index 5231a0b0c4241..0be112b532b7a 100644 --- a/enterprise/coderd/workspaceproxy_test.go +++ b/enterprise/coderd/workspaceproxy_test.go @@ -320,7 +320,6 @@ func TestProxyRegisterDeregister(t *testing.T) { } registerRes1, err := proxyClient.RegisterWorkspaceProxy(ctx, req) require.NoError(t, err) - require.NotEmpty(t, registerRes1.AppSecurityKey) require.NotEmpty(t, registerRes1.DERPMeshKey) require.EqualValues(t, 10001, registerRes1.DERPRegionID) require.Empty(t, registerRes1.SiblingReplicas) @@ -609,11 +608,8 @@ func TestProxyRegisterDeregister(t *testing.T) { func TestIssueSignedAppToken(t *testing.T) { t.Parallel() - db, pubsub := dbtestutil.NewDB(t) client, user := coderdenttest.New(t, &coderdenttest.Options{ Options: &coderdtest.Options{ - Database: db, - Pubsub: pubsub, IncludeProvisionerDaemon: true, }, LicenseOptions: &coderdenttest.LicenseOptions{ @@ -716,6 +712,10 @@ func TestReconnectingPTYSignedToken(t *testing.T) { closer.Close() }) + _ = dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceAppsToken, + }) + // Create a workspace + apps authToken := uuid.NewString() version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ @@ -915,51 +915,86 @@ func TestGetCryptoKeys(t *testing.T) { now := time.Now() expectedKey1 := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now.Add(-time.Hour), Sequence: 2, }) - key1 := db2sdk.CryptoKey(expectedKey1) + encryptionKey := db2sdk.CryptoKey(expectedKey1) expectedKey2 := dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsToken, StartsAt: now, Sequence: 3, }) - key2 := db2sdk.CryptoKey(expectedKey2) + signingKey := db2sdk.CryptoKey(expectedKey2) // Create a deleted key. _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, StartsAt: now.Add(-time.Hour), Secret: sql.NullString{ String: "secret1", Valid: false, }, - Sequence: 1, - }) - - // Create a key with different features. - _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureTailnetResume, - StartsAt: now.Add(-time.Hour), - Sequence: 1, - }) - _ = dbgen.CryptoKey(t, db, database.CryptoKey{ - Feature: database.CryptoKeyFeatureOidcConvert, - StartsAt: now.Add(-time.Hour), - Sequence: 1, + Sequence: 4, }) proxy := coderdenttest.NewWorkspaceProxyReplica(t, api, cclient, &coderdenttest.ProxyOptions{ Name: testutil.GetRandomName(t), }) - keys, err := proxy.SDKClient.CryptoKeys(ctx) + keys, err := proxy.SDKClient.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey) require.NoError(t, err) require.NotEmpty(t, keys) + // 1 key is generated on startup, the other we manually generated. require.Equal(t, 2, len(keys.CryptoKeys)) - requireContainsKeys(t, keys.CryptoKeys, key1, key2) + requireContainsKeys(t, keys.CryptoKeys, encryptionKey) + requireNotContainsKeys(t, keys.CryptoKeys, signingKey) + + keys, err = proxy.SDKClient.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsToken) + require.NoError(t, err) + require.NotEmpty(t, keys) + // 1 key is generated on startup, the other we manually generated. + require.Equal(t, 2, len(keys.CryptoKeys)) + requireContainsKeys(t, keys.CryptoKeys, signingKey) + requireNotContainsKeys(t, keys.CryptoKeys, encryptionKey) + }) + + t.Run("InvalidFeature", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitMedium) + db, pubsub := dbtestutil.NewDB(t) + cclient, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pubsub, + IncludeProvisionerDaemon: true, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureWorkspaceProxy: 1, + }, + }, + }) + + proxy := coderdenttest.NewWorkspaceProxyReplica(t, api, cclient, &coderdenttest.ProxyOptions{ + Name: testutil.GetRandomName(t), + }) + + _, err := proxy.SDKClient.CryptoKeys(ctx, codersdk.CryptoKeyFeatureOIDCConvert) + require.Error(t, err) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + _, err = proxy.SDKClient.CryptoKeys(ctx, codersdk.CryptoKeyFeatureTailnetResume) + require.Error(t, err) + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + _, err = proxy.SDKClient.CryptoKeys(ctx, "invalid") + require.Error(t, err) + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) }) t.Run("Unauthorized", func(t *testing.T) { @@ -987,7 +1022,7 @@ func TestGetCryptoKeys(t *testing.T) { client := wsproxysdk.New(cclient.URL) client.SetSessionToken(cclient.SessionToken()) - _, err := client.CryptoKeys(ctx) + _, err := client.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey) require.Error(t, err) var sdkErr *codersdk.Error require.ErrorAs(t, err, &sdkErr) @@ -995,6 +1030,18 @@ func TestGetCryptoKeys(t *testing.T) { }) } +func requireNotContainsKeys(t *testing.T, keys []codersdk.CryptoKey, unexpected ...codersdk.CryptoKey) { + t.Helper() + + for _, unexpectedKey := range unexpected { + for _, key := range keys { + if key.Feature == unexpectedKey.Feature && key.Sequence == unexpectedKey.Sequence { + t.Fatalf("unexpected key %+v found", unexpectedKey) + } + } + } +} + func requireContainsKeys(t *testing.T, keys []codersdk.CryptoKey, expected ...codersdk.CryptoKey) { t.Helper() diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index 432dc90061677..a96c32aaa8aae 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -397,12 +397,12 @@ func TestCryptoKeys(t *testing.T) { _ = dbgen.CryptoKey(t, crypt, database.CryptoKey{ Secret: sql.NullString{String: "test", Valid: true}, }) - key, err := crypt.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps) + key, err := crypt.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceAppsAPIKey) require.NoError(t, err) require.Equal(t, "test", key.Secret.String) require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String) - key, err = db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps) + key, err = db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceAppsAPIKey) require.NoError(t, err) requireEncryptedEquals(t, ciphers[0], key.Secret.String, "test") require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String) @@ -415,7 +415,7 @@ func TestCryptoKeys(t *testing.T) { Secret: sql.NullString{String: "test", Valid: true}, }) key, err := crypt.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: key.Sequence, }) require.NoError(t, err) @@ -423,7 +423,7 @@ func TestCryptoKeys(t *testing.T) { require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String) key, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: key.Sequence, }) require.NoError(t, err) @@ -459,7 +459,7 @@ func TestCryptoKeys(t *testing.T) { Secret: sql.NullString{String: "test", Valid: true}, }) _ = dbgen.CryptoKey(t, crypt, database.CryptoKey{ - Feature: database.CryptoKeyFeatureWorkspaceApps, + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, Sequence: 43, }) keys, err := crypt.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureTailnetResume) diff --git a/enterprise/workspaceapps_test.go b/enterprise/workspaceapps_test.go index f4ba577f13e33..51d0314c45767 100644 --- a/enterprise/workspaceapps_test.go +++ b/enterprise/workspaceapps_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/workspaceapps/apptest" "github.com/coder/coder/v2/codersdk" @@ -36,6 +37,9 @@ func TestWorkspaceApps(t *testing.T) { flushStatsCollectorCh <- flushStatsCollectorDone <-flushStatsCollectorDone } + + db, pubsub := dbtestutil.NewDB(t) + client, _, _, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ Options: &coderdtest.Options{ DeploymentValues: deploymentValues, @@ -51,6 +55,8 @@ func TestWorkspaceApps(t *testing.T) { }, }, WorkspaceAppsStatsCollectorOptions: opts.StatsCollectorOptions, + Database: db, + Pubsub: pubsub, }, LicenseOptions: &coderdenttest.LicenseOptions{ Features: license.Features{ diff --git a/enterprise/wsproxy/keyfetcher.go b/enterprise/wsproxy/keyfetcher.go index f30fffb2cd093..1a1745d6ccd2d 100644 --- a/enterprise/wsproxy/keyfetcher.go +++ b/enterprise/wsproxy/keyfetcher.go @@ -13,12 +13,11 @@ import ( var _ cryptokeys.Fetcher = &ProxyFetcher{} type ProxyFetcher struct { - Client *wsproxysdk.Client - Feature codersdk.CryptoKeyFeature + Client *wsproxysdk.Client } -func (p *ProxyFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) { - keys, err := p.Client.CryptoKeys(ctx) +func (p *ProxyFetcher) Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) { + keys, err := p.Client.CryptoKeys(ctx, feature) if err != nil { return nil, xerrors.Errorf("crypto keys: %w", err) } diff --git a/enterprise/wsproxy/tokenprovider.go b/enterprise/wsproxy/tokenprovider.go index 38822a4e7a22d..5093c6015725e 100644 --- a/enterprise/wsproxy/tokenprovider.go +++ b/enterprise/wsproxy/tokenprovider.go @@ -7,6 +7,8 @@ import ( "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/cryptokeys" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" ) @@ -18,18 +20,19 @@ type TokenProvider struct { AccessURL *url.URL AppHostname string - Client *wsproxysdk.Client - SecurityKey workspaceapps.SecurityKey - Logger slog.Logger + Client *wsproxysdk.Client + TokenSigningKeycache cryptokeys.SigningKeycache + APIKeyEncryptionKeycache cryptokeys.EncryptionKeycache + Logger slog.Logger } func (p *TokenProvider) FromRequest(r *http.Request) (*workspaceapps.SignedToken, bool) { - return workspaceapps.FromRequest(r, p.SecurityKey) + return workspaceapps.FromRequest(r, p.TokenSigningKeycache) } func (p *TokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *http.Request, issueReq workspaceapps.IssueTokenRequest) (*workspaceapps.SignedToken, string, bool) { appReq := issueReq.AppRequest.Normalize() - err := appReq.Validate() + err := appReq.Check() if err != nil { workspaceapps.WriteWorkspaceApp500(p.Logger, p.DashboardURL, rw, r, &appReq, err, "invalid app request") return nil, "", false @@ -42,7 +45,8 @@ func (p *TokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *ht } // Check that it verifies properly and matches the string. - token, err := p.SecurityKey.VerifySignedToken(resp.SignedTokenStr) + var token workspaceapps.SignedToken + err = jwtutils.Verify(ctx, p.TokenSigningKeycache, resp.SignedTokenStr, &token) if err != nil { workspaceapps.WriteWorkspaceApp500(p.Logger, p.DashboardURL, rw, r, &appReq, err, "failed to verify newly generated signed token") return nil, "", false diff --git a/enterprise/wsproxy/wsproxy.go b/enterprise/wsproxy/wsproxy.go index 2a7e9e81e0cda..fe900fa433530 100644 --- a/enterprise/wsproxy/wsproxy.go +++ b/enterprise/wsproxy/wsproxy.go @@ -31,6 +31,7 @@ import ( "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/cli/cliutil" "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/tracing" @@ -130,6 +131,13 @@ type Server struct { // the moon's token. SDKClient *wsproxysdk.Client + // apiKeyEncryptionKeycache manages the encryption keys for smuggling API + // tokens to the alternate domain when using workspace apps. + apiKeyEncryptionKeycache cryptokeys.EncryptionKeycache + // appTokenSigningKeycache manages the signing keys for signing the app + // tokens we use for workspace apps. + appTokenSigningKeycache cryptokeys.SigningKeycache + // DERP derpMesh *derpmesh.Mesh derpMeshTLSConfig *tls.Config @@ -195,19 +203,42 @@ func New(ctx context.Context, opts *Options) (*Server, error) { derpServer := derp.NewServer(key.NewNode(), tailnet.Logger(opts.Logger.Named("net.derp"))) ctx, cancel := context.WithCancel(context.Background()) + + encryptionCache, err := cryptokeys.NewEncryptionCache(ctx, + opts.Logger, + &ProxyFetcher{Client: client}, + codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey, + ) + if err != nil { + cancel() + return nil, xerrors.Errorf("create api key encryption cache: %w", err) + } + signingCache, err := cryptokeys.NewSigningCache(ctx, + opts.Logger, + &ProxyFetcher{Client: client}, + codersdk.CryptoKeyFeatureWorkspaceAppsToken, + ) + if err != nil { + cancel() + return nil, xerrors.Errorf("create api token signing cache: %w", err) + } + r := chi.NewRouter() s := &Server{ - Options: opts, - Handler: r, - DashboardURL: opts.DashboardURL, - Logger: opts.Logger.Named("net.workspace-proxy"), - TracerProvider: opts.Tracing, - PrometheusRegistry: opts.PrometheusRegistry, - SDKClient: client, - derpMesh: derpmesh.New(opts.Logger.Named("net.derpmesh"), derpServer, meshTLSConfig), - derpMeshTLSConfig: meshTLSConfig, - ctx: ctx, - cancel: cancel, + ctx: ctx, + cancel: cancel, + + Options: opts, + Handler: r, + DashboardURL: opts.DashboardURL, + Logger: opts.Logger.Named("net.workspace-proxy"), + TracerProvider: opts.Tracing, + PrometheusRegistry: opts.PrometheusRegistry, + SDKClient: client, + derpMesh: derpmesh.New(opts.Logger.Named("net.derpmesh"), derpServer, meshTLSConfig), + derpMeshTLSConfig: meshTLSConfig, + apiKeyEncryptionKeycache: encryptionCache, + appTokenSigningKeycache: signingCache, } // Register the workspace proxy with the primary coderd instance and start a @@ -240,11 +271,6 @@ func New(ctx context.Context, opts *Options) (*Server, error) { return nil, xerrors.Errorf("handle register: %w", err) } - secKey, err := workspaceapps.KeyFromString(regResp.AppSecurityKey) - if err != nil { - return nil, xerrors.Errorf("parse app security key: %w", err) - } - agentProvider, err := coderd.NewServerTailnet(ctx, s.Logger, nil, @@ -277,20 +303,21 @@ func New(ctx context.Context, opts *Options) (*Server, error) { HostnameRegex: opts.AppHostnameRegex, RealIPConfig: opts.RealIPConfig, SignedTokenProvider: &TokenProvider{ - DashboardURL: opts.DashboardURL, - AccessURL: opts.AccessURL, - AppHostname: opts.AppHostname, - Client: client, - SecurityKey: secKey, - Logger: s.Logger.Named("proxy_token_provider"), + DashboardURL: opts.DashboardURL, + AccessURL: opts.AccessURL, + AppHostname: opts.AppHostname, + Client: client, + TokenSigningKeycache: signingCache, + APIKeyEncryptionKeycache: encryptionCache, + Logger: s.Logger.Named("proxy_token_provider"), }, - AppSecurityKey: secKey, DisablePathApps: opts.DisablePathApps, SecureAuthCookie: opts.SecureAuthCookie, - AgentProvider: agentProvider, - StatsCollector: workspaceapps.NewStatsCollector(opts.StatsCollectorOptions), + AgentProvider: agentProvider, + StatsCollector: workspaceapps.NewStatsCollector(opts.StatsCollectorOptions), + APIKeyEncryptionKeycache: encryptionCache, } derpHandler := derphttp.Handler(derpServer) @@ -435,6 +462,8 @@ func (s *Server) Close() error { err = multierror.Append(err, agentProviderErr) } s.SDKClient.SDKClient.HTTPClient.CloseIdleConnections() + _ = s.appTokenSigningKeycache.Close() + _ = s.apiKeyEncryptionKeycache.Close() return err } diff --git a/enterprise/wsproxy/wsproxy_test.go b/enterprise/wsproxy/wsproxy_test.go index 3d3926c5afae7..4add46af9bc0a 100644 --- a/enterprise/wsproxy/wsproxy_test.go +++ b/enterprise/wsproxy/wsproxy_test.go @@ -25,6 +25,9 @@ import ( "github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/healthcheck/derphealth" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/workspaceapps/apptest" @@ -932,6 +935,9 @@ func TestWorkspaceProxyWorkspaceApps(t *testing.T) { if opts.PrimaryAppHost == "" { opts.PrimaryAppHost = "*.primary.test.coder.com" } + + db, pubsub := dbtestutil.NewDB(t) + client, closer, api, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ Options: &coderdtest.Options{ DeploymentValues: deploymentValues, @@ -947,6 +953,8 @@ func TestWorkspaceProxyWorkspaceApps(t *testing.T) { }, }, WorkspaceAppsStatsCollectorOptions: opts.StatsCollectorOptions, + Database: db, + Pubsub: pubsub, }, LicenseOptions: &coderdenttest.LicenseOptions{ Features: license.Features{ @@ -959,6 +967,13 @@ func TestWorkspaceProxyWorkspaceApps(t *testing.T) { _ = closer.Close() }) + _ = dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceAppsToken, + }) + _ = dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, + }) + // Create the external proxy if opts.DisableSubdomainApps { opts.AppHost = "" @@ -1002,6 +1017,8 @@ func TestWorkspaceProxyWorkspaceApps_BlockDirect(t *testing.T) { if opts.PrimaryAppHost == "" { opts.PrimaryAppHost = "*.primary.test.coder.com" } + + db, pubsub := dbtestutil.NewDB(t) client, closer, api, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ Options: &coderdtest.Options{ DeploymentValues: deploymentValues, @@ -1017,6 +1034,8 @@ func TestWorkspaceProxyWorkspaceApps_BlockDirect(t *testing.T) { }, }, WorkspaceAppsStatsCollectorOptions: opts.StatsCollectorOptions, + Database: db, + Pubsub: pubsub, }, LicenseOptions: &coderdenttest.LicenseOptions{ Features: license.Features{ @@ -1029,6 +1048,13 @@ func TestWorkspaceProxyWorkspaceApps_BlockDirect(t *testing.T) { _ = closer.Close() }) + _ = dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceAppsToken, + }) + _ = dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey, + }) + // Create the external proxy if opts.DisableSubdomainApps { opts.AppHost = "" diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index 77d36561c6de8..a8f22c2b93063 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -205,7 +205,6 @@ type RegisterWorkspaceProxyRequest struct { } type RegisterWorkspaceProxyResponse struct { - AppSecurityKey string `json:"app_security_key"` DERPMeshKey string `json:"derp_mesh_key"` DERPRegionID int32 `json:"derp_region_id"` DERPMap *tailcfg.DERPMap `json:"derp_map"` @@ -372,12 +371,6 @@ func (l *RegisterWorkspaceProxyLoop) Start(ctx context.Context) (RegisterWorkspa } failedAttempts = 0 - // Check for consistency. - if originalRes.AppSecurityKey != resp.AppSecurityKey { - l.failureFn(xerrors.New("app security key has changed, proxy must be restarted")) - return - } - if originalRes.DERPMeshKey != resp.DERPMeshKey { l.failureFn(xerrors.New("DERP mesh key has changed, proxy must be restarted")) return @@ -586,10 +579,10 @@ type CryptoKeysResponse struct { CryptoKeys []codersdk.CryptoKey `json:"crypto_keys"` } -func (c *Client) CryptoKeys(ctx context.Context) (CryptoKeysResponse, error) { +func (c *Client) CryptoKeys(ctx context.Context, feature codersdk.CryptoKeyFeature) (CryptoKeysResponse, error) { res, err := c.Request(ctx, http.MethodGet, - "/api/v2/workspaceproxies/me/crypto-keys", - nil, + "/api/v2/workspaceproxies/me/crypto-keys", nil, + codersdk.WithQueryParam("feature", string(feature)), ) if err != nil { return CryptoKeysResponse{}, xerrors.Errorf("make request: %w", err) diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index e55167ef03f88..d687fb68ec61f 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -2110,8 +2110,8 @@ export type BuildReason = "autostart" | "autostop" | "initiator" export const BuildReasons: BuildReason[] = ["autostart", "autostop", "initiator"] // From codersdk/deployment.go -export type CryptoKeyFeature = "oidc_convert" | "tailnet_resume" | "workspace_apps" -export const CryptoKeyFeatures: CryptoKeyFeature[] = ["oidc_convert", "tailnet_resume", "workspace_apps"] +export type CryptoKeyFeature = "oidc_convert" | "tailnet_resume" | "workspace_apps_api_key" | "workspace_apps_token" +export const CryptoKeyFeatures: CryptoKeyFeature[] = ["oidc_convert", "tailnet_resume", "workspace_apps_api_key", "workspace_apps_token"] // From codersdk/workspaceagents.go export type DisplayApp = "port_forwarding_helper" | "ssh_helper" | "vscode" | "vscode_insiders" | "web_terminal" diff --git a/tailnet/resume.go b/tailnet/resume.go index b9443064a37f9..2975fa35f1674 100644 --- a/tailnet/resume.go +++ b/tailnet/resume.go @@ -3,32 +3,23 @@ package tailnet import ( "context" "crypto/rand" - "database/sql" - "encoding/hex" - "encoding/json" "time" - "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v4/jwt" "github.com/google/uuid" "golang.org/x/xerrors" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/quartz" ) const ( DefaultResumeTokenExpiry = 24 * time.Hour - - resumeTokenSigningAlgorithm = jose.HS512 ) -// resumeTokenSigningKeyID is a fixed key ID for the resume token signing key. -// If/when we add support for multiple keys (e.g. key rotation), this will move -// to the database instead. -var resumeTokenSigningKeyID = uuid.MustParse("97166747-9309-4d7f-9071-a230e257c2a4") - // NewInsecureTestResumeTokenProvider returns a ResumeTokenProvider that uses a // random key with short expiry for testing purposes. If any errors occur while // generating the key, the function panics. @@ -37,12 +28,15 @@ func NewInsecureTestResumeTokenProvider() ResumeTokenProvider { if err != nil { panic(err) } - return NewResumeTokenKeyProvider(key, quartz.NewReal(), time.Hour) + return NewResumeTokenKeyProvider(jwtutils.StaticKey{ + ID: uuid.New().String(), + Key: key[:], + }, quartz.NewReal(), time.Hour) } type ResumeTokenProvider interface { - GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) - VerifyResumeToken(token string) (uuid.UUID, error) + GenerateResumeToken(ctx context.Context, peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) + VerifyResumeToken(ctx context.Context, token string) (uuid.UUID, error) } type ResumeTokenSigningKey [64]byte @@ -56,104 +50,37 @@ func GenerateResumeTokenSigningKey() (ResumeTokenSigningKey, error) { return key, nil } -type ResumeTokenSigningKeyDatabaseStore interface { - GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) - UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, key string) error -} - -// ResumeTokenSigningKeyFromDatabase retrieves the coordinator resume token -// signing key from the database. If the key is not found, a new key is -// generated and inserted into the database. -func ResumeTokenSigningKeyFromDatabase(ctx context.Context, db ResumeTokenSigningKeyDatabaseStore) (ResumeTokenSigningKey, error) { - var resumeTokenKey ResumeTokenSigningKey - resumeTokenKeyStr, err := db.GetCoordinatorResumeTokenSigningKey(ctx) - if err != nil && !xerrors.Is(err, sql.ErrNoRows) { - return resumeTokenKey, xerrors.Errorf("get coordinator resume token key: %w", err) - } - if decoded, err := hex.DecodeString(resumeTokenKeyStr); err != nil || len(decoded) != len(resumeTokenKey) { - newKey, err := GenerateResumeTokenSigningKey() - if err != nil { - return resumeTokenKey, xerrors.Errorf("generate fresh coordinator resume token key: %w", err) - } - - resumeTokenKeyStr = hex.EncodeToString(newKey[:]) - err = db.UpsertCoordinatorResumeTokenSigningKey(ctx, resumeTokenKeyStr) - if err != nil { - return resumeTokenKey, xerrors.Errorf("insert freshly generated coordinator resume token key to database: %w", err) - } - } - - resumeTokenKeyBytes, err := hex.DecodeString(resumeTokenKeyStr) - if err != nil { - return resumeTokenKey, xerrors.Errorf("decode coordinator resume token key from database: %w", err) - } - if len(resumeTokenKeyBytes) != len(resumeTokenKey) { - return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is not the correct length, expect %d got %d", len(resumeTokenKey), len(resumeTokenKeyBytes)) - } - copy(resumeTokenKey[:], resumeTokenKeyBytes) - if resumeTokenKey == [64]byte{} { - return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is empty") - } - return resumeTokenKey, nil -} - type ResumeTokenKeyProvider struct { - key ResumeTokenSigningKey + key jwtutils.SigningKeyManager clock quartz.Clock expiry time.Duration } -func NewResumeTokenKeyProvider(key ResumeTokenSigningKey, clock quartz.Clock, expiry time.Duration) ResumeTokenProvider { +func NewResumeTokenKeyProvider(key jwtutils.SigningKeyManager, clock quartz.Clock, expiry time.Duration) ResumeTokenProvider { if expiry <= 0 { expiry = DefaultResumeTokenExpiry } return ResumeTokenKeyProvider{ key: key, clock: clock, - expiry: DefaultResumeTokenExpiry, + expiry: expiry, } } -type resumeTokenPayload struct { - PeerID uuid.UUID `json:"sub"` - Expiry int64 `json:"exp"` -} - -func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) { +func (p ResumeTokenKeyProvider) GenerateResumeToken(ctx context.Context, peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) { exp := p.clock.Now().Add(p.expiry) - payload := resumeTokenPayload{ - PeerID: peerID, - Expiry: exp.Unix(), - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - return nil, xerrors.Errorf("marshal payload to JSON: %w", err) - } - - signer, err := jose.NewSigner(jose.SigningKey{ - Algorithm: resumeTokenSigningAlgorithm, - Key: p.key[:], - }, &jose.SignerOptions{ - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "kid": resumeTokenSigningKeyID.String(), - }, - }) - if err != nil { - return nil, xerrors.Errorf("create signer: %w", err) + payload := jwtutils.RegisteredClaims{ + Subject: peerID.String(), + Expiry: jwt.NewNumericDate(exp), } - signedObject, err := signer.Sign(payloadBytes) + token, err := jwtutils.Sign(ctx, p.key, payload) if err != nil { return nil, xerrors.Errorf("sign payload: %w", err) } - serialized, err := signedObject.CompactSerialize() - if err != nil { - return nil, xerrors.Errorf("serialize JWS: %w", err) - } - return &proto.RefreshResumeTokenResponse{ - Token: serialized, + Token: token, RefreshIn: durationpb.New(p.expiry / 2), ExpiresAt: timestamppb.New(exp), }, nil @@ -162,35 +89,17 @@ func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.Re // VerifyResumeToken parses a signed tailnet resume token with the given key and // returns the payload. If the token is invalid or expired, an error is // returned. -func (p ResumeTokenKeyProvider) VerifyResumeToken(str string) (uuid.UUID, error) { - object, err := jose.ParseSigned(str) - if err != nil { - return uuid.Nil, xerrors.Errorf("parse JWS: %w", err) - } - if len(object.Signatures) != 1 { - return uuid.Nil, xerrors.New("expected 1 signature") - } - if object.Signatures[0].Header.Algorithm != string(resumeTokenSigningAlgorithm) { - return uuid.Nil, xerrors.Errorf("expected token signing algorithm to be %q, got %q", resumeTokenSigningAlgorithm, object.Signatures[0].Header.Algorithm) - } - if object.Signatures[0].Header.KeyID != resumeTokenSigningKeyID.String() { - return uuid.Nil, xerrors.Errorf("expected token key ID to be %q, got %q", resumeTokenSigningKeyID, object.Signatures[0].Header.KeyID) - } - - output, err := object.Verify(p.key[:]) +func (p ResumeTokenKeyProvider) VerifyResumeToken(ctx context.Context, str string) (uuid.UUID, error) { + var tok jwt.Claims + err := jwtutils.Verify(ctx, p.key, str, &tok, jwtutils.WithVerifyExpected(jwt.Expected{ + Time: p.clock.Now(), + })) if err != nil { - return uuid.Nil, xerrors.Errorf("verify JWS: %w", err) + return uuid.Nil, xerrors.Errorf("verify payload: %w", err) } - - var tok resumeTokenPayload - err = json.Unmarshal(output, &tok) + parsed, err := uuid.Parse(tok.Subject) if err != nil { - return uuid.Nil, xerrors.Errorf("unmarshal payload: %w", err) + return uuid.Nil, xerrors.Errorf("parse peerID from token: %w", err) } - exp := time.Unix(tok.Expiry, 0) - if exp.Before(p.clock.Now()) { - return uuid.Nil, xerrors.New("signed resume token expired") - } - - return tok.PeerID, nil + return parsed, nil } diff --git a/tailnet/resume_test.go b/tailnet/resume_test.go index 3f63887cbfef3..6f32fba4c511e 100644 --- a/tailnet/resume_test.go +++ b/tailnet/resume_test.go @@ -1,117 +1,20 @@ package tailnet_test import ( - "context" - "encoding/hex" "testing" "time" + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" "github.com/google/uuid" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - "github.com/coder/coder/v2/coderd/database/dbmock" - "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" ) -func TestResumeTokenSigningKeyFromDatabase(t *testing.T) { - t.Parallel() - - assertRandomKey := func(t *testing.T, key tailnet.ResumeTokenSigningKey) { - t.Helper() - assert.NotEqual(t, tailnet.ResumeTokenSigningKey{}, key, "key should not be empty") - assert.NotEqualValues(t, [64]byte{1}, key, "key should not be all 1s") - } - - t.Run("GenerateRetrieve", func(t *testing.T) { - t.Parallel() - - db, _ := dbtestutil.NewDB(t) - ctx := testutil.Context(t, testutil.WaitShort) - key1, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db) - require.NoError(t, err) - assertRandomKey(t, key1) - - key2, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db) - require.NoError(t, err) - require.Equal(t, key1, key2, "keys should not be different") - }) - - t.Run("GetError", func(t *testing.T) { - t.Parallel() - - db := dbmock.NewMockStore(gomock.NewController(t)) - db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("", assert.AnError) - - ctx := testutil.Context(t, testutil.WaitShort) - _, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db) - require.ErrorIs(t, err, assert.AnError) - }) - - t.Run("UpsertError", func(t *testing.T) { - t.Parallel() - - db := dbmock.NewMockStore(gomock.NewController(t)) - db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("", nil) - db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(assert.AnError) - - ctx := testutil.Context(t, testutil.WaitShort) - _, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db) - require.ErrorIs(t, err, assert.AnError) - }) - - t.Run("DecodeErrorShouldRegenerate", func(t *testing.T) { - t.Parallel() - - db := dbmock.NewMockStore(gomock.NewController(t)) - db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("invalid", nil) - - var storedKey tailnet.ResumeTokenSigningKey - db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Do(func(_ context.Context, value string) error { - keyBytes, err := hex.DecodeString(value) - require.NoError(t, err) - require.Len(t, keyBytes, len(storedKey)) - copy(storedKey[:], keyBytes) - return nil - }) - - ctx := testutil.Context(t, testutil.WaitShort) - key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db) - require.NoError(t, err) - assertRandomKey(t, key) - require.Equal(t, storedKey, key, "key should match stored value") - }) - - t.Run("LengthErrorShouldRegenerate", func(t *testing.T) { - t.Parallel() - - db := dbmock.NewMockStore(gomock.NewController(t)) - db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("deadbeef", nil) - db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(nil) - - ctx := testutil.Context(t, testutil.WaitShort) - key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db) - require.NoError(t, err) - assertRandomKey(t, key) - }) - - t.Run("EmptyError", func(t *testing.T) { - t.Parallel() - - db := dbmock.NewMockStore(gomock.NewController(t)) - emptyKey := hex.EncodeToString(make([]byte, 64)) - db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return(emptyKey, nil) - - ctx := testutil.Context(t, testutil.WaitShort) - _, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db) - require.ErrorContains(t, err, "is empty") - }) -} - func TestResumeTokenKeyProvider(t *testing.T) { t.Parallel() @@ -121,17 +24,18 @@ func TestResumeTokenKeyProvider(t *testing.T) { t.Run("OK", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) id := uuid.New() clock := quartz.NewMock(t) - provider := tailnet.NewResumeTokenKeyProvider(key, clock, tailnet.DefaultResumeTokenExpiry) - token, err := provider.GenerateResumeToken(id) + provider := tailnet.NewResumeTokenKeyProvider(newKeySigner(key), clock, tailnet.DefaultResumeTokenExpiry) + token, err := provider.GenerateResumeToken(ctx, id) require.NoError(t, err) require.NotNil(t, token) require.NotEmpty(t, token.Token) require.Equal(t, tailnet.DefaultResumeTokenExpiry/2, token.RefreshIn.AsDuration()) require.WithinDuration(t, clock.Now().Add(tailnet.DefaultResumeTokenExpiry), token.ExpiresAt.AsTime(), time.Second) - gotID, err := provider.VerifyResumeToken(token.Token) + gotID, err := provider.VerifyResumeToken(ctx, token.Token) require.NoError(t, err) require.Equal(t, id, gotID) }) @@ -139,43 +43,57 @@ func TestResumeTokenKeyProvider(t *testing.T) { t.Run("Expired", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) id := uuid.New() clock := quartz.NewMock(t) - provider := tailnet.NewResumeTokenKeyProvider(key, clock, tailnet.DefaultResumeTokenExpiry) - token, err := provider.GenerateResumeToken(id) + provider := tailnet.NewResumeTokenKeyProvider(newKeySigner(key), clock, tailnet.DefaultResumeTokenExpiry) + token, err := provider.GenerateResumeToken(ctx, id) require.NoError(t, err) require.NotNil(t, token) require.NotEmpty(t, token.Token) require.Equal(t, tailnet.DefaultResumeTokenExpiry/2, token.RefreshIn.AsDuration()) require.WithinDuration(t, clock.Now().Add(tailnet.DefaultResumeTokenExpiry), token.ExpiresAt.AsTime(), time.Second) - // Advance time past expiry - _ = clock.Advance(tailnet.DefaultResumeTokenExpiry + time.Second) + // Advance time past expiry. Account for leeway. + _ = clock.Advance(tailnet.DefaultResumeTokenExpiry + time.Second*61) - _, err = provider.VerifyResumeToken(token.Token) - require.ErrorContains(t, err, "expired") + _, err = provider.VerifyResumeToken(ctx, token.Token) + require.Error(t, err) + require.ErrorIs(t, err, jwt.ErrExpired) }) t.Run("InvalidToken", func(t *testing.T) { t.Parallel() - provider := tailnet.NewResumeTokenKeyProvider(key, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry) - _, err := provider.VerifyResumeToken("invalid") + ctx := testutil.Context(t, testutil.WaitShort) + provider := tailnet.NewResumeTokenKeyProvider(newKeySigner(key), quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry) + _, err := provider.VerifyResumeToken(ctx, "invalid") require.ErrorContains(t, err, "parse JWS") }) t.Run("VerifyError", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) // Generate a resume token with a different key otherKey, err := tailnet.GenerateResumeTokenSigningKey() require.NoError(t, err) - otherProvider := tailnet.NewResumeTokenKeyProvider(otherKey, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry) - token, err := otherProvider.GenerateResumeToken(uuid.New()) + otherSigner := newKeySigner(otherKey) + otherProvider := tailnet.NewResumeTokenKeyProvider(otherSigner, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry) + token, err := otherProvider.GenerateResumeToken(ctx, uuid.New()) require.NoError(t, err) - provider := tailnet.NewResumeTokenKeyProvider(key, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry) - _, err = provider.VerifyResumeToken(token.Token) - require.ErrorContains(t, err, "verify JWS") + signer := newKeySigner(key) + signer.ID = otherSigner.ID + provider := tailnet.NewResumeTokenKeyProvider(signer, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry) + _, err = provider.VerifyResumeToken(ctx, token.Token) + require.ErrorIs(t, err, jose.ErrCryptoFailure) }) } + +func newKeySigner(key tailnet.ResumeTokenSigningKey) jwtutils.StaticKey { + return jwtutils.StaticKey{ + ID: "123", + Key: key[:], + } +} diff --git a/tailnet/service.go b/tailnet/service.go index 28a054dd8d671..7f38f63a589b3 100644 --- a/tailnet/service.go +++ b/tailnet/service.go @@ -177,7 +177,7 @@ func (s *DRPCService) RefreshResumeToken(ctx context.Context, _ *proto.RefreshRe return nil, xerrors.New("no Stream ID") } - res, err := s.ResumeTokenProvider.GenerateResumeToken(streamID.ID) + res, err := s.ResumeTokenProvider.GenerateResumeToken(ctx, streamID.ID) if err != nil { return nil, xerrors.Errorf("generate resume token: %w", err) } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy