Skip to content

Commit d0f1a9a

Browse files
committed
feat: add support for WorkspaceUpdates to WebsocketDialer
1 parent 40fd326 commit d0f1a9a

File tree

8 files changed

+305
-63
lines changed

8 files changed

+305
-63
lines changed

Makefile

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,13 @@ DB_GEN_FILES := \
482482
coderd/database/dbauthz/dbauthz.go \
483483
coderd/database/dbmock/dbmock.go
484484

485+
TAILNETTEST_MOCKS := \
486+
tailnet/tailnettest/coordinatormock.go \
487+
tailnet/tailnettest/coordinateemock.go \
488+
tailnet/tailnettest/workspaceupdatesprovidermock.go \
489+
tailnet/tailnettest/subscriptionmock.go
490+
491+
485492
# all gen targets should be added here and to gen/mark-fresh
486493
gen: \
487494
tailnet/proto/tailnet.pb.go \
@@ -505,8 +512,7 @@ gen: \
505512
site/e2e/provisionerGenerated.ts \
506513
site/src/theme/icons.json \
507514
examples/examples.gen.json \
508-
tailnet/tailnettest/coordinatormock.go \
509-
tailnet/tailnettest/coordinateemock.go \
515+
$(TAILNETTEST_MOCKS) \
510516
coderd/database/pubsub/psmock/psmock.go
511517
.PHONY: gen
512518

@@ -534,8 +540,7 @@ gen/mark-fresh:
534540
site/e2e/provisionerGenerated.ts \
535541
site/src/theme/icons.json \
536542
examples/examples.gen.json \
537-
tailnet/tailnettest/coordinatormock.go \
538-
tailnet/tailnettest/coordinateemock.go \
543+
$(TAILNETTEST_MOCKS) \
539544
coderd/database/pubsub/psmock/psmock.go \
540545
"
541546

@@ -568,7 +573,7 @@ coderd/database/dbmock/dbmock.go: coderd/database/db.go coderd/database/querier.
568573
coderd/database/pubsub/psmock/psmock.go: coderd/database/pubsub/pubsub.go
569574
go generate ./coderd/database/pubsub/psmock
570575

571-
tailnet/tailnettest/coordinatormock.go tailnet/tailnettest/coordinateemock.go: tailnet/coordinator.go
576+
$(TAILNETTEST_MOCKS): tailnet/coordinator.go tailnet/service.go
572577
go generate ./tailnet/tailnettest/
573578

574579
tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto

codersdk/workspacesdk/dialer.go

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,26 @@ var permanentErrorStatuses = []int{
2525
}
2626

2727
type WebsocketDialer struct {
28-
logger slog.Logger
29-
dialOptions *websocket.DialOptions
30-
url *url.URL
28+
logger slog.Logger
29+
dialOptions *websocket.DialOptions
30+
url *url.URL
31+
// workspaceUpdatesReq != nil means that the dialer should call the WorkspaceUpdates RPC and
32+
// return the corresponding client
33+
workspaceUpdatesReq *proto.WorkspaceUpdatesRequest
34+
3135
resumeTokenFailed bool
3236
connected chan error
3337
isFirst bool
3438
}
3539

40+
type WebsocketDialerOption func(*WebsocketDialer)
41+
42+
func WithWorkspaceUpdates(req *proto.WorkspaceUpdatesRequest) WebsocketDialerOption {
43+
return func(w *WebsocketDialer) {
44+
w.workspaceUpdatesReq = req
45+
}
46+
}
47+
3648
func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenController,
3749
) (
3850
tailnet.ControlProtocolClients, error,
@@ -41,14 +53,27 @@ func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenControl
4153

4254
u := new(url.URL)
4355
*u = *w.url
56+
q := u.Query()
4457
if r != nil && !w.resumeTokenFailed {
4558
if token, ok := r.Token(); ok {
46-
q := u.Query()
4759
q.Set("resume_token", token)
48-
u.RawQuery = q.Encode()
4960
w.logger.Debug(ctx, "using resume token on dial")
5061
}
5162
}
63+
// The current version includes additions
64+
//
65+
// 2.1 GetAnnouncementBanners on the Agent API (version locked to Tailnet API)
66+
// 2.2 PostTelemetry on the Tailnet API
67+
// 2.3 RefreshResumeToken, WorkspaceUpdates
68+
//
69+
// Resume tokens and telemetry are optional, and fail gracefully. So we use version 2.0 for
70+
// maximum compatibility if we don't need WorkspaceUpdates. If we do, we use 2.3.
71+
if w.workspaceUpdatesReq != nil {
72+
q.Add("version", "2.3")
73+
} else {
74+
q.Add("version", "2.0")
75+
}
76+
u.RawQuery = q.Encode()
5277

5378
// nolint:bodyclose
5479
ws, res, err := websocket.Dial(ctx, u.String(), w.dialOptions)
@@ -115,25 +140,43 @@ func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenControl
115140
return tailnet.ControlProtocolClients{}, err
116141
}
117142

143+
var updates tailnet.WorkspaceUpdatesClient
144+
if w.workspaceUpdatesReq != nil {
145+
updates, err = client.WorkspaceUpdates(context.Background(), w.workspaceUpdatesReq)
146+
if err != nil {
147+
w.logger.Debug(ctx, "failed to create WorkspaceUpdates stream", slog.Error(err))
148+
_ = ws.Close(websocket.StatusInternalError, "")
149+
return tailnet.ControlProtocolClients{}, err
150+
}
151+
}
152+
118153
return tailnet.ControlProtocolClients{
119-
Closer: client.DRPCConn(),
120-
Coordinator: coord,
121-
DERP: derps,
122-
ResumeToken: client,
123-
Telemetry: client,
154+
Closer: client.DRPCConn(),
155+
Coordinator: coord,
156+
DERP: derps,
157+
ResumeToken: client,
158+
Telemetry: client,
159+
WorkspaceUpdates: updates,
124160
}, nil
125161
}
126162

127163
func (w *WebsocketDialer) Connected() <-chan error {
128164
return w.connected
129165
}
130166

131-
func NewWebsocketDialer(logger slog.Logger, u *url.URL, opts *websocket.DialOptions) *WebsocketDialer {
132-
return &WebsocketDialer{
167+
func NewWebsocketDialer(
168+
logger slog.Logger, u *url.URL, websocketOptions *websocket.DialOptions,
169+
dialerOptions ...WebsocketDialerOption,
170+
) *WebsocketDialer {
171+
w := &WebsocketDialer{
133172
logger: logger,
134-
dialOptions: opts,
173+
dialOptions: websocketOptions,
135174
url: u,
136175
connected: make(chan error, 1),
137176
isFirst: true,
138177
}
178+
for _, o := range dialerOptions {
179+
o(w)
180+
}
181+
return w
139182
}

codersdk/workspacesdk/dialer_test.go

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ import (
99
"testing"
1010
"time"
1111

12+
"github.com/google/uuid"
1213
"github.com/stretchr/testify/assert"
1314
"github.com/stretchr/testify/require"
15+
"go.uber.org/mock/gomock"
1416
"nhooyr.io/websocket"
1517
"tailscale.com/tailcfg"
1618

@@ -21,7 +23,7 @@ import (
2123
"github.com/coder/coder/v2/codersdk"
2224
"github.com/coder/coder/v2/codersdk/workspacesdk"
2325
"github.com/coder/coder/v2/tailnet"
24-
"github.com/coder/coder/v2/tailnet/proto"
26+
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
2527
"github.com/coder/coder/v2/tailnet/tailnettest"
2628
"github.com/coder/coder/v2/testutil"
2729
)
@@ -102,6 +104,7 @@ func TestWebsocketDialer_TokenController(t *testing.T) {
102104
require.Equal(t, "", gotToken)
103105

104106
clients = testutil.RequireRecvCtx(ctx, t, clientCh)
107+
require.Nil(t, clients.WorkspaceUpdates)
105108
clients.Closer.Close()
106109

107110
err = testutil.RequireRecvCtx(ctx, t, wsErr)
@@ -273,7 +276,7 @@ func TestWebsocketDialer_UplevelVersion(t *testing.T) {
273276
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
274277

275278
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
276-
sVer := apiversion.New(proto.CurrentMajor, proto.CurrentMinor-1)
279+
sVer := apiversion.New(2, 2)
277280

278281
// the following matches what Coderd does;
279282
// c.f. coderd/workspaceagents.go: workspaceAgentClientCoordinate
@@ -291,7 +294,10 @@ func TestWebsocketDialer_UplevelVersion(t *testing.T) {
291294
svrURL, err := url.Parse(svr.URL)
292295
require.NoError(t, err)
293296

294-
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
297+
uut := workspacesdk.NewWebsocketDialer(
298+
logger, svrURL, &websocket.DialOptions{},
299+
workspacesdk.WithWorkspaceUpdates(&tailnetproto.WorkspaceUpdatesRequest{}),
300+
)
295301

296302
errCh := make(chan error, 1)
297303
go func() {
@@ -307,6 +313,84 @@ func TestWebsocketDialer_UplevelVersion(t *testing.T) {
307313
require.NotEmpty(t, sdkErr.Helper)
308314
}
309315

316+
func TestWebsocketDialer_WorkspaceUpdates(t *testing.T) {
317+
t.Parallel()
318+
ctx := testutil.Context(t, testutil.WaitShort)
319+
logger := slogtest.Make(t, &slogtest.Options{
320+
IgnoreErrors: true,
321+
}).Leveled(slog.LevelDebug)
322+
323+
fCoord := tailnettest.NewFakeCoordinator()
324+
var coord tailnet.Coordinator = fCoord
325+
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
326+
coordPtr.Store(&coord)
327+
ctrl := gomock.NewController(t)
328+
mProvider := tailnettest.NewMockWorkspaceUpdatesProvider(ctrl)
329+
330+
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
331+
Logger: logger,
332+
CoordPtr: &coordPtr,
333+
DERPMapUpdateFrequency: time.Hour,
334+
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
335+
WorkspaceUpdatesProvider: mProvider,
336+
})
337+
require.NoError(t, err)
338+
339+
wsErr := make(chan error, 1)
340+
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
341+
// need 2.3 for WorkspaceUpdates RPC
342+
cVer := r.URL.Query().Get("version")
343+
assert.Equal(t, "2.3", cVer)
344+
345+
sws, err := websocket.Accept(w, r, nil)
346+
if !assert.NoError(t, err) {
347+
return
348+
}
349+
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
350+
// streamID can be empty because we don't call RPCs in this test.
351+
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
352+
}))
353+
defer svr.Close()
354+
svrURL, err := url.Parse(svr.URL)
355+
require.NoError(t, err)
356+
357+
userID := uuid.UUID{88}
358+
359+
mSub := tailnettest.NewMockSubscription(ctrl)
360+
updateCh := make(chan *tailnetproto.WorkspaceUpdate, 1)
361+
mProvider.EXPECT().Subscribe(gomock.Any(), userID).Times(1).Return(mSub, nil)
362+
mSub.EXPECT().Updates().MinTimes(1).Return(updateCh)
363+
mSub.EXPECT().Close().Times(1).Return(nil)
364+
365+
uut := workspacesdk.NewWebsocketDialer(
366+
logger, svrURL, &websocket.DialOptions{},
367+
workspacesdk.WithWorkspaceUpdates(&tailnetproto.WorkspaceUpdatesRequest{
368+
WorkspaceOwnerId: userID[:],
369+
}),
370+
)
371+
372+
clients, err := uut.Dial(ctx, nil)
373+
require.NoError(t, err)
374+
require.NotNil(t, clients.WorkspaceUpdates)
375+
376+
wsID := uuid.UUID{99}
377+
expectedUpdate := &tailnetproto.WorkspaceUpdate{
378+
UpsertedWorkspaces: []*tailnetproto.Workspace{
379+
{Id: wsID[:]},
380+
},
381+
}
382+
updateCh <- expectedUpdate
383+
384+
gotUpdate, err := clients.WorkspaceUpdates.Recv()
385+
require.NoError(t, err)
386+
require.Equal(t, wsID[:], gotUpdate.GetUpsertedWorkspaces()[0].GetId())
387+
388+
clients.Closer.Close()
389+
390+
err = testutil.RequireRecvCtx(ctx, t, wsErr)
391+
require.NoError(t, err)
392+
}
393+
310394
type fakeResumeTokenController struct {
311395
ctx context.Context
312396
t testing.TB

codersdk/workspacesdk/workspacesdk.go

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -216,17 +216,6 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
216216
if err != nil {
217217
return nil, xerrors.Errorf("parse url: %w", err)
218218
}
219-
q := coordinateURL.Query()
220-
// The current version includes additions
221-
//
222-
// 2.1 GetAnnouncementBanners on the Agent API (version locked to Tailnet API)
223-
// 2.2 PostTelemetry on the Tailnet API
224-
// 2.3 RefreshResumeToken, WorkspaceUpdates
225-
//
226-
// Since resume tokens and telemetry are optional, and fail gracefully, and we don't use
227-
// WorkspaceUpdates to talk to a single agent, we ask for version 2.0 for maximum compatibility
228-
q.Add("version", "2.0")
229-
coordinateURL.RawQuery = q.Encode()
230219

231220
dialer := NewWebsocketDialer(options.Logger, coordinateURL, &websocket.DialOptions{
232221
HTTPClient: c.client.HTTPClient,

tailnet/service_test.go

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/google/uuid"
1212
"github.com/stretchr/testify/assert"
1313
"github.com/stretchr/testify/require"
14+
"go.uber.org/mock/gomock"
1415
"golang.org/x/xerrors"
1516
"tailscale.com/tailcfg"
1617

@@ -236,8 +237,8 @@ func TestClientUserCoordinateeAuth(t *testing.T) {
236237
agentID2 := uuid.UUID{0x02}
237238
clientID := uuid.UUID{0x03}
238239

239-
updatesCh := make(chan *proto.WorkspaceUpdate, 1)
240-
updatesProvider := &fakeUpdatesProvider{ch: updatesCh}
240+
ctrl := gomock.NewController(t)
241+
updatesProvider := tailnettest.NewMockWorkspaceUpdatesProvider(ctrl)
241242

242243
fCoord, client := createUpdateService(t, ctx, clientID, updatesProvider)
243244

@@ -271,8 +272,10 @@ func TestWorkspaceUpdates(t *testing.T) {
271272
t.Parallel()
272273

273274
ctx := testutil.Context(t, testutil.WaitShort)
275+
ctrl := gomock.NewController(t)
276+
updatesProvider := tailnettest.NewMockWorkspaceUpdatesProvider(ctrl)
277+
mSub := tailnettest.NewMockSubscription(ctrl)
274278
updatesCh := make(chan *proto.WorkspaceUpdate, 1)
275-
updatesProvider := &fakeUpdatesProvider{ch: updatesCh}
276279

277280
clientID := uuid.UUID{0x03}
278281
wsID := uuid.UUID{0x04}
@@ -293,6 +296,11 @@ func TestWorkspaceUpdates(t *testing.T) {
293296
DeletedAgents: []*proto.Agent{},
294297
}
295298
updatesCh <- expected
299+
updatesProvider.EXPECT().Subscribe(gomock.Any(), clientID).
300+
Times(1).
301+
Return(mSub, nil)
302+
mSub.EXPECT().Updates().MinTimes(1).Return(updatesCh)
303+
mSub.EXPECT().Close().Times(1).Return(nil)
296304

297305
updatesStream, err := client.WorkspaceUpdates(ctx, &proto.WorkspaceUpdatesRequest{
298306
WorkspaceOwnerId: tailnet.UUIDToByteSlice(clientID),
@@ -354,34 +362,6 @@ func createUpdateService(t *testing.T, ctx context.Context, clientID uuid.UUID,
354362
return fCoord, client
355363
}
356364

357-
type fakeUpdatesProvider struct {
358-
ch chan *proto.WorkspaceUpdate
359-
}
360-
361-
func (*fakeUpdatesProvider) Close() error {
362-
return nil
363-
}
364-
365-
func (f *fakeUpdatesProvider) Subscribe(context.Context, uuid.UUID) (tailnet.Subscription, error) {
366-
return &fakeSubscription{ch: f.ch}, nil
367-
}
368-
369-
type fakeSubscription struct {
370-
ch chan *proto.WorkspaceUpdate
371-
}
372-
373-
func (*fakeSubscription) Close() error {
374-
return nil
375-
}
376-
377-
func (f *fakeSubscription) Updates() <-chan *proto.WorkspaceUpdate {
378-
return f.ch
379-
}
380-
381-
var _ tailnet.Subscription = (*fakeSubscription)(nil)
382-
383-
var _ tailnet.WorkspaceUpdatesProvider = (*fakeUpdatesProvider)(nil)
384-
385365
type fakeTunnelAuth struct{}
386366

387367
// AuthorizeTunnel implements tailnet.TunnelAuthorizer.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy