From 8c52b2ada656f5eababcb4878536b8eb8d18ba6f Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 26 Aug 2022 11:58:13 -0700 Subject: [PATCH 1/6] Use licenses for entitlements API Signed-off-by: Spike Curtis --- cli/features.go | 9 +- coderd/coderd.go | 6 +- coderd/database/databasefake/databasefake.go | 13 + coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 16 ++ coderd/database/queries/users.sql | 8 + coderd/features.go | 8 +- coderd/features_internal_test.go | 2 +- codersdk/features.go | 4 +- enterprise/coderd/coderd.go | 19 ++ enterprise/coderd/features.go | 264 +++++++++++++++++++ enterprise/coderd/licenses.go | 35 ++- 12 files changed, 375 insertions(+), 10 deletions(-) create mode 100644 enterprise/coderd/features.go diff --git a/cli/features.go b/cli/features.go index 5d631fc04977f..1995153275eaf 100644 --- a/cli/features.go +++ b/cli/features.go @@ -1,6 +1,7 @@ package cli import ( + "bytes" "encoding/json" "fmt" "strings" @@ -53,12 +54,14 @@ func featuresList() *cobra.Command { return xerrors.Errorf("render table: %w", err) } case "json": - outBytes, err := json.Marshal(entitlements) + buf := new(bytes.Buffer) + enc := json.NewEncoder(buf) + enc.SetIndent("", " ") + err = enc.Encode(entitlements) if err != nil { return xerrors.Errorf("marshal features to JSON: %w", err) } - - out = string(outBytes) + out = buf.String() default: return xerrors.Errorf(`unknown output format %q, only "table" and "json" are supported`, outputFormat) } diff --git a/coderd/coderd.go b/coderd/coderd.go index 7bfdffe1d2382..ac8d82a1408a1 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -68,6 +68,7 @@ type Options struct { TracerProvider *sdktrace.TracerProvider AutoImportTemplates []AutoImportTemplate LicenseHandler http.Handler + FeaturesService FeaturesService } // New constructs a Coder API handler. @@ -97,6 +98,9 @@ func New(options *Options) *API { if options.LicenseHandler == nil { options.LicenseHandler = licenses() } + if options.FeaturesService == nil { + options.FeaturesService = featuresService{} + } siteCacheDir := options.CacheDir if siteCacheDir != "" { @@ -404,7 +408,7 @@ func New(options *Options) *API { }) r.Route("/entitlements", func(r chi.Router) { r.Use(apiKeyMiddleware) - r.Get("/", entitlements) + r.Get("/", api.FeaturesService.EntitlementsAPI) }) r.Route("/licenses", func(r chi.Router) { r.Use(apiKeyMiddleware) diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 21634eab29099..e6e64be63057b 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -246,6 +246,19 @@ func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) { return int64(len(q.users)), nil } +func (q *fakeQuerier) GetActiveUserCount(_ context.Context) (int64, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + active := int64(0) + for _, u := range q.users { + if u.Status == database.UserStatusActive { + active++ + } + } + return active, nil +} + func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.User, error) { q.mutex.RLock() defer q.mutex.RUnlock() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 9603608ebbb05..49bf69a0bc5f8 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -25,6 +25,7 @@ type querier interface { DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error) + GetActiveUserCount(ctx context.Context) (int64, error) // GetAuditLogsBefore retrieves `limit` number of audit logs before the provided // ID. GetAuditLogsBefore(ctx context.Context, arg GetAuditLogsBeforeParams) ([]AuditLog, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 9a22d86d888a8..edb817f8aebc0 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2664,6 +2664,22 @@ func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinke return i, err } +const getActiveUserCount = `-- name: GetActiveUserCount :one +SELECT + COUNT(*) +FROM + users +WHERE + status = 'active'::public.user_status +` + +func (q *sqlQuerier) GetActiveUserCount(ctx context.Context) (int64, error) { + row := q.db.QueryRowContext(ctx, getActiveUserCount) + var count int64 + err := row.Scan(&count) + return count, err +} + const getAuthorizationUserRoles = `-- name: GetAuthorizationUserRoles :one SELECT -- username is returned just to help for logging purposes diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index 1d9caa758625e..12751fe064b47 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -28,6 +28,14 @@ SELECT FROM users; +-- name: GetActiveUserCount :one +SELECT + COUNT(*) +FROM + users +WHERE + status = 'active'::public.user_status; + -- name: InsertUser :one INSERT INTO users ( diff --git a/coderd/features.go b/coderd/features.go index a6eaeca9c545b..02fc00a068548 100644 --- a/coderd/features.go +++ b/coderd/features.go @@ -7,7 +7,13 @@ import ( "github.com/coder/coder/codersdk" ) -func entitlements(rw http.ResponseWriter, _ *http.Request) { +type FeaturesService interface { + EntitlementsAPI(w http.ResponseWriter, r *http.Request) +} + +type featuresService struct{} + +func (featuresService) EntitlementsAPI(rw http.ResponseWriter, _ *http.Request) { features := make(map[string]codersdk.Feature) for _, f := range codersdk.FeatureNames { features[f] = codersdk.Feature{ diff --git a/coderd/features_internal_test.go b/coderd/features_internal_test.go index 50c7e8f53e397..d06fc96e19626 100644 --- a/coderd/features_internal_test.go +++ b/coderd/features_internal_test.go @@ -18,7 +18,7 @@ func TestEntitlements(t *testing.T) { t.Parallel() r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil) rw := httptest.NewRecorder() - entitlements(rw, r) + featuresService{}.EntitlementsAPI(rw, r) resp := rw.Result() defer resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) diff --git a/codersdk/features.go b/codersdk/features.go index 6bdfd5ff53bd4..37b0113c37dfb 100644 --- a/codersdk/features.go +++ b/codersdk/features.go @@ -24,8 +24,8 @@ var FeatureNames = []string{FeatureUserLimit, FeatureAuditLog} type Feature struct { Entitlement Entitlement `json:"entitlement"` Enabled bool `json:"enabled"` - Limit *int64 `json:"limit"` - Actual *int64 `json:"actual"` + Limit *int64 `json:"limit,omitempty"` + Actual *int64 `json:"actual,omitempty"` } type Entitlements struct { diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 2be49052d3658..598c32f11b367 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -1,12 +1,18 @@ package coderd import ( + "context" + "os" + "strings" + "golang.org/x/xerrors" "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/rbac" ) +const EnvAuditLogEnable = "CODER_AUDIT_LOG_ENABLE" + func NewEnterprise(options *coderd.Options) *coderd.API { var eOpts = *options if eOpts.Authorizer == nil { @@ -26,5 +32,18 @@ func NewEnterprise(options *coderd.Options) *coderd.API { Authorizer: eOpts.Authorizer, Logger: eOpts.Logger, }).handler() + en := Enablements{AuditLogs: true} + auditLog := os.Getenv(EnvAuditLogEnable) + auditLog = strings.ToLower(auditLog) + if auditLog == "disable" || auditLog == "false" || auditLog == "0" || auditLog == "no" { + en.AuditLogs = false + } + eOpts.FeaturesService = newFeaturesService( + context.Background(), + eOpts.Logger, + eOpts.Database, + eOpts.Pubsub, + en, + ) return coderd.New(&eOpts) } diff --git a/enterprise/coderd/features.go b/enterprise/coderd/features.go new file mode 100644 index 0000000000000..78814e33fb56e --- /dev/null +++ b/enterprise/coderd/features.go @@ -0,0 +1,264 @@ +package coderd + +import ( + "context" + "crypto/ed25519" + "fmt" + "net/http" + "sync" + "time" + + "cdr.dev/slog" + "github.com/cenkalti/backoff/v4" + + agpl "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/codersdk" +) + +type Enablements struct { + AuditLogs bool +} + +type featuresService struct { + logger slog.Logger + database database.Store + pubsub database.Pubsub + keys map[string]ed25519.PublicKey + enablements Enablements + resyncInterval time.Duration + + mu sync.RWMutex + entitlements entitlements +} + +// newFeaturesService creates a FeaturesService and starts it. It will continue running for the +// duration of the passed ctx. +func newFeaturesService( + ctx context.Context, + logger slog.Logger, + db database.Store, + pubsub database.Pubsub, + enablements Enablements, +) agpl.FeaturesService { + fs := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: keys, + enablements: enablements, + resyncInterval: 10 * time.Minute, + entitlements: entitlements{ + activeUsers: numericalEntitlement{ + entitlementLimit: entitlementLimit{ + unlimited: true, + }, + }, + }, + } + go fs.syncEntitlements(ctx) + return fs +} + +func (s *featuresService) EntitlementsAPI(rw http.ResponseWriter, r *http.Request) { + s.mu.RLock() + e := s.entitlements + s.mu.RUnlock() + s.logger.Info(r.Context(), "entitlements now", slog.F("entitlements", e)) + + resp := codersdk.Entitlements{ + Features: make(map[string]codersdk.Feature), + Warnings: make([]string, 0), + HasLicense: e.hasLicense, + } + + // User limit + uf := codersdk.Feature{ + Entitlement: e.activeUsers.state.toSDK(), + Enabled: true, + } + if !e.activeUsers.unlimited { + n, err := s.database.GetActiveUserCount(r.Context()) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Unable to query database", + Detail: err.Error(), + }) + return + } + uf.Actual = &n + uf.Limit = &e.activeUsers.limit + if n > e.activeUsers.limit { + resp.Warnings = append(resp.Warnings, + fmt.Sprintf( + "Your deployment has %d active users but is only licensed for %d", + n, e.activeUsers.limit)) + } + } + resp.Features[codersdk.FeatureUserLimit] = uf + + // Audit logs + resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{ + Entitlement: e.auditLogs.state.toSDK(), + Enabled: s.enablements.AuditLogs, + } + if e.auditLogs.state == gracePeriod && s.enablements.AuditLogs { + resp.Warnings = append(resp.Warnings, + "Audit logging is enabled but your license for this feature is expired") + } + + httpapi.Write(rw, http.StatusOK, resp) +} + +type entitlementState int + +const ( + notEntitled entitlementState = iota + gracePeriod + entitled +) + +type entitlementLimit struct { + unlimited bool + limit int64 +} + +type entitlement struct { + state entitlementState +} + +func (s entitlementState) toSDK() codersdk.Entitlement { + switch s { + case notEntitled: + return codersdk.EntitlementNotEntitled + case gracePeriod: + return codersdk.EntitlementGracePeriod + case entitled: + return codersdk.EntitlementGracePeriod + default: + panic("unknown entitlementState") + } +} + +type numericalEntitlement struct { + entitlement + entitlementLimit +} + +type entitlements struct { + hasLicense bool + activeUsers numericalEntitlement + auditLogs entitlement +} + +func (s *featuresService) getEntitlements(ctx context.Context) (entitlements, error) { + licenses, err := s.database.GetLicenses(ctx) + if err != nil { + return entitlements{}, err + } + now := time.Now() + e := entitlements{ + activeUsers: numericalEntitlement{ + entitlementLimit: entitlementLimit{ + unlimited: true, + }, + }, + } + s.logger.Info(ctx, "Got licenses", slog.F("num", len(licenses))) + for _, l := range licenses { + claims, err := validateDBLicense(l, s.keys) + if err != nil { + s.logger.Info(ctx, "skipping invalid license", + slog.F("id", l.ID), slog.Error(err)) + continue + } + e.hasLicense = true + thisEntitlement := entitled + if now.After(claims.LicenseExpires.Time) { + s.logger.Info(ctx, "grace period license") + // if the grace period were over, the validation fails, so if we are after + // LicenseExpires we must be in grace period. + thisEntitlement = gracePeriod + } + if claims.Features.UserLimit > 0 { + s.logger.Info(ctx, "user limit", slog.F("user_limit", claims.Features.UserLimit)) + e.activeUsers.state = thisEntitlement + e.activeUsers.unlimited = false + e.activeUsers.limit = max(e.activeUsers.limit, claims.Features.UserLimit) + } + if claims.Features.AuditLog > 0 { + e.auditLogs.state = thisEntitlement + } + } + return e, nil +} + +func (s *featuresService) syncEntitlements(ctx context.Context) { + s.logger.Info(ctx, "starting license sync function") + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + b := backoff.WithContext(eb, ctx) + updates := make(chan struct{}, 1) + subscribed := false + + for { + select { + case <-ctx.Done(): + return + default: + // pass + } + if !subscribed { + cancel, err := s.pubsub.Subscribe(PubSubEventLicenses, func(_ context.Context, _ []byte) { + // don't block. If the channel is full, drop the event, as there is a resync + // scheduled already. + select { + case updates <- struct{}{}: + // pass + default: + // pass + } + }) + if err != nil { + s.logger.Warn(ctx, "failed to subscribe to license updates", slog.Error(err)) + time.Sleep(b.NextBackOff()) + continue + } + defer cancel() + subscribed = true + s.logger.Info(ctx, "successfully subscribed to pubsub") + } + + s.logger.Info(ctx, "syncing licensed entitlements") + ents, err := s.getEntitlements(ctx) + if err != nil { + s.logger.Warn(ctx, "failed to get feature entitlements", slog.Error(err)) + time.Sleep(b.NextBackOff()) + continue + } + b.Reset() + + s.mu.Lock() + s.entitlements = ents + s.mu.Unlock() + s.logger.Info(ctx, "synced licensed entitlements") + + select { + case <-ctx.Done(): + return + case <-time.After(s.resyncInterval): + continue + case <-updates: + s.logger.Info(ctx, "got pubsub update") + continue + } + } +} + +func max(a, b int64) int64 { + if a > b { + return a + } + return b +} diff --git a/enterprise/coderd/licenses.go b/enterprise/coderd/licenses.go index 16592fcde2654..02bef91d52bcb 100644 --- a/enterprise/coderd/licenses.go +++ b/enterprise/coderd/licenses.go @@ -64,8 +64,9 @@ type Claims struct { } var ( - ErrInvalidVersion = xerrors.New("license must be version 3") - ErrMissingKeyID = xerrors.Errorf("JOSE header must contain %s", HeaderKeyID) + ErrInvalidVersion = xerrors.New("license must be version 3") + ErrMissingKeyID = xerrors.Errorf("JOSE header must contain %s", HeaderKeyID) + ErrMissingLicenseExpires = xerrors.New("license missing license_expires") ) // parseLicense parses the license and returns the claims. If the license's signature is invalid or @@ -92,6 +93,30 @@ func parseLicense(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, e return nil, xerrors.New("unable to parse Claims") } +// validateDBLicense validates a database.License record, and if valid, returns the claims. If +// unparsable or invalid, it returns an error +func validateDBLicense(l database.License, keys map[string]ed25519.PublicKey) (*Claims, error) { + tok, err := jwt.ParseWithClaims( + l.JWT, + &Claims{}, + keyFunc(keys), + jwt.WithValidMethods(ValidMethods), + ) + if err != nil { + return nil, err + } + if claims, ok := tok.Claims.(*Claims); ok && tok.Valid { + if claims.Version != uint64(CurrentVersion) { + return nil, ErrInvalidVersion + } + if claims.LicenseExpires == nil { + return nil, ErrMissingLicenseExpires + } + return claims, nil + } + return nil, xerrors.New("unable to parse Claims") +} + func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, error) { return func(j *jwt.Token) (interface{}, error) { keyID, ok := j.Header[HeaderKeyID].(string) @@ -297,5 +322,11 @@ func (a *licenseAPI) delete(rw http.ResponseWriter, r *http.Request) { }) return } + + err = a.pubsub.Publish(PubSubEventLicenses, []byte("delete")) + if err != nil { + a.logger.Error(context.Background(), "failed to publish license delete", slog.Error(err)) + // don't fail the HTTP request, since we did write it successfully to the database + } rw.WriteHeader(http.StatusOK) } From f7380cdedb0ea88f223045a978aa7480f9140dd1 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 26 Aug 2022 14:33:10 -0700 Subject: [PATCH 2/6] Tests for entitlements API Signed-off-by: Spike Curtis --- coderd/database/databasefake/databasefake.go | 15 + coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 35 ++ coderd/database/queries/licenses.sql | 6 + enterprise/coderd/features.go | 4 +- enterprise/coderd/features_internal_test.go | 336 +++++++++++++++++++ 6 files changed, 395 insertions(+), 2 deletions(-) create mode 100644 enterprise/coderd/features_internal_test.go diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index e6e64be63057b..1ee445c8e4f22 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -2341,6 +2341,21 @@ func (q *fakeQuerier) GetLicenses(_ context.Context) ([]database.License, error) return results, nil } +func (q *fakeQuerier) GetUnexpiredLicenses(_ context.Context) ([]database.License, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + now := time.Now() + var results []database.License + for _, l := range q.licenses { + if l.Exp.After(now) { + results = append(results, l) + } + } + sort.Slice(results, func(i, j int) bool { return results[i].ID < results[j].ID }) + return results, nil +} + func (q *fakeQuerier) DeleteLicense(_ context.Context, id int32) (int32, error) { q.mutex.Lock() defer q.mutex.Unlock() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 49bf69a0bc5f8..389f15b385d6d 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -64,6 +64,7 @@ type querier interface { GetTemplateVersionsCreatedAfter(ctx context.Context, createdAt time.Time) ([]TemplateVersion, error) GetTemplates(ctx context.Context) ([]Template, error) GetTemplatesWithFilter(ctx context.Context, arg GetTemplatesWithFilterParams) ([]Template, error) + GetUnexpiredLicenses(ctx context.Context) ([]License, error) GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) GetUserCount(ctx context.Context) (int64, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index edb817f8aebc0..1e4a194fa740d 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -522,6 +522,41 @@ func (q *sqlQuerier) GetLicenses(ctx context.Context) ([]License, error) { return items, nil } +const getUnexpiredLicenses = `-- name: GetUnexpiredLicenses :many +SELECT id, uploaded_at, jwt, exp +FROM licenses +WHERE exp > NOW() +ORDER BY (id) +` + +func (q *sqlQuerier) GetUnexpiredLicenses(ctx context.Context) ([]License, error) { + rows, err := q.db.QueryContext(ctx, getUnexpiredLicenses) + if err != nil { + return nil, err + } + defer rows.Close() + var items []License + for rows.Next() { + var i License + if err := rows.Scan( + &i.ID, + &i.UploadedAt, + &i.JWT, + &i.Exp, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const insertLicense = `-- name: InsertLicense :one INSERT INTO licenses ( diff --git a/coderd/database/queries/licenses.sql b/coderd/database/queries/licenses.sql index e299589087119..39419c301761d 100644 --- a/coderd/database/queries/licenses.sql +++ b/coderd/database/queries/licenses.sql @@ -13,6 +13,12 @@ SELECT * FROM licenses ORDER BY (id); +-- name: GetUnexpiredLicenses :many +SELECT * +FROM licenses +WHERE exp > NOW() +ORDER BY (id); + -- name: DeleteLicense :one DELETE FROM licenses diff --git a/enterprise/coderd/features.go b/enterprise/coderd/features.go index 78814e33fb56e..1e9d3392cd448 100644 --- a/enterprise/coderd/features.go +++ b/enterprise/coderd/features.go @@ -135,7 +135,7 @@ func (s entitlementState) toSDK() codersdk.Entitlement { case gracePeriod: return codersdk.EntitlementGracePeriod case entitled: - return codersdk.EntitlementGracePeriod + return codersdk.EntitlementEntitled default: panic("unknown entitlementState") } @@ -153,7 +153,7 @@ type entitlements struct { } func (s *featuresService) getEntitlements(ctx context.Context) (entitlements, error) { - licenses, err := s.database.GetLicenses(ctx) + licenses, err := s.database.GetUnexpiredLicenses(ctx) if err != nil { return entitlements{}, err } diff --git a/enterprise/coderd/features_internal_test.go b/enterprise/coderd/features_internal_test.go new file mode 100644 index 0000000000000..ef639126999f0 --- /dev/null +++ b/enterprise/coderd/features_internal_test.go @@ -0,0 +1,336 @@ +package coderd + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/testutil" +) + +func TestFeaturesService_EntitlementsAPI(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, nil) + + // Note that these are not actually used because we don't run the syncEntitlements + // routine in this test. + pubsub := database.NewPubsubInMemory() + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + keyID := "testing" + + t.Run("NoLicense", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + entitlements: entitlements{ + hasLicense: false, + activeUsers: numericalEntitlement{ + entitlement{notEntitled}, + entitlementLimit{ + unlimited: true, + }, + }, + auditLogs: entitlement{notEntitled}, + }, + } + result := requestEntitlements(t, uut) + assert.False(t, result.HasLicense) + assert.Empty(t, result.Warnings) + assert.Equal(t, codersdk.EntitlementNotEntitled, result.Features[codersdk.FeatureUserLimit].Entitlement) + assert.Equal(t, codersdk.EntitlementNotEntitled, result.Features[codersdk.FeatureAuditLog].Entitlement) + }) + + t.Run("FullLicense", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + db := databasefake.New() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + entitlements: entitlements{ + hasLicense: true, + activeUsers: numericalEntitlement{ + entitlement{entitled}, + entitlementLimit{ + unlimited: false, + limit: 100, + }, + }, + auditLogs: entitlement{entitled}, + }, + } + _, err = db.InsertUser(ctx, database.InsertUserParams{ + ID: uuid.UUID{}, + Email: "", + Username: "", + HashedPassword: nil, + CreatedAt: time.Time{}, + UpdatedAt: time.Time{}, + RBACRoles: nil, + LoginType: "", + }) + require.NoError(t, err) + result := requestEntitlements(t, uut) + assert.True(t, result.HasLicense) + ul := result.Features[codersdk.FeatureUserLimit] + assert.Equal(t, codersdk.EntitlementEntitled, ul.Entitlement) + assert.Equal(t, int64(100), *ul.Limit) + assert.Equal(t, int64(1), *ul.Actual) + assert.True(t, ul.Enabled) + al := result.Features[codersdk.FeatureAuditLog] + assert.Equal(t, codersdk.EntitlementEntitled, al.Entitlement) + assert.True(t, al.Enabled) + assert.Nil(t, al.Limit) + assert.Nil(t, al.Actual) + assert.Empty(t, result.Warnings) + }) + + t.Run("Warnings", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + db := databasefake.New() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + entitlements: entitlements{ + hasLicense: true, + activeUsers: numericalEntitlement{ + entitlement{gracePeriod}, + entitlementLimit{ + unlimited: false, + limit: 4, + }, + }, + auditLogs: entitlement{gracePeriod}, + }, + } + for i := byte(0); i < 5; i++ { + _, err = db.InsertUser(ctx, database.InsertUserParams{ + ID: uuid.UUID{i}, + Email: "", + Username: "", + HashedPassword: nil, + CreatedAt: time.Time{}, + UpdatedAt: time.Time{}, + RBACRoles: nil, + LoginType: "", + }) + require.NoError(t, err) + } + result := requestEntitlements(t, uut) + assert.True(t, result.HasLicense) + ul := result.Features[codersdk.FeatureUserLimit] + assert.Equal(t, codersdk.EntitlementGracePeriod, ul.Entitlement) + assert.Equal(t, int64(4), *ul.Limit) + assert.Equal(t, int64(5), *ul.Actual) + assert.True(t, ul.Enabled) + al := result.Features[codersdk.FeatureAuditLog] + assert.Equal(t, codersdk.EntitlementGracePeriod, al.Entitlement) + assert.True(t, al.Enabled) + assert.Nil(t, al.Limit) + assert.Nil(t, al.Actual) + assert.Len(t, result.Warnings, 2) + assert.Contains(t, result.Warnings, + "Your deployment has 5 active users but is only licensed for 4") + assert.Contains(t, result.Warnings, + "Audit logging is enabled but your license for this feature is expired") + }) +} + +func TestFeaturesServiceSyncEntitlements(t *testing.T) { + t.Parallel() + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + keyID := "testing" + + // This tests that pubsub updates work by setting the resync interval very long + t.Run("PubSub", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil) + pubsub := database.NewPubsubInMemory() + db := databasefake.New() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + resyncInterval: time.Hour, // no resyncs during test + entitlements: entitlements{}, + } + + _, invalidKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Start of day, 3 licenses, one expired, one invalid + _ = putLicense(ctx, t, db, priv, keyID, 1000, -2*time.Hour, -1*time.Hour) + _ = putLicense(ctx, t, db, invalidKey, "invalid", 900, time.Hour, 2*time.Hour) + l0 := putLicense(ctx, t, db, priv, keyID, 300, time.Hour, 2*time.Hour) + + go uut.syncEntitlements(ctx) + + testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) + + // New license + l1 := putLicense(ctx, t, db, priv, keyID, 305, time.Hour, 2*time.Hour) + err = pubsub.Publish(PubSubEventLicenses, []byte("add")) + require.NoError(t, err) + + // User limit goes up, because 305 > 300 + testutil.Eventually(ctx, t, userLimitIs(uut, 305), testutil.IntervalFast) + + // New license with lower limit + _ = putLicense(ctx, t, db, priv, keyID, 295, time.Hour, 2*time.Hour) + err = pubsub.Publish(PubSubEventLicenses, []byte("add")) + require.NoError(t, err) + + // Need to delete the others before the limit lowers + _, err = db.DeleteLicense(ctx, l1.ID) + require.NoError(t, err) + err = pubsub.Publish(PubSubEventLicenses, []byte("delete")) + require.NoError(t, err) + testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) + + _, err = db.DeleteLicense(ctx, l0.ID) + require.NoError(t, err) + err = pubsub.Publish(PubSubEventLicenses, []byte("delete")) + require.NoError(t, err) + testutil.Eventually(ctx, t, userLimitIs(uut, 295), testutil.IntervalFast) + }) + + // This tests that periodic resyncs work by setting the resync interval very fast and + // not sending any pubsub updates. + t.Run("Resyncs", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil) + pubsub := database.NewPubsubInMemory() + db := databasefake.New() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + resyncInterval: 10 * time.Millisecond, + entitlements: entitlements{}, + } + + _, invalidKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + // Start of day, 3 licenses, one expired, one invalid + _ = putLicense(ctx, t, db, priv, keyID, 1000, -2*time.Hour, -1*time.Hour) + _ = putLicense(ctx, t, db, invalidKey, "invalid", 900, time.Hour, 2*time.Hour) + l0 := putLicense(ctx, t, db, priv, keyID, 300, time.Hour, 2*time.Hour) + + go uut.syncEntitlements(ctx) + + testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) + + // New license + l1 := putLicense(ctx, t, db, priv, keyID, 305, time.Hour, 2*time.Hour) + + // User limit goes up, because 305 > 300 + testutil.Eventually(ctx, t, userLimitIs(uut, 305), testutil.IntervalFast) + + // New license with lower limit + _ = putLicense(ctx, t, db, priv, keyID, 295, time.Hour, 2*time.Hour) + + // Need to delete the others before the limit lowers + _, err = db.DeleteLicense(ctx, l1.ID) + require.NoError(t, err) + testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast) + + _, err = db.DeleteLicense(ctx, l0.ID) + require.NoError(t, err) + testutil.Eventually(ctx, t, userLimitIs(uut, 295), testutil.IntervalFast) + }) +} + +func requestEntitlements(t *testing.T, uut coderd.FeaturesService) codersdk.Entitlements { + t.Helper() + r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil) + rw := httptest.NewRecorder() + uut.EntitlementsAPI(rw, r) + resp := rw.Result() + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + dec := json.NewDecoder(resp.Body) + var result codersdk.Entitlements + err := dec.Decode(&result) + require.NoError(t, err) + return result +} + +func putLicense( + ctx context.Context, t *testing.T, db database.Store, + k ed25519.PrivateKey, keyID string, userLimit int64, + timeToGrace, timeToExpire time.Duration, +) database.License { + t.Helper() + c := &Claims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "test@testing.test", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(timeToExpire)), + NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), + }, + LicenseExpires: jwt.NewNumericDate(time.Now().Add(timeToGrace)), + Version: CurrentVersion, + Features: Features{ + UserLimit: userLimit, + AuditLog: 1, + }, + } + j, err := makeLicense(c, k, keyID) + require.NoError(t, err) + l, err := db.InsertLicense(ctx, database.InsertLicenseParams{ + UploadedAt: c.IssuedAt.Time, + JWT: j, + Exp: c.ExpiresAt.Time, + }) + require.NoError(t, err) + return l +} + +func userLimitIs(fs *featuresService, limit int64) func(context.Context) bool { + return func(_ context.Context) bool { + fs.mu.RLock() + defer fs.mu.RUnlock() + return fs.entitlements.activeUsers.limit == limit + } +} From c4b897f890f33f473e49f946f6a21b3b17689e83 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 26 Aug 2022 14:39:43 -0700 Subject: [PATCH 3/6] Add commentary about FeatureService Signed-off-by: Spike Curtis --- coderd/features.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/coderd/features.go b/coderd/features.go index 02fc00a068548..55ddd2af895f9 100644 --- a/coderd/features.go +++ b/coderd/features.go @@ -7,8 +7,15 @@ import ( "github.com/coder/coder/codersdk" ) +// FeaturesService is the interface for interacting with enterprise features. type FeaturesService interface { EntitlementsAPI(w http.ResponseWriter, r *http.Request) + + // TODO + // Get returns the implementations for feature interfaces. Parameter `s `must be a pointer to a + // struct type containing feature interfaces as fields. The FeatureService sets all fields to + // the correct implementations depending on whether the features are turned on. + // Get(s any) error } type featuresService struct{} From 1e0f4be1a4637a2ee5f4ecbcd05546fb0081b6ec Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 26 Aug 2022 14:41:36 -0700 Subject: [PATCH 4/6] Lint Signed-off-by: Spike Curtis --- enterprise/coderd/features.go | 3 ++- enterprise/coderd/features_internal_test.go | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/enterprise/coderd/features.go b/enterprise/coderd/features.go index 1e9d3392cd448..34fb986250eef 100644 --- a/enterprise/coderd/features.go +++ b/enterprise/coderd/features.go @@ -8,9 +8,10 @@ import ( "sync" "time" - "cdr.dev/slog" "github.com/cenkalti/backoff/v4" + "cdr.dev/slog" + agpl "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/httpapi" diff --git a/enterprise/coderd/features_internal_test.go b/enterprise/coderd/features_internal_test.go index ef639126999f0..bb1b14a57606d 100644 --- a/enterprise/coderd/features_internal_test.go +++ b/enterprise/coderd/features_internal_test.go @@ -12,11 +12,12 @@ import ( "github.com/golang-jwt/jwt/v4" - "cdr.dev/slog/sloggers/slogtest" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" From 4f6a8cef01d0bcacb86494923b5e24174826614c Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 26 Aug 2022 15:13:22 -0700 Subject: [PATCH 5/6] Quiet down the logs Signed-off-by: Spike Curtis --- enterprise/coderd/features.go | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/enterprise/coderd/features.go b/enterprise/coderd/features.go index 34fb986250eef..66d37a1d443bf 100644 --- a/enterprise/coderd/features.go +++ b/enterprise/coderd/features.go @@ -66,7 +66,6 @@ func (s *featuresService) EntitlementsAPI(rw http.ResponseWriter, r *http.Reques s.mu.RLock() e := s.entitlements s.mu.RUnlock() - s.logger.Info(r.Context(), "entitlements now", slog.F("entitlements", e)) resp := codersdk.Entitlements{ Features: make(map[string]codersdk.Feature), @@ -166,24 +165,21 @@ func (s *featuresService) getEntitlements(ctx context.Context) (entitlements, er }, }, } - s.logger.Info(ctx, "Got licenses", slog.F("num", len(licenses))) for _, l := range licenses { claims, err := validateDBLicense(l, s.keys) if err != nil { - s.logger.Info(ctx, "skipping invalid license", + s.logger.Debug(ctx, "skipping invalid license", slog.F("id", l.ID), slog.Error(err)) continue } e.hasLicense = true thisEntitlement := entitled if now.After(claims.LicenseExpires.Time) { - s.logger.Info(ctx, "grace period license") // if the grace period were over, the validation fails, so if we are after // LicenseExpires we must be in grace period. thisEntitlement = gracePeriod } if claims.Features.UserLimit > 0 { - s.logger.Info(ctx, "user limit", slog.F("user_limit", claims.Features.UserLimit)) e.activeUsers.state = thisEntitlement e.activeUsers.unlimited = false e.activeUsers.limit = max(e.activeUsers.limit, claims.Features.UserLimit) @@ -196,7 +192,6 @@ func (s *featuresService) getEntitlements(ctx context.Context) (entitlements, er } func (s *featuresService) syncEntitlements(ctx context.Context) { - s.logger.Info(ctx, "starting license sync function") eb := backoff.NewExponentialBackOff() eb.MaxElapsedTime = 0 // retry indefinitely b := backoff.WithContext(eb, ctx) @@ -228,7 +223,7 @@ func (s *featuresService) syncEntitlements(ctx context.Context) { } defer cancel() subscribed = true - s.logger.Info(ctx, "successfully subscribed to pubsub") + s.logger.Debug(ctx, "successfully subscribed to pubsub") } s.logger.Info(ctx, "syncing licensed entitlements") @@ -243,7 +238,7 @@ func (s *featuresService) syncEntitlements(ctx context.Context) { s.mu.Lock() s.entitlements = ents s.mu.Unlock() - s.logger.Info(ctx, "synced licensed entitlements") + s.logger.Debug(ctx, "synced licensed entitlements") select { case <-ctx.Done(): @@ -251,7 +246,7 @@ func (s *featuresService) syncEntitlements(ctx context.Context) { case <-time.After(s.resyncInterval): continue case <-updates: - s.logger.Info(ctx, "got pubsub update") + s.logger.Debug(ctx, "got pubsub update") continue } } From ac3bd47b3df2c5cb310debf0a24ff47befb49730 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 26 Aug 2022 16:00:24 -0700 Subject: [PATCH 6/6] Tell revive it's ok Signed-off-by: Spike Curtis --- enterprise/coderd/features.go | 1 + 1 file changed, 1 insertion(+) diff --git a/enterprise/coderd/features.go b/enterprise/coderd/features.go index 66d37a1d443bf..2102cdc0eb122 100644 --- a/enterprise/coderd/features.go +++ b/enterprise/coderd/features.go @@ -221,6 +221,7 @@ func (s *featuresService) syncEntitlements(ctx context.Context) { time.Sleep(b.NextBackOff()) continue } + // nolint: revive defer cancel() subscribed = true s.logger.Debug(ctx, "successfully subscribed to pubsub") pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy