diff --git a/cli/features.go b/cli/features.go index 5d631fc04977f..1995153275eaf 100644 --- a/cli/features.go +++ b/cli/features.go @@ -1,6 +1,7 @@ package cli import ( + "bytes" "encoding/json" "fmt" "strings" @@ -53,12 +54,14 @@ func featuresList() *cobra.Command { return xerrors.Errorf("render table: %w", err) } case "json": - outBytes, err := json.Marshal(entitlements) + buf := new(bytes.Buffer) + enc := json.NewEncoder(buf) + enc.SetIndent("", " ") + err = enc.Encode(entitlements) if err != nil { return xerrors.Errorf("marshal features to JSON: %w", err) } - - out = string(outBytes) + out = buf.String() default: return xerrors.Errorf(`unknown output format %q, only "table" and "json" are supported`, outputFormat) } diff --git a/coderd/coderd.go b/coderd/coderd.go index f7b8603367b5e..be089523ec503 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -68,6 +68,7 @@ type Options struct { TracerProvider *sdktrace.TracerProvider AutoImportTemplates []AutoImportTemplate LicenseHandler http.Handler + FeaturesService FeaturesService } // New constructs a Coder API handler. @@ -97,6 +98,9 @@ func New(options *Options) *API { if options.LicenseHandler == nil { options.LicenseHandler = licenses() } + if options.FeaturesService == nil { + options.FeaturesService = featuresService{} + } siteCacheDir := options.CacheDir if siteCacheDir != "" { @@ -406,7 +410,7 @@ func New(options *Options) *API { }) r.Route("/entitlements", func(r chi.Router) { r.Use(apiKeyMiddleware) - r.Get("/", entitlements) + r.Get("/", api.FeaturesService.EntitlementsAPI) }) r.Route("/licenses", func(r chi.Router) { r.Use(apiKeyMiddleware) diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 590d3c5685bd4..2b42d0cca7ad2 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -246,6 +246,19 @@ func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) { return int64(len(q.users)), nil } +func (q *fakeQuerier) GetActiveUserCount(_ context.Context) (int64, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + active := int64(0) + for _, u := range q.users { + if u.Status == database.UserStatusActive { + active++ + } + } + return active, nil +} + func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.User, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -2322,6 +2335,21 @@ func (q *fakeQuerier) GetLicenses(_ context.Context) ([]database.License, error) return results, nil } +func (q *fakeQuerier) GetUnexpiredLicenses(_ context.Context) ([]database.License, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + now := time.Now() + var results []database.License + for _, l := range q.licenses { + if l.Exp.After(now) { + results = append(results, l) + } + } + sort.Slice(results, func(i, j int) bool { return results[i].ID < results[j].ID }) + return results, nil +} + func (q *fakeQuerier) DeleteLicense(_ context.Context, id int32) (int32, error) { q.mutex.Lock() defer q.mutex.Unlock() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 9603608ebbb05..389f15b385d6d 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -25,6 +25,7 @@ type querier interface { DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error) + GetActiveUserCount(ctx context.Context) (int64, error) // GetAuditLogsBefore retrieves `limit` number of audit logs before the provided // ID. GetAuditLogsBefore(ctx context.Context, arg GetAuditLogsBeforeParams) ([]AuditLog, error) @@ -63,6 +64,7 @@ type querier interface { GetTemplateVersionsCreatedAfter(ctx context.Context, createdAt time.Time) ([]TemplateVersion, error) GetTemplates(ctx context.Context) ([]Template, error) GetTemplatesWithFilter(ctx context.Context, arg GetTemplatesWithFilterParams) ([]Template, error) + GetUnexpiredLicenses(ctx context.Context) ([]License, error) GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) GetUserCount(ctx context.Context) (int64, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 9a22d86d888a8..1e4a194fa740d 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -522,6 +522,41 @@ func (q *sqlQuerier) GetLicenses(ctx context.Context) ([]License, error) { return items, nil } +const getUnexpiredLicenses = `-- name: GetUnexpiredLicenses :many +SELECT id, uploaded_at, jwt, exp +FROM licenses +WHERE exp > NOW() +ORDER BY (id) +` + +func (q *sqlQuerier) GetUnexpiredLicenses(ctx context.Context) ([]License, error) { + rows, err := q.db.QueryContext(ctx, getUnexpiredLicenses) + if err != nil { + return nil, err + } + defer rows.Close() + var items []License + for rows.Next() { + var i License + if err := rows.Scan( + &i.ID, + &i.UploadedAt, + &i.JWT, + &i.Exp, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const insertLicense = `-- name: InsertLicense :one INSERT INTO licenses ( @@ -2664,6 +2699,22 @@ func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinke return i, err } +const getActiveUserCount = `-- name: GetActiveUserCount :one +SELECT + COUNT(*) +FROM + users +WHERE + status = 'active'::public.user_status +` + +func (q *sqlQuerier) GetActiveUserCount(ctx context.Context) (int64, error) { + row := q.db.QueryRowContext(ctx, getActiveUserCount) + var count int64 + err := row.Scan(&count) + return count, err +} + const getAuthorizationUserRoles = `-- name: GetAuthorizationUserRoles :one SELECT -- username is returned just to help for logging purposes diff --git a/coderd/database/queries/licenses.sql b/coderd/database/queries/licenses.sql index e299589087119..39419c301761d 100644 --- a/coderd/database/queries/licenses.sql +++ b/coderd/database/queries/licenses.sql @@ -13,6 +13,12 @@ SELECT * FROM licenses ORDER BY (id); +-- name: GetUnexpiredLicenses :many +SELECT * +FROM licenses +WHERE exp > NOW() +ORDER BY (id); + -- name: DeleteLicense :one DELETE FROM licenses diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index 1d9caa758625e..12751fe064b47 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -28,6 +28,14 @@ SELECT FROM users; +-- name: GetActiveUserCount :one +SELECT + COUNT(*) +FROM + users +WHERE + status = 'active'::public.user_status; + -- name: InsertUser :one INSERT INTO users ( diff --git a/coderd/features.go b/coderd/features.go index a6eaeca9c545b..55ddd2af895f9 100644 --- a/coderd/features.go +++ b/coderd/features.go @@ -7,7 +7,20 @@ import ( "github.com/coder/coder/codersdk" ) -func entitlements(rw http.ResponseWriter, _ *http.Request) { +// FeaturesService is the interface for interacting with enterprise features. +type FeaturesService interface { + EntitlementsAPI(w http.ResponseWriter, r *http.Request) + + // TODO + // Get returns the implementations for feature interfaces. Parameter `s `must be a pointer to a + // struct type containing feature interfaces as fields. The FeatureService sets all fields to + // the correct implementations depending on whether the features are turned on. + // Get(s any) error +} + +type featuresService struct{} + +func (featuresService) EntitlementsAPI(rw http.ResponseWriter, _ *http.Request) { features := make(map[string]codersdk.Feature) for _, f := range codersdk.FeatureNames { features[f] = codersdk.Feature{ diff --git a/coderd/features_internal_test.go b/coderd/features_internal_test.go index 50c7e8f53e397..d06fc96e19626 100644 --- a/coderd/features_internal_test.go +++ b/coderd/features_internal_test.go @@ -18,7 +18,7 @@ func TestEntitlements(t *testing.T) { t.Parallel() r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil) rw := httptest.NewRecorder() - entitlements(rw, r) + featuresService{}.EntitlementsAPI(rw, r) resp := rw.Result() defer resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) diff --git a/codersdk/features.go b/codersdk/features.go index 6bdfd5ff53bd4..37b0113c37dfb 100644 --- a/codersdk/features.go +++ b/codersdk/features.go @@ -24,8 +24,8 @@ var FeatureNames = []string{FeatureUserLimit, FeatureAuditLog} type Feature struct { Entitlement Entitlement `json:"entitlement"` Enabled bool `json:"enabled"` - Limit *int64 `json:"limit"` - Actual *int64 `json:"actual"` + Limit *int64 `json:"limit,omitempty"` + Actual *int64 `json:"actual,omitempty"` } type Entitlements struct { diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 2be49052d3658..598c32f11b367 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -1,12 +1,18 @@ package coderd import ( + "context" + "os" + "strings" + "golang.org/x/xerrors" "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/rbac" ) +const EnvAuditLogEnable = "CODER_AUDIT_LOG_ENABLE" + func NewEnterprise(options *coderd.Options) *coderd.API { var eOpts = *options if eOpts.Authorizer == nil { @@ -26,5 +32,18 @@ func NewEnterprise(options *coderd.Options) *coderd.API { Authorizer: eOpts.Authorizer, Logger: eOpts.Logger, }).handler() + en := Enablements{AuditLogs: true} + auditLog := os.Getenv(EnvAuditLogEnable) + auditLog = strings.ToLower(auditLog) + if auditLog == "disable" || auditLog == "false" || auditLog == "0" || auditLog == "no" { + en.AuditLogs = false + } + eOpts.FeaturesService = newFeaturesService( + context.Background(), + eOpts.Logger, + eOpts.Database, + eOpts.Pubsub, + en, + ) return coderd.New(&eOpts) } diff --git a/enterprise/coderd/features.go b/enterprise/coderd/features.go new file mode 100644 index 0000000000000..2102cdc0eb122 --- /dev/null +++ b/enterprise/coderd/features.go @@ -0,0 +1,261 @@ +package coderd + +import ( + "context" + "crypto/ed25519" + "fmt" + "net/http" + "sync" + "time" + + "github.com/cenkalti/backoff/v4" + + "cdr.dev/slog" + + agpl "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/codersdk" +) + +type Enablements struct { + AuditLogs bool +} + +type featuresService struct { + logger slog.Logger + database database.Store + pubsub database.Pubsub + keys map[string]ed25519.PublicKey + enablements Enablements + resyncInterval time.Duration + + mu sync.RWMutex + entitlements entitlements +} + +// newFeaturesService creates a FeaturesService and starts it. It will continue running for the +// duration of the passed ctx. +func newFeaturesService( + ctx context.Context, + logger slog.Logger, + db database.Store, + pubsub database.Pubsub, + enablements Enablements, +) agpl.FeaturesService { + fs := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: keys, + enablements: enablements, + resyncInterval: 10 * time.Minute, + entitlements: entitlements{ + activeUsers: numericalEntitlement{ + entitlementLimit: entitlementLimit{ + unlimited: true, + }, + }, + }, + } + go fs.syncEntitlements(ctx) + return fs +} + +func (s *featuresService) EntitlementsAPI(rw http.ResponseWriter, r *http.Request) { + s.mu.RLock() + e := s.entitlements + s.mu.RUnlock() + + resp := codersdk.Entitlements{ + Features: make(map[string]codersdk.Feature), + Warnings: make([]string, 0), + HasLicense: e.hasLicense, + } + + // User limit + uf := codersdk.Feature{ + Entitlement: e.activeUsers.state.toSDK(), + Enabled: true, + } + if !e.activeUsers.unlimited { + n, err := s.database.GetActiveUserCount(r.Context()) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Unable to query database", + Detail: err.Error(), + }) + return + } + uf.Actual = &n + uf.Limit = &e.activeUsers.limit + if n > e.activeUsers.limit { + resp.Warnings = append(resp.Warnings, + fmt.Sprintf( + "Your deployment has %d active users but is only licensed for %d", + n, e.activeUsers.limit)) + } + } + resp.Features[codersdk.FeatureUserLimit] = uf + + // Audit logs + resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{ + Entitlement: e.auditLogs.state.toSDK(), + Enabled: s.enablements.AuditLogs, + } + if e.auditLogs.state == gracePeriod && s.enablements.AuditLogs { + resp.Warnings = append(resp.Warnings, + "Audit logging is enabled but your license for this feature is expired") + } + + httpapi.Write(rw, http.StatusOK, resp) +} + +type entitlementState int + +const ( + notEntitled entitlementState = iota + gracePeriod + entitled +) + +type entitlementLimit struct { + unlimited bool + limit int64 +} + +type entitlement struct { + state entitlementState +} + +func (s entitlementState) toSDK() codersdk.Entitlement { + switch s { + case notEntitled: + return codersdk.EntitlementNotEntitled + case gracePeriod: + return codersdk.EntitlementGracePeriod + case entitled: + return codersdk.EntitlementEntitled + default: + panic("unknown entitlementState") + } +} + +type numericalEntitlement struct { + entitlement + entitlementLimit +} + +type entitlements struct { + hasLicense bool + activeUsers numericalEntitlement + auditLogs entitlement +} + +func (s *featuresService) getEntitlements(ctx context.Context) (entitlements, error) { + licenses, err := s.database.GetUnexpiredLicenses(ctx) + if err != nil { + return entitlements{}, err + } + now := time.Now() + e := entitlements{ + activeUsers: numericalEntitlement{ + entitlementLimit: entitlementLimit{ + unlimited: true, + }, + }, + } + for _, l := range licenses { + claims, err := validateDBLicense(l, s.keys) + if err != nil { + s.logger.Debug(ctx, "skipping invalid license", + slog.F("id", l.ID), slog.Error(err)) + continue + } + e.hasLicense = true + thisEntitlement := entitled + if now.After(claims.LicenseExpires.Time) { + // if the grace period were over, the validation fails, so if we are after + // LicenseExpires we must be in grace period. + thisEntitlement = gracePeriod + } + if claims.Features.UserLimit > 0 { + e.activeUsers.state = thisEntitlement + e.activeUsers.unlimited = false + e.activeUsers.limit = max(e.activeUsers.limit, claims.Features.UserLimit) + } + if claims.Features.AuditLog > 0 { + e.auditLogs.state = thisEntitlement + } + } + return e, nil +} + +func (s *featuresService) syncEntitlements(ctx context.Context) { + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + b := backoff.WithContext(eb, ctx) + updates := make(chan struct{}, 1) + subscribed := false + + for { + select { + case <-ctx.Done(): + return + default: + // pass + } + if !subscribed { + cancel, err := s.pubsub.Subscribe(PubSubEventLicenses, func(_ context.Context, _ []byte) { + // don't block. If the channel is full, drop the event, as there is a resync + // scheduled already. + select { + case updates <- struct{}{}: + // pass + default: + // pass + } + }) + if err != nil { + s.logger.Warn(ctx, "failed to subscribe to license updates", slog.Error(err)) + time.Sleep(b.NextBackOff()) + continue + } + // nolint: revive + defer cancel() + subscribed = true + s.logger.Debug(ctx, "successfully subscribed to pubsub") + } + + s.logger.Info(ctx, "syncing licensed entitlements") + ents, err := s.getEntitlements(ctx) + if err != nil { + s.logger.Warn(ctx, "failed to get feature entitlements", slog.Error(err)) + time.Sleep(b.NextBackOff()) + continue + } + b.Reset() + + s.mu.Lock() + s.entitlements = ents + s.mu.Unlock() + s.logger.Debug(ctx, "synced licensed entitlements") + + select { + case <-ctx.Done(): + return + case <-time.After(s.resyncInterval): + continue + case <-updates: + s.logger.Debug(ctx, "got pubsub update") + continue + } + } +} + +func max(a, b int64) int64 { + if a > b { + return a + } + return b +} diff --git a/enterprise/coderd/features_internal_test.go b/enterprise/coderd/features_internal_test.go new file mode 100644 index 0000000000000..bb1b14a57606d --- /dev/null +++ b/enterprise/coderd/features_internal_test.go @@ -0,0 +1,337 @@ +package coderd + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/testutil" +) + +func TestFeaturesService_EntitlementsAPI(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, nil) + + // Note that these are not actually used because we don't run the syncEntitlements + // routine in this test. + pubsub := database.NewPubsubInMemory() + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + keyID := "testing" + + t.Run("NoLicense", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + entitlements: entitlements{ + hasLicense: false, + activeUsers: numericalEntitlement{ + entitlement{notEntitled}, + entitlementLimit{ + unlimited: true, + }, + }, + auditLogs: entitlement{notEntitled}, + }, + } + result := requestEntitlements(t, uut) + assert.False(t, result.HasLicense) + assert.Empty(t, result.Warnings) + assert.Equal(t, codersdk.EntitlementNotEntitled, result.Features[codersdk.FeatureUserLimit].Entitlement) + assert.Equal(t, codersdk.EntitlementNotEntitled, result.Features[codersdk.FeatureAuditLog].Entitlement) + }) + + t.Run("FullLicense", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + db := databasefake.New() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + entitlements: entitlements{ + hasLicense: true, + activeUsers: numericalEntitlement{ + entitlement{entitled}, + entitlementLimit{ + unlimited: false, + limit: 100, + }, + }, + auditLogs: entitlement{entitled}, + }, + } + _, err = db.InsertUser(ctx, database.InsertUserParams{ + ID: uuid.UUID{}, + Email: "", + Username: "", + HashedPassword: nil, + CreatedAt: time.Time{}, + UpdatedAt: time.Time{}, + RBACRoles: nil, + LoginType: "", + }) + require.NoError(t, err) + result := requestEntitlements(t, uut) + assert.True(t, result.HasLicense) + ul := result.Features[codersdk.FeatureUserLimit] + assert.Equal(t, codersdk.EntitlementEntitled, ul.Entitlement) + assert.Equal(t, int64(100), *ul.Limit) + assert.Equal(t, int64(1), *ul.Actual) + assert.True(t, ul.Enabled) + al := result.Features[codersdk.FeatureAuditLog] + assert.Equal(t, codersdk.EntitlementEntitled, al.Entitlement) + assert.True(t, al.Enabled) + assert.Nil(t, al.Limit) + assert.Nil(t, al.Actual) + assert.Empty(t, result.Warnings) + }) + + t.Run("Warnings", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + db := databasefake.New() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + entitlements: entitlements{ + hasLicense: true, + activeUsers: numericalEntitlement{ + entitlement{gracePeriod}, + entitlementLimit{ + unlimited: false, + limit: 4, + }, + }, + auditLogs: entitlement{gracePeriod}, + }, + } + for i := byte(0); i < 5; i++ { + _, err = db.InsertUser(ctx, database.InsertUserParams{ + ID: uuid.UUID{i}, + Email: "", + Username: "", + HashedPassword: nil, + CreatedAt: time.Time{}, + UpdatedAt: time.Time{}, + RBACRoles: nil, + LoginType: "", + }) + require.NoError(t, err) + } + result := requestEntitlements(t, uut) + assert.True(t, result.HasLicense) + ul := result.Features[codersdk.FeatureUserLimit] + assert.Equal(t, codersdk.EntitlementGracePeriod, ul.Entitlement) + assert.Equal(t, int64(4), *ul.Limit) + assert.Equal(t, int64(5), *ul.Actual) + assert.True(t, ul.Enabled) + al := result.Features[codersdk.FeatureAuditLog] + assert.Equal(t, codersdk.EntitlementGracePeriod, al.Entitlement) + assert.True(t, al.Enabled) + assert.Nil(t, al.Limit) + assert.Nil(t, al.Actual) + assert.Len(t, result.Warnings, 2) + assert.Contains(t, result.Warnings, + "Your deployment has 5 active users but is only licensed for 4") + assert.Contains(t, result.Warnings, + "Audit logging is enabled but your license for this feature is expired") + }) +} + +func TestFeaturesServiceSyncEntitlements(t *testing.T) { + t.Parallel() + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + keyID := "testing" + + // This tests that pubsub updates work by setting the resync interval very long + t.Run("PubSub", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil) + pubsub := database.NewPubsubInMemory() + db := databasefake.New() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + resyncInterval: time.Hour, // no resyncs during test + entitlements: entitlements{}, + } + + _, invalidKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Start of day, 3 licenses, one expired, one invalid + _ = putLicense(ctx, t, db, priv, keyID, 1000, -2*time.Hour, -1*time.Hour) + _ = putLicense(ctx, t, db, invalidKey, "invalid", 900, time.Hour, 2*time.Hour) + l0 := putLicense(ctx, t, db, priv, keyID, 300, time.Hour, 2*time.Hour) + + go uut.syncEntitlements(ctx) + + testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) + + // New license + l1 := putLicense(ctx, t, db, priv, keyID, 305, time.Hour, 2*time.Hour) + err = pubsub.Publish(PubSubEventLicenses, []byte("add")) + require.NoError(t, err) + + // User limit goes up, because 305 > 300 + testutil.Eventually(ctx, t, userLimitIs(uut, 305), testutil.IntervalFast) + + // New license with lower limit + _ = putLicense(ctx, t, db, priv, keyID, 295, time.Hour, 2*time.Hour) + err = pubsub.Publish(PubSubEventLicenses, []byte("add")) + require.NoError(t, err) + + // Need to delete the others before the limit lowers + _, err = db.DeleteLicense(ctx, l1.ID) + require.NoError(t, err) + err = pubsub.Publish(PubSubEventLicenses, []byte("delete")) + require.NoError(t, err) + testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) + + _, err = db.DeleteLicense(ctx, l0.ID) + require.NoError(t, err) + err = pubsub.Publish(PubSubEventLicenses, []byte("delete")) + require.NoError(t, err) + testutil.Eventually(ctx, t, userLimitIs(uut, 295), testutil.IntervalFast) + }) + + // This tests that periodic resyncs work by setting the resync interval very fast and + // not sending any pubsub updates. + t.Run("Resyncs", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil) + pubsub := database.NewPubsubInMemory() + db := databasefake.New() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + resyncInterval: 10 * time.Millisecond, + entitlements: entitlements{}, + } + + _, invalidKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Start of day, 3 licenses, one expired, one invalid + _ = putLicense(ctx, t, db, priv, keyID, 1000, -2*time.Hour, -1*time.Hour) + _ = putLicense(ctx, t, db, invalidKey, "invalid", 900, time.Hour, 2*time.Hour) + l0 := putLicense(ctx, t, db, priv, keyID, 300, time.Hour, 2*time.Hour) + + go uut.syncEntitlements(ctx) + + testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) + + // New license + l1 := putLicense(ctx, t, db, priv, keyID, 305, time.Hour, 2*time.Hour) + + // User limit goes up, because 305 > 300 + testutil.Eventually(ctx, t, userLimitIs(uut, 305), testutil.IntervalFast) + + // New license with lower limit + _ = putLicense(ctx, t, db, priv, keyID, 295, time.Hour, 2*time.Hour) + + // Need to delete the others before the limit lowers + _, err = db.DeleteLicense(ctx, l1.ID) + require.NoError(t, err) + testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) + + _, err = db.DeleteLicense(ctx, l0.ID) + require.NoError(t, err) + testutil.Eventually(ctx, t, userLimitIs(uut, 295), testutil.IntervalFast) + }) +} + +func requestEntitlements(t *testing.T, uut coderd.FeaturesService) codersdk.Entitlements { + t.Helper() + r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil) + rw := httptest.NewRecorder() + uut.EntitlementsAPI(rw, r) + resp := rw.Result() + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + dec := json.NewDecoder(resp.Body) + var result codersdk.Entitlements + err := dec.Decode(&result) + require.NoError(t, err) + return result +} + +func putLicense( + ctx context.Context, t *testing.T, db database.Store, + k ed25519.PrivateKey, keyID string, userLimit int64, + timeToGrace, timeToExpire time.Duration, +) database.License { + t.Helper() + c := &Claims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "test@testing.test", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(timeToExpire)), + NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), + }, + LicenseExpires: jwt.NewNumericDate(time.Now().Add(timeToGrace)), + Version: CurrentVersion, + Features: Features{ + UserLimit: userLimit, + AuditLog: 1, + }, + } + j, err := makeLicense(c, k, keyID) + require.NoError(t, err) + l, err := db.InsertLicense(ctx, database.InsertLicenseParams{ + UploadedAt: c.IssuedAt.Time, + JWT: j, + Exp: c.ExpiresAt.Time, + }) + require.NoError(t, err) + return l +} + +func userLimitIs(fs *featuresService, limit int64) func(context.Context) bool { + return func(_ context.Context) bool { + fs.mu.RLock() + defer fs.mu.RUnlock() + return fs.entitlements.activeUsers.limit == limit + } +} diff --git a/enterprise/coderd/licenses.go b/enterprise/coderd/licenses.go index 16592fcde2654..02bef91d52bcb 100644 --- a/enterprise/coderd/licenses.go +++ b/enterprise/coderd/licenses.go @@ -64,8 +64,9 @@ type Claims struct { } var ( - ErrInvalidVersion = xerrors.New("license must be version 3") - ErrMissingKeyID = xerrors.Errorf("JOSE header must contain %s", HeaderKeyID) + ErrInvalidVersion = xerrors.New("license must be version 3") + ErrMissingKeyID = xerrors.Errorf("JOSE header must contain %s", HeaderKeyID) + ErrMissingLicenseExpires = xerrors.New("license missing license_expires") ) // parseLicense parses the license and returns the claims. If the license's signature is invalid or @@ -92,6 +93,30 @@ func parseLicense(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, e return nil, xerrors.New("unable to parse Claims") } +// validateDBLicense validates a database.License record, and if valid, returns the claims. If +// unparsable or invalid, it returns an error +func validateDBLicense(l database.License, keys map[string]ed25519.PublicKey) (*Claims, error) { + tok, err := jwt.ParseWithClaims( + l.JWT, + &Claims{}, + keyFunc(keys), + jwt.WithValidMethods(ValidMethods), + ) + if err != nil { + return nil, err + } + if claims, ok := tok.Claims.(*Claims); ok && tok.Valid { + if claims.Version != uint64(CurrentVersion) { + return nil, ErrInvalidVersion + } + if claims.LicenseExpires == nil { + return nil, ErrMissingLicenseExpires + } + return claims, nil + } + return nil, xerrors.New("unable to parse Claims") +} + func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, error) { return func(j *jwt.Token) (interface{}, error) { keyID, ok := j.Header[HeaderKeyID].(string) @@ -297,5 +322,11 @@ func (a *licenseAPI) delete(rw http.ResponseWriter, r *http.Request) { }) return } + + err = a.pubsub.Publish(PubSubEventLicenses, []byte("delete")) + if err != nil { + a.logger.Error(context.Background(), "failed to publish license delete", slog.Error(err)) + // don't fail the HTTP request, since we did write it successfully to the database + } rw.WriteHeader(http.StatusOK) }
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: