diff --git a/cli/server.go b/cli/server.go index 602f05d028b66..26d0c8f110403 100644 --- a/cli/server.go +++ b/cli/server.go @@ -1101,7 +1101,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. autobuildTicker := time.NewTicker(vals.AutobuildPollInterval.Value()) defer autobuildTicker.Stop() autobuildExecutor := autobuild.NewExecutor( - ctx, options.Database, options.Pubsub, coderAPI.FileCache, options.PrometheusRegistry, coderAPI.TemplateScheduleStore, &coderAPI.Auditor, coderAPI.AccessControlStore, logger, autobuildTicker.C, options.NotificationsEnqueuer, coderAPI.Experiments) + ctx, options.Database, options.Pubsub, coderAPI.FileCache, options.PrometheusRegistry, coderAPI.TemplateScheduleStore, &coderAPI.Auditor, coderAPI.AccessControlStore, coderAPI.BuildUsageChecker, logger, autobuildTicker.C, options.NotificationsEnqueuer, coderAPI.Experiments) autobuildExecutor.Run() jobReaperTicker := time.NewTicker(vals.JobReaperDetectorInterval.Value()) diff --git a/coderd/autobuild/lifecycle_executor.go b/coderd/autobuild/lifecycle_executor.go index d49bf831515d0..234a72de04c50 100644 --- a/coderd/autobuild/lifecycle_executor.go +++ b/coderd/autobuild/lifecycle_executor.go @@ -42,6 +42,7 @@ type Executor struct { templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] accessControlStore *atomic.Pointer[dbauthz.AccessControlStore] auditor *atomic.Pointer[audit.Auditor] + buildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker] log slog.Logger tick <-chan time.Time statsCh chan<- Stats @@ -65,7 +66,7 @@ type Stats struct { } // New returns a new wsactions executor. -func NewExecutor(ctx context.Context, db database.Store, ps pubsub.Pubsub, fc *files.Cache, reg prometheus.Registerer, tss *atomic.Pointer[schedule.TemplateScheduleStore], auditor *atomic.Pointer[audit.Auditor], acs *atomic.Pointer[dbauthz.AccessControlStore], log slog.Logger, tick <-chan time.Time, enqueuer notifications.Enqueuer, exp codersdk.Experiments) *Executor { +func NewExecutor(ctx context.Context, db database.Store, ps pubsub.Pubsub, fc *files.Cache, reg prometheus.Registerer, tss *atomic.Pointer[schedule.TemplateScheduleStore], auditor *atomic.Pointer[audit.Auditor], acs *atomic.Pointer[dbauthz.AccessControlStore], buildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker], log slog.Logger, tick <-chan time.Time, enqueuer notifications.Enqueuer, exp codersdk.Experiments) *Executor { factory := promauto.With(reg) le := &Executor{ //nolint:gocritic // Autostart has a limited set of permissions. @@ -78,6 +79,7 @@ func NewExecutor(ctx context.Context, db database.Store, ps pubsub.Pubsub, fc *f log: log.Named("autobuild"), auditor: auditor, accessControlStore: acs, + buildUsageChecker: buildUsageChecker, notificationsEnqueuer: enqueuer, reg: reg, experiments: exp, @@ -279,7 +281,7 @@ func (e *Executor) runOnce(t time.Time) Stats { } if nextTransition != "" { - builder := wsbuilder.New(ws, nextTransition). + builder := wsbuilder.New(ws, nextTransition, *e.buildUsageChecker.Load()). SetLastWorkspaceBuildInTx(&latestBuild). SetLastWorkspaceBuildJobInTx(&latestJob). Experiments(e.experiments). diff --git a/coderd/coderd.go b/coderd/coderd.go index fa10846a7d0a6..9115888fc566b 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -21,6 +21,7 @@ import ( "github.com/coder/coder/v2/coderd/oauth2provider" "github.com/coder/coder/v2/coderd/prebuilds" + "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/andybalholm/brotli" "github.com/go-chi/chi/v5" @@ -559,6 +560,13 @@ func New(options *Options) *API { // bugs that may only occur when a key isn't precached in tests and the latency cost is minimal. cryptokeys.StartRotator(ctx, options.Logger, options.Database) + // AGPL uses a no-op build usage checker as there are no license + // entitlements to enforce. This is swapped out in + // enterprise/coderd/coderd.go. + var buildUsageChecker atomic.Pointer[wsbuilder.UsageChecker] + var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{} + buildUsageChecker.Store(&noopUsageChecker) + api := &API{ ctx: ctx, cancel: cancel, @@ -579,6 +587,7 @@ func New(options *Options) *API { TemplateScheduleStore: options.TemplateScheduleStore, UserQuietHoursScheduleStore: options.UserQuietHoursScheduleStore, AccessControlStore: options.AccessControlStore, + BuildUsageChecker: &buildUsageChecker, FileCache: files.New(options.PrometheusRegistry, options.Authorizer), Experiments: experiments, WebpushDispatcher: options.WebPushDispatcher, @@ -1650,6 +1659,9 @@ type API struct { FileCache *files.Cache PrebuildsClaimer atomic.Pointer[prebuilds.Claimer] PrebuildsReconciler atomic.Pointer[prebuilds.ReconciliationOrchestrator] + // BuildUsageChecker is a pointer as it's passed around to multiple + // components. + BuildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker] UpdatesProvider tailnet.WorkspaceUpdatesProvider diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 96030b215e5dd..7085068e97ff4 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -55,6 +55,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/archive" "github.com/coder/coder/v2/coderd/files" + "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/coder/quartz" "github.com/coder/coder/v2/coderd" @@ -364,6 +365,10 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can } connectionLogger.Store(&options.ConnectionLogger) + var buildUsageChecker atomic.Pointer[wsbuilder.UsageChecker] + var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{} + buildUsageChecker.Store(&noopUsageChecker) + ctx, cancelFunc := context.WithCancel(context.Background()) experiments := coderd.ReadExperiments(*options.Logger, options.DeploymentValues.Experiments) lifecycleExecutor := autobuild.NewExecutor( @@ -375,6 +380,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can &templateScheduleStore, &auditor, accessControlStore, + &buildUsageChecker, *options.Logger, options.AutobuildTicker, options.NotificationsEnqueuer, diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a12db9aa6919f..257cbc6e6b142 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2193,6 +2193,14 @@ func (q *querier) GetLogoURL(ctx context.Context) (string, error) { return q.db.GetLogoURL(ctx) } +func (q *querier) GetManagedAgentCount(ctx context.Context, arg database.GetManagedAgentCountParams) (int64, error) { + // Must be able to read all workspaces to check usage. + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace); err != nil { + return 0, xerrors.Errorf("authorize read all workspaces: %w", err) + } + return q.db.GetManagedAgentCount(ctx, arg) +} + func (q *querier) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceNotificationMessage); err != nil { return nil, err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 2b0801024eb8d..bcf0caa95c365 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -17,20 +17,18 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" - - "github.com/coder/coder/v2/coderd/database/db2sdk" - "github.com/coder/coder/v2/coderd/notifications" - "github.com/coder/coder/v2/coderd/rbac/policy" - "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/util/slice" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisionersdk" "github.com/coder/coder/v2/testutil" ) @@ -903,6 +901,14 @@ func (s *MethodTestSuite) TestLicense() { require.NoError(s.T(), err) check.Args().Asserts().Returns("value") })) + s.Run("GetManagedAgentCount", s.Subtest(func(db database.Store, check *expects) { + start := dbtime.Now() + end := start.Add(time.Hour) + check.Args(database.GetManagedAgentCountParams{ + StartTime: start, + EndTime: end, + }).Asserts(rbac.ResourceWorkspace, policy.ActionRead).Returns(int64(0)) + })) } func (s *MethodTestSuite) TestOrganization() { diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index d4e1db1612790..811d945ac7da9 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -964,6 +964,13 @@ func (m queryMetricsStore) GetLogoURL(ctx context.Context) (string, error) { return url, err } +func (m queryMetricsStore) GetManagedAgentCount(ctx context.Context, arg database.GetManagedAgentCountParams) (int64, error) { + start := time.Now() + r0, r1 := m.s.GetManagedAgentCount(ctx, arg) + m.queryLatencies.WithLabelValues("GetManagedAgentCount").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) { start := time.Now() r0, r1 := m.s.GetNotificationMessagesByStatus(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index f3ed6c2bc78ca..b20c3d06209b5 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -2012,6 +2012,21 @@ func (mr *MockStoreMockRecorder) GetLogoURL(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogoURL", reflect.TypeOf((*MockStore)(nil).GetLogoURL), ctx) } +// GetManagedAgentCount mocks base method. +func (m *MockStore) GetManagedAgentCount(ctx context.Context, arg database.GetManagedAgentCountParams) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetManagedAgentCount", ctx, arg) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetManagedAgentCount indicates an expected call of GetManagedAgentCount. +func (mr *MockStoreMockRecorder) GetManagedAgentCount(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedAgentCount", reflect.TypeOf((*MockStore)(nil).GetManagedAgentCount), ctx, arg) +} + // GetNotificationMessagesByStatus mocks base method. func (m *MockStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 6471d79defa6c..baa5d8590b1d7 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -216,6 +216,8 @@ type sqlcQuerier interface { GetLicenseByID(ctx context.Context, id int32) (License, error) GetLicenses(ctx context.Context) ([]License, error) GetLogoURL(ctx context.Context) (string, error) + // This isn't strictly a license query, but it's related to license enforcement. + GetManagedAgentCount(ctx context.Context, arg GetManagedAgentCountParams) (int64, error) GetNotificationMessagesByStatus(ctx context.Context, arg GetNotificationMessagesByStatusParams) ([]NotificationMessage, error) // Fetch the notification report generator log indicating recent activity. GetNotificationReportGeneratorLogByTemplate(ctx context.Context, templateID uuid.UUID) (NotificationReportGeneratorLog, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 47d46a4e74a8b..4bf01000de0ec 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -4286,6 +4286,44 @@ func (q *sqlQuerier) GetLicenses(ctx context.Context) ([]License, error) { return items, nil } +const getManagedAgentCount = `-- name: GetManagedAgentCount :one +SELECT + COUNT(DISTINCT wb.id) AS count +FROM + workspace_builds AS wb +JOIN + provisioner_jobs AS pj +ON + wb.job_id = pj.id +WHERE + wb.transition = 'start'::workspace_transition + AND wb.has_ai_task = true + -- Only count jobs that are pending, running or succeeded. Other statuses + -- like cancel(ed|ing), failed or unknown are not considered as managed + -- agent usage. These workspace builds are typically unusable anyway. + AND pj.job_status IN ( + 'pending'::provisioner_job_status, + 'running'::provisioner_job_status, + 'succeeded'::provisioner_job_status + ) + -- Jobs are counted at the time they are created, not when they are + -- completed, as pending jobs haven't completed yet. + AND wb.created_at BETWEEN $1::timestamptz AND $2::timestamptz +` + +type GetManagedAgentCountParams struct { + StartTime time.Time `db:"start_time" json:"start_time"` + EndTime time.Time `db:"end_time" json:"end_time"` +} + +// This isn't strictly a license query, but it's related to license enforcement. +func (q *sqlQuerier) GetManagedAgentCount(ctx context.Context, arg GetManagedAgentCountParams) (int64, error) { + row := q.db.QueryRowContext(ctx, getManagedAgentCount, arg.StartTime, arg.EndTime) + var count int64 + err := row.Scan(&count) + return count, err +} + const getUnexpiredLicenses = `-- name: GetUnexpiredLicenses :many SELECT id, uploaded_at, jwt, exp, uuid FROM licenses diff --git a/coderd/database/queries/licenses.sql b/coderd/database/queries/licenses.sql index 3512a46514787..ac864a94d1792 100644 --- a/coderd/database/queries/licenses.sql +++ b/coderd/database/queries/licenses.sql @@ -35,3 +35,28 @@ DELETE FROM licenses WHERE id = $1 RETURNING id; + +-- name: GetManagedAgentCount :one +-- This isn't strictly a license query, but it's related to license enforcement. +SELECT + COUNT(DISTINCT wb.id) AS count +FROM + workspace_builds AS wb +JOIN + provisioner_jobs AS pj +ON + wb.job_id = pj.id +WHERE + wb.transition = 'start'::workspace_transition + AND wb.has_ai_task = true + -- Only count jobs that are pending, running or succeeded. Other statuses + -- like cancel(ed|ing), failed or unknown are not considered as managed + -- agent usage. These workspace builds are typically unusable anyway. + AND pj.job_status IN ( + 'pending'::provisioner_job_status, + 'running'::provisioner_job_status, + 'succeeded'::provisioner_job_status + ) + -- Jobs are counted at the time they are created, not when they are + -- completed, as pending jobs haven't completed yet. + AND wb.created_at BETWEEN @start_time::timestamptz AND @end_time::timestamptz; diff --git a/coderd/workspacebuilds.go b/coderd/workspacebuilds.go index 88774c63368ca..884a963405007 100644 --- a/coderd/workspacebuilds.go +++ b/coderd/workspacebuilds.go @@ -335,7 +335,7 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) { return } - builder := wsbuilder.New(workspace, database.WorkspaceTransition(createBuild.Transition)). + builder := wsbuilder.New(workspace, database.WorkspaceTransition(createBuild.Transition), *api.BuildUsageChecker.Load()). Initiator(apiKey.UserID). RichParameterValues(createBuild.RichParameterValues). LogLevel(string(createBuild.LogLevel)). diff --git a/coderd/workspaces.go b/coderd/workspaces.go index 32b412946907e..0f3f0a24c75d3 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -701,7 +701,7 @@ func createWorkspace( return xerrors.Errorf("get workspace by ID: %w", err) } - builder := wsbuilder.New(workspace, database.WorkspaceTransitionStart). + builder := wsbuilder.New(workspace, database.WorkspaceTransitionStart, *api.BuildUsageChecker.Load()). Reason(database.BuildReasonInitiator). Initiator(initiatorID). ActiveVersion(). diff --git a/coderd/wsbuilder/wsbuilder.go b/coderd/wsbuilder/wsbuilder.go index d608682c58eee..52567b463baac 100644 --- a/coderd/wsbuilder/wsbuilder.go +++ b/coderd/wsbuilder/wsbuilder.go @@ -56,6 +56,7 @@ type Builder struct { logLevel string deploymentValues *codersdk.DeploymentValues experiments codersdk.Experiments + usageChecker UsageChecker richParameterValues []codersdk.WorkspaceBuildParameter initiator uuid.UUID @@ -89,7 +90,24 @@ type Builder struct { verifyNoLegacyParametersOnce bool } -type Option func(Builder) Builder +type UsageChecker interface { + CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (UsageCheckResponse, error) +} + +type UsageCheckResponse struct { + Permitted bool + Message string +} + +type NoopUsageChecker struct{} + +var _ UsageChecker = NoopUsageChecker{} + +func (NoopUsageChecker) CheckBuildUsage(_ context.Context, _ database.Store, _ *database.TemplateVersion) (UsageCheckResponse, error) { + return UsageCheckResponse{ + Permitted: true, + }, nil +} // versionTarget expresses how to determine the template version for the build. // @@ -121,8 +139,8 @@ type stateTarget struct { explicit *[]byte } -func New(w database.Workspace, t database.WorkspaceTransition) Builder { - return Builder{workspace: w, trans: t} +func New(w database.Workspace, t database.WorkspaceTransition, uc UsageChecker) Builder { + return Builder{workspace: w, trans: t, usageChecker: uc} } // Methods that customize the build are public, have a struct receiver and return a new Builder. @@ -321,6 +339,10 @@ func (b *Builder) buildTx(authFunc func(action policy.Action, object rbac.Object if err != nil { return nil, nil, nil, err } + err = b.checkUsage() + if err != nil { + return nil, nil, nil, err + } err = b.checkRunningBuild() if err != nil { return nil, nil, nil, err @@ -1253,6 +1275,23 @@ func (b *Builder) checkTemplateJobStatus() error { return nil } +func (b *Builder) checkUsage() error { + templateVersion, err := b.getTemplateVersion() + if err != nil { + return BuildError{http.StatusInternalServerError, "Failed to fetch template version", err} + } + + resp, err := b.usageChecker.CheckBuildUsage(b.ctx, b.store, templateVersion) + if err != nil { + return BuildError{http.StatusInternalServerError, "Failed to check build usage", err} + } + if !resp.Permitted { + return BuildError{http.StatusForbidden, "Build is not permitted: " + resp.Message, nil} + } + + return nil +} + func (b *Builder) checkRunningBuild() error { job, err := b.getLastBuildJob() if xerrors.Is(err, sql.ErrNoRows) { diff --git a/coderd/wsbuilder/wsbuilder_test.go b/coderd/wsbuilder/wsbuilder_test.go index 41ea3fe2c9921..ee421a8adb649 100644 --- a/coderd/wsbuilder/wsbuilder_test.go +++ b/coderd/wsbuilder/wsbuilder_test.go @@ -5,30 +5,30 @@ import ( "database/sql" "encoding/json" "net/http" + "sync/atomic" "testing" "time" - "github.com/prometheus/client_golang/prometheus" - - "github.com/coder/coder/v2/coderd/coderdtest" - "github.com/coder/coder/v2/coderd/files" - "github.com/coder/coder/v2/coderd/httpapi/httperror" - "github.com/coder/coder/v2/provisionersdk" - "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/propagation" "go.uber.org/mock/gomock" + "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/audit" + "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/files" + "github.com/coder/coder/v2/coderd/httpapi/httperror" "github.com/coder/coder/v2/coderd/provisionerdserver" "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/provisionersdk" ) var ( @@ -102,7 +102,7 @@ func TestBuilder_NoOptions(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -142,7 +142,8 @@ func TestBuilder_Initiator(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).Initiator(otherUserID) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + Initiator(otherUserID) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -188,7 +189,8 @@ func TestBuilder_Baggage(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).Initiator(otherUserID) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + Initiator(otherUserID) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{IP: "127.0.0.1"}) req.NoError(err) @@ -227,7 +229,8 @@ func TestBuilder_Reason(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).Reason(database.BuildReasonAutostart) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + Reason(database.BuildReasonAutostart) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -271,7 +274,8 @@ func TestBuilder_ActiveVersion(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).ActiveVersion() + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + ActiveVersion() // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -386,7 +390,8 @@ func TestWorkspaceBuildWithTags(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).RichParameterValues(buildParameters) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + RichParameterValues(buildParameters) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -469,7 +474,8 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).RichParameterValues(nextBuildParameters) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + RichParameterValues(nextBuildParameters) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -517,7 +523,8 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).RichParameterValues(nextBuildParameters) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + RichParameterValues(nextBuildParameters) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -555,7 +562,8 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}) + // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) bldErr := wsbuilder.BuildError{} req.ErrorAs(err, &bldErr) @@ -591,7 +599,8 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).RichParameterValues(nextBuildParameters) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + RichParameterValues(nextBuildParameters) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) bldErr := wsbuilder.BuildError{} @@ -656,7 +665,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart). + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). RichParameterValues(nextBuildParameters). VersionID(activeVersionID) _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) @@ -720,7 +729,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart). + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). RichParameterValues(nextBuildParameters). VersionID(activeVersionID) _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) @@ -782,7 +791,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart). + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). RichParameterValues(nextBuildParameters). VersionID(activeVersionID) // nolint: dogsled @@ -849,7 +858,7 @@ func TestWorkspaceBuildWithPreset(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart). + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). ActiveVersion(). TemplateVersionPresetID(presetID) // nolint: dogsled @@ -916,7 +925,7 @@ func TestWorkspaceBuildDeleteOrphan(t *testing.T) { ) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionDelete).Orphan() + uut := wsbuilder.New(ws, database.WorkspaceTransitionDelete, wsbuilder.NoopUsageChecker{}).Orphan() fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) // nolint: dogsled @@ -993,7 +1002,7 @@ func TestWorkspaceBuildDeleteOrphan(t *testing.T) { ) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionDelete).Orphan() + uut := wsbuilder.New(ws, database.WorkspaceTransitionDelete, wsbuilder.NoopUsageChecker{}).Orphan() fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) @@ -1001,6 +1010,115 @@ func TestWorkspaceBuildDeleteOrphan(t *testing.T) { }) } +func TestWorkspaceBuildUsageChecker(t *testing.T) { + t.Parallel() + + t.Run("Permitted", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var calls int64 + fakeUsageChecker := &fakeUsageChecker{ + checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { + atomic.AddInt64(&calls, 1) + return wsbuilder.UsageCheckResponse{Permitted: true}, nil + }, + } + + mDB := expectDB(t, + // Inputs + withTemplate, + withInactiveVersion(nil), + withLastBuildFound, + withTemplateVersionVariables(inactiveVersionID, nil), + withRichParameters(nil), + withParameterSchemas(inactiveJobID, nil), + withWorkspaceTags(inactiveVersionID, nil), + withProvisionerDaemons([]database.GetEligibleProvisionerDaemonsByProvisionerJobIDsRow{}), + + // Outputs + expectProvisionerJob(func(job database.InsertProvisionerJobParams) {}), + withInTx, + expectBuild(func(bld database.InsertWorkspaceBuildParams) {}), + withBuild, + expectBuildParameters(func(params database.InsertWorkspaceBuildParametersParams) {}), + ) + fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) + + ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, fakeUsageChecker) + // nolint: dogsled + _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) + require.NoError(t, err) + require.EqualValues(t, 1, calls) + }) + + // The failure cases are mostly identical from a test perspective. + const message = "fake test message" + cases := []struct { + name string + response wsbuilder.UsageCheckResponse + responseErr error + assertions func(t *testing.T, err error) + }{ + { + name: "NotPermitted", + response: wsbuilder.UsageCheckResponse{ + Permitted: false, + Message: message, + }, + assertions: func(t *testing.T, err error) { + require.ErrorContains(t, err, message) + var buildErr wsbuilder.BuildError + require.ErrorAs(t, err, &buildErr) + require.Equal(t, http.StatusForbidden, buildErr.Status) + }, + }, + { + name: "Error", + responseErr: xerrors.New("fake error"), + assertions: func(t *testing.T, err error) { + require.ErrorContains(t, err, "fake error") + require.ErrorAs(t, err, &wsbuilder.BuildError{}) + }, + }, + } + + for _, c := range cases { + c := c + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var calls int64 + fakeUsageChecker := &fakeUsageChecker{ + checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { + atomic.AddInt64(&calls, 1) + return c.response, c.responseErr + }, + } + + mDB := expectDB(t, + withTemplate, + withInactiveVersionNoParams(), + ) + fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) + + ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, fakeUsageChecker). + VersionID(inactiveVersionID) + // nolint: dogsled + _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) + c.assertions(t, err) + require.EqualValues(t, 1, calls) + }) + } +} + func TestWsbuildError(t *testing.T) { t.Parallel() @@ -1366,3 +1484,11 @@ func withProvisionerDaemons(provisionerDaemons []database.GetEligibleProvisioner mTx.EXPECT().GetEligibleProvisionerDaemonsByProvisionerJobIDs(gomock.Any(), gomock.Any()).Return(provisionerDaemons, nil) } } + +type fakeUsageChecker struct { + checkBuildUsageFunc func(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) +} + +func (f *fakeUsageChecker) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { + return f.checkBuildUsageFunc(ctx, store, templateVersion) +} diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 0d176567713a2..d6e47f4cfdf00 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -22,6 +22,7 @@ import ( agplportsharing "github.com/coder/coder/v2/coderd/portsharing" agplprebuilds "github.com/coder/coder/v2/coderd/prebuilds" "github.com/coder/coder/v2/coderd/rbac/policy" + "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/coder/coder/v2/enterprise/coderd/connectionlog" "github.com/coder/coder/v2/enterprise/coderd/enidpsync" "github.com/coder/coder/v2/enterprise/coderd/portsharing" @@ -916,10 +917,70 @@ func (api *API) updateEntitlements(ctx context.Context) error { reloadedEntitlements.Warnings = append(reloadedEntitlements.Warnings, msg) } reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption + + // If there's a license installed, we will use the enterprise build + // limit checker. + // This checker currently only enforces the managed agent limit. + if reloadedEntitlements.HasLicense { + var checker wsbuilder.UsageChecker = api + api.AGPL.BuildUsageChecker.Store(&checker) + } else { + // Don't check any usage, just like AGPL. + var checker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{} + api.AGPL.BuildUsageChecker.Store(&checker) + } + return reloadedEntitlements, nil }) } +var _ wsbuilder.UsageChecker = &API{} + +func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { + // We assume that if this function is called, a valid license is installed. + // When there are no licenses installed, a noop usage checker is used + // instead. + + // If the template version doesn't have an AI task, we don't need to check + // usage. + if !templateVersion.HasAITask.Valid || !templateVersion.HasAITask.Bool { + return wsbuilder.UsageCheckResponse{ + Permitted: true, + }, nil + } + + // Otherwise, we need to check that we haven't breached the managed agent + // limit. + managedAgentLimit, ok := api.Entitlements.Feature(codersdk.FeatureManagedAgentLimit) + if !ok || !managedAgentLimit.Enabled || managedAgentLimit.Limit == nil || managedAgentLimit.UsagePeriod == nil { + return wsbuilder.UsageCheckResponse{ + Permitted: false, + Message: "Your license is not entitled to managed agents. Please contact sales to continue using managed agents.", + }, nil + } + + // This check is intentionally not committed to the database. It's fine if + // it's not 100% accurate or allows for minor breaches due to build races. + managedAgentCount, err := store.GetManagedAgentCount(ctx, database.GetManagedAgentCountParams{ + StartTime: managedAgentLimit.UsagePeriod.Start, + EndTime: managedAgentLimit.UsagePeriod.End, + }) + if err != nil { + return wsbuilder.UsageCheckResponse{}, xerrors.Errorf("get managed agent count: %w", err) + } + + if managedAgentCount >= *managedAgentLimit.Limit { + return wsbuilder.UsageCheckResponse{ + Permitted: false, + Message: "You have breached the managed agent limit in your license. Please contact sales to continue using managed agents.", + }, nil + } + + return wsbuilder.UsageCheckResponse{ + Permitted: true, + }, nil +} + // getProxyDERPStartingRegionID returns the starting region ID that should be // used for workspace proxies. A proxy's actual region ID is the return value // from this function + it's RegionID field. @@ -1186,6 +1247,6 @@ func (api *API) setupPrebuilds(featureEnabled bool) (agplprebuilds.Reconciliatio } reconciler := prebuilds.NewStoreReconciler(api.Database, api.Pubsub, api.AGPL.FileCache, api.DeploymentValues.Prebuilds, - api.Logger.Named("prebuilds"), quartz.NewReal(), api.PrometheusRegistry, api.NotificationsEnqueuer) + api.Logger.Named("prebuilds"), quartz.NewReal(), api.PrometheusRegistry, api.NotificationsEnqueuer, api.AGPL.BuildUsageChecker) return reconciler, prebuilds.NewEnterpriseClaimer(api.Database) } diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index 52301f6dae034..42645a98b06c2 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -32,6 +32,8 @@ import ( "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/enterprise/coderd/prebuilds" + "github.com/coder/coder/v2/provisioner/echo" + "github.com/coder/coder/v2/provisionersdk/proto" "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/retry" @@ -621,6 +623,88 @@ func TestSCIMDisabled(t *testing.T) { } } +func TestManagedAgentLimit(t *testing.T) { + t.Parallel() + + cli, _ := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + IncludeProvisionerDaemon: true, + }, + LicenseOptions: (&coderdenttest.LicenseOptions{}).ManagedAgentLimit(1, 1), + }) + + // It's fine that the app ID is only used in a single successful workspace + // build. + appID := uuid.NewString() + echoRes := &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: []*proto.Response{ + { + Type: &proto.Response_Plan{ + Plan: &proto.PlanComplete{ + Plan: []byte("{}"), + ModuleFiles: []byte{}, + HasAiTasks: true, + }, + }, + }, + }, + ProvisionApply: []*proto.Response{{ + Type: &proto.Response_Apply{ + Apply: &proto.ApplyComplete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + Id: uuid.NewString(), + Name: "example", + Auth: &proto.Agent_Token{ + Token: uuid.NewString(), + }, + Apps: []*proto.App{{ + Id: appID, + Slug: "test", + Url: "http://localhost:1234", + }}, + }}, + }}, + AiTasks: []*proto.AITask{{ + Id: uuid.NewString(), + SidebarApp: &proto.AITaskSidebarApp{ + Id: appID, + }, + }}, + }, + }, + }}, + } + + // Create two templates, one with AI and one without. + aiVersion := coderdtest.CreateTemplateVersion(t, cli, uuid.Nil, echoRes) + coderdtest.AwaitTemplateVersionJobCompleted(t, cli, aiVersion.ID) + aiTemplate := coderdtest.CreateTemplate(t, cli, uuid.Nil, aiVersion.ID) + noAiVersion := coderdtest.CreateTemplateVersion(t, cli, uuid.Nil, nil) // use default responses + coderdtest.AwaitTemplateVersionJobCompleted(t, cli, noAiVersion.ID) + noAiTemplate := coderdtest.CreateTemplate(t, cli, uuid.Nil, noAiVersion.ID) + + // Create one AI workspace, which should succeed. + workspace := coderdtest.CreateWorkspace(t, cli, aiTemplate.ID) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, cli, workspace.LatestBuild.ID) + + // Create a second AI workspace, which should fail. This needs to be done + // manually because coderdtest.CreateWorkspace expects it to succeed. + _, err := cli.CreateUserWorkspace(context.Background(), codersdk.Me, codersdk.CreateWorkspaceRequest{ //nolint:gocritic // owners must still be subject to the limit + TemplateID: aiTemplate.ID, + Name: coderdtest.RandomUsername(t), + AutomaticUpdates: codersdk.AutomaticUpdatesNever, + }) + require.ErrorContains(t, err, "You have breached the managed agent limit in your license") + + // Create a third non-AI workspace, which should succeed. + workspace = coderdtest.CreateWorkspace(t, cli, noAiTemplate.ID) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, cli, workspace.LatestBuild.ID) +} + // testDBAuthzRole returns a context with a subject that has a role // with permissions required for test setup. func testDBAuthzRole(ctx context.Context) context.Context { diff --git a/enterprise/coderd/license/license.go b/enterprise/coderd/license/license.go index 9371c10c138d8..7776557522f86 100644 --- a/enterprise/coderd/license/license.go +++ b/enterprise/coderd/license/license.go @@ -94,15 +94,15 @@ func Entitlements( return codersdk.Entitlements{}, xerrors.Errorf("query active user count: %w", err) } - // always shows active user count regardless of license entitlements, err := LicensesEntitlements(ctx, now, licenses, enablements, keys, FeatureArguments{ ActiveUserCount: activeUserCount, ReplicaCount: replicaCount, ExternalAuthCount: externalAuthCount, - ManagedAgentCountFn: func(_ context.Context, _ time.Time, _ time.Time) (int64, error) { - // TODO(@deansheather): replace this with a real implementation in a - // follow up PR. - return 0, nil + ManagedAgentCountFn: func(ctx context.Context, startTime time.Time, endTime time.Time) (int64, error) { + return db.GetManagedAgentCount(ctx, database.GetManagedAgentCountParams{ + StartTime: startTime, + EndTime: endTime, + }) }, }) if err != nil { diff --git a/enterprise/coderd/license/license_test.go b/enterprise/coderd/license/license_test.go index fac1d2b44bb63..d8203117039cb 100644 --- a/enterprise/coderd/license/license_test.go +++ b/enterprise/coderd/license/license_test.go @@ -10,8 +10,10 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/codersdk" @@ -678,6 +680,67 @@ func TestEntitlements(t *testing.T) { require.Len(t, entitlements.Warnings, 1) require.Equal(t, "You have multiple External Auth Providers configured but your license is expired. Reduce to one.", entitlements.Warnings[0]) }) + + t.Run("ManagedAgentLimitHasValue", func(t *testing.T) { + t.Parallel() + + // Use a mock database for this test so I don't need to make real + // workspace builds. + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + + licenseOpts := (&coderdenttest.LicenseOptions{ + FeatureSet: codersdk.FeatureSetPremium, + IssuedAt: dbtime.Now().Add(-2 * time.Hour).Truncate(time.Second), + NotBefore: dbtime.Now().Add(-time.Hour).Truncate(time.Second), + GraceAt: dbtime.Now().Add(time.Hour * 24 * 60).Truncate(time.Second), // 60 days to remove warning + ExpiresAt: dbtime.Now().Add(time.Hour * 24 * 90).Truncate(time.Second), // 90 days to remove warning + }). + UserLimit(100). + ManagedAgentLimit(100, 200) + + lic := database.License{ + ID: 1, + JWT: coderdenttest.GenerateLicense(t, *licenseOpts), + Exp: licenseOpts.ExpiresAt, + } + + mDB.EXPECT(). + GetUnexpiredLicenses(gomock.Any()). + Return([]database.License{lic}, nil) + mDB.EXPECT(). + GetActiveUserCount(gomock.Any(), false). + Return(int64(1), nil) + mDB.EXPECT(). + GetManagedAgentCount(gomock.Any(), gomock.Cond(func(params database.GetManagedAgentCountParams) bool { + // gomock doesn't seem to compare times very nicely. + if !assert.WithinDuration(t, licenseOpts.NotBefore, params.StartTime, time.Second) { + return false + } + if !assert.WithinDuration(t, licenseOpts.ExpiresAt, params.EndTime, time.Second) { + return false + } + return true + })). + Return(int64(175), nil) + + entitlements, err := license.Entitlements(context.Background(), mDB, 1, 0, coderdenttest.Keys, all) + require.NoError(t, err) + require.True(t, entitlements.HasLicense) + + managedAgentLimit, ok := entitlements.Features[codersdk.FeatureManagedAgentLimit] + require.True(t, ok) + require.NotNil(t, managedAgentLimit.SoftLimit) + require.EqualValues(t, 100, *managedAgentLimit.SoftLimit) + require.NotNil(t, managedAgentLimit.Limit) + require.EqualValues(t, 200, *managedAgentLimit.Limit) + require.NotNil(t, managedAgentLimit.Actual) + require.EqualValues(t, 175, *managedAgentLimit.Actual) + + // Should've also populated a warning. + require.Len(t, entitlements.Warnings, 1) + require.Equal(t, "You are approaching the managed agent limit in your license. Please refer to the Deployment Licenses page for more information.", entitlements.Warnings[0]) + }) } func TestLicenseEntitlements(t *testing.T) { diff --git a/enterprise/coderd/prebuilds/claim_test.go b/enterprise/coderd/prebuilds/claim_test.go index 67c1f0dd21ade..01195e3485016 100644 --- a/enterprise/coderd/prebuilds/claim_test.go +++ b/enterprise/coderd/prebuilds/claim_test.go @@ -166,7 +166,7 @@ func TestClaimPrebuild(t *testing.T) { defer provisionerCloser.Close() cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(spy, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(spy, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(spy) api.AGPL.PrebuildsClaimer.Store(&claimer) diff --git a/enterprise/coderd/prebuilds/metricscollector_test.go b/enterprise/coderd/prebuilds/metricscollector_test.go index 96c3d071ac48a..1e9f3f5082806 100644 --- a/enterprise/coderd/prebuilds/metricscollector_test.go +++ b/enterprise/coderd/prebuilds/metricscollector_test.go @@ -201,7 +201,7 @@ func TestMetricsCollector(t *testing.T) { clock := quartz.NewMock(t) db, pubsub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ctx := testutil.Context(t, testutil.WaitLong) createdUsers := []uuid.UUID{database.PrebuildsSystemUserID} @@ -338,7 +338,7 @@ func TestMetricsCollector_DuplicateTemplateNames(t *testing.T) { clock := quartz.NewMock(t) db, pubsub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ctx := testutil.Context(t, testutil.WaitLong) collector := prebuilds.NewMetricsCollector(db, logger, reconciler) @@ -491,7 +491,7 @@ func TestMetricsCollector_ReconciliationPausedMetric(t *testing.T) { db, pubsub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) registry := prometheus.NewPedanticRegistry() - reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer(), newNoopUsageCheckerPtr()) ctx := testutil.Context(t, testutil.WaitLong) // Ensure no pause setting is set (default state) @@ -520,7 +520,7 @@ func TestMetricsCollector_ReconciliationPausedMetric(t *testing.T) { db, pubsub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) registry := prometheus.NewPedanticRegistry() - reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer(), newNoopUsageCheckerPtr()) ctx := testutil.Context(t, testutil.WaitLong) // Set reconciliation to paused @@ -549,7 +549,7 @@ func TestMetricsCollector_ReconciliationPausedMetric(t *testing.T) { db, pubsub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) registry := prometheus.NewPedanticRegistry() - reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer(), newNoopUsageCheckerPtr()) ctx := testutil.Context(t, testutil.WaitLong) // Set reconciliation back to not paused diff --git a/enterprise/coderd/prebuilds/reconcile.go b/enterprise/coderd/prebuilds/reconcile.go index 049568c7e7f0c..214d1643bb228 100644 --- a/enterprise/coderd/prebuilds/reconcile.go +++ b/enterprise/coderd/prebuilds/reconcile.go @@ -39,15 +39,16 @@ import ( ) type StoreReconciler struct { - store database.Store - cfg codersdk.PrebuildsConfig - pubsub pubsub.Pubsub - fileCache *files.Cache - logger slog.Logger - clock quartz.Clock - registerer prometheus.Registerer - metrics *MetricsCollector - notifEnq notifications.Enqueuer + store database.Store + cfg codersdk.PrebuildsConfig + pubsub pubsub.Pubsub + fileCache *files.Cache + logger slog.Logger + clock quartz.Clock + registerer prometheus.Registerer + metrics *MetricsCollector + notifEnq notifications.Enqueuer + buildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker] cancelFn context.CancelCauseFunc running atomic.Bool @@ -66,6 +67,7 @@ func NewStoreReconciler(store database.Store, clock quartz.Clock, registerer prometheus.Registerer, notifEnq notifications.Enqueuer, + buildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker], ) *StoreReconciler { reconciler := &StoreReconciler{ store: store, @@ -76,6 +78,7 @@ func NewStoreReconciler(store database.Store, clock: clock, registerer: registerer, notifEnq: notifEnq, + buildUsageChecker: buildUsageChecker, done: make(chan struct{}, 1), provisionNotifyCh: make(chan database.ProvisionerJob, 10), } @@ -738,7 +741,7 @@ func (c *StoreReconciler) provision( }) } - builder := wsbuilder.New(workspace, transition). + builder := wsbuilder.New(workspace, transition, *c.buildUsageChecker.Load()). Reason(database.BuildReasonInitiator). Initiator(database.PrebuildsSystemUserID). MarkPrebuild() diff --git a/enterprise/coderd/prebuilds/reconcile_test.go b/enterprise/coderd/prebuilds/reconcile_test.go index 5ba36912ce5c8..8d2a81e1ade83 100644 --- a/enterprise/coderd/prebuilds/reconcile_test.go +++ b/enterprise/coderd/prebuilds/reconcile_test.go @@ -6,6 +6,7 @@ import ( "fmt" "sort" "sync" + "sync/atomic" "testing" "time" @@ -19,6 +20,7 @@ import ( "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/notifications/notificationstest" "github.com/coder/coder/v2/coderd/util/slice" + "github.com/coder/coder/v2/coderd/wsbuilder" sdkproto "github.com/coder/coder/v2/provisionersdk/proto" "github.com/google/uuid" @@ -56,7 +58,7 @@ func TestNoReconciliationActionsIfNoPresets(t *testing.T) { } logger := testutil.Logger(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) // given a template version with no presets org := dbgen.Organization(t, db, database.Organization{}) @@ -102,7 +104,7 @@ func TestNoReconciliationActionsIfNoPrebuilds(t *testing.T) { } logger := testutil.Logger(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) // given there are presets, but no prebuilds org := dbgen.Organization(t, db, database.Organization{}) @@ -382,7 +384,7 @@ func TestPrebuildReconciliation(t *testing.T) { pubSub = &brokenPublisher{Pubsub: pubSub} } cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) // Run the reconciliation multiple times to ensure idempotency // 8 was arbitrary, but large enough to reasonably trust the result @@ -460,7 +462,7 @@ func TestMultiplePresetsPerTemplateVersion(t *testing.T) { ).Leveled(slog.LevelDebug) db, pubSub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ownerID := uuid.New() dbgen.User(t, db, database.User{ @@ -586,7 +588,7 @@ func TestPrebuildScheduling(t *testing.T) { ).Leveled(slog.LevelDebug) db, pubSub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ownerID := uuid.New() dbgen.User(t, db, database.User{ @@ -691,7 +693,7 @@ func TestInvalidPreset(t *testing.T) { ).Leveled(slog.LevelDebug) db, pubSub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ownerID := uuid.New() dbgen.User(t, db, database.User{ @@ -756,7 +758,7 @@ func TestDeletionOfPrebuiltWorkspaceWithInvalidPreset(t *testing.T) { ).Leveled(slog.LevelDebug) db, pubSub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ownerID := uuid.New() dbgen.User(t, db, database.User{ @@ -853,7 +855,7 @@ func TestSkippingHardLimitedPresets(t *testing.T) { fakeEnqueuer := newFakeEnqueuer() registry := prometheus.NewRegistry() cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer, newNoopUsageCheckerPtr()) // Set up test environment with a template, version, and preset. ownerID := uuid.New() @@ -997,7 +999,7 @@ func TestHardLimitedPresetShouldNotBlockDeletion(t *testing.T) { fakeEnqueuer := newFakeEnqueuer() registry := prometheus.NewRegistry() cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer, newNoopUsageCheckerPtr()) // Set up test environment with a template, version, and preset. ownerID := uuid.New() @@ -1191,7 +1193,7 @@ func TestRunLoop(t *testing.T) { ).Leveled(slog.LevelDebug) db, pubSub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ownerID := uuid.New() dbgen.User(t, db, database.User{ @@ -1322,7 +1324,7 @@ func TestFailedBuildBackoff(t *testing.T) { ).Leveled(slog.LevelDebug) db, ps := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) // Given: an active template version with presets and prebuilds configured. const desiredInstances = 2 @@ -1447,7 +1449,8 @@ func TestReconciliationLock(t *testing.T) { slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug), quartz.NewMock(t), prometheus.NewRegistry(), - newNoopEnqueuer()) + newNoopEnqueuer(), + newNoopUsageCheckerPtr()) reconciler.WithReconciliationLock(ctx, logger, func(_ context.Context, _ database.Store) error { lockObtained := mutex.TryLock() // As long as the postgres lock is held, this mutex should always be unlocked when we get here. @@ -1481,7 +1484,7 @@ func TestTrackResourceReplacement(t *testing.T) { fakeEnqueuer := newFakeEnqueuer() registry := prometheus.NewRegistry() cache := files.New(registry, &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, ps, cache, codersdk.PrebuildsConfig{}, logger, clock, registry, fakeEnqueuer) + reconciler := prebuilds.NewStoreReconciler(db, ps, cache, codersdk.PrebuildsConfig{}, logger, clock, registry, fakeEnqueuer, newNoopUsageCheckerPtr()) // Given: a template admin to receive a notification. templateAdmin := dbgen.User(t, db, database.User{ @@ -1637,7 +1640,7 @@ func TestExpiredPrebuildsMultipleActions(t *testing.T) { fakeEnqueuer := newFakeEnqueuer() registry := prometheus.NewRegistry() cache := files.New(registry, &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer, newNoopUsageCheckerPtr()) // Set up test environment with a template, version, and preset ownerID := uuid.New() @@ -1800,6 +1803,13 @@ func newFakeEnqueuer() *notificationstest.FakeEnqueuer { return notificationstest.NewFakeEnqueuer() } +func newNoopUsageCheckerPtr() *atomic.Pointer[wsbuilder.UsageChecker] { + var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{} + buildUsageChecker := atomic.Pointer[wsbuilder.UsageChecker]{} + buildUsageChecker.Store(&noopUsageChecker) + return &buildUsageChecker +} + // nolint:revive // It's a control flag, but this is a test. func setupTestDBTemplate( t *testing.T, @@ -2270,7 +2280,7 @@ func TestReconciliationRespectsPauseSetting(t *testing.T) { } logger := testutil.Logger(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) // Setup a template with a preset that should create prebuilds org := dbgen.Organization(t, db, database.Organization{}) diff --git a/enterprise/coderd/workspaces_test.go b/enterprise/coderd/workspaces_test.go index d622748899aa0..2278fb2a71939 100644 --- a/enterprise/coderd/workspaces_test.go +++ b/enterprise/coderd/workspaces_test.go @@ -1864,6 +1864,7 @@ func TestExecutorPrebuilds(t *testing.T) { clock, prometheus.NewRegistry(), notificationsNoop, + api.AGPL.BuildUsageChecker, ) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) api.AGPL.PrebuildsClaimer.Store(&claimer) @@ -2004,6 +2005,7 @@ func TestExecutorPrebuilds(t *testing.T) { clock, prometheus.NewRegistry(), notificationsNoop, + api.AGPL.BuildUsageChecker, ) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) api.AGPL.PrebuildsClaimer.Store(&claimer) @@ -2134,6 +2136,7 @@ func TestExecutorPrebuilds(t *testing.T) { clock, prometheus.NewRegistry(), notificationsNoop, + api.AGPL.BuildUsageChecker, ) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) api.AGPL.PrebuildsClaimer.Store(&claimer) @@ -2266,6 +2269,7 @@ func TestExecutorPrebuilds(t *testing.T) { clock, prometheus.NewRegistry(), notificationsNoop, + api.AGPL.BuildUsageChecker, ) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) api.AGPL.PrebuildsClaimer.Store(&claimer) @@ -2376,6 +2380,7 @@ func TestExecutorPrebuilds(t *testing.T) { clock, prometheus.NewRegistry(), notificationsNoop, + api.AGPL.BuildUsageChecker, ) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) api.AGPL.PrebuildsClaimer.Store(&claimer)
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: