From 99c97c215e634a6ef118f5366dd9499ec68f8d64 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 3 Sep 2024 10:15:52 -0500 Subject: [PATCH 01/38] wip --- coderd/coderd.go | 15 ++--- coderd/idpsync/group.go | 79 ++++++++++++++++++++++++ coderd/idpsync/idpsync.go | 22 ++++--- coderd/idpsync/organization.go | 11 ++++ enterprise/coderd/coderd.go | 16 ++--- enterprise/coderd/enidpsync/enidpsync.go | 1 - enterprise/coderd/enidpsync/groups.go | 28 +++++++++ 7 files changed, 147 insertions(+), 25 deletions(-) create mode 100644 coderd/idpsync/group.go create mode 100644 enterprise/coderd/enidpsync/groups.go diff --git a/coderd/coderd.go b/coderd/coderd.go index 51b6780e4dc47..895aa3e501c27 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -276,13 +276,6 @@ func New(options *Options) *API { if options.Entitlements == nil { options.Entitlements = entitlements.New() } - if options.IDPSync == nil { - options.IDPSync = idpsync.NewAGPLSync(options.Logger, idpsync.SyncSettings{ - OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(), - OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value, - OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(), - }) - } if options.NewTicker == nil { options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) { ticker := time.NewTicker(duration) @@ -318,6 +311,14 @@ func New(options *Options) *API { options.AccessControlStore, ) + if options.IDPSync == nil { + options.IDPSync = idpsync.NewAGPLSync(options.Logger, idpsync.SyncSettings{ + OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(), + OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value, + OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(), + }) + } + experiments := ReadExperiments( options.Logger, options.DeploymentValues.Experiments.Value(), ) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go new file mode 100644 index 0000000000000..1bbc6a09a34d5 --- /dev/null +++ b/coderd/idpsync/group.go @@ -0,0 +1,79 @@ +package idpsync + +import ( + "context" + + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" +) + +type GroupParams struct { + // SyncEnabled if false will skip syncing the user's groups + SyncEnabled bool + MergedClaims jwt.MapClaims +} + +func (AGPLIDPSync) GroupSyncEnabled() bool { + // AGPL does not support syncing groups. + return false +} + +func (s AGPLIDPSync) ParseGroupClaims(_ context.Context, _ jwt.MapClaims) (GroupParams, *HTTPError) { + return GroupParams{ + SyncEnabled: s.GroupSyncEnabled(), + }, nil +} + +// TODO: Group allowlist behavior should probably happen at this step. +func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error { + // Nothing happens if sync is not enabled + if !params.SyncEnabled { + return nil + } + + // nolint:gocritic // all syncing is done as a system user + ctx = dbauthz.AsSystemRestricted(ctx) + + db.InTx(func(tx database.Store) error { + userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{ + HasMemberID: user.ID, + }) + if err != nil { + return xerrors.Errorf("get user groups: %w", err) + } + + // Figure out which organizations the user is a member of. + userOrgs := make(map[uuid.UUID][]database.GetGroupsRow) + for _, g := range userGroups { + g := g + userOrgs[g.Group.OrganizationID] = append(userOrgs[g.Group.OrganizationID], g) + } + + // Force each organization, we sync the groups. + db.RemoveUserFromAllGroups(ctx, user.ID) + + return nil + }, nil) + + // + //tx.InTx(func(tx database.Store) error { + // // When setting the user's groups, it's easier to just clear their groups and re-add them. + // // This ensures that the user's groups are always in sync with the auth provider. + // err := tx.RemoveUserFromAllGroups(ctx, user.ID) + // if err != nil { + // return err + // } + // + // for _, org := range userOrgs { + // + // } + // + // return nil + //}, nil) + + return nil +} diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 73a7b9b6f530d..227436cfab998 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -3,6 +3,7 @@ package idpsync import ( "context" "net/http" + "regexp" "strings" "github.com/golang-jwt/jwt/v4" @@ -29,6 +30,11 @@ type IDPSync interface { // SyncOrganizations assigns and removed users from organizations based on the // provided params. SyncOrganizations(ctx context.Context, tx database.Store, user database.User, params OrganizationParams) error + + GroupSyncEnabled() bool + // ParseGroupClaims takes claims from an OIDC provider, and returns the + // group sync params for assigning users into groups. + ParseGroupClaims(ctx context.Context, _ jwt.MapClaims) (GroupParams, *HTTPError) } // AGPLIDPSync is the configuration for syncing user information from an external @@ -50,17 +56,13 @@ type SyncSettings struct { // placed into the default organization. This is mostly a hack to support // legacy deployments. OrganizationAssignDefault bool -} -type OrganizationParams struct { - // SyncEnabled if false will skip syncing the user's organizations. - SyncEnabled bool - // IncludeDefault is primarily for single org deployments. It will ensure - // a user is always inserted into the default org. - IncludeDefault bool - // Organizations is the list of organizations the user should be a member of - // assuming syncing is turned on. - Organizations []uuid.UUID + // Group options here are set by the deployment config and only apply to + // the default organization. + GroupField string + CreateMissingGroups bool + GroupMapping map[string]string + GroupFilter *regexp.Regexp } func NewAGPLSync(logger slog.Logger, settings SyncSettings) *AGPLIDPSync { diff --git a/coderd/idpsync/organization.go b/coderd/idpsync/organization.go index 6d475f28ea0ef..fa091eba441ad 100644 --- a/coderd/idpsync/organization.go +++ b/coderd/idpsync/organization.go @@ -16,6 +16,17 @@ import ( "github.com/coder/coder/v2/coderd/util/slice" ) +type OrganizationParams struct { + // SyncEnabled if false will skip syncing the user's organizations. + SyncEnabled bool + // IncludeDefault is primarily for single org deployments. It will ensure + // a user is always inserted into the default org. + IncludeDefault bool + // Organizations is the list of organizations the user should be a member of + // assuming syncing is turned on. + Organizations []uuid.UUID +} + func (AGPLIDPSync) OrganizationSyncEnabled() bool { // AGPL does not support syncing organizations. return false diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 6cd3e796d1825..bc6491a41198f 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -80,13 +80,6 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { if options.Entitlements == nil { options.Entitlements = entitlements.New() } - if options.IDPSync == nil { - options.IDPSync = enidpsync.NewSync(options.Logger, options.Entitlements, idpsync.SyncSettings{ - OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(), - OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value, - OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(), - }) - } ctx, cancelFunc := context.WithCancel(ctx) @@ -118,6 +111,15 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { } options.Database = cryptDB + + if options.IDPSync == nil { + options.IDPSync = enidpsync.NewSync(options.Logger, options.Entitlements, idpsync.SyncSettings{ + OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(), + OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value, + OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(), + }) + } + api := &API{ ctx: ctx, cancel: cancelFunc, diff --git a/enterprise/coderd/enidpsync/enidpsync.go b/enterprise/coderd/enidpsync/enidpsync.go index bb21c68501e1b..918b9f8edb118 100644 --- a/enterprise/coderd/enidpsync/enidpsync.go +++ b/enterprise/coderd/enidpsync/enidpsync.go @@ -2,7 +2,6 @@ package enidpsync import ( "cdr.dev/slog" - "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/idpsync" ) diff --git a/enterprise/coderd/enidpsync/groups.go b/enterprise/coderd/enidpsync/groups.go new file mode 100644 index 0000000000000..5c8328f039068 --- /dev/null +++ b/enterprise/coderd/enidpsync/groups.go @@ -0,0 +1,28 @@ +package enidpsync + +import ( + "context" + + "github.com/golang-jwt/jwt/v4" + + "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/codersdk" +) + +func (e EnterpriseIDPSync) GroupSyncEnabled() bool { + return e.entitlements.Enabled(codersdk.FeatureTemplateRBAC) + +} + +// ParseGroupClaims returns the groups from the external IDP. +// TODO: Implement group allow_list behavior here since that is deployment wide. +func (e EnterpriseIDPSync) ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (idpsync.GroupParams, *idpsync.HTTPError) { + if !e.GroupSyncEnabled() { + return e.AGPLIDPSync.ParseGroupClaims(ctx, mergedClaims) + } + + return idpsync.GroupParams{ + SyncEnabled: e.OrganizationSyncEnabled(), + MergedClaims: mergedClaims, + }, nil +} From bfddeb644f7c10d27ae1abd63ff6585fb61094af Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 3 Sep 2024 16:46:19 -0500 Subject: [PATCH 02/38] begin group sync main work --- coderd/coderd.go | 2 +- coderd/database/dbauthz/dbauthz.go | 8 + coderd/database/dbauthz/dbauthz_test.go | 11 ++ coderd/database/dbmem/dbmem.go | 30 ++++ coderd/database/dbmetrics/dbmetrics.go | 7 + coderd/database/models.go | 2 +- coderd/database/querier.go | 4 +- coderd/database/queries.sql.go | 50 +++++- coderd/database/queries/groupmembers.sql | 19 +++ coderd/idpsync/group.go | 187 ++++++++++++++++++++++- coderd/idpsync/idpsync.go | 19 ++- enterprise/coderd/enidpsync/groups.go | 4 +- 12 files changed, 331 insertions(+), 12 deletions(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index 895aa3e501c27..97c2d9f883713 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -312,7 +312,7 @@ func New(options *Options) *API { ) if options.IDPSync == nil { - options.IDPSync = idpsync.NewAGPLSync(options.Logger, idpsync.SyncSettings{ + options.IDPSync = idpsync.NewAGPLSync(options.Logger, idpsync.DeploymentSyncSettings{ OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(), OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value, OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(), diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 5782bdc8e7155..3e5e3e39164b6 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2892,6 +2892,14 @@ func (q *querier) InsertUser(ctx context.Context, arg database.InsertUserParams) return insert(q.log, q.auth, obj, q.db.InsertUser)(ctx, arg) } +func (q *querier) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { + // This is used by OIDC sync. So only used by a system user. + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.InsertUserGroupsByID(ctx, arg) +} + func (q *querier) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error { // This will add the user to all named groups. This counts as updating a group. // NOTE: instead of checking if the user has permission to update each group, we instead diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index d23bb48184b61..2bd55c4bec499 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -388,6 +388,17 @@ func (s *MethodTestSuite) TestGroup() { GroupNames: slice.New(g1.Name, g2.Name), }).Asserts(rbac.ResourceGroup.InOrg(o.ID), policy.ActionUpdate).Returns() })) + s.Run("InsertUserGroupsByID", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u1 := dbgen.User(s.T(), db, database.User{}) + g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g1.ID, UserID: u1.ID}) + check.Args(database.InsertUserGroupsByIDParams{ + UserID: u1.ID, + GroupIds: slice.New(g1.ID, g2.ID), + }).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1, g2)) + })) s.Run("RemoveUserFromAllGroups", s.Subtest(func(db database.Store, check *expects) { o := dbgen.Organization(s.T(), db, database.Organization{}) u1 := dbgen.User(s.T(), db, database.User{}) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 04f0d32537f90..c3d04e8f9f201 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -7015,7 +7015,37 @@ func (q *FakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam return user, nil } +func (q *FakeQuerier) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { + err := validateDatabaseType(arg) + if err != nil { + return nil, err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + var groupIDs []uuid.UUID + for _, group := range q.groups { + for _, groupID := range arg.GroupIds { + if group.ID == groupID { + q.groupMembers = append(q.groupMembers, database.GroupMemberTable{ + UserID: arg.UserID, + GroupID: groupID, + }) + groupIDs = append(groupIDs, group.ID) + } + } + } + + return groupIDs, nil +} + func (q *FakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.InsertUserGroupsByNameParams) error { + err := validateDatabaseType(arg) + if err != nil { + return err + } + q.mutex.Lock() defer q.mutex.Unlock() diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 5aa3a0c8d8cfb..510af865fc1c4 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -1789,6 +1789,13 @@ func (m metricsStore) InsertUser(ctx context.Context, arg database.InsertUserPar return user, err } +func (m metricsStore) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { + start := time.Now() + r0 := m.s.InsertUserGroupsByID(ctx, arg) + m.queryLatencies.WithLabelValues("InsertUserGroupsByID").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error { start := time.Now() err := m.s.InsertUserGroupsByName(ctx, arg) diff --git a/coderd/database/models.go b/coderd/database/models.go index 9e0283ba859c1..950c2674ab310 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package database diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 3432bac7dada1..3499f9cf702b3 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package database @@ -369,6 +369,8 @@ type sqlcQuerier interface { InsertTemplateVersionVariable(ctx context.Context, arg InsertTemplateVersionVariableParams) (TemplateVersionVariable, error) InsertTemplateVersionWorkspaceTag(ctx context.Context, arg InsertTemplateVersionWorkspaceTagParams) (TemplateVersionWorkspaceTag, error) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) + // InsertUserGroupsByID adds a user to all provided groups, if they exist. + InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error) // InsertUserGroupsByName adds a user to all provided groups, if they exist. InsertUserGroupsByName(ctx context.Context, arg InsertUserGroupsByNameParams) error InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 89822a72a7855..2816dad13e6ba 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package database @@ -1446,6 +1446,54 @@ func (q *sqlQuerier) InsertGroupMember(ctx context.Context, arg InsertGroupMembe return err } +const insertUserGroupsByID = `-- name: InsertUserGroupsByID :many +WITH groups AS ( + SELECT + id + FROM + groups + WHERE + groups.id = ANY($2 :: uuid []) +) +INSERT INTO + group_members (user_id, group_id) +SELECT + $1, + groups.id +FROM + groups +RETURNING group_id +` + +type InsertUserGroupsByIDParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"` +} + +// InsertUserGroupsByID adds a user to all provided groups, if they exist. +func (q *sqlQuerier) InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error) { + rows, err := q.db.QueryContext(ctx, insertUserGroupsByID, arg.UserID, pq.Array(arg.GroupIds)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []uuid.UUID + for rows.Next() { + var group_id uuid.UUID + if err := rows.Scan(&group_id); err != nil { + return nil, err + } + items = append(items, group_id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const insertUserGroupsByName = `-- name: InsertUserGroupsByName :exec WITH groups AS ( SELECT diff --git a/coderd/database/queries/groupmembers.sql b/coderd/database/queries/groupmembers.sql index 0ef2c72323cc9..867b1ba75d0e7 100644 --- a/coderd/database/queries/groupmembers.sql +++ b/coderd/database/queries/groupmembers.sql @@ -29,6 +29,25 @@ SELECT FROM groups; +-- InsertUserGroupsByID adds a user to all provided groups, if they exist. +-- name: InsertUserGroupsByID :many +WITH groups AS ( + SELECT + id + FROM + groups + WHERE + groups.id = ANY(@group_ids :: uuid []) +) +INSERT INTO + group_members (user_id, group_id) +SELECT + @user_id, + groups.id +FROM + groups +RETURNING group_id; + -- name: RemoveUserFromAllGroups :exec DELETE FROM group_members diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 1bbc6a09a34d5..d47a7f69045d5 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -2,13 +2,18 @@ package idpsync import ( "context" + "regexp" "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "golang.org/x/xerrors" + "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/runtimeconfig" + "github.com/coder/coder/v2/coderd/util/slice" ) type GroupParams struct { @@ -39,7 +44,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat ctx = dbauthz.AsSystemRestricted(ctx) db.InTx(func(tx database.Store) error { - userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{ + resolver := runtimeconfig.NewStoreResolver(tx) + userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ HasMemberID: user.ID, }) if err != nil { @@ -53,9 +59,86 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat userOrgs[g.Group.OrganizationID] = append(userOrgs[g.Group.OrganizationID], g) } - // Force each organization, we sync the groups. - db.RemoveUserFromAllGroups(ctx, user.ID) + // For each org, we need to fetch the sync settings + orgSettings := make(map[uuid.UUID]GroupSyncSettings) + for orgID := range userOrgs { + orgResolver := runtimeconfig.NewOrgResolver(orgID, resolver) + settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver) + if err != nil { + return xerrors.Errorf("resolve group sync settings: %w", err) + } + orgSettings[orgID] = settings.Value + } + + // collect all diffs to do 1 sql update for all orgs + groupsToAdd := make([]uuid.UUID, 0) + groupsToRemove := make([]uuid.UUID, 0) + // For each org, determine which groups the user should land in + for orgID, settings := range orgSettings { + if settings.GroupField == "" { + // No group sync enabled for this org, so do nothing. + continue + } + + expectedGroups, err := settings.ParseClaims(params.MergedClaims) + if err != nil { + s.Logger.Debug(ctx, "failed to parse claims for groups", + slog.F("organization_field", s.GroupField), + slog.F("organization_id", orgID), + slog.Error(err), + ) + // Unsure where to raise this error on the UI or database. + continue + } + // Everyone group is always implied. + expectedGroups = append(expectedGroups, ExpectedGroup{ + GroupID: &orgID, + }) + + // Now we know what groups the user should be in for a given org, + // determine if we have to do any group updates to sync the user's + // state. + existingGroups := userOrgs[orgID] + existingGroupsTyped := db2sdk.List(existingGroups, func(f database.GetGroupsRow) ExpectedGroup { + return ExpectedGroup{ + GroupID: &f.Group.ID, + GroupName: &f.Group.Name, + } + }) + add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool { + // Only the name or the name needs to be checked, priority is given to the ID. + if a.GroupID != nil && b.GroupID != nil { + return *a.GroupID == *b.GroupID + } + if a.GroupName != nil && b.GroupName != nil { + return *a.GroupName == *b.GroupName + } + return false + }) + + // HandleMissingGroups will add the new groups to the org if + // the settings specify. It will convert all group names into uuids + // for easier assignment. + assignGroups, err := settings.HandleMissingGroups(ctx, tx, orgID, add) + if err != nil { + return xerrors.Errorf("handle missing groups: %w", err) + } + for _, removeGroup := range remove { + // This should always be the case. + // TODO: make sure this is always the case + if removeGroup.GroupID != nil { + groupsToRemove = append(groupsToRemove, *removeGroup.GroupID) + } + } + + groupsToAdd = append(groupsToAdd, assignGroups...) + } + + tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{ + UserID: user.ID, + GroupIds: groupsToAdd, + }) return nil }, nil) @@ -77,3 +160,101 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return nil } + +type GroupSyncSettings struct { + GroupField string `json:"field"` + // GroupMapping maps from an OIDC group --> Coder group ID + GroupMapping map[string][]uuid.UUID `json:"mapping"` + RegexFilter *regexp.Regexp `json:"regex_filter"` + AutoCreateMissingGroups bool `json:"auto_create_missing_groups"` +} + +type ExpectedGroup struct { + GroupID *uuid.UUID + GroupName *string +} + +// ParseClaims will take the merged claims from the IDP and return the groups +// the user is expected to be a member of. The expected group can either be a +// name or an ID. +// It is unfortunate we cannot use exclusively names or exclusively IDs. +// When configuring though, if a group is mapped from "A" -> "UUID 1234", and +// the group "UUID 1234" is renamed, we want to maintain the mapping. +// We have to keep names because group sync supports syncing groups by name if +// the external IDP group name matches the Coder one. +func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGroup, error) { + groupsRaw, ok := mergedClaims[s.GroupField] + if !ok { + return []ExpectedGroup{}, nil + } + + parsedGroups, err := ParseStringSliceClaim(groupsRaw) + if err != nil { + return nil, xerrors.Errorf("parse groups field, unexpected type %T: %w", groupsRaw, err) + } + + groups := make([]ExpectedGroup, 0) + for _, group := range parsedGroups { + // Only allow through groups that pass the regex + if s.RegexFilter != nil { + if !s.RegexFilter.MatchString(group) { + continue + } + } + + mappedGroupIDs, ok := s.GroupMapping[group] + if ok { + for _, gid := range mappedGroupIDs { + gid := gid + groups = append(groups, ExpectedGroup{GroupID: &gid}) + } + continue + } + group := group + groups = append(groups, ExpectedGroup{GroupName: &group}) + } + + return groups, nil +} + +func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) { + if !s.AutoCreateMissingGroups { + // Remove all groups that are missing, they will not be created + filter := make([]uuid.UUID, 0) + for _, expected := range add { + if expected.GroupID != nil { + filter = append(filter, *expected.GroupID) + } + } + + return filter, nil + } + // All expected that are missing IDs means the group does not exist + // in the database. Either remove them, or create them if auto create is + // turned on. + var missingGroups []string + addIDs := make([]uuid.UUID, 0) + + for _, expected := range add { + if expected.GroupID == nil && expected.GroupName != nil { + missingGroups = append(missingGroups, *expected.GroupName) + } else if expected.GroupID != nil { + // Keep the IDs to sync the groups. + addIDs = append(addIDs, *expected.GroupID) + } + } + + createdMissingGroups, err := tx.InsertMissingGroups(ctx, database.InsertMissingGroupsParams{ + OrganizationID: orgID, + Source: database.GroupSourceOidc, + GroupNames: missingGroups, + }) + if err != nil { + return nil, xerrors.Errorf("insert missing groups: %w", err) + } + for _, created := range createdMissingGroups { + addIDs = append(addIDs, created.ID) + } + + return addIDs, nil +} diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 227436cfab998..2d02b941bcc80 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -13,8 +13,10 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/site" + "github.com/coder/serpent" ) // IDPSync is an interface, so we can implement this as AGPL and as enterprise, @@ -45,7 +47,8 @@ type AGPLIDPSync struct { SyncSettings } -type SyncSettings struct { +// DeploymentSyncSettings are static and are sourced from the deployment config. +type DeploymentSyncSettings struct { // OrganizationField selects the claim field to be used as the created user's // organizations. If the field is the empty string, then no organization updates // will ever come from the OIDC provider. @@ -56,6 +59,12 @@ type SyncSettings struct { // placed into the default organization. This is mostly a hack to support // legacy deployments. OrganizationAssignDefault bool +} + +type SyncSettings struct { + DeploymentSyncSettings + + Group runtimeconfig.Entry[*serpent.Struct[GroupSyncSettings]] // Group options here are set by the deployment config and only apply to // the default organization. @@ -65,10 +74,12 @@ type SyncSettings struct { GroupFilter *regexp.Regexp } -func NewAGPLSync(logger slog.Logger, settings SyncSettings) *AGPLIDPSync { +func NewAGPLSync(logger slog.Logger, settings DeploymentSyncSettings) *AGPLIDPSync { return &AGPLIDPSync{ - Logger: logger.Named("idp-sync"), - SyncSettings: settings, + Logger: logger.Named("idp-sync"), + SyncSettings: SyncSettings{ + DeploymentSyncSettings: settings, + }, } } diff --git a/enterprise/coderd/enidpsync/groups.go b/enterprise/coderd/enidpsync/groups.go index 5c8328f039068..02f012b8e14c3 100644 --- a/enterprise/coderd/enidpsync/groups.go +++ b/enterprise/coderd/enidpsync/groups.go @@ -14,7 +14,9 @@ func (e EnterpriseIDPSync) GroupSyncEnabled() bool { } -// ParseGroupClaims returns the groups from the external IDP. +// ParseGroupClaims parses the user claims and handles deployment wide group behavior. +// Almost all behavior is deferred since each organization configures it's own +// group sync settings. // TODO: Implement group allow_list behavior here since that is deployment wide. func (e EnterpriseIDPSync) ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (idpsync.GroupParams, *idpsync.HTTPError) { if !e.GroupSyncEnabled() { From f2857c69a3e7fad35a17a23d72667e376ac78966 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 3 Sep 2024 17:36:24 -0500 Subject: [PATCH 03/38] initial implementation of group sync --- coderd/database/dbauthz/dbauthz.go | 4 ++ coderd/database/dbauthz/dbauthz_test.go | 2 +- coderd/database/dbmem/dbmem.go | 9 +++++ coderd/database/dbmetrics/dbmetrics.go | 11 +++++- coderd/database/dbmock/dbmock.go | 15 ++++++++ coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 37 ++++++++++++++++++ coderd/database/queries/groupmembers.sql | 8 ++++ coderd/idpsync/group.go | 48 ++++++++++++++---------- enterprise/coderd/enidpsync/enidpsync.go | 2 +- 10 files changed, 114 insertions(+), 23 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 3e5e3e39164b6..eaf994e849fc5 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3108,6 +3108,10 @@ func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) return q.db.RemoveUserFromAllGroups(ctx, userID) } +func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { + panic("not implemented") +} + func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil { return err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 2bd55c4bec499..f9b9fb49b71fc 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -397,7 +397,7 @@ func (s *MethodTestSuite) TestGroup() { check.Args(database.InsertUserGroupsByIDParams{ UserID: u1.ID, GroupIds: slice.New(g1.ID, g2.ID), - }).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1, g2)) + }).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1.ID, g2.ID)) })) s.Run("RemoveUserFromAllGroups", s.Subtest(func(db database.Store, check *expects) { o := dbgen.Organization(s.T(), db, database.Organization{}) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index c3d04e8f9f201..423b13ef4a774 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -7637,6 +7637,15 @@ func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUI return nil } +func (q *FakeQuerier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { + err := validateDatabaseType(arg) + if err != nil { + return nil, err + } + + panic("not implemented") +} + func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error { q.mutex.Lock() defer q.mutex.Unlock() diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 510af865fc1c4..0ec70c1736d43 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -1791,9 +1791,9 @@ func (m metricsStore) InsertUser(ctx context.Context, arg database.InsertUserPar func (m metricsStore) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { start := time.Now() - r0 := m.s.InsertUserGroupsByID(ctx, arg) + r0, r1 := m.s.InsertUserGroupsByID(ctx, arg) m.queryLatencies.WithLabelValues("InsertUserGroupsByID").Observe(time.Since(start).Seconds()) - return r0 + return r0, r1 } func (m metricsStore) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error { @@ -1950,6 +1950,13 @@ func (m metricsStore) RemoveUserFromAllGroups(ctx context.Context, userID uuid.U return r0 } +func (m metricsStore) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { + start := time.Now() + r0, r1 := m.s.RemoveUserFromGroups(ctx, arg) + m.queryLatencies.WithLabelValues("RemoveUserFromGroups").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error { start := time.Now() r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 6d881cfe6fc1b..fe2e444ff5c67 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -3766,6 +3766,21 @@ func (mr *MockStoreMockRecorder) InsertUser(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUser", reflect.TypeOf((*MockStore)(nil).InsertUser), arg0, arg1) } +// InsertUserGroupsByID mocks base method. +func (m *MockStore) InsertUserGroupsByID(arg0 context.Context, arg1 database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertUserGroupsByID", arg0, arg1) + ret0, _ := ret[0].([]uuid.UUID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InsertUserGroupsByID indicates an expected call of InsertUserGroupsByID. +func (mr *MockStoreMockRecorder) InsertUserGroupsByID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUserGroupsByID", reflect.TypeOf((*MockStore)(nil).InsertUserGroupsByID), arg0, arg1) +} + // InsertUserGroupsByName mocks base method. func (m *MockStore) InsertUserGroupsByName(arg0 context.Context, arg1 database.InsertUserGroupsByNameParams) error { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 3499f9cf702b3..3cedeeade34b7 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -398,6 +398,7 @@ type sqlcQuerier interface { ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error + RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error // Non blocking lock. Returns true if the lock was acquired, false otherwise. // diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 2816dad13e6ba..3e6d6ce61c6fb 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1537,6 +1537,43 @@ func (q *sqlQuerier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UU return err } +const removeUserFromGroups = `-- name: RemoveUserFromGroups :many +DELETE FROM + group_members +WHERE + user_id = $1 AND + group_id = ANY($2 :: uuid []) +RETURNING group_id +` + +type RemoveUserFromGroupsParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"` +} + +func (q *sqlQuerier) RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) { + rows, err := q.db.QueryContext(ctx, removeUserFromGroups, arg.UserID, pq.Array(arg.GroupIds)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []uuid.UUID + for rows.Next() { + var group_id uuid.UUID + if err := rows.Scan(&group_id); err != nil { + return nil, err + } + items = append(items, group_id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const deleteGroupByID = `-- name: DeleteGroupByID :exec DELETE FROM groups diff --git a/coderd/database/queries/groupmembers.sql b/coderd/database/queries/groupmembers.sql index 867b1ba75d0e7..814f878cb9232 100644 --- a/coderd/database/queries/groupmembers.sql +++ b/coderd/database/queries/groupmembers.sql @@ -54,6 +54,14 @@ DELETE FROM WHERE user_id = @user_id; +-- name: RemoveUserFromGroups :many +DELETE FROM + group_members +WHERE + user_id = @user_id AND + group_id = ANY(@group_ids :: uuid []) +RETURNING group_id; + -- name: InsertGroupMember :exec INSERT INTO group_members (user_id, group_id) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index d47a7f69045d5..6d5fd11a52e5a 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -135,29 +135,39 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat groupsToAdd = append(groupsToAdd, assignGroups...) } - tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{ - UserID: user.ID, - GroupIds: groupsToAdd, + assignedGroupIDs, err := tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{ + UserID: user.ID, + GroupIds: groupsToAdd, }) + if err != nil { + return xerrors.Errorf("insert user into %d groups: %w", len(groupsToAdd), err) + } + if len(assignedGroupIDs) != len(groupsToAdd) { + s.Logger.Debug(ctx, "failed to assign all groups to user", + slog.F("user_id", user.ID), + slog.F("groups_assigned_count", len(assignedGroupIDs)), + slog.F("expected_count", len(groupsToAdd)), + ) + } + + removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{ + UserID: user.ID, + GroupIds: groupsToRemove, + }) + if err != nil { + return xerrors.Errorf("remove user from %d groups: %w", len(groupsToRemove), err) + } + if len(removedGroupIDs) != len(groupsToRemove) { + s.Logger.Debug(ctx, "failed to remove user from all groups", + slog.F("user_id", user.ID), + slog.F("groups_removed_count", len(removedGroupIDs)), + slog.F("expected_count", len(groupsToRemove)), + ) + } + return nil }, nil) - // - //tx.InTx(func(tx database.Store) error { - // // When setting the user's groups, it's easier to just clear their groups and re-add them. - // // This ensures that the user's groups are always in sync with the auth provider. - // err := tx.RemoveUserFromAllGroups(ctx, user.ID) - // if err != nil { - // return err - // } - // - // for _, org := range userOrgs { - // - // } - // - // return nil - //}, nil) - return nil } diff --git a/enterprise/coderd/enidpsync/enidpsync.go b/enterprise/coderd/enidpsync/enidpsync.go index 918b9f8edb118..10988832743da 100644 --- a/enterprise/coderd/enidpsync/enidpsync.go +++ b/enterprise/coderd/enidpsync/enidpsync.go @@ -16,7 +16,7 @@ type EnterpriseIDPSync struct { *idpsync.AGPLIDPSync } -func NewSync(logger slog.Logger, set *entitlements.Set, settings idpsync.SyncSettings) *EnterpriseIDPSync { +func NewSync(logger slog.Logger, set *entitlements.Set, settings idpsync.DeploymentSyncSettings) *EnterpriseIDPSync { return &EnterpriseIDPSync{ entitlements: set, AGPLIDPSync: idpsync.NewAGPLSync(logger.With(slog.F("enterprise_capable", "true")), settings), From 791a05977df0ae118f2e88beec96876ba69e64d4 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 4 Sep 2024 09:24:16 -0500 Subject: [PATCH 04/38] work on moving to the manager --- coderd/idpsync/idpsync.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 2d02b941bcc80..6400977387536 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -34,15 +34,19 @@ type IDPSync interface { SyncOrganizations(ctx context.Context, tx database.Store, user database.User, params OrganizationParams) error GroupSyncEnabled() bool - // ParseGroupClaims takes claims from an OIDC provider, and returns the - // group sync params for assigning users into groups. + // ParseGroupClaims takes claims from an OIDC provider, and returns the params + // for group syncing. Most of the logic happens in SyncGroups. ParseGroupClaims(ctx context.Context, _ jwt.MapClaims) (GroupParams, *HTTPError) + + // SyncGroups assigns and removes users from groups based on the provided params. + SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error } // AGPLIDPSync is the configuration for syncing user information from an external // IDP. All related code to syncing user information should be in this package. type AGPLIDPSync struct { - Logger slog.Logger + Logger slog.Logger + Manager runtimeconfig.Manager SyncSettings } @@ -74,9 +78,10 @@ type SyncSettings struct { GroupFilter *regexp.Regexp } -func NewAGPLSync(logger slog.Logger, settings DeploymentSyncSettings) *AGPLIDPSync { +func NewAGPLSync(logger slog.Logger, manager runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync { return &AGPLIDPSync{ - Logger: logger.Named("idp-sync"), + Logger: logger.Named("idp-sync"), + Manager: manager, SyncSettings: SyncSettings{ DeploymentSyncSettings: settings, }, From 4326e9d94af1915ac1d109e72f808ed9cb3acd5c Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 4 Sep 2024 10:22:22 -0500 Subject: [PATCH 05/38] fixup compile issues --- coderd/idpsync/group.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 6d5fd11a52e5a..11e14260a7f3d 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -12,7 +12,6 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" - "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/coderd/util/slice" ) @@ -44,7 +43,6 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat ctx = dbauthz.AsSystemRestricted(ctx) db.InTx(func(tx database.Store) error { - resolver := runtimeconfig.NewStoreResolver(tx) userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ HasMemberID: user.ID, }) @@ -62,7 +60,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat // For each org, we need to fetch the sync settings orgSettings := make(map[uuid.UUID]GroupSyncSettings) for orgID := range userOrgs { - orgResolver := runtimeconfig.NewOrgResolver(orgID, resolver) + orgResolver := s.Manager.Scoped(orgID.String()) settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver) if err != nil { return xerrors.Errorf("resolve group sync settings: %w", err) From 6d3ed2e57043c7eada5f619697c4d2997b3bf790 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 4 Sep 2024 10:25:12 -0500 Subject: [PATCH 06/38] fixup some tests --- coderd/idpsync/organizations_test.go | 29 +++++++++++-------- enterprise/coderd/enidpsync/enidpsync.go | 5 ++-- .../coderd/enidpsync/organizations_test.go | 9 +++--- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/coderd/idpsync/organizations_test.go b/coderd/idpsync/organizations_test.go index 03b1ebfa4b27b..b0e7728b0640a 100644 --- a/coderd/idpsync/organizations_test.go +++ b/coderd/idpsync/organizations_test.go @@ -9,6 +9,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/testutil" ) @@ -18,11 +19,13 @@ func TestParseOrganizationClaims(t *testing.T) { t.Run("SingleOrgDeployment", func(t *testing.T) { t.Parallel() - s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), idpsync.SyncSettings{ - OrganizationField: "", - OrganizationMapping: nil, - OrganizationAssignDefault: true, - }) + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewNoopManager(), + idpsync.DeploymentSyncSettings{ + OrganizationField: "", + OrganizationMapping: nil, + OrganizationAssignDefault: true, + }) ctx := testutil.Context(t, testutil.WaitMedium) @@ -38,13 +41,15 @@ func TestParseOrganizationClaims(t *testing.T) { t.Parallel() // AGPL has limited behavior - s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), idpsync.SyncSettings{ - OrganizationField: "orgs", - OrganizationMapping: map[string][]uuid.UUID{ - "random": {uuid.New()}, - }, - OrganizationAssignDefault: false, - }) + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewNoopManager(), + idpsync.DeploymentSyncSettings{ + OrganizationField: "orgs", + OrganizationMapping: map[string][]uuid.UUID{ + "random": {uuid.New()}, + }, + OrganizationAssignDefault: false, + }) ctx := testutil.Context(t, testutil.WaitMedium) diff --git a/enterprise/coderd/enidpsync/enidpsync.go b/enterprise/coderd/enidpsync/enidpsync.go index 10988832743da..a7ff1eaa07257 100644 --- a/enterprise/coderd/enidpsync/enidpsync.go +++ b/enterprise/coderd/enidpsync/enidpsync.go @@ -4,6 +4,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/coderd/runtimeconfig" ) // EnterpriseIDPSync enabled syncing user information from an external IDP. @@ -16,9 +17,9 @@ type EnterpriseIDPSync struct { *idpsync.AGPLIDPSync } -func NewSync(logger slog.Logger, set *entitlements.Set, settings idpsync.DeploymentSyncSettings) *EnterpriseIDPSync { +func NewSync(logger slog.Logger, manager runtimeconfig.Manager, set *entitlements.Set, settings idpsync.DeploymentSyncSettings) *EnterpriseIDPSync { return &EnterpriseIDPSync{ entitlements: set, - AGPLIDPSync: idpsync.NewAGPLSync(logger.With(slog.F("enterprise_capable", "true")), settings), + AGPLIDPSync: idpsync.NewAGPLSync(logger.With(slog.F("enterprise_capable", "true")), manager, settings), } } diff --git a/enterprise/coderd/enidpsync/organizations_test.go b/enterprise/coderd/enidpsync/organizations_test.go index 0b2ed1ef6521f..8978fa6b46ee1 100644 --- a/enterprise/coderd/enidpsync/organizations_test.go +++ b/enterprise/coderd/enidpsync/organizations_test.go @@ -19,6 +19,7 @@ import ( "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/idpsync" "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/enidpsync" "github.com/coder/coder/v2/testutil" @@ -41,7 +42,7 @@ type Expectations struct { } type OrganizationSyncTestCase struct { - Settings idpsync.SyncSettings + Settings idpsync.DeploymentSyncSettings Entitlements *entitlements.Set Exps []Expectations } @@ -89,7 +90,7 @@ func TestOrganizationSync(t *testing.T) { other := dbgen.Organization(t, db, database.Organization{}) return OrganizationSyncTestCase{ Entitlements: entitled, - Settings: idpsync.SyncSettings{ + Settings: idpsync.DeploymentSyncSettings{ OrganizationField: "", OrganizationMapping: nil, OrganizationAssignDefault: true, @@ -142,7 +143,7 @@ func TestOrganizationSync(t *testing.T) { three := dbgen.Organization(t, db, database.Organization{}) return OrganizationSyncTestCase{ Entitlements: entitled, - Settings: idpsync.SyncSettings{ + Settings: idpsync.DeploymentSyncSettings{ OrganizationField: "organizations", OrganizationMapping: map[string][]uuid.UUID{ "first": {one.ID}, @@ -236,7 +237,7 @@ func TestOrganizationSync(t *testing.T) { } // Create a new sync object - sync := enidpsync.NewSync(logger, caseData.Entitlements, caseData.Settings) + sync := enidpsync.NewSync(logger, runtimeconfig.NewStoreManager(rdb), caseData.Entitlements, caseData.Settings) for _, exp := range caseData.Exps { t.Run(exp.Name, func(t *testing.T) { params, httpErr := sync.ParseOrganizationClaims(ctx, exp.Claims) From 0803619e8c65eb1bb584abae3137e403bf07f8fe Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 4 Sep 2024 10:51:49 -0500 Subject: [PATCH 07/38] handle allow list --- coderd/coderd.go | 6 +--- coderd/idpsync/group.go | 8 +++++ coderd/idpsync/idpsync.go | 23 ++++++++++++ coderd/userauth.go | 2 +- enterprise/coderd/coderd.go | 6 +--- enterprise/coderd/enidpsync/groups.go | 42 +++++++++++++++++++++- enterprise/coderd/enidpsync/groups_test.go | 35 ++++++++++++++++++ 7 files changed, 110 insertions(+), 12 deletions(-) create mode 100644 enterprise/coderd/enidpsync/groups_test.go diff --git a/coderd/coderd.go b/coderd/coderd.go index 97c2d9f883713..b829d37a06773 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -312,11 +312,7 @@ func New(options *Options) *API { ) if options.IDPSync == nil { - options.IDPSync = idpsync.NewAGPLSync(options.Logger, idpsync.DeploymentSyncSettings{ - OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(), - OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value, - OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(), - }) + options.IDPSync = idpsync.NewAGPLSync(options.Logger, options.RuntimeConfig, idpsync.FromDeploymentValues(options.DeploymentValues)) } experiments := ReadExperiments( diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 11e14260a7f3d..0257801ae2a7a 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -266,3 +266,11 @@ func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database. return addIDs, nil } + +func ConvertAllowList(allowList []string) map[string]struct{} { + allowMap := make(map[string]struct{}, len(allowList)) + for _, group := range allowList { + allowMap[group] = struct{}{} + } + return allowMap +} diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 6400977387536..5ad2ffb52ff12 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -63,6 +63,29 @@ type DeploymentSyncSettings struct { // placed into the default organization. This is mostly a hack to support // legacy deployments. OrganizationAssignDefault bool + + // GroupField at the deployment level is used for deployment level group claim + // settings. + GroupField string + // GroupAllowList (if set) will restrict authentication to only users who + // have at least one group in this list. + // A map representation is used for easier lookup. + GroupAllowList map[string]struct{} +} + +func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings { + if dv == nil { + panic("Developer error: DeploymentValues should not be nil") + } + return DeploymentSyncSettings{ + OrganizationField: dv.OIDC.OrganizationField.Value(), + OrganizationMapping: dv.OIDC.OrganizationMapping.Value, + OrganizationAssignDefault: dv.OIDC.OrganizationAssignDefault.Value(), + + GroupField: dv.OIDC.GroupField.Value(), + GroupAllowList: ConvertAllowList(dv.OIDC.GroupAllowList.Value()), + } + } type SyncSettings struct { diff --git a/coderd/userauth.go b/coderd/userauth.go index bb149d9d07379..a1abadc63f31a 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -1142,7 +1142,7 @@ func (api *API) oidcGroups(ctx context.Context, mergedClaims map[string]interfac slog.F("allow_list_count", len(api.OIDCConfig.GroupAllowList)), slog.F("user_group_count", len(groups)), ) - detail := "Ask an administrator to add one of your groups to the whitelist" + detail := "Ask an administrator to add one of your groups to the allow list" if len(groups) == 0 { detail = "You are currently not a member of any groups! Ask an administrator to add you to an authorized group to login." } diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index bc6491a41198f..ce55bae8ec8d0 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -113,11 +113,7 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { options.Database = cryptDB if options.IDPSync == nil { - options.IDPSync = enidpsync.NewSync(options.Logger, options.Entitlements, idpsync.SyncSettings{ - OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(), - OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value, - OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(), - }) + options.IDPSync = enidpsync.NewSync(options.Logger, options.RuntimeConfig, options.Entitlements, idpsync.FromDeploymentValues(options.DeploymentValues)) } api := &API{ diff --git a/enterprise/coderd/enidpsync/groups.go b/enterprise/coderd/enidpsync/groups.go index 02f012b8e14c3..441f847c6a450 100644 --- a/enterprise/coderd/enidpsync/groups.go +++ b/enterprise/coderd/enidpsync/groups.go @@ -2,6 +2,7 @@ package enidpsync import ( "context" + "net/http" "github.com/golang-jwt/jwt/v4" @@ -11,7 +12,6 @@ import ( func (e EnterpriseIDPSync) GroupSyncEnabled() bool { return e.entitlements.Enabled(codersdk.FeatureTemplateRBAC) - } // ParseGroupClaims parses the user claims and handles deployment wide group behavior. @@ -23,6 +23,46 @@ func (e EnterpriseIDPSync) ParseGroupClaims(ctx context.Context, mergedClaims jw return e.AGPLIDPSync.ParseGroupClaims(ctx, mergedClaims) } + if e.GroupField != "" && len(e.GroupAllowList) > 0 { + groupsRaw, ok := mergedClaims[e.GroupField] + if !ok { + return idpsync.GroupParams{}, &idpsync.HTTPError{ + Code: http.StatusForbidden, + Msg: "Not a member of an allowed group", + Detail: "You have no groups in your claims!", + RenderStaticPage: true, + } + } + parsedGroups, err := idpsync.ParseStringSliceClaim(groupsRaw) + if err != nil { + return idpsync.GroupParams{}, &idpsync.HTTPError{ + Code: http.StatusBadRequest, + Msg: "Failed read groups from claims for allow list check. Ask an administrator for help.", + Detail: err.Error(), + RenderStaticPage: true, + } + } + + inAllowList := false + AllowListCheckLoop: + for _, group := range parsedGroups { + if _, ok := e.GroupAllowList[group]; ok { + inAllowList = true + break AllowListCheckLoop + } + } + + if !inAllowList { + return idpsync.GroupParams{}, &idpsync.HTTPError{ + Code: http.StatusForbidden, + Msg: "Not a member of an allowed group", + Detail: "Ask an administrator to add one of your groups to the allow list.", + RenderStaticPage: true, + } + } + + } + return idpsync.GroupParams{ SyncEnabled: e.OrganizationSyncEnabled(), MergedClaims: mergedClaims, diff --git a/enterprise/coderd/enidpsync/groups_test.go b/enterprise/coderd/enidpsync/groups_test.go new file mode 100644 index 0000000000000..149c57dadd79a --- /dev/null +++ b/enterprise/coderd/enidpsync/groups_test.go @@ -0,0 +1,35 @@ +package enidpsync_test + +import ( + "testing" + + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/entitlements" + "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/coderd/runtimeconfig" + "github.com/coder/coder/v2/enterprise/coderd/enidpsync" + "github.com/coder/coder/v2/testutil" +) + +func TestEnterpriseParseGroupClaims(t *testing.T) { + t.Parallel() + + t.Run("NoEntitlements", func(t *testing.T) { + t.Parallel() + + s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewNoopManager(), + entitlements.New(), + idpsync.DeploymentSyncSettings{}) + + ctx := testutil.Context(t, testutil.WaitMedium) + + params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{}) + require.Nil(t, err) + + require.False(t, params.SyncEnabled) + }) +} From 596e7b467feef913f10ff70768783bb6b5f826e5 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 4 Sep 2024 12:25:51 -0500 Subject: [PATCH 08/38] WIP unit test for group sync --- coderd/coderdtest/uuids.go | 21 ++ coderd/database/dbmem/dbmem.go | 17 +- coderd/idpsync/group.go | 23 +- coderd/idpsync/group_test.go | 289 +++++++++++++++++++++ coderd/idpsync/idpsync.go | 26 +- enterprise/coderd/enidpsync/groups.go | 2 +- enterprise/coderd/enidpsync/groups_test.go | 61 +++++ 7 files changed, 421 insertions(+), 18 deletions(-) create mode 100644 coderd/coderdtest/uuids.go create mode 100644 coderd/idpsync/group_test.go diff --git a/coderd/coderdtest/uuids.go b/coderd/coderdtest/uuids.go new file mode 100644 index 0000000000000..aefa6e83c0b3c --- /dev/null +++ b/coderd/coderdtest/uuids.go @@ -0,0 +1,21 @@ +package coderdtest + +import "github.com/google/uuid" + +type DeterministicUUIDGenerator struct { + Named map[string]uuid.UUID +} + +func NewDeterministicUUIDGenerator() *DeterministicUUIDGenerator { + return &DeterministicUUIDGenerator{ + Named: make(map[string]uuid.UUID), + } +} + +func (d *DeterministicUUIDGenerator) ID(name string) uuid.UUID { + if v, ok := d.Named[name]; ok { + return v + } + d.Named[name] = uuid.New() + return d.Named[name] +} diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 423b13ef4a774..37811063997db 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -7643,7 +7643,22 @@ func (q *FakeQuerier) RemoveUserFromGroups(ctx context.Context, arg database.Rem return nil, err } - panic("not implemented") + q.mutex.Lock() + defer q.mutex.Unlock() + + removed := make([]uuid.UUID, 0) + q.data.groupMembers = slices.DeleteFunc(q.data.groupMembers, func(groupMember database.GroupMemberTable) bool { + if groupMember.UserID != arg.UserID { + return false + } + if !slices.Contains(arg.GroupIds, groupMember.GroupID) { + return false + } + removed = append(removed, groupMember.GroupID) + return true + }) + + return removed, nil } func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error { diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 0257801ae2a7a..d45d79bf04cac 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -2,6 +2,7 @@ package idpsync import ( "context" + "encoding/json" "regexp" "github.com/golang-jwt/jwt/v4" @@ -12,6 +13,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/coderd/util/slice" ) @@ -32,7 +34,6 @@ func (s AGPLIDPSync) ParseGroupClaims(_ context.Context, _ jwt.MapClaims) (Group }, nil } -// TODO: Group allowlist behavior should probably happen at this step. func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error { // Nothing happens if sync is not enabled if !params.SyncEnabled { @@ -43,6 +44,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat ctx = dbauthz.AsSystemRestricted(ctx) db.InTx(func(tx database.Store) error { + manager := runtimeconfig.NewStoreManager(tx) + userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ HasMemberID: user.ID, }) @@ -60,12 +63,12 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat // For each org, we need to fetch the sync settings orgSettings := make(map[uuid.UUID]GroupSyncSettings) for orgID := range userOrgs { - orgResolver := s.Manager.Scoped(orgID.String()) + orgResolver := manager.Scoped(orgID.String()) settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver) if err != nil { return xerrors.Errorf("resolve group sync settings: %w", err) } - orgSettings[orgID] = settings.Value + orgSettings[orgID] = *settings } // collect all diffs to do 1 sql update for all orgs @@ -177,6 +180,20 @@ type GroupSyncSettings struct { AutoCreateMissingGroups bool `json:"auto_create_missing_groups"` } +func (s *GroupSyncSettings) Set(v string) error { + return json.Unmarshal([]byte(v), s) +} +func (s *GroupSyncSettings) String() string { + v, err := json.Marshal(s) + if err != nil { + return "decode failed: " + err.Error() + } + return string(v) +} +func (s *GroupSyncSettings) Type() string { + return "GroupSyncSettings" +} + type ExpectedGroup struct { GroupID *uuid.UUID GroupName *string diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go new file mode 100644 index 0000000000000..42465f115488e --- /dev/null +++ b/coderd/idpsync/group_test.go @@ -0,0 +1,289 @@ +package idpsync_test + +import ( + "context" + "regexp" + "testing" + + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/idpsync" + "github.com/coder/coder/v2/coderd/runtimeconfig" + "github.com/coder/coder/v2/testutil" +) + +func TestParseGroupClaims(t *testing.T) { + t.Parallel() + + t.Run("EmptyConfig", func(t *testing.T) { + t.Parallel() + + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewNoopManager(), + idpsync.DeploymentSyncSettings{}) + + ctx := testutil.Context(t, testutil.WaitMedium) + + params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{}) + require.Nil(t, err) + + require.False(t, params.SyncEnabled) + }) + + // AllowList has no effect in AGPL + t.Run("AllowList", func(t *testing.T) { + t.Parallel() + + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewNoopManager(), + idpsync.DeploymentSyncSettings{ + GroupField: "groups", + GroupAllowList: map[string]struct{}{ + "foo": {}, + }, + }) + + ctx := testutil.Context(t, testutil.WaitMedium) + + params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{}) + require.Nil(t, err) + require.False(t, params.SyncEnabled) + }) +} + +func TestGroupSyncTable(t *testing.T) { + t.Parallel() + + if dbtestutil.WillUsePostgres() { + t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.") + } + + userClaims := jwt.MapClaims{ + "groups": []string{ + "foo", "bar", "baz", + "create-bar", "create-baz", + }, + } + + ids := coderdtest.NewDeterministicUUIDGenerator() + testCases := []orgSetupDefinition{ + { + Name: "SwitchGroups", + Settings: &idpsync.GroupSyncSettings{ + GroupField: "groups", + GroupMapping: map[string][]uuid.UUID{ + "foo": {ids.ID("sg-foo"), ids.ID("sg-foo-2")}, + "bar": {ids.ID("sg-bar")}, + "baz": {ids.ID("sg-baz")}, + }, + }, + Groups: map[uuid.UUID]bool{ + uuid.New(): true, + uuid.New(): true, + // Extra groups + ids.ID("sg-foo"): false, + ids.ID("sg-foo-2"): false, + ids.ID("sg-bar"): false, + ids.ID("sg-baz"): false, + }, + ExpectedGroups: []uuid.UUID{ + ids.ID("sg-foo"), + ids.ID("sg-foo-2"), + ids.ID("sg-bar"), + ids.ID("sg-baz"), + }, + }, + { + Name: "StayInGroup", + Settings: &idpsync.GroupSyncSettings{ + GroupField: "groups", + // Only match foo, so bar does not map + RegexFilter: regexp.MustCompile("^foo$"), + GroupMapping: map[string][]uuid.UUID{ + "foo": {ids.ID("gg-foo"), uuid.New()}, + "bar": {ids.ID("gg-bar")}, + "baz": {ids.ID("gg-baz")}, + }, + }, + Groups: map[uuid.UUID]bool{ + ids.ID("gg-foo"): true, + ids.ID("gg-bar"): false, + }, + ExpectedGroups: []uuid.UUID{ + ids.ID("gg-foo"), + }, + }, + { + Name: "UserJoinsGroups", + Settings: &idpsync.GroupSyncSettings{ + GroupField: "groups", + GroupMapping: map[string][]uuid.UUID{ + "foo": {ids.ID("ng-foo"), uuid.New()}, + "bar": {ids.ID("ng-bar"), ids.ID("ng-bar-2")}, + "baz": {ids.ID("ng-baz")}, + }, + }, + Groups: map[uuid.UUID]bool{ + ids.ID("ng-foo"): false, + ids.ID("ng-bar"): false, + ids.ID("ng-bar-2"): false, + ids.ID("ng-baz"): false, + }, + ExpectedGroups: []uuid.UUID{ + ids.ID("ng-foo"), + ids.ID("ng-bar"), + ids.ID("ng-bar-2"), + ids.ID("ng-baz"), + }, + }, + { + Name: "CreateGroups", + Settings: &idpsync.GroupSyncSettings{ + GroupField: "groups", + RegexFilter: regexp.MustCompile("^create"), + AutoCreateMissingGroups: true, + }, + Groups: map[uuid.UUID]bool{}, + ExpectedGroups: []uuid.UUID{ + ids.ID("create-bar"), + ids.ID("create-baz"), + }, + }, + { + Name: "NoUser", + Settings: &idpsync.GroupSyncSettings{ + GroupField: "groups", + GroupMapping: map[string][]uuid.UUID{ + // Extra ID that does not map to a group + "foo": {ids.ID("ow-foo"), uuid.New()}, + }, + RegexFilter: nil, + AutoCreateMissingGroups: false, + }, + NotMember: true, + Groups: map[uuid.UUID]bool{ + ids.ID("ow-foo"): false, + ids.ID("ow-bar"): false, + }, + }, + { + Name: "NoSettingsNoUser", + Settings: nil, + Groups: map[uuid.UUID]bool{}, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + + if tc.OrgID == uuid.Nil { + tc.OrgID = uuid.New() + } + + db, _ := dbtestutil.NewDB(t) + manager := runtimeconfig.NewStoreManager(db) + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + manager, + idpsync.DeploymentSyncSettings{ + GroupField: "groups", + }, + ) + + ctx := testutil.Context(t, testutil.WaitMedium) + user := dbgen.User(t, db, database.User{}) + SetupOrganization(t, s, db, user, tc) + + // Do the group sync! + err := s.SyncGroups(ctx, db, user, idpsync.GroupParams{ + SyncEnabled: true, + MergedClaims: userClaims, + }) + require.NoError(t, err) + + tc.Assert(t, tc.OrgID, db, user) + }) + } +} + +func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, def orgSetupDefinition) { + org := dbgen.Organization(t, db, database.Organization{ + ID: def.OrgID, + }) + + manager := runtimeconfig.NewStoreManager(db) + orgResolver := manager.Scoped(org.ID.String()) + err := s.Group.SetRuntimeValue(context.Background(), orgResolver, def.Settings) + require.NoError(t, err) + + if !def.NotMember { + dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + } + for groupID, in := range def.Groups { + dbgen.Group(t, db, database.Group{ + ID: groupID, + OrganizationID: org.ID, + }) + if in { + dbgen.GroupMember(t, db, database.GroupMemberTable{ + UserID: user.ID, + GroupID: groupID, + }) + } + } +} + +type orgSetupDefinition struct { + Name string + OrgID uuid.UUID + // True if the user is a member of the group + Groups map[uuid.UUID]bool + NotMember bool + + Settings *idpsync.GroupSyncSettings + ExpectedGroups []uuid.UUID +} + +func (o orgSetupDefinition) Assert(t *testing.T, orgID uuid.UUID, db database.Store, user database.User) { + t.Helper() + + t.Run(o.Name+"-Assert", func(t *testing.T) { + ctx := context.Background() + + members, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{ + OrganizationID: orgID, + UserID: user.ID, + }) + require.NoError(t, err) + if o.NotMember { + require.Len(t, members, 0, "should not be a member") + } else { + require.Len(t, members, 1, "should be a member") + } + + userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{ + OrganizationID: orgID, + HasMemberID: user.ID, + }) + require.NoError(t, err) + if o.ExpectedGroups == nil { + o.ExpectedGroups = make([]uuid.UUID, 0) + } + found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID { + return g.Group.ID + }) + require.ElementsMatch(t, o.ExpectedGroups, found, "user groups") + }) +} diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 5ad2ffb52ff12..3ff8d78fd5174 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -3,7 +3,6 @@ package idpsync import ( "context" "net/http" - "regexp" "strings" "github.com/golang-jwt/jwt/v4" @@ -16,7 +15,6 @@ import ( "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/site" - "github.com/coder/serpent" ) // IDPSync is an interface, so we can implement this as AGPL and as enterprise, @@ -45,8 +43,7 @@ type IDPSync interface { // AGPLIDPSync is the configuration for syncing user information from an external // IDP. All related code to syncing user information should be in this package. type AGPLIDPSync struct { - Logger slog.Logger - Manager runtimeconfig.Manager + Logger slog.Logger SyncSettings } @@ -91,22 +88,25 @@ func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings type SyncSettings struct { DeploymentSyncSettings - Group runtimeconfig.Entry[*serpent.Struct[GroupSyncSettings]] + Group runtimeconfig.Entry[*GroupSyncSettings] - // Group options here are set by the deployment config and only apply to - // the default organization. - GroupField string - CreateMissingGroups bool - GroupMapping map[string]string - GroupFilter *regexp.Regexp + //// Group options here are set by the deployment config and only apply to + //// the default organization. + //GroupField string + //CreateMissingGroups bool + //GroupMapping map[string]string + //GroupFilter *regexp.Regexp } func NewAGPLSync(logger slog.Logger, manager runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync { return &AGPLIDPSync{ - Logger: logger.Named("idp-sync"), - Manager: manager, + Logger: logger.Named("idp-sync"), SyncSettings: SyncSettings{ DeploymentSyncSettings: settings, + // Default to '{}' if the group sync settings are not set. + // TODO: Feels strange to have to define the type as a string. I should be + // able to pass in an empty struct. + Group: runtimeconfig.MustNew[*GroupSyncSettings]("group-sync-settings", "{}"), }, } } diff --git a/enterprise/coderd/enidpsync/groups.go b/enterprise/coderd/enidpsync/groups.go index 441f847c6a450..2ecc8703e29cd 100644 --- a/enterprise/coderd/enidpsync/groups.go +++ b/enterprise/coderd/enidpsync/groups.go @@ -64,7 +64,7 @@ func (e EnterpriseIDPSync) ParseGroupClaims(ctx context.Context, mergedClaims jw } return idpsync.GroupParams{ - SyncEnabled: e.OrganizationSyncEnabled(), + SyncEnabled: true, MergedClaims: mergedClaims, }, nil } diff --git a/enterprise/coderd/enidpsync/groups_test.go b/enterprise/coderd/enidpsync/groups_test.go index 149c57dadd79a..138d2954712de 100644 --- a/enterprise/coderd/enidpsync/groups_test.go +++ b/enterprise/coderd/enidpsync/groups_test.go @@ -10,6 +10,7 @@ import ( "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/idpsync" "github.com/coder/coder/v2/coderd/runtimeconfig" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/enidpsync" "github.com/coder/coder/v2/testutil" ) @@ -17,6 +18,14 @@ import ( func TestEnterpriseParseGroupClaims(t *testing.T) { t.Parallel() + entitled := entitlements.New() + entitled.Update(func(entitlements *codersdk.Entitlements) { + entitlements.Features[codersdk.FeatureTemplateRBAC] = codersdk.Feature{ + Entitlement: codersdk.EntitlementEntitled, + Enabled: true, + } + }) + t.Run("NoEntitlements", func(t *testing.T) { t.Parallel() @@ -32,4 +41,56 @@ func TestEnterpriseParseGroupClaims(t *testing.T) { require.False(t, params.SyncEnabled) }) + + t.Run("NotInAllowList", func(t *testing.T) { + t.Parallel() + + s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewNoopManager(), + entitled, + idpsync.DeploymentSyncSettings{ + GroupField: "groups", + GroupAllowList: map[string]struct{}{ + "foo": {}, + }, + }) + + ctx := testutil.Context(t, testutil.WaitMedium) + + // Try with incorrect group + _, err := s.ParseGroupClaims(ctx, jwt.MapClaims{ + "groups": []string{"bar"}, + }) + require.NotNil(t, err) + require.Equal(t, 403, err.Code) + + // Try with no groups + _, err = s.ParseGroupClaims(ctx, jwt.MapClaims{}) + require.NotNil(t, err) + require.Equal(t, 403, err.Code) + }) + + t.Run("InAllowList", func(t *testing.T) { + t.Parallel() + + s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewNoopManager(), + entitled, + idpsync.DeploymentSyncSettings{ + GroupField: "groups", + GroupAllowList: map[string]struct{}{ + "foo": {}, + }, + }) + + ctx := testutil.Context(t, testutil.WaitMedium) + + claims := jwt.MapClaims{ + "groups": []string{"foo", "bar"}, + } + params, err := s.ParseGroupClaims(ctx, claims) + require.Nil(t, err) + require.True(t, params.SyncEnabled) + require.Equal(t, claims, params.MergedClaims) + }) } From b9476ac14070ee3917ac935606bfb6bbc4523a2b Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 4 Sep 2024 14:29:54 -0500 Subject: [PATCH 09/38] fixup tests, account for existing groups --- coderd/database/queries.sql.go | 7 +- coderd/database/queries/groups.sql | 4 ++ coderd/idpsync/group.go | 20 +++++- coderd/idpsync/group_test.go | 105 +++++++++++++++++++++-------- 4 files changed, 106 insertions(+), 30 deletions(-) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 3e6d6ce61c6fb..b87ad6f857bb9 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1677,11 +1677,16 @@ WHERE ) ELSE true END + AND CASE WHEN array_length($3 :: text[], 1) > 0 THEN + name = ANY($3) + ELSE true + END ` type GetGroupsParams struct { OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` HasMemberID uuid.UUID `db:"has_member_id" json:"has_member_id"` + GroupNames []string `db:"group_names" json:"group_names"` } type GetGroupsRow struct { @@ -1691,7 +1696,7 @@ type GetGroupsRow struct { } func (q *sqlQuerier) GetGroups(ctx context.Context, arg GetGroupsParams) ([]GetGroupsRow, error) { - rows, err := q.db.QueryContext(ctx, getGroups, arg.OrganizationID, arg.HasMemberID) + rows, err := q.db.QueryContext(ctx, getGroups, arg.OrganizationID, arg.HasMemberID, pq.Array(arg.GroupNames)) if err != nil { return nil, err } diff --git a/coderd/database/queries/groups.sql b/coderd/database/queries/groups.sql index 1752ccd112ea7..628395b8a81b0 100644 --- a/coderd/database/queries/groups.sql +++ b/coderd/database/queries/groups.sql @@ -52,6 +52,10 @@ WHERE ) ELSE true END + AND CASE WHEN array_length(@group_names :: text[], 1) > 0 THEN + name = ANY(@group_names) + ELSE true + END ; -- name: InsertGroup :one diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index d45d79bf04cac..de1a3eee6597a 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -244,16 +244,34 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) { if !s.AutoCreateMissingGroups { - // Remove all groups that are missing, they will not be created + // construct the list of groups to search by name to see if they exist. + var lookups []string filter := make([]uuid.UUID, 0) for _, expected := range add { if expected.GroupID != nil { filter = append(filter, *expected.GroupID) + } else if expected.GroupName != nil { + lookups = append(lookups, *expected.GroupName) + } + } + + if len(lookups) > 0 { + newGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ + OrganizationID: uuid.UUID{}, + HasMemberID: uuid.UUID{}, + GroupNames: lookups, + }) + if err != nil { + return nil, xerrors.Errorf("get groups by names: %w", err) + } + for _, g := range newGroups { + filter = append(filter, g.Group.ID) } } return filter, nil } + // All expected that are missing IDs means the group does not exist // in the database. Either remove them, or create them if auto create is // turned on. diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 42465f115488e..6b63b13e76ae5 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -8,6 +8,7 @@ import ( "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/coderdtest" @@ -152,9 +153,26 @@ func TestGroupSyncTable(t *testing.T) { AutoCreateMissingGroups: true, }, Groups: map[uuid.UUID]bool{}, - ExpectedGroups: []uuid.UUID{ - ids.ID("create-bar"), - ids.ID("create-baz"), + ExpectedGroupNames: []string{ + "create-bar", + "create-baz", + }, + }, + { + Name: "GroupNamesNoMapping", + Settings: &idpsync.GroupSyncSettings{ + GroupField: "groups", + RegexFilter: regexp.MustCompile(".*"), + AutoCreateMissingGroups: false, + }, + GroupNames: map[string]bool{ + "foo": false, + "bar": false, + "goob": true, + }, + ExpectedGroupNames: []string{ + "foo", + "bar", }, }, { @@ -219,10 +237,12 @@ func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, org := dbgen.Organization(t, db, database.Organization{ ID: def.OrgID, }) + _, err := db.InsertAllUsersGroup(context.Background(), org.ID) + require.NoError(t, err, "Everyone group for an org") manager := runtimeconfig.NewStoreManager(db) orgResolver := manager.Scoped(org.ID.String()) - err := s.Group.SetRuntimeValue(context.Background(), orgResolver, def.Settings) + err = s.Group.SetRuntimeValue(context.Background(), orgResolver, def.Settings) require.NoError(t, err) if !def.NotMember { @@ -243,47 +263,76 @@ func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, }) } } + for groupName, in := range def.GroupNames { + group := dbgen.Group(t, db, database.Group{ + Name: groupName, + OrganizationID: org.ID, + }) + if in { + dbgen.GroupMember(t, db, database.GroupMemberTable{ + UserID: user.ID, + GroupID: group.ID, + }) + } + } } type orgSetupDefinition struct { Name string OrgID uuid.UUID // True if the user is a member of the group - Groups map[uuid.UUID]bool - NotMember bool + Groups map[uuid.UUID]bool + GroupNames map[string]bool + NotMember bool - Settings *idpsync.GroupSyncSettings - ExpectedGroups []uuid.UUID + Settings *idpsync.GroupSyncSettings + ExpectedGroups []uuid.UUID + ExpectedGroupNames []string } func (o orgSetupDefinition) Assert(t *testing.T, orgID uuid.UUID, db database.Store, user database.User) { t.Helper() - t.Run(o.Name+"-Assert", func(t *testing.T) { - ctx := context.Background() + ctx := context.Background() - members, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{ - OrganizationID: orgID, - UserID: user.ID, - }) - require.NoError(t, err) - if o.NotMember { - require.Len(t, members, 0, "should not be a member") - } else { - require.Len(t, members, 1, "should be a member") - } + members, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{ + OrganizationID: orgID, + UserID: user.ID, + }) + require.NoError(t, err) + if o.NotMember { + require.Len(t, members, 0, "should not be a member") + } else { + require.Len(t, members, 1, "should be a member") + } + + userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{ + OrganizationID: orgID, + HasMemberID: user.ID, + }) + require.NoError(t, err) + if o.ExpectedGroups == nil { + o.ExpectedGroups = make([]uuid.UUID, 0) + } + if len(o.ExpectedGroupNames) > 0 && len(o.ExpectedGroups) > 0 { + t.Fatal("ExpectedGroups and ExpectedGroupNames are mutually exclusive") + } + + // Everyone groups mess up our asserts + userGroups = slices.DeleteFunc(userGroups, func(row database.GetGroupsRow) bool { + return row.Group.ID == row.Group.OrganizationID + }) - userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{ - OrganizationID: orgID, - HasMemberID: user.ID, + if len(o.ExpectedGroupNames) > 0 { + found := db2sdk.List(userGroups, func(g database.GetGroupsRow) string { + return g.Group.Name }) - require.NoError(t, err) - if o.ExpectedGroups == nil { - o.ExpectedGroups = make([]uuid.UUID, 0) - } + require.ElementsMatch(t, o.ExpectedGroupNames, found, "user groups by name") + } else { + // Check by ID, recommended found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID { return g.Group.ID }) require.ElementsMatch(t, o.ExpectedGroups, found, "user groups") - }) + } } From ee8e4e4b07e54611a4f99c0fec383493c0198bf7 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 12:28:30 -0500 Subject: [PATCH 10/38] fix compile issues --- coderd/idpsync/group.go | 5 +---- coderd/idpsync/group_test.go | 10 +++++----- coderd/idpsync/idpsync.go | 13 ++++++------- coderd/idpsync/organizations_test.go | 4 ++-- 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index de1a3eee6597a..cedcb8ba8eaae 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -13,7 +13,6 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" - "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/coderd/util/slice" ) @@ -44,8 +43,6 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat ctx = dbauthz.AsSystemRestricted(ctx) db.InTx(func(tx database.Store) error { - manager := runtimeconfig.NewStoreManager(tx) - userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ HasMemberID: user.ID, }) @@ -63,7 +60,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat // For each org, we need to fetch the sync settings orgSettings := make(map[uuid.UUID]GroupSyncSettings) for orgID := range userOrgs { - orgResolver := manager.Scoped(orgID.String()) + orgResolver := s.Manager.OrganizationResolver(tx, orgID) settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver) if err != nil { return xerrors.Errorf("resolve group sync settings: %w", err) diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 6b63b13e76ae5..456e0752ebc1e 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -28,7 +28,7 @@ func TestParseGroupClaims(t *testing.T) { t.Parallel() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewNoopManager(), + runtimeconfig.NewStoreManager(), idpsync.DeploymentSyncSettings{}) ctx := testutil.Context(t, testutil.WaitMedium) @@ -44,7 +44,7 @@ func TestParseGroupClaims(t *testing.T) { t.Parallel() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewNoopManager(), + runtimeconfig.NewStoreManager(), idpsync.DeploymentSyncSettings{ GroupField: "groups", GroupAllowList: map[string]struct{}{ @@ -209,7 +209,7 @@ func TestGroupSyncTable(t *testing.T) { } db, _ := dbtestutil.NewDB(t) - manager := runtimeconfig.NewStoreManager(db) + manager := runtimeconfig.NewStoreManager() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), manager, idpsync.DeploymentSyncSettings{ @@ -240,8 +240,8 @@ func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, _, err := db.InsertAllUsersGroup(context.Background(), org.ID) require.NoError(t, err, "Everyone group for an org") - manager := runtimeconfig.NewStoreManager(db) - orgResolver := manager.Scoped(org.ID.String()) + manager := runtimeconfig.NewStoreManager() + orgResolver := manager.OrganizationResolver(db, org.ID) err = s.Group.SetRuntimeValue(context.Background(), orgResolver, def.Settings) require.NoError(t, err) diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 3ff8d78fd5174..b462f5da01bdb 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -43,7 +43,8 @@ type IDPSync interface { // AGPLIDPSync is the configuration for syncing user information from an external // IDP. All related code to syncing user information should be in this package. type AGPLIDPSync struct { - Logger slog.Logger + Logger slog.Logger + Manager runtimeconfig.Manager SyncSettings } @@ -88,7 +89,7 @@ func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings type SyncSettings struct { DeploymentSyncSettings - Group runtimeconfig.Entry[*GroupSyncSettings] + Group runtimeconfig.RuntimeEntry[*GroupSyncSettings] //// Group options here are set by the deployment config and only apply to //// the default organization. @@ -100,13 +101,11 @@ type SyncSettings struct { func NewAGPLSync(logger slog.Logger, manager runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync { return &AGPLIDPSync{ - Logger: logger.Named("idp-sync"), + Logger: logger.Named("idp-sync"), + Manager: manager, SyncSettings: SyncSettings{ DeploymentSyncSettings: settings, - // Default to '{}' if the group sync settings are not set. - // TODO: Feels strange to have to define the type as a string. I should be - // able to pass in an empty struct. - Group: runtimeconfig.MustNew[*GroupSyncSettings]("group-sync-settings", "{}"), + Group: runtimeconfig.MustNew[*GroupSyncSettings]("group-sync-settings"), }, } } diff --git a/coderd/idpsync/organizations_test.go b/coderd/idpsync/organizations_test.go index b0e7728b0640a..934d7d83816ab 100644 --- a/coderd/idpsync/organizations_test.go +++ b/coderd/idpsync/organizations_test.go @@ -20,7 +20,7 @@ func TestParseOrganizationClaims(t *testing.T) { t.Parallel() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewNoopManager(), + runtimeconfig.NewStoreManager(), idpsync.DeploymentSyncSettings{ OrganizationField: "", OrganizationMapping: nil, @@ -42,7 +42,7 @@ func TestParseOrganizationClaims(t *testing.T) { // AGPL has limited behavior s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewNoopManager(), + runtimeconfig.NewStoreManager(), idpsync.DeploymentSyncSettings{ OrganizationField: "orgs", OrganizationMapping: map[string][]uuid.UUID{ From d5ff0f7bfa82b6abe2be2f5d7c030c66393da913 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 12:36:56 -0500 Subject: [PATCH 11/38] add comment for test helper --- coderd/coderdtest/uuids.go | 4 ++++ coderd/coderdtest/uuids_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 coderd/coderdtest/uuids_test.go diff --git a/coderd/coderdtest/uuids.go b/coderd/coderdtest/uuids.go index aefa6e83c0b3c..1ff60bf26c572 100644 --- a/coderd/coderdtest/uuids.go +++ b/coderd/coderdtest/uuids.go @@ -2,6 +2,10 @@ package coderdtest import "github.com/google/uuid" +// DeterministicUUIDGenerator allows "naming" uuids for unit tests. +// An example of where this is useful, is when a tabled test references +// a UUID that is not yet known. An alternative to this would be to +// hard code some UUID strings, but these strings are not human friendly. type DeterministicUUIDGenerator struct { Named map[string]uuid.UUID } diff --git a/coderd/coderdtest/uuids_test.go b/coderd/coderdtest/uuids_test.go new file mode 100644 index 0000000000000..bb92d6faffabd --- /dev/null +++ b/coderd/coderdtest/uuids_test.go @@ -0,0 +1,33 @@ +package coderdtest_test + +import ( + "github.com/google/uuid" + + "github.com/coder/coder/v2/coderd/coderdtest" +) + +func ExampleNewDeterministicUUIDGenerator() { + det := coderdtest.NewDeterministicUUIDGenerator() + testCases := []struct { + CreateUsers []uuid.UUID + ExpectedIDs []uuid.UUID + }{ + { + CreateUsers: []uuid.UUID{ + det.ID("player1"), + det.ID("player2"), + }, + ExpectedIDs: []uuid.UUID{ + det.ID("player1"), + det.ID("player2"), + }, + }, + } + + for _, tc := range testCases { + tc := tc + var _ = tc + // Do the test with CreateUsers as the setup, and the expected IDs + // will match. + } +} From 86c0f6f52eb79c142920a9abce15b6d40e4f315c Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 14:41:42 -0500 Subject: [PATCH 12/38] handle legacy params --- coderd/database/queries.sql.go | 2 +- coderd/database/queries/groups.sql | 2 +- coderd/idpsync/group.go | 34 ++++++++++++++++++++++++++++++ coderd/idpsync/idpsync.go | 24 +++++++++++++++------ 4 files changed, 53 insertions(+), 9 deletions(-) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index b87ad6f857bb9..7c7fbbf0f88f0 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1678,7 +1678,7 @@ WHERE ELSE true END AND CASE WHEN array_length($3 :: text[], 1) > 0 THEN - name = ANY($3) + groups.name = ANY($3) ELSE true END ` diff --git a/coderd/database/queries/groups.sql b/coderd/database/queries/groups.sql index 628395b8a81b0..0df848d6a6d05 100644 --- a/coderd/database/queries/groups.sql +++ b/coderd/database/queries/groups.sql @@ -53,7 +53,7 @@ WHERE ELSE true END AND CASE WHEN array_length(@group_names :: text[], 1) > 0 THEN - name = ANY(@group_names) + groups.name = ANY(@group_names) ELSE true END ; diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index cedcb8ba8eaae..5acf9665f80ce 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -39,6 +39,18 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return nil } + // Only care about the default org for deployment settings if the + // legacy deployment settings exist. + defaultOrgID := uuid.Nil + // Default organization is configured via legacy deployment values + if s.DeploymentSyncSettings.Legacy.GroupField != "" { + defaultOrganization, err := db.GetDefaultOrganization(ctx) + if err != nil { + return xerrors.Errorf("get default organization: %w", err) + } + defaultOrgID = defaultOrganization.ID + } + // nolint:gocritic // all syncing is done as a system user ctx = dbauthz.AsSystemRestricted(ctx) @@ -66,6 +78,16 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return xerrors.Errorf("resolve group sync settings: %w", err) } orgSettings[orgID] = *settings + + // Legacy deployment settings will override empty settings. + if orgID == defaultOrgID && settings.GroupField == "" { + settings = &GroupSyncSettings{ + GroupField: s.Legacy.GroupField, + LegacyGroupNameMapping: s.Legacy.GroupMapping, + RegexFilter: s.Legacy.GroupFilter, + AutoCreateMissingGroups: s.Legacy.CreateMissingGroups, + } + } } // collect all diffs to do 1 sql update for all orgs @@ -175,6 +197,12 @@ type GroupSyncSettings struct { GroupMapping map[string][]uuid.UUID `json:"mapping"` RegexFilter *regexp.Regexp `json:"regex_filter"` AutoCreateMissingGroups bool `json:"auto_create_missing_groups"` + // LegacyGroupNameMapping is deprecated. It remaps an IDP group name to + // a Coder group name. Since configuration is now done at runtime, + // group IDs are used to account for group renames. + // For legacy configurations, this config option has to remain. + // Deprecated: Use GroupMapping instead. + LegacyGroupNameMapping map[string]string } func (s *GroupSyncSettings) Set(v string) error { @@ -232,6 +260,12 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr } continue } + + mappedGroupName, ok := s.LegacyGroupNameMapping[group] + if ok { + groups = append(groups, ExpectedGroup{GroupName: &mappedGroupName}) + continue + } group := group groups = append(groups, ExpectedGroup{GroupName: &group}) } diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index b462f5da01bdb..a01e3bc14f745 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -3,6 +3,7 @@ package idpsync import ( "context" "net/http" + "regexp" "strings" "github.com/golang-jwt/jwt/v4" @@ -69,6 +70,15 @@ type DeploymentSyncSettings struct { // have at least one group in this list. // A map representation is used for easier lookup. GroupAllowList map[string]struct{} + // Legacy deployment settings that only apply to the default org. + Legacy DefaultOrgLegacySettings +} + +type DefaultOrgLegacySettings struct { + GroupField string + GroupMapping map[string]string + GroupFilter *regexp.Regexp + CreateMissingGroups bool } func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings { @@ -80,8 +90,15 @@ func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings OrganizationMapping: dv.OIDC.OrganizationMapping.Value, OrganizationAssignDefault: dv.OIDC.OrganizationAssignDefault.Value(), + // TODO: Separate group field for allow list from default org GroupField: dv.OIDC.GroupField.Value(), GroupAllowList: ConvertAllowList(dv.OIDC.GroupAllowList.Value()), + Legacy: DefaultOrgLegacySettings{ + GroupField: dv.OIDC.GroupField.Value(), + GroupMapping: dv.OIDC.GroupMapping.Value, + GroupFilter: dv.OIDC.GroupRegexFilter.Value(), + CreateMissingGroups: dv.OIDC.GroupAutoCreate.Value(), + }, } } @@ -90,13 +107,6 @@ type SyncSettings struct { DeploymentSyncSettings Group runtimeconfig.RuntimeEntry[*GroupSyncSettings] - - //// Group options here are set by the deployment config and only apply to - //// the default organization. - //GroupField string - //CreateMissingGroups bool - //GroupMapping map[string]string - //GroupFilter *regexp.Regexp } func NewAGPLSync(logger slog.Logger, manager runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync { From 2f03e182b2554c449e1ee1eb13a16cdc5321270a Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 14:44:00 -0500 Subject: [PATCH 13/38] make gen --- coderd/database/dbmock/dbmock.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index fe2e444ff5c67..c5d579e1c2656 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -4118,6 +4118,21 @@ func (mr *MockStoreMockRecorder) RemoveUserFromAllGroups(arg0, arg1 any) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromAllGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromAllGroups), arg0, arg1) } +// RemoveUserFromGroups mocks base method. +func (m *MockStore) RemoveUserFromGroups(arg0 context.Context, arg1 database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveUserFromGroups", arg0, arg1) + ret0, _ := ret[0].([]uuid.UUID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RemoveUserFromGroups indicates an expected call of RemoveUserFromGroups. +func (mr *MockStoreMockRecorder) RemoveUserFromGroups(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromGroups), arg0, arg1) +} + // RevokeDBCryptKey mocks base method. func (m *MockStore) RevokeDBCryptKey(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() From ec8092d25c3fe8c81edef30afb2af9790c9117a1 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 14:46:36 -0500 Subject: [PATCH 14/38] cleanup --- coderd/coderdtest/uuids_test.go | 2 +- coderd/database/dbauthz/dbauthz.go | 6 +++++- coderd/database/dbmem/dbmem.go | 4 ++-- coderd/idpsync/group.go | 8 +++++++- coderd/idpsync/idpsync.go | 1 - enterprise/coderd/enidpsync/groups_test.go | 6 +++--- 6 files changed, 18 insertions(+), 9 deletions(-) diff --git a/coderd/coderdtest/uuids_test.go b/coderd/coderdtest/uuids_test.go index bb92d6faffabd..5a0e10935bd50 100644 --- a/coderd/coderdtest/uuids_test.go +++ b/coderd/coderdtest/uuids_test.go @@ -26,7 +26,7 @@ func ExampleNewDeterministicUUIDGenerator() { for _, tc := range testCases { tc := tc - var _ = tc + _ = tc // Do the test with CreateUsers as the setup, and the expected IDs // will match. } diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index eaf994e849fc5..077d704be1300 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3109,7 +3109,11 @@ func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) } func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { - panic("not implemented") + // This is a system function to clear user groups in group sync. + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.RemoveUserFromGroups(ctx, arg) } func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error { diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 37811063997db..6f0c04eb4e512 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -7015,7 +7015,7 @@ func (q *FakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam return user, nil } -func (q *FakeQuerier) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { +func (q *FakeQuerier) InsertUserGroupsByID(_ context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { err := validateDatabaseType(arg) if err != nil { return nil, err @@ -7637,7 +7637,7 @@ func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUI return nil } -func (q *FakeQuerier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { +func (q *FakeQuerier) RemoveUserFromGroups(_ context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { err := validateDatabaseType(arg) if err != nil { return nil, err diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 5acf9665f80ce..07ead53cd52c2 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -54,7 +54,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat // nolint:gocritic // all syncing is done as a system user ctx = dbauthz.AsSystemRestricted(ctx) - db.InTx(func(tx database.Store) error { + err := db.InTx(func(tx database.Store) error { userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ HasMemberID: user.ID, }) @@ -188,6 +188,10 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return nil }, nil) + if err != nil { + return err + } + return nil } @@ -208,6 +212,7 @@ type GroupSyncSettings struct { func (s *GroupSyncSettings) Set(v string) error { return json.Unmarshal([]byte(v), s) } + func (s *GroupSyncSettings) String() string { v, err := json.Marshal(s) if err != nil { @@ -215,6 +220,7 @@ func (s *GroupSyncSettings) String() string { } return string(v) } + func (s *GroupSyncSettings) Type() string { return "GroupSyncSettings" } diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index a01e3bc14f745..bc3e5cd479064 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -100,7 +100,6 @@ func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings CreateMissingGroups: dv.OIDC.GroupAutoCreate.Value(), }, } - } type SyncSettings struct { diff --git a/enterprise/coderd/enidpsync/groups_test.go b/enterprise/coderd/enidpsync/groups_test.go index 138d2954712de..8103f8a002937 100644 --- a/enterprise/coderd/enidpsync/groups_test.go +++ b/enterprise/coderd/enidpsync/groups_test.go @@ -30,7 +30,7 @@ func TestEnterpriseParseGroupClaims(t *testing.T) { t.Parallel() s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewNoopManager(), + runtimeconfig.NewStoreManager(), entitlements.New(), idpsync.DeploymentSyncSettings{}) @@ -46,7 +46,7 @@ func TestEnterpriseParseGroupClaims(t *testing.T) { t.Parallel() s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewNoopManager(), + runtimeconfig.NewStoreManager(), entitled, idpsync.DeploymentSyncSettings{ GroupField: "groups", @@ -74,7 +74,7 @@ func TestEnterpriseParseGroupClaims(t *testing.T) { t.Parallel() s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewNoopManager(), + runtimeconfig.NewStoreManager(), entitled, idpsync.DeploymentSyncSettings{ GroupField: "groups", From d63727d5288de948555f6355643e2d45bc608d21 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 15:08:52 -0500 Subject: [PATCH 15/38] add unit test for legacy behavior --- coderd/idpsync/group.go | 30 ++++++++++++++++++++++++------ coderd/idpsync/group_test.go | 20 ++++++++++++++++++++ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 07ead53cd52c2..c64b08ee07553 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -206,7 +206,7 @@ type GroupSyncSettings struct { // group IDs are used to account for group renames. // For legacy configurations, this config option has to remain. // Deprecated: Use GroupMapping instead. - LegacyGroupNameMapping map[string]string + LegacyGroupNameMapping map[string]string `json:"legacy_group_name_mapping,omitempty"` } func (s *GroupSyncSettings) Set(v string) error { @@ -251,6 +251,12 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr groups := make([]ExpectedGroup, 0) for _, group := range parsedGroups { + // Legacy group mappings happen before the regex filter. + mappedGroupName, ok := s.LegacyGroupNameMapping[group] + if ok { + group = mappedGroupName + } + // Only allow through groups that pass the regex if s.RegexFilter != nil { if !s.RegexFilter.MatchString(group) { @@ -267,11 +273,6 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr continue } - mappedGroupName, ok := s.LegacyGroupNameMapping[group] - if ok { - groups = append(groups, ExpectedGroup{GroupName: &mappedGroupName}) - continue - } group := group groups = append(groups, ExpectedGroup{GroupName: &group}) } @@ -332,6 +333,23 @@ func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database. if err != nil { return nil, xerrors.Errorf("insert missing groups: %w", err) } + + if len(missingGroups) != len(createdMissingGroups) { + // This is unfortunate, but if legacy params are used, then some existing groups + // can come as params. So we need to fetch them + allGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ + OrganizationID: orgID, + GroupNames: missingGroups, + }) + if err != nil { + return nil, xerrors.Errorf("get groups by names: %w", err) + } + + createdMissingGroups = db2sdk.List(allGroups, func(g database.GetGroupsRow) database.Group { + return g.Group + }) + } + for _, created := range createdMissingGroups { addIDs = append(addIDs, created.ID) } diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 456e0752ebc1e..406df099167c3 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -197,6 +197,26 @@ func TestGroupSyncTable(t *testing.T) { Settings: nil, Groups: map[uuid.UUID]bool{}, }, + { + Name: "LegacyMapping", + Settings: &idpsync.GroupSyncSettings{ + GroupField: "groups", + RegexFilter: regexp.MustCompile("^legacy"), + LegacyGroupNameMapping: map[string]string{ + "create-bar": "legacy-bar", + "foo": "legacy-foo", + }, + AutoCreateMissingGroups: true, + }, + Groups: map[uuid.UUID]bool{}, + GroupNames: map[string]bool{ + "legacy-foo": false, + }, + ExpectedGroupNames: []string{ + "legacy-bar", + "legacy-foo", + }, + }, } for _, tc := range testCases { From 2a1769c7fdcd34b5c7a82f335d7f1e4f05b696ce Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 15:41:41 -0500 Subject: [PATCH 16/38] work on batching removal by name or id --- coderd/database/dbmem/dbmem.go | 41 ++++++--- coderd/database/queries.sql.go | 19 +++- coderd/database/queries/groupmembers.sql | 12 ++- coderd/idpsync/group.go | 108 +++++++++++++++-------- coderd/idpsync/group_test.go | 5 +- 5 files changed, 130 insertions(+), 55 deletions(-) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 6f0c04eb4e512..fd97fb0d701bf 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -682,6 +682,17 @@ func (q *FakeQuerier) getWorkspaceResourcesByJobIDNoLock(_ context.Context, jobI return resources, nil } +func (q *FakeQuerier) getGroupByNameNoLock(arg database.NameOrganizationPair) (database.Group, error) { + for _, group := range q.groups { + if group.OrganizationID == arg.OrganizationID && + group.Name == arg.Name { + return group, nil + } + } + + return database.Group{}, sql.ErrNoRows +} + func (q *FakeQuerier) getGroupByIDNoLock(_ context.Context, id uuid.UUID) (database.Group, error) { for _, group := range q.groups { if group.ID == id { @@ -2613,14 +2624,10 @@ func (q *FakeQuerier) GetGroupByOrgAndName(_ context.Context, arg database.GetGr q.mutex.RLock() defer q.mutex.RUnlock() - for _, group := range q.groups { - if group.OrganizationID == arg.OrganizationID && - group.Name == arg.Name { - return group, nil - } - } - - return database.Group{}, sql.ErrNoRows + return q.getGroupByNameNoLock(database.NameOrganizationPair{ + Name: arg.Name, + OrganizationID: arg.OrganizationID, + }) } func (q *FakeQuerier) GetGroupMembers(ctx context.Context) ([]database.GroupMember, error) { @@ -7648,14 +7655,24 @@ func (q *FakeQuerier) RemoveUserFromGroups(_ context.Context, arg database.Remov removed := make([]uuid.UUID, 0) q.data.groupMembers = slices.DeleteFunc(q.data.groupMembers, func(groupMember database.GroupMemberTable) bool { + // Delete all group members that match the arguments. if groupMember.UserID != arg.UserID { + // Not the right user, ignore. return false } - if !slices.Contains(arg.GroupIds, groupMember.GroupID) { - return false + + matchesByID := slices.Contains(arg.GroupIds, groupMember.GroupID) + matchesByName := slices.ContainsFunc(arg.GroupNames, func(name database.NameOrganizationPair) bool { + _, err := q.getGroupByNameNoLock(name) + return err == nil + }) + + if matchesByName || matchesByID { + removed = append(removed, groupMember.GroupID) + return true } - removed = append(removed, groupMember.GroupID) - return true + + return false }) return removed, nil diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 7c7fbbf0f88f0..04c111bcee78f 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1540,19 +1540,30 @@ func (q *sqlQuerier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UU const removeUserFromGroups = `-- name: RemoveUserFromGroups :many DELETE FROM group_members + USING groups WHERE + group_members.group_id = groups.id AND user_id = $1 AND - group_id = ANY($2 :: uuid []) + ( + CASE WHEN array_length($2 :: name_organization_pair[], 1) > 0 THEN + -- Using 'coalesce' to avoid troubles with null literals being an empty string. + (groups.name, coalesce(groups.organization_id, '00000000-0000-0000-0000-000000000000' ::uuid)) = ANY ($2::name_organization_pair[]) + ELSE false + END + OR + group_id = ANY ($3 :: uuid[]) + ) RETURNING group_id ` type RemoveUserFromGroupsParams struct { - UserID uuid.UUID `db:"user_id" json:"user_id"` - GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + GroupNames []NameOrganizationPair `db:"group_names" json:"group_names"` + GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"` } func (q *sqlQuerier) RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) { - rows, err := q.db.QueryContext(ctx, removeUserFromGroups, arg.UserID, pq.Array(arg.GroupIds)) + rows, err := q.db.QueryContext(ctx, removeUserFromGroups, arg.UserID, pq.Array(arg.GroupNames), pq.Array(arg.GroupIds)) if err != nil { return nil, err } diff --git a/coderd/database/queries/groupmembers.sql b/coderd/database/queries/groupmembers.sql index 814f878cb9232..5345d976fcd4e 100644 --- a/coderd/database/queries/groupmembers.sql +++ b/coderd/database/queries/groupmembers.sql @@ -57,9 +57,19 @@ WHERE -- name: RemoveUserFromGroups :many DELETE FROM group_members + USING groups WHERE + group_members.group_id = groups.id AND user_id = @user_id AND - group_id = ANY(@group_ids :: uuid []) + ( + CASE WHEN array_length(@group_names :: name_organization_pair[], 1) > 0 THEN + -- Using 'coalesce' to avoid troubles with null literals being an empty string. + (groups.name, coalesce(groups.organization_id, '00000000-0000-0000-0000-000000000000' ::uuid)) = ANY (@group_names::name_organization_pair[]) + ELSE false + END + OR + group_id = ANY (@group_ids :: uuid[]) + ) RETURNING group_id; -- name: InsertGroupMember :exec diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index c64b08ee07553..3560238e13b4a 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -91,8 +91,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat } // collect all diffs to do 1 sql update for all orgs - groupsToAdd := make([]uuid.UUID, 0) - groupsToRemove := make([]uuid.UUID, 0) + groupIDsToAdd := make([]uuid.UUID, 0) + groupsToRemove := make([]ExpectedGroup, 0) // For each org, determine which groups the user should land in for orgID, settings := range orgSettings { if settings.GroupField == "" { @@ -112,7 +112,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat } // Everyone group is always implied. expectedGroups = append(expectedGroups, ExpectedGroup{ - GroupID: &orgID, + OrganizationID: orgID, + GroupID: &orgID, }) // Now we know what groups the user should be in for a given org, @@ -121,8 +122,9 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat existingGroups := userOrgs[orgID] existingGroupsTyped := db2sdk.List(existingGroups, func(f database.GetGroupsRow) ExpectedGroup { return ExpectedGroup{ - GroupID: &f.Group.ID, - GroupName: &f.Group.Name, + OrganizationID: orgID, + GroupID: &f.Group.ID, + GroupName: &f.Group.Name, } }) add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool { @@ -144,52 +146,75 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return xerrors.Errorf("handle missing groups: %w", err) } - for _, removeGroup := range remove { - // This should always be the case. - // TODO: make sure this is always the case - if removeGroup.GroupID != nil { - groupsToRemove = append(groupsToRemove, *removeGroup.GroupID) - } - } + groupsToRemove = append(groupsToRemove, remove...) + groupIDsToAdd = append(groupIDsToAdd, assignGroups...) + } - groupsToAdd = append(groupsToAdd, assignGroups...) + err = s.applyGroupDifference(ctx, tx, user, groupIDsToAdd, groupsToRemove) + if err != nil { + return xerrors.Errorf("apply group difference: %w", err) } - assignedGroupIDs, err := tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{ - UserID: user.ID, - GroupIds: groupsToAdd, + return nil + }, nil) + + if err != nil { + return err + } + + return nil +} + +func (s AGPLIDPSync) applyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, remove []ExpectedGroup) error { + // Always do group removal before group add. This way if there is an error, + // we error on the underprivileged side. + removeIDs := make([]uuid.UUID, 0) + removeNames := make([]database.NameOrganizationPair, 0) + for _, r := range remove { + if r.GroupID != nil { + removeIDs = append(removeIDs, *r.GroupID) + } else if r.GroupName != nil { + removeNames = append(removeNames, database.NameOrganizationPair{ + Name: *r.GroupName, + OrganizationID: r.OrganizationID, + }) + } + } + + // If there is something to remove, do it. + if len(removeIDs) > 0 || len(removeNames) > 0 { + removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{ + UserID: user.ID, + GroupNames: removeNames, + GroupIds: removeIDs, }) if err != nil { - return xerrors.Errorf("insert user into %d groups: %w", len(groupsToAdd), err) + return xerrors.Errorf("remove user from %d groups: %w", len(removeIDs), err) } - if len(assignedGroupIDs) != len(groupsToAdd) { - s.Logger.Debug(ctx, "failed to assign all groups to user", + if len(removedGroupIDs) != len(removeIDs) { + s.Logger.Debug(ctx, "failed to remove user from all groups", slog.F("user_id", user.ID), - slog.F("groups_assigned_count", len(assignedGroupIDs)), - slog.F("expected_count", len(groupsToAdd)), + slog.F("groups_removed_count", len(removedGroupIDs)), + slog.F("expected_count", len(removeIDs)), ) } + } - removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{ + if len(add) > 0 { + assignedGroupIDs, err := tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{ UserID: user.ID, - GroupIds: groupsToRemove, + GroupIds: add, }) if err != nil { - return xerrors.Errorf("remove user from %d groups: %w", len(groupsToRemove), err) + return xerrors.Errorf("insert user into %d groups: %w", len(add), err) } - if len(removedGroupIDs) != len(groupsToRemove) { - s.Logger.Debug(ctx, "failed to remove user from all groups", + if len(assignedGroupIDs) != len(add) { + s.Logger.Debug(ctx, "failed to assign all groups to user", slog.F("user_id", user.ID), - slog.F("groups_removed_count", len(removedGroupIDs)), - slog.F("expected_count", len(groupsToRemove)), + slog.F("groups_assigned_count", len(assignedGroupIDs)), + slog.F("expected_count", len(add)), ) } - - return nil - }, nil) - - if err != nil { - return err } return nil @@ -226,8 +251,9 @@ func (s *GroupSyncSettings) Type() string { } type ExpectedGroup struct { - GroupID *uuid.UUID - GroupName *string + OrganizationID uuid.UUID + GroupID *uuid.UUID + GroupName *string } // ParseClaims will take the merged claims from the IDP and return the groups @@ -280,13 +306,20 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr return groups, nil } +// HandleMissingGroups ensures all ExpectedGroups convert to uuids. +// Groups can be referenced by name via legacy params or IDP group names. +// These group names are converted to IDs for easier assignment. +// Missing groups are created if AutoCreate is enabled. +// TODO: Batching this would be better, as this is 1 or 2 db calls per organization. func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) { if !s.AutoCreateMissingGroups { - // construct the list of groups to search by name to see if they exist. + // If we are not creating groups, then just construct a db lookup for + // all groups by name. var lookups []string filter := make([]uuid.UUID, 0) for _, expected := range add { if expected.GroupID != nil { + // Groups with IDs are easy! filter = append(filter, *expected.GroupID) } else if expected.GroupName != nil { lookups = append(lookups, *expected.GroupName) @@ -294,6 +327,7 @@ func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database. } if len(lookups) > 0 { + // Do name lookups for all groups that are missing IDs. newGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ OrganizationID: uuid.UUID{}, HasMemberID: uuid.UUID{}, diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 406df099167c3..e1d0ac9d6c095 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -208,9 +208,12 @@ func TestGroupSyncTable(t *testing.T) { }, AutoCreateMissingGroups: true, }, - Groups: map[uuid.UUID]bool{}, + Groups: map[uuid.UUID]bool{ + ids.ID("lg-foo"): true, + }, GroupNames: map[string]bool{ "legacy-foo": false, + "extra": true, }, ExpectedGroupNames: []string{ "legacy-bar", From 640e86e47d633feb767de942b7faa3a437c8bc03 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 16:08:36 -0500 Subject: [PATCH 17/38] group sync adjustments --- coderd/database/dbmem/dbmem.go | 18 ++-- coderd/database/queries.sql.go | 19 +--- coderd/database/queries/groupmembers.sql | 12 +-- coderd/idpsync/group.go | 125 +++++++++-------------- coderd/idpsync/group_test.go | 2 + 5 files changed, 63 insertions(+), 113 deletions(-) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index fd97fb0d701bf..7e761de411615 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -2730,6 +2730,10 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams) continue } + if len(arg.GroupNames) > 0 && !slices.Contains(arg.GroupNames, group.Name) { + continue + } + orgDetails, ok := orgDetailsCache[group.ID] if !ok { for _, org := range q.organizations { @@ -7661,18 +7665,12 @@ func (q *FakeQuerier) RemoveUserFromGroups(_ context.Context, arg database.Remov return false } - matchesByID := slices.Contains(arg.GroupIds, groupMember.GroupID) - matchesByName := slices.ContainsFunc(arg.GroupNames, func(name database.NameOrganizationPair) bool { - _, err := q.getGroupByNameNoLock(name) - return err == nil - }) - - if matchesByName || matchesByID { - removed = append(removed, groupMember.GroupID) - return true + if !slices.Contains(arg.GroupIds, groupMember.GroupID) { + return false } - return false + removed = append(removed, groupMember.GroupID) + return true }) return removed, nil diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 04c111bcee78f..7c7fbbf0f88f0 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1540,30 +1540,19 @@ func (q *sqlQuerier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UU const removeUserFromGroups = `-- name: RemoveUserFromGroups :many DELETE FROM group_members - USING groups WHERE - group_members.group_id = groups.id AND user_id = $1 AND - ( - CASE WHEN array_length($2 :: name_organization_pair[], 1) > 0 THEN - -- Using 'coalesce' to avoid troubles with null literals being an empty string. - (groups.name, coalesce(groups.organization_id, '00000000-0000-0000-0000-000000000000' ::uuid)) = ANY ($2::name_organization_pair[]) - ELSE false - END - OR - group_id = ANY ($3 :: uuid[]) - ) + group_id = ANY($2 :: uuid []) RETURNING group_id ` type RemoveUserFromGroupsParams struct { - UserID uuid.UUID `db:"user_id" json:"user_id"` - GroupNames []NameOrganizationPair `db:"group_names" json:"group_names"` - GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"` } func (q *sqlQuerier) RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) { - rows, err := q.db.QueryContext(ctx, removeUserFromGroups, arg.UserID, pq.Array(arg.GroupNames), pq.Array(arg.GroupIds)) + rows, err := q.db.QueryContext(ctx, removeUserFromGroups, arg.UserID, pq.Array(arg.GroupIds)) if err != nil { return nil, err } diff --git a/coderd/database/queries/groupmembers.sql b/coderd/database/queries/groupmembers.sql index 5345d976fcd4e..814f878cb9232 100644 --- a/coderd/database/queries/groupmembers.sql +++ b/coderd/database/queries/groupmembers.sql @@ -57,19 +57,9 @@ WHERE -- name: RemoveUserFromGroups :many DELETE FROM group_members - USING groups WHERE - group_members.group_id = groups.id AND user_id = @user_id AND - ( - CASE WHEN array_length(@group_names :: name_organization_pair[], 1) > 0 THEN - -- Using 'coalesce' to avoid troubles with null literals being an empty string. - (groups.name, coalesce(groups.organization_id, '00000000-0000-0000-0000-000000000000' ::uuid)) = ANY (@group_names::name_organization_pair[]) - ELSE false - END - OR - group_id = ANY (@group_ids :: uuid[]) - ) + group_id = ANY(@group_ids :: uuid []) RETURNING group_id; -- name: InsertGroupMember :exec diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 3560238e13b4a..f076c7c5d5c87 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -3,6 +3,7 @@ package idpsync import ( "context" "encoding/json" + "fmt" "regexp" "github.com/golang-jwt/jwt/v4" @@ -92,7 +93,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat // collect all diffs to do 1 sql update for all orgs groupIDsToAdd := make([]uuid.UUID, 0) - groupsToRemove := make([]ExpectedGroup, 0) + groupIDsToRemove := make([]uuid.UUID, 0) // For each org, determine which groups the user should land in for orgID, settings := range orgSettings { if settings.GroupField == "" { @@ -100,7 +101,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat continue } - expectedGroups, err := settings.ParseClaims(params.MergedClaims) + expectedGroups, err := settings.ParseClaims(orgID, params.MergedClaims) if err != nil { s.Logger.Debug(ctx, "failed to parse claims for groups", slog.F("organization_field", s.GroupField), @@ -128,6 +129,10 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat } }) add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool { + // Must match + if a.OrganizationID != b.OrganizationID { + return false + } // Only the name or the name needs to be checked, priority is given to the ID. if a.GroupID != nil && b.GroupID != nil { return *a.GroupID == *b.GroupID @@ -138,6 +143,20 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return false }) + for _, r := range remove { + // This should never happen. All group removals come from the + // existing set, which come from the db. All groups from the + // database have IDs. This code is purely defensive. + if r.GroupID == nil { + detail := "user:" + user.Username + if r.GroupName != nil { + detail += fmt.Sprintf(" from group %s", *r.GroupName) + } + return xerrors.Errorf("removal group has nil ID, which should never happen: %s", detail) + } + groupIDsToRemove = append(groupIDsToRemove, *r.GroupID) + } + // HandleMissingGroups will add the new groups to the org if // the settings specify. It will convert all group names into uuids // for easier assignment. @@ -146,11 +165,10 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return xerrors.Errorf("handle missing groups: %w", err) } - groupsToRemove = append(groupsToRemove, remove...) groupIDsToAdd = append(groupIDsToAdd, assignGroups...) } - err = s.applyGroupDifference(ctx, tx, user, groupIDsToAdd, groupsToRemove) + err = s.applyGroupDifference(ctx, tx, user, groupIDsToAdd, groupIDsToRemove) if err != nil { return xerrors.Errorf("apply group difference: %w", err) } @@ -165,28 +183,13 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return nil } -func (s AGPLIDPSync) applyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, remove []ExpectedGroup) error { +func (s AGPLIDPSync) applyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, removeIDs []uuid.UUID) error { // Always do group removal before group add. This way if there is an error, // we error on the underprivileged side. - removeIDs := make([]uuid.UUID, 0) - removeNames := make([]database.NameOrganizationPair, 0) - for _, r := range remove { - if r.GroupID != nil { - removeIDs = append(removeIDs, *r.GroupID) - } else if r.GroupName != nil { - removeNames = append(removeNames, database.NameOrganizationPair{ - Name: *r.GroupName, - OrganizationID: r.OrganizationID, - }) - } - } - - // If there is something to remove, do it. - if len(removeIDs) > 0 || len(removeNames) > 0 { + if len(removeIDs) > 0 { removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{ - UserID: user.ID, - GroupNames: removeNames, - GroupIds: removeIDs, + UserID: user.ID, + GroupIds: removeIDs, }) if err != nil { return xerrors.Errorf("remove user from %d groups: %w", len(removeIDs), err) @@ -264,7 +267,7 @@ type ExpectedGroup struct { // the group "UUID 1234" is renamed, we want to maintain the mapping. // We have to keep names because group sync supports syncing groups by name if // the external IDP group name matches the Coder one. -func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGroup, error) { +func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClaims) ([]ExpectedGroup, error) { groupsRaw, ok := mergedClaims[s.GroupField] if !ok { return []ExpectedGroup{}, nil @@ -294,13 +297,13 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr if ok { for _, gid := range mappedGroupIDs { gid := gid - groups = append(groups, ExpectedGroup{GroupID: &gid}) + groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupID: &gid}) } continue } group := group - groups = append(groups, ExpectedGroup{GroupName: &group}) + groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupName: &group}) } return groups, nil @@ -312,38 +315,6 @@ func (s GroupSyncSettings) ParseClaims(mergedClaims jwt.MapClaims) ([]ExpectedGr // Missing groups are created if AutoCreate is enabled. // TODO: Batching this would be better, as this is 1 or 2 db calls per organization. func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) { - if !s.AutoCreateMissingGroups { - // If we are not creating groups, then just construct a db lookup for - // all groups by name. - var lookups []string - filter := make([]uuid.UUID, 0) - for _, expected := range add { - if expected.GroupID != nil { - // Groups with IDs are easy! - filter = append(filter, *expected.GroupID) - } else if expected.GroupName != nil { - lookups = append(lookups, *expected.GroupName) - } - } - - if len(lookups) > 0 { - // Do name lookups for all groups that are missing IDs. - newGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ - OrganizationID: uuid.UUID{}, - HasMemberID: uuid.UUID{}, - GroupNames: lookups, - }) - if err != nil { - return nil, xerrors.Errorf("get groups by names: %w", err) - } - for _, g := range newGroups { - filter = append(filter, g.Group.ID) - } - } - - return filter, nil - } - // All expected that are missing IDs means the group does not exist // in the database. Either remove them, or create them if auto create is // turned on. @@ -359,33 +330,33 @@ func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database. } } - createdMissingGroups, err := tx.InsertMissingGroups(ctx, database.InsertMissingGroupsParams{ - OrganizationID: orgID, - Source: database.GroupSourceOidc, - GroupNames: missingGroups, - }) - if err != nil { - return nil, xerrors.Errorf("insert missing groups: %w", err) + if s.AutoCreateMissingGroups && len(missingGroups) > 0 { + // Insert any missing groups. If the groups already exist, this is a noop. + _, err := tx.InsertMissingGroups(ctx, database.InsertMissingGroupsParams{ + OrganizationID: orgID, + Source: database.GroupSourceOidc, + GroupNames: missingGroups, + }) + if err != nil { + return nil, xerrors.Errorf("insert missing groups: %w", err) + } } - if len(missingGroups) != len(createdMissingGroups) { - // This is unfortunate, but if legacy params are used, then some existing groups - // can come as params. So we need to fetch them - allGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ + // Fetch any missing groups by name. If they exist, their IDs will be + // matched and returned. + if len(missingGroups) > 0 { + // Do name lookups for all groups that are missing IDs. + newGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ OrganizationID: orgID, + HasMemberID: uuid.UUID{}, GroupNames: missingGroups, }) if err != nil { return nil, xerrors.Errorf("get groups by names: %w", err) } - - createdMissingGroups = db2sdk.List(allGroups, func(g database.GetGroupsRow) database.Group { - return g.Group - }) - } - - for _, created := range createdMissingGroups { - addIDs = append(addIDs, created.ID) + for _, g := range newGroups { + addIDs = append(addIDs, g.Group.ID) + } } return addIDs, nil diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index e1d0ac9d6c095..2207c52fd6830 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -205,6 +205,7 @@ func TestGroupSyncTable(t *testing.T) { LegacyGroupNameMapping: map[string]string{ "create-bar": "legacy-bar", "foo": "legacy-foo", + "bop": "legacy-bop", }, AutoCreateMissingGroups: true, }, @@ -214,6 +215,7 @@ func TestGroupSyncTable(t *testing.T) { GroupNames: map[string]bool{ "legacy-foo": false, "extra": true, + "legacy-bop": true, }, ExpectedGroupNames: []string{ "legacy-bar", From c544a293e30ffe9f087c4bdeb8f5a3b92ced1209 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 16:46:30 -0500 Subject: [PATCH 18/38] test legacy params --- coderd/idpsync/group.go | 5 +- coderd/idpsync/group_test.go | 113 +++++++++++++++++++++++++++++++---- 2 files changed, 104 insertions(+), 14 deletions(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index f076c7c5d5c87..a6799d5e50ece 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -78,7 +78,6 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat if err != nil { return xerrors.Errorf("resolve group sync settings: %w", err) } - orgSettings[orgID] = *settings // Legacy deployment settings will override empty settings. if orgID == defaultOrgID && settings.GroupField == "" { @@ -89,6 +88,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat AutoCreateMissingGroups: s.Legacy.CreateMissingGroups, } } + orgSettings[orgID] = *settings } // collect all diffs to do 1 sql update for all orgs @@ -280,6 +280,8 @@ func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClai groups := make([]ExpectedGroup, 0) for _, group := range parsedGroups { + group := group + // Legacy group mappings happen before the regex filter. mappedGroupName, ok := s.LegacyGroupNameMapping[group] if ok { @@ -302,7 +304,6 @@ func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClai continue } - group := group groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupName: &group}) } diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 2207c52fd6830..82b057422a787 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -14,6 +14,7 @@ import ( "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/idpsync" @@ -71,6 +72,7 @@ func TestGroupSyncTable(t *testing.T) { "groups": []string{ "foo", "bar", "baz", "create-bar", "create-baz", + "legacy-bar", }, } @@ -229,10 +231,6 @@ func TestGroupSyncTable(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { t.Parallel() - if tc.OrgID == uuid.Nil { - tc.OrgID = uuid.New() - } - db, _ := dbtestutil.NewDB(t) manager := runtimeconfig.NewStoreManager() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), @@ -242,9 +240,10 @@ func TestGroupSyncTable(t *testing.T) { }, ) - ctx := testutil.Context(t, testutil.WaitMedium) + ctx := testutil.Context(t, testutil.WaitSuperLong) user := dbgen.User(t, db, database.User{}) - SetupOrganization(t, s, db, user, tc) + orgID := uuid.New() + SetupOrganization(t, s, db, user, orgID, tc) // Do the group sync! err := s.SyncGroups(ctx, db, user, idpsync.GroupParams{ @@ -253,17 +252,106 @@ func TestGroupSyncTable(t *testing.T) { }) require.NoError(t, err) - tc.Assert(t, tc.OrgID, db, user) + tc.Assert(t, orgID, db, user) }) } + + // AllTogether runs the entire tabled test as a singular user and + // deployment. This tests all organizations being synced together. + // The reason we do them individually, is that it is much easier to + // debug a single test case. + t.Run("AllTogether", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + manager := runtimeconfig.NewStoreManager() + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + manager, + // Also sync the default org! + idpsync.DeploymentSyncSettings{ + GroupField: "groups", + Legacy: idpsync.DefaultOrgLegacySettings{ + GroupField: "groups", + GroupMapping: map[string]string{ + "foo": "legacy-foo", + "baz": "legacy-baz", + }, + GroupFilter: regexp.MustCompile("^legacy"), + CreateMissingGroups: true, + }, + }, + ) + + ctx := testutil.Context(t, testutil.WaitSuperLong) + user := dbgen.User(t, db, database.User{}) + + var asserts []func(t *testing.T) + // The default org is also going to do something + def := orgSetupDefinition{ + Name: "DefaultOrg", + GroupNames: map[string]bool{ + "legacy-foo": false, + "legacy-baz": true, + "random": true, + }, + // No settings, because they come from the deployment values + Settings: nil, + ExpectedGroups: nil, + ExpectedGroupNames: []string{"legacy-foo", "legacy-baz", "legacy-bar"}, + } + + //nolint:gocritic // testing + defOrg, err := db.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx)) + require.NoError(t, err) + SetupOrganization(t, s, db, user, defOrg.ID, def) + asserts = append(asserts, func(t *testing.T) { + t.Run(def.Name, func(t *testing.T) { + t.Parallel() + def.Assert(t, defOrg.ID, db, user) + }) + }) + + for _, tc := range testCases { + tc := tc + + orgID := uuid.New() + SetupOrganization(t, s, db, user, orgID, tc) + asserts = append(asserts, func(t *testing.T) { + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + tc.Assert(t, orgID, db, user) + }) + }) + } + + asserts = append(asserts, func(t *testing.T) { + t.Helper() + def.Assert(t, defOrg.ID, db, user) + }) + + // Do the group sync! + err = s.SyncGroups(ctx, db, user, idpsync.GroupParams{ + SyncEnabled: true, + MergedClaims: userClaims, + }) + require.NoError(t, err) + + for _, assert := range asserts { + assert(t) + } + }) } -func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, def orgSetupDefinition) { +func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, orgID uuid.UUID, def orgSetupDefinition) { + t.Helper() + org := dbgen.Organization(t, db, database.Organization{ - ID: def.OrgID, + ID: orgID, }) _, err := db.InsertAllUsersGroup(context.Background(), org.ID) - require.NoError(t, err, "Everyone group for an org") + if !database.IsUniqueViolation(err) { + require.NoError(t, err, "Everyone group for an org") + } manager := runtimeconfig.NewStoreManager() orgResolver := manager.OrganizationResolver(db, org.ID) @@ -303,8 +391,7 @@ func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, } type orgSetupDefinition struct { - Name string - OrgID uuid.UUID + Name string // True if the user is a member of the group Groups map[uuid.UUID]bool GroupNames map[string]bool @@ -353,11 +440,13 @@ func (o orgSetupDefinition) Assert(t *testing.T, orgID uuid.UUID, db database.St return g.Group.Name }) require.ElementsMatch(t, o.ExpectedGroupNames, found, "user groups by name") + require.Len(t, o.ExpectedGroups, 0, "ExpectedGroups should be empty") } else { // Check by ID, recommended found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID { return g.Group.ID }) require.ElementsMatch(t, o.ExpectedGroups, found, "user groups") + require.Len(t, o.ExpectedGroupNames, 0, "ExpectedGroupNames should be empty") } } From 476be45195870bdc133d77533e137e87925b12da Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 17:07:00 -0500 Subject: [PATCH 19/38] add unit test for ApplyGroupDifference --- coderd/database/dbmem/dbmem.go | 8 +- coderd/idpsync/group.go | 5 +- coderd/idpsync/group_test.go | 152 +++++++++++++++++++++++++++++++++ 3 files changed, 159 insertions(+), 6 deletions(-) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 7e761de411615..2e4e737ed5428 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -2702,18 +2702,18 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams) q.mutex.RLock() defer q.mutex.RUnlock() - groupIDs := make(map[uuid.UUID]struct{}) + userGroupIDs := make(map[uuid.UUID]struct{}) if arg.HasMemberID != uuid.Nil { for _, member := range q.groupMembers { if member.UserID == arg.HasMemberID { - groupIDs[member.GroupID] = struct{}{} + userGroupIDs[member.GroupID] = struct{}{} } } // Handle the everyone group for _, orgMember := range q.organizationMembers { if orgMember.UserID == arg.HasMemberID { - groupIDs[orgMember.OrganizationID] = struct{}{} + userGroupIDs[orgMember.OrganizationID] = struct{}{} } } } @@ -2725,7 +2725,7 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams) continue } - _, ok := groupIDs[group.ID] + _, ok := userGroupIDs[group.ID] if arg.HasMemberID != uuid.Nil && !ok { continue } diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index a6799d5e50ece..0930ede7cc545 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -168,7 +168,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat groupIDsToAdd = append(groupIDsToAdd, assignGroups...) } - err = s.applyGroupDifference(ctx, tx, user, groupIDsToAdd, groupIDsToRemove) + err = s.ApplyGroupDifference(ctx, tx, user, groupIDsToAdd, groupIDsToRemove) if err != nil { return xerrors.Errorf("apply group difference: %w", err) } @@ -183,7 +183,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return nil } -func (s AGPLIDPSync) applyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, removeIDs []uuid.UUID) error { +// ApplyGroupDifference will add and remove the user from the specified groups. +func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, removeIDs []uuid.UUID) error { // Always do group removal before group add. This way if there is an error, // we error on the underprivileged side. if len(removeIDs) > 0 { diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 82b057422a787..aa9e3e6c68b46 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -342,6 +342,158 @@ func TestGroupSyncTable(t *testing.T) { }) } +// TestApplyGroupDifference is mainly testing the database functions +func TestApplyGroupDifference(t *testing.T) { + t.Parallel() + + ids := coderdtest.NewDeterministicUUIDGenerator() + testCase := []struct { + Name string + Before map[uuid.UUID]bool + Add []uuid.UUID + Remove []uuid.UUID + Expect []uuid.UUID + }{ + { + Name: "Empty", + }, + { + Name: "AddFromNone", + Before: map[uuid.UUID]bool{ + ids.ID("g1"): false, + }, + Add: []uuid.UUID{ + ids.ID("g1"), + }, + Expect: []uuid.UUID{ + ids.ID("g1"), + }, + }, + { + Name: "AddSome", + Before: map[uuid.UUID]bool{ + ids.ID("g1"): true, + ids.ID("g2"): false, + ids.ID("g3"): false, + uuid.New(): false, + }, + Add: []uuid.UUID{ + ids.ID("g2"), + ids.ID("g3"), + }, + Expect: []uuid.UUID{ + ids.ID("g1"), + ids.ID("g2"), + ids.ID("g3"), + }, + }, + { + Name: "RemoveAll", + Before: map[uuid.UUID]bool{ + uuid.New(): false, + ids.ID("g2"): true, + ids.ID("g3"): true, + }, + Remove: []uuid.UUID{ + ids.ID("g2"), + ids.ID("g3"), + }, + Expect: []uuid.UUID{}, + }, + { + Name: "Mixed", + Before: map[uuid.UUID]bool{ + // adds + ids.ID("a1"): true, + ids.ID("a2"): true, + ids.ID("a3"): false, + ids.ID("a4"): false, + // removes + ids.ID("r1"): true, + ids.ID("r2"): true, + ids.ID("r3"): false, + ids.ID("r4"): false, + // stable + ids.ID("s1"): true, + ids.ID("s2"): true, + // noise + uuid.New(): false, + uuid.New(): false, + }, + Add: []uuid.UUID{ + ids.ID("a1"), ids.ID("a2"), + ids.ID("a3"), ids.ID("a4"), + // Double up to try and confuse + ids.ID("a1"), + ids.ID("a4"), + }, + Remove: []uuid.UUID{ + ids.ID("r1"), ids.ID("r2"), + ids.ID("r3"), ids.ID("r4"), + // Double up to try and confuse + ids.ID("r1"), + ids.ID("r4"), + }, + Expect: []uuid.UUID{ + ids.ID("a1"), ids.ID("a2"), ids.ID("a3"), ids.ID("a4"), + ids.ID("s1"), ids.ID("s2"), + }, + }, + } + + for _, tc := range testCase { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + mgr := runtimeconfig.NewStoreManager() + db, _ := dbtestutil.NewDB(t) + + ctx := testutil.Context(t, testutil.WaitMedium) + //nolint:gocritic // testing + ctx = dbauthz.AsSystemRestricted(ctx) + + org := dbgen.Organization(t, db, database.Organization{}) + _, err := db.InsertAllUsersGroup(ctx, org.ID) + require.NoError(t, err) + + user := dbgen.User(t, db, database.User{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + + for gid, in := range tc.Before { + group := dbgen.Group(t, db, database.Group{ + ID: gid, + OrganizationID: org.ID, + }) + if in { + _ = dbgen.GroupMember(t, db, database.GroupMemberTable{ + UserID: user.ID, + GroupID: group.ID, + }) + } + } + + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), mgr, idpsync.FromDeploymentValues(coderdtest.DeploymentValues(t))) + err = s.ApplyGroupDifference(context.Background(), db, user, tc.Add, tc.Remove) + require.NoError(t, err) + + userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{ + HasMemberID: user.ID, + }) + require.NoError(t, err) + + // assert + found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID { + return g.Group.ID + }) + + // Add everyone group + require.ElementsMatch(t, append(tc.Expect, org.ID), found) + }) + } +} + func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, orgID uuid.UUID, def orgSetupDefinition) { t.Helper() From 164aeacebac6f544d85c73d4cd83fe028ce5259a Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 17:19:20 -0500 Subject: [PATCH 20/38] chore: remove old group sync code --- coderd/coderd.go | 11 --- coderd/idpsync/group.go | 8 +- coderd/idpsync/idpsync.go | 4 +- coderd/userauth.go | 178 ++++++---------------------------- enterprise/coderd/coderd.go | 1 - enterprise/coderd/userauth.go | 66 ------------- 6 files changed, 36 insertions(+), 232 deletions(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index b829d37a06773..e04f13d367c6e 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -181,7 +181,6 @@ type Options struct { NetworkTelemetryBatchFrequency time.Duration NetworkTelemetryBatchMaxSize int SwaggerEndpoint bool - SetUserGroups func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error SetUserSiteRoles func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, roles []string) error TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore] @@ -374,16 +373,6 @@ func New(options *Options) *API { if options.TracerProvider == nil { options.TracerProvider = trace.NewNoopTracerProvider() } - if options.SetUserGroups == nil { - options.SetUserGroups = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error { - logger.Warn(ctx, "attempted to assign OIDC groups without enterprise license", - slog.F("user_id", userID), - slog.F("groups", orgGroupNames), - slog.F("create_missing_groups", createMissingGroups), - ) - return nil - } - } if options.SetUserSiteRoles == nil { options.SetUserSiteRoles = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, roles []string) error { logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise license", diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 0930ede7cc545..660d0b9b9c23e 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -14,6 +14,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/coderd/util/slice" ) @@ -76,7 +77,12 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat orgResolver := s.Manager.OrganizationResolver(tx, orgID) settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver) if err != nil { - return xerrors.Errorf("resolve group sync settings: %w", err) + if xerrors.Is(err, runtimeconfig.EntryNotFound) { + // Default to not being configured + settings = &GroupSyncSettings{} + } else { + return xerrors.Errorf("resolve group sync settings: %w", err) + } } // Legacy deployment settings will override empty settings. diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index bc3e5cd479064..7fac0e7329d3d 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -27,7 +27,7 @@ type IDPSync interface { OrganizationSyncEnabled() bool // ParseOrganizationClaims takes claims from an OIDC provider, and returns the // organization sync params for assigning users into organizations. - ParseOrganizationClaims(ctx context.Context, _ jwt.MapClaims) (OrganizationParams, *HTTPError) + ParseOrganizationClaims(ctx context.Context, mergedClaims jwt.MapClaims) (OrganizationParams, *HTTPError) // SyncOrganizations assigns and removed users from organizations based on the // provided params. SyncOrganizations(ctx context.Context, tx database.Store, user database.User, params OrganizationParams) error @@ -35,7 +35,7 @@ type IDPSync interface { GroupSyncEnabled() bool // ParseGroupClaims takes claims from an OIDC provider, and returns the params // for group syncing. Most of the logic happens in SyncGroups. - ParseGroupClaims(ctx context.Context, _ jwt.MapClaims) (GroupParams, *HTTPError) + ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (GroupParams, *HTTPError) // SyncGroups assigns and removes users from groups based on the provided params. SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error diff --git a/coderd/userauth.go b/coderd/userauth.go index a1abadc63f31a..76d29a7c1a9ec 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -20,7 +20,6 @@ import ( "github.com/google/go-github/v43/github" "github.com/google/uuid" "github.com/moby/moby/pkg/namesgenerator" - "golang.org/x/exp/slices" "golang.org/x/oauth2" "golang.org/x/xerrors" @@ -659,6 +658,9 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { AvatarURL: ghUser.GetAvatarURL(), Name: normName, DebugContext: OauthDebugContext{}, + GroupSync: idpsync.GroupParams{ + SyncEnabled: false, + }, OrganizationSync: idpsync.OrganizationParams{ SyncEnabled: false, IncludeDefault: true, @@ -1004,11 +1006,6 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } ctx = slog.With(ctx, slog.F("email", email), slog.F("username", username), slog.F("name", name)) - usingGroups, groups, groupErr := api.oidcGroups(ctx, mergedClaims) - if groupErr != nil { - groupErr.Write(rw, r) - return - } roles, roleErr := api.oidcRoles(ctx, mergedClaims) if roleErr != nil { @@ -1032,6 +1029,12 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { return } + groupSync, groupSyncErr := api.IDPSync.ParseGroupClaims(ctx, mergedClaims) + if groupSyncErr != nil { + groupSyncErr.Write(rw, r) + return + } + // If a new user is authenticating for the first time // the audit action is 'register', not 'login' if user.ID == uuid.Nil { @@ -1039,23 +1042,20 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } params := (&oauthLoginParams{ - User: user, - Link: link, - State: state, - LinkedID: oidcLinkedID(idToken), - LoginType: database.LoginTypeOIDC, - AllowSignups: api.OIDCConfig.AllowSignups, - Email: email, - Username: username, - Name: name, - AvatarURL: picture, - UsingRoles: api.OIDCConfig.RoleSyncEnabled(), - Roles: roles, - UsingGroups: usingGroups, - Groups: groups, - OrganizationSync: orgSync, - CreateMissingGroups: api.OIDCConfig.CreateMissingGroups, - GroupFilter: api.OIDCConfig.GroupFilter, + User: user, + Link: link, + State: state, + LinkedID: oidcLinkedID(idToken), + LoginType: database.LoginTypeOIDC, + AllowSignups: api.OIDCConfig.AllowSignups, + Email: email, + Username: username, + Name: name, + AvatarURL: picture, + UsingRoles: api.OIDCConfig.RoleSyncEnabled(), + Roles: roles, + OrganizationSync: orgSync, + GroupSync: groupSync, DebugContext: OauthDebugContext{ IDTokenClaims: idtokenClaims, UserInfoClaims: userInfoClaims, @@ -1091,79 +1091,6 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect) } -// oidcGroups returns the groups for the user from the OIDC claims. -func (api *API) oidcGroups(ctx context.Context, mergedClaims map[string]interface{}) (bool, []string, *idpsync.HTTPError) { - logger := api.Logger.Named(userAuthLoggerName) - usingGroups := false - var groups []string - - // If the GroupField is the empty string, then groups from OIDC are not used. - // This is so we can support manual group assignment. - if api.OIDCConfig.GroupField != "" { - // If the allow list is empty, then the user is allowed to log in. - // Otherwise, they must belong to at least 1 group in the allow list. - inAllowList := len(api.OIDCConfig.GroupAllowList) == 0 - - usingGroups = true - groupsRaw, ok := mergedClaims[api.OIDCConfig.GroupField] - if ok { - parsedGroups, err := idpsync.ParseStringSliceClaim(groupsRaw) - if err != nil { - api.Logger.Debug(ctx, "groups field was an unknown type in oidc claims", - slog.F("type", fmt.Sprintf("%T", groupsRaw)), - slog.Error(err), - ) - return false, nil, &idpsync.HTTPError{ - Code: http.StatusBadRequest, - Msg: "Failed to sync groups from OIDC claims", - Detail: err.Error(), - RenderStaticPage: false, - } - } - - api.Logger.Debug(ctx, "groups returned in oidc claims", - slog.F("len", len(parsedGroups)), - slog.F("groups", parsedGroups), - ) - - for _, group := range parsedGroups { - if mappedGroup, ok := api.OIDCConfig.GroupMapping[group]; ok { - group = mappedGroup - } - if _, ok := api.OIDCConfig.GroupAllowList[group]; ok { - inAllowList = true - } - groups = append(groups, group) - } - } - - if !inAllowList { - logger.Debug(ctx, "oidc group claim not in allow list, rejecting login", - slog.F("allow_list_count", len(api.OIDCConfig.GroupAllowList)), - slog.F("user_group_count", len(groups)), - ) - detail := "Ask an administrator to add one of your groups to the allow list" - if len(groups) == 0 { - detail = "You are currently not a member of any groups! Ask an administrator to add you to an authorized group to login." - } - return usingGroups, groups, &idpsync.HTTPError{ - Code: http.StatusForbidden, - Msg: "Not a member of an allowed group", - Detail: detail, - RenderStaticPage: true, - } - } - } - - // This conditional is purely to warn the user they might have misconfigured their OIDC - // configuration. - if _, groupClaimExists := mergedClaims["groups"]; !usingGroups && groupClaimExists { - logger.Debug(ctx, "claim 'groups' was returned, but 'oidc-group-field' is not set, check your coder oidc settings") - } - - return usingGroups, groups, nil -} - // oidcRoles returns the roles for the user from the OIDC claims. // If the function returns false, then the caller should return early. // All writes to the response writer are handled by this function. @@ -1278,14 +1205,7 @@ type oauthLoginParams struct { AvatarURL string // OrganizationSync has the organizations that the user will be assigned to. OrganizationSync idpsync.OrganizationParams - // Is UsingGroups is true, then the user will be assigned - // to the Groups provided. - UsingGroups bool - CreateMissingGroups bool - // These are the group names from the IDP. Internally, they will map to - // some organization groups. - Groups []string - GroupFilter *regexp.Regexp + GroupSync idpsync.GroupParams // Is UsingRoles is true, then the user will be assigned // the roles provided. UsingRoles bool @@ -1491,53 +1411,9 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C return xerrors.Errorf("sync organizations: %w", err) } - // Ensure groups are correct. - // This places all groups into the default organization. - // To go multi-org, we need to add a mapping feature here to know which - // groups go to which orgs. - if params.UsingGroups { - filtered := params.Groups - if params.GroupFilter != nil { - filtered = make([]string, 0, len(params.Groups)) - for _, group := range params.Groups { - if params.GroupFilter.MatchString(group) { - filtered = append(filtered, group) - } - } - } - - //nolint:gocritic // No user present in the context. - defaultOrganization, err := tx.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx)) - if err != nil { - // If there is no default org, then we can't assign groups. - // By default, we assume all groups belong to the default org. - return xerrors.Errorf("get default organization: %w", err) - } - - //nolint:gocritic // No user present in the context. - memberships, err := tx.OrganizationMembers(dbauthz.AsSystemRestricted(ctx), database.OrganizationMembersParams{ - UserID: user.ID, - OrganizationID: uuid.Nil, - }) - if err != nil { - return xerrors.Errorf("get organization memberships: %w", err) - } - - // If the user is not in the default organization, then we can't assign groups. - // A user cannot be in groups to an org they are not a member of. - if !slices.ContainsFunc(memberships, func(member database.OrganizationMembersRow) bool { - return member.OrganizationMember.OrganizationID == defaultOrganization.ID - }) { - return xerrors.Errorf("user %s is not a member of the default organization, cannot assign to groups in the org", user.ID) - } - - //nolint:gocritic - err = api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, map[uuid.UUID][]string{ - defaultOrganization.ID: filtered, - }, params.CreateMissingGroups) - if err != nil { - return xerrors.Errorf("set user groups: %w", err) - } + err = api.IDPSync.SyncGroups(ctx, tx, user, params.GroupSync) + if err != nil { + return xerrors.Errorf("sync groups: %w", err) } // Ensure roles are correct. diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index ce55bae8ec8d0..f9ab3e452ac04 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -145,7 +145,6 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { } return c.Subject, c.Trial, nil } - api.AGPL.Options.SetUserGroups = api.setUserGroups api.AGPL.Options.SetUserSiteRoles = api.setUserSiteRoles api.AGPL.SiteHandler.RegionsFetcher = func(ctx context.Context) (any, error) { // If the user can read the workspace proxy resource, return that. diff --git a/enterprise/coderd/userauth.go b/enterprise/coderd/userauth.go index 65c4a3473f3f7..60cba28cc37f3 100644 --- a/enterprise/coderd/userauth.go +++ b/enterprise/coderd/userauth.go @@ -8,75 +8,9 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/codersdk" ) -// nolint: revive -func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error { - if !api.Entitlements.Enabled(codersdk.FeatureTemplateRBAC) { - return nil - } - - return db.InTx(func(tx database.Store) error { - // When setting the user's groups, it's easier to just clear their groups and re-add them. - // This ensures that the user's groups are always in sync with the auth provider. - orgs, err := tx.GetOrganizationsByUserID(ctx, userID) - if err != nil { - return xerrors.Errorf("get user orgs: %w", err) - } - if len(orgs) != 1 { - return xerrors.Errorf("expected 1 org, got %d", len(orgs)) - } - - // Delete all groups the user belongs to. - // nolint:gocritic // Requires system context to remove user from all groups. - err = tx.RemoveUserFromAllGroups(dbauthz.AsSystemRestricted(ctx), userID) - if err != nil { - return xerrors.Errorf("delete user groups: %w", err) - } - - // TODO: This could likely be improved by making these single queries. - // Either by batching or some other means. This for loop could be really - // inefficient if there are a lot of organizations. There was deployments - // on v1 with >100 orgs. - for orgID, groupNames := range orgGroupNames { - // Create the missing groups for each organization. - if createMissingGroups { - // This is the system creating these additional groups, so we use the system restricted context. - // nolint:gocritic - created, err := tx.InsertMissingGroups(dbauthz.AsSystemRestricted(ctx), database.InsertMissingGroupsParams{ - OrganizationID: orgID, - GroupNames: groupNames, - Source: database.GroupSourceOidc, - }) - if err != nil { - return xerrors.Errorf("insert missing groups: %w", err) - } - if len(created) > 0 { - logger.Debug(ctx, "auto created missing groups", - slog.F("org_id", orgID.ID), - slog.F("created", created), - slog.F("num", len(created)), - ) - } - } - - // Re-add the user to all groups returned by the auth provider. - err = tx.InsertUserGroupsByName(ctx, database.InsertUserGroupsByNameParams{ - UserID: userID, - OrganizationID: orgID, - GroupNames: groupNames, - }) - if err != nil { - return xerrors.Errorf("insert user groups: %w", err) - } - } - - return nil - }, nil) -} - func (api *API) setUserSiteRoles(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, roles []string) error { if !api.Entitlements.Enabled(codersdk.FeatureUserRoleManagement) { logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise entitlement, roles left unchanged", From 986498d5fb0ad4fe8a79b481617a71a233b99db7 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 5 Sep 2024 17:38:33 -0500 Subject: [PATCH 21/38] switch oidc test config to deployment values --- cli/server.go | 5 -- coderd/idpsync/group.go | 50 ++++++++++-------- coderd/idpsync/group_test.go | 38 +++++++------- coderd/userauth.go | 24 +-------- enterprise/coderd/userauth_test.go | 83 ++++++++++++++++++------------ 5 files changed, 100 insertions(+), 100 deletions(-) diff --git a/cli/server.go b/cli/server.go index 4e3b1e16a1482..c2cd476edfaa4 100644 --- a/cli/server.go +++ b/cli/server.go @@ -187,11 +187,6 @@ func createOIDCConfig(ctx context.Context, logger slog.Logger, vals *codersdk.De EmailField: vals.OIDC.EmailField.String(), AuthURLParams: vals.OIDC.AuthURLParams.Value, IgnoreUserInfo: vals.OIDC.IgnoreUserInfo.Value(), - GroupField: vals.OIDC.GroupField.String(), - GroupFilter: vals.OIDC.GroupRegexFilter.Value(), - GroupAllowList: groupAllowList, - CreateMissingGroups: vals.OIDC.GroupAutoCreate.Value(), - GroupMapping: vals.OIDC.GroupMapping.Value, UserRoleField: vals.OIDC.UserRoleField.String(), UserRoleMapping: vals.OIDC.UserRoleMapping.Value, UserRolesDefault: vals.OIDC.UserRolesDefault.GetSlice(), diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 660d0b9b9c23e..69915125acc71 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -41,6 +41,9 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return nil } + // nolint:gocritic // all syncing is done as a system user + ctx = dbauthz.AsSystemRestricted(ctx) + // Only care about the default org for deployment settings if the // legacy deployment settings exist. defaultOrgID := uuid.Nil @@ -53,9 +56,6 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat defaultOrgID = defaultOrganization.ID } - // nolint:gocritic // all syncing is done as a system user - ctx = dbauthz.AsSystemRestricted(ctx) - err := db.InTx(func(tx database.Store) error { userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{ HasMemberID: user.ID, @@ -86,12 +86,12 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat } // Legacy deployment settings will override empty settings. - if orgID == defaultOrgID && settings.GroupField == "" { + if orgID == defaultOrgID && settings.Field == "" { settings = &GroupSyncSettings{ - GroupField: s.Legacy.GroupField, - LegacyGroupNameMapping: s.Legacy.GroupMapping, - RegexFilter: s.Legacy.GroupFilter, - AutoCreateMissingGroups: s.Legacy.CreateMissingGroups, + Field: s.Legacy.GroupField, + LegacyNameMapping: s.Legacy.GroupMapping, + RegexFilter: s.Legacy.GroupFilter, + AutoCreateMissing: s.Legacy.CreateMissingGroups, } } orgSettings[orgID] = *settings @@ -102,7 +102,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat groupIDsToRemove := make([]uuid.UUID, 0) // For each org, determine which groups the user should land in for orgID, settings := range orgSettings { - if settings.GroupField == "" { + if settings.Field == "" { // No group sync enabled for this org, so do nothing. continue } @@ -231,17 +231,25 @@ func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store } type GroupSyncSettings struct { - GroupField string `json:"field"` - // GroupMapping maps from an OIDC group --> Coder group ID - GroupMapping map[string][]uuid.UUID `json:"mapping"` - RegexFilter *regexp.Regexp `json:"regex_filter"` - AutoCreateMissingGroups bool `json:"auto_create_missing_groups"` - // LegacyGroupNameMapping is deprecated. It remaps an IDP group name to + // Field selects the claim field to be used as the created user's + // groups. If the group field is the empty string, then no group updates + // will ever come from the OIDC provider. + Field string `json:"field"` + // Mapping maps from an OIDC group --> Coder group ID + Mapping map[string][]uuid.UUID `json:"mapping"` + // RegexFilter is a regular expression that filters the groups returned by + // the OIDC provider. Any group not matched by this regex will be ignored. + // If the group filter is nil, then no group filtering will occur. + RegexFilter *regexp.Regexp `json:"regex_filter"` + // AutoCreateMissing controls whether groups returned by the OIDC provider + // are automatically created in Coder if they are missing. + AutoCreateMissing bool `json:"auto_create_missing_groups"` + // LegacyNameMapping is deprecated. It remaps an IDP group name to // a Coder group name. Since configuration is now done at runtime, // group IDs are used to account for group renames. // For legacy configurations, this config option has to remain. - // Deprecated: Use GroupMapping instead. - LegacyGroupNameMapping map[string]string `json:"legacy_group_name_mapping,omitempty"` + // Deprecated: Use Mapping instead. + LegacyNameMapping map[string]string `json:"legacy_group_name_mapping,omitempty"` } func (s *GroupSyncSettings) Set(v string) error { @@ -275,7 +283,7 @@ type ExpectedGroup struct { // We have to keep names because group sync supports syncing groups by name if // the external IDP group name matches the Coder one. func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClaims) ([]ExpectedGroup, error) { - groupsRaw, ok := mergedClaims[s.GroupField] + groupsRaw, ok := mergedClaims[s.Field] if !ok { return []ExpectedGroup{}, nil } @@ -290,7 +298,7 @@ func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClai group := group // Legacy group mappings happen before the regex filter. - mappedGroupName, ok := s.LegacyGroupNameMapping[group] + mappedGroupName, ok := s.LegacyNameMapping[group] if ok { group = mappedGroupName } @@ -302,7 +310,7 @@ func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClai } } - mappedGroupIDs, ok := s.GroupMapping[group] + mappedGroupIDs, ok := s.Mapping[group] if ok { for _, gid := range mappedGroupIDs { gid := gid @@ -338,7 +346,7 @@ func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database. } } - if s.AutoCreateMissingGroups && len(missingGroups) > 0 { + if s.AutoCreateMissing && len(missingGroups) > 0 { // Insert any missing groups. If the groups already exist, this is a noop. _, err := tx.InsertMissingGroups(ctx, database.InsertMissingGroupsParams{ OrganizationID: orgID, diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index aa9e3e6c68b46..4e56260400114 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -81,8 +81,8 @@ func TestGroupSyncTable(t *testing.T) { { Name: "SwitchGroups", Settings: &idpsync.GroupSyncSettings{ - GroupField: "groups", - GroupMapping: map[string][]uuid.UUID{ + Field: "groups", + Mapping: map[string][]uuid.UUID{ "foo": {ids.ID("sg-foo"), ids.ID("sg-foo-2")}, "bar": {ids.ID("sg-bar")}, "baz": {ids.ID("sg-baz")}, @@ -107,10 +107,10 @@ func TestGroupSyncTable(t *testing.T) { { Name: "StayInGroup", Settings: &idpsync.GroupSyncSettings{ - GroupField: "groups", + Field: "groups", // Only match foo, so bar does not map RegexFilter: regexp.MustCompile("^foo$"), - GroupMapping: map[string][]uuid.UUID{ + Mapping: map[string][]uuid.UUID{ "foo": {ids.ID("gg-foo"), uuid.New()}, "bar": {ids.ID("gg-bar")}, "baz": {ids.ID("gg-baz")}, @@ -127,8 +127,8 @@ func TestGroupSyncTable(t *testing.T) { { Name: "UserJoinsGroups", Settings: &idpsync.GroupSyncSettings{ - GroupField: "groups", - GroupMapping: map[string][]uuid.UUID{ + Field: "groups", + Mapping: map[string][]uuid.UUID{ "foo": {ids.ID("ng-foo"), uuid.New()}, "bar": {ids.ID("ng-bar"), ids.ID("ng-bar-2")}, "baz": {ids.ID("ng-baz")}, @@ -150,9 +150,9 @@ func TestGroupSyncTable(t *testing.T) { { Name: "CreateGroups", Settings: &idpsync.GroupSyncSettings{ - GroupField: "groups", - RegexFilter: regexp.MustCompile("^create"), - AutoCreateMissingGroups: true, + Field: "groups", + RegexFilter: regexp.MustCompile("^create"), + AutoCreateMissing: true, }, Groups: map[uuid.UUID]bool{}, ExpectedGroupNames: []string{ @@ -163,9 +163,9 @@ func TestGroupSyncTable(t *testing.T) { { Name: "GroupNamesNoMapping", Settings: &idpsync.GroupSyncSettings{ - GroupField: "groups", - RegexFilter: regexp.MustCompile(".*"), - AutoCreateMissingGroups: false, + Field: "groups", + RegexFilter: regexp.MustCompile(".*"), + AutoCreateMissing: false, }, GroupNames: map[string]bool{ "foo": false, @@ -180,13 +180,13 @@ func TestGroupSyncTable(t *testing.T) { { Name: "NoUser", Settings: &idpsync.GroupSyncSettings{ - GroupField: "groups", - GroupMapping: map[string][]uuid.UUID{ + Field: "groups", + Mapping: map[string][]uuid.UUID{ // Extra ID that does not map to a group "foo": {ids.ID("ow-foo"), uuid.New()}, }, - RegexFilter: nil, - AutoCreateMissingGroups: false, + RegexFilter: nil, + AutoCreateMissing: false, }, NotMember: true, Groups: map[uuid.UUID]bool{ @@ -202,14 +202,14 @@ func TestGroupSyncTable(t *testing.T) { { Name: "LegacyMapping", Settings: &idpsync.GroupSyncSettings{ - GroupField: "groups", + Field: "groups", RegexFilter: regexp.MustCompile("^legacy"), - LegacyGroupNameMapping: map[string]string{ + LegacyNameMapping: map[string]string{ "create-bar": "legacy-bar", "foo": "legacy-foo", "bop": "legacy-bop", }, - AutoCreateMissingGroups: true, + AutoCreateMissing: true, }, Groups: map[uuid.UUID]bool{ ids.ID("lg-foo"): true, diff --git a/coderd/userauth.go b/coderd/userauth.go index 76d29a7c1a9ec..a2c8140c65be5 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -8,7 +8,6 @@ import ( "fmt" "net/http" "net/mail" - "regexp" "sort" "strconv" "strings" @@ -659,7 +658,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { Name: normName, DebugContext: OauthDebugContext{}, GroupSync: idpsync.GroupParams{ - SyncEnabled: false, + SyncEnabled: false, }, OrganizationSync: idpsync.OrganizationParams{ SyncEnabled: false, @@ -743,27 +742,6 @@ type OIDCConfig struct { // support the userinfo endpoint, or if the userinfo endpoint causes // undesirable behavior. IgnoreUserInfo bool - - // TODO: Move all idp fields into the IDPSync struct - // GroupField selects the claim field to be used as the created user's - // groups. If the group field is the empty string, then no group updates - // will ever come from the OIDC provider. - GroupField string - // CreateMissingGroups controls whether groups returned by the OIDC provider - // are automatically created in Coder if they are missing. - CreateMissingGroups bool - // GroupFilter is a regular expression that filters the groups returned by - // the OIDC provider. Any group not matched by this regex will be ignored. - // If the group filter is nil, then no group filtering will occur. - GroupFilter *regexp.Regexp - // GroupAllowList is a list of groups that are allowed to log in. - // If the list length is 0, then the allow list will not be applied and - // this feature is disabled. - GroupAllowList map[string]bool - // GroupMapping controls how groups returned by the OIDC provider get mapped - // to groups within Coder. - // map[oidcGroupName]coderGroupName - GroupMapping map[string]string // UserRoleField selects the claim field to be used as the created user's // roles. If the field is the empty string, then no role updates // will ever come from the OIDC provider. diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index 3e94a25a1c013..0ab67542cc2c7 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -402,7 +402,9 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim }, }) @@ -433,8 +435,10 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim - cfg.GroupMapping = map[string]string{oidcGroupName: coderGroupName} + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim + dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{map[string]string{oidcGroupName: coderGroupName}} }, }) @@ -468,7 +472,9 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim }, }) @@ -502,7 +508,9 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim }, }) @@ -537,7 +545,9 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim }, }) @@ -559,8 +569,10 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim - cfg.CreateMissingGroups = true + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim + dv.OIDC.GroupAutoCreate = true }, }) @@ -582,8 +594,10 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim - cfg.CreateMissingGroups = true + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim + dv.OIDC.GroupAutoCreate = true }, }) @@ -606,8 +620,10 @@ func TestUserOIDC(t *testing.T) { runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { cfg.AllowSignups = true - cfg.GroupField = groupClaim - cfg.GroupAllowList = map[string]bool{allowedGroup: true} + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = groupClaim + dv.OIDC.GroupAllowList = []string{allowedGroup} }, }) @@ -697,6 +713,7 @@ func TestGroupSync(t *testing.T) { testCases := []struct { name string modCfg func(cfg *coderd.OIDCConfig) + modDV func(dv *codersdk.DeploymentValues) // initialOrgGroups is initial groups in the org initialOrgGroups []string // initialUserGroups is initial groups for the user @@ -718,10 +735,10 @@ func TestGroupSync(t *testing.T) { }, { name: "GroupSyncDisabled", - modCfg: func(cfg *coderd.OIDCConfig) { + modDV: func(dv *codersdk.DeploymentValues) { // Disable group sync - cfg.GroupField = "" - cfg.GroupFilter = regexp.MustCompile(".*") + dv.OIDC.GroupField = "" + dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile(".*")) }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"b", "c", "d"}, @@ -732,10 +749,8 @@ func TestGroupSync(t *testing.T) { { // From a,c,b -> b,c,d name: "ChangeUserGroups", - modCfg: func(cfg *coderd.OIDCConfig) { - cfg.GroupMapping = map[string]string{ - "D": "d", - } + modDV: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{map[string]string{"D": "d"}} }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"a", "b", "c"}, @@ -749,8 +764,8 @@ func TestGroupSync(t *testing.T) { { // From a,c,b -> [] name: "RemoveAllGroups", - modCfg: func(cfg *coderd.OIDCConfig) { - cfg.GroupFilter = regexp.MustCompile(".*") + modDV: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile(".*")) }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"a", "b", "c"}, @@ -763,8 +778,8 @@ func TestGroupSync(t *testing.T) { { // From a,c,b -> b,c,d,e,f name: "CreateMissingGroups", - modCfg: func(cfg *coderd.OIDCConfig) { - cfg.CreateMissingGroups = true + modDV: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupAutoCreate = true }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"a", "b", "c"}, @@ -777,14 +792,11 @@ func TestGroupSync(t *testing.T) { { // From a,c,b -> b,c,d,e,f name: "CreateMissingGroupsFilter", - modCfg: func(cfg *coderd.OIDCConfig) { - cfg.CreateMissingGroups = true + modDV: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupAutoCreate = true // Only single letter groups - cfg.GroupFilter = regexp.MustCompile("^[a-z]$") - cfg.GroupMapping = map[string]string{ - // Does not match the filter, but does after being mapped! - "zebra": "z", - } + dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile("^[a-z]$")) + dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{map[string]string{"zebra": "z"}} }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"a", "b", "c"}, @@ -806,8 +818,15 @@ func TestGroupSync(t *testing.T) { t.Parallel() runner := setupOIDCTest(t, oidcTestConfig{ Config: func(cfg *coderd.OIDCConfig) { - cfg.GroupField = "groups" - tc.modCfg(cfg) + if tc.modCfg != nil { + tc.modCfg(cfg) + } + }, + DeploymentValues: func(dv *codersdk.DeploymentValues) { + dv.OIDC.GroupField = "groups" + if tc.modDV != nil { + tc.modDV(dv) + } }, }) From 290cfa51aeaaa14d7edc47a841995e096173ed97 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 09:44:36 -0500 Subject: [PATCH 22/38] fix err name --- coderd/idpsync/group.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 69915125acc71..38c7260b80b0a 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -77,7 +77,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat orgResolver := s.Manager.OrganizationResolver(tx, orgID) settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver) if err != nil { - if xerrors.Is(err, runtimeconfig.EntryNotFound) { + if xerrors.Is(err, runtimeconfig.ErrEntryNotFound) { // Default to not being configured settings = &GroupSyncSettings{} } else { From c563b10717abb7bb1cb509ddf1785cc540eddc6d Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 09:52:32 -0500 Subject: [PATCH 23/38] some linting cleanup --- coderd/database/models.go | 2 +- coderd/database/querier.go | 2 +- coderd/database/queries.sql.go | 10 +++++----- coderd/idpsync/group.go | 1 - enterprise/coderd/enidpsync/organizations_test.go | 2 +- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/coderd/database/models.go b/coderd/database/models.go index 950c2674ab310..9e0283ba859c1 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.26.0 +// sqlc v1.25.0 package database diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 3cedeeade34b7..315f2d6fa1cfd 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.26.0 +// sqlc v1.25.0 package database diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 7c7fbbf0f88f0..52044e4e7e90d 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.26.0 +// sqlc v1.25.0 package database @@ -3126,7 +3126,7 @@ func (q *sqlQuerier) GetJFrogXrayScanByWorkspaceAndAgentID(ctx context.Context, } const upsertJFrogXrayScanByWorkspaceAndAgentID = `-- name: UpsertJFrogXrayScanByWorkspaceAndAgentID :exec -INSERT INTO +INSERT INTO jfrog_xray_scans ( agent_id, workspace_id, @@ -3135,7 +3135,7 @@ INSERT INTO medium, results_url ) -VALUES +VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (agent_id, workspace_id) DO UPDATE SET critical = $3, high = $4, medium = $5, results_url = $6 @@ -5863,7 +5863,7 @@ FROM provisioner_keys WHERE organization_id = $1 -AND +AND lower(name) = lower($2) ` @@ -7616,7 +7616,7 @@ func (q *sqlQuerier) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUI } const updateTailnetPeerStatusByCoordinator = `-- name: UpdateTailnetPeerStatusByCoordinator :exec -UPDATE +UPDATE tailnet_peers SET status = $2 diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 38c7260b80b0a..d5709b5b9f722 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -181,7 +181,6 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat return nil }, nil) - if err != nil { return err } diff --git a/enterprise/coderd/enidpsync/organizations_test.go b/enterprise/coderd/enidpsync/organizations_test.go index 8978fa6b46ee1..e01ae5a18d98b 100644 --- a/enterprise/coderd/enidpsync/organizations_test.go +++ b/enterprise/coderd/enidpsync/organizations_test.go @@ -237,7 +237,7 @@ func TestOrganizationSync(t *testing.T) { } // Create a new sync object - sync := enidpsync.NewSync(logger, runtimeconfig.NewStoreManager(rdb), caseData.Entitlements, caseData.Settings) + sync := enidpsync.NewSync(logger, runtimeconfig.NewStoreManager(), caseData.Entitlements, caseData.Settings) for _, exp := range caseData.Exps { t.Run(exp.Name, func(t *testing.T) { params, httpErr := sync.ParseOrganizationClaims(ctx, exp.Claims) From d2c247fc8bba073a75ed9598cc28420ce0c7c5b4 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 10:11:05 -0500 Subject: [PATCH 24/38] dbauthz test for new query --- coderd/database/dbauthz/dbauthz_test.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index f9b9fb49b71fc..4b4874f34247c 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -408,6 +408,18 @@ func (s *MethodTestSuite) TestGroup() { _ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID}) check.Args(u1.ID).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns() })) + s.Run("RemoveUserFromGroups", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u1 := dbgen.User(s.T(), db, database.User{}) + g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g1.ID, UserID: u1.ID}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID}) + check.Args(database.RemoveUserFromGroupsParams{ + UserID: u1.ID, + GroupIds: []uuid.UUID{g1.ID, g2.ID}, + }).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1.ID, g2.ID)) + })) s.Run("UpdateGroupByID", s.Subtest(func(db database.Store, check *expects) { g := dbgen.Group(s.T(), db, database.Group{}) check.Args(database.UpdateGroupByIDParams{ From 12685bd985c4d3446edd31f718dbe2cabe8ba6bc Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 10:39:49 -0500 Subject: [PATCH 25/38] fixup comments --- coderd/idpsync/group.go | 31 +++++++++++++++++++++------ enterprise/coderd/enidpsync/groups.go | 3 ++- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index d5709b5b9f722..743b368b094f3 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -65,6 +65,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat } // Figure out which organizations the user is a member of. + // The "Everyone" group is always included, so we can infer organization + // membership via the groups the user is in. userOrgs := make(map[uuid.UUID][]database.GetGroupsRow) for _, g := range userGroups { g := g @@ -72,6 +74,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat } // For each org, we need to fetch the sync settings + // This loop also handles any legacy settings for the default + // organization. orgSettings := make(map[uuid.UUID]GroupSyncSettings) for orgID := range userOrgs { orgResolver := s.Manager.OrganizationResolver(tx, orgID) @@ -97,16 +101,23 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat orgSettings[orgID] = *settings } - // collect all diffs to do 1 sql update for all orgs + // groupIDsToAdd & groupIDsToRemove are the final group differences + // needed to be applied to user. The loop below will iterate over all + // organizations the user is in, and determine the diffs. + // The diffs are applied as a batch sql query, rather than each + // organization having to execute a query. groupIDsToAdd := make([]uuid.UUID, 0) groupIDsToRemove := make([]uuid.UUID, 0) // For each org, determine which groups the user should land in for orgID, settings := range orgSettings { if settings.Field == "" { // No group sync enabled for this org, so do nothing. + // The user can remain in their groups for this org. continue } + // expectedGroups is the set of groups the IDP expects the + // user to be a member of. expectedGroups, err := settings.ParseClaims(orgID, params.MergedClaims) if err != nil { s.Logger.Debug(ctx, "failed to parse claims for groups", @@ -117,7 +128,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat // Unsure where to raise this error on the UI or database. continue } - // Everyone group is always implied. + // Everyone group is always implied, so include it. expectedGroups = append(expectedGroups, ExpectedGroup{ OrganizationID: orgID, GroupID: &orgID, @@ -134,6 +145,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat GroupName: &f.Group.Name, } }) + add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool { // Must match if a.OrganizationID != b.OrganizationID { @@ -150,10 +162,10 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat }) for _, r := range remove { - // This should never happen. All group removals come from the - // existing set, which come from the db. All groups from the - // database have IDs. This code is purely defensive. if r.GroupID == nil { + // This should never happen. All group removals come from the + // existing set, which come from the db. All groups from the + // database have IDs. This code is purely defensive. detail := "user:" + user.Username if r.GroupName != nil { detail += fmt.Sprintf(" from group %s", *r.GroupName) @@ -166,6 +178,11 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat // HandleMissingGroups will add the new groups to the org if // the settings specify. It will convert all group names into uuids // for easier assignment. + // TODO: This code should be batched at the end of the for loop. + // Optimizing this is being pushed because if AutoCreate is disabled, + // this code will only add cost on the first login for each user. + // AutoCreate is usually disabled for large deployments. + // For small deployments, this is less of a problem. assignGroups, err := settings.HandleMissingGroups(ctx, tx, orgID, add) if err != nil { return xerrors.Errorf("handle missing groups: %w", err) @@ -174,6 +191,8 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat groupIDsToAdd = append(groupIDsToAdd, assignGroups...) } + // ApplyGroupDifference will take the total adds and removes, and apply + // them. err = s.ApplyGroupDifference(ctx, tx, user, groupIDsToAdd, groupIDsToRemove) if err != nil { return xerrors.Errorf("apply group difference: %w", err) @@ -190,8 +209,6 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat // ApplyGroupDifference will add and remove the user from the specified groups. func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, removeIDs []uuid.UUID) error { - // Always do group removal before group add. This way if there is an error, - // we error on the underprivileged side. if len(removeIDs) > 0 { removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{ UserID: user.ID, diff --git a/enterprise/coderd/enidpsync/groups.go b/enterprise/coderd/enidpsync/groups.go index 2ecc8703e29cd..932357e2772fe 100644 --- a/enterprise/coderd/enidpsync/groups.go +++ b/enterprise/coderd/enidpsync/groups.go @@ -17,7 +17,8 @@ func (e EnterpriseIDPSync) GroupSyncEnabled() bool { // ParseGroupClaims parses the user claims and handles deployment wide group behavior. // Almost all behavior is deferred since each organization configures it's own // group sync settings. -// TODO: Implement group allow_list behavior here since that is deployment wide. +// GroupAllowList is implemented here to prevent login by unauthorized users. +// TODO: GroupAllowList overlaps with the default organization group sync settings. func (e EnterpriseIDPSync) ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (idpsync.GroupParams, *idpsync.HTTPError) { if !e.GroupSyncEnabled() { return e.AGPLIDPSync.ParseGroupClaims(ctx, mergedClaims) From bf0d4edac865146f84456658bf768775dfd27e77 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 10:42:46 -0500 Subject: [PATCH 26/38] fixup compile issues from rebase --- coderd/idpsync/group_test.go | 12 ++++++------ coderd/idpsync/idpsync.go | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 4e56260400114..0ef4e18b40bec 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -29,7 +29,7 @@ func TestParseGroupClaims(t *testing.T) { t.Parallel() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewStoreManager(), + runtimeconfig.NewManager(), idpsync.DeploymentSyncSettings{}) ctx := testutil.Context(t, testutil.WaitMedium) @@ -45,7 +45,7 @@ func TestParseGroupClaims(t *testing.T) { t.Parallel() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewStoreManager(), + runtimeconfig.NewManager(), idpsync.DeploymentSyncSettings{ GroupField: "groups", GroupAllowList: map[string]struct{}{ @@ -232,7 +232,7 @@ func TestGroupSyncTable(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) - manager := runtimeconfig.NewStoreManager() + manager := runtimeconfig.NewManager() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), manager, idpsync.DeploymentSyncSettings{ @@ -264,7 +264,7 @@ func TestGroupSyncTable(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) - manager := runtimeconfig.NewStoreManager() + manager := runtimeconfig.NewManager() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), manager, // Also sync the default org! @@ -444,7 +444,7 @@ func TestApplyGroupDifference(t *testing.T) { for _, tc := range testCase { tc := tc t.Run(tc.Name, func(t *testing.T) { - mgr := runtimeconfig.NewStoreManager() + mgr := runtimeconfig.NewManager() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitMedium) @@ -505,7 +505,7 @@ func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, require.NoError(t, err, "Everyone group for an org") } - manager := runtimeconfig.NewStoreManager() + manager := runtimeconfig.NewManager() orgResolver := manager.OrganizationResolver(db, org.ID) err = s.Group.SetRuntimeValue(context.Background(), orgResolver, def.Settings) require.NoError(t, err) diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 7fac0e7329d3d..2c2b185c619c9 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -45,7 +45,7 @@ type IDPSync interface { // IDP. All related code to syncing user information should be in this package. type AGPLIDPSync struct { Logger slog.Logger - Manager runtimeconfig.Manager + Manager *runtimeconfig.Manager SyncSettings } @@ -108,7 +108,7 @@ type SyncSettings struct { Group runtimeconfig.RuntimeEntry[*GroupSyncSettings] } -func NewAGPLSync(logger slog.Logger, manager runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync { +func NewAGPLSync(logger slog.Logger, manager *runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync { return &AGPLIDPSync{ Logger: logger.Named("idp-sync"), Manager: manager, From f95128e14401ba41eca811df86be8e2398c8dbe5 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 11:23:28 -0500 Subject: [PATCH 27/38] add test for disabled sync --- coderd/idpsync/group.go | 17 ++---- coderd/idpsync/group_test.go | 54 +++++++++++++++++++ coderd/idpsync/organizations_test.go | 4 +- coderd/runtimeconfig/entry.go | 9 ++++ enterprise/coderd/enidpsync/enidpsync.go | 2 +- enterprise/coderd/enidpsync/groups_test.go | 6 +-- .../coderd/enidpsync/organizations_test.go | 2 +- 7 files changed, 74 insertions(+), 20 deletions(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 743b368b094f3..a54f6fbfa09cf 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -81,12 +81,11 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat orgResolver := s.Manager.OrganizationResolver(tx, orgID) settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver) if err != nil { - if xerrors.Is(err, runtimeconfig.ErrEntryNotFound) { - // Default to not being configured - settings = &GroupSyncSettings{} - } else { + if !xerrors.Is(err, runtimeconfig.ErrEntryNotFound) { return xerrors.Errorf("resolve group sync settings: %w", err) } + // Default to not being configured + settings = &GroupSyncSettings{} } // Legacy deployment settings will override empty settings. @@ -273,15 +272,7 @@ func (s *GroupSyncSettings) Set(v string) error { } func (s *GroupSyncSettings) String() string { - v, err := json.Marshal(s) - if err != nil { - return "decode failed: " + err.Error() - } - return string(v) -} - -func (s *GroupSyncSettings) Type() string { - return "GroupSyncSettings" + return runtimeconfig.JSONString(s) } type ExpectedGroup struct { diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 0ef4e18b40bec..07c9052881fad 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -342,6 +342,60 @@ func TestGroupSyncTable(t *testing.T) { }) } +func TestSyncDisabled(t *testing.T) { + t.Parallel() + + if dbtestutil.WillUsePostgres() { + t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.") + } + + db, _ := dbtestutil.NewDB(t) + manager := runtimeconfig.NewManager() + s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), + manager, + idpsync.DeploymentSyncSettings{}, + ) + + ids := coderdtest.NewDeterministicUUIDGenerator() + ctx := testutil.Context(t, testutil.WaitSuperLong) + user := dbgen.User(t, db, database.User{}) + orgID := uuid.New() + + def := orgSetupDefinition{ + Name: "SyncDisabled", + Groups: map[uuid.UUID]bool{ + ids.ID("foo"): true, + ids.ID("bar"): true, + ids.ID("baz"): false, + ids.ID("bop"): false, + }, + Settings: &idpsync.GroupSyncSettings{ + Field: "groups", + Mapping: map[string][]uuid.UUID{ + "foo": {ids.ID("foo")}, + "baz": {ids.ID("baz")}, + }, + }, + ExpectedGroups: []uuid.UUID{ + ids.ID("foo"), + ids.ID("bar"), + }, + } + + SetupOrganization(t, s, db, user, orgID, def) + + // Do the group sync! + err := s.SyncGroups(ctx, db, user, idpsync.GroupParams{ + SyncEnabled: false, + MergedClaims: jwt.MapClaims{ + "groups": []string{"baz", "bop"}, + }, + }) + require.NoError(t, err) + + def.Assert(t, orgID, db, user) +} + // TestApplyGroupDifference is mainly testing the database functions func TestApplyGroupDifference(t *testing.T) { t.Parallel() diff --git a/coderd/idpsync/organizations_test.go b/coderd/idpsync/organizations_test.go index 934d7d83816ab..1670beaaedc75 100644 --- a/coderd/idpsync/organizations_test.go +++ b/coderd/idpsync/organizations_test.go @@ -20,7 +20,7 @@ func TestParseOrganizationClaims(t *testing.T) { t.Parallel() s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewStoreManager(), + runtimeconfig.NewManager(), idpsync.DeploymentSyncSettings{ OrganizationField: "", OrganizationMapping: nil, @@ -42,7 +42,7 @@ func TestParseOrganizationClaims(t *testing.T) { // AGPL has limited behavior s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewStoreManager(), + runtimeconfig.NewManager(), idpsync.DeploymentSyncSettings{ OrganizationField: "orgs", OrganizationMapping: map[string][]uuid.UUID{ diff --git a/coderd/runtimeconfig/entry.go b/coderd/runtimeconfig/entry.go index 780138a89d03b..c0260b0268ddb 100644 --- a/coderd/runtimeconfig/entry.go +++ b/coderd/runtimeconfig/entry.go @@ -2,6 +2,7 @@ package runtimeconfig import ( "context" + "encoding/json" "fmt" "golang.org/x/xerrors" @@ -93,3 +94,11 @@ func (e *RuntimeEntry[T]) name() (string, error) { return e.n, nil } + +func JSONString(v any) string { + s, err := json.Marshal(v) + if err != nil { + return "decode failed: " + err.Error() + } + return string(s) +} diff --git a/enterprise/coderd/enidpsync/enidpsync.go b/enterprise/coderd/enidpsync/enidpsync.go index a7ff1eaa07257..c7ba8dd3ecdc6 100644 --- a/enterprise/coderd/enidpsync/enidpsync.go +++ b/enterprise/coderd/enidpsync/enidpsync.go @@ -17,7 +17,7 @@ type EnterpriseIDPSync struct { *idpsync.AGPLIDPSync } -func NewSync(logger slog.Logger, manager runtimeconfig.Manager, set *entitlements.Set, settings idpsync.DeploymentSyncSettings) *EnterpriseIDPSync { +func NewSync(logger slog.Logger, manager *runtimeconfig.Manager, set *entitlements.Set, settings idpsync.DeploymentSyncSettings) *EnterpriseIDPSync { return &EnterpriseIDPSync{ entitlements: set, AGPLIDPSync: idpsync.NewAGPLSync(logger.With(slog.F("enterprise_capable", "true")), manager, settings), diff --git a/enterprise/coderd/enidpsync/groups_test.go b/enterprise/coderd/enidpsync/groups_test.go index 8103f8a002937..77b078cd9e3f0 100644 --- a/enterprise/coderd/enidpsync/groups_test.go +++ b/enterprise/coderd/enidpsync/groups_test.go @@ -30,7 +30,7 @@ func TestEnterpriseParseGroupClaims(t *testing.T) { t.Parallel() s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewStoreManager(), + runtimeconfig.NewManager(), entitlements.New(), idpsync.DeploymentSyncSettings{}) @@ -46,7 +46,7 @@ func TestEnterpriseParseGroupClaims(t *testing.T) { t.Parallel() s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewStoreManager(), + runtimeconfig.NewManager(), entitled, idpsync.DeploymentSyncSettings{ GroupField: "groups", @@ -74,7 +74,7 @@ func TestEnterpriseParseGroupClaims(t *testing.T) { t.Parallel() s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}), - runtimeconfig.NewStoreManager(), + runtimeconfig.NewManager(), entitled, idpsync.DeploymentSyncSettings{ GroupField: "groups", diff --git a/enterprise/coderd/enidpsync/organizations_test.go b/enterprise/coderd/enidpsync/organizations_test.go index e01ae5a18d98b..cb6da2723b2f5 100644 --- a/enterprise/coderd/enidpsync/organizations_test.go +++ b/enterprise/coderd/enidpsync/organizations_test.go @@ -237,7 +237,7 @@ func TestOrganizationSync(t *testing.T) { } // Create a new sync object - sync := enidpsync.NewSync(logger, runtimeconfig.NewStoreManager(), caseData.Entitlements, caseData.Settings) + sync := enidpsync.NewSync(logger, runtimeconfig.NewManager(), caseData.Entitlements, caseData.Settings) for _, exp := range caseData.Exps { t.Run(exp.Name, func(t *testing.T) { params, httpErr := sync.ParseOrganizationClaims(ctx, exp.Claims) From 88b0ad9b86a9be67e83210015e332e43a5dbe10f Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 12:51:52 -0500 Subject: [PATCH 28/38] linting --- coderd/database/queries.sql.go | 8 ++++---- coderd/idpsync/group_test.go | 2 ++ enterprise/coderd/userauth_test.go | 6 +++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 52044e4e7e90d..191cf291102ad 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3126,7 +3126,7 @@ func (q *sqlQuerier) GetJFrogXrayScanByWorkspaceAndAgentID(ctx context.Context, } const upsertJFrogXrayScanByWorkspaceAndAgentID = `-- name: UpsertJFrogXrayScanByWorkspaceAndAgentID :exec -INSERT INTO +INSERT INTO jfrog_xray_scans ( agent_id, workspace_id, @@ -3135,7 +3135,7 @@ INSERT INTO medium, results_url ) -VALUES +VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (agent_id, workspace_id) DO UPDATE SET critical = $3, high = $4, medium = $5, results_url = $6 @@ -5863,7 +5863,7 @@ FROM provisioner_keys WHERE organization_id = $1 -AND +AND lower(name) = lower($2) ` @@ -7616,7 +7616,7 @@ func (q *sqlQuerier) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUI } const updateTailnetPeerStatusByCoordinator = `-- name: UpdateTailnetPeerStatusByCoordinator :exec -UPDATE +UPDATE tailnet_peers SET status = $2 diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 07c9052881fad..a3c9140577b8c 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -498,6 +498,8 @@ func TestApplyGroupDifference(t *testing.T) { for _, tc := range testCase { tc := tc t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + mgr := runtimeconfig.NewManager() db, _ := dbtestutil.NewDB(t) diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index 0ab67542cc2c7..3b42dc1aeec5f 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -438,7 +438,7 @@ func TestUserOIDC(t *testing.T) { }, DeploymentValues: func(dv *codersdk.DeploymentValues) { dv.OIDC.GroupField = groupClaim - dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{map[string]string{oidcGroupName: coderGroupName}} + dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{Value: map[string]string{oidcGroupName: coderGroupName}} }, }) @@ -750,7 +750,7 @@ func TestGroupSync(t *testing.T) { // From a,c,b -> b,c,d name: "ChangeUserGroups", modDV: func(dv *codersdk.DeploymentValues) { - dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{map[string]string{"D": "d"}} + dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{Value: map[string]string{"D": "d"}} }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"a", "b", "c"}, @@ -796,7 +796,7 @@ func TestGroupSync(t *testing.T) { dv.OIDC.GroupAutoCreate = true // Only single letter groups dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile("^[a-z]$")) - dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{map[string]string{"zebra": "z"}} + dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{Value: map[string]string{"zebra": "z"}} }, initialOrgGroups: []string{"a", "b", "c", "d"}, initialUserGroups: []string{"a", "b", "c"}, From 6491f6ac295dfc3a765d95ef883d14f61647d4fc Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 13:06:23 -0500 Subject: [PATCH 29/38] chore: handle db conflicts gracefully --- coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 2 ++ coderd/database/queries/groupmembers.sql | 2 ++ coderd/idpsync/group.go | 2 ++ coderd/idpsync/group_test.go | 16 ++++++++++++---- 5 files changed, 19 insertions(+), 4 deletions(-) diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 315f2d6fa1cfd..ee9a64f12076d 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -370,6 +370,7 @@ type sqlcQuerier interface { InsertTemplateVersionWorkspaceTag(ctx context.Context, arg InsertTemplateVersionWorkspaceTagParams) (TemplateVersionWorkspaceTag, error) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) // InsertUserGroupsByID adds a user to all provided groups, if they exist. + // If there is a conflict, the user is already a member InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error) // InsertUserGroupsByName adds a user to all provided groups, if they exist. InsertUserGroupsByName(ctx context.Context, arg InsertUserGroupsByNameParams) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 191cf291102ad..c9f1d1de145d9 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1462,6 +1462,7 @@ SELECT groups.id FROM groups +ON CONFLICT DO NOTHING RETURNING group_id ` @@ -1471,6 +1472,7 @@ type InsertUserGroupsByIDParams struct { } // InsertUserGroupsByID adds a user to all provided groups, if they exist. +// If there is a conflict, the user is already a member func (q *sqlQuerier) InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error) { rows, err := q.db.QueryContext(ctx, insertUserGroupsByID, arg.UserID, pq.Array(arg.GroupIds)) if err != nil { diff --git a/coderd/database/queries/groupmembers.sql b/coderd/database/queries/groupmembers.sql index 814f878cb9232..4efe9bf488590 100644 --- a/coderd/database/queries/groupmembers.sql +++ b/coderd/database/queries/groupmembers.sql @@ -46,6 +46,8 @@ SELECT groups.id FROM groups +-- If there is a conflict, the user is already a member +ON CONFLICT DO NOTHING RETURNING group_id; -- name: RemoveUserFromAllGroups :exec diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index a54f6fbfa09cf..704fd1b10ea75 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -226,6 +226,8 @@ func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store } if len(add) > 0 { + add = slice.Unique(add) + // Defensive programming to only insert uniques. assignedGroupIDs, err := tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{ UserID: user.ID, GroupIds: add, diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index a3c9140577b8c..0f9d0345f1e60 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -2,6 +2,7 @@ package idpsync_test import ( "context" + "database/sql" "regexp" "testing" @@ -9,6 +10,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" "golang.org/x/exp/slices" + "golang.org/x/xerrors" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/coderdtest" @@ -64,6 +66,7 @@ func TestParseGroupClaims(t *testing.T) { func TestGroupSyncTable(t *testing.T) { t.Parallel() + // Last checked, takes 30s with postgres on a fast machine. if dbtestutil.WillUsePostgres() { t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.") } @@ -553,10 +556,15 @@ func TestApplyGroupDifference(t *testing.T) { func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, orgID uuid.UUID, def orgSetupDefinition) { t.Helper() - org := dbgen.Organization(t, db, database.Organization{ - ID: orgID, - }) - _, err := db.InsertAllUsersGroup(context.Background(), org.ID) + // Account that the org might be the default organization + org, err := db.GetOrganizationByID(context.Background(), orgID) + if xerrors.Is(err, sql.ErrNoRows) { + org = dbgen.Organization(t, db, database.Organization{ + ID: orgID, + }) + } + + _, err = db.InsertAllUsersGroup(context.Background(), org.ID) if !database.IsUniqueViolation(err) { require.NoError(t, err, "Everyone group for an org") } From bd2328836d951ae48993e0d19cd6e07e0b881e44 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 13:21:54 -0500 Subject: [PATCH 30/38] test expected group equality --- coderd/idpsync/group.go | 38 ++++++--- coderd/idpsync/group_test.go | 146 +++++++++++++++++++++++++++++++++++ 2 files changed, 172 insertions(+), 12 deletions(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 704fd1b10ea75..c779b7ed15df3 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -146,18 +146,7 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat }) add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool { - // Must match - if a.OrganizationID != b.OrganizationID { - return false - } - // Only the name or the name needs to be checked, priority is given to the ID. - if a.GroupID != nil && b.GroupID != nil { - return *a.GroupID == *b.GroupID - } - if a.GroupName != nil && b.GroupName != nil { - return *a.GroupName == *b.GroupName - } - return false + return a.Equal(b) }) for _, r := range remove { @@ -283,6 +272,31 @@ type ExpectedGroup struct { GroupName *string } +// Equal compares two ExpectedGroups. The org id must be the same. +// If the group ID is set, it will be compared and take priorty, ignoring the +// name value. So 2 groups with the same ID but different names will be +// considered equal. +func (a ExpectedGroup) Equal(b ExpectedGroup) bool { + // Must match + if a.OrganizationID != b.OrganizationID { + return false + } + // Only the name or the name needs to be checked, priority is given to the ID. + if a.GroupID != nil && b.GroupID != nil { + return *a.GroupID == *b.GroupID + } + if a.GroupName != nil && b.GroupName != nil { + return *a.GroupName == *b.GroupName + } + + // If everything is nil, it is equal. Although a bit pointless + if a.GroupID == nil && b.GroupID == nil && + a.GroupName == nil && b.GroupName == nil { + return true + } + return false +} + // ParseClaims will take the merged claims from the IDP and return the groups // the user is expected to be a member of. The expected group can either be a // name or an ID. diff --git a/coderd/idpsync/group_test.go b/coderd/idpsync/group_test.go index 0f9d0345f1e60..cf312a576d720 100644 --- a/coderd/idpsync/group_test.go +++ b/coderd/idpsync/group_test.go @@ -21,6 +21,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/idpsync" "github.com/coder/coder/v2/coderd/runtimeconfig" + "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/testutil" ) @@ -553,6 +554,151 @@ func TestApplyGroupDifference(t *testing.T) { } } +func TestExpectedGroupEqual(t *testing.T) { + t.Parallel() + + ids := coderdtest.NewDeterministicUUIDGenerator() + testCases := []struct { + Name string + A idpsync.ExpectedGroup + B idpsync.ExpectedGroup + Equal bool + }{ + { + Name: "Empty", + A: idpsync.ExpectedGroup{}, + B: idpsync.ExpectedGroup{}, + Equal: true, + }, + { + Name: "DifferentOrgs", + A: idpsync.ExpectedGroup{ + OrganizationID: uuid.New(), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + B: idpsync.ExpectedGroup{ + OrganizationID: uuid.New(), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + Equal: false, + }, + { + Name: "SameID", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + Equal: true, + }, + { + Name: "DifferentIDs", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(uuid.New()), + GroupName: nil, + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(uuid.New()), + GroupName: nil, + }, + Equal: false, + }, + { + Name: "SameName", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: ptr.Ref("foo"), + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: ptr.Ref("foo"), + }, + Equal: true, + }, + { + Name: "DifferentName", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: ptr.Ref("foo"), + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: ptr.Ref("bar"), + }, + Equal: false, + }, + // Edge cases + { + // A bit strange, but valid as ID takes priority. + // We assume 2 groups with the same ID are equal, even if + // their names are different. Names are mutable, IDs are not, + // so there is 0% chance they are different groups. + Name: "DifferentIDSameName", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: ptr.Ref("foo"), + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: ptr.Ref("bar"), + }, + Equal: true, + }, + { + Name: "MixedNils", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: ptr.Ref("bar"), + }, + Equal: false, + }, + { + Name: "NoComparable", + A: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: ptr.Ref(ids.ID("g1")), + GroupName: nil, + }, + B: idpsync.ExpectedGroup{ + OrganizationID: ids.ID("org"), + GroupID: nil, + GroupName: nil, + }, + Equal: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + + require.Equal(t, tc.Equal, tc.A.Equal(tc.B)) + }) + } +} + func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, orgID uuid.UUID, def orgSetupDefinition) { t.Helper() From a390ec4cba6db241cbd1ff42115c7ad42907656b Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 13:30:50 -0500 Subject: [PATCH 31/38] cleanup comments --- coderd/idpsync/group.go | 7 +++---- coderd/idpsync/idpsync.go | 4 +++- coderd/userauth.go | 2 ++ 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index c779b7ed15df3..8a097dca37f47 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -206,7 +206,7 @@ func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store return xerrors.Errorf("remove user from %d groups: %w", len(removeIDs), err) } if len(removedGroupIDs) != len(removeIDs) { - s.Logger.Debug(ctx, "failed to remove user from all groups", + s.Logger.Debug(ctx, "user not removed from expected number of groups", slog.F("user_id", user.ID), slog.F("groups_removed_count", len(removedGroupIDs)), slog.F("expected_count", len(removeIDs)), @@ -225,7 +225,7 @@ func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store return xerrors.Errorf("insert user into %d groups: %w", len(add), err) } if len(assignedGroupIDs) != len(add) { - s.Logger.Debug(ctx, "failed to assign all groups to user", + s.Logger.Debug(ctx, "user not assigned to expected number of groups", slog.F("user_id", user.ID), slog.F("groups_assigned_count", len(assignedGroupIDs)), slog.F("expected_count", len(add)), @@ -355,8 +355,7 @@ func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClai // TODO: Batching this would be better, as this is 1 or 2 db calls per organization. func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) { // All expected that are missing IDs means the group does not exist - // in the database. Either remove them, or create them if auto create is - // turned on. + // in the database, or it is a legacy mapping, and we need to do a lookup. var missingGroups []string addIDs := make([]uuid.UUID, 0) diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 2c2b185c619c9..2c99e780ffee6 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -90,7 +90,9 @@ func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings OrganizationMapping: dv.OIDC.OrganizationMapping.Value, OrganizationAssignDefault: dv.OIDC.OrganizationAssignDefault.Value(), - // TODO: Separate group field for allow list from default org + // TODO: Separate group field for allow list from default org. + // Right now you cannot disable group sync from the default org and + // configure an allow list. GroupField: dv.OIDC.GroupField.Value(), GroupAllowList: ConvertAllowList(dv.OIDC.GroupAllowList.Value()), Legacy: DefaultOrgLegacySettings{ diff --git a/coderd/userauth.go b/coderd/userauth.go index a2c8140c65be5..223f697c09bb9 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -1389,6 +1389,8 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C return xerrors.Errorf("sync organizations: %w", err) } + // Group sync needs to occur after org sync, since a user can join an org, + // then have their groups sync to said org. err = api.IDPSync.SyncGroups(ctx, tx, user, params.GroupSync) if err != nil { return xerrors.Errorf("sync groups: %w", err) From a0a1c53bdfcdd7f2259403ce0165ff5c081c98b3 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 13:34:33 -0500 Subject: [PATCH 32/38] spelling mistake --- coderd/idpsync/group.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 8a097dca37f47..91e440c38b668 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -273,7 +273,7 @@ type ExpectedGroup struct { } // Equal compares two ExpectedGroups. The org id must be the same. -// If the group ID is set, it will be compared and take priorty, ignoring the +// If the group ID is set, it will be compared and take priority, ignoring the // name value. So 2 groups with the same ID but different names will be // considered equal. func (a ExpectedGroup) Equal(b ExpectedGroup) bool { From a86ba834180aad44a6ac675783abb1e2a6263dbd Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 6 Sep 2024 14:35:13 -0500 Subject: [PATCH 33/38] linting: --- enterprise/coderd/enidpsync/groups.go | 1 - 1 file changed, 1 deletion(-) diff --git a/enterprise/coderd/enidpsync/groups.go b/enterprise/coderd/enidpsync/groups.go index 932357e2772fe..dc8456fc6b1c9 100644 --- a/enterprise/coderd/enidpsync/groups.go +++ b/enterprise/coderd/enidpsync/groups.go @@ -61,7 +61,6 @@ func (e EnterpriseIDPSync) ParseGroupClaims(ctx context.Context, mergedClaims jw RenderStaticPage: true, } } - } return idpsync.GroupParams{ From 0df7f28209e5b32747f7859871cfa7c954faf6f6 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 9 Sep 2024 14:31:15 -0500 Subject: [PATCH 34/38] add interface method to allow api crud --- coderd/idpsync/group.go | 3 +++ coderd/idpsync/idpsync.go | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 91e440c38b668..153f5db91199f 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -28,6 +28,9 @@ func (AGPLIDPSync) GroupSyncEnabled() bool { // AGPL does not support syncing groups. return false } +func (s AGPLIDPSync) GroupSyncSettings() runtimeconfig.RuntimeEntry[*GroupSyncSettings] { + return s.Group +} func (s AGPLIDPSync) ParseGroupClaims(_ context.Context, _ jwt.MapClaims) (GroupParams, *HTTPError) { return GroupParams{ diff --git a/coderd/idpsync/idpsync.go b/coderd/idpsync/idpsync.go index 2c99e780ffee6..2c8ed10ce9bcc 100644 --- a/coderd/idpsync/idpsync.go +++ b/coderd/idpsync/idpsync.go @@ -36,9 +36,12 @@ type IDPSync interface { // ParseGroupClaims takes claims from an OIDC provider, and returns the params // for group syncing. Most of the logic happens in SyncGroups. ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (GroupParams, *HTTPError) - // SyncGroups assigns and removes users from groups based on the provided params. SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error + // GroupSyncSettings is exposed for the API to implement CRUD operations + // on the settings used by IDPSync. This entry is thread safe and can be + // accessed concurrently. The settings are stored in the database. + GroupSyncSettings() runtimeconfig.RuntimeEntry[*GroupSyncSettings] } // AGPLIDPSync is the configuration for syncing user information from an external From 7a802a9196b1501b77fb1d979bb1c14456d090c7 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 11 Sep 2024 11:49:14 -0500 Subject: [PATCH 35/38] Remove testable example --- coderd/coderdtest/uuids_test.go | 32 ++++++++------------------------ 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/coderd/coderdtest/uuids_test.go b/coderd/coderdtest/uuids_test.go index 5a0e10935bd50..935be36eb8b15 100644 --- a/coderd/coderdtest/uuids_test.go +++ b/coderd/coderdtest/uuids_test.go @@ -1,33 +1,17 @@ package coderdtest_test import ( - "github.com/google/uuid" + "testing" + + "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/coderdtest" ) -func ExampleNewDeterministicUUIDGenerator() { - det := coderdtest.NewDeterministicUUIDGenerator() - testCases := []struct { - CreateUsers []uuid.UUID - ExpectedIDs []uuid.UUID - }{ - { - CreateUsers: []uuid.UUID{ - det.ID("player1"), - det.ID("player2"), - }, - ExpectedIDs: []uuid.UUID{ - det.ID("player1"), - det.ID("player2"), - }, - }, - } +func TestDeterministicUUIDGenerator(t *testing.T) { + t.Parallel() - for _, tc := range testCases { - tc := tc - _ = tc - // Do the test with CreateUsers as the setup, and the expected IDs - // will match. - } + ids := coderdtest.NewDeterministicUUIDGenerator() + require.Equal(t, ids.ID("g1"), ids.ID("g1")) + require.NotEqual(t, ids.ID("g1"), ids.ID("g2")) } From 611f1e3a6a7b74ce159d561369352db81bab8143 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 11 Sep 2024 11:51:51 -0500 Subject: [PATCH 36/38] fix formatting of sql, add a comment --- coderd/database/queries.sql.go | 2 +- coderd/database/queries/groups.sql | 2 +- coderd/idpsync/group.go | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index c9f1d1de145d9..3616fcb66d3fb 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1680,7 +1680,7 @@ WHERE ELSE true END AND CASE WHEN array_length($3 :: text[], 1) > 0 THEN - groups.name = ANY($3) + groups.name = ANY($3) ELSE true END ` diff --git a/coderd/database/queries/groups.sql b/coderd/database/queries/groups.sql index 0df848d6a6d05..780c0d0154740 100644 --- a/coderd/database/queries/groups.sql +++ b/coderd/database/queries/groups.sql @@ -53,7 +53,7 @@ WHERE ELSE true END AND CASE WHEN array_length(@group_names :: text[], 1) > 0 THEN - groups.name = ANY(@group_names) + groups.name = ANY(@group_names) ELSE true END ; diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 153f5db91199f..7c61aeb2fe4ef 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -128,6 +128,9 @@ func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user dat slog.Error(err), ) // Unsure where to raise this error on the UI or database. + // TODO: This error prevents group sync, but we have no way + // to raise this to an org admin. Come up with a solution to + // notify the admin and user of this issue. continue } // Everyone group is always implied, so include it. From 7f28a5359be59fb978170f2ddd3444391b200510 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 11 Sep 2024 12:09:27 -0500 Subject: [PATCH 37/38] remove function only used in 1 place --- coderd/database/dbmem/dbmem.go | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 2e4e737ed5428..ed766d48ecd43 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -682,17 +682,6 @@ func (q *FakeQuerier) getWorkspaceResourcesByJobIDNoLock(_ context.Context, jobI return resources, nil } -func (q *FakeQuerier) getGroupByNameNoLock(arg database.NameOrganizationPair) (database.Group, error) { - for _, group := range q.groups { - if group.OrganizationID == arg.OrganizationID && - group.Name == arg.Name { - return group, nil - } - } - - return database.Group{}, sql.ErrNoRows -} - func (q *FakeQuerier) getGroupByIDNoLock(_ context.Context, id uuid.UUID) (database.Group, error) { for _, group := range q.groups { if group.ID == id { @@ -2624,10 +2613,14 @@ func (q *FakeQuerier) GetGroupByOrgAndName(_ context.Context, arg database.GetGr q.mutex.RLock() defer q.mutex.RUnlock() - return q.getGroupByNameNoLock(database.NameOrganizationPair{ - Name: arg.Name, - OrganizationID: arg.OrganizationID, - }) + for _, group := range q.groups { + if group.OrganizationID == arg.OrganizationID && + group.Name == arg.Name { + return group, nil + } + } + + return database.Group{}, sql.ErrNoRows } func (q *FakeQuerier) GetGroupMembers(ctx context.Context) ([]database.GroupMember, error) { From 41994d2195e5980b7a5fea352a77aeb41a61cbfb Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 11 Sep 2024 12:14:14 -0500 Subject: [PATCH 38/38] make fmt --- coderd/idpsync/group.go | 1 + 1 file changed, 1 insertion(+) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 7c61aeb2fe4ef..1b6b8f76dc685 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -28,6 +28,7 @@ func (AGPLIDPSync) GroupSyncEnabled() bool { // AGPL does not support syncing groups. return false } + func (s AGPLIDPSync) GroupSyncSettings() runtimeconfig.RuntimeEntry[*GroupSyncSettings] { return s.Group } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy