From 9cf8f03bf5099d6449d535a98e66d6528df8f196 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 18 Jan 2024 09:15:27 +0400 Subject: [PATCH] feat: set peers lost when disconnected from coordinator --- tailnet/conn.go | 5 ++ tailnet/coordinator.go | 67 +++++++++----- tailnet/coordinator_test.go | 172 ++++++++++++++++++++++++++++++++++-- testutil/ctx.go | 10 +++ 4 files changed, 226 insertions(+), 28 deletions(-) diff --git a/tailnet/conn.go b/tailnet/conn.go index 0b4b942f8f4b5..4048567946791 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -356,6 +356,11 @@ func (c *Conn) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error return nil } +// SetAllPeersLost marks all peers lost; typically used when we disconnect from a coordinator. +func (c *Conn) SetAllPeersLost() { + c.configMaps.setAllPeersLost() +} + // NodeAddresses returns the addresses of a node from the NetworkMap. func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) { return c.configMaps.nodeAddresses(publicKey) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 0fa62fc922209..3c4b1aeb24d3c 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -97,6 +97,7 @@ type Node struct { // Conn. type Coordinatee interface { UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error + SetAllPeersLost() SetNodeCallback(func(*Node)) } @@ -107,20 +108,28 @@ type Coordination interface { type remoteCoordination struct { sync.Mutex - closed bool - errChan chan error - coordinatee Coordinatee - logger slog.Logger - protocol proto.DRPCTailnet_CoordinateClient + closed bool + errChan chan error + coordinatee Coordinatee + logger slog.Logger + protocol proto.DRPCTailnet_CoordinateClient + respLoopDone chan struct{} } -func (c *remoteCoordination) Close() error { +func (c *remoteCoordination) Close() (retErr error) { c.Lock() defer c.Unlock() if c.closed { return nil } c.closed = true + defer func() { + protoErr := c.protocol.Close() + <-c.respLoopDone + if retErr == nil { + retErr = protoErr + } + }() err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) if err != nil { return xerrors.Errorf("send disconnect: %w", err) @@ -140,6 +149,10 @@ func (c *remoteCoordination) sendErr(err error) { } func (c *remoteCoordination) respLoop() { + defer func() { + c.coordinatee.SetAllPeersLost() + close(c.respLoopDone) + }() for { resp, err := c.protocol.Recv() if err != nil { @@ -162,10 +175,11 @@ func NewRemoteCoordination(logger slog.Logger, tunnelTarget uuid.UUID, ) Coordination { c := &remoteCoordination{ - errChan: make(chan error, 1), - coordinatee: coordinatee, - logger: logger, - protocol: protocol, + errChan: make(chan error, 1), + coordinatee: coordinatee, + logger: logger, + protocol: protocol, + respLoopDone: make(chan struct{}), } if tunnelTarget != uuid.Nil { c.Lock() @@ -200,14 +214,15 @@ func NewRemoteCoordination(logger slog.Logger, type inMemoryCoordination struct { sync.Mutex - ctx context.Context - errChan chan error - closed bool - closedCh chan struct{} - coordinatee Coordinatee - logger slog.Logger - resps <-chan *proto.CoordinateResponse - reqs chan<- *proto.CoordinateRequest + ctx context.Context + errChan chan error + closed bool + closedCh chan struct{} + respLoopDone chan struct{} + coordinatee Coordinatee + logger slog.Logger + resps <-chan *proto.CoordinateResponse + reqs chan<- *proto.CoordinateRequest } func (c *inMemoryCoordination) sendErr(err error) { @@ -238,11 +253,12 @@ func NewInMemoryCoordination( thisID = clientID } c := &inMemoryCoordination{ - ctx: ctx, - errChan: make(chan error, 1), - coordinatee: coordinatee, - logger: logger, - closedCh: make(chan struct{}), + ctx: ctx, + errChan: make(chan error, 1), + coordinatee: coordinatee, + logger: logger, + closedCh: make(chan struct{}), + respLoopDone: make(chan struct{}), } // use the background context since we will depend exclusively on closing the req channel to @@ -285,6 +301,10 @@ func NewInMemoryCoordination( } func (c *inMemoryCoordination) respLoop() { + defer func() { + c.coordinatee.SetAllPeersLost() + close(c.respLoopDone) + }() for { select { case <-c.closedCh: @@ -315,6 +335,7 @@ func (c *inMemoryCoordination) Close() error { defer close(c.reqs) c.closed = true close(c.closedCh) + <-c.respLoopDone select { case <-c.ctx.Done(): return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err()) diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index 1aad59e5a2f68..7207f93d78bd0 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -6,19 +6,24 @@ import ( "net" "net/http" "net/http/httptest" + "sync" + "sync/atomic" "testing" "time" - "nhooyr.io/websocket" - - "cdr.dev/slog" - "cdr.dev/slog/sloggers/slogtest" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "nhooyr.io/websocket" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" + "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/tailnet/test" "github.com/coder/coder/v2/testutil" ) @@ -400,3 +405,160 @@ func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server n require.True(t, ok) return client, server } + +func TestInMemoryCoordination(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + clientID := uuid.UUID{1} + agentID := uuid.UUID{2} + mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t)) + fConn := &fakeCoordinatee{} + + reqs := make(chan *proto.CoordinateRequest, 100) + resps := make(chan *proto.CoordinateResponse, 100) + mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}). + Times(1).Return(reqs, resps) + + uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn) + defer uut.Close() + + coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) + + select { + case err := <-uut.Error(): + require.NoError(t, err) + default: + // OK! + } +} + +func TestRemoteCoordination(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + clientID := uuid.UUID{1} + agentID := uuid.UUID{2} + mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t)) + fConn := &fakeCoordinatee{} + + reqs := make(chan *proto.CoordinateRequest, 100) + resps := make(chan *proto.CoordinateResponse, 100) + mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}). + Times(1).Return(reqs, resps) + + var coord tailnet.Coordinator = mCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + svc, err := tailnet.NewClientService( + logger.Named("svc"), &coordPtr, + time.Hour, + func() *tailcfg.DERPMap { panic("not implemented") }, + ) + require.NoError(t, err) + sC, cC := net.Pipe() + + serveErr := make(chan error, 1) + go func() { + err := svc.ServeClient(ctx, tailnet.CurrentVersion.String(), sC, clientID, agentID) + serveErr <- err + }() + + client, err := tailnet.NewDRPCClient(cC) + require.NoError(t, err) + protocol, err := client.Coordinate(ctx) + require.NoError(t, err) + + uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID) + defer uut.Close() + + coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) + + select { + case err := <-uut.Error(): + require.ErrorContains(t, err, "stream terminated by sending close") + default: + // OK! + } +} + +// coordinationTest tests that a coordination behaves correctly +func coordinationTest( + ctx context.Context, t *testing.T, + uut tailnet.Coordination, fConn *fakeCoordinatee, + reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse, + agentID uuid.UUID, +) { + // It should add the tunnel, since we configured as a client + req := testutil.RequireRecvCtx(ctx, t, reqs) + require.Equal(t, agentID[:], req.GetAddTunnel().GetId()) + + // when we call the callback, it should send a node update + require.NotNil(t, fConn.callback) + fConn.callback(&tailnet.Node{PreferredDERP: 1}) + + req = testutil.RequireRecvCtx(ctx, t, reqs) + require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp()) + + // When we send a peer update, it should update the coordinatee + nk, err := key.NewNode().Public().MarshalBinary() + require.NoError(t, err) + dk, err := key.NewDisco().Public().MarshalText() + require.NoError(t, err) + updates := []*proto.CoordinateResponse_PeerUpdate{ + { + Id: agentID[:], + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + Node: &proto.Node{ + Id: 2, + Key: nk, + Disco: string(dk), + }, + }, + } + testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates}) + require.Eventually(t, func() bool { + fConn.Lock() + defer fConn.Unlock() + return len(fConn.updates) > 0 + }, testutil.WaitShort, testutil.IntervalFast) + require.Len(t, fConn.updates[0], 1) + require.Equal(t, agentID[:], fConn.updates[0][0].Id) + + err = uut.Close() + require.NoError(t, err) + uut.Error() + + // When we close, it should gracefully disconnect + req = testutil.RequireRecvCtx(ctx, t, reqs) + require.NotNil(t, req.Disconnect) + + // It should set all peers lost on the coordinatee + require.Equal(t, 1, fConn.setAllPeersLostCalls) +} + +type fakeCoordinatee struct { + sync.Mutex + callback func(*tailnet.Node) + updates [][]*proto.CoordinateResponse_PeerUpdate + setAllPeersLostCalls int +} + +func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error { + f.Lock() + defer f.Unlock() + f.updates = append(f.updates, updates) + return nil +} + +func (f *fakeCoordinatee) SetAllPeersLost() { + f.Lock() + defer f.Unlock() + f.setAllPeersLostCalls++ +} + +func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) { + f.Lock() + defer f.Unlock() + f.callback = callback +} diff --git a/testutil/ctx.go b/testutil/ctx.go index 2cc44c5bad8d7..c8f8c1769fe7f 100644 --- a/testutil/ctx.go +++ b/testutil/ctx.go @@ -22,3 +22,13 @@ func RequireRecvCtx[A any](ctx context.Context, t testing.TB, c <-chan A) (a A) return a } } + +func RequireSendCtx[A any](ctx context.Context, t testing.TB, c chan<- A, a A) { + t.Helper() + select { + case <-ctx.Done(): + t.Fatal("timeout") + case c <- a: + // OK! + } +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy