diff --git a/Makefile b/Makefile index f1da02f2f53c8..d771fb0233d9e 100644 --- a/Makefile +++ b/Makefile @@ -475,7 +475,8 @@ gen: \ site/.eslintignore \ site/e2e/provisionerGenerated.ts \ site/src/theme/icons.json \ - examples/examples.gen.json + examples/examples.gen.json \ + tailnet/tailnettest/coordinatormock.go .PHONY: gen # Mark all generated files as fresh so make thinks they're up-to-date. This is @@ -502,6 +503,7 @@ gen/mark-fresh: site/e2e/provisionerGenerated.ts \ site/src/theme/icons.json \ examples/examples.gen.json \ + tailnet/tailnettest/coordinatormock.go \ " for file in $$files; do echo "$$file" @@ -529,6 +531,9 @@ coderd/database/querier.go: coderd/database/sqlc.yaml coderd/database/dump.sql $ coderd/database/dbmock/dbmock.go: coderd/database/db.go coderd/database/querier.go go generate ./coderd/database/dbmock/ +tailnet/tailnettest/coordinatormock.go: tailnet/coordinator.go + go generate ./tailnet/tailnettest/ + tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto protoc \ --go_out=. \ diff --git a/agent/agent.go b/agent/agent.go index 514e10a7af3c0..25e24215d90bb 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -30,6 +30,7 @@ import ( "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" + "storj.io/drpc" "tailscale.com/net/speedtest" "tailscale.com/tailcfg" "tailscale.com/types/netlogtype" @@ -47,6 +48,7 @@ import ( "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/tailnet" + tailnetproto "github.com/coder/coder/v2/tailnet/proto" ) const ( @@ -86,7 +88,7 @@ type Options struct { type Client interface { Manifest(ctx context.Context) (agentsdk.Manifest, error) - Listen(ctx context.Context) (net.Conn, error) + Listen(ctx context.Context) (drpc.Conn, error) DERPMapUpdates(ctx context.Context) (<-chan agentsdk.DERPMapUpdate, io.Closer, error) ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error @@ -1058,20 +1060,34 @@ func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error ctx, cancel := context.WithCancel(ctx) defer cancel() - coordinator, err := a.client.Listen(ctx) + conn, err := a.client.Listen(ctx) if err != nil { return err } - defer coordinator.Close() + defer func() { + cErr := conn.Close() + if cErr != nil { + a.logger.Debug(ctx, "error closing drpc connection", slog.Error(err)) + } + }() + + tClient := tailnetproto.NewDRPCTailnetClient(conn) + coordinate, err := tClient.Coordinate(ctx) + if err != nil { + return xerrors.Errorf("failed to connect to the coordinate endpoint: %w", err) + } + defer func() { + cErr := coordinate.Close() + if cErr != nil { + a.logger.Debug(ctx, "error closing Coordinate client", slog.Error(err)) + } + }() a.logger.Info(ctx, "connected to coordination endpoint") - sendNodes, errChan := tailnet.ServeCoordinator(coordinator, func(nodes []*tailnet.Node) error { - return network.UpdateNodes(nodes, false) - }) - network.SetNodeCallback(sendNodes) + coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil) select { case <-ctx.Done(): return ctx.Err() - case err := <-errChan: + case err := <-coordination.Error(): return err } } diff --git a/agent/agent_test.go b/agent/agent_test.go index f884918c83dba..163c64b78841d 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1664,9 +1664,11 @@ func TestAgent_UpdatedDERP(t *testing.T) { require.NotNil(t, originalDerpMap) coordinator := tailnet.NewCoordinator(logger) - defer func() { + // use t.Cleanup so the coordinator closing doesn't deadlock with in-memory + // coordination + t.Cleanup(func() { _ = coordinator.Close() - }() + }) agentID := uuid.New() statsCh := make(chan *agentsdk.Stats, 50) fs := afero.NewMemMapFs() @@ -1681,41 +1683,42 @@ func TestAgent_UpdatedDERP(t *testing.T) { statsCh, coordinator, ) - closer := agent.New(agent.Options{ + uut := agent.New(agent.Options{ Client: client, Filesystem: fs, Logger: logger.Named("agent"), ReconnectingPTYTimeout: time.Minute, }) - defer func() { - _ = closer.Close() - }() + t.Cleanup(func() { + _ = uut.Close() + }) // Setup a client connection. - newClientConn := func(derpMap *tailcfg.DERPMap) *codersdk.WorkspaceAgentConn { + newClientConn := func(derpMap *tailcfg.DERPMap, name string) *codersdk.WorkspaceAgentConn { conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, DERPMap: derpMap, - Logger: logger.Named("client"), + Logger: logger.Named(name), }) require.NoError(t, err) - clientConn, serverConn := net.Pipe() - serveClientDone := make(chan struct{}) t.Cleanup(func() { - _ = clientConn.Close() - _ = serverConn.Close() + t.Logf("closing conn %s", name) _ = conn.Close() - <-serveClientDone }) - go func() { - defer close(serveClientDone) - err := coordinator.ServeClient(serverConn, uuid.New(), agentID) - assert.NoError(t, err) - }() - sendNode, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error { - return conn.UpdateNodes(nodes, false) + testCtx, testCtxCancel := context.WithCancel(context.Background()) + t.Cleanup(testCtxCancel) + clientID := uuid.New() + coordination := tailnet.NewInMemoryCoordination( + testCtx, logger, + clientID, agentID, + coordinator, conn) + t.Cleanup(func() { + t.Logf("closing coordination %s", name) + err := coordination.Close() + if err != nil { + t.Logf("error closing in-memory coordination: %s", err.Error()) + } }) - conn.SetNodeCallback(sendNode) // Force DERP. conn.SetBlockEndpoints(true) @@ -1724,6 +1727,7 @@ func TestAgent_UpdatedDERP(t *testing.T) { CloseFunc: func() error { return codersdk.ErrSkipClose }, }) t.Cleanup(func() { + t.Logf("closing sdkConn %s", name) _ = sdkConn.Close() }) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) @@ -1734,7 +1738,7 @@ func TestAgent_UpdatedDERP(t *testing.T) { return sdkConn } - conn1 := newClientConn(originalDerpMap) + conn1 := newClientConn(originalDerpMap, "client1") // Change the DERP map. newDerpMap, _ := tailnettest.RunDERPAndSTUN(t) @@ -1753,27 +1757,34 @@ func TestAgent_UpdatedDERP(t *testing.T) { DERPMap: newDerpMap, }) require.NoError(t, err) + t.Logf("client Pushed DERPMap update") require.Eventually(t, func() bool { - conn := closer.TailnetConn() + conn := uut.TailnetConn() if conn == nil { return false } regionIDs := conn.DERPMap().RegionIDs() - return len(regionIDs) == 1 && regionIDs[0] == 2 && conn.Node().PreferredDERP == 2 + preferredDERP := conn.Node().PreferredDERP + t.Logf("agent Conn DERPMap with regionIDs %v, PreferredDERP %d", regionIDs, preferredDERP) + return len(regionIDs) == 1 && regionIDs[0] == 2 && preferredDERP == 2 }, testutil.WaitLong, testutil.IntervalFast) + t.Logf("agent got the new DERPMap") // Connect from a second client and make sure it uses the new DERP map. - conn2 := newClientConn(newDerpMap) + conn2 := newClientConn(newDerpMap, "client2") require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs()) + t.Log("conn2 got the new DERPMap") // If the first client gets a DERP map update, it should be able to // reconnect just fine. conn1.SetDERPMap(newDerpMap) require.Equal(t, []int{2}, conn1.DERPMap().RegionIDs()) + t.Log("set the new DERPMap on conn1") ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() require.True(t, conn1.AwaitReachable(ctx)) + t.Log("conn1 reached agent with new DERP") } func TestAgent_Speedtest(t *testing.T) { @@ -2050,22 +2061,22 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati Logger: logger.Named("client"), }) require.NoError(t, err) - clientConn, serverConn := net.Pipe() - serveClientDone := make(chan struct{}) t.Cleanup(func() { - _ = clientConn.Close() - _ = serverConn.Close() _ = conn.Close() - <-serveClientDone }) - go func() { - defer close(serveClientDone) - coordinator.ServeClient(serverConn, uuid.New(), metadata.AgentID) - }() - sendNode, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error { - return conn.UpdateNodes(nodes, false) + testCtx, testCtxCancel := context.WithCancel(context.Background()) + t.Cleanup(testCtxCancel) + clientID := uuid.New() + coordination := tailnet.NewInMemoryCoordination( + testCtx, logger, + clientID, metadata.AgentID, + coordinator, conn) + t.Cleanup(func() { + err := coordination.Close() + if err != nil { + t.Logf("error closing in-mem coordination: %s", err.Error()) + } }) - conn.SetNodeCallback(sendNode) agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ AgentID: metadata.AgentID, }) diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index c63962f0e4a24..ddea2d749e39c 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -3,19 +3,26 @@ package agenttest import ( "context" "io" - "net" "sync" + "sync/atomic" "testing" "time" "github.com/google/uuid" + "github.com/stretchr/testify/require" "golang.org/x/exp/maps" "golang.org/x/xerrors" + "storj.io/drpc" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + "tailscale.com/tailcfg" "cdr.dev/slog" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" + drpcsdk "github.com/coder/coder/v2/codersdk/drpc" "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/testutil" ) @@ -24,11 +31,31 @@ func NewClient(t testing.TB, agentID uuid.UUID, manifest agentsdk.Manifest, statsChan chan *agentsdk.Stats, - coordinator tailnet.CoordinatorV1, + coordinator tailnet.Coordinator, ) *Client { if manifest.AgentID == uuid.Nil { manifest.AgentID = agentID } + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coordinator) + mux := drpcmux.New() + drpcService := &tailnet.DRPCService{ + CoordPtr: &coordPtr, + Logger: logger, + // TODO: handle DERPMap too! + DerpMapUpdateFrequency: time.Hour, + DerpMapFn: func() *tailcfg.DERPMap { panic("not implemented") }, + } + err := proto.DRPCRegisterTailnet(mux, drpcService) + require.NoError(t, err) + server := drpcserver.NewWithOptions(mux, drpcserver.Options{ + Log: func(err error) { + if xerrors.Is(err, io.EOF) { + return + } + logger.Debug(context.Background(), "drpc server error", slog.Error(err)) + }, + }) return &Client{ t: t, logger: logger.Named("client"), @@ -36,6 +63,7 @@ func NewClient(t testing.TB, manifest: manifest, statsChan: statsChan, coordinator: coordinator, + server: server, derpMapUpdates: make(chan agentsdk.DERPMapUpdate), } } @@ -47,7 +75,8 @@ type Client struct { manifest agentsdk.Manifest metadata map[string]agentsdk.Metadata statsChan chan *agentsdk.Stats - coordinator tailnet.CoordinatorV1 + coordinator tailnet.Coordinator + server *drpcserver.Server LastWorkspaceAgent func() PatchWorkspaceLogs func() error GetServiceBannerFunc func() (codersdk.ServiceBannerConfig, error) @@ -63,20 +92,29 @@ func (c *Client) Manifest(_ context.Context) (agentsdk.Manifest, error) { return c.manifest, nil } -func (c *Client) Listen(_ context.Context) (net.Conn, error) { - clientConn, serverConn := net.Pipe() +func (c *Client) Listen(_ context.Context) (drpc.Conn, error) { + conn, lis := drpcsdk.MemTransportPipe() closed := make(chan struct{}) c.LastWorkspaceAgent = func() { - _ = serverConn.Close() - _ = clientConn.Close() + _ = conn.Close() + _ = lis.Close() <-closed } c.t.Cleanup(c.LastWorkspaceAgent) + serveCtx, cancel := context.WithCancel(context.Background()) + c.t.Cleanup(cancel) + auth := tailnet.AgentTunnelAuth{} + streamID := tailnet.StreamID{ + Name: "agenttest", + ID: c.agentID, + Auth: auth, + } + serveCtx = tailnet.WithStreamID(serveCtx, streamID) go func() { - _ = c.coordinator.ServeAgent(serverConn, c.agentID, "") + _ = c.server.Serve(serveCtx, lis) close(closed) }() - return clientConn, nil + return conn, nil } func (c *Client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) { diff --git a/coderd/coderd_test.go b/coderd/coderd_test.go index 8d7c12974650f..4c98feffb7546 100644 --- a/coderd/coderd_test.go +++ b/coderd/coderd_test.go @@ -33,6 +33,7 @@ import ( "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisioner/echo" "github.com/coder/coder/v2/tailnet" + tailnetproto "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/testutil" ) @@ -98,14 +99,32 @@ func TestDERP(t *testing.T) { w2Ready := make(chan struct{}) w2ReadyOnce := sync.Once{} + w1ID := uuid.New() w1.SetNodeCallback(func(node *tailnet.Node) { - w2.UpdateNodes([]*tailnet.Node{node}, false) + pn, err := tailnet.NodeToProto(node) + if !assert.NoError(t, err) { + return + } + w2.UpdatePeers([]*tailnetproto.CoordinateResponse_PeerUpdate{{ + Id: w1ID[:], + Node: pn, + Kind: tailnetproto.CoordinateResponse_PeerUpdate_NODE, + }}) w2ReadyOnce.Do(func() { close(w2Ready) }) }) + w2ID := uuid.New() w2.SetNodeCallback(func(node *tailnet.Node) { - w1.UpdateNodes([]*tailnet.Node{node}, false) + pn, err := tailnet.NodeToProto(node) + if !assert.NoError(t, err) { + return + } + w1.UpdatePeers([]*tailnetproto.CoordinateResponse_PeerUpdate{{ + Id: w2ID[:], + Node: pn, + Kind: tailnetproto.CoordinateResponse_PeerUpdate_NODE, + }}) }) conn := make(chan struct{}) @@ -199,7 +218,11 @@ func TestDERPForceWebSockets(t *testing.T) { defer cancel() resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) - conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) + conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, + &codersdk.DialWorkspaceAgentOptions{ + Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug).Named("client"), + }, + ) require.NoError(t, err) defer func() { _ = conn.Close() diff --git a/coderd/tailnet.go b/coderd/tailnet.go index 6521d79149b48..5f3300711aad0 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -121,12 +121,23 @@ func NewServerTailnet( } tn.agentConn.Store(&agentConn) - err = tn.getAgentConn().UpdateSelf(conn.Node()) + pn, err := tailnet.NodeToProto(conn.Node()) if err != nil { - tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err)) + tn.logger.Critical(context.Background(), "failed to convert node", slog.Error(err)) + } else { + err = tn.getAgentConn().UpdateSelf(pn) + if err != nil { + tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err)) + } } + conn.SetNodeCallback(func(node *tailnet.Node) { - err := tn.getAgentConn().UpdateSelf(node) + pn, err := tailnet.NodeToProto(node) + if err != nil { + tn.logger.Critical(context.Background(), "failed to convert node", slog.Error(err)) + return + } + err = tn.getAgentConn().UpdateSelf(pn) if err != nil { tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err)) } @@ -191,21 +202,9 @@ func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) { // If no one has connected since the cutoff and there are no active // connections, remove the agent. if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 { - deleted, err := s.conn.RemovePeer(tailnet.PeerSelector{ - ID: tailnet.NodeID(agentID), - IP: netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128), - }) - if err != nil { - s.logger.Warn(ctx, "failed to remove peer from server tailnet", slog.Error(err)) - continue - } - if !deleted { - s.logger.Warn(ctx, "peer didn't exist in tailnet", slog.Error(err)) - } - deletedCount++ delete(s.agentConnectionTimes, agentID) - err = agentConn.UnsubscribeAgent(agentID) + err := agentConn.UnsubscribeAgent(agentID) if err != nil { s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID)) } @@ -221,7 +220,7 @@ func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) { func (s *ServerTailnet) watchAgentUpdates() { for { conn := s.getAgentConn() - nodes, ok := conn.NextUpdate(s.ctx) + resp, ok := conn.NextUpdate(s.ctx) if !ok { if conn.IsClosed() && s.ctx.Err() == nil { s.logger.Warn(s.ctx, "multiagent closed, reinitializing") @@ -231,7 +230,7 @@ func (s *ServerTailnet) watchAgentUpdates() { return } - err := s.conn.UpdateNodes(nodes, false) + err := s.conn.UpdatePeers(resp.GetPeerUpdates()) if err != nil { if xerrors.Is(err, tailnet.ErrConnClosed) { s.logger.Warn(context.Background(), "tailnet conn closed, exiting watchAgentUpdates", slog.Error(err)) diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go index 2a0b0dfdbae70..392bc8d306f49 100644 --- a/coderd/tailnet_test.go +++ b/coderd/tailnet_test.go @@ -3,7 +3,6 @@ package coderd_test import ( "context" "fmt" - "net" "net/http" "net/http/httptest" "net/netip" @@ -204,22 +203,20 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A Logger: logger.Named("client"), }) require.NoError(t, err) - clientConn, serverConn := net.Pipe() - serveClientDone := make(chan struct{}) t.Cleanup(func() { - _ = clientConn.Close() - _ = serverConn.Close() _ = conn.Close() - <-serveClientDone }) - go func() { - defer close(serveClientDone) - coord.ServeClient(serverConn, uuid.New(), manifest.AgentID) - }() - sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { - return conn.UpdateNodes(node, false) + clientID := uuid.New() + testCtx, testCtxCancel := context.WithCancel(context.Background()) + t.Cleanup(testCtxCancel) + coordination := tailnet.NewInMemoryCoordination( + testCtx, logger, + clientID, manifest.AgentID, + coord, conn, + ) + t.Cleanup(func() { + _ = coordination.Close() }) - conn.SetNodeCallback(sendNode) return codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ AgentID: manifest.AgentID, AgentIP: codersdk.WorkspaceAgentIP, diff --git a/coderd/util/apiversion/apiversion.go b/coderd/util/apiversion/apiversion.go index 7decaeab325c7..225fe01785724 100644 --- a/coderd/util/apiversion/apiversion.go +++ b/coderd/util/apiversion/apiversion.go @@ -30,6 +30,10 @@ func (v *APIVersion) WithBackwardCompat(majs ...int) *APIVersion { return v } +func (v *APIVersion) String() string { + return fmt.Sprintf("%d.%d", v.supportedMajor, v.supportedMinor) +} + // Validate validates the given version against the given constraints: // A given major.minor version is valid iff: // 1. The requested major version is contained within v.supportedMajors @@ -42,10 +46,6 @@ func (v *APIVersion) WithBackwardCompat(majs ...int) *APIVersion { // - 1.x is supported, // - 2.0, 2.1, and 2.2 are supported, // - 2.3+ is not supported. -func (v *APIVersion) String() string { - return fmt.Sprintf("%d.%d", v.supportedMajor, v.supportedMinor) -} - func (v *APIVersion) Validate(version string) error { major, minor, err := Parse(version) if err != nil { diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 1e48ea0e7a088..ad508eebed25e 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -857,8 +857,6 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req // Deprecated: use api.tailnet.AgentConn instead. // See: https://github.com/coder/coder/issues/8218 func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { - clientConn, serverConn := net.Pipe() - derpMap := api.DERPMap() conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, @@ -868,8 +866,6 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa BlockEndpoints: api.DeploymentValues.DERP.Config.BlockDirect.Value(), }) if err != nil { - _ = clientConn.Close() - _ = serverConn.Close() return nil, xerrors.Errorf("create tailnet conn: %w", err) } ctx, cancel := context.WithCancel(api.ctx) @@ -887,10 +883,10 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa return left }) - sendNodes, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error { - return conn.UpdateNodes(nodes, true) - }) - conn.SetNodeCallback(sendNodes) + clientID := uuid.New() + coordination := tailnet.NewInMemoryCoordination(ctx, api.Logger, + clientID, agentID, + *(api.TailnetCoordinator.Load()), conn) // Check for updated DERP map every 5 seconds. go func() { @@ -920,27 +916,13 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa AgentID: agentID, AgentIP: codersdk.WorkspaceAgentIP, CloseFunc: func() error { + _ = coordination.Close() cancel() - _ = clientConn.Close() - _ = serverConn.Close() return nil }, }) - go func() { - err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID) - if err != nil { - // Sometimes, we get benign closed pipe errors when the server is - // shutting down. - if api.ctx.Err() == nil { - api.Logger.Warn(ctx, "tailnet coordinator client error", slog.Error(err)) - } - _ = agentConn.Close() - } - }() if !agentConn.AwaitReachable(ctx) { _ = agentConn.Close() - _ = serverConn.Close() - _ = clientConn.Close() cancel() return nil, xerrors.Errorf("agent not reachable") } diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 0d620c991e6dd..9d5fd8da1befd 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -535,7 +535,6 @@ func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) { }) require.NoError(t, err) defer conn.Close() - require.True(t, conn.BlockEndpoints()) require.True(t, conn.AwaitReachable(ctx)) _, p2p, _, err := conn.Ping(ctx) diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index c824159a810ed..8a66e3ba0364f 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -12,14 +12,19 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/atomic" "go.uber.org/goleak" + "golang.org/x/xerrors" + "storj.io/drpc" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + "tailscale.com/tailcfg" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -27,7 +32,9 @@ import ( "github.com/coder/coder/v2/coderd/wsconncache" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" + drpcsdk "github.com/coder/coder/v2/codersdk/drpc" "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/testutil" ) @@ -41,7 +48,7 @@ func TestCache(t *testing.T) { t.Run("Same", func(t *testing.T) { t.Parallel() cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { - return setupAgent(t, agentsdk.Manifest{}, 0), nil + return setupAgent(t, agentsdk.Manifest{}, 0) }, 0) defer func() { _ = cache.Close() @@ -54,10 +61,10 @@ func TestCache(t *testing.T) { }) t.Run("Expire", func(t *testing.T) { t.Parallel() - called := atomic.NewInt32(0) + called := int32(0) cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { - called.Add(1) - return setupAgent(t, agentsdk.Manifest{}, 0), nil + atomic.AddInt32(&called, 1) + return setupAgent(t, agentsdk.Manifest{}, 0) }, time.Microsecond) defer func() { _ = cache.Close() @@ -70,12 +77,12 @@ func TestCache(t *testing.T) { require.NoError(t, err) release() <-conn.Closed() - require.Equal(t, int32(2), called.Load()) + require.Equal(t, int32(2), called) }) t.Run("NoExpireWhenLocked", func(t *testing.T) { t.Parallel() cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { - return setupAgent(t, agentsdk.Manifest{}, 0), nil + return setupAgent(t, agentsdk.Manifest{}, 0) }, time.Microsecond) defer func() { _ = cache.Close() @@ -108,7 +115,7 @@ func TestCache(t *testing.T) { go server.Serve(random) cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { - return setupAgent(t, agentsdk.Manifest{}, 0), nil + return setupAgent(t, agentsdk.Manifest{}, 0) }, time.Microsecond) defer func() { _ = cache.Close() @@ -154,7 +161,7 @@ func TestCache(t *testing.T) { }) } -func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn { +func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) (*codersdk.WorkspaceAgentConn, error) { t.Helper() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) manifest.DERPMap, _ = tailnettest.RunDERPAndSTUN(t) @@ -184,18 +191,25 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati DERPForceWebSockets: manifest.DERPForceWebSockets, Logger: slogtest.Make(t, nil).Named("tailnet").Leveled(slog.LevelDebug), }) - require.NoError(t, err) - clientConn, serverConn := net.Pipe() + // setupAgent is called by wsconncache Dialer, so we can't use require here as it will end the + // test, which in turn closes the wsconncache, which in turn waits for the Dialer and deadlocks. + if !assert.NoError(t, err) { + return nil, err + } t.Cleanup(func() { - _ = clientConn.Close() - _ = serverConn.Close() _ = conn.Close() }) - go coordinator.ServeClient(serverConn, uuid.New(), manifest.AgentID) - sendNode, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error { - return conn.UpdateNodes(nodes, false) + clientID := uuid.New() + testCtx, testCtxCancel := context.WithCancel(context.Background()) + t.Cleanup(testCtxCancel) + coordination := tailnet.NewInMemoryCoordination( + testCtx, logger, + clientID, manifest.AgentID, + coordinator, conn, + ) + t.Cleanup(func() { + _ = coordination.Close() }) - conn.SetNodeCallback(sendNode) agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ AgentID: manifest.AgentID, AgentIP: codersdk.WorkspaceAgentIP, @@ -206,16 +220,20 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) defer cancel() if !agentConn.AwaitReachable(ctx) { - t.Fatal("agent not reachable") + // setupAgent is called by wsconncache Dialer, so we can't use t.Fatal here as it will end + // the test, which in turn closes the wsconncache, which in turn waits for the Dialer and + // deadlocks. + t.Error("agent not reachable") + return nil, xerrors.New("agent not reachable") } - return agentConn + return agentConn, nil } type client struct { t *testing.T agentID uuid.UUID manifest agentsdk.Manifest - coordinator tailnet.CoordinatorV1 + coordinator tailnet.Coordinator } func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) { @@ -240,19 +258,53 @@ func (*client) DERPMapUpdates(_ context.Context) (<-chan agentsdk.DERPMapUpdate, }, nil } -func (c *client) Listen(_ context.Context) (net.Conn, error) { - clientConn, serverConn := net.Pipe() +func (c *client) Listen(_ context.Context) (drpc.Conn, error) { + logger := slogtest.Make(c.t, nil).Leveled(slog.LevelDebug).Named("drpc") + conn, lis := drpcsdk.MemTransportPipe() closed := make(chan struct{}) c.t.Cleanup(func() { - _ = serverConn.Close() - _ = clientConn.Close() + _ = conn.Close() + _ = lis.Close() <-closed }) + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&c.coordinator) + mux := drpcmux.New() + drpcService := &tailnet.DRPCService{ + CoordPtr: &coordPtr, + Logger: logger, + // TODO: handle DERPMap too! + DerpMapUpdateFrequency: time.Hour, + DerpMapFn: func() *tailcfg.DERPMap { panic("not implemented") }, + } + err := proto.DRPCRegisterTailnet(mux, drpcService) + if err != nil { + return nil, xerrors.Errorf("register DRPC service: %w", err) + } + server := drpcserver.NewWithOptions(mux, drpcserver.Options{ + Log: func(err error) { + if xerrors.Is(err, io.EOF) || + xerrors.Is(err, context.Canceled) || + xerrors.Is(err, context.DeadlineExceeded) { + return + } + logger.Debug(context.Background(), "drpc server error", slog.Error(err)) + }, + }) + serveCtx, cancel := context.WithCancel(context.Background()) + c.t.Cleanup(cancel) + auth := tailnet.AgentTunnelAuth{} + streamID := tailnet.StreamID{ + Name: "wsconncache_test-agent", + ID: c.agentID, + Auth: auth, + } + serveCtx = tailnet.WithStreamID(serveCtx, streamID) go func() { - _ = c.coordinator.ServeAgent(serverConn, c.agentID, "") + server.Serve(serveCtx, lis) close(closed) }() - return clientConn, nil + return conn, nil } func (*client) ReportStats(_ context.Context, _ slog.Logger, _ <-chan *agentsdk.Stats, _ func(time.Duration)) (io.Closer, error) { diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index b1960bc7d260a..2b65f3a316ff9 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -14,12 +14,15 @@ import ( "cloud.google.com/go/compute/metadata" "github.com/google/uuid" + "github.com/hashicorp/yamux" "golang.org/x/xerrors" "nhooyr.io/websocket" + "storj.io/drpc" "tailscale.com/tailcfg" "cdr.dev/slog" "github.com/coder/coder/v2/codersdk" + drpcsdk "github.com/coder/coder/v2/codersdk/drpc" "github.com/coder/retry" ) @@ -280,8 +283,8 @@ func (c *Client) DERPMapUpdates(ctx context.Context) (<-chan DERPMapUpdate, io.C // Listen connects to the workspace agent coordinate WebSocket // that handles connection negotiation. -func (c *Client) Listen(ctx context.Context) (net.Conn, error) { - coordinateURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/coordinate") +func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) { + coordinateURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/rpc") if err != nil { return nil, xerrors.Errorf("parse url: %w", err) } @@ -312,14 +315,21 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) { ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) pingClosed := pingWebSocket(ctx, c.SDK.Logger(), conn, "coordinate") - return &closeNetConn{ + netConn := &closeNetConn{ Conn: wsNetConn, closeFunc: func() { cancelFunc() _ = conn.Close(websocket.StatusGoingAway, "Listen closed") <-pingClosed }, - }, nil + } + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + session, err := yamux.Client(netConn, config) + if err != nil { + return nil, xerrors.Errorf("multiplex client: %w", err) + } + return drpcsdk.MultiplexedConn(session), nil } type PostAppHealthsRequest struct { diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index d3cfabcb63469..8165b78c125de 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -313,6 +313,9 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, if err != nil { return nil, xerrors.Errorf("parse url: %w", err) } + q := coordinateURL.Query() + q.Add("version", tailnet.CurrentVersion.String()) + coordinateURL.RawQuery = q.Encode() closedCoordinator := make(chan struct{}) // Must only ever be used once, send error OR close to avoid // reassignment race. Buffered so we don't hang in goroutine. @@ -344,12 +347,22 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, options.Logger.Debug(ctx, "failed to dial", slog.Error(err)) continue } - sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(nodes []*tailnet.Node) error { - return conn.UpdateNodes(nodes, false) - }) - conn.SetNodeCallback(sendNode) + client, err := tailnet.NewDRPCClient(websocket.NetConn(ctx, ws, websocket.MessageBinary)) + if err != nil { + options.Logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err)) + _ = ws.Close(websocket.StatusInternalError, "") + continue + } + coordinate, err := client.Coordinate(ctx) + if err != nil { + options.Logger.Debug(ctx, "failed to reach the Coordinate endpoint", slog.Error(err)) + _ = ws.Close(websocket.StatusInternalError, "") + continue + } + + coordination := tailnet.NewRemoteCoordination(options.Logger, coordinate, conn, agentID) options.Logger.Debug(ctx, "serving coordinator") - err = <-errChan + err = <-coordination.Error() if errors.Is(err, context.Canceled) { _ = ws.Close(websocket.StatusGoingAway, "") return diff --git a/enterprise/coderd/workspaceproxycoordinate.go b/enterprise/coderd/workspaceproxycoordinate.go index bf291e45cecfb..4fe25827b52cc 100644 --- a/enterprise/coderd/workspaceproxycoordinate.go +++ b/enterprise/coderd/workspaceproxycoordinate.go @@ -8,6 +8,7 @@ import ( "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/util/apiversion" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" agpl "github.com/coder/coder/v2/tailnet" @@ -53,6 +54,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request ctx := r.Context() version := "1.0" + msgType := websocket.MessageText qv := r.URL.Query().Get("version") if qv != "" { version = qv @@ -66,6 +68,11 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request }) return } + maj, _, _ := apiversion.Parse(version) + if maj >= 2 { + // Versions 2+ use dRPC over a binary connection + msgType = websocket.MessageBinary + } api.AGPL.WebsocketWaitMutex.Lock() api.AGPL.WebsocketWaitGroup.Add(1) @@ -81,7 +88,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request return } - ctx, nc := websocketNetConn(ctx, conn, websocket.MessageText) + ctx, nc := websocketNetConn(ctx, conn, msgType) defer nc.Close() id := uuid.New() diff --git a/enterprise/coderd/workspaceproxycoordinator_test.go b/enterprise/coderd/workspaceproxycoordinator_test.go index 83bbb5c49d1fa..38ba957bf61df 100644 --- a/enterprise/coderd/workspaceproxycoordinator_test.go +++ b/enterprise/coderd/workspaceproxycoordinator_test.go @@ -10,6 +10,7 @@ import ( "github.com/moby/moby/pkg/namesgenerator" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" "tailscale.com/types/key" "cdr.dev/slog/sloggers/slogtest" @@ -20,6 +21,7 @@ import ( "github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" agpl "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/testutil" ) @@ -27,6 +29,12 @@ import ( func Test_agentIsLegacy(t *testing.T) { t.Parallel() + nodeKey := key.NewNode().Public() + discoKey := key.NewDisco().Public() + nkBin, err := nodeKey.MarshalBinary() + require.NoError(t, err) + dkBin, err := discoKey.MarshalText() + require.NoError(t, err) t.Run("Legacy", func(t *testing.T) { t.Parallel() @@ -54,18 +62,18 @@ func Test_agentIsLegacy(t *testing.T) { nodeID := uuid.New() ma := coordinator.ServeMultiAgent(nodeID) defer ma.Close() - require.NoError(t, ma.UpdateSelf(&agpl.Node{ - ID: 55, - AsOf: time.Unix(1689653252, 0), - Key: key.NewNode().Public(), - DiscoKey: key.NewDisco().Public(), - PreferredDERP: 0, - DERPLatency: map[string]float64{ + require.NoError(t, ma.UpdateSelf(&proto.Node{ + Id: 55, + AsOf: timestamppb.New(time.Unix(1689653252, 0)), + Key: nkBin, + Disco: string(dkBin), + PreferredDerp: 0, + DerpLatency: map[string]float64{ "0": 1.0, }, - DERPForcedWebsocket: map[int]string{}, - Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, - AllowedIPs: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, + DerpForcedWebsocket: map[int32]string{}, + Addresses: []string{codersdk.WorkspaceAgentIP.String() + "/128"}, + AllowedIps: []string{codersdk.WorkspaceAgentIP.String() + "/128"}, Endpoints: []string{"192.168.1.1:18842"}, })) require.Eventually(t, func() bool { @@ -114,18 +122,18 @@ func Test_agentIsLegacy(t *testing.T) { nodeID := uuid.New() ma := coordinator.ServeMultiAgent(nodeID) defer ma.Close() - require.NoError(t, ma.UpdateSelf(&agpl.Node{ - ID: 55, - AsOf: time.Unix(1689653252, 0), - Key: key.NewNode().Public(), - DiscoKey: key.NewDisco().Public(), - PreferredDERP: 0, - DERPLatency: map[string]float64{ + require.NoError(t, ma.UpdateSelf(&proto.Node{ + Id: 55, + AsOf: timestamppb.New(time.Unix(1689653252, 0)), + Key: nkBin, + Disco: string(dkBin), + PreferredDerp: 0, + DerpLatency: map[string]float64{ "0": 1.0, }, - DERPForcedWebsocket: map[int]string{}, - Addresses: []netip.Prefix{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128)}, - AllowedIPs: []netip.Prefix{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128)}, + DerpForcedWebsocket: map[int32]string{}, + Addresses: []string{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128).String()}, + AllowedIps: []string{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128).String()}, Endpoints: []string{"192.168.1.1:18842"}, })) require.Eventually(t, func() bool { diff --git a/enterprise/tailnet/multiagent_test.go b/enterprise/tailnet/multiagent_test.go index e51cab881482b..c9f8f73fe93f8 100644 --- a/enterprise/tailnet/multiagent_test.go +++ b/enterprise/tailnet/multiagent_test.go @@ -6,12 +6,15 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" + "tailscale.com/types/key" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/enterprise/tailnet" agpl "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/testutil" ) @@ -39,22 +42,19 @@ func TestPGCoordinator_MultiAgent(t *testing.T) { defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - id := uuid.New() - ma1 := coord1.ServeMultiAgent(id) - defer ma1.Close() + ma1 := newTestMultiAgent(t, coord1) + defer ma1.close() - err = ma1.SubscribeAgent(agent1.id) - require.NoError(t, err) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + ma1.subscribeAgent(agent1.id) + ma1.assertEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + ma1.assertEventuallyHasDERPs(ctx, 1) - err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) - require.NoError(t, err) + ma1.sendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) - require.NoError(t, ma1.Close()) + ma1.close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -86,23 +86,20 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) { defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - id := uuid.New() - ma1 := coord1.ServeMultiAgent(id) - defer ma1.Close() + ma1 := newTestMultiAgent(t, coord1) + defer ma1.close() - err = ma1.SubscribeAgent(agent1.id) - require.NoError(t, err) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + ma1.subscribeAgent(agent1.id) + ma1.assertEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + ma1.assertEventuallyHasDERPs(ctx, 1) - err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) - require.NoError(t, err) + ma1.sendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) - require.NoError(t, ma1.UnsubscribeAgent(agent1.id)) - require.NoError(t, ma1.Close()) + ma1.unsubscribeAgent(agent1.id) + ma1.close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -134,37 +131,35 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) { defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - id := uuid.New() - ma1 := coord1.ServeMultiAgent(id) - defer ma1.Close() + ma1 := newTestMultiAgent(t, coord1) + defer ma1.close() - err = ma1.SubscribeAgent(agent1.id) - require.NoError(t, err) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + ma1.subscribeAgent(agent1.id) + ma1.assertEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + ma1.assertEventuallyHasDERPs(ctx, 1) - require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})) + ma1.sendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) - require.NoError(t, ma1.UnsubscribeAgent(agent1.id)) + ma1.unsubscribeAgent(agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) func() { ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) defer cancel() - require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 9})) + ma1.sendNodeWithDERP(9) assertNeverHasDERPs(ctx, t, agent1, 9) }() func() { ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) defer cancel() agent1.sendNode(&agpl.Node{PreferredDERP: 8}) - assertMultiAgentNeverHasDERPs(ctx, t, ma1, 8) + ma1.assertNeverHasDERPs(ctx, 8) }() - require.NoError(t, ma1.Close()) + ma1.close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -201,22 +196,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) { defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - id := uuid.New() - ma1 := coord2.ServeMultiAgent(id) - defer ma1.Close() + ma1 := newTestMultiAgent(t, coord2) + defer ma1.close() - err = ma1.SubscribeAgent(agent1.id) - require.NoError(t, err) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + ma1.subscribeAgent(agent1.id) + ma1.assertEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + ma1.assertEventuallyHasDERPs(ctx, 1) - err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) - require.NoError(t, err) + ma1.sendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) - require.NoError(t, ma1.Close()) + ma1.close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -254,22 +246,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - id := uuid.New() - ma1 := coord2.ServeMultiAgent(id) - defer ma1.Close() + ma1 := newTestMultiAgent(t, coord2) + defer ma1.close() - err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) - require.NoError(t, err) + ma1.sendNodeWithDERP(3) - err = ma1.SubscribeAgent(agent1.id) - require.NoError(t, err) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + ma1.subscribeAgent(agent1.id) + ma1.assertEventuallyHasDERPs(ctx, 5) assertEventuallyHasDERPs(ctx, t, agent1, 3) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + ma1.assertEventuallyHasDERPs(ctx, 1) - require.NoError(t, ma1.Close()) + ma1.close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -316,33 +305,129 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) { defer agent1.close() agent2.sendNode(&agpl.Node{PreferredDERP: 6}) - id := uuid.New() - ma1 := coord3.ServeMultiAgent(id) - defer ma1.Close() + ma1 := newTestMultiAgent(t, coord3) + defer ma1.close() - err = ma1.SubscribeAgent(agent1.id) - require.NoError(t, err) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + ma1.subscribeAgent(agent1.id) + ma1.assertEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + ma1.assertEventuallyHasDERPs(ctx, 1) - err = ma1.SubscribeAgent(agent2.id) - require.NoError(t, err) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 6) + ma1.subscribeAgent(agent2.id) + ma1.assertEventuallyHasDERPs(ctx, 6) agent2.sendNode(&agpl.Node{PreferredDERP: 2}) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 2) + ma1.assertEventuallyHasDERPs(ctx, 2) - err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) - require.NoError(t, err) + ma1.sendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) assertEventuallyHasDERPs(ctx, t, agent2, 3) - require.NoError(t, ma1.Close()) + ma1.close() require.NoError(t, agent1.close()) require.NoError(t, agent2.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyLost(ctx, t, store, agent1.id) } + +type testMultiAgent struct { + t testing.TB + id uuid.UUID + a agpl.MultiAgentConn + nodeKey []byte + discoKey string +} + +func newTestMultiAgent(t testing.TB, coord agpl.Coordinator) *testMultiAgent { + nk, err := key.NewNode().Public().MarshalBinary() + require.NoError(t, err) + dk, err := key.NewDisco().Public().MarshalText() + require.NoError(t, err) + m := &testMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)} + m.a = coord.ServeMultiAgent(m.id) + return m +} + +func (m *testMultiAgent) sendNodeWithDERP(derp int32) { + m.t.Helper() + err := m.a.UpdateSelf(&proto.Node{ + Key: m.nodeKey, + Disco: m.discoKey, + PreferredDerp: derp, + }) + require.NoError(m.t, err) +} + +func (m *testMultiAgent) close() { + m.t.Helper() + err := m.a.Close() + require.NoError(m.t, err) +} + +func (m *testMultiAgent) subscribeAgent(id uuid.UUID) { + m.t.Helper() + err := m.a.SubscribeAgent(id) + require.NoError(m.t, err) +} + +func (m *testMultiAgent) unsubscribeAgent(id uuid.UUID) { + m.t.Helper() + err := m.a.UnsubscribeAgent(id) + require.NoError(m.t, err) +} + +func (m *testMultiAgent) assertEventuallyHasDERPs(ctx context.Context, expected ...int) { + m.t.Helper() + for { + resp, ok := m.a.NextUpdate(ctx) + require.True(m.t, ok) + nodes, err := agpl.OnlyNodeUpdates(resp) + require.NoError(m.t, err) + if len(nodes) != len(expected) { + m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) + continue + } + + derps := make([]int, 0, len(nodes)) + for _, n := range nodes { + derps = append(derps, n.PreferredDERP) + } + for _, e := range expected { + if !slices.Contains(derps, e) { + m.t.Logf("expected DERP %d to be in %v", e, derps) + continue + } + return + } + } +} + +func (m *testMultiAgent) assertNeverHasDERPs(ctx context.Context, expected ...int) { + m.t.Helper() + for { + resp, ok := m.a.NextUpdate(ctx) + if !ok { + return + } + nodes, err := agpl.OnlyNodeUpdates(resp) + require.NoError(m.t, err) + if len(nodes) != len(expected) { + m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) + continue + } + + derps := make([]int, 0, len(nodes)) + for _, n := range nodes { + derps = append(derps, n.PreferredDERP) + } + for _, e := range expected { + if !slices.Contains(derps, e) { + m.t.Logf("expected DERP %d to be in %v", e, derps) + continue + } + return + } + } +} diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 63ee818eae45c..c26418ff6677b 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -818,56 +818,6 @@ func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expecte } } -func assertMultiAgentEventuallyHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) { - t.Helper() - for { - nodes, ok := ma.NextUpdate(ctx) - require.True(t, ok) - if len(nodes) != len(expected) { - t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) - continue - } - - derps := make([]int, 0, len(nodes)) - for _, n := range nodes { - derps = append(derps, n.PreferredDERP) - } - for _, e := range expected { - if !slices.Contains(derps, e) { - t.Logf("expected DERP %d to be in %v", e, derps) - continue - } - return - } - } -} - -func assertMultiAgentNeverHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) { - t.Helper() - for { - nodes, ok := ma.NextUpdate(ctx) - if !ok { - return - } - if len(nodes) != len(expected) { - t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) - continue - } - - derps := make([]int, 0, len(nodes)) - for _, n := range nodes { - derps = append(derps, n.PreferredDERP) - } - for _, e := range expected { - if !slices.Contains(derps, e) { - t.Logf("expected DERP %d to be in %v", e, derps) - continue - } - return - } - } -} - func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { t.Helper() assert.Eventually(t, func() bool { diff --git a/enterprise/tailnet/workspaceproxy.go b/enterprise/tailnet/workspaceproxy.go index 0471c076b0485..d8f64aa3985a7 100644 --- a/enterprise/tailnet/workspaceproxy.go +++ b/enterprise/tailnet/workspaceproxy.go @@ -96,7 +96,11 @@ func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentC return xerrors.Errorf("unsubscribe agent: %w", err) } case wsproxysdk.CoordinateMessageTypeNodeUpdate: - err := ma.UpdateSelf(msg.Node) + pn, err := agpl.NodeToProto(msg.Node) + if err != nil { + return err + } + err = ma.UpdateSelf(pn) if err != nil { return xerrors.Errorf("update self: %w", err) } @@ -110,11 +114,14 @@ func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentC func forwardNodesToWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error { var lastData []byte for { - nodes, ok := ma.NextUpdate(ctx) + resp, ok := ma.NextUpdate(ctx) if !ok { return xerrors.New("multiagent is closed") } - + nodes, err := agpl.OnlyNodeUpdates(resp) + if err != nil { + return xerrors.Errorf("failed to convert response: %w", err) + } data, err := json.Marshal(wsproxysdk.CoordinateNodes{Nodes: nodes}) if err != nil { return err diff --git a/enterprise/wsproxy/wsproxy.go b/enterprise/wsproxy/wsproxy.go index cbf9695bd77b6..fe4b1d3b22a0b 100644 --- a/enterprise/wsproxy/wsproxy.go +++ b/enterprise/wsproxy/wsproxy.go @@ -158,7 +158,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) { // TODO: Probably do some version checking here info, err := client.SDKClient.BuildInfo(ctx) if err != nil { - return nil, fmt.Errorf("buildinfo: %w", errors.Join( + return nil, xerrors.Errorf("buildinfo: %w", errors.Join( xerrors.Errorf("unable to fetch build info from primary coderd. Are you sure %q is a coderd instance?", opts.DashboardURL), err, )) diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index 142d0b5c1ee57..f8d8c22543b47 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io" - "net" "net/http" "net/url" "sync" @@ -23,6 +22,7 @@ import ( "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/codersdk" agpl "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" ) // Client is a HTTP client for a subset of Coder API routes that external @@ -438,6 +438,9 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro cancel() return nil, xerrors.Errorf("parse url: %w", err) } + q := coordinateURL.Query() + q.Add("version", agpl.CurrentVersion.String()) + coordinateURL.RawQuery = q.Encode() coordinateHeaders := make(http.Header) tokenHeader := codersdk.SessionTokenHeader if c.SDKClient.SessionTokenHeader != "" { @@ -457,10 +460,24 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro go httpapi.HeartbeatClose(ctx, logger, cancel, conn) - nc := websocket.NetConn(ctx, conn, websocket.MessageText) + nc := websocket.NetConn(ctx, conn, websocket.MessageBinary) + client, err := agpl.NewDRPCClient(nc) + if err != nil { + logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err)) + _ = conn.Close(websocket.StatusInternalError, "") + return nil, xerrors.Errorf("failed to create DRPCClient: %w", err) + } + protocol, err := client.Coordinate(ctx) + if err != nil { + logger.Debug(ctx, "failed to reach the Coordinate endpoint", slog.Error(err)) + _ = conn.Close(websocket.StatusInternalError, "") + return nil, xerrors.Errorf("failed to reach the Coordinate endpoint: %w", err) + } + rma := remoteMultiAgentHandler{ sdk: c, - nc: nc, + logger: logger, + protocol: protocol, cancel: cancel, legacyAgentCache: map[uuid.UUID]bool{}, } @@ -471,103 +488,75 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro OnSubscribe: rma.OnSubscribe, OnUnsubscribe: rma.OnUnsubscribe, OnNodeUpdate: rma.OnNodeUpdate, - OnRemove: func(agpl.Queue) { conn.Close(websocket.StatusGoingAway, "closed") }, + OnRemove: rma.OnRemove, }).Init() go func() { <-ctx.Done() ma.Close() + _ = conn.Close(websocket.StatusGoingAway, "closed") }() - go func() { - defer cancel() - dec := json.NewDecoder(nc) - for { - var msg CoordinateNodes - err := dec.Decode(&msg) - if err != nil { - if xerrors.Is(err, io.EOF) { - logger.Info(ctx, "websocket connection severed", slog.Error(err)) - return - } - - logger.Error(ctx, "decode coordinator nodes", slog.Error(err)) - return - } - - err = ma.Enqueue(msg.Nodes) - if err != nil { - logger.Error(ctx, "enqueue nodes from coordinator", slog.Error(err)) - continue - } - } - }() + rma.ma = ma + go rma.respLoop() return ma, nil } type remoteMultiAgentHandler struct { - sdk *Client - nc net.Conn - cancel func() + sdk *Client + logger slog.Logger + protocol proto.DRPCTailnet_CoordinateClient + ma *agpl.MultiAgent + cancel func() legacyMu sync.RWMutex legacyAgentCache map[uuid.UUID]bool legacySingleflight singleflight.Group[uuid.UUID, AgentIsLegacyResponse] } -func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error { - data, err := json.Marshal(v) - if err != nil { - return xerrors.Errorf("json marshal message: %w", err) - } +func (a *remoteMultiAgentHandler) respLoop() { + { + defer a.cancel() + for { + resp, err := a.protocol.Recv() + if err != nil { + if xerrors.Is(err, io.EOF) { + a.logger.Info(context.Background(), "remote multiagent connection severed", slog.Error(err)) + return + } - // Set a deadline so that hung connections don't put back pressure on the system. - // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. - err = a.nc.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout)) - if err != nil { - a.cancel() - return xerrors.Errorf("set write deadline: %w", err) - } - _, err = a.nc.Write(data) - if err != nil { - a.cancel() - return xerrors.Errorf("write message: %w", err) - } + a.logger.Error(context.Background(), "error receiving multiagent responses", slog.Error(err)) + return + } - // nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are - // *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write() - // fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context. - // If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after - // our successful write, it is important that we reset the deadline before it fires. - err = a.nc.SetWriteDeadline(time.Time{}) - if err != nil { - a.cancel() - return xerrors.Errorf("clear write deadline: %w", err) + err = a.ma.Enqueue(resp) + if err != nil { + a.logger.Error(context.Background(), "enqueue response from coordinator", slog.Error(err)) + continue + } + } } - - return nil } -func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *agpl.Node) error { - return a.writeJSON(CoordinateMessage{ - Type: CoordinateMessageTypeNodeUpdate, - Node: node, - }) +func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *proto.Node) error { + return a.protocol.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: node}}) } -func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) (*agpl.Node, error) { - return nil, a.writeJSON(CoordinateMessage{ - Type: CoordinateMessageTypeSubscribe, - AgentID: agentID, - }) +func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) error { + return a.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}}) } func (a *remoteMultiAgentHandler) OnUnsubscribe(_ agpl.Queue, agentID uuid.UUID) error { - return a.writeJSON(CoordinateMessage{ - Type: CoordinateMessageTypeUnsubscribe, - AgentID: agentID, - }) + return a.protocol.Send(&proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}}) +} + +func (a *remoteMultiAgentHandler) OnRemove(_ agpl.Queue) { + err := a.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) + if err != nil { + a.logger.Warn(context.Background(), "failed to gracefully disconnect", slog.Error(err)) + } + _ = a.protocol.CloseSend() } func (a *remoteMultiAgentHandler) AgentIsLegacy(agentID uuid.UUID) bool { diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go index 1901b3207be15..8cf8b1ee18d08 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go @@ -18,8 +18,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" - "golang.org/x/xerrors" + "google.golang.org/protobuf/types/known/timestamppb" "nhooyr.io/websocket" + "tailscale.com/tailcfg" "tailscale.com/types/key" "cdr.dev/slog" @@ -30,6 +31,7 @@ import ( "github.com/coder/coder/v2/enterprise/tailnet" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" agpl "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/testutil" ) @@ -156,25 +158,48 @@ func TestDialCoordinator(t *testing.T) { t.Run("OK", func(t *testing.T) { t.Parallel() var ( - ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort) - logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) - agentID = uuid.New() - serverMultiAgent = tailnettest.NewMockMultiAgentConn(gomock.NewController(t)) - r = chi.NewRouter() - srv = httptest.NewServer(r) + ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort) + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + agentID = uuid.UUID{33} + proxyID = uuid.UUID{44} + mCoord = tailnettest.NewMockCoordinator(gomock.NewController(t)) + coord agpl.Coordinator = mCoord + r = chi.NewRouter() + srv = httptest.NewServer(r) ) defer cancel() + defer srv.Close() + + coordPtr := atomic.Pointer[agpl.Coordinator]{} + coordPtr.Store(&coord) + cSrv, err := tailnet.NewClientService( + logger, &coordPtr, + time.Hour, + func() *tailcfg.DERPMap { panic("not implemented") }, + ) + require.NoError(t, err) + + // buffer the channels here, so we don't need to read and write in goroutines to + // avoid blocking + reqs := make(chan *proto.CoordinateRequest, 100) + resps := make(chan *proto.CoordinateResponse, 100) + mCoord.EXPECT().Coordinate(gomock.Any(), proxyID, gomock.Any(), agpl.SingleTailnetTunnelAuth{}). + Times(1). + Return(reqs, resps) + serveMACErr := make(chan error, 1) r.Get("/api/v2/workspaceproxies/me/coordinate", func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) - require.NoError(t, err) - nc := websocket.NetConn(r.Context(), conn, websocket.MessageText) - defer serverMultiAgent.Close() - - err = tailnet.ServeWorkspaceProxy(ctx, nc, serverMultiAgent) - if !xerrors.Is(err, io.EOF) { - assert.NoError(t, err) + if !assert.NoError(t, err) { + return + } + version := r.URL.Query().Get("version") + if !assert.Equal(t, version, agpl.CurrentVersion.String()) { + return } + nc := websocket.NetConn(r.Context(), conn, websocket.MessageBinary) + err = cSrv.ServeMultiAgentClient(ctx, version, nc, proxyID) + serveMACErr <- err }) r.Get("/api/v2/workspaceagents/{workspaceagent}/legacy", func(w http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, w, http.StatusOK, wsproxysdk.AgentIsLegacyResponse{ @@ -188,51 +213,50 @@ func TestDialCoordinator(t *testing.T) { client := wsproxysdk.New(u) client.SDKClient.SetLogger(logger) - expected := []*agpl.Node{{ - ID: 55, - AsOf: time.Unix(1689653252, 0), - Key: key.NewNode().Public(), - DiscoKey: key.NewDisco().Public(), - PreferredDERP: 0, - DERPLatency: map[string]float64{ - "0": 1.0, + peerID := uuid.UUID{55} + peerNodeKey, err := key.NewNode().Public().MarshalBinary() + require.NoError(t, err) + peerDiscoKey, err := key.NewDisco().Public().MarshalText() + require.NoError(t, err) + expected := &proto.CoordinateResponse{PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{ + Id: peerID[:], + Node: &proto.Node{ + Id: 55, + AsOf: timestamppb.New(time.Unix(1689653252, 0)), + Key: peerNodeKey[:], + Disco: string(peerDiscoKey), + PreferredDerp: 0, + DerpLatency: map[string]float64{ + "0": 1.0, + }, + DerpForcedWebsocket: map[int32]string{}, + Addresses: []string{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128).String()}, + AllowedIps: []string{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128).String()}, + Endpoints: []string{"192.168.1.1:18842"}, }, - DERPForcedWebsocket: map[int]string{}, - Addresses: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128)}, - AllowedIPs: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128)}, - Endpoints: []string{"192.168.1.1:18842"}, - }} - sendNode := make(chan struct{}) - - serverMultiAgent.EXPECT().NextUpdate(gomock.Any()).AnyTimes(). - DoAndReturn(func(ctx context.Context) ([]*agpl.Node, bool) { - select { - case <-sendNode: - return expected, true - case <-ctx.Done(): - return nil, false - } - }) + }}} rma, err := client.DialCoordinator(ctx) require.NoError(t, err) // Subscribe { - ch := make(chan struct{}) - serverMultiAgent.EXPECT().SubscribeAgent(agentID).Do(func(uuid.UUID) { - close(ch) - }) require.NoError(t, rma.SubscribeAgent(agentID)) - waitOrCancel(ctx, t, ch) + + req := testutil.RequireRecvCtx(ctx, t, reqs) + require.Equal(t, agentID[:], req.GetAddTunnel().GetId()) } // Read updated agent node { - sendNode <- struct{}{} - got, ok := rma.NextUpdate(ctx) + resps <- expected + + resp, ok := rma.NextUpdate(ctx) assert.True(t, ok) - got[0].AsOf = got[0].AsOf.In(time.Local) - assert.Equal(t, *expected[0], *got[0]) + updates := resp.GetPeerUpdates() + assert.Len(t, updates, 1) + eq, err := updates[0].GetNode().Equal(expected.GetPeerUpdates()[0].GetNode()) + assert.NoError(t, err) + assert.True(t, eq) } // Check legacy { @@ -241,45 +265,38 @@ func TestDialCoordinator(t *testing.T) { } // UpdateSelf { - ch := make(chan struct{}) - serverMultiAgent.EXPECT().UpdateSelf(gomock.Any()).Do(func(node *agpl.Node) { - node.AsOf = node.AsOf.In(time.Local) - assert.Equal(t, expected[0], node) - close(ch) - }) - require.NoError(t, rma.UpdateSelf(expected[0])) - waitOrCancel(ctx, t, ch) + require.NoError(t, rma.UpdateSelf(expected.PeerUpdates[0].GetNode())) + + req := testutil.RequireRecvCtx(ctx, t, reqs) + eq, err := req.GetUpdateSelf().GetNode().Equal(expected.PeerUpdates[0].GetNode()) + require.NoError(t, err) + require.True(t, eq) } // Unsubscribe { - ch := make(chan struct{}) - serverMultiAgent.EXPECT().UnsubscribeAgent(agentID).Do(func(uuid.UUID) { - close(ch) - }) require.NoError(t, rma.UnsubscribeAgent(agentID)) - waitOrCancel(ctx, t, ch) + + req := testutil.RequireRecvCtx(ctx, t, reqs) + require.Equal(t, agentID[:], req.GetRemoveTunnel().GetId()) } // Close { - ch := make(chan struct{}) - serverMultiAgent.EXPECT().Close().Do(func() { - close(ch) - }) require.NoError(t, rma.Close()) - waitOrCancel(ctx, t, ch) + + req := testutil.RequireRecvCtx(ctx, t, reqs) + require.NotNil(t, req.Disconnect) + close(resps) + select { + case <-ctx.Done(): + t.Fatal("timeout waiting for req close") + case _, ok := <-reqs: + require.False(t, ok, "didn't close requests") + } + require.Error(t, testutil.RequireRecvCtx(ctx, t, serveMACErr)) } }) } -func waitOrCancel(ctx context.Context, t testing.TB, ch <-chan struct{}) { - t.Helper() - select { - case <-ch: - case <-ctx.Done(): - t.Fatal("timed out waiting for channel") - } -} - type ResponseRecorder struct { rw *httptest.ResponseRecorder wasWritten atomic.Bool diff --git a/tailnet/configmaps.go b/tailnet/configmaps.go index 49200aa5fd875..9c9fe7ee8d733 100644 --- a/tailnet/configmaps.go +++ b/tailnet/configmaps.go @@ -490,6 +490,18 @@ func (c *configMaps) protoNodeToTailcfg(p *proto.Node) (*tailcfg.Node, error) { }, nil } +// nodeAddresses returns the addresses for the peer with the given publicKey, if known. +func (c *configMaps) nodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) { + c.L.Lock() + defer c.L.Unlock() + for _, lc := range c.peers { + if lc.node.Key == publicKey { + return lc.node.Addresses, true + } + } + return nil, false +} + type peerLifecycle struct { peerID uuid.UUID node *tailcfg.Node diff --git a/tailnet/conn.go b/tailnet/conn.go index 34712ee0ffb9f..0b4b942f8f4b5 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -3,48 +3,40 @@ package tailnet import ( "context" "encoding/binary" - "errors" "fmt" "net" "net/http" "net/netip" "os" - "reflect" "strconv" "sync" "time" "github.com/cenkalti/backoff/v4" "github.com/google/uuid" - "go4.org/netipx" "golang.org/x/xerrors" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "tailscale.com/envknob" "tailscale.com/ipn/ipnstate" "tailscale.com/net/connstats" - "tailscale.com/net/dns" "tailscale.com/net/netmon" "tailscale.com/net/netns" "tailscale.com/net/tsdial" "tailscale.com/net/tstun" "tailscale.com/tailcfg" "tailscale.com/tsd" - "tailscale.com/types/ipproto" "tailscale.com/types/key" tslogger "tailscale.com/types/logger" "tailscale.com/types/netlogtype" - "tailscale.com/types/netmap" "tailscale.com/wgengine" - "tailscale.com/wgengine/filter" "tailscale.com/wgengine/magicsock" "tailscale.com/wgengine/netstack" "tailscale.com/wgengine/router" - "tailscale.com/wgengine/wgcfg/nmcfg" "cdr.dev/slog" - "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/cryptorand" + "github.com/coder/coder/v2/tailnet/proto" ) var ErrConnClosed = xerrors.New("connection closed") @@ -128,42 +120,6 @@ func NewConn(options *Options) (conn *Conn, err error) { } nodePrivateKey := key.NewNode() - nodePublicKey := nodePrivateKey.Public() - - netMap := &netmap.NetworkMap{ - DERPMap: options.DERPMap, - NodeKey: nodePublicKey, - PrivateKey: nodePrivateKey, - Addresses: options.Addresses, - PacketFilter: []filter.Match{{ - // Allow any protocol! - IPProto: []ipproto.Proto{ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, ipproto.SCTP}, - // Allow traffic sourced from anywhere. - Srcs: []netip.Prefix{ - netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0), - netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0), - }, - // Allow traffic to route anywhere. - Dsts: []filter.NetPortRange{ - { - Net: netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0), - Ports: filter.PortRange{ - First: 0, - Last: 65535, - }, - }, - { - Net: netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0), - Ports: filter.PortRange{ - First: 0, - Last: 65535, - }, - }, - }, - Caps: []filter.CapMatch{}, - }}, - } - var nodeID tailcfg.NodeID // If we're provided with a UUID, use it to populate our node ID. @@ -177,14 +133,6 @@ func NewConn(options *Options) (conn *Conn, err error) { nodeID = tailcfg.NodeID(uid) } - // This is used by functions below to identify the node via key - netMap.SelfNode = &tailcfg.Node{ - ID: nodeID, - Key: nodePublicKey, - Addresses: options.Addresses, - AllowedIPs: options.Addresses, - } - wireguardMonitor, err := netmon.New(Logger(options.Logger.Named("net.wgmonitor"))) if err != nil { return nil, xerrors.Errorf("create wireguard link monitor: %w", err) @@ -243,7 +191,6 @@ func NewConn(options *Options) (conn *Conn, err error) { if err != nil { return nil, xerrors.Errorf("set node private key: %w", err) } - netMap.SelfNode.DiscoKey = magicConn.DiscoPublicKey() netStack, err := netstack.Create( Logger(options.Logger.Named("net.netstack")), @@ -262,44 +209,46 @@ func NewConn(options *Options) (conn *Conn, err error) { } netStack.ProcessLocalIPs = true wireguardEngine = wgengine.NewWatchdog(wireguardEngine) - wireguardEngine.SetDERPMap(options.DERPMap) - netMapCopy := *netMap - options.Logger.Debug(context.Background(), "updating network map") - wireguardEngine.SetNetworkMap(&netMapCopy) - - localIPSet := netipx.IPSetBuilder{} - for _, addr := range netMap.Addresses { - localIPSet.AddPrefix(addr) - } - localIPs, _ := localIPSet.IPSet() - logIPSet := netipx.IPSetBuilder{} - logIPs, _ := logIPSet.IPSet() - wireguardEngine.SetFilter(filter.New( - netMap.PacketFilter, - localIPs, - logIPs, + + cfgMaps := newConfigMaps( + options.Logger, + wireguardEngine, + nodeID, + nodePrivateKey, + magicConn.DiscoPublicKey(), + ) + cfgMaps.setAddresses(options.Addresses) + cfgMaps.setDERPMap(DERPMapToProto(options.DERPMap)) + cfgMaps.setBlockEndpoints(options.BlockEndpoints) + + nodeUp := newNodeUpdater( + options.Logger, nil, - Logger(options.Logger.Named("net.packet-filter")), - )) + nodeID, + nodePrivateKey.Public(), + magicConn.DiscoPublicKey(), + ) + nodeUp.setAddresses(options.Addresses) + nodeUp.setBlockEndpoints(options.BlockEndpoints) + wireguardEngine.SetStatusCallback(nodeUp.setStatus) + wireguardEngine.SetNetInfoCallback(nodeUp.setNetInfo) + magicConn.SetDERPForcedWebsocketCallback(nodeUp.setDERPForcedWebsocket) server := &Conn{ - blockEndpoints: options.BlockEndpoints, - derpForceWebSockets: options.DERPForceWebSockets, - closed: make(chan struct{}), - logger: options.Logger, - magicConn: magicConn, - dialer: dialer, - listeners: map[listenKey]*listener{}, - peerMap: map[tailcfg.NodeID]*tailcfg.Node{}, - lastDERPForcedWebSockets: map[int]string{}, - tunDevice: sys.Tun.Get(), - netMap: netMap, - netStack: netStack, - wireguardMonitor: wireguardMonitor, + closed: make(chan struct{}), + logger: options.Logger, + magicConn: magicConn, + dialer: dialer, + listeners: map[listenKey]*listener{}, + tunDevice: sys.Tun.Get(), + netStack: netStack, + wireguardMonitor: wireguardMonitor, wireguardRouter: &router.Config{ - LocalAddrs: netMap.Addresses, + LocalAddrs: options.Addresses, }, wireguardEngine: wireguardEngine, + configMaps: cfgMaps, + nodeUpdater: nodeUp, } defer func() { if err != nil { @@ -307,52 +256,6 @@ func NewConn(options *Options) (conn *Conn, err error) { } }() - wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) { - server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.Error(err)) - if err != nil { - return - } - server.lastMutex.Lock() - if s.AsOf.Before(server.lastStatus) { - // Don't process outdated status! - server.lastMutex.Unlock() - return - } - server.lastStatus = s.AsOf - if endpointsEqual(s.LocalAddrs, server.lastEndpoints) { - // No need to update the node if nothing changed! - server.lastMutex.Unlock() - return - } - server.lastEndpoints = append([]tailcfg.Endpoint{}, s.LocalAddrs...) - server.lastMutex.Unlock() - server.sendNode() - }) - - wireguardEngine.SetNetInfoCallback(func(ni *tailcfg.NetInfo) { - server.logger.Debug(context.Background(), "netinfo callback", slog.F("netinfo", ni)) - server.lastMutex.Lock() - if reflect.DeepEqual(server.lastNetInfo, ni) { - server.lastMutex.Unlock() - return - } - server.lastNetInfo = ni.Clone() - server.lastMutex.Unlock() - server.sendNode() - }) - - magicConn.SetDERPForcedWebsocketCallback(func(region int, reason string) { - server.logger.Debug(context.Background(), "derp forced websocket", slog.F("region", region), slog.F("reason", reason)) - server.lastMutex.Lock() - if server.lastDERPForcedWebSockets[region] == reason { - server.lastMutex.Unlock() - return - } - server.lastDERPForcedWebSockets[region] = reason - server.lastMutex.Unlock() - server.sendNode() - }) - netStack.GetTCPHandlerForFlow = server.forwardTCP err = netStack.Start(nil) @@ -389,16 +292,14 @@ func IPFromUUID(uid uuid.UUID) netip.Addr { // Conn is an actively listening Wireguard connection. type Conn struct { - mutex sync.Mutex - closed chan struct{} - logger slog.Logger - blockEndpoints bool - derpForceWebSockets bool + mutex sync.Mutex + closed chan struct{} + logger slog.Logger dialer *tsdial.Dialer tunDevice *tstun.Wrapper - peerMap map[tailcfg.NodeID]*tailcfg.Node - netMap *netmap.NetworkMap + configMaps *configMaps + nodeUpdater *nodeUpdater netStack *netstack.Impl magicConn *magicsock.Conn wireguardMonitor *netmon.Monitor @@ -406,17 +307,6 @@ type Conn struct { wireguardEngine wgengine.Engine listeners map[listenKey]*listener - lastMutex sync.Mutex - nodeSending bool - nodeChanged bool - // It's only possible to store these values via status functions, - // so the values must be stored for retrieval later on. - lastStatus time.Time - lastEndpoints []tailcfg.Endpoint - lastDERPForcedWebSockets map[int]string - lastNetInfo *tailcfg.NetInfo - nodeCallback func(node *Node) - trafficStats *connstats.Statistics } @@ -425,57 +315,30 @@ func (c *Conn) MagicsockSetDebugLoggingEnabled(enabled bool) { } func (c *Conn) SetAddresses(ips []netip.Prefix) error { - c.mutex.Lock() - defer c.mutex.Unlock() - - c.netMap.Addresses = ips - - netMapCopy := *c.netMap - c.logger.Debug(context.Background(), "updating network map") - c.wireguardEngine.SetNetworkMap(&netMapCopy) - err := c.reconfig() - if err != nil { - return xerrors.Errorf("reconfig: %w", err) - } - + c.configMaps.setAddresses(ips) + c.nodeUpdater.setAddresses(ips) return nil } -func (c *Conn) Addresses() []netip.Prefix { - c.mutex.Lock() - defer c.mutex.Unlock() - return c.netMap.Addresses -} - func (c *Conn) SetNodeCallback(callback func(node *Node)) { - c.lastMutex.Lock() - c.nodeCallback = callback - c.lastMutex.Unlock() - c.sendNode() + c.nodeUpdater.setCallback(callback) } // SetDERPMap updates the DERPMap of a connection. func (c *Conn) SetDERPMap(derpMap *tailcfg.DERPMap) { - c.mutex.Lock() - defer c.mutex.Unlock() - c.logger.Debug(context.Background(), "updating derp map", slog.F("derp_map", derpMap)) - c.wireguardEngine.SetDERPMap(derpMap) - c.netMap.DERPMap = derpMap - netMapCopy := *c.netMap - c.logger.Debug(context.Background(), "updating network map") - c.wireguardEngine.SetNetworkMap(&netMapCopy) + c.configMaps.setDERPMap(DERPMapToProto(derpMap)) } func (c *Conn) SetDERPForceWebSockets(v bool) { + c.logger.Info(context.Background(), "setting DERP Force Websockets", slog.F("force_derp_websockets", v)) c.magicConn.SetDERPForceWebsockets(v) } -// SetBlockEndpoints sets whether or not to block P2P endpoints. This setting +// SetBlockEndpoints sets whether to block P2P endpoints. This setting // will only apply to new peers. func (c *Conn) SetBlockEndpoints(blockEndpoints bool) { - c.mutex.Lock() - defer c.mutex.Unlock() - c.blockEndpoints = blockEndpoints + c.configMaps.setBlockEndpoints(blockEndpoints) + c.nodeUpdater.setBlockEndpoints(blockEndpoints) } // SetDERPRegionDialer updates the dialer to use for connecting to DERP regions. @@ -483,186 +346,24 @@ func (c *Conn) SetDERPRegionDialer(dialer func(ctx context.Context, region *tail c.magicConn.SetDERPRegionDialer(dialer) } -// UpdateNodes connects with a set of peers. This can be constantly updated, -// and peers will continually be reconnected as necessary. If replacePeers is -// true, all peers will be removed before adding the new ones. -// -//nolint:revive // Complains about replacePeers. -func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error { - c.mutex.Lock() - defer c.mutex.Unlock() - +// UpdatePeers connects with a set of peers. This can be constantly updated, +// and peers will continually be reconnected as necessary. +func (c *Conn) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error { if c.isClosed() { return ErrConnClosed } - - status := c.Status() - if replacePeers { - c.netMap.Peers = []*tailcfg.Node{} - c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{} - } - for _, peer := range c.netMap.Peers { - peerStatus, ok := status.Peer[peer.Key] - if !ok { - continue - } - // If this peer was added in the last 5 minutes, assume it - // could still be active. - if time.Since(peer.Created) < 5*time.Minute { - continue - } - // We double-check that it's safe to remove by ensuring no - // handshake has been sent in the past 5 minutes as well. Connections that - // are actively exchanging IP traffic will handshake every 2 minutes. - if time.Since(peerStatus.LastHandshake) < 5*time.Minute { - continue - } - - c.logger.Debug(context.Background(), "removing peer, last handshake >5m ago", - slog.F("peer", peer.Key), slog.F("last_handshake", peerStatus.LastHandshake), - ) - delete(c.peerMap, peer.ID) - } - - for _, node := range nodes { - // If no preferred DERP is provided, we can't reach the node. - if node.PreferredDERP == 0 { - c.logger.Debug(context.Background(), "no preferred DERP, skipping node", slog.F("node", node)) - continue - } - c.logger.Debug(context.Background(), "adding node", slog.F("node", node)) - - peerStatus, ok := status.Peer[node.Key] - peerNode := &tailcfg.Node{ - ID: node.ID, - Created: time.Now(), - Key: node.Key, - DiscoKey: node.DiscoKey, - Addresses: node.Addresses, - AllowedIPs: node.AllowedIPs, - Endpoints: node.Endpoints, - DERP: fmt.Sprintf("%s:%d", tailcfg.DerpMagicIP, node.PreferredDERP), - Hostinfo: (&tailcfg.Hostinfo{}).View(), - // Starting KeepAlive messages at the initialization of a connection - // causes a race condition. If we handshake before the peer has our - // node, we'll have wait for 5 seconds before trying again. Ideally, - // the first handshake starts when the user first initiates a - // connection to the peer. After a successful connection we enable - // keep alives to persist the connection and keep it from becoming - // idle. SSH connections don't send send packets while idle, so we - // use keep alives to avoid random hangs while we set up the - // connection again after inactivity. - KeepAlive: ok && peerStatus.Active, - } - if c.blockEndpoints { - peerNode.Endpoints = nil - } - c.peerMap[node.ID] = peerNode - } - - c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap)) - for _, peer := range c.peerMap { - c.netMap.Peers = append(c.netMap.Peers, peer.Clone()) - } - - netMapCopy := *c.netMap - c.logger.Debug(context.Background(), "updating network map") - c.wireguardEngine.SetNetworkMap(&netMapCopy) - err := c.reconfig() - if err != nil { - return xerrors.Errorf("reconfig: %w", err) - } - - return nil -} - -// PeerSelector is used to select a peer from within a Tailnet. -type PeerSelector struct { - ID tailcfg.NodeID - IP netip.Prefix -} - -func (c *Conn) RemovePeer(selector PeerSelector) (deleted bool, err error) { - c.mutex.Lock() - defer c.mutex.Unlock() - - if c.isClosed() { - return false, ErrConnClosed - } - - deleted = false - for _, peer := range c.peerMap { - if peer.ID == selector.ID { - delete(c.peerMap, peer.ID) - deleted = true - break - } - - for _, peerIP := range peer.Addresses { - if peerIP.Bits() == selector.IP.Bits() && peerIP.Addr().Compare(selector.IP.Addr()) == 0 { - delete(c.peerMap, peer.ID) - deleted = true - break - } - } - } - if !deleted { - return false, nil - } - - c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap)) - for _, peer := range c.peerMap { - c.netMap.Peers = append(c.netMap.Peers, peer.Clone()) - } - - netMapCopy := *c.netMap - c.logger.Debug(context.Background(), "updating network map") - c.wireguardEngine.SetNetworkMap(&netMapCopy) - err = c.reconfig() - if err != nil { - return false, xerrors.Errorf("reconfig: %w", err) - } - - return true, nil -} - -func (c *Conn) reconfig() error { - cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("net.wgconfig")), netmap.AllowSingleHosts, "") - if err != nil { - return xerrors.Errorf("update wireguard config: %w", err) - } - - err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{}) - if err != nil { - if c.isClosed() { - return nil - } - if errors.Is(err, wgengine.ErrNoChanges) { - return nil - } - return xerrors.Errorf("reconfig: %w", err) - } - + c.configMaps.updatePeers(updates) return nil } // NodeAddresses returns the addresses of a node from the NetworkMap. func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) { - c.mutex.Lock() - defer c.mutex.Unlock() - for _, node := range c.netMap.Peers { - if node.Key == publicKey { - return node.Addresses, true - } - } - return nil, false + return c.configMaps.nodeAddresses(publicKey) } // Status returns the current ipnstate of a connection. func (c *Conn) Status() *ipnstate.Status { - sb := &ipnstate.StatusBuilder{WantPeers: true} - c.wireguardEngine.UpdateStatus(sb) - return sb.Status() + return c.configMaps.status() } // Ping sends a ping to the Wireguard engine. @@ -689,16 +390,9 @@ func (c *Conn) Ping(ctx context.Context, ip netip.Addr) (time.Duration, bool, *i // DERPMap returns the currently set DERP mapping. func (c *Conn) DERPMap() *tailcfg.DERPMap { - c.mutex.Lock() - defer c.mutex.Unlock() - return c.netMap.DERPMap -} - -// BlockEndpoints returns whether or not P2P is blocked. -func (c *Conn) BlockEndpoints() bool { - c.mutex.Lock() - defer c.mutex.Unlock() - return c.blockEndpoints + c.configMaps.L.Lock() + defer c.configMaps.L.Unlock() + return c.configMaps.derpMapLocked() } // AwaitReachable pings the provided IP continually until the @@ -759,6 +453,9 @@ func (c *Conn) Closed() <-chan struct{} { // Close shuts down the Wireguard connection. func (c *Conn) Close() error { + c.logger.Info(context.Background(), "closing tailnet Conn") + c.configMaps.close() + c.nodeUpdater.close() c.mutex.Lock() select { case <-c.closed: @@ -808,91 +505,11 @@ func (c *Conn) isClosed() bool { } } -func (c *Conn) sendNode() { - c.lastMutex.Lock() - defer c.lastMutex.Unlock() - if c.nodeSending { - c.nodeChanged = true - return - } - node := c.selfNode() - // Conn.UpdateNodes will skip any nodes that don't have the PreferredDERP - // set to non-zero, since we cannot reach nodes without DERP for discovery. - // Therefore, there is no point in sending the node without this, and we can - // save ourselves from churn in the tailscale/wireguard layer. - if node.PreferredDERP == 0 { - c.logger.Debug(context.Background(), "skipped sending node; no PreferredDERP", slog.F("node", node)) - return - } - nodeCallback := c.nodeCallback - if nodeCallback == nil { - return - } - c.nodeSending = true - go func() { - c.logger.Debug(context.Background(), "sending node", slog.F("node", node)) - nodeCallback(node) - c.lastMutex.Lock() - c.nodeSending = false - if c.nodeChanged { - c.nodeChanged = false - c.lastMutex.Unlock() - c.sendNode() - return - } - c.lastMutex.Unlock() - }() -} - // Node returns the last node that was sent to the node callback. func (c *Conn) Node() *Node { - c.lastMutex.Lock() - defer c.lastMutex.Unlock() - return c.selfNode() -} - -func (c *Conn) selfNode() *Node { - endpoints := make([]string, 0, len(c.lastEndpoints)) - for _, addr := range c.lastEndpoints { - endpoints = append(endpoints, addr.Addr.String()) - } - var preferredDERP int - var derpLatency map[string]float64 - derpForcedWebsocket := make(map[int]string, 0) - if c.lastNetInfo != nil { - preferredDERP = c.lastNetInfo.PreferredDERP - derpLatency = c.lastNetInfo.DERPLatency - - if c.derpForceWebSockets { - // We only need to store this for a single region, since this is - // mostly used for debugging purposes and doesn't actually have a - // code purpose. - derpForcedWebsocket[preferredDERP] = "DERP is configured to always fallback to WebSockets" - } else { - for k, v := range c.lastDERPForcedWebSockets { - derpForcedWebsocket[k] = v - } - } - } - - node := &Node{ - ID: c.netMap.SelfNode.ID, - AsOf: dbtime.Now(), - Key: c.netMap.SelfNode.Key, - Addresses: c.netMap.SelfNode.Addresses, - AllowedIPs: c.netMap.SelfNode.AllowedIPs, - DiscoKey: c.magicConn.DiscoPublicKey(), - Endpoints: endpoints, - PreferredDERP: preferredDERP, - DERPLatency: derpLatency, - DERPForcedWebsocket: derpForcedWebsocket, - } - c.mutex.Lock() - if c.blockEndpoints { - node.Endpoints = nil - } - c.mutex.Unlock() - return node + c.nodeUpdater.L.Lock() + defer c.nodeUpdater.L.Unlock() + return c.nodeUpdater.nodeLocked() } // This and below is taken _mostly_ verbatim from Tailscale: @@ -1056,15 +673,3 @@ func Logger(logger slog.Logger) tslogger.Logf { logger.Debug(context.Background(), fmt.Sprintf(format, args...)) }) } - -func endpointsEqual(x, y []tailcfg.Endpoint) bool { - if len(x) != len(y) { - return false - } - for i := range x { - if x[i] != y[i] { - return false - } - } - return true -} diff --git a/tailnet/conn_test.go b/tailnet/conn_test.go index 7554c94cec682..f3bc96e242f9e 100644 --- a/tailnet/conn_test.go +++ b/tailnet/conn_test.go @@ -5,6 +5,7 @@ import ( "net/netip" "testing" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" @@ -12,6 +13,7 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/testutil" ) @@ -22,10 +24,10 @@ func TestMain(m *testing.M) { func TestTailnet(t *testing.T) { t.Parallel() - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) derpMap, _ := tailnettest.RunDERPAndSTUN(t) t.Run("InstantClose", func(t *testing.T) { t.Parallel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, Logger: logger.Named("w1"), @@ -37,6 +39,8 @@ func TestTailnet(t *testing.T) { }) t.Run("Connect", func(t *testing.T) { t.Parallel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx := testutil.Context(t, testutil.WaitLong) w1IP := tailnet.IP() w1, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)}, @@ -55,14 +59,8 @@ func TestTailnet(t *testing.T) { _ = w1.Close() _ = w2.Close() }) - w1.SetNodeCallback(func(node *tailnet.Node) { - err := w2.UpdateNodes([]*tailnet.Node{node}, false) - assert.NoError(t, err) - }) - w2.SetNodeCallback(func(node *tailnet.Node) { - err := w1.UpdateNodes([]*tailnet.Node{node}, false) - assert.NoError(t, err) - }) + stitch(t, w2, w1) + stitch(t, w1, w2) require.True(t, w2.AwaitReachable(context.Background(), w1IP)) conn := make(chan struct{}, 1) go func() { @@ -89,7 +87,7 @@ func TestTailnet(t *testing.T) { default: } }) - node := <-nodes + node := testutil.RequireRecvCtx(ctx, t, nodes) // Ensure this connected over DERP! require.Len(t, node.DERPForcedWebsocket, 0) @@ -99,6 +97,7 @@ func TestTailnet(t *testing.T) { t.Run("ForcesWebSockets", func(t *testing.T) { t.Parallel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) ctx := testutil.Context(t, testutil.WaitMedium) w1IP := tailnet.IP() @@ -122,14 +121,8 @@ func TestTailnet(t *testing.T) { _ = w1.Close() _ = w2.Close() }) - w1.SetNodeCallback(func(node *tailnet.Node) { - err := w2.UpdateNodes([]*tailnet.Node{node}, false) - assert.NoError(t, err) - }) - w2.SetNodeCallback(func(node *tailnet.Node) { - err := w1.UpdateNodes([]*tailnet.Node{node}, false) - assert.NoError(t, err) - }) + stitch(t, w2, w1) + stitch(t, w1, w2) require.True(t, w2.AwaitReachable(ctx, w1IP)) conn := make(chan struct{}, 1) go func() { @@ -243,11 +236,16 @@ func TestConn_UpdateDERP(t *testing.T) { err := client1.Close() assert.NoError(t, err) }() - client1.SetNodeCallback(func(node *tailnet.Node) { - err := conn.UpdateNodes([]*tailnet.Node{node}, false) - assert.NoError(t, err) - }) - client1.UpdateNodes([]*tailnet.Node{conn.Node()}, false) + stitch(t, conn, client1) + pn, err := tailnet.NodeToProto(conn.Node()) + require.NoError(t, err) + connID := uuid.New() + err = client1.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{ + Id: connID[:], + Node: pn, + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + }}) + require.NoError(t, err) awaitReachableCtx1, awaitReachableCancel1 := context.WithTimeout(context.Background(), testutil.WaitShort) defer awaitReachableCancel1() @@ -288,7 +286,13 @@ parentLoop: // ... unless the client updates it's derp map and nodes. client1.SetDERPMap(derpMap2) - client1.UpdateNodes([]*tailnet.Node{conn.Node()}, false) + pn, err = tailnet.NodeToProto(conn.Node()) + require.NoError(t, err) + client1.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{ + Id: connID[:], + Node: pn, + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + }}) awaitReachableCtx3, awaitReachableCancel3 := context.WithTimeout(context.Background(), testutil.WaitShort) defer awaitReachableCancel3() require.True(t, client1.AwaitReachable(awaitReachableCtx3, ip)) @@ -306,13 +310,34 @@ parentLoop: err := client2.Close() assert.NoError(t, err) }() - client2.SetNodeCallback(func(node *tailnet.Node) { - err := conn.UpdateNodes([]*tailnet.Node{node}, false) - assert.NoError(t, err) - }) - client2.UpdateNodes([]*tailnet.Node{conn.Node()}, false) + stitch(t, conn, client2) + pn, err = tailnet.NodeToProto(conn.Node()) + require.NoError(t, err) + client2.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{ + Id: connID[:], + Node: pn, + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + }}) awaitReachableCtx4, awaitReachableCancel4 := context.WithTimeout(context.Background(), testutil.WaitShort) defer awaitReachableCancel4() require.True(t, client2.AwaitReachable(awaitReachableCtx4, ip)) } + +// stitch sends node updates from src Conn as peer updates to dst Conn. Sort of +// like the Coordinator would, but without actually needing a Coordinator. +func stitch(t *testing.T, dst, src *tailnet.Conn) { + srcID := uuid.New() + src.SetNodeCallback(func(node *tailnet.Node) { + pn, err := tailnet.NodeToProto(node) + if !assert.NoError(t, err) { + return + } + err = dst.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{ + Id: srcID[:], + Node: pn, + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + }}) + assert.NoError(t, err) + }) +} diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 04c5fc0ee3e7e..0fa62fc922209 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -3,6 +3,7 @@ package tailnet import ( "context" "encoding/json" + "fmt" "html/template" "io" "net" @@ -92,6 +93,237 @@ type Node struct { Endpoints []string `json:"endpoints"` } +// Coordinatee is something that can be coordinated over the Coordinate protocol. Usually this is a +// Conn. +type Coordinatee interface { + UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error + SetNodeCallback(func(*Node)) +} + +type Coordination interface { + io.Closer + Error() <-chan error +} + +type remoteCoordination struct { + sync.Mutex + closed bool + errChan chan error + coordinatee Coordinatee + logger slog.Logger + protocol proto.DRPCTailnet_CoordinateClient +} + +func (c *remoteCoordination) Close() error { + c.Lock() + defer c.Unlock() + if c.closed { + return nil + } + c.closed = true + err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) + if err != nil { + return xerrors.Errorf("send disconnect: %w", err) + } + return nil +} + +func (c *remoteCoordination) Error() <-chan error { + return c.errChan +} + +func (c *remoteCoordination) sendErr(err error) { + select { + case c.errChan <- err: + default: + } +} + +func (c *remoteCoordination) respLoop() { + for { + resp, err := c.protocol.Recv() + if err != nil { + c.sendErr(xerrors.Errorf("read: %w", err)) + return + } + err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates()) + if err != nil { + c.sendErr(xerrors.Errorf("update peers: %w", err)) + return + } + } +} + +// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinee (usually a +// Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as +// a client---agents should NOT set this!). +func NewRemoteCoordination(logger slog.Logger, + protocol proto.DRPCTailnet_CoordinateClient, coordinatee Coordinatee, + tunnelTarget uuid.UUID, +) Coordination { + c := &remoteCoordination{ + errChan: make(chan error, 1), + coordinatee: coordinatee, + logger: logger, + protocol: protocol, + } + if tunnelTarget != uuid.Nil { + c.Lock() + err := c.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tunnelTarget[:]}}) + c.Unlock() + if err != nil { + c.sendErr(err) + } + } + + coordinatee.SetNodeCallback(func(node *Node) { + pn, err := NodeToProto(node) + if err != nil { + c.logger.Critical(context.Background(), "failed to convert node", slog.Error(err)) + c.sendErr(err) + return + } + c.Lock() + defer c.Unlock() + if c.closed { + c.logger.Debug(context.Background(), "ignored node update because coordination is closed") + return + } + err = c.protocol.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}) + if err != nil { + c.sendErr(xerrors.Errorf("write: %w", err)) + } + }) + go c.respLoop() + return c +} + +type inMemoryCoordination struct { + sync.Mutex + ctx context.Context + errChan chan error + closed bool + closedCh chan struct{} + coordinatee Coordinatee + logger slog.Logger + resps <-chan *proto.CoordinateResponse + reqs chan<- *proto.CoordinateRequest +} + +func (c *inMemoryCoordination) sendErr(err error) { + select { + case c.errChan <- err: + default: + } +} + +func (c *inMemoryCoordination) Error() <-chan error { + return c.errChan +} + +// NewInMemoryCoordination connects a Coordinatee (usually Conn) to an in memory Coordinator, for testing +// or local clients. Set ClientID to uuid.Nil for an agent. +func NewInMemoryCoordination( + ctx context.Context, logger slog.Logger, + clientID, agentID uuid.UUID, + coordinator Coordinator, coordinatee Coordinatee, +) Coordination { + thisID := agentID + logger = logger.With(slog.F("agent_id", agentID)) + var auth TunnelAuth = AgentTunnelAuth{} + if clientID != uuid.Nil { + // this is a client connection + auth = ClientTunnelAuth{AgentID: agentID} + logger = logger.With(slog.F("client_id", clientID)) + thisID = clientID + } + c := &inMemoryCoordination{ + ctx: ctx, + errChan: make(chan error, 1), + coordinatee: coordinatee, + logger: logger, + closedCh: make(chan struct{}), + } + + // use the background context since we will depend exclusively on closing the req channel to + // tell the coordinator we are done. + c.reqs, c.resps = coordinator.Coordinate(context.Background(), + thisID, fmt.Sprintf("inmemory%s", thisID), + auth, + ) + go c.respLoop() + if agentID != uuid.Nil { + select { + case <-ctx.Done(): + c.logger.Warn(ctx, "context expired before we could add tunnel", slog.Error(ctx.Err())) + return c + case c.reqs <- &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}}: + // OK! + } + } + coordinatee.SetNodeCallback(func(n *Node) { + pn, err := NodeToProto(n) + if err != nil { + c.logger.Critical(ctx, "failed to convert node", slog.Error(err)) + c.sendErr(err) + return + } + c.Lock() + defer c.Unlock() + if c.closed { + return + } + select { + case <-ctx.Done(): + c.logger.Info(ctx, "context expired before sending node update") + return + case c.reqs <- &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}: + c.logger.Debug(ctx, "sent node in-memory to coordinator") + } + }) + return c +} + +func (c *inMemoryCoordination) respLoop() { + for { + select { + case <-c.closedCh: + c.logger.Debug(context.Background(), "in-memory coordination closed") + return + case resp, ok := <-c.resps: + if !ok { + c.logger.Debug(context.Background(), "in-memory response channel closed") + return + } + c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp)) + err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates()) + if err != nil { + c.sendErr(xerrors.Errorf("failed to update peers: %w", err)) + return + } + } + } +} + +func (c *inMemoryCoordination) Close() error { + c.Lock() + defer c.Unlock() + c.logger.Debug(context.Background(), "closing in-memory coordination") + if c.closed { + return nil + } + defer close(c.reqs) + c.closed = true + close(c.closedCh) + select { + case <-c.ctx.Done(): + return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err()) + case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}: + c.logger.Debug(context.Background(), "sent graceful disconnect in-memory") + return nil + } +} + // ServeCoordinator matches the RW structure of a coordinator to exchange node messages. func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func(node *Node), <-chan error) { errChan := make(chan error, 1) @@ -237,21 +469,17 @@ func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAge } return false }, - OnSubscribe: func(enq Queue, agent uuid.UUID) (*Node, error) { + OnSubscribe: func(enq Queue, agent uuid.UUID) error { err := SendCtx(ctx, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}}) - return c.Node(agent), err + return err }, OnUnsubscribe: func(enq Queue, agent uuid.UUID) error { err := SendCtx(ctx, reqs, &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}}) return err }, - OnNodeUpdate: func(id uuid.UUID, node *Node) error { - pn, err := NodeToProto(node) - if err != nil { - return err - } + OnNodeUpdate: func(id uuid.UUID, node *proto.Node) error { return SendCtx(ctx, reqs, &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{ - Node: pn, + Node: node, }}) }, OnRemove: func(_ Queue) { @@ -285,7 +513,7 @@ const ( type Queue interface { UniqueID() uuid.UUID Kind() QueueKind - Enqueue(n []*Node) error + Enqueue(resp *proto.CoordinateResponse) error Name() string Stats() (start, lastWrite int64) Overwrites() int64 @@ -793,18 +1021,7 @@ func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logg return } logger.Debug(ctx, "v1RespLoop got response", slog.F("resp", resp)) - nodes, err := OnlyNodeUpdates(resp) - if err != nil { - logger.Critical(ctx, "v1RespLoop failed to decode resp", slog.F("resp", resp), slog.Error(err)) - _ = q.CoordinatorClose() - return - } - // don't send empty updates - if len(nodes) == 0 { - logger.Debug(ctx, "v1RespLoop skipping enqueueing 0-length v1 update") - continue - } - err = q.Enqueue(nodes) + err = q.Enqueue(resp) if err != nil && !xerrors.Is(err, context.Canceled) { logger.Error(ctx, "v1RespLoop failed to enqueue v1 update", slog.Error(err)) } diff --git a/tailnet/multiagent.go b/tailnet/multiagent.go index 5c3412a595152..621f6bc6b197d 100644 --- a/tailnet/multiagent.go +++ b/tailnet/multiagent.go @@ -8,13 +8,15 @@ import ( "github.com/google/uuid" "golang.org/x/xerrors" + + "github.com/coder/coder/v2/tailnet/proto" ) type MultiAgentConn interface { - UpdateSelf(node *Node) error + UpdateSelf(node *proto.Node) error SubscribeAgent(agentID uuid.UUID) error UnsubscribeAgent(agentID uuid.UUID) error - NextUpdate(ctx context.Context) ([]*Node, bool) + NextUpdate(ctx context.Context) (*proto.CoordinateResponse, bool) AgentIsLegacy(agentID uuid.UUID) bool Close() error IsClosed() bool @@ -26,16 +28,16 @@ type MultiAgent struct { ID uuid.UUID AgentIsLegacyFunc func(agentID uuid.UUID) bool - OnSubscribe func(enq Queue, agent uuid.UUID) (*Node, error) + OnSubscribe func(enq Queue, agent uuid.UUID) error OnUnsubscribe func(enq Queue, agent uuid.UUID) error - OnNodeUpdate func(id uuid.UUID, node *Node) error + OnNodeUpdate func(id uuid.UUID, node *proto.Node) error OnRemove func(enq Queue) ctx context.Context ctxCancel func() closed bool - updates chan []*Node + updates chan *proto.CoordinateResponse closeOnce sync.Once start int64 lastWrite int64 @@ -45,7 +47,7 @@ type MultiAgent struct { } func (m *MultiAgent) Init() *MultiAgent { - m.updates = make(chan []*Node, 128) + m.updates = make(chan *proto.CoordinateResponse, 128) m.start = time.Now().Unix() m.ctx, m.ctxCancel = context.WithCancel(context.Background()) return m @@ -65,7 +67,7 @@ func (m *MultiAgent) AgentIsLegacy(agentID uuid.UUID) bool { var ErrMultiAgentClosed = xerrors.New("multiagent is closed") -func (m *MultiAgent) UpdateSelf(node *Node) error { +func (m *MultiAgent) UpdateSelf(node *proto.Node) error { m.mu.RLock() defer m.mu.RUnlock() if m.closed { @@ -82,15 +84,11 @@ func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID) error { return ErrMultiAgentClosed } - node, err := m.OnSubscribe(m, agentID) + err := m.OnSubscribe(m, agentID) if err != nil { return err } - if node != nil { - return m.enqueueLocked([]*Node{node}) - } - return nil } @@ -104,17 +102,17 @@ func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) error { return m.OnUnsubscribe(m, agentID) } -func (m *MultiAgent) NextUpdate(ctx context.Context) ([]*Node, bool) { +func (m *MultiAgent) NextUpdate(ctx context.Context) (*proto.CoordinateResponse, bool) { select { case <-ctx.Done(): return nil, false - case nodes, ok := <-m.updates: - return nodes, ok + case resp, ok := <-m.updates: + return resp, ok } } -func (m *MultiAgent) Enqueue(nodes []*Node) error { +func (m *MultiAgent) Enqueue(resp *proto.CoordinateResponse) error { m.mu.RLock() defer m.mu.RUnlock() @@ -122,14 +120,14 @@ func (m *MultiAgent) Enqueue(nodes []*Node) error { return nil } - return m.enqueueLocked(nodes) + return m.enqueueLocked(resp) } -func (m *MultiAgent) enqueueLocked(nodes []*Node) error { +func (m *MultiAgent) enqueueLocked(resp *proto.CoordinateResponse) error { atomic.StoreInt64(&m.lastWrite, time.Now().Unix()) select { - case m.updates <- nodes: + case m.updates <- resp: return nil default: return ErrWouldBlock diff --git a/tailnet/service.go b/tailnet/service.go index 191319d16c5f4..7347afbb32a48 100644 --- a/tailnet/service.go +++ b/tailnet/service.go @@ -75,7 +75,9 @@ func NewClientService( } server := drpcserver.NewWithOptions(mux, drpcserver.Options{ Log: func(err error) { - if xerrors.Is(err, io.EOF) { + if xerrors.Is(err, io.EOF) || + xerrors.Is(err, context.Canceled) || + xerrors.Is(err, context.DeadlineExceeded) { return } logger.Debug(context.Background(), "drpc server error", slog.Error(err)) diff --git a/tailnet/tailnettest/coordinatormock.go b/tailnet/tailnettest/coordinatormock.go new file mode 100644 index 0000000000000..b0ae36d3f8c89 --- /dev/null +++ b/tailnet/tailnettest/coordinatormock.go @@ -0,0 +1,142 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coder/coder/v2/tailnet (interfaces: Coordinator) +// +// Generated by this command: +// +// mockgen -destination ./coordinatormock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinator +// + +// Package tailnettest is a generated GoMock package. +package tailnettest + +import ( + context "context" + net "net" + http "net/http" + reflect "reflect" + + tailnet "github.com/coder/coder/v2/tailnet" + proto "github.com/coder/coder/v2/tailnet/proto" + uuid "github.com/google/uuid" + gomock "go.uber.org/mock/gomock" +) + +// MockCoordinator is a mock of Coordinator interface. +type MockCoordinator struct { + ctrl *gomock.Controller + recorder *MockCoordinatorMockRecorder +} + +// MockCoordinatorMockRecorder is the mock recorder for MockCoordinator. +type MockCoordinatorMockRecorder struct { + mock *MockCoordinator +} + +// NewMockCoordinator creates a new mock instance. +func NewMockCoordinator(ctrl *gomock.Controller) *MockCoordinator { + mock := &MockCoordinator{ctrl: ctrl} + mock.recorder = &MockCoordinatorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCoordinator) EXPECT() *MockCoordinatorMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockCoordinator) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockCoordinatorMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCoordinator)(nil).Close)) +} + +// Coordinate mocks base method. +func (m *MockCoordinator) Coordinate(arg0 context.Context, arg1 uuid.UUID, arg2 string, arg3 tailnet.TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Coordinate", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(chan<- *proto.CoordinateRequest) + ret1, _ := ret[1].(<-chan *proto.CoordinateResponse) + return ret0, ret1 +} + +// Coordinate indicates an expected call of Coordinate. +func (mr *MockCoordinatorMockRecorder) Coordinate(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Coordinate", reflect.TypeOf((*MockCoordinator)(nil).Coordinate), arg0, arg1, arg2, arg3) +} + +// Node mocks base method. +func (m *MockCoordinator) Node(arg0 uuid.UUID) *tailnet.Node { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Node", arg0) + ret0, _ := ret[0].(*tailnet.Node) + return ret0 +} + +// Node indicates an expected call of Node. +func (mr *MockCoordinatorMockRecorder) Node(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Node", reflect.TypeOf((*MockCoordinator)(nil).Node), arg0) +} + +// ServeAgent mocks base method. +func (m *MockCoordinator) ServeAgent(arg0 net.Conn, arg1 uuid.UUID, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ServeAgent", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// ServeAgent indicates an expected call of ServeAgent. +func (mr *MockCoordinatorMockRecorder) ServeAgent(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeAgent", reflect.TypeOf((*MockCoordinator)(nil).ServeAgent), arg0, arg1, arg2) +} + +// ServeClient mocks base method. +func (m *MockCoordinator) ServeClient(arg0 net.Conn, arg1, arg2 uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ServeClient", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// ServeClient indicates an expected call of ServeClient. +func (mr *MockCoordinatorMockRecorder) ServeClient(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeClient", reflect.TypeOf((*MockCoordinator)(nil).ServeClient), arg0, arg1, arg2) +} + +// ServeHTTPDebug mocks base method. +func (m *MockCoordinator) ServeHTTPDebug(arg0 http.ResponseWriter, arg1 *http.Request) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ServeHTTPDebug", arg0, arg1) +} + +// ServeHTTPDebug indicates an expected call of ServeHTTPDebug. +func (mr *MockCoordinatorMockRecorder) ServeHTTPDebug(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeHTTPDebug", reflect.TypeOf((*MockCoordinator)(nil).ServeHTTPDebug), arg0, arg1) +} + +// ServeMultiAgent mocks base method. +func (m *MockCoordinator) ServeMultiAgent(arg0 uuid.UUID) tailnet.MultiAgentConn { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ServeMultiAgent", arg0) + ret0, _ := ret[0].(tailnet.MultiAgentConn) + return ret0 +} + +// ServeMultiAgent indicates an expected call of ServeMultiAgent. +func (mr *MockCoordinatorMockRecorder) ServeMultiAgent(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeMultiAgent", reflect.TypeOf((*MockCoordinator)(nil).ServeMultiAgent), arg0) +} diff --git a/tailnet/tailnettest/multiagentmock.go b/tailnet/tailnettest/multiagentmock.go deleted file mode 100644 index fd03a0e7f21a4..0000000000000 --- a/tailnet/tailnettest/multiagentmock.go +++ /dev/null @@ -1,141 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/coder/coder/v2/tailnet (interfaces: MultiAgentConn) -// -// Generated by this command: -// -// mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn -// - -// Package tailnettest is a generated GoMock package. -package tailnettest - -import ( - context "context" - reflect "reflect" - - tailnet "github.com/coder/coder/v2/tailnet" - uuid "github.com/google/uuid" - gomock "go.uber.org/mock/gomock" -) - -// MockMultiAgentConn is a mock of MultiAgentConn interface. -type MockMultiAgentConn struct { - ctrl *gomock.Controller - recorder *MockMultiAgentConnMockRecorder -} - -// MockMultiAgentConnMockRecorder is the mock recorder for MockMultiAgentConn. -type MockMultiAgentConnMockRecorder struct { - mock *MockMultiAgentConn -} - -// NewMockMultiAgentConn creates a new mock instance. -func NewMockMultiAgentConn(ctrl *gomock.Controller) *MockMultiAgentConn { - mock := &MockMultiAgentConn{ctrl: ctrl} - mock.recorder = &MockMultiAgentConnMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockMultiAgentConn) EXPECT() *MockMultiAgentConnMockRecorder { - return m.recorder -} - -// AgentIsLegacy mocks base method. -func (m *MockMultiAgentConn) AgentIsLegacy(arg0 uuid.UUID) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AgentIsLegacy", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// AgentIsLegacy indicates an expected call of AgentIsLegacy. -func (mr *MockMultiAgentConnMockRecorder) AgentIsLegacy(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AgentIsLegacy", reflect.TypeOf((*MockMultiAgentConn)(nil).AgentIsLegacy), arg0) -} - -// Close mocks base method. -func (m *MockMultiAgentConn) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockMultiAgentConnMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMultiAgentConn)(nil).Close)) -} - -// IsClosed mocks base method. -func (m *MockMultiAgentConn) IsClosed() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsClosed") - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsClosed indicates an expected call of IsClosed. -func (mr *MockMultiAgentConnMockRecorder) IsClosed() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosed", reflect.TypeOf((*MockMultiAgentConn)(nil).IsClosed)) -} - -// NextUpdate mocks base method. -func (m *MockMultiAgentConn) NextUpdate(arg0 context.Context) ([]*tailnet.Node, bool) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NextUpdate", arg0) - ret0, _ := ret[0].([]*tailnet.Node) - ret1, _ := ret[1].(bool) - return ret0, ret1 -} - -// NextUpdate indicates an expected call of NextUpdate. -func (mr *MockMultiAgentConnMockRecorder) NextUpdate(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextUpdate", reflect.TypeOf((*MockMultiAgentConn)(nil).NextUpdate), arg0) -} - -// SubscribeAgent mocks base method. -func (m *MockMultiAgentConn) SubscribeAgent(arg0 uuid.UUID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SubscribeAgent", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SubscribeAgent indicates an expected call of SubscribeAgent. -func (mr *MockMultiAgentConnMockRecorder) SubscribeAgent(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).SubscribeAgent), arg0) -} - -// UnsubscribeAgent mocks base method. -func (m *MockMultiAgentConn) UnsubscribeAgent(arg0 uuid.UUID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UnsubscribeAgent", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// UnsubscribeAgent indicates an expected call of UnsubscribeAgent. -func (mr *MockMultiAgentConnMockRecorder) UnsubscribeAgent(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).UnsubscribeAgent), arg0) -} - -// UpdateSelf mocks base method. -func (m *MockMultiAgentConn) UpdateSelf(arg0 *tailnet.Node) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateSelf", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// UpdateSelf indicates an expected call of UpdateSelf. -func (mr *MockMultiAgentConnMockRecorder) UpdateSelf(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSelf", reflect.TypeOf((*MockMultiAgentConn)(nil).UpdateSelf), arg0) -} diff --git a/tailnet/tailnettest/tailnettest.go b/tailnet/tailnettest/tailnettest.go index e9eb45ad96df9..e7ed6361a1090 100644 --- a/tailnet/tailnettest/tailnettest.go +++ b/tailnet/tailnettest/tailnettest.go @@ -21,7 +21,7 @@ import ( "github.com/coder/coder/v2/tailnet" ) -//go:generate mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn +//go:generate mockgen -destination ./coordinatormock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinator // RunDERPAndSTUN creates a DERP mapping for tests. func RunDERPAndSTUN(t *testing.T) (*tailcfg.DERPMap, *derp.Server) { diff --git a/tailnet/trackedconn.go b/tailnet/trackedconn.go index 3b3feaa13286b..a801cdfae0964 100644 --- a/tailnet/trackedconn.go +++ b/tailnet/trackedconn.go @@ -11,6 +11,7 @@ import ( "github.com/google/uuid" "cdr.dev/slog" + "github.com/coder/coder/v2/tailnet/proto" ) const ( @@ -29,7 +30,7 @@ type TrackedConn struct { cancel func() kind QueueKind conn net.Conn - updates chan []*Node + updates chan *proto.CoordinateResponse logger slog.Logger lastData []byte @@ -55,7 +56,7 @@ func NewTrackedConn(ctx context.Context, cancel func(), // coordinator mutex while queuing. Node updates don't // come quickly, so 512 should be plenty for all but // the most pathological cases. - updates := make(chan []*Node, ResponseBufferSize) + updates := make(chan *proto.CoordinateResponse, ResponseBufferSize) now := time.Now().Unix() return &TrackedConn{ ctx: ctx, @@ -72,10 +73,10 @@ func NewTrackedConn(ctx context.Context, cancel func(), } } -func (t *TrackedConn) Enqueue(n []*Node) (err error) { +func (t *TrackedConn) Enqueue(resp *proto.CoordinateResponse) (err error) { atomic.StoreInt64(&t.lastWrite, time.Now().Unix()) select { - case t.updates <- n: + case t.updates <- resp: return nil default: return ErrWouldBlock @@ -124,7 +125,16 @@ func (t *TrackedConn) SendUpdates() { case <-t.ctx.Done(): t.logger.Debug(t.ctx, "done sending updates") return - case nodes := <-t.updates: + case resp := <-t.updates: + nodes, err := OnlyNodeUpdates(resp) + if err != nil { + t.logger.Critical(t.ctx, "unable to parse response", slog.Error(err)) + return + } + if len(nodes) == 0 { + t.logger.Debug(t.ctx, "skipping response with no nodes") + continue + } data, err := json.Marshal(nodes) if err != nil { t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes))
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: