Skip to content

Commit 964bfe9

Browse files
committed
review
1 parent b11895f commit 964bfe9

File tree

10 files changed

+215
-96
lines changed

10 files changed

+215
-96
lines changed

cli/server.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -722,13 +722,6 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
722722
options.Database = dbmetrics.NewDBMetrics(options.Database, options.Logger, options.PrometheusRegistry)
723723
}
724724

725-
wsUpdates, err := coderd.NewUpdatesProvider(logger.Named("workspace_updates"), options.Pubsub, options.Database, options.Authorizer)
726-
if err != nil {
727-
return xerrors.Errorf("create workspace updates provider: %w", err)
728-
}
729-
options.WorkspaceUpdatesProvider = wsUpdates
730-
defer wsUpdates.Close()
731-
732725
var deploymentID string
733726
err = options.Database.InTx(func(tx database.Store) error {
734727
// This will block until the lock is acquired, and will be

coderd/coderd.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,13 @@ func New(options *Options) *API {
495495
}
496496
}
497497

498+
if options.WorkspaceUpdatesProvider == nil {
499+
options.WorkspaceUpdatesProvider, err = NewUpdatesProvider(options.Logger.Named("workspace_updates"), options.Pubsub, options.Database, options.Authorizer)
500+
if err != nil {
501+
options.Logger.Critical(ctx, "failed to properly instantiate workspace updates provider", slog.Error(err))
502+
}
503+
}
504+
498505
// Start a background process that rotates keys. We intentionally start this after the caches
499506
// are created to force initial requests for a key to populate the caches. This helps catch
500507
// bugs that may only occur when a key isn't precached in tests and the latency cost is minimal.
@@ -1495,6 +1502,7 @@ func (api *API) Close() error {
14951502
_ = api.OIDCConvertKeyCache.Close()
14961503
_ = api.AppSigningKeyCache.Close()
14971504
_ = api.AppEncryptionKeyCache.Close()
1505+
_ = api.WorkspaceUpdatesProvider.Close()
14981506
return nil
14991507
}
15001508

coderd/workspaceupdates.go

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414
"github.com/coder/coder/v2/coderd/database/dbauthz"
1515
"github.com/coder/coder/v2/coderd/database/pubsub"
1616
"github.com/coder/coder/v2/coderd/rbac"
17-
"github.com/coder/coder/v2/coderd/rbac/policy"
1817
"github.com/coder/coder/v2/coderd/util/slice"
1918
"github.com/coder/coder/v2/coderd/wspubsub"
2019
"github.com/coder/coder/v2/codersdk"
@@ -23,7 +22,8 @@ import (
2322
)
2423

2524
type UpdatesQuerier interface {
26-
GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx context.Context, ownerID uuid.UUID, prep rbac.PreparedAuthorized) ([]database.GetWorkspacesAndAgentsByOwnerIDRow, error)
25+
// GetAuthorizedWorkspacesAndAgentsByOwnerID requires a context with an actor set
26+
GetWorkspacesAndAgentsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.GetWorkspacesAndAgentsByOwnerIDRow, error)
2727
GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.Workspace, error)
2828
}
2929

@@ -45,11 +45,10 @@ type sub struct {
4545
ctx context.Context
4646
cancelFn context.CancelFunc
4747

48-
mu sync.RWMutex
49-
userID uuid.UUID
50-
ch chan *proto.WorkspaceUpdate
51-
prev workspacesByID
52-
readPrep rbac.PreparedAuthorized
48+
mu sync.RWMutex
49+
userID uuid.UUID
50+
ch chan *proto.WorkspaceUpdate
51+
prev workspacesByID
5352

5453
db UpdatesQuerier
5554
ps pubsub.Pubsub
@@ -76,7 +75,7 @@ func (s *sub) handleEvent(ctx context.Context, event wspubsub.WorkspaceEvent, er
7675
}
7776
}
7877

79-
rows, err := s.db.GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx, s.userID, s.readPrep)
78+
rows, err := s.db.GetWorkspacesAndAgentsByOwnerID(ctx, s.userID)
8079
if err != nil {
8180
s.logger.Warn(ctx, "failed to get workspaces and agents by owner ID", slog.Error(err))
8281
return
@@ -97,7 +96,7 @@ func (s *sub) handleEvent(ctx context.Context, event wspubsub.WorkspaceEvent, er
9796
}
9897

9998
func (s *sub) start(ctx context.Context) (err error) {
100-
rows, err := s.db.GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx, s.userID, s.readPrep)
99+
rows, err := s.db.GetWorkspacesAndAgentsByOwnerID(ctx, s.userID)
101100
if err != nil {
102101
return xerrors.Errorf("get workspaces and agents by owner ID: %w", err)
103102
}
@@ -168,17 +167,17 @@ func (u *updatesProvider) Close() error {
168167
return nil
169168
}
170169

170+
// Subscribe subscribes to workspace updates for a user, for the workspaces
171+
// that user is authorized to `ActionRead` on. The provided context must have
172+
// a dbauthz actor set.
171173
func (u *updatesProvider) Subscribe(ctx context.Context, userID uuid.UUID) (tailnet.Subscription, error) {
172174
actor, ok := dbauthz.ActorFromContext(ctx)
173175
if !ok {
174176
return nil, xerrors.Errorf("actor not found in context")
175177
}
176-
readPrep, err := u.auth.Prepare(ctx, actor, policy.ActionRead, rbac.ResourceWorkspace.Type)
177-
if err != nil {
178-
return nil, xerrors.Errorf("prepare read action: %w", err)
179-
}
178+
ctx, cancel := context.WithCancel(u.ctx)
179+
ctx = dbauthz.As(ctx, actor)
180180
ch := make(chan *proto.WorkspaceUpdate, 1)
181-
ctx, cancel := context.WithCancel(ctx)
182181
sub := &sub{
183182
ctx: ctx,
184183
cancelFn: cancel,
@@ -188,9 +187,8 @@ func (u *updatesProvider) Subscribe(ctx context.Context, userID uuid.UUID) (tail
188187
ps: u.ps,
189188
logger: u.logger.Named(fmt.Sprintf("workspace_updates_subscriber_%s", userID)),
190189
prev: workspacesByID{},
191-
readPrep: readPrep,
192190
}
193-
err = sub.start(ctx)
191+
err := sub.start(ctx)
194192
if err != nil {
195193
_ = sub.Close()
196194
return nil, err

coderd/workspaceupdates_test.go

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,23 @@ import (
2525

2626
func TestWorkspaceUpdates(t *testing.T) {
2727
t.Parallel()
28-
ctx := context.Background()
2928

30-
ws1ID := uuid.New()
29+
ws1ID := uuid.UUID{0x01}
3130
ws1IDSlice := tailnet.UUIDToByteSlice(ws1ID)
32-
agent1ID := uuid.New()
31+
agent1ID := uuid.UUID{0x02}
3332
agent1IDSlice := tailnet.UUIDToByteSlice(agent1ID)
34-
ws2ID := uuid.New()
33+
ws2ID := uuid.UUID{0x03}
3534
ws2IDSlice := tailnet.UUIDToByteSlice(ws2ID)
36-
ws3ID := uuid.New()
35+
ws3ID := uuid.UUID{0x04}
3736
ws3IDSlice := tailnet.UUIDToByteSlice(ws3ID)
38-
agent2ID := uuid.New()
37+
agent2ID := uuid.UUID{0x05}
3938
agent2IDSlice := tailnet.UUIDToByteSlice(agent2ID)
40-
ws4ID := uuid.New()
39+
ws4ID := uuid.UUID{0x06}
4140
ws4IDSlice := tailnet.UUIDToByteSlice(ws4ID)
41+
agent3ID := uuid.UUID{0x07}
42+
agent3IDSlice := tailnet.UUIDToByteSlice(agent3ID)
4243

43-
ownerID := uuid.New()
44+
ownerID := uuid.UUID{0x08}
4445
memberRole, err := rbac.RoleByName(rbac.RoleMember())
4546
require.NoError(t, err)
4647
ownerSubject := rbac.Subject{
@@ -53,9 +54,11 @@ func TestWorkspaceUpdates(t *testing.T) {
5354
t.Run("Basic", func(t *testing.T) {
5455
t.Parallel()
5556

57+
ctx := testutil.Context(t, testutil.WaitShort)
58+
5659
db := &mockWorkspaceStore{
5760
orderedRows: []database.GetWorkspacesAndAgentsByOwnerIDRow{
58-
// Gains a new agent
61+
// Gains agent2
5962
{
6063
ID: ws1ID,
6164
Name: "ws1",
@@ -81,6 +84,12 @@ func TestWorkspaceUpdates(t *testing.T) {
8184
Name: "ws3",
8285
JobStatus: database.ProvisionerJobStatusSucceeded,
8386
Transition: database.WorkspaceTransitionStop,
87+
Agents: []database.AgentIDNamePair{
88+
{
89+
ID: agent3ID,
90+
Name: "agent3",
91+
},
92+
},
8493
},
8594
},
8695
}
@@ -97,13 +106,15 @@ func TestWorkspaceUpdates(t *testing.T) {
97106

98107
sub, err := updateProvider.Subscribe(dbauthz.As(ctx, ownerSubject), ownerID)
99108
require.NoError(t, err)
100-
ch := sub.Updates()
109+
defer sub.Close()
101110

102-
update, ok := <-ch
103-
require.True(t, ok)
111+
update := testutil.RequireRecvCtx(ctx, t, sub.Updates())
104112
slices.SortFunc(update.UpsertedWorkspaces, func(a, b *proto.Workspace) int {
105113
return strings.Compare(a.Name, b.Name)
106114
})
115+
slices.SortFunc(update.UpsertedAgents, func(a, b *proto.Agent) int {
116+
return strings.Compare(a.Name, b.Name)
117+
})
107118
require.Equal(t, &proto.WorkspaceUpdate{
108119
UpsertedWorkspaces: []*proto.Workspace{
109120
{
@@ -128,6 +139,11 @@ func TestWorkspaceUpdates(t *testing.T) {
128139
Name: "agent1",
129140
WorkspaceId: ws1IDSlice,
130141
},
142+
{
143+
Id: agent3IDSlice,
144+
Name: "agent3",
145+
WorkspaceId: ws3IDSlice,
146+
},
131147
},
132148
DeletedWorkspaces: []*proto.Workspace{},
133149
DeletedAgents: []*proto.Agent{},
@@ -169,8 +185,7 @@ func TestWorkspaceUpdates(t *testing.T) {
169185
WorkspaceID: ws1ID,
170186
})
171187

172-
update, ok = <-ch
173-
require.True(t, ok)
188+
update = testutil.RequireRecvCtx(ctx, t, sub.Updates())
174189
slices.SortFunc(update.UpsertedWorkspaces, func(a, b *proto.Workspace) int {
175190
return strings.Compare(a.Name, b.Name)
176191
})
@@ -203,13 +218,21 @@ func TestWorkspaceUpdates(t *testing.T) {
203218
Status: proto.Workspace_STOPPED,
204219
},
205220
},
206-
DeletedAgents: []*proto.Agent{},
221+
DeletedAgents: []*proto.Agent{
222+
{
223+
Id: agent3IDSlice,
224+
Name: "agent3",
225+
WorkspaceId: ws3IDSlice,
226+
},
227+
},
207228
}, update)
208229
})
209230

210231
t.Run("Resubscribe", func(t *testing.T) {
211232
t.Parallel()
212233

234+
ctx := testutil.Context(t, testutil.WaitShort)
235+
213236
db := &mockWorkspaceStore{
214237
orderedRows: []database.GetWorkspacesAndAgentsByOwnerIDRow{
215238
{
@@ -290,7 +313,7 @@ type mockWorkspaceStore struct {
290313
}
291314

292315
// GetAuthorizedWorkspacesAndAgentsByOwnerID implements coderd.UpdatesQuerier.
293-
func (m *mockWorkspaceStore) GetAuthorizedWorkspacesAndAgentsByOwnerID(context.Context, uuid.UUID, rbac.PreparedAuthorized) ([]database.GetWorkspacesAndAgentsByOwnerIDRow, error) {
316+
func (m *mockWorkspaceStore) GetWorkspacesAndAgentsByOwnerID(context.Context, uuid.UUID) ([]database.GetWorkspacesAndAgentsByOwnerIDRow, error) {
294317
return m.orderedRows, nil
295318
}
296319

enterprise/tailnet/connio.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ var errDisconnect = xerrors.New("graceful disconnect")
133133

134134
func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
135135
c.logger.Debug(c.peerCtx, "got request")
136-
err := c.auth.Authorize(c.coordCtx, req)
136+
err := c.auth.Authorize(c.peerCtx, req)
137137
if err != nil {
138138
c.logger.Warn(c.peerCtx, "unauthorized request", slog.Error(err))
139139
return xerrors.Errorf("authorize request: %w", err)

enterprise/tailnet/pgcoord_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,42 @@ func TestPGCoordinatorDual_PeerReconnect(t *testing.T) {
913913
p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED)
914914
}
915915

916+
// TestPGCoordinatorPropogatedPeerContext tests that the context for a specific peer
917+
// is propogated through to the `Authorize` method of the coordinatee auth
918+
func TestPGCoordinatorPropogatedPeerContext(t *testing.T) {
919+
t.Parallel()
920+
921+
if !dbtestutil.WillUsePostgres() {
922+
t.Skip("test only with postgres")
923+
}
924+
925+
ctx := testutil.Context(t, testutil.WaitShort)
926+
store, ps := dbtestutil.NewDB(t)
927+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
928+
929+
peerCtx := context.WithValue(ctx, agpltest.FakeSubjectKey{}, struct{}{})
930+
peerID := uuid.UUID{0x01}
931+
agentID := uuid.UUID{0x02}
932+
933+
c1, err := tailnet.NewPGCoord(ctx, logger, ps, store)
934+
require.NoError(t, err)
935+
defer func() {
936+
err := c1.Close()
937+
require.NoError(t, err)
938+
}()
939+
940+
ch := make(chan struct{})
941+
auth := agpltest.FakeCoordinateeAuth{
942+
Chan: ch,
943+
}
944+
945+
reqs, _ := c1.Coordinate(peerCtx, peerID, "peer1", auth)
946+
947+
testutil.RequireSendCtx(ctx, t, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agpl.UUIDToByteSlice(agentID)}})
948+
949+
_ = testutil.RequireRecvCtx(ctx, t, ch)
950+
}
951+
916952
func assertEventuallyStatus(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID, status database.TailnetStatus) {
917953
t.Helper()
918954
assert.Eventually(t, func() bool {

tailnet/coordinator_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,3 +529,36 @@ func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
529529
defer f.Unlock()
530530
f.callback = callback
531531
}
532+
533+
// TestCoordinatorPropogatedPeerContext tests that the context for a specific peer
534+
// is propogated through to the `Authorize“ method of the coordinatee auth
535+
func TestCoordinatorPropogatedPeerContext(t *testing.T) {
536+
t.Parallel()
537+
538+
ctx := testutil.Context(t, testutil.WaitShort)
539+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
540+
541+
peerCtx := context.WithValue(ctx, test.FakeSubjectKey{}, struct{}{})
542+
peerCtx, peerCtxCancel := context.WithCancel(peerCtx)
543+
peerID := uuid.UUID{0x01}
544+
agentID := uuid.UUID{0x02}
545+
546+
c1 := tailnet.NewCoordinator(logger)
547+
t.Cleanup(func() {
548+
err := c1.Close()
549+
require.NoError(t, err)
550+
})
551+
552+
ch := make(chan struct{})
553+
auth := test.FakeCoordinateeAuth{
554+
Chan: ch,
555+
}
556+
557+
reqs, _ := c1.Coordinate(peerCtx, peerID, "peer1", auth)
558+
559+
testutil.RequireSendCtx(ctx, t, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tailnet.UUIDToByteSlice(agentID)}})
560+
_ = testutil.RequireRecvCtx(ctx, t, ch)
561+
// If we don't cancel the context, the coordinator close will wait until the
562+
// peer request loop finishes, which will be after the timeout
563+
peerCtxCancel()
564+
}

tailnet/service.go

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -220,28 +220,15 @@ func (s *DRPCService) WorkspaceUpdates(req *proto.WorkspaceUpdatesRequest, strea
220220
defer stream.Close()
221221

222222
ctx := stream.Context()
223-
streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID)
224-
if !ok {
225-
return xerrors.New("no Stream ID")
226-
}
227223

228224
ownerID, err := uuid.FromBytes(req.WorkspaceOwnerId)
229225
if err != nil {
230226
return xerrors.Errorf("parse workspace owner ID: %w", err)
231227
}
232228

233-
var sub Subscription
234-
switch auth := streamID.Auth.(type) {
235-
case ClientUserCoordinateeAuth:
236-
sub, err = s.WorkspaceUpdatesProvider.Subscribe(ctx, ownerID)
237-
if err != nil {
238-
err = xerrors.Errorf("subscribe to workspace updates: %w", err)
239-
}
240-
default:
241-
err = xerrors.Errorf("workspace updates not supported by auth name %T", auth)
242-
}
229+
sub, err := s.WorkspaceUpdatesProvider.Subscribe(ctx, ownerID)
243230
if err != nil {
244-
return err
231+
return xerrors.Errorf("subscribe to workspace updates: %w", err)
245232
}
246233
defer sub.Close()
247234

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy