Skip to content

Commit 2d92458

Browse files
committed
feat: set peers lost when disconnected from coordinator
1 parent 2daadfc commit 2d92458

File tree

4 files changed

+226
-28
lines changed

4 files changed

+226
-28
lines changed

tailnet/conn.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,11 @@ func (c *Conn) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error
356356
return nil
357357
}
358358

359+
// SetAllPeersLost marks all peers lost; typically used when we disconnect from a coordinator.
360+
func (c *Conn) SetAllPeersLost() {
361+
c.configMaps.setAllPeersLost()
362+
}
363+
359364
// NodeAddresses returns the addresses of a node from the NetworkMap.
360365
func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) {
361366
return c.configMaps.nodeAddresses(publicKey)

tailnet/coordinator.go

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ type Node struct {
9797
// Conn.
9898
type Coordinatee interface {
9999
UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error
100+
SetAllPeersLost()
100101
SetNodeCallback(func(*Node))
101102
}
102103

@@ -107,20 +108,28 @@ type Coordination interface {
107108

108109
type remoteCoordination struct {
109110
sync.Mutex
110-
closed bool
111-
errChan chan error
112-
coordinatee Coordinatee
113-
logger slog.Logger
114-
protocol proto.DRPCTailnet_CoordinateClient
111+
closed bool
112+
errChan chan error
113+
coordinatee Coordinatee
114+
logger slog.Logger
115+
protocol proto.DRPCTailnet_CoordinateClient
116+
respLoopDone chan struct{}
115117
}
116118

117-
func (c *remoteCoordination) Close() error {
119+
func (c *remoteCoordination) Close() (retErr error) {
118120
c.Lock()
119121
defer c.Unlock()
120122
if c.closed {
121123
return nil
122124
}
123125
c.closed = true
126+
defer func() {
127+
protoErr := c.protocol.Close()
128+
<-c.respLoopDone
129+
if retErr == nil {
130+
retErr = protoErr
131+
}
132+
}()
124133
err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
125134
if err != nil {
126135
return xerrors.Errorf("send disconnect: %w", err)
@@ -140,6 +149,10 @@ func (c *remoteCoordination) sendErr(err error) {
140149
}
141150

142151
func (c *remoteCoordination) respLoop() {
152+
defer func() {
153+
c.coordinatee.SetAllPeersLost()
154+
close(c.respLoopDone)
155+
}()
143156
for {
144157
resp, err := c.protocol.Recv()
145158
if err != nil {
@@ -162,10 +175,11 @@ func NewRemoteCoordination(logger slog.Logger,
162175
tunnelTarget uuid.UUID,
163176
) Coordination {
164177
c := &remoteCoordination{
165-
errChan: make(chan error, 1),
166-
coordinatee: coordinatee,
167-
logger: logger,
168-
protocol: protocol,
178+
errChan: make(chan error, 1),
179+
coordinatee: coordinatee,
180+
logger: logger,
181+
protocol: protocol,
182+
respLoopDone: make(chan struct{}),
169183
}
170184
if tunnelTarget != uuid.Nil {
171185
c.Lock()
@@ -200,14 +214,15 @@ func NewRemoteCoordination(logger slog.Logger,
200214

201215
type inMemoryCoordination struct {
202216
sync.Mutex
203-
ctx context.Context
204-
errChan chan error
205-
closed bool
206-
closedCh chan struct{}
207-
coordinatee Coordinatee
208-
logger slog.Logger
209-
resps <-chan *proto.CoordinateResponse
210-
reqs chan<- *proto.CoordinateRequest
217+
ctx context.Context
218+
errChan chan error
219+
closed bool
220+
closedCh chan struct{}
221+
respLoopDone chan struct{}
222+
coordinatee Coordinatee
223+
logger slog.Logger
224+
resps <-chan *proto.CoordinateResponse
225+
reqs chan<- *proto.CoordinateRequest
211226
}
212227

213228
func (c *inMemoryCoordination) sendErr(err error) {
@@ -238,11 +253,12 @@ func NewInMemoryCoordination(
238253
thisID = clientID
239254
}
240255
c := &inMemoryCoordination{
241-
ctx: ctx,
242-
errChan: make(chan error, 1),
243-
coordinatee: coordinatee,
244-
logger: logger,
245-
closedCh: make(chan struct{}),
256+
ctx: ctx,
257+
errChan: make(chan error, 1),
258+
coordinatee: coordinatee,
259+
logger: logger,
260+
closedCh: make(chan struct{}),
261+
respLoopDone: make(chan struct{}),
246262
}
247263

248264
// use the background context since we will depend exclusively on closing the req channel to
@@ -285,6 +301,10 @@ func NewInMemoryCoordination(
285301
}
286302

287303
func (c *inMemoryCoordination) respLoop() {
304+
defer func() {
305+
c.coordinatee.SetAllPeersLost()
306+
close(c.respLoopDone)
307+
}()
288308
for {
289309
select {
290310
case <-c.closedCh:
@@ -315,6 +335,7 @@ func (c *inMemoryCoordination) Close() error {
315335
defer close(c.reqs)
316336
c.closed = true
317337
close(c.closedCh)
338+
<-c.respLoopDone
318339
select {
319340
case <-c.ctx.Done():
320341
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())

tailnet/coordinator_test.go

Lines changed: 167 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,24 @@ import (
66
"net"
77
"net/http"
88
"net/http/httptest"
9+
"sync"
10+
"sync/atomic"
911
"testing"
1012
"time"
1113

12-
"nhooyr.io/websocket"
13-
14-
"cdr.dev/slog"
15-
"cdr.dev/slog/sloggers/slogtest"
16-
1714
"github.com/google/uuid"
1815
"github.com/stretchr/testify/assert"
1916
"github.com/stretchr/testify/require"
17+
"go.uber.org/mock/gomock"
18+
"nhooyr.io/websocket"
19+
"tailscale.com/tailcfg"
20+
"tailscale.com/types/key"
2021

22+
"cdr.dev/slog"
23+
"cdr.dev/slog/sloggers/slogtest"
2124
"github.com/coder/coder/v2/tailnet"
25+
"github.com/coder/coder/v2/tailnet/proto"
26+
"github.com/coder/coder/v2/tailnet/tailnettest"
2227
"github.com/coder/coder/v2/tailnet/test"
2328
"github.com/coder/coder/v2/testutil"
2429
)
@@ -400,3 +405,160 @@ func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server n
400405
require.True(t, ok)
401406
return client, server
402407
}
408+
409+
func TestInMemoryCoordination(t *testing.T) {
410+
t.Parallel()
411+
ctx := testutil.Context(t, testutil.WaitShort)
412+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
413+
clientID := uuid.UUID{1}
414+
agentID := uuid.UUID{2}
415+
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
416+
fConn := &fakeCoordinatee{}
417+
418+
reqs := make(chan *proto.CoordinateRequest, 100)
419+
resps := make(chan *proto.CoordinateResponse, 100)
420+
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
421+
Times(1).Return(reqs, resps)
422+
423+
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
424+
defer uut.Close()
425+
426+
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
427+
428+
select {
429+
case err := <-uut.Error():
430+
require.NoError(t, err)
431+
default:
432+
// OK!
433+
}
434+
}
435+
436+
func TestRemoteCoordination(t *testing.T) {
437+
t.Parallel()
438+
ctx := testutil.Context(t, testutil.WaitShort)
439+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
440+
clientID := uuid.UUID{1}
441+
agentID := uuid.UUID{2}
442+
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
443+
fConn := &fakeCoordinatee{}
444+
445+
reqs := make(chan *proto.CoordinateRequest, 100)
446+
resps := make(chan *proto.CoordinateResponse, 100)
447+
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
448+
Times(1).Return(reqs, resps)
449+
450+
var coord tailnet.Coordinator = mCoord
451+
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
452+
coordPtr.Store(&coord)
453+
svc, err := tailnet.NewClientService(
454+
logger.Named("svc"), &coordPtr,
455+
time.Hour,
456+
func() *tailcfg.DERPMap { panic("not implemented") },
457+
)
458+
require.NoError(t, err)
459+
sC, cC := net.Pipe()
460+
461+
serveErr := make(chan error, 1)
462+
go func() {
463+
err := svc.ServeClient(ctx, tailnet.CurrentVersion.String(), sC, clientID, agentID)
464+
serveErr <- err
465+
}()
466+
467+
client, err := tailnet.NewDRPCClient(cC)
468+
require.NoError(t, err)
469+
protocol, err := client.Coordinate(ctx)
470+
require.NoError(t, err)
471+
472+
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
473+
defer uut.Close()
474+
475+
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
476+
477+
select {
478+
case err := <-uut.Error():
479+
require.ErrorContains(t, err, "stream terminated by sending close")
480+
default:
481+
// OK!
482+
}
483+
}
484+
485+
// coordinationTest tests that a coordination behaves correctly
486+
func coordinationTest(
487+
ctx context.Context, t *testing.T,
488+
uut tailnet.Coordination, fConn *fakeCoordinatee,
489+
reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse,
490+
agentID uuid.UUID,
491+
) {
492+
// It should add the tunnel, since we configured as a client
493+
req := testutil.RequireRecvCtx(ctx, t, reqs)
494+
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
495+
496+
// when we call the callback, it should send a node update
497+
require.NotNil(t, fConn.callback)
498+
fConn.callback(&tailnet.Node{PreferredDERP: 1})
499+
500+
req = testutil.RequireRecvCtx(ctx, t, reqs)
501+
require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp())
502+
503+
// When we send a peer update, it should update the coordinatee
504+
nk, err := key.NewNode().Public().MarshalBinary()
505+
require.NoError(t, err)
506+
dk, err := key.NewDisco().Public().MarshalText()
507+
require.NoError(t, err)
508+
updates := []*proto.CoordinateResponse_PeerUpdate{
509+
{
510+
Id: agentID[:],
511+
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
512+
Node: &proto.Node{
513+
Id: 2,
514+
Key: nk,
515+
Disco: string(dk),
516+
},
517+
},
518+
}
519+
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates})
520+
require.Eventually(t, func() bool {
521+
fConn.Lock()
522+
defer fConn.Unlock()
523+
return len(fConn.updates) > 0
524+
}, testutil.WaitShort, testutil.IntervalFast)
525+
require.Len(t, fConn.updates[0], 1)
526+
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
527+
528+
err = uut.Close()
529+
require.NoError(t, err)
530+
uut.Error()
531+
532+
// When we close, it should gracefully disconnect
533+
req = testutil.RequireRecvCtx(ctx, t, reqs)
534+
require.NotNil(t, req.Disconnect)
535+
536+
// It should set all peers lost on the coordinatee
537+
require.Equal(t, 1, fConn.setAllPeersLostCalls)
538+
}
539+
540+
type fakeCoordinatee struct {
541+
sync.Mutex
542+
callback func(*tailnet.Node)
543+
updates [][]*proto.CoordinateResponse_PeerUpdate
544+
setAllPeersLostCalls int
545+
}
546+
547+
func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
548+
f.Lock()
549+
defer f.Unlock()
550+
f.updates = append(f.updates, updates)
551+
return nil
552+
}
553+
554+
func (f *fakeCoordinatee) SetAllPeersLost() {
555+
f.Lock()
556+
defer f.Unlock()
557+
f.setAllPeersLostCalls++
558+
}
559+
560+
func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
561+
f.Lock()
562+
defer f.Unlock()
563+
f.callback = callback
564+
}

testutil/ctx.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,13 @@ func RequireRecvCtx[A any](ctx context.Context, t testing.TB, c <-chan A) (a A)
2222
return a
2323
}
2424
}
25+
26+
func RequireSendCtx[A any](ctx context.Context, t testing.TB, c chan<- A, a A) {
27+
t.Helper()
28+
select {
29+
case <-ctx.Done():
30+
t.Fatal("timeout")
31+
case c <- a:
32+
// OK!
33+
}
34+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy