diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 077d704be1300..d3bbceec14252 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1041,6 +1041,10 @@ func (q *querier) DeleteCoordinator(ctx context.Context, id uuid.UUID) error { return q.db.DeleteCoordinator(ctx, id) } +func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) { + panic("not implemented") +} + func (q *querier) DeleteCustomRole(ctx context.Context, arg database.DeleteCustomRoleParams) error { if arg.OrganizationID.UUID != uuid.Nil { if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceAssignOrgRole.InOrg(arg.OrganizationID.UUID)); err != nil { @@ -1383,6 +1387,14 @@ func (q *querier) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (stri return q.db.GetCoordinatorResumeTokenSigningKey(ctx) } +func (q *querier) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { + panic("not implemented") +} + +func (q *querier) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) { + panic("not implemented") +} + func (q *querier) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil { return nil, err @@ -1549,6 +1561,10 @@ func (q *querier) GetLastUpdateCheck(ctx context.Context) (string, error) { return q.db.GetLastUpdateCheck(ctx) } +func (q *querier) GetLatestCryptoKeyByFeature(ctx context.Context, feature database.CryptoKeyFeature) (database.CryptoKey, error) { + panic("not implemented") +} + func (q *querier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { if _, err := q.GetWorkspaceByID(ctx, workspaceID); err != nil { return database.WorkspaceBuild{}, err @@ -2654,6 +2670,10 @@ func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLo return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg) } +func (q *querier) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) { + panic("not implemented") +} + func (q *querier) InsertCustomRole(ctx context.Context, arg database.InsertCustomRoleParams) (database.CustomRole, error) { // Org and site role upsert share the same query. So switch the assertion based on the org uuid. if arg.OrganizationID.UUID != uuid.Nil { @@ -3157,6 +3177,10 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg) } +func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) { + panic("not implemented") +} + func (q *querier) UpdateCustomRole(ctx context.Context, arg database.UpdateCustomRoleParams) (database.CustomRole, error) { if arg.OrganizationID.UUID != uuid.Nil { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceAssignOrgRole.InOrg(arg.OrganizationID.UUID)); err != nil { diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 79aee59d97dbe..06e40287cff29 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -2,6 +2,7 @@ package dbgen import ( "context" + "crypto/rand" "crypto/sha256" "database/sql" "encoding/hex" @@ -893,6 +894,36 @@ func CustomRole(t testing.TB, db database.Store, seed database.CustomRole) datab return role } +func CryptoKey(t testing.TB, db database.Store, seed database.CryptoKey) database.CryptoKey { + t.Helper() + + b := make([]byte, 96) + _, err := rand.Read(b) + require.NoError(t, err, "generate secret") + + key, err := db.InsertCryptoKey(genCtx, database.InsertCryptoKeyParams{ + Sequence: takeFirst(seed.Sequence, 123), + Secret: takeFirst(seed.Secret, sql.NullString{ + String: hex.EncodeToString(b), + Valid: true, + }), + SecretKeyID: takeFirst(seed.SecretKeyID, sql.NullString{}), + Feature: takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceApps), + StartsAt: takeFirst(seed.StartsAt, time.Now()), + }) + require.NoError(t, err, "insert crypto key") + + if seed.DeletesAt.Valid { + key, err = db.UpdateCryptoKeyDeletesAt(genCtx, database.UpdateCryptoKeyDeletesAtParams{ + Feature: key.Feature, + Sequence: key.Sequence, + DeletesAt: sql.NullTime{Time: seed.DeletesAt.Time, Valid: true}, + }) + require.NoError(t, err, "update crypto key deletes_at") + } + return key +} + func must[V any](v V, err error) V { if err != nil { panic(err) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index ed766d48ecd43..774d9296e51bc 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -153,6 +153,7 @@ type data struct { // New tables workspaceAgentStats []database.WorkspaceAgentStat auditLogs []database.AuditLog + cryptoKeys []database.CryptoKey dbcryptKeys []database.DBCryptKey files []database.File externalAuthLinks []database.ExternalAuthLink @@ -1434,6 +1435,15 @@ func (*FakeQuerier) DeleteCoordinator(context.Context, uuid.UUID) error { return ErrUnimplemented } +func (q *FakeQuerier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) { + err := validateDatabaseType(arg) + if err != nil { + return database.CryptoKey{}, err + } + + panic("not implemented") +} + func (q *FakeQuerier) DeleteCustomRole(_ context.Context, arg database.DeleteCustomRoleParams) error { err := validateDatabaseType(arg) if err != nil { @@ -2309,6 +2319,32 @@ func (q *FakeQuerier) GetCoordinatorResumeTokenSigningKey(_ context.Context) (st return q.coordinatorResumeTokenSigningKey, nil } +func (q *FakeQuerier) GetCryptoKeyByFeatureAndSequence(_ context.Context, arg database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { + err := validateDatabaseType(arg) + if err != nil { + return database.CryptoKey{}, err + } + + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, key := range q.cryptoKeys { + if key.Feature == arg.Feature && key.Sequence == arg.Sequence { + // Keys with NULL secrets are considered deleted. + if key.Secret.Valid { + return key, nil + } + return database.CryptoKey{}, sql.ErrNoRows + } + } + + return database.CryptoKey{}, sql.ErrNoRows +} + +func (q *FakeQuerier) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) { + panic("not implemented") +} + func (q *FakeQuerier) GetDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -2806,6 +2842,10 @@ func (q *FakeQuerier) GetLastUpdateCheck(_ context.Context) (string, error) { return string(q.lastUpdateCheck), nil } +func (q *FakeQuerier) GetLatestCryptoKeyByFeature(ctx context.Context, feature database.CryptoKeyFeature) (database.CryptoKey, error) { + panic("not implemented") +} + func (q *FakeQuerier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -6305,6 +6345,28 @@ func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAudit return alog, nil } +func (q *FakeQuerier) InsertCryptoKey(_ context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) { + err := validateDatabaseType(arg) + if err != nil { + return database.CryptoKey{}, err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + key := database.CryptoKey{ + Feature: arg.Feature, + Sequence: arg.Sequence, + Secret: arg.Secret, + SecretKeyID: arg.SecretKeyID, + StartsAt: arg.StartsAt, + } + + q.cryptoKeys = append(q.cryptoKeys, key) + + return key, nil +} + func (q *FakeQuerier) InsertCustomRole(_ context.Context, arg database.InsertCustomRoleParams) (database.CustomRole, error) { err := validateDatabaseType(arg) if err != nil { @@ -7774,6 +7836,15 @@ func (q *FakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPI return sql.ErrNoRows } +func (q *FakeQuerier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) { + err := validateDatabaseType(arg) + if err != nil { + return database.CryptoKey{}, err + } + + panic("not implemented") +} + func (q *FakeQuerier) UpdateCustomRole(_ context.Context, arg database.UpdateCustomRoleParams) (database.CustomRole, error) { err := validateDatabaseType(arg) if err != nil { diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 0ec70c1736d43..bf95ad82896d8 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -214,6 +214,13 @@ func (m metricsStore) DeleteCoordinator(ctx context.Context, id uuid.UUID) error return r0 } +func (m metricsStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) { + start := time.Now() + r0, r1 := m.s.DeleteCryptoKey(ctx, arg) + m.queryLatencies.WithLabelValues("DeleteCryptoKey").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) DeleteCustomRole(ctx context.Context, arg database.DeleteCustomRoleParams) error { start := time.Now() r0 := m.s.DeleteCustomRole(ctx, arg) @@ -543,6 +550,20 @@ func (m metricsStore) GetCoordinatorResumeTokenSigningKey(ctx context.Context) ( return r0, r1 } +func (m metricsStore) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { + start := time.Now() + r0, r1 := m.s.GetCryptoKeyByFeatureAndSequence(ctx, arg) + m.queryLatencies.WithLabelValues("GetCryptoKeyByFeatureAndSequence").Observe(time.Since(start).Seconds()) + return r0, r1 +} + +func (m metricsStore) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) { + start := time.Now() + r0, r1 := m.s.GetCryptoKeys(ctx) + m.queryLatencies.WithLabelValues("GetCryptoKeys").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) { start := time.Now() r0, r1 := m.s.GetDBCryptKeys(ctx) @@ -711,6 +732,13 @@ func (m metricsStore) GetLastUpdateCheck(ctx context.Context) (string, error) { return version, err } +func (m metricsStore) GetLatestCryptoKeyByFeature(ctx context.Context, feature database.CryptoKeyFeature) (database.CryptoKey, error) { + start := time.Now() + r0, r1 := m.s.GetLatestCryptoKeyByFeature(ctx, feature) + m.queryLatencies.WithLabelValues("GetLatestCryptoKeyByFeature").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { start := time.Now() build, err := m.s.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspaceID) @@ -1593,6 +1621,13 @@ func (m metricsStore) InsertAuditLog(ctx context.Context, arg database.InsertAud return log, err } +func (m metricsStore) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) { + start := time.Now() + key, err := m.s.InsertCryptoKey(ctx, arg) + m.queryLatencies.WithLabelValues("InsertCryptoKey").Observe(time.Since(start).Seconds()) + return key, err +} + func (m metricsStore) InsertCustomRole(ctx context.Context, arg database.InsertCustomRoleParams) (database.CustomRole, error) { start := time.Now() r0, r1 := m.s.InsertCustomRole(ctx, arg) @@ -1992,6 +2027,13 @@ func (m metricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateA return err } +func (m metricsStore) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) { + start := time.Now() + key, err := m.s.UpdateCryptoKeyDeletesAt(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateCryptoKeyDeletesAt").Observe(time.Since(start).Seconds()) + return key, err +} + func (m metricsStore) UpdateCustomRole(ctx context.Context, arg database.UpdateCustomRoleParams) (database.CustomRole, error) { start := time.Now() r0, r1 := m.s.UpdateCustomRole(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index c5d579e1c2656..0ab399f573bfe 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -317,6 +317,21 @@ func (mr *MockStoreMockRecorder) DeleteCoordinator(arg0, arg1 any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCoordinator", reflect.TypeOf((*MockStore)(nil).DeleteCoordinator), arg0, arg1) } +// DeleteCryptoKey mocks base method. +func (m *MockStore) DeleteCryptoKey(arg0 context.Context, arg1 database.DeleteCryptoKeyParams) (database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteCryptoKey", arg0, arg1) + ret0, _ := ret[0].(database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteCryptoKey indicates an expected call of DeleteCryptoKey. +func (mr *MockStoreMockRecorder) DeleteCryptoKey(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCryptoKey", reflect.TypeOf((*MockStore)(nil).DeleteCryptoKey), arg0, arg1) +} + // DeleteCustomRole mocks base method. func (m *MockStore) DeleteCustomRole(arg0 context.Context, arg1 database.DeleteCustomRoleParams) error { m.ctrl.T.Helper() @@ -1058,6 +1073,36 @@ func (mr *MockStoreMockRecorder) GetCoordinatorResumeTokenSigningKey(arg0 any) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCoordinatorResumeTokenSigningKey", reflect.TypeOf((*MockStore)(nil).GetCoordinatorResumeTokenSigningKey), arg0) } +// GetCryptoKeyByFeatureAndSequence mocks base method. +func (m *MockStore) GetCryptoKeyByFeatureAndSequence(arg0 context.Context, arg1 database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCryptoKeyByFeatureAndSequence", arg0, arg1) + ret0, _ := ret[0].(database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCryptoKeyByFeatureAndSequence indicates an expected call of GetCryptoKeyByFeatureAndSequence. +func (mr *MockStoreMockRecorder) GetCryptoKeyByFeatureAndSequence(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCryptoKeyByFeatureAndSequence", reflect.TypeOf((*MockStore)(nil).GetCryptoKeyByFeatureAndSequence), arg0, arg1) +} + +// GetCryptoKeys mocks base method. +func (m *MockStore) GetCryptoKeys(arg0 context.Context) ([]database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCryptoKeys", arg0) + ret0, _ := ret[0].([]database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCryptoKeys indicates an expected call of GetCryptoKeys. +func (mr *MockStoreMockRecorder) GetCryptoKeys(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCryptoKeys", reflect.TypeOf((*MockStore)(nil).GetCryptoKeys), arg0) +} + // GetDBCryptKeys mocks base method. func (m *MockStore) GetDBCryptKeys(arg0 context.Context) ([]database.DBCryptKey, error) { m.ctrl.T.Helper() @@ -1418,6 +1463,21 @@ func (mr *MockStoreMockRecorder) GetLastUpdateCheck(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLastUpdateCheck", reflect.TypeOf((*MockStore)(nil).GetLastUpdateCheck), arg0) } +// GetLatestCryptoKeyByFeature mocks base method. +func (m *MockStore) GetLatestCryptoKeyByFeature(arg0 context.Context, arg1 database.CryptoKeyFeature) (database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLatestCryptoKeyByFeature", arg0, arg1) + ret0, _ := ret[0].(database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLatestCryptoKeyByFeature indicates an expected call of GetLatestCryptoKeyByFeature. +func (mr *MockStoreMockRecorder) GetLatestCryptoKeyByFeature(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestCryptoKeyByFeature", reflect.TypeOf((*MockStore)(nil).GetLatestCryptoKeyByFeature), arg0, arg1) +} + // GetLatestWorkspaceBuildByWorkspaceID mocks base method. func (m *MockStore) GetLatestWorkspaceBuildByWorkspaceID(arg0 context.Context, arg1 uuid.UUID) (database.WorkspaceBuild, error) { m.ctrl.T.Helper() @@ -3352,6 +3412,21 @@ func (mr *MockStoreMockRecorder) InsertAuditLog(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAuditLog", reflect.TypeOf((*MockStore)(nil).InsertAuditLog), arg0, arg1) } +// InsertCryptoKey mocks base method. +func (m *MockStore) InsertCryptoKey(arg0 context.Context, arg1 database.InsertCryptoKeyParams) (database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertCryptoKey", arg0, arg1) + ret0, _ := ret[0].(database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InsertCryptoKey indicates an expected call of InsertCryptoKey. +func (mr *MockStoreMockRecorder) InsertCryptoKey(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertCryptoKey", reflect.TypeOf((*MockStore)(nil).InsertCryptoKey), arg0, arg1) +} + // InsertCustomRole mocks base method. func (m *MockStore) InsertCustomRole(arg0 context.Context, arg1 database.InsertCustomRoleParams) (database.CustomRole, error) { m.ctrl.T.Helper() @@ -4204,6 +4279,21 @@ func (mr *MockStoreMockRecorder) UpdateAPIKeyByID(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAPIKeyByID", reflect.TypeOf((*MockStore)(nil).UpdateAPIKeyByID), arg0, arg1) } +// UpdateCryptoKeyDeletesAt mocks base method. +func (m *MockStore) UpdateCryptoKeyDeletesAt(arg0 context.Context, arg1 database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateCryptoKeyDeletesAt", arg0, arg1) + ret0, _ := ret[0].(database.CryptoKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateCryptoKeyDeletesAt indicates an expected call of UpdateCryptoKeyDeletesAt. +func (mr *MockStoreMockRecorder) UpdateCryptoKeyDeletesAt(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateCryptoKeyDeletesAt", reflect.TypeOf((*MockStore)(nil).UpdateCryptoKeyDeletesAt), arg0, arg1) +} + // UpdateCustomRole mocks base method. func (m *MockStore) UpdateCustomRole(arg0 context.Context, arg1 database.UpdateCustomRoleParams) (database.CustomRole, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 6638d52745ba6..17fd3511442ec 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -36,6 +36,12 @@ CREATE TYPE build_reason AS ENUM ( 'autodelete' ); +CREATE TYPE crypto_key_feature AS ENUM ( + 'workspace_apps', + 'oidc_convert', + 'peer_reconnect' +); + CREATE TYPE display_app AS ENUM ( 'vscode', 'vscode_insiders', @@ -494,6 +500,15 @@ CREATE TABLE audit_logs ( resource_icon text NOT NULL ); +CREATE TABLE crypto_keys ( + feature crypto_key_feature NOT NULL, + sequence integer NOT NULL, + secret text, + secret_key_id text, + starts_at timestamp with time zone NOT NULL, + deletes_at timestamp with time zone +); + CREATE TABLE custom_roles ( name text NOT NULL, display_name text NOT NULL, @@ -1640,6 +1655,9 @@ ALTER TABLE ONLY api_keys ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id); +ALTER TABLE ONLY crypto_keys + ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence); + ALTER TABLE ONLY custom_roles ADD CONSTRAINT custom_roles_unique_key UNIQUE (name, organization_id); @@ -2035,6 +2053,9 @@ CREATE TRIGGER update_notification_message_dedupe_hash BEFORE INSERT OR UPDATE O ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; +ALTER TABLE ONLY crypto_keys + ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest); + ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index 0c578255f091c..6046a94e3bcad 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -7,6 +7,7 @@ type ForeignKeyConstraint string // ForeignKeyConstraint enums. const ( ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; + ForeignKeyCryptoKeysSecretKeyID ForeignKeyConstraint = "crypto_keys_secret_key_id_fkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyGitAuthLinksOauthAccessTokenKeyID ForeignKeyConstraint = "git_auth_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyGitAuthLinksOauthRefreshTokenKeyID ForeignKeyConstraint = "git_auth_links_oauth_refresh_token_key_id_fkey" // ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_refresh_token_key_id_fkey FOREIGN KEY (oauth_refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyGitSSHKeysUserID ForeignKeyConstraint = "gitsshkeys_user_id_fkey" // ALTER TABLE ONLY gitsshkeys ADD CONSTRAINT gitsshkeys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id); diff --git a/coderd/database/migrations/000250_crypto_keys.down.sql b/coderd/database/migrations/000250_crypto_keys.down.sql new file mode 100644 index 0000000000000..8b0cc3702bcc4 --- /dev/null +++ b/coderd/database/migrations/000250_crypto_keys.down.sql @@ -0,0 +1 @@ +DROP TABLE "keys"; diff --git a/coderd/database/migrations/000250_crypto_keys.up.sql b/coderd/database/migrations/000250_crypto_keys.up.sql new file mode 100644 index 0000000000000..7c1aa7888fdd1 --- /dev/null +++ b/coderd/database/migrations/000250_crypto_keys.up.sql @@ -0,0 +1,16 @@ +CREATE TYPE crypto_key_feature AS ENUM ( + 'workspace_apps', + 'oidc_convert', + 'peer_reconnect' +); + +CREATE TABLE crypto_keys ( + feature crypto_key_feature NOT NULL, + sequence integer NOT NULL, + secret text NULL, + secret_key_id text NULL REFERENCES dbcrypt_keys(active_key_digest), + starts_at timestamptz NOT NULL, + deletes_at timestamptz NULL, + PRIMARY KEY (feature, sequence) +); + diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index 816fc4c9214b0..82be5e710c058 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -447,3 +447,7 @@ func (r GetAuthorizationUserRolesRow) RoleNames() ([]rbac.RoleIdentifier, error) } return names, nil } + +func (k CryptoKey) ExpiresAt(keyDuration time.Duration) time.Time { + return k.StartsAt.Add(keyDuration).UTC() +} diff --git a/coderd/database/models.go b/coderd/database/models.go index 9e0283ba859c1..e9bb8e42b8960 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -339,6 +339,67 @@ func AllBuildReasonValues() []BuildReason { } } +type CryptoKeyFeature string + +const ( + CryptoKeyFeatureWorkspaceApps CryptoKeyFeature = "workspace_apps" + CryptoKeyFeatureOidcConvert CryptoKeyFeature = "oidc_convert" + CryptoKeyFeaturePeerReconnect CryptoKeyFeature = "peer_reconnect" +) + +func (e *CryptoKeyFeature) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = CryptoKeyFeature(s) + case string: + *e = CryptoKeyFeature(s) + default: + return fmt.Errorf("unsupported scan type for CryptoKeyFeature: %T", src) + } + return nil +} + +type NullCryptoKeyFeature struct { + CryptoKeyFeature CryptoKeyFeature `json:"crypto_key_feature"` + Valid bool `json:"valid"` // Valid is true if CryptoKeyFeature is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullCryptoKeyFeature) Scan(value interface{}) error { + if value == nil { + ns.CryptoKeyFeature, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.CryptoKeyFeature.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullCryptoKeyFeature) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.CryptoKeyFeature), nil +} + +func (e CryptoKeyFeature) Valid() bool { + switch e { + case CryptoKeyFeatureWorkspaceApps, + CryptoKeyFeatureOidcConvert, + CryptoKeyFeaturePeerReconnect: + return true + } + return false +} + +func AllCryptoKeyFeatureValues() []CryptoKeyFeature { + return []CryptoKeyFeature{ + CryptoKeyFeatureWorkspaceApps, + CryptoKeyFeatureOidcConvert, + CryptoKeyFeaturePeerReconnect, + } +} + type DisplayApp string const ( @@ -2043,6 +2104,15 @@ type AuditLog struct { ResourceIcon string `db:"resource_icon" json:"resource_icon"` } +type CryptoKey struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + Secret sql.NullString `db:"secret" json:"secret"` + SecretKeyID sql.NullString `db:"secret_key_id" json:"secret_key_id"` + StartsAt time.Time `db:"starts_at" json:"starts_at"` + DeletesAt sql.NullTime `db:"deletes_at" json:"deletes_at"` +} + // Custom roles allow dynamic roles expanded at runtime type CustomRole struct { Name string `db:"name" json:"name"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index ee9a64f12076d..8e8f587d302c8 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -69,6 +69,7 @@ type sqlcQuerier interface { DeleteAllTailnetTunnels(ctx context.Context, arg DeleteAllTailnetTunnelsParams) error DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error DeleteCoordinator(ctx context.Context, id uuid.UUID) error + DeleteCryptoKey(ctx context.Context, arg DeleteCryptoKeyParams) (CryptoKey, error) DeleteCustomRole(ctx context.Context, arg DeleteCustomRoleParams) error DeleteExternalAuthLink(ctx context.Context, arg DeleteExternalAuthLinkParams) error DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error @@ -131,6 +132,8 @@ type sqlcQuerier interface { // are included. GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) + GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error) + GetCryptoKeys(ctx context.Context) ([]CryptoKey, error) GetDBCryptKeys(ctx context.Context) ([]DBCryptKey, error) GetDERPMeshKey(ctx context.Context) (string, error) GetDefaultOrganization(ctx context.Context) (Organization, error) @@ -159,6 +162,7 @@ type sqlcQuerier interface { GetHungProvisionerJobs(ctx context.Context, updatedAt time.Time) ([]ProvisionerJob, error) GetJFrogXrayScanByWorkspaceAndAgentID(ctx context.Context, arg GetJFrogXrayScanByWorkspaceAndAgentIDParams) (JfrogXrayScan, error) GetLastUpdateCheck(ctx context.Context) (string, error) + GetLatestCryptoKeyByFeature(ctx context.Context, feature CryptoKeyFeature) (CryptoKey, error) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error) GetLatestWorkspaceBuilds(ctx context.Context) ([]WorkspaceBuild, error) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceBuild, error) @@ -337,6 +341,7 @@ type sqlcQuerier interface { // every member of the org. InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error) InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error) + InsertCryptoKey(ctx context.Context, arg InsertCryptoKeyParams) (CryptoKey, error) InsertCustomRole(ctx context.Context, arg InsertCustomRoleParams) (CustomRole, error) InsertDBCryptKey(ctx context.Context, arg InsertDBCryptKeyParams) error InsertDERPMeshKey(ctx context.Context, value string) error @@ -410,6 +415,7 @@ type sqlcQuerier interface { UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error + UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error) UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error) UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error) UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyParams) (GitSSHKey, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 6831415907b67..b5c198c94da42 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -761,6 +761,186 @@ func (q *sqlQuerier) InsertAuditLog(ctx context.Context, arg InsertAuditLogParam return i, err } +const deleteCryptoKey = `-- name: DeleteCryptoKey :one +UPDATE crypto_keys +SET secret = NULL +WHERE feature = $1 AND sequence = $2 RETURNING feature, sequence, secret, secret_key_id, starts_at, deletes_at +` + +type DeleteCryptoKeyParams struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` +} + +func (q *sqlQuerier) DeleteCryptoKey(ctx context.Context, arg DeleteCryptoKeyParams) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, deleteCryptoKey, arg.Feature, arg.Sequence) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + +const getCryptoKeyByFeatureAndSequence = `-- name: GetCryptoKeyByFeatureAndSequence :one +SELECT feature, sequence, secret, secret_key_id, starts_at, deletes_at +FROM crypto_keys +WHERE feature = $1 + AND sequence = $2 + AND secret IS NOT NULL +` + +type GetCryptoKeyByFeatureAndSequenceParams struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` +} + +func (q *sqlQuerier) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, getCryptoKeyByFeatureAndSequence, arg.Feature, arg.Sequence) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + +const getCryptoKeys = `-- name: GetCryptoKeys :many +SELECT feature, sequence, secret, secret_key_id, starts_at, deletes_at +FROM crypto_keys +WHERE secret IS NOT NULL +` + +func (q *sqlQuerier) GetCryptoKeys(ctx context.Context) ([]CryptoKey, error) { + rows, err := q.db.QueryContext(ctx, getCryptoKeys) + if err != nil { + return nil, err + } + defer rows.Close() + var items []CryptoKey + for rows.Next() { + var i CryptoKey + if err := rows.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getLatestCryptoKeyByFeature = `-- name: GetLatestCryptoKeyByFeature :one +SELECT feature, sequence, secret, secret_key_id, starts_at, deletes_at +FROM crypto_keys +WHERE feature = $1 +ORDER BY sequence DESC +LIMIT 1 +` + +func (q *sqlQuerier) GetLatestCryptoKeyByFeature(ctx context.Context, feature CryptoKeyFeature) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, getLatestCryptoKeyByFeature, feature) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + +const insertCryptoKey = `-- name: InsertCryptoKey :one +INSERT INTO crypto_keys ( + feature, + sequence, + secret, + starts_at, + secret_key_id +) VALUES ( + $1, + $2, + $3, + $4, + $5 +) RETURNING feature, sequence, secret, secret_key_id, starts_at, deletes_at +` + +type InsertCryptoKeyParams struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + Secret sql.NullString `db:"secret" json:"secret"` + StartsAt time.Time `db:"starts_at" json:"starts_at"` + SecretKeyID sql.NullString `db:"secret_key_id" json:"secret_key_id"` +} + +func (q *sqlQuerier) InsertCryptoKey(ctx context.Context, arg InsertCryptoKeyParams) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, insertCryptoKey, + arg.Feature, + arg.Sequence, + arg.Secret, + arg.StartsAt, + arg.SecretKeyID, + ) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + +const updateCryptoKeyDeletesAt = `-- name: UpdateCryptoKeyDeletesAt :one +UPDATE crypto_keys +SET deletes_at = $3 +WHERE feature = $1 AND sequence = $2 RETURNING feature, sequence, secret, secret_key_id, starts_at, deletes_at +` + +type UpdateCryptoKeyDeletesAtParams struct { + Feature CryptoKeyFeature `db:"feature" json:"feature"` + Sequence int32 `db:"sequence" json:"sequence"` + DeletesAt sql.NullTime `db:"deletes_at" json:"deletes_at"` +} + +func (q *sqlQuerier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error) { + row := q.db.QueryRowContext(ctx, updateCryptoKeyDeletesAt, arg.Feature, arg.Sequence, arg.DeletesAt) + var i CryptoKey + err := row.Scan( + &i.Feature, + &i.Sequence, + &i.Secret, + &i.SecretKeyID, + &i.StartsAt, + &i.DeletesAt, + ) + return i, err +} + const getDBCryptKeys = `-- name: GetDBCryptKeys :many SELECT number, active_key_digest, revoked_key_digest, created_at, revoked_at, test FROM dbcrypt_keys ORDER BY number ASC ` diff --git a/coderd/database/queries/crypto_keys.sql b/coderd/database/queries/crypto_keys.sql new file mode 100644 index 0000000000000..39dc8175f95ab --- /dev/null +++ b/coderd/database/queries/crypto_keys.sql @@ -0,0 +1,44 @@ +-- name: GetCryptoKeys :many +SELECT * +FROM crypto_keys +WHERE secret IS NOT NULL; + +-- name: GetLatestCryptoKeyByFeature :one +SELECT * +FROM crypto_keys +WHERE feature = $1 +ORDER BY sequence DESC +LIMIT 1; + + +-- name: GetCryptoKeyByFeatureAndSequence :one +SELECT * +FROM crypto_keys +WHERE feature = $1 + AND sequence = $2 + AND secret IS NOT NULL; + +-- name: DeleteCryptoKey :one +UPDATE crypto_keys +SET secret = NULL +WHERE feature = $1 AND sequence = $2 RETURNING *; + +-- name: InsertCryptoKey :one +INSERT INTO crypto_keys ( + feature, + sequence, + secret, + starts_at, + secret_key_id +) VALUES ( + $1, + $2, + $3, + $4, + $5 +) RETURNING *; + +-- name: UpdateCryptoKeyDeletesAt :one +UPDATE crypto_keys +SET deletes_at = $3 +WHERE feature = $1 AND sequence = $2 RETURNING *; diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index b3bf72f8178b6..01a811af9c5ed 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -9,6 +9,7 @@ const ( UniqueAgentStatsPkey UniqueConstraint = "agent_stats_pkey" // ALTER TABLE ONLY workspace_agent_stats ADD CONSTRAINT agent_stats_pkey PRIMARY KEY (id); UniqueAPIKeysPkey UniqueConstraint = "api_keys_pkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_pkey PRIMARY KEY (id); UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id); + UniqueCryptoKeysPkey UniqueConstraint = "crypto_keys_pkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence); UniqueCustomRolesUniqueKey UniqueConstraint = "custom_roles_unique_key" // ALTER TABLE ONLY custom_roles ADD CONSTRAINT custom_roles_unique_key UNIQUE (name, organization_id); UniqueDbcryptKeysActiveKeyDigestKey UniqueConstraint = "dbcrypt_keys_active_key_digest_key" // ALTER TABLE ONLY dbcrypt_keys ADD CONSTRAINT dbcrypt_keys_active_key_digest_key UNIQUE (active_key_digest); UniqueDbcryptKeysPkey UniqueConstraint = "dbcrypt_keys_pkey" // ALTER TABLE ONLY dbcrypt_keys ADD CONSTRAINT dbcrypt_keys_pkey PRIMARY KEY (number); diff --git a/coderd/keyrotate/rotate.go b/coderd/keyrotate/rotate.go new file mode 100644 index 0000000000000..e9e4305a99aab --- /dev/null +++ b/coderd/keyrotate/rotate.go @@ -0,0 +1,243 @@ +package keyrotate + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/hex" + "time" + + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/quartz" +) + +const ( + WorkspaceAppsTokenDuration = time.Minute + OIDCConvertTokenDuration = time.Minute * 5 + PeerReconnectTokenDuration = time.Hour * 24 +) + +type KeyRotator struct { + DB database.Store + KeyDuration time.Duration + Clock quartz.Clock + Logger slog.Logger + ScanInterval time.Duration + ResultsCh chan []database.CryptoKey + features []database.CryptoKeyFeature +} + +func (k *KeyRotator) Start(ctx context.Context) { + ticker := k.Clock.NewTicker(k.ScanInterval) + defer ticker.Stop() + + if len(k.features) == 0 { + k.features = database.AllCryptoKeyFeatureValues() + } + + for { + modifiedKeys, err := k.rotateKeys(ctx) + if err != nil { + k.Logger.Error(ctx, "failed to rotate keys", slog.Error(err)) + } + + // This should only be called in test code so we don't + // both to select on the push. + if k.ResultsCh != nil { + k.ResultsCh <- modifiedKeys + } + + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + } +} + +// rotateKeys checks for keys nearing expiration and rotates them if necessary. +func (k *KeyRotator) rotateKeys(ctx context.Context) ([]database.CryptoKey, error) { + var modifiedKeys []database.CryptoKey + return modifiedKeys, database.ReadModifyUpdate(k.DB, func(tx database.Store) error { + // Reset the modified keys slice for each iteration. + modifiedKeys = make([]database.CryptoKey, 0) + keys, err := tx.GetCryptoKeys(ctx) + if err != nil { + return xerrors.Errorf("get keys: %w", err) + } + + // Groups the keys by feature so that we can + // ensure we have at least one key for each feature. + keysByFeature := keysByFeature(keys, k.features) + now := dbtime.Time(k.Clock.Now().UTC()) + for feature, keys := range keysByFeature { + // It's possible there are no keys if someone + // has manually deleted all the keys. + if len(keys) == 0 { + k.Logger.Info(ctx, "no valid keys detected, inserting new key", + slog.F("feature", feature), + ) + newKey, err := k.insertNewKey(ctx, tx, feature, now) + if err != nil { + return xerrors.Errorf("insert new key: %w", err) + } + modifiedKeys = append(modifiedKeys, newKey) + } + + for _, key := range keys { + switch { + case shouldDeleteKey(key, now): + deletedKey, err := tx.DeleteCryptoKey(ctx, database.DeleteCryptoKeyParams{ + Feature: key.Feature, + Sequence: key.Sequence, + }) + if err != nil { + return xerrors.Errorf("delete key: %w", err) + } + modifiedKeys = append(modifiedKeys, deletedKey) + case shouldRotateKey(key, k.KeyDuration, now): + rotatedKeys, err := k.rotateKey(ctx, tx, key) + if err != nil { + return xerrors.Errorf("rotate key: %w", err) + } + modifiedKeys = append(modifiedKeys, rotatedKeys...) + default: + continue + } + } + } + return nil + }) +} + +func (k *KeyRotator) insertNewKey(ctx context.Context, tx database.Store, feature database.CryptoKeyFeature, now time.Time) (database.CryptoKey, error) { + secret, err := generateNewSecret(feature) + if err != nil { + return database.CryptoKey{}, xerrors.Errorf("generate new secret: %w", err) + } + + latestKey, err := tx.GetLatestCryptoKeyByFeature(ctx, feature) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + return database.CryptoKey{}, xerrors.Errorf("get latest key: %w", err) + } + + newKey, err := tx.InsertCryptoKey(ctx, database.InsertCryptoKeyParams{ + Feature: feature, + // We'll assume that the first key we insert is 1. + Sequence: latestKey.Sequence + 1, + Secret: sql.NullString{ + String: secret, + Valid: true, + }, + StartsAt: now.UTC(), + }) + if err != nil { + return database.CryptoKey{}, xerrors.Errorf("inserting new key: %w", err) + } + + k.Logger.Info(ctx, "inserted new key for feature", slog.F("feature", feature)) + return newKey, nil +} + +func (k *KeyRotator) rotateKey(ctx context.Context, tx database.Store, key database.CryptoKey) ([]database.CryptoKey, error) { + // The starts at of the new key is the expiration of the old key. + newStartsAt := key.ExpiresAt(k.KeyDuration) + + secret, err := generateNewSecret(key.Feature) + if err != nil { + return nil, xerrors.Errorf("generate new secret: %w", err) + } + + // Insert new key + newKey, err := tx.InsertCryptoKey(ctx, database.InsertCryptoKeyParams{ + Feature: key.Feature, + Sequence: key.Sequence + 1, + Secret: sql.NullString{ + String: secret, + Valid: true, + }, + StartsAt: newStartsAt.UTC(), + }) + if err != nil { + return nil, xerrors.Errorf("inserting new key: %w", err) + } + + // Set old key's deletes_at + deletesAt := newStartsAt.Add(time.Hour).Add(tokenDuration(key.Feature)) + + updatedKey, err := tx.UpdateCryptoKeyDeletesAt(ctx, database.UpdateCryptoKeyDeletesAtParams{ + Feature: key.Feature, + Sequence: key.Sequence, + DeletesAt: sql.NullTime{ + Time: deletesAt.UTC(), + Valid: true, + }, + }) + if err != nil { + return nil, xerrors.Errorf("update old key's deletes_at: %w", err) + } + + return []database.CryptoKey{updatedKey, newKey}, nil +} + +func generateNewSecret(feature database.CryptoKeyFeature) (string, error) { + switch feature { + case database.CryptoKeyFeatureWorkspaceApps: + return generateKey(96) + case database.CryptoKeyFeatureOidcConvert: + return generateKey(32) + case database.CryptoKeyFeaturePeerReconnect: + return generateKey(64) + } + return "", xerrors.Errorf("unknown feature: %s", feature) +} + +func generateKey(length int) (string, error) { + b := make([]byte, length) + _, err := rand.Read(b) + if err != nil { + return "", xerrors.Errorf("rand read: %w", err) + } + return hex.EncodeToString(b), nil +} + +func tokenDuration(feature database.CryptoKeyFeature) time.Duration { + switch feature { + case database.CryptoKeyFeatureWorkspaceApps: + return WorkspaceAppsTokenDuration + case database.CryptoKeyFeatureOidcConvert: + return OIDCConvertTokenDuration + case database.CryptoKeyFeaturePeerReconnect: + return PeerReconnectTokenDuration + default: + return 0 + } +} + +func shouldDeleteKey(key database.CryptoKey, now time.Time) bool { + return key.DeletesAt.Valid && key.DeletesAt.Time.UTC().After(now.UTC()) +} + +func shouldRotateKey(key database.CryptoKey, keyDuration time.Duration, now time.Time) bool { + // If deletes_at is set, we've already inserted a key. + if key.DeletesAt.Valid { + return false + } + expirationTime := key.ExpiresAt(keyDuration) + return now.Add(time.Hour).UTC().After(expirationTime.UTC()) +} + +func keysByFeature(keys []database.CryptoKey, features []database.CryptoKeyFeature) map[database.CryptoKeyFeature][]database.CryptoKey { + m := map[database.CryptoKeyFeature][]database.CryptoKey{} + for _, feature := range features { + m[feature] = []database.CryptoKey{} + } + for _, key := range keys { + m[key.Feature] = append(m[key.Feature], key) + } + return m +} diff --git a/coderd/keyrotate/rotate_internal_test.go b/coderd/keyrotate/rotate_internal_test.go new file mode 100644 index 0000000000000..a0f7e9522507a --- /dev/null +++ b/coderd/keyrotate/rotate_internal_test.go @@ -0,0 +1,186 @@ +package keyrotate + +import ( + "database/sql" + "encoding/hex" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func Test_rotateKeys(t *testing.T) { + t.Parallel() + + t.Run("RotatesKeysNearExpiration", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 7 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + resultsCh = make(chan []database.CryptoKey, 1) + ) + + kr := &KeyRotator{ + DB: db, + KeyDuration: keyDuration, + Clock: clock, + Logger: logger, + ScanInterval: 0, + ResultsCh: resultsCh, + features: []database.CryptoKeyFeature{ + database.CryptoKeyFeatureWorkspaceApps, + }, + } + + now := dbnow(clock) + + // Seed the database with an existing key. + oldKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: now, + Sequence: 15, + }) + + // Advance the window to just inside rotation time. + _ = clock.Advance(keyDuration - time.Minute*59) + keys, err := kr.rotateKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 2) + + now = dbnow(clock) + expectedDeletesAt := oldKey.ExpiresAt(keyDuration).Add(WorkspaceAppsTokenDuration + time.Hour) + + // Fetch the old key, it should have an expires_at now. + oldKey, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: oldKey.Feature, + Sequence: oldKey.Sequence, + }) + require.NoError(t, err) + require.Equal(t, oldKey.DeletesAt.Time.UTC(), expectedDeletesAt) + + // The new key should be created and have a starts_at of the old key's expires_at. + newKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + Sequence: oldKey.Sequence + 1, + }) + require.NoError(t, err) + requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, oldKey.ExpiresAt(keyDuration), time.Time{}, oldKey.Sequence+1) + + // Advance the clock just past the keys delete time. + clock.Advance(oldKey.DeletesAt.Time.UTC().Sub(now) - time.Second) + + // We should have deleted the old key. + keys, err = kr.rotateKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 1) + + // The old key should be "deleted". + _, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: oldKey.Feature, + Sequence: oldKey.Sequence, + }) + require.ErrorIs(t, err, sql.ErrNoRows) + }) + + t.Run("DoesNotRotateValidKeys", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 7 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + resultsCh = make(chan []database.CryptoKey, 1) + ) + + kr := &KeyRotator{ + DB: db, + KeyDuration: keyDuration, + Clock: clock, + Logger: logger, + ScanInterval: 0, + ResultsCh: resultsCh, + features: []database.CryptoKeyFeature{ + database.CryptoKeyFeatureWorkspaceApps, + }, + } + + now := dbnow(clock) + + // Seed the database with an existing key + existingKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: now, + Sequence: 1, + }) + + // Advance the clock by 6 days, 23 hours. Once we + // breach the last hour we will insert a new key. + clock.Advance(keyDuration - time.Hour) + + keys, err := kr.rotateKeys(ctx) + require.NoError(t, err) + require.Empty(t, keys) + + // Verify that the existing key is still the only key in the database + dbKeys, err := db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, dbKeys, 1) + requireKey(t, dbKeys[0], existingKey.Feature, existingKey.StartsAt.UTC(), existingKey.DeletesAt.Time.UTC(), existingKey.Sequence) + }) + + t.Run("DeletesExpiredKeys", func(t *testing.T) { + t.Parallel() + // TODO: Implement test for deleting expired keys + }) + + t.Run("HandlesMultipleKeyTypes", func(t *testing.T) { + t.Parallel() + // TODO: Implement test for handling multiple key types + }) + + t.Run("GracefullyHandlesErrors", func(t *testing.T) { + t.Parallel() + // TODO: Implement test for error handling + }) +} + +func dbnow(c quartz.Clock) time.Time { + return dbtime.Time(c.Now().UTC()) +} + +func requireKey(t *testing.T, key database.CryptoKey, feature database.CryptoKeyFeature, startsAt time.Time, deletesAt time.Time, sequence int32) { + t.Helper() + require.Equal(t, feature, key.Feature) + require.Equal(t, startsAt, key.StartsAt.UTC()) + require.Equal(t, deletesAt, key.DeletesAt.Time.UTC()) + require.Equal(t, sequence, key.Sequence) + + secret, err := hex.DecodeString(key.Secret.String) + require.NoError(t, err) + + switch key.Feature { + case database.CryptoKeyFeatureOidcConvert: + require.Len(t, secret, 32) + case database.CryptoKeyFeatureWorkspaceApps: + require.Len(t, secret, 96) + case database.CryptoKeyFeaturePeerReconnect: + require.Len(t, secret, 64) + default: + t.Fatalf("unknown key feature: %s", key.Feature) + } +} diff --git a/coderd/keyrotate/rotate_test.go b/coderd/keyrotate/rotate_test.go new file mode 100644 index 0000000000000..5f82a0647c99d --- /dev/null +++ b/coderd/keyrotate/rotate_test.go @@ -0,0 +1,60 @@ +package keyrotate_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" +) + +func TestKeyRotator(t *testing.T) { + t.Run("NoExistingKeys", func(t *testing.T) { + // t.Parallel() + + // var ( + // db, _ = dbtestutil.NewDB(t) + // clock = quartz.NewMock(t) + // logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + // ctx = testutil.Context(t, testutil.WaitShort) + // resultsCh = make(chan []database.CryptoKey, 1) + // ) + + // kr := &KeyRotator{ + // DB: db, + // KeyDuration: 0, + // Clock: clock, + // Logger: logger, + // ScanInterval: 0, + // ResultsCh: resultsCh, + // } + + // now := dbnow(clock) + // keys, err := kr.rotateKeys(ctx) + // require.NoError(t, err) + // require.Len(t, keys, len(database.AllCryptoKeyFeatureValues())) + + // // Fetch the keys from the database and ensure they + // // are as expected. + // dbkeys, err := db.GetCryptoKeys(ctx) + // require.NoError(t, err) + // require.Equal(t, keys, dbkeys) + // requireContainsAllFeatures(t, keys) + // for _, key := range keys { + // requireKey(t, key, key.Feature, now, time.Time{}, 1) + // } + }) + +} + +func requireContainsAllFeatures(t *testing.T, keys []database.CryptoKey) { + t.Helper() + + features := make(map[database.CryptoKeyFeature]bool) + for _, key := range keys { + features[key.Feature] = true + } + require.True(t, features[database.CryptoKeyFeatureOidcConvert]) + require.True(t, features[database.CryptoKeyFeatureWorkspaceApps]) + require.True(t, features[database.CryptoKeyFeaturePeerReconnect]) +} diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go index ec56a4897a1e3..2717ef1d48188 100644 --- a/enterprise/dbcrypt/dbcrypt.go +++ b/enterprise/dbcrypt/dbcrypt.go @@ -261,6 +261,31 @@ func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.U return link, nil } +func (db *dbCrypt) GetCryptoKeyByFeatureAndSequence(ctx context.Context, params database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { + key, err := db.Store.GetCryptoKeyByFeatureAndSequence(ctx, params) + if err != nil { + return database.CryptoKey{}, err + } + if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil { + return database.CryptoKey{}, err + } + return key, nil +} + +func (db *dbCrypt) InsertCryptoKey(ctx context.Context, params database.InsertCryptoKeyParams) (database.CryptoKey, error) { + if err := db.encryptField(¶ms.Secret.String, ¶ms.SecretKeyID); err != nil { + return database.CryptoKey{}, err + } + key, err := db.Store.InsertCryptoKey(ctx, params) + if err != nil { + return database.CryptoKey{}, err + } + if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil { + return database.CryptoKey{}, err + } + return key, nil +} + func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error { // If no cipher is loaded, then we can't encrypt anything! if db.ciphers == nil || db.primaryCipherDigest == "" { diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index 37fcc8cae55a3..7dad716b8139b 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -349,6 +349,51 @@ func TestExternalAuthLinks(t *testing.T) { }) } +func TestCryptoKeys(t *testing.T) { + t.Parallel() + ctx := context.Background() + db, crypt, ciphers := setup(t) + + // We don't write a GetCryptoKeyByFeatureAndSequence test + // because it's basically the same as InsertCryptoKey. + t.Run("InsertCryptoKey", func(t *testing.T) { + t.Parallel() + key := dbgen.CryptoKey(t, crypt, database.CryptoKey{ + Secret: sql.NullString{String: "test", Valid: true}, + }) + require.Equal(t, "test", key.Secret.String) + + key, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: key.Feature, + Sequence: key.Sequence, + }) + require.NoError(t, err) + require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String) + requireEncryptedEquals(t, ciphers[0], key.Secret.String, "test") + }) + t.Run("DecryptErr", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + key := dbgen.CryptoKey(t, db, database.CryptoKey{ + Secret: sql.NullString{ + String: fakeBase64RandomData(t, 32), + Valid: true, + }, + SecretKeyID: sql.NullString{ + String: ciphers[0].HexDigest(), + Valid: true, + }, + }) + _, err := crypt.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: key.Feature, + Sequence: key.Sequence, + }) + require.Error(t, err, "expected an error") + var derr *DecryptFailedError + require.ErrorAs(t, err, &derr, "expected a decrypt error") + }) +} + func TestNew(t *testing.T) { t.Parallel() pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy