Skip to content

test: start migrating dbauthz tests to mocked db #19257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions coderd/database/dbauthz/dbauthz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ import (
"testing"
"time"

"github.com/brianvoe/gofakeit/v7"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"

"cdr.dev/slog"
Expand All @@ -22,6 +24,7 @@ import (
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/notifications"
Expand Down Expand Up @@ -204,14 +207,15 @@ func defaultIPAddress() pqtype.Inet {
}

func (s *MethodTestSuite) TestAPIKey() {
s.Run("DeleteAPIKeyByID", s.Subtest(func(db database.Store, check *expects) {
dbtestutil.DisableForeignKeysAndTriggers(s.T(), db)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙌

key, _ := dbgen.APIKey(s.T(), db, database.APIKey{})
s.Run("DeleteAPIKeyByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
key := testutil.Fake(s.T(), faker, database.APIKey{})
dbm.EXPECT().GetAPIKeyByID(gomock.Any(), key.ID).Return(key, nil).AnyTimes()
dbm.EXPECT().DeleteAPIKeyByID(gomock.Any(), key.ID).Return(nil).AnyTimes()
check.Args(key.ID).Asserts(key, policy.ActionDelete).Returns()
}))
s.Run("GetAPIKeyByID", s.Subtest(func(db database.Store, check *expects) {
dbtestutil.DisableForeignKeysAndTriggers(s.T(), db)
key, _ := dbgen.APIKey(s.T(), db, database.APIKey{})
s.Run("GetAPIKeyByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
key := testutil.Fake(s.T(), faker, database.APIKey{})
dbm.EXPECT().GetAPIKeyByID(gomock.Any(), key.ID).Return(key, nil).AnyTimes()
check.Args(key.ID).Asserts(key, policy.ActionRead).Returns(key)
}))
s.Run("GetAPIKeyByName", s.Subtest(func(db database.Store, check *expects) {
Expand All @@ -234,14 +238,12 @@ func (s *MethodTestSuite) TestAPIKey() {
Asserts(a, policy.ActionRead, b, policy.ActionRead).
Returns(slice.New(a, b))
}))
s.Run("GetAPIKeysByUserID", s.Subtest(func(db database.Store, check *expects) {
u1 := dbgen.User(s.T(), db, database.User{})
u2 := dbgen.User(s.T(), db, database.User{})

keyA, _ := dbgen.APIKey(s.T(), db, database.APIKey{UserID: u1.ID, LoginType: database.LoginTypeToken, TokenName: "key-a"})
keyB, _ := dbgen.APIKey(s.T(), db, database.APIKey{UserID: u1.ID, LoginType: database.LoginTypeToken, TokenName: "key-b"})
_, _ = dbgen.APIKey(s.T(), db, database.APIKey{UserID: u2.ID, LoginType: database.LoginTypeToken})
s.Run("GetAPIKeysByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u1 := testutil.Fake(s.T(), faker, database.User{})
keyA := testutil.Fake(s.T(), faker, database.APIKey{UserID: u1.ID, LoginType: database.LoginTypeToken, TokenName: "key-a"})
keyB := testutil.Fake(s.T(), faker, database.APIKey{UserID: u1.ID, LoginType: database.LoginTypeToken, TokenName: "key-b"})

dbm.EXPECT().GetAPIKeysByUserID(gomock.Any(), gomock.Any()).Return(slice.New(keyA, keyB), nil).AnyTimes()
check.Args(database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: u1.ID}).
Asserts(keyA, policy.ActionRead, keyB, policy.ActionRead).
Returns(slice.New(keyA, keyB))
Expand Down
34 changes: 30 additions & 4 deletions coderd/database/dbauthz/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"
"testing"

"github.com/brianvoe/gofakeit/v7"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
Expand All @@ -20,14 +21,14 @@ import (
"golang.org/x/xerrors"

"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/rbac/policy"

"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/rbac/regosql"
"github.com/coder/coder/v2/coderd/util/slice"
)
Expand Down Expand Up @@ -105,19 +106,44 @@ func (s *MethodTestSuite) TearDownSuite() {

var testActorID = uuid.New()

// Subtest is a helper function that returns a function that can be passed to
// Mocked runs a subtest with a mocked database. Removing the overhead of a real
// postgres database resulting in much faster tests.
func (s *MethodTestSuite) Mocked(testCaseF func(dmb *dbmock.MockStore, faker *gofakeit.Faker, check *expects)) func() {
t := s.T()
mDB := dbmock.NewMockStore(gomock.NewController(t))
mDB.EXPECT().Wrappers().Return([]string{}).AnyTimes()

// Use a constant seed to prevent flakes from random data generation.
faker := gofakeit.New(0)

// The usual Subtest assumes the test setup will use a real database to populate
// with data. In this mocked case, we want to pass the underlying mocked database
// to the test case instead.
return s.SubtestWithDB(mDB, func(_ database.Store, check *expects) {
testCaseF(mDB, faker, check)
})
}

// Subtest starts up a real postgres database for each test case.
// Deprecated: Use 'Mocked' instead for much faster tests.
func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expects)) func() {
t := s.T()
db, _ := dbtestutil.NewDB(t)
return s.SubtestWithDB(db, testCaseF)
}

// SubtestWithDB is a helper function that returns a function that can be passed to
// s.Run(). This function will run the test case for the method that is being
// tested. The check parameter is used to assert the results of the method.
// If the caller does not use the `check` parameter, the test will fail.
func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expects)) func() {
func (s *MethodTestSuite) SubtestWithDB(db database.Store, testCaseF func(db database.Store, check *expects)) func() {
return func() {
t := s.T()
testName := s.T().Name()
names := strings.Split(testName, "/")
methodName := names[len(names)-1]
s.methodAccounting[methodName]++

db, _ := dbtestutil.NewDB(t)
fakeAuthorizer := &coderdtest.FakeAuthorizer{}
rec := &coderdtest.RecordingAuthorizer{
Wrapped: fakeAuthorizer,
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ require (
)

require (
github.com/brianvoe/gofakeit/v7 v7.3.0
github.com/coder/agentapi-sdk-go v0.0.0-20250505131810-560d1d88d225
github.com/coder/aisdk-go v0.0.9
github.com/coder/preview v1.0.3
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,8 @@ github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl
github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bramvdbogaerde/go-scp v1.5.0 h1:a9BinAjTfQh273eh7vd3qUgmBC+bx+3TRDtkZWmIpzM=
github.com/bramvdbogaerde/go-scp v1.5.0/go.mod h1:on2aH5AxaFb2G0N5Vsdy6B0Ml7k9HuHSwfo1y0QzAbQ=
github.com/brianvoe/gofakeit/v7 v7.3.0 h1:TWStf7/lLpAjKw+bqwzeORo9jvrxToWEwp9b1J2vApQ=
github.com/brianvoe/gofakeit/v7 v7.3.0/go.mod h1:QXuPeBw164PJCzCUZVmgpgHJ3Llj49jSLVkKPMtxtxA=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/bytecodealliance/wasmtime-go/v3 v3.0.2 h1:3uZCA/BLTIu+DqCfguByNMJa2HVHpXvjfy0Dy7g6fuA=
Expand Down
67 changes: 67 additions & 0 deletions testutil/faker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package testutil

import (
"reflect"
"testing"

"github.com/brianvoe/gofakeit/v7"
"github.com/stretchr/testify/require"
)

// Fake will populate any zero fields in the provided struct with fake data.
// Non-zero fields will remain unchanged.
// Usage:
//
// key := Fake(t, faker, database.APIKey{
// TokenName: "keep-my-name",
// })
func Fake[T any](t *testing.T, faker *gofakeit.Faker, seed T) T {
t.Helper()

var tmp T
err := faker.Struct(&tmp)
require.NoError(t, err, "failed to generate fake data for type %T", tmp)

mergeZero(&seed, tmp)
return seed
}

// mergeZero merges the fields of src into dst, but only if the field in dst is
// currently the zero value.
// Make sure `dst` is a pointer to a struct, otherwise the fields are not assignable.
func mergeZero(dst any, src any) {
srcv := reflect.ValueOf(src)
if srcv.Kind() == reflect.Ptr {
srcv = srcv.Elem()
}
remain := [][2]reflect.Value{
{reflect.ValueOf(dst).Elem(), srcv},
}

// Traverse the struct fields and set them only if they are currently zero.
// This is a breadth-first traversal of the struct fields. Struct definitions
// Should not be that deep, so we should not hit any stack overflow issues.
for {
if len(remain) == 0 {
return
}
dv, sv := remain[0][0], remain[0][1]
remain = remain[1:] //
for i := 0; i < dv.NumField(); i++ {
df := dv.Field(i)
sf := sv.Field(i)
if !df.CanSet() {
continue
}
if df.IsZero() { // only write if currently zero
df.Set(sf)
continue
}

if dv.Field(i).Kind() == reflect.Struct {
// If the field is a struct, we need to traverse it as well.
remain = append(remain, [2]reflect.Value{df, sf})
}
}
}
}
71 changes: 71 additions & 0 deletions testutil/faker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package testutil_test

import (
"testing"

"github.com/brianvoe/gofakeit/v7"
"github.com/google/uuid"
"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/testutil"
)

type simpleStruct struct {
ID uuid.UUID
Name string
Description string
Age int `fake:"{number:18,60}"`
}

type nestedStruct struct {
Person simpleStruct
Address string
}

func TestFake(t *testing.T) {
t.Parallel()

t.Run("Simple", func(t *testing.T) {
t.Parallel()

faker := gofakeit.New(0)
person := testutil.Fake(t, faker, simpleStruct{
Name: "alice",
})
require.Equal(t, "alice", person.Name)
require.NotEqual(t, uuid.Nil, person.ID)
require.NotEmpty(t, person.Description)
require.Greater(t, person.Age, 17, "Age should be greater than 17")
require.Less(t, person.Age, 61, "Age should be less than 61")
})

t.Run("Nested", func(t *testing.T) {
t.Parallel()

faker := gofakeit.New(0)
person := testutil.Fake(t, faker, nestedStruct{
Person: simpleStruct{
Name: "alice",
},
})
require.Equal(t, "alice", person.Person.Name)
require.NotEqual(t, uuid.Nil, person.Person.ID)
require.NotEmpty(t, person.Person.Description)
require.Greater(t, person.Person.Age, 17, "Age should be greater than 17")
require.NotEmpty(t, person.Address)
})

t.Run("DatabaseType", func(t *testing.T) {
t.Parallel()

faker := gofakeit.New(0)
id := uuid.New()
key := testutil.Fake(t, faker, database.APIKey{
UserID: id,
TokenName: "keep-my-name",
})
require.Equal(t, id, key.UserID)
require.NotEmpty(t, key.TokenName)
})
}
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy