Skip to content

Commit 89b90c2

Browse files
committed
feat: add support for WorkspaceUpdates to WebsocketDialer
1 parent 615f316 commit 89b90c2

File tree

9 files changed

+305
-78
lines changed

9 files changed

+305
-78
lines changed

Makefile

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,13 @@ DB_GEN_FILES := \
482482
coderd/database/dbauthz/dbauthz.go \
483483
coderd/database/dbmock/dbmock.go
484484

485+
TAILNETTEST_MOCKS := \
486+
tailnet/tailnettest/coordinatormock.go \
487+
tailnet/tailnettest/coordinateemock.go \
488+
tailnet/tailnettest/workspaceupdatesprovidermock.go \
489+
tailnet/tailnettest/subscriptionmock.go
490+
491+
485492
# all gen targets should be added here and to gen/mark-fresh
486493
gen: \
487494
tailnet/proto/tailnet.pb.go \
@@ -506,8 +513,7 @@ gen: \
506513
site/e2e/provisionerGenerated.ts \
507514
site/src/theme/icons.json \
508515
examples/examples.gen.json \
509-
tailnet/tailnettest/coordinatormock.go \
510-
tailnet/tailnettest/coordinateemock.go \
516+
$(TAILNETTEST_MOCKS) \
511517
coderd/database/pubsub/psmock/psmock.go
512518
.PHONY: gen
513519

@@ -536,8 +542,7 @@ gen/mark-fresh:
536542
site/e2e/provisionerGenerated.ts \
537543
site/src/theme/icons.json \
538544
examples/examples.gen.json \
539-
tailnet/tailnettest/coordinatormock.go \
540-
tailnet/tailnettest/coordinateemock.go \
545+
$(TAILNETTEST_MOCKS) \
541546
coderd/database/pubsub/psmock/psmock.go \
542547
"
543548

@@ -570,7 +575,7 @@ coderd/database/dbmock/dbmock.go: coderd/database/db.go coderd/database/querier.
570575
coderd/database/pubsub/psmock/psmock.go: coderd/database/pubsub/pubsub.go
571576
go generate ./coderd/database/pubsub/psmock
572577

573-
tailnet/tailnettest/coordinatormock.go tailnet/tailnettest/coordinateemock.go: tailnet/coordinator.go
578+
$(TAILNETTEST_MOCKS): tailnet/coordinator.go tailnet/service.go
574579
go generate ./tailnet/tailnettest/
575580

576581
tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto

codersdk/workspacesdk/dialer.go

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,26 @@ var permanentErrorStatuses = []int{
2525
}
2626

2727
type WebsocketDialer struct {
28-
logger slog.Logger
29-
dialOptions *websocket.DialOptions
30-
url *url.URL
28+
logger slog.Logger
29+
dialOptions *websocket.DialOptions
30+
url *url.URL
31+
// workspaceUpdatesReq != nil means that the dialer should call the WorkspaceUpdates RPC and
32+
// return the corresponding client
33+
workspaceUpdatesReq *proto.WorkspaceUpdatesRequest
34+
3135
resumeTokenFailed bool
3236
connected chan error
3337
isFirst bool
3438
}
3539

40+
type WebsocketDialerOption func(*WebsocketDialer)
41+
42+
func WithWorkspaceUpdates(req *proto.WorkspaceUpdatesRequest) WebsocketDialerOption {
43+
return func(w *WebsocketDialer) {
44+
w.workspaceUpdatesReq = req
45+
}
46+
}
47+
3648
func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenController,
3749
) (
3850
tailnet.ControlProtocolClients, error,
@@ -41,14 +53,27 @@ func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenControl
4153

4254
u := new(url.URL)
4355
*u = *w.url
56+
q := u.Query()
4457
if r != nil && !w.resumeTokenFailed {
4558
if token, ok := r.Token(); ok {
46-
q := u.Query()
4759
q.Set("resume_token", token)
48-
u.RawQuery = q.Encode()
4960
w.logger.Debug(ctx, "using resume token on dial")
5061
}
5162
}
63+
// The current version includes additions
64+
//
65+
// 2.1 GetAnnouncementBanners on the Agent API (version locked to Tailnet API)
66+
// 2.2 PostTelemetry on the Tailnet API
67+
// 2.3 RefreshResumeToken, WorkspaceUpdates
68+
//
69+
// Resume tokens and telemetry are optional, and fail gracefully. So we use version 2.0 for
70+
// maximum compatibility if we don't need WorkspaceUpdates. If we do, we use 2.3.
71+
if w.workspaceUpdatesReq != nil {
72+
q.Add("version", "2.3")
73+
} else {
74+
q.Add("version", "2.0")
75+
}
76+
u.RawQuery = q.Encode()
5277

5378
// nolint:bodyclose
5479
ws, res, err := websocket.Dial(ctx, u.String(), w.dialOptions)
@@ -115,25 +140,43 @@ func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenControl
115140
return tailnet.ControlProtocolClients{}, err
116141
}
117142

143+
var updates tailnet.WorkspaceUpdatesClient
144+
if w.workspaceUpdatesReq != nil {
145+
updates, err = client.WorkspaceUpdates(context.Background(), w.workspaceUpdatesReq)
146+
if err != nil {
147+
w.logger.Debug(ctx, "failed to create WorkspaceUpdates stream", slog.Error(err))
148+
_ = ws.Close(websocket.StatusInternalError, "")
149+
return tailnet.ControlProtocolClients{}, err
150+
}
151+
}
152+
118153
return tailnet.ControlProtocolClients{
119-
Closer: client.DRPCConn(),
120-
Coordinator: coord,
121-
DERP: derps,
122-
ResumeToken: client,
123-
Telemetry: client,
154+
Closer: client.DRPCConn(),
155+
Coordinator: coord,
156+
DERP: derps,
157+
ResumeToken: client,
158+
Telemetry: client,
159+
WorkspaceUpdates: updates,
124160
}, nil
125161
}
126162

127163
func (w *WebsocketDialer) Connected() <-chan error {
128164
return w.connected
129165
}
130166

131-
func NewWebsocketDialer(logger slog.Logger, u *url.URL, opts *websocket.DialOptions) *WebsocketDialer {
132-
return &WebsocketDialer{
167+
func NewWebsocketDialer(
168+
logger slog.Logger, u *url.URL, websocketOptions *websocket.DialOptions,
169+
dialerOptions ...WebsocketDialerOption,
170+
) *WebsocketDialer {
171+
w := &WebsocketDialer{
133172
logger: logger,
134-
dialOptions: opts,
173+
dialOptions: websocketOptions,
135174
url: u,
136175
connected: make(chan error, 1),
137176
isFirst: true,
138177
}
178+
for _, o := range dialerOptions {
179+
o(w)
180+
}
181+
return w
139182
}

codersdk/workspacesdk/dialer_test.go

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ import (
99
"testing"
1010
"time"
1111

12+
"github.com/google/uuid"
1213
"github.com/stretchr/testify/assert"
1314
"github.com/stretchr/testify/require"
15+
"go.uber.org/mock/gomock"
1416
"nhooyr.io/websocket"
1517
"tailscale.com/tailcfg"
1618

@@ -21,7 +23,7 @@ import (
2123
"github.com/coder/coder/v2/codersdk"
2224
"github.com/coder/coder/v2/codersdk/workspacesdk"
2325
"github.com/coder/coder/v2/tailnet"
24-
"github.com/coder/coder/v2/tailnet/proto"
26+
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
2527
"github.com/coder/coder/v2/tailnet/tailnettest"
2628
"github.com/coder/coder/v2/testutil"
2729
)
@@ -102,6 +104,7 @@ func TestWebsocketDialer_TokenController(t *testing.T) {
102104
require.Equal(t, "", gotToken)
103105

104106
clients = testutil.RequireRecvCtx(ctx, t, clientCh)
107+
require.Nil(t, clients.WorkspaceUpdates)
105108
clients.Closer.Close()
106109

107110
err = testutil.RequireRecvCtx(ctx, t, wsErr)
@@ -273,7 +276,7 @@ func TestWebsocketDialer_UplevelVersion(t *testing.T) {
273276
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
274277

275278
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
276-
sVer := apiversion.New(proto.CurrentMajor, proto.CurrentMinor-1)
279+
sVer := apiversion.New(2, 2)
277280

278281
// the following matches what Coderd does;
279282
// c.f. coderd/workspaceagents.go: workspaceAgentClientCoordinate
@@ -291,7 +294,10 @@ func TestWebsocketDialer_UplevelVersion(t *testing.T) {
291294
svrURL, err := url.Parse(svr.URL)
292295
require.NoError(t, err)
293296

294-
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
297+
uut := workspacesdk.NewWebsocketDialer(
298+
logger, svrURL, &websocket.DialOptions{},
299+
workspacesdk.WithWorkspaceUpdates(&tailnetproto.WorkspaceUpdatesRequest{}),
300+
)
295301

296302
errCh := make(chan error, 1)
297303
go func() {
@@ -307,6 +313,84 @@ func TestWebsocketDialer_UplevelVersion(t *testing.T) {
307313
require.NotEmpty(t, sdkErr.Helper)
308314
}
309315

316+
func TestWebsocketDialer_WorkspaceUpdates(t *testing.T) {
317+
t.Parallel()
318+
ctx := testutil.Context(t, testutil.WaitShort)
319+
logger := slogtest.Make(t, &slogtest.Options{
320+
IgnoreErrors: true,
321+
}).Leveled(slog.LevelDebug)
322+
323+
fCoord := tailnettest.NewFakeCoordinator()
324+
var coord tailnet.Coordinator = fCoord
325+
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
326+
coordPtr.Store(&coord)
327+
ctrl := gomock.NewController(t)
328+
mProvider := tailnettest.NewMockWorkspaceUpdatesProvider(ctrl)
329+
330+
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
331+
Logger: logger,
332+
CoordPtr: &coordPtr,
333+
DERPMapUpdateFrequency: time.Hour,
334+
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
335+
WorkspaceUpdatesProvider: mProvider,
336+
})
337+
require.NoError(t, err)
338+
339+
wsErr := make(chan error, 1)
340+
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
341+
// need 2.3 for WorkspaceUpdates RPC
342+
cVer := r.URL.Query().Get("version")
343+
assert.Equal(t, "2.3", cVer)
344+
345+
sws, err := websocket.Accept(w, r, nil)
346+
if !assert.NoError(t, err) {
347+
return
348+
}
349+
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
350+
// streamID can be empty because we don't call RPCs in this test.
351+
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
352+
}))
353+
defer svr.Close()
354+
svrURL, err := url.Parse(svr.URL)
355+
require.NoError(t, err)
356+
357+
userID := uuid.UUID{88}
358+
359+
mSub := tailnettest.NewMockSubscription(ctrl)
360+
updateCh := make(chan *tailnetproto.WorkspaceUpdate, 1)
361+
mProvider.EXPECT().Subscribe(gomock.Any(), userID).Times(1).Return(mSub, nil)
362+
mSub.EXPECT().Updates().MinTimes(1).Return(updateCh)
363+
mSub.EXPECT().Close().Times(1).Return(nil)
364+
365+
uut := workspacesdk.NewWebsocketDialer(
366+
logger, svrURL, &websocket.DialOptions{},
367+
workspacesdk.WithWorkspaceUpdates(&tailnetproto.WorkspaceUpdatesRequest{
368+
WorkspaceOwnerId: userID[:],
369+
}),
370+
)
371+
372+
clients, err := uut.Dial(ctx, nil)
373+
require.NoError(t, err)
374+
require.NotNil(t, clients.WorkspaceUpdates)
375+
376+
wsID := uuid.UUID{99}
377+
expectedUpdate := &tailnetproto.WorkspaceUpdate{
378+
UpsertedWorkspaces: []*tailnetproto.Workspace{
379+
{Id: wsID[:]},
380+
},
381+
}
382+
updateCh <- expectedUpdate
383+
384+
gotUpdate, err := clients.WorkspaceUpdates.Recv()
385+
require.NoError(t, err)
386+
require.Equal(t, wsID[:], gotUpdate.GetUpsertedWorkspaces()[0].GetId())
387+
388+
clients.Closer.Close()
389+
390+
err = testutil.RequireRecvCtx(ctx, t, wsErr)
391+
require.NoError(t, err)
392+
}
393+
310394
type fakeResumeTokenController struct {
311395
ctx context.Context
312396
t testing.TB

codersdk/workspacesdk/workspacesdk.go

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -216,17 +216,6 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
216216
if err != nil {
217217
return nil, xerrors.Errorf("parse url: %w", err)
218218
}
219-
q := coordinateURL.Query()
220-
// The current version includes additions
221-
//
222-
// 2.1 GetAnnouncementBanners on the Agent API (version locked to Tailnet API)
223-
// 2.2 PostTelemetry on the Tailnet API
224-
// 2.3 RefreshResumeToken, WorkspaceUpdates
225-
//
226-
// Since resume tokens and telemetry are optional, and fail gracefully, and we don't use
227-
// WorkspaceUpdates to talk to a single agent, we ask for version 2.0 for maximum compatibility
228-
q.Add("version", "2.0")
229-
coordinateURL.RawQuery = q.Encode()
230219

231220
dialer := NewWebsocketDialer(options.Logger, coordinateURL, &websocket.DialOptions{
232221
HTTPClient: c.client.HTTPClient,

enterprise/wsproxy/wsproxysdk/wsproxysdk.go

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,6 @@ import (
2121
agpl "github.com/coder/coder/v2/tailnet"
2222
)
2323

24-
// TailnetAPIVersion is the version of the Tailnet API we use for wsproxy.
25-
//
26-
// # The current version of the Tailnet API includes additions
27-
//
28-
// 2.1 GetAnnouncementBanners on the Agent API (version locked to Tailnet API)
29-
// 2.2 PostTelemetry on the Tailnet API
30-
// 2.3 RefreshResumeToken, WorkspaceUpdates
31-
//
32-
// Since resume tokens and telemetry are optional, and fail gracefully, and we don't use
33-
// WorkspaceUpdates in the wsproxy, we ask for version 2.0 for maximum compatibility
34-
const TailnetAPIVersion = "2.0"
35-
3624
// Client is a HTTP client for a subset of Coder API routes that external
3725
// proxies need.
3826
type Client struct {
@@ -518,9 +506,6 @@ func (c *Client) TailnetDialer() (*workspacesdk.WebsocketDialer, error) {
518506
if err != nil {
519507
return nil, xerrors.Errorf("parse url: %w", err)
520508
}
521-
q := coordinateURL.Query()
522-
q.Add("version", TailnetAPIVersion)
523-
coordinateURL.RawQuery = q.Encode()
524509
coordinateHeaders := make(http.Header)
525510
tokenHeader := codersdk.SessionTokenHeader
526511
if c.SDKClient.SessionTokenHeader != "" {

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy