diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 6a768fa9b4dfd..d12b9aba23863 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1680,6 +1680,10 @@ func (q *querier) GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt return q.db.GetProvisionerJobsCreatedAfter(ctx, createdAt) } +func (q *querier) GetProvisionerKeyByHashedSecret(ctx context.Context, hashedSecret []byte) (database.ProvisionerKey, error) { + return fetch(q.log, q.auth, q.db.GetProvisionerKeyByHashedSecret)(ctx, hashedSecret) +} + func (q *querier) GetProvisionerKeyByID(ctx context.Context, id uuid.UUID) (database.ProvisionerKey, error) { return fetch(q.log, q.auth, q.db.GetProvisionerKeyByID)(ctx, id) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 6514d2f0dfeb0..0ec7d2b17fb9c 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1825,6 +1825,11 @@ func (s *MethodTestSuite) TestProvisionerKeys() { pk := dbgen.ProvisionerKey(s.T(), db, database.ProvisionerKey{OrganizationID: org.ID}) check.Args(pk.ID).Asserts(pk, policy.ActionRead).Returns(pk) })) + s.Run("GetProvisionerKeyByHashedSecret", s.Subtest(func(db database.Store, check *expects) { + org := dbgen.Organization(s.T(), db, database.Organization{}) + pk := dbgen.ProvisionerKey(s.T(), db, database.ProvisionerKey{OrganizationID: org.ID, HashedSecret: []byte("foo")}) + check.Args([]byte("foo")).Asserts(pk, policy.ActionRead).Returns(pk) + })) s.Run("GetProvisionerKeyByName", s.Subtest(func(db database.Store, check *expects) { org := dbgen.Organization(s.T(), db, database.Organization{}) pk := dbgen.ProvisionerKey(s.T(), db, database.ProvisionerKey{OrganizationID: org.ID}) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 827d99a2c14df..8d1088616f6bc 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -3240,6 +3240,19 @@ func (q *FakeQuerier) GetProvisionerJobsCreatedAfter(_ context.Context, after ti return jobs, nil } +func (q *FakeQuerier) GetProvisionerKeyByHashedSecret(_ context.Context, hashedSecret []byte) (database.ProvisionerKey, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, key := range q.provisionerKeys { + if bytes.Equal(key.HashedSecret, hashedSecret) { + return key, nil + } + } + + return database.ProvisionerKey{}, sql.ErrNoRows +} + func (q *FakeQuerier) GetProvisionerKeyByID(_ context.Context, id uuid.UUID) (database.ProvisionerKey, error) { q.mutex.RLock() defer q.mutex.RUnlock() diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index e6642da53974f..f987d0505653b 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -914,6 +914,13 @@ func (m metricsStore) GetProvisionerJobsCreatedAfter(ctx context.Context, create return jobs, err } +func (m metricsStore) GetProvisionerKeyByHashedSecret(ctx context.Context, hashedSecret []byte) (database.ProvisionerKey, error) { + start := time.Now() + r0, r1 := m.s.GetProvisionerKeyByHashedSecret(ctx, hashedSecret) + m.queryLatencies.WithLabelValues("GetProvisionerKeyByHashedSecret").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetProvisionerKeyByID(ctx context.Context, id uuid.UUID) (database.ProvisionerKey, error) { start := time.Now() r0, r1 := m.s.GetProvisionerKeyByID(ctx, id) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 8517a7a8e5f21..78cd95a69cde5 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1840,6 +1840,21 @@ func (mr *MockStoreMockRecorder) GetProvisionerJobsCreatedAfter(arg0, arg1 any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerJobsCreatedAfter", reflect.TypeOf((*MockStore)(nil).GetProvisionerJobsCreatedAfter), arg0, arg1) } +// GetProvisionerKeyByHashedSecret mocks base method. +func (m *MockStore) GetProvisionerKeyByHashedSecret(arg0 context.Context, arg1 []byte) (database.ProvisionerKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProvisionerKeyByHashedSecret", arg0, arg1) + ret0, _ := ret[0].(database.ProvisionerKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProvisionerKeyByHashedSecret indicates an expected call of GetProvisionerKeyByHashedSecret. +func (mr *MockStoreMockRecorder) GetProvisionerKeyByHashedSecret(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerKeyByHashedSecret", reflect.TypeOf((*MockStore)(nil).GetProvisionerKeyByHashedSecret), arg0, arg1) +} + // GetProvisionerKeyByID mocks base method. func (m *MockStore) GetProvisionerKeyByID(arg0 context.Context, arg1 uuid.UUID) (database.ProvisionerKey, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 78ebf958739d6..9d0494813e306 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -186,6 +186,7 @@ type sqlcQuerier interface { GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids []uuid.UUID) ([]GetProvisionerJobsByIDsWithQueuePositionRow, error) GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]ProvisionerJob, error) + GetProvisionerKeyByHashedSecret(ctx context.Context, hashedSecret []byte) (ProvisionerKey, error) GetProvisionerKeyByID(ctx context.Context, id uuid.UUID) (ProvisionerKey, error) GetProvisionerKeyByName(ctx context.Context, arg GetProvisionerKeyByNameParams) (ProvisionerKey, error) GetProvisionerLogsAfterID(ctx context.Context, arg GetProvisionerLogsAfterIDParams) ([]ProvisionerJobLog, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index f383f2e7c0d5d..2e3a5c9892d40 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -5531,6 +5531,29 @@ func (q *sqlQuerier) DeleteProvisionerKey(ctx context.Context, id uuid.UUID) err return err } +const getProvisionerKeyByHashedSecret = `-- name: GetProvisionerKeyByHashedSecret :one +SELECT + id, created_at, organization_id, name, hashed_secret, tags +FROM + provisioner_keys +WHERE + hashed_secret = $1 +` + +func (q *sqlQuerier) GetProvisionerKeyByHashedSecret(ctx context.Context, hashedSecret []byte) (ProvisionerKey, error) { + row := q.db.QueryRowContext(ctx, getProvisionerKeyByHashedSecret, hashedSecret) + var i ProvisionerKey + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.OrganizationID, + &i.Name, + &i.HashedSecret, + &i.Tags, + ) + return i, err +} + const getProvisionerKeyByID = `-- name: GetProvisionerKeyByID :one SELECT id, created_at, organization_id, name, hashed_secret, tags diff --git a/coderd/database/queries/provisionerkeys.sql b/coderd/database/queries/provisionerkeys.sql index ac41eb2d444d2..cb4c763f1061e 100644 --- a/coderd/database/queries/provisionerkeys.sql +++ b/coderd/database/queries/provisionerkeys.sql @@ -19,6 +19,14 @@ FROM WHERE id = $1; +-- name: GetProvisionerKeyByHashedSecret :one +SELECT + * +FROM + provisioner_keys +WHERE + hashed_secret = $1; + -- name: GetProvisionerKeyByName :one SELECT * diff --git a/coderd/httpmw/provisionerdaemon.go b/coderd/httpmw/provisionerdaemon.go index 243af82598ff8..cac4aa0cba0a9 100644 --- a/coderd/httpmw/provisionerdaemon.go +++ b/coderd/httpmw/provisionerdaemon.go @@ -71,16 +71,17 @@ func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig) fu return } - id, keyValue, err := provisionerkey.Parse(key) + err := provisionerkey.Validate(key) if err != nil { - handleOptional(http.StatusUnauthorized, codersdk.Response{ + handleOptional(http.StatusBadRequest, codersdk.Response{ Message: "provisioner daemon key invalid", + Detail: err.Error(), }) return } - + hashedKey := provisionerkey.HashSecret(key) // nolint:gocritic // System must check if the provisioner key is valid. - pk, err := opts.DB.GetProvisionerKeyByID(dbauthz.AsSystemRestricted(ctx), id) + pk, err := opts.DB.GetProvisionerKeyByHashedSecret(dbauthz.AsSystemRestricted(ctx), hashedKey) if err != nil { if httpapi.Is404Error(err) { handleOptional(http.StatusUnauthorized, codersdk.Response{ @@ -90,12 +91,13 @@ func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig) fu } handleOptional(http.StatusInternalServerError, codersdk.Response{ - Message: "get provisioner daemon key: " + err.Error(), + Message: "get provisioner daemon key", + Detail: err.Error(), }) return } - if provisionerkey.Compare(pk.HashedSecret, provisionerkey.HashSecret(keyValue)) { + if provisionerkey.Compare(pk.HashedSecret, hashedKey) { handleOptional(http.StatusUnauthorized, codersdk.Response{ Message: "provisioner daemon key invalid", }) diff --git a/coderd/provisionerkey/provisionerkey.go b/coderd/provisionerkey/provisionerkey.go index 5be3658f6a5be..bfd70fb0295e0 100644 --- a/coderd/provisionerkey/provisionerkey.go +++ b/coderd/provisionerkey/provisionerkey.go @@ -3,8 +3,6 @@ package provisionerkey import ( "crypto/sha256" "crypto/subtle" - "fmt" - "strings" "github.com/google/uuid" "golang.org/x/xerrors" @@ -14,41 +12,36 @@ import ( "github.com/coder/coder/v2/cryptorand" ) +const ( + secretLength = 43 +) + func New(organizationID uuid.UUID, name string, tags map[string]string) (database.InsertProvisionerKeyParams, string, error) { - id := uuid.New() - secret, err := cryptorand.HexString(64) + secret, err := cryptorand.String(secretLength) if err != nil { - return database.InsertProvisionerKeyParams{}, "", xerrors.Errorf("generate token: %w", err) + return database.InsertProvisionerKeyParams{}, "", xerrors.Errorf("generate secret: %w", err) } - hashedSecret := HashSecret(secret) - token := fmt.Sprintf("%s:%s", id, secret) if tags == nil { tags = map[string]string{} } return database.InsertProvisionerKeyParams{ - ID: id, + ID: uuid.New(), CreatedAt: dbtime.Now(), OrganizationID: organizationID, Name: name, - HashedSecret: hashedSecret, + HashedSecret: HashSecret(secret), Tags: tags, - }, token, nil + }, secret, nil } -func Parse(token string) (uuid.UUID, string, error) { - parts := strings.Split(token, ":") - if len(parts) != 2 { - return uuid.UUID{}, "", xerrors.Errorf("invalid token format") - } - - id, err := uuid.Parse(parts[0]) - if err != nil { - return uuid.UUID{}, "", xerrors.Errorf("parse id: %w", err) +func Validate(token string) error { + if len(token) != secretLength { + return xerrors.Errorf("must be %d characters", secretLength) } - return id, parts[1], nil + return nil } func HashSecret(secret string) []byte { diff --git a/enterprise/cli/provisionerdaemonstart.go b/enterprise/cli/provisionerdaemonstart.go index b0dfff227dbe3..f92b0126c46a7 100644 --- a/enterprise/cli/provisionerdaemonstart.go +++ b/enterprise/cli/provisionerdaemonstart.go @@ -122,9 +122,9 @@ func (r *RootCmd) provisionerDaemonStart() *serpent.Command { if len(rawTags) > 0 { return xerrors.New("cannot provide tags when using provisioner key") } - _, _, err := provisionerkey.Parse(provisionerKey) + err = provisionerkey.Validate(provisionerKey) if err != nil { - return xerrors.Errorf("parse provisioner key: %w", err) + return xerrors.Errorf("validate provisioner key: %w", err) } } diff --git a/enterprise/cli/provisionerkeys_test.go b/enterprise/cli/provisionerkeys_test.go index 5b62b1e9d46fd..47df45ed98596 100644 --- a/enterprise/cli/provisionerkeys_test.go +++ b/enterprise/cli/provisionerkeys_test.go @@ -4,11 +4,11 @@ import ( "strings" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/provisionerkey" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" @@ -58,10 +58,7 @@ func TestProvisionerKeys(t *testing.T) { _ = pty.ReadLine(ctx) key := pty.ReadLine(ctx) require.NotEmpty(t, key) - parts := strings.Split(key, ":") - require.Len(t, parts, 2, "expected 2 parts") - _, err = uuid.Parse(parts[0]) - require.NoError(t, err, "expected token to be a uuid") + require.NoError(t, provisionerkey.Validate(key)) inv, conf = newCLI( t, diff --git a/enterprise/coderd/provisionerdaemons_test.go b/enterprise/coderd/provisionerdaemons_test.go index 451ff2249a15d..a3cf9a23cc75e 100644 --- a/enterprise/coderd/provisionerdaemons_test.go +++ b/enterprise/coderd/provisionerdaemons_test.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "net/http" - "strings" "testing" "github.com/google/uuid" @@ -612,36 +611,12 @@ func TestProvisionerDaemonServe(t *testing.T) { errStatusCode: http.StatusUnauthorized, }, { - name: "WrongKey", + name: "InvalidKey", multiOrgFeatureEnabled: true, multiOrgExperimentEnabled: true, insertParams: insertParams, requestProvisionerKey: "provisionersftw", - errStatusCode: http.StatusUnauthorized, - }, - { - name: "IdOKKeyValueWrong", - multiOrgFeatureEnabled: true, - multiOrgExperimentEnabled: true, - insertParams: insertParams, - requestProvisionerKey: insertParams.ID.String() + ":" + "wrong", - errStatusCode: http.StatusUnauthorized, - }, - { - name: "IdWrongKeyValueOK", - multiOrgFeatureEnabled: true, - multiOrgExperimentEnabled: true, - insertParams: insertParams, - requestProvisionerKey: uuid.NewString() + ":" + token, - errStatusCode: http.StatusUnauthorized, - }, - { - name: "KeyValueOnly", - multiOrgFeatureEnabled: true, - multiOrgExperimentEnabled: true, - insertParams: insertParams, - requestProvisionerKey: strings.Split(token, ":")[1], - errStatusCode: http.StatusUnauthorized, + errStatusCode: http.StatusBadRequest, }, { name: "KeyAndPSK",
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: