Skip to content

Commit af125c3

Browse files
authored
chore: refactor entitlements to be a safe object to use (#14406)
* chore: refactor entitlements to be passable as an argument Previously, all usage of entitlements requires mutex usage on the api struct directly. This prevents passing the entitlements to a sub package. It also creates the possibility for misuse.
1 parent cb6a472 commit af125c3

17 files changed

+247
-124
lines changed

coderd/coderd.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
"tailscale.com/util/singleflight"
3838

3939
"cdr.dev/slog"
40+
"github.com/coder/coder/v2/coderd/entitlements"
4041
"github.com/coder/quartz"
4142
"github.com/coder/serpent"
4243

@@ -157,6 +158,9 @@ type Options struct {
157158
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
158159
// RefreshEntitlements is used to set correct entitlements after creating first user and generating trial license.
159160
RefreshEntitlements func(ctx context.Context) error
161+
// Entitlements can come from the enterprise caller if enterprise code is
162+
// included.
163+
Entitlements *entitlements.Set
160164
// PostAuthAdditionalHeadersFunc is used to add additional headers to the response
161165
// after a successful authentication.
162166
// This is somewhat janky, but seemingly the only reasonable way to add a header
@@ -263,6 +267,9 @@ func New(options *Options) *API {
263267
if options == nil {
264268
options = &Options{}
265269
}
270+
if options.Entitlements == nil {
271+
options.Entitlements = entitlements.New()
272+
}
266273
if options.NewTicker == nil {
267274
options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) {
268275
ticker := time.NewTicker(duration)
@@ -500,6 +507,7 @@ func New(options *Options) *API {
500507
DocsURL: options.DeploymentValues.DocsURL.String(),
501508
AppearanceFetcher: &api.AppearanceFetcher,
502509
BuildInfo: buildInfo,
510+
Entitlements: options.Entitlements,
503511
})
504512
api.SiteHandler.Experiments.Store(&experiments)
505513

coderd/entitlements/entitlements.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package entitlements
2+
3+
import (
4+
"encoding/json"
5+
"net/http"
6+
"sync"
7+
"time"
8+
9+
"github.com/coder/coder/v2/codersdk"
10+
)
11+
12+
type Set struct {
13+
entitlementsMu sync.RWMutex
14+
entitlements codersdk.Entitlements
15+
}
16+
17+
func New() *Set {
18+
return &Set{
19+
// Some defaults for an unlicensed instance.
20+
// These will be updated when coderd is initialized.
21+
entitlements: codersdk.Entitlements{
22+
Features: map[codersdk.FeatureName]codersdk.Feature{},
23+
Warnings: nil,
24+
Errors: nil,
25+
HasLicense: false,
26+
Trial: false,
27+
RequireTelemetry: false,
28+
RefreshedAt: time.Time{},
29+
},
30+
}
31+
}
32+
33+
// AllowRefresh returns whether the entitlements are allowed to be refreshed.
34+
// If it returns false, that means it was recently refreshed and the caller should
35+
// wait the returned duration before trying again.
36+
func (l *Set) AllowRefresh(now time.Time) (bool, time.Duration) {
37+
l.entitlementsMu.RLock()
38+
defer l.entitlementsMu.RUnlock()
39+
40+
diff := now.Sub(l.entitlements.RefreshedAt)
41+
if diff < time.Minute {
42+
return false, time.Minute - diff
43+
}
44+
45+
return true, 0
46+
}
47+
48+
func (l *Set) Feature(name codersdk.FeatureName) (codersdk.Feature, bool) {
49+
l.entitlementsMu.RLock()
50+
defer l.entitlementsMu.RUnlock()
51+
52+
f, ok := l.entitlements.Features[name]
53+
return f, ok
54+
}
55+
56+
func (l *Set) Enabled(feature codersdk.FeatureName) bool {
57+
l.entitlementsMu.RLock()
58+
defer l.entitlementsMu.RUnlock()
59+
60+
f, ok := l.entitlements.Features[feature]
61+
if !ok {
62+
return false
63+
}
64+
return f.Enabled
65+
}
66+
67+
// AsJSON is used to return this to the api without exposing the entitlements for
68+
// mutation.
69+
func (l *Set) AsJSON() json.RawMessage {
70+
l.entitlementsMu.RLock()
71+
defer l.entitlementsMu.RUnlock()
72+
73+
b, _ := json.Marshal(l.entitlements)
74+
return b
75+
}
76+
77+
func (l *Set) Replace(entitlements codersdk.Entitlements) {
78+
l.entitlementsMu.Lock()
79+
defer l.entitlementsMu.Unlock()
80+
81+
l.entitlements = entitlements
82+
}
83+
84+
func (l *Set) Update(do func(entitlements *codersdk.Entitlements)) {
85+
l.entitlementsMu.Lock()
86+
defer l.entitlementsMu.Unlock()
87+
88+
do(&l.entitlements)
89+
}
90+
91+
func (l *Set) FeatureChanged(featureName codersdk.FeatureName, newFeature codersdk.Feature) (initial, changed, enabled bool) {
92+
l.entitlementsMu.RLock()
93+
defer l.entitlementsMu.RUnlock()
94+
95+
oldFeature := l.entitlements.Features[featureName]
96+
if oldFeature.Enabled != newFeature.Enabled {
97+
return false, true, newFeature.Enabled
98+
}
99+
return false, false, newFeature.Enabled
100+
}
101+
102+
func (l *Set) WriteEntitlementWarningHeaders(header http.Header) {
103+
l.entitlementsMu.RLock()
104+
defer l.entitlementsMu.RUnlock()
105+
106+
for _, warning := range l.entitlements.Warnings {
107+
header.Add(codersdk.EntitlementsWarningHeader, warning)
108+
}
109+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package entitlements_test
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/stretchr/testify/require"
8+
9+
"github.com/coder/coder/v2/coderd/entitlements"
10+
"github.com/coder/coder/v2/codersdk"
11+
)
12+
13+
func TestUpdate(t *testing.T) {
14+
t.Parallel()
15+
16+
set := entitlements.New()
17+
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
18+
19+
set.Update(func(entitlements *codersdk.Entitlements) {
20+
entitlements.Features[codersdk.FeatureMultipleOrganizations] = codersdk.Feature{
21+
Enabled: true,
22+
Entitlement: codersdk.EntitlementEntitled,
23+
}
24+
})
25+
require.True(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
26+
}
27+
28+
func TestAllowRefresh(t *testing.T) {
29+
t.Parallel()
30+
31+
now := time.Now()
32+
set := entitlements.New()
33+
set.Update(func(entitlements *codersdk.Entitlements) {
34+
entitlements.RefreshedAt = now
35+
})
36+
37+
ok, wait := set.AllowRefresh(now)
38+
require.False(t, ok)
39+
require.InDelta(t, time.Minute.Seconds(), wait.Seconds(), 5)
40+
41+
set.Update(func(entitlements *codersdk.Entitlements) {
42+
entitlements.RefreshedAt = now.Add(time.Minute * -2)
43+
})
44+
45+
ok, wait = set.AllowRefresh(now)
46+
require.True(t, ok)
47+
require.Equal(t, time.Duration(0), wait)
48+
}
49+
50+
func TestReplace(t *testing.T) {
51+
t.Parallel()
52+
53+
set := entitlements.New()
54+
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
55+
set.Replace(codersdk.Entitlements{
56+
Features: map[codersdk.FeatureName]codersdk.Feature{
57+
codersdk.FeatureMultipleOrganizations: {
58+
Enabled: true,
59+
},
60+
},
61+
})
62+
require.True(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
63+
}

codersdk/deployment.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ const (
3535
EntitlementNotEntitled Entitlement = "not_entitled"
3636
)
3737

38+
// Entitled returns if the entitlement can be used. So this is true if it
39+
// is entitled or still in it's grace period.
40+
func (e Entitlement) Entitled() bool {
41+
return e == EntitlementEntitled || e == EntitlementGracePeriod
42+
}
43+
3844
// Weight converts the enum types to a numerical value for easier
3945
// comparisons. Easier than sets of if statements.
4046
func (e Entitlement) Weight() int {

enterprise/coderd/coderd.go

Lines changed: 30 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/coder/coder/v2/buildinfo"
1616
"github.com/coder/coder/v2/coderd/appearance"
1717
"github.com/coder/coder/v2/coderd/database"
18+
"github.com/coder/coder/v2/coderd/entitlements"
1819
agplportsharing "github.com/coder/coder/v2/coderd/portsharing"
1920
"github.com/coder/coder/v2/coderd/rbac/policy"
2021
"github.com/coder/coder/v2/enterprise/coderd/portsharing"
@@ -103,19 +104,26 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
103104
}
104105
return nil, xerrors.Errorf("init database encryption: %w", err)
105106
}
107+
108+
entitlementsSet := entitlements.New()
106109
options.Database = cryptDB
107110
api := &API{
108-
ctx: ctx,
109-
cancel: cancelFunc,
110-
Options: options,
111+
ctx: ctx,
112+
cancel: cancelFunc,
113+
Options: options,
114+
entitlements: entitlementsSet,
111115
provisionerDaemonAuth: &provisionerDaemonAuth{
112116
psk: options.ProvisionerDaemonPSK,
113117
authorizer: options.Authorizer,
114118
db: options.Database,
115119
},
120+
licenseMetricsCollector: &license.MetricsCollector{
121+
Entitlements: entitlementsSet,
122+
},
116123
}
117124
// This must happen before coderd initialization!
118125
options.PostAuthAdditionalHeadersFunc = api.writeEntitlementWarningsHeader
126+
options.Options.Entitlements = api.entitlements
119127
api.AGPL = coderd.New(options.Options)
120128
defer func() {
121129
if err != nil {
@@ -493,7 +501,7 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
493501
}
494502
api.AGPL.WorkspaceProxiesFetchUpdater.Store(&fetchUpdater)
495503

496-
err = api.PrometheusRegistry.Register(&api.licenseMetricsCollector)
504+
err = api.PrometheusRegistry.Register(api.licenseMetricsCollector)
497505
if err != nil {
498506
return nil, xerrors.Errorf("unable to register license metrics collector")
499507
}
@@ -553,13 +561,11 @@ type API struct {
553561
// ProxyHealth checks the reachability of all workspace proxies.
554562
ProxyHealth *proxyhealth.ProxyHealth
555563

556-
entitlementsUpdateMu sync.Mutex
557-
entitlementsMu sync.RWMutex
558-
entitlements codersdk.Entitlements
564+
entitlements *entitlements.Set
559565

560566
provisionerDaemonAuth *provisionerDaemonAuth
561567

562-
licenseMetricsCollector license.MetricsCollector
568+
licenseMetricsCollector *license.MetricsCollector
563569
tailnetService *tailnet.ClientService
564570
}
565571

@@ -588,11 +594,8 @@ func (api *API) writeEntitlementWarningsHeader(a rbac.Subject, header http.Heade
588594
// has no roles. This is a normal user!
589595
return
590596
}
591-
api.entitlementsMu.RLock()
592-
defer api.entitlementsMu.RUnlock()
593-
for _, warning := range api.entitlements.Warnings {
594-
header.Add(codersdk.EntitlementsWarningHeader, warning)
595-
}
597+
598+
api.entitlements.WriteEntitlementWarningHeaders(header)
596599
}
597600

598601
func (api *API) Close() error {
@@ -614,9 +617,6 @@ func (api *API) Close() error {
614617
}
615618

616619
func (api *API) updateEntitlements(ctx context.Context) error {
617-
api.entitlementsUpdateMu.Lock()
618-
defer api.entitlementsUpdateMu.Unlock()
619-
620620
replicas := api.replicaManager.AllPrimary()
621621
agedReplicas := make([]database.Replica, 0, len(replicas))
622622
for _, replica := range replicas {
@@ -632,7 +632,7 @@ func (api *API) updateEntitlements(ctx context.Context) error {
632632
agedReplicas = append(agedReplicas, replica)
633633
}
634634

635-
entitlements, err := license.Entitlements(
635+
reloadedEntitlements, err := license.Entitlements(
636636
ctx, api.Database,
637637
len(agedReplicas), len(api.ExternalAuthConfigs), api.LicenseKeys, map[codersdk.FeatureName]bool{
638638
codersdk.FeatureAuditLog: api.AuditLogging,
@@ -652,29 +652,24 @@ func (api *API) updateEntitlements(ctx context.Context) error {
652652
return err
653653
}
654654

655-
if entitlements.RequireTelemetry && !api.DeploymentValues.Telemetry.Enable.Value() {
655+
if reloadedEntitlements.RequireTelemetry && !api.DeploymentValues.Telemetry.Enable.Value() {
656656
// We can't fail because then the user couldn't remove the offending
657657
// license w/o a restart.
658658
//
659659
// We don't simply append to entitlement.Errors since we don't want any
660660
// enterprise features enabled.
661-
api.entitlements.Errors = []string{
662-
"License requires telemetry but telemetry is disabled",
663-
}
661+
api.entitlements.Update(func(entitlements *codersdk.Entitlements) {
662+
entitlements.Errors = []string{
663+
"License requires telemetry but telemetry is disabled",
664+
}
665+
})
666+
664667
api.Logger.Error(ctx, "license requires telemetry enabled")
665668
return nil
666669
}
667670

668671
featureChanged := func(featureName codersdk.FeatureName) (initial, changed, enabled bool) {
669-
if api.entitlements.Features == nil {
670-
return true, false, entitlements.Features[featureName].Enabled
671-
}
672-
oldFeature := api.entitlements.Features[featureName]
673-
newFeature := entitlements.Features[featureName]
674-
if oldFeature.Enabled != newFeature.Enabled {
675-
return false, true, newFeature.Enabled
676-
}
677-
return false, false, newFeature.Enabled
672+
return api.entitlements.FeatureChanged(featureName, reloadedEntitlements.Features[featureName])
678673
}
679674

680675
shouldUpdate := func(initial, changed, enabled bool) bool {
@@ -831,20 +826,16 @@ func (api *API) updateEntitlements(ctx context.Context) error {
831826
}
832827

833828
// External token encryption is soft-enforced
834-
featureExternalTokenEncryption := entitlements.Features[codersdk.FeatureExternalTokenEncryption]
829+
featureExternalTokenEncryption := reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption]
835830
featureExternalTokenEncryption.Enabled = len(api.ExternalTokenEncryption) > 0
836831
if featureExternalTokenEncryption.Enabled && featureExternalTokenEncryption.Entitlement != codersdk.EntitlementEntitled {
837832
msg := fmt.Sprintf("%s is enabled (due to setting external token encryption keys) but your license is not entitled to this feature.", codersdk.FeatureExternalTokenEncryption.Humanize())
838833
api.Logger.Warn(ctx, msg)
839-
entitlements.Warnings = append(entitlements.Warnings, msg)
834+
reloadedEntitlements.Warnings = append(reloadedEntitlements.Warnings, msg)
840835
}
841-
entitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption
836+
reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption
842837

843-
api.entitlementsMu.Lock()
844-
defer api.entitlementsMu.Unlock()
845-
api.entitlements = entitlements
846-
api.licenseMetricsCollector.Entitlements.Store(&entitlements)
847-
api.AGPL.SiteHandler.Entitlements.Store(&entitlements)
838+
api.entitlements.Replace(reloadedEntitlements)
848839
return nil
849840
}
850841

@@ -1024,10 +1015,7 @@ func derpMapper(logger slog.Logger, proxyHealth *proxyhealth.ProxyHealth) func(*
10241015
// @Router /entitlements [get]
10251016
func (api *API) serveEntitlements(rw http.ResponseWriter, r *http.Request) {
10261017
ctx := r.Context()
1027-
api.entitlementsMu.RLock()
1028-
entitlements := api.entitlements
1029-
api.entitlementsMu.RUnlock()
1030-
httpapi.Write(ctx, rw, http.StatusOK, entitlements)
1018+
httpapi.Write(ctx, rw, http.StatusOK, api.entitlements.AsJSON())
10311019
}
10321020

10331021
func (api *API) runEntitlementsLoop(ctx context.Context) {

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy