diff --git a/agent/agent.go b/agent/agent.go index c2e2670a41257..f755587e793ec 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -64,6 +64,7 @@ type Options struct { SSHMaxTimeout time.Duration TailnetListenPort uint16 Subsystem codersdk.AgentSubsystem + Addresses []netip.Prefix PrometheusRegistry *prometheus.Registry } @@ -132,6 +133,7 @@ func New(options Options) Agent { connStatsChan: make(chan *agentsdk.Stats, 1), sshMaxTimeout: options.SSHMaxTimeout, subsystem: options.Subsystem, + addresses: options.Addresses, prometheusRegistry: prometheusRegistry, metrics: newAgentMetrics(prometheusRegistry), @@ -177,6 +179,7 @@ type agent struct { lifecycleStates []agentsdk.PostLifecycleRequest network *tailnet.Conn + addresses []netip.Prefix connStatsChan chan *agentsdk.Stats latestStat atomic.Pointer[agentsdk.Stats] @@ -545,6 +548,10 @@ func (a *agent) run(ctx context.Context) error { } a.logger.Info(ctx, "fetched manifest", slog.F("manifest", manifest)) + if manifest.AgentID == uuid.Nil { + return xerrors.New("nil agentID returned by manifest") + } + // Expand the directory and send it back to coderd so external // applications that rely on the directory can use it. // @@ -630,7 +637,7 @@ func (a *agent) run(ctx context.Context) error { network := a.network a.closeMutex.Unlock() if network == nil { - network, err = a.createTailnet(ctx, manifest.DERPMap, manifest.DisableDirectConnections) + network, err = a.createTailnet(ctx, manifest.AgentID, manifest.DERPMap, manifest.DisableDirectConnections) if err != nil { return xerrors.Errorf("create tailnet: %w", err) } @@ -648,6 +655,11 @@ func (a *agent) run(ctx context.Context) error { a.startReportingConnectionStats(ctx) } else { + // Update the wireguard IPs if the agent ID changed. + err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID)) + if err != nil { + a.logger.Error(ctx, "update tailnet addresses", slog.Error(err)) + } // Update the DERP map and allow/disallow direct connections. network.SetDERPMap(manifest.DERPMap) network.SetBlockEndpoints(manifest.DisableDirectConnections) @@ -661,6 +673,20 @@ func (a *agent) run(ctx context.Context) error { return nil } +func (a *agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix { + if len(a.addresses) == 0 { + return []netip.Prefix{ + // This is the IP that should be used primarily. + netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128), + // We also listen on the legacy codersdk.WorkspaceAgentIP. This + // allows for a transition away from wsconncache. + netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128), + } + } + + return a.addresses +} + func (a *agent) trackConnGoroutine(fn func()) error { a.closeMutex.Lock() defer a.closeMutex.Unlock() @@ -675,9 +701,9 @@ func (a *agent) trackConnGoroutine(fn func()) error { return nil } -func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap, disableDirectConnections bool) (_ *tailnet.Conn, err error) { +func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *tailcfg.DERPMap, disableDirectConnections bool) (_ *tailnet.Conn, err error) { network, err := tailnet.NewConn(&tailnet.Options{ - Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, + Addresses: a.wireguardAddresses(agentID), DERPMap: derpMap, Logger: a.logger.Named("tailnet"), ListenPort: a.tailnetListenPort, diff --git a/agent/agent_test.go b/agent/agent_test.go index 8ac7eca050af9..e9b1f485f718a 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -35,7 +35,6 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" "golang.org/x/crypto/ssh" - "golang.org/x/exp/maps" "golang.org/x/exp/slices" "golang.org/x/xerrors" "tailscale.com/net/speedtest" @@ -45,6 +44,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" "github.com/coder/coder/agent/agentssh" + "github.com/coder/coder/agent/agenttest" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk/agentsdk" @@ -67,7 +67,7 @@ func TestAgent_Stats_SSH(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, stats, _, _ := setupAgent(t, &client{}, 0) + conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) @@ -100,7 +100,7 @@ func TestAgent_Stats_ReconnectingPTY(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, stats, _, _ := setupAgent(t, &client{}, 0) + conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) ptyConn, err := conn.ReconnectingPTY(ctx, uuid.New(), 128, 128, "/bin/bash") require.NoError(t, err) @@ -130,7 +130,7 @@ func TestAgent_Stats_Magic(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{}, 0) + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -157,7 +157,7 @@ func TestAgent_Stats_Magic(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, stats, _, _ := setupAgent(t, &client{}, 0) + conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -425,20 +425,19 @@ func TestAgent_Session_TTY_MOTD_Update(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled // Allow the blank identifiers. - conn, client, _, _, _ := setupAgent(t, &client{}, 0) + conn, client, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) for _, test := range tests { test := test - + // Set new banner func and wait for the agent to call it to update the + // banner. ready := make(chan struct{}, 2) - client.mu.Lock() - client.getServiceBanner = func() (codersdk.ServiceBannerConfig, error) { + client.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) { select { case ready <- struct{}{}: default: } return test.banner, nil - } - client.mu.Unlock() + }) <-ready <-ready // Wait for two updates to ensure the value has propagated. @@ -542,7 +541,7 @@ func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{}, 0) + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -592,7 +591,7 @@ func TestAgent_Session_TTY_HugeOutputIsNotLost(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{}, 0) + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -922,7 +921,7 @@ func TestAgent_SFTP(t *testing.T) { home = "/" + strings.ReplaceAll(home, "\\", "/") } //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{}, 0) + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -954,7 +953,7 @@ func TestAgent_SCP(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{}, 0) + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -1062,16 +1061,15 @@ func TestAgent_StartupScript(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - client := &client{ - t: t, - agentID: uuid.New(), - manifest: agentsdk.Manifest{ + client := agenttest.NewClient(t, + uuid.New(), + agentsdk.Manifest{ StartupScript: command, DERPMap: &tailcfg.DERPMap{}, }, - statsChan: make(chan *agentsdk.Stats), - coordinator: tailnet.NewCoordinator(logger), - } + make(chan *agentsdk.Stats), + tailnet.NewCoordinator(logger), + ) closer := agent.New(agent.Options{ Client: client, Filesystem: afero.NewMemMapFs(), @@ -1082,36 +1080,35 @@ func TestAgent_StartupScript(t *testing.T) { _ = closer.Close() }) assert.Eventually(t, func() bool { - got := client.getLifecycleStates() + got := client.GetLifecycleStates() return len(got) > 0 && got[len(got)-1] == codersdk.WorkspaceAgentLifecycleReady }, testutil.WaitShort, testutil.IntervalMedium) - require.Len(t, client.getStartupLogs(), 1) - require.Equal(t, output, client.getStartupLogs()[0].Output) + require.Len(t, client.GetStartupLogs(), 1) + require.Equal(t, output, client.GetStartupLogs()[0].Output) }) // This ensures that even when coderd sends back that the startup // script has written too many lines it will still succeed! t.Run("OverflowsAndSkips", func(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - client := &client{ - t: t, - agentID: uuid.New(), - manifest: agentsdk.Manifest{ + client := agenttest.NewClient(t, + uuid.New(), + agentsdk.Manifest{ StartupScript: command, DERPMap: &tailcfg.DERPMap{}, }, - patchWorkspaceLogs: func() error { - resp := httptest.NewRecorder() - httpapi.Write(context.Background(), resp, http.StatusRequestEntityTooLarge, codersdk.Response{ - Message: "Too many lines!", - }) - res := resp.Result() - defer res.Body.Close() - return codersdk.ReadBodyAsError(res) - }, - statsChan: make(chan *agentsdk.Stats), - coordinator: tailnet.NewCoordinator(logger), + make(chan *agentsdk.Stats, 50), + tailnet.NewCoordinator(logger), + ) + client.PatchWorkspaceLogs = func() error { + resp := httptest.NewRecorder() + httpapi.Write(context.Background(), resp, http.StatusRequestEntityTooLarge, codersdk.Response{ + Message: "Too many lines!", + }) + res := resp.Result() + defer res.Body.Close() + return codersdk.ReadBodyAsError(res) } closer := agent.New(agent.Options{ Client: client, @@ -1123,10 +1120,10 @@ func TestAgent_StartupScript(t *testing.T) { _ = closer.Close() }) assert.Eventually(t, func() bool { - got := client.getLifecycleStates() + got := client.GetLifecycleStates() return len(got) > 0 && got[len(got)-1] == codersdk.WorkspaceAgentLifecycleReady }, testutil.WaitShort, testutil.IntervalMedium) - require.Len(t, client.getStartupLogs(), 0) + require.Len(t, client.GetStartupLogs(), 0) }) } @@ -1138,28 +1135,26 @@ func TestAgent_Metadata(t *testing.T) { t.Run("Once", func(t *testing.T) { t.Parallel() //nolint:dogsled - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - Metadata: []codersdk.WorkspaceAgentMetadataDescription{ - { - Key: "greeting", - Interval: 0, - Script: echoHello, - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + Metadata: []codersdk.WorkspaceAgentMetadataDescription{ + { + Key: "greeting", + Interval: 0, + Script: echoHello, }, }, }, 0) var gotMd map[string]agentsdk.PostMetadataRequest require.Eventually(t, func() bool { - gotMd = client.getMetadata() + gotMd = client.GetMetadata() return len(gotMd) == 1 }, testutil.WaitShort, testutil.IntervalMedium) collectedAt := gotMd["greeting"].CollectedAt require.Never(t, func() bool { - gotMd = client.getMetadata() + gotMd = client.GetMetadata() if len(gotMd) != 1 { panic("unexpected number of metadata") } @@ -1170,22 +1165,20 @@ func TestAgent_Metadata(t *testing.T) { t.Run("Many", func(t *testing.T) { t.Parallel() //nolint:dogsled - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - Metadata: []codersdk.WorkspaceAgentMetadataDescription{ - { - Key: "greeting", - Interval: 1, - Timeout: 100, - Script: echoHello, - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + Metadata: []codersdk.WorkspaceAgentMetadataDescription{ + { + Key: "greeting", + Interval: 1, + Timeout: 100, + Script: echoHello, }, }, }, 0) var gotMd map[string]agentsdk.PostMetadataRequest require.Eventually(t, func() bool { - gotMd = client.getMetadata() + gotMd = client.GetMetadata() return len(gotMd) == 1 }, testutil.WaitShort, testutil.IntervalMedium) @@ -1195,7 +1188,7 @@ func TestAgent_Metadata(t *testing.T) { } if !assert.Eventually(t, func() bool { - gotMd = client.getMetadata() + gotMd = client.GetMetadata() return gotMd["greeting"].CollectedAt.After(collectedAt1) }, testutil.WaitShort, testutil.IntervalMedium) { t.Fatalf("expected metadata to be collected again") @@ -1221,29 +1214,27 @@ func TestAgentMetadata_Timing(t *testing.T) { script = "echo hello | tee -a " + greetingPath ) //nolint:dogsled - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - Metadata: []codersdk.WorkspaceAgentMetadataDescription{ - { - Key: "greeting", - Interval: reportInterval, - Script: script, - }, - { - Key: "bad", - Interval: reportInterval, - Script: "exit 1", - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + Metadata: []codersdk.WorkspaceAgentMetadataDescription{ + { + Key: "greeting", + Interval: reportInterval, + Script: script, + }, + { + Key: "bad", + Interval: reportInterval, + Script: "exit 1", }, }, }, 0) require.Eventually(t, func() bool { - return len(client.getMetadata()) == 2 + return len(client.GetMetadata()) == 2 }, testutil.WaitShort, testutil.IntervalMedium) for start := time.Now(); time.Since(start) < testutil.WaitMedium; time.Sleep(testutil.IntervalMedium) { - md := client.getMetadata() + md := client.GetMetadata() require.Len(t, md, 2, "got: %+v", md) require.Equal(t, "hello\n", md["greeting"].Value) @@ -1285,11 +1276,9 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("StartTimeout", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - StartupScript: "sleep 3", - StartupScriptTimeout: time.Nanosecond, - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + StartupScript: "sleep 3", + StartupScriptTimeout: time.Nanosecond, }, 0) want := []codersdk.WorkspaceAgentLifecycle{ @@ -1299,7 +1288,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates() + got = client.GetLifecycleStates() return slices.Contains(got, want[len(want)-1]) }, testutil.WaitShort, testutil.IntervalMedium) @@ -1309,11 +1298,9 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("StartError", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - StartupScript: "false", - StartupScriptTimeout: 30 * time.Second, - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + StartupScript: "false", + StartupScriptTimeout: 30 * time.Second, }, 0) want := []codersdk.WorkspaceAgentLifecycle{ @@ -1323,7 +1310,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates() + got = client.GetLifecycleStates() return slices.Contains(got, want[len(want)-1]) }, testutil.WaitShort, testutil.IntervalMedium) @@ -1333,11 +1320,9 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("Ready", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - StartupScript: "true", - StartupScriptTimeout: 30 * time.Second, - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + StartupScript: "true", + StartupScriptTimeout: 30 * time.Second, }, 0) want := []codersdk.WorkspaceAgentLifecycle{ @@ -1347,7 +1332,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates() + got = client.GetLifecycleStates() return len(got) > 0 && got[len(got)-1] == want[len(want)-1] }, testutil.WaitShort, testutil.IntervalMedium) @@ -1357,15 +1342,13 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("ShuttingDown", func(t *testing.T) { t.Parallel() - _, client, _, _, closer := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - ShutdownScript: "sleep 3", - StartupScriptTimeout: 30 * time.Second, - }, + _, client, _, _, closer := setupAgent(t, agentsdk.Manifest{ + ShutdownScript: "sleep 3", + StartupScriptTimeout: 30 * time.Second, }, 0) assert.Eventually(t, func() bool { - return slices.Contains(client.getLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady) + return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady) }, testutil.WaitShort, testutil.IntervalMedium) // Start close asynchronously so that we an inspect the state. @@ -1387,7 +1370,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates() + got = client.GetLifecycleStates() return slices.Contains(got, want[len(want)-1]) }, testutil.WaitShort, testutil.IntervalMedium) @@ -1397,15 +1380,13 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("ShutdownTimeout", func(t *testing.T) { t.Parallel() - _, client, _, _, closer := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - ShutdownScript: "sleep 3", - ShutdownScriptTimeout: time.Nanosecond, - }, + _, client, _, _, closer := setupAgent(t, agentsdk.Manifest{ + ShutdownScript: "sleep 3", + ShutdownScriptTimeout: time.Nanosecond, }, 0) assert.Eventually(t, func() bool { - return slices.Contains(client.getLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady) + return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady) }, testutil.WaitShort, testutil.IntervalMedium) // Start close asynchronously so that we an inspect the state. @@ -1428,7 +1409,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates() + got = client.GetLifecycleStates() return slices.Contains(got, want[len(want)-1]) }, testutil.WaitShort, testutil.IntervalMedium) @@ -1438,15 +1419,13 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("ShutdownError", func(t *testing.T) { t.Parallel() - _, client, _, _, closer := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - ShutdownScript: "false", - ShutdownScriptTimeout: 30 * time.Second, - }, + _, client, _, _, closer := setupAgent(t, agentsdk.Manifest{ + ShutdownScript: "false", + ShutdownScriptTimeout: 30 * time.Second, }, 0) assert.Eventually(t, func() bool { - return slices.Contains(client.getLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady) + return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady) }, testutil.WaitShort, testutil.IntervalMedium) // Start close asynchronously so that we an inspect the state. @@ -1469,7 +1448,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates() + got = client.GetLifecycleStates() return slices.Contains(got, want[len(want)-1]) }, testutil.WaitShort, testutil.IntervalMedium) @@ -1480,17 +1459,18 @@ func TestAgent_Lifecycle(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) expected := "this-is-shutdown" - client := &client{ - t: t, - agentID: uuid.New(), - manifest: agentsdk.Manifest{ - DERPMap: tailnettest.RunDERPAndSTUN(t), + derpMap, _ := tailnettest.RunDERPAndSTUN(t) + + client := agenttest.NewClient(t, + uuid.New(), + agentsdk.Manifest{ + DERPMap: derpMap, StartupScript: "echo 1", ShutdownScript: "echo " + expected, }, - statsChan: make(chan *agentsdk.Stats), - coordinator: tailnet.NewCoordinator(logger), - } + make(chan *agentsdk.Stats, 50), + tailnet.NewCoordinator(logger), + ) fs := afero.NewMemMapFs() agent := agent.New(agent.Options{ @@ -1536,71 +1516,63 @@ func TestAgent_Startup(t *testing.T) { t.Run("EmptyDirectory", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - StartupScript: "true", - StartupScriptTimeout: 30 * time.Second, - Directory: "", - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + StartupScript: "true", + StartupScriptTimeout: 30 * time.Second, + Directory: "", }, 0) assert.Eventually(t, func() bool { - return client.getStartup().Version != "" + return client.GetStartup().Version != "" }, testutil.WaitShort, testutil.IntervalFast) - require.Equal(t, "", client.getStartup().ExpandedDirectory) + require.Equal(t, "", client.GetStartup().ExpandedDirectory) }) t.Run("HomeDirectory", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - StartupScript: "true", - StartupScriptTimeout: 30 * time.Second, - Directory: "~", - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + StartupScript: "true", + StartupScriptTimeout: 30 * time.Second, + Directory: "~", }, 0) assert.Eventually(t, func() bool { - return client.getStartup().Version != "" + return client.GetStartup().Version != "" }, testutil.WaitShort, testutil.IntervalFast) homeDir, err := os.UserHomeDir() require.NoError(t, err) - require.Equal(t, homeDir, client.getStartup().ExpandedDirectory) + require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory) }) t.Run("NotAbsoluteDirectory", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - StartupScript: "true", - StartupScriptTimeout: 30 * time.Second, - Directory: "coder/coder", - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + StartupScript: "true", + StartupScriptTimeout: 30 * time.Second, + Directory: "coder/coder", }, 0) assert.Eventually(t, func() bool { - return client.getStartup().Version != "" + return client.GetStartup().Version != "" }, testutil.WaitShort, testutil.IntervalFast) homeDir, err := os.UserHomeDir() require.NoError(t, err) - require.Equal(t, filepath.Join(homeDir, "coder/coder"), client.getStartup().ExpandedDirectory) + require.Equal(t, filepath.Join(homeDir, "coder/coder"), client.GetStartup().ExpandedDirectory) }) t.Run("HomeEnvironmentVariable", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - StartupScript: "true", - StartupScriptTimeout: 30 * time.Second, - Directory: "$HOME", - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + StartupScript: "true", + StartupScriptTimeout: 30 * time.Second, + Directory: "$HOME", }, 0) assert.Eventually(t, func() bool { - return client.getStartup().Version != "" + return client.GetStartup().Version != "" }, testutil.WaitShort, testutil.IntervalFast) homeDir, err := os.UserHomeDir() require.NoError(t, err) - require.Equal(t, homeDir, client.getStartup().ExpandedDirectory) + require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory) }) } @@ -1617,7 +1589,7 @@ func TestAgent_ReconnectingPTY(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{}, 0) + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) id := uuid.New() netConn, err := conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash") require.NoError(t, err) @@ -1719,7 +1691,7 @@ func TestAgent_Dial(t *testing.T) { }() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{}, 0) + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) require.True(t, conn.AwaitReachable(context.Background())) conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) require.NoError(t, err) @@ -1739,12 +1711,10 @@ func TestAgent_Speedtest(t *testing.T) { t.Skip("This test is relatively flakey because of Tailscale's speedtest code...") ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - derpMap := tailnettest.RunDERPAndSTUN(t) + derpMap, _ := tailnettest.RunDERPAndSTUN(t) //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - DERPMap: derpMap, - }, + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{ + DERPMap: derpMap, }, 0) defer conn.Close() res, err := conn.Speedtest(ctx, speedtest.Upload, 250*time.Millisecond) @@ -1761,17 +1731,16 @@ func TestAgent_Reconnect(t *testing.T) { defer coordinator.Close() agentID := uuid.New() - statsCh := make(chan *agentsdk.Stats) - derpMap := tailnettest.RunDERPAndSTUN(t) - client := &client{ - t: t, - agentID: agentID, - manifest: agentsdk.Manifest{ + statsCh := make(chan *agentsdk.Stats, 50) + derpMap, _ := tailnettest.RunDERPAndSTUN(t) + client := agenttest.NewClient(t, + agentID, + agentsdk.Manifest{ DERPMap: derpMap, }, - statsChan: statsCh, - coordinator: coordinator, - } + statsCh, + coordinator, + ) initialized := atomic.Int32{} closer := agent.New(agent.Options{ ExchangeToken: func(ctx context.Context) (string, error) { @@ -1786,7 +1755,7 @@ func TestAgent_Reconnect(t *testing.T) { require.Eventually(t, func() bool { return coordinator.Node(agentID) != nil }, testutil.WaitShort, testutil.IntervalFast) - client.lastWorkspaceAgent() + client.LastWorkspaceAgent() require.Eventually(t, func() bool { return initialized.Load() == 2 }, testutil.WaitShort, testutil.IntervalFast) @@ -1798,16 +1767,15 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) { coordinator := tailnet.NewCoordinator(logger) defer coordinator.Close() - client := &client{ - t: t, - agentID: uuid.New(), - manifest: agentsdk.Manifest{ + client := agenttest.NewClient(t, + uuid.New(), + agentsdk.Manifest{ GitAuthConfigs: 1, DERPMap: &tailcfg.DERPMap{}, }, - statsChan: make(chan *agentsdk.Stats), - coordinator: coordinator, - } + make(chan *agentsdk.Stats, 50), + coordinator, + ) filesystem := afero.NewMemMapFs() closer := agent.New(agent.Options{ ExchangeToken: func(ctx context.Context) (string, error) { @@ -1830,7 +1798,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) { func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*ptytest.PTYCmd, pty.Process) { //nolint:dogsled - agentConn, _, _, _, _ := setupAgent(t, &client{}, 0) + agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) waitGroup := sync.WaitGroup{} @@ -1883,12 +1851,11 @@ func setupSSHSession( ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, _, fs, _ := setupAgent(t, &client{ - manifest: options, - getServiceBanner: func() (codersdk.ServiceBannerConfig, error) { + conn, _, _, fs, _ := setupAgent(t, options, 0, func(c *agenttest.Client, _ *agent.Options) { + c.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) { return serviceBanner, nil - }, - }, 0) + }) + }) if prepareFS != nil { prepareFS(fs) } @@ -1905,31 +1872,28 @@ func setupSSHSession( return session } -type closeFunc func() error - -func (c closeFunc) Close() error { - return c() -} - -func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func(agent.Options) agent.Options) ( +func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) ( *codersdk.WorkspaceAgentConn, - *client, + *agenttest.Client, <-chan *agentsdk.Stats, afero.Fs, io.Closer, ) { - c.t = t logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - if c.manifest.DERPMap == nil { - c.manifest.DERPMap = tailnettest.RunDERPAndSTUN(t) + if metadata.DERPMap == nil { + metadata.DERPMap, _ = tailnettest.RunDERPAndSTUN(t) } - c.coordinator = tailnet.NewCoordinator(logger) + if metadata.AgentID == uuid.Nil { + metadata.AgentID = uuid.New() + } + coordinator := tailnet.NewCoordinator(logger) t.Cleanup(func() { - _ = c.coordinator.Close() + _ = coordinator.Close() }) - c.agentID = uuid.New() - c.statsChan = make(chan *agentsdk.Stats, 50) + statsCh := make(chan *agentsdk.Stats, 50) fs := afero.NewMemMapFs() + c := agenttest.NewClient(t, metadata.AgentID, metadata, statsCh, coordinator) + options := agent.Options{ Client: c, Filesystem: fs, @@ -1938,7 +1902,7 @@ func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func( } for _, opt := range opts { - options = opt(options) + opt(c, &options) } closer := agent.New(options) @@ -1947,7 +1911,7 @@ func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func( }) conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, - DERPMap: c.manifest.DERPMap, + DERPMap: metadata.DERPMap, Logger: logger.Named("client"), }) require.NoError(t, err) @@ -1961,15 +1925,15 @@ func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func( }) go func() { defer close(serveClientDone) - c.coordinator.ServeClient(serverConn, uuid.New(), c.agentID) + coordinator.ServeClient(serverConn, uuid.New(), metadata.AgentID) }() sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { return conn.UpdateNodes(node, false) }) conn.SetNodeCallback(sendNode) - agentConn := &codersdk.WorkspaceAgentConn{ - Conn: conn, - } + agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ + AgentID: metadata.AgentID, + }) t.Cleanup(func() { _ = agentConn.Close() }) @@ -1980,7 +1944,7 @@ func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func( if !agentConn.AwaitReachable(ctx) { t.Fatal("agent not reachable") } - return agentConn, c, c.statsChan, fs, closer + return agentConn, c, statsCh, fs, closer } var dialTestPayload = []byte("dean-was-here123") @@ -2043,146 +2007,6 @@ func testSessionOutput(t *testing.T, session *ssh.Session, expected, unexpected } } -type client struct { - t *testing.T - agentID uuid.UUID - manifest agentsdk.Manifest - metadata map[string]agentsdk.PostMetadataRequest - statsChan chan *agentsdk.Stats - coordinator tailnet.Coordinator - lastWorkspaceAgent func() - patchWorkspaceLogs func() error - getServiceBanner func() (codersdk.ServiceBannerConfig, error) - - mu sync.Mutex // Protects following. - lifecycleStates []codersdk.WorkspaceAgentLifecycle - startup agentsdk.PostStartupRequest - logs []agentsdk.StartupLog -} - -func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) { - return c.manifest, nil -} - -func (c *client) Listen(_ context.Context) (net.Conn, error) { - clientConn, serverConn := net.Pipe() - closed := make(chan struct{}) - c.lastWorkspaceAgent = func() { - _ = serverConn.Close() - _ = clientConn.Close() - <-closed - } - c.t.Cleanup(c.lastWorkspaceAgent) - go func() { - _ = c.coordinator.ServeAgent(serverConn, c.agentID, "") - close(closed) - }() - return clientConn, nil -} - -func (c *client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) { - doneCh := make(chan struct{}) - ctx, cancel := context.WithCancel(ctx) - - go func() { - defer close(doneCh) - - setInterval(500 * time.Millisecond) - for { - select { - case <-ctx.Done(): - return - case stat := <-statsChan: - select { - case c.statsChan <- stat: - case <-ctx.Done(): - return - default: - // We don't want to send old stats. - continue - } - } - } - }() - return closeFunc(func() error { - cancel() - <-doneCh - close(c.statsChan) - return nil - }), nil -} - -func (c *client) getLifecycleStates() []codersdk.WorkspaceAgentLifecycle { - c.mu.Lock() - defer c.mu.Unlock() - return c.lifecycleStates -} - -func (c *client) PostLifecycle(_ context.Context, req agentsdk.PostLifecycleRequest) error { - c.mu.Lock() - defer c.mu.Unlock() - c.lifecycleStates = append(c.lifecycleStates, req.State) - return nil -} - -func (*client) PostAppHealth(_ context.Context, _ agentsdk.PostAppHealthsRequest) error { - return nil -} - -func (c *client) getStartup() agentsdk.PostStartupRequest { - c.mu.Lock() - defer c.mu.Unlock() - return c.startup -} - -func (c *client) getMetadata() map[string]agentsdk.PostMetadataRequest { - c.mu.Lock() - defer c.mu.Unlock() - return maps.Clone(c.metadata) -} - -func (c *client) PostMetadata(_ context.Context, key string, req agentsdk.PostMetadataRequest) error { - c.mu.Lock() - defer c.mu.Unlock() - if c.metadata == nil { - c.metadata = make(map[string]agentsdk.PostMetadataRequest) - } - c.metadata[key] = req - return nil -} - -func (c *client) PostStartup(_ context.Context, startup agentsdk.PostStartupRequest) error { - c.mu.Lock() - defer c.mu.Unlock() - c.startup = startup - return nil -} - -func (c *client) getStartupLogs() []agentsdk.StartupLog { - c.mu.Lock() - defer c.mu.Unlock() - return c.logs -} - -func (c *client) PatchStartupLogs(_ context.Context, logs agentsdk.PatchStartupLogs) error { - c.mu.Lock() - defer c.mu.Unlock() - if c.patchWorkspaceLogs != nil { - return c.patchWorkspaceLogs() - } - c.logs = append(c.logs, logs.Logs...) - return nil -} - -func (c *client) GetServiceBanner(_ context.Context) (codersdk.ServiceBannerConfig, error) { - c.mu.Lock() - defer c.mu.Unlock() - if c.getServiceBanner != nil { - return c.getServiceBanner() - } - return codersdk.ServiceBannerConfig{}, nil -} - // tempDirUnixSocket returns a temporary directory that can safely hold unix // sockets (probably). // @@ -2214,9 +2038,8 @@ func TestAgent_Metrics_SSH(t *testing.T) { registry := prometheus.NewRegistry() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{}, 0, func(o agent.Options) agent.Options { + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { o.PrometheusRegistry = registry - return o }) sshClient, err := conn.SSHClient(ctx) diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go new file mode 100644 index 0000000000000..c69ff59eb730b --- /dev/null +++ b/agent/agenttest/client.go @@ -0,0 +1,189 @@ +package agenttest + +import ( + "context" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "golang.org/x/exp/maps" + + "cdr.dev/slog" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/codersdk/agentsdk" + "github.com/coder/coder/tailnet" +) + +func NewClient(t testing.TB, + agentID uuid.UUID, + manifest agentsdk.Manifest, + statsChan chan *agentsdk.Stats, + coordinator tailnet.Coordinator, +) *Client { + if manifest.AgentID == uuid.Nil { + manifest.AgentID = agentID + } + return &Client{ + t: t, + agentID: agentID, + manifest: manifest, + statsChan: statsChan, + coordinator: coordinator, + } +} + +type Client struct { + t testing.TB + agentID uuid.UUID + manifest agentsdk.Manifest + metadata map[string]agentsdk.PostMetadataRequest + statsChan chan *agentsdk.Stats + coordinator tailnet.Coordinator + LastWorkspaceAgent func() + PatchWorkspaceLogs func() error + GetServiceBannerFunc func() (codersdk.ServiceBannerConfig, error) + + mu sync.Mutex // Protects following. + lifecycleStates []codersdk.WorkspaceAgentLifecycle + startup agentsdk.PostStartupRequest + logs []agentsdk.StartupLog +} + +func (c *Client) Manifest(_ context.Context) (agentsdk.Manifest, error) { + return c.manifest, nil +} + +func (c *Client) Listen(_ context.Context) (net.Conn, error) { + clientConn, serverConn := net.Pipe() + closed := make(chan struct{}) + c.LastWorkspaceAgent = func() { + _ = serverConn.Close() + _ = clientConn.Close() + <-closed + } + c.t.Cleanup(c.LastWorkspaceAgent) + go func() { + _ = c.coordinator.ServeAgent(serverConn, c.agentID, "") + close(closed) + }() + return clientConn, nil +} + +func (c *Client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) { + doneCh := make(chan struct{}) + ctx, cancel := context.WithCancel(ctx) + + go func() { + defer close(doneCh) + + setInterval(500 * time.Millisecond) + for { + select { + case <-ctx.Done(): + return + case stat := <-statsChan: + select { + case c.statsChan <- stat: + case <-ctx.Done(): + return + default: + // We don't want to send old stats. + continue + } + } + } + }() + return closeFunc(func() error { + cancel() + <-doneCh + close(c.statsChan) + return nil + }), nil +} + +func (c *Client) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle { + c.mu.Lock() + defer c.mu.Unlock() + return c.lifecycleStates +} + +func (c *Client) PostLifecycle(_ context.Context, req agentsdk.PostLifecycleRequest) error { + c.mu.Lock() + defer c.mu.Unlock() + c.lifecycleStates = append(c.lifecycleStates, req.State) + return nil +} + +func (*Client) PostAppHealth(_ context.Context, _ agentsdk.PostAppHealthsRequest) error { + return nil +} + +func (c *Client) GetStartup() agentsdk.PostStartupRequest { + c.mu.Lock() + defer c.mu.Unlock() + return c.startup +} + +func (c *Client) GetMetadata() map[string]agentsdk.PostMetadataRequest { + c.mu.Lock() + defer c.mu.Unlock() + return maps.Clone(c.metadata) +} + +func (c *Client) PostMetadata(_ context.Context, key string, req agentsdk.PostMetadataRequest) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.metadata == nil { + c.metadata = make(map[string]agentsdk.PostMetadataRequest) + } + c.metadata[key] = req + return nil +} + +func (c *Client) PostStartup(_ context.Context, startup agentsdk.PostStartupRequest) error { + c.mu.Lock() + defer c.mu.Unlock() + c.startup = startup + return nil +} + +func (c *Client) GetStartupLogs() []agentsdk.StartupLog { + c.mu.Lock() + defer c.mu.Unlock() + return c.logs +} + +func (c *Client) PatchStartupLogs(_ context.Context, logs agentsdk.PatchStartupLogs) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.PatchWorkspaceLogs != nil { + return c.PatchWorkspaceLogs() + } + c.logs = append(c.logs, logs.Logs...) + return nil +} + +func (c *Client) SetServiceBannerFunc(f func() (codersdk.ServiceBannerConfig, error)) { + c.mu.Lock() + defer c.mu.Unlock() + + c.GetServiceBannerFunc = f +} + +func (c *Client) GetServiceBanner(_ context.Context) (codersdk.ServiceBannerConfig, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.GetServiceBannerFunc != nil { + return c.GetServiceBannerFunc() + } + return codersdk.ServiceBannerConfig{}, nil +} + +type closeFunc func() error + +func (c closeFunc) Close() error { + return c() +} diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 328107b39d54b..31970e84477cc 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -5961,6 +5961,9 @@ const docTemplate = `{ "agentsdk.Manifest": { "type": "object", "properties": { + "agent_id": { + "type": "string" + }, "apps": { "type": "array", "items": { @@ -7617,6 +7620,7 @@ const docTemplate = `{ "workspace_actions", "tailnet_ha_coordinator", "convert-to-oidc", + "single_tailnet", "workspace_build_logs_ui" ], "x-enum-varnames": [ @@ -7624,6 +7628,7 @@ const docTemplate = `{ "ExperimentWorkspaceActions", "ExperimentTailnetHACoordinator", "ExperimentConvertToOIDC", + "ExperimentSingleTailnet", "ExperimentWorkspaceBuildLogsUI" ] }, diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 7ea1a1de0633c..841f9c50bbe5f 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -5251,6 +5251,9 @@ "agentsdk.Manifest": { "type": "object", "properties": { + "agent_id": { + "type": "string" + }, "apps": { "type": "array", "items": { @@ -6818,6 +6821,7 @@ "workspace_actions", "tailnet_ha_coordinator", "convert-to-oidc", + "single_tailnet", "workspace_build_logs_ui" ], "x-enum-varnames": [ @@ -6825,6 +6829,7 @@ "ExperimentWorkspaceActions", "ExperimentTailnetHACoordinator", "ExperimentConvertToOIDC", + "ExperimentSingleTailnet", "ExperimentWorkspaceBuildLogsUI" ] }, diff --git a/coderd/coderd.go b/coderd/coderd.go index dc14727879f08..7104049187235 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -364,8 +364,23 @@ func New(options *Options) *API { } api.Auditor.Store(&options.Auditor) - api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0) api.TailnetCoordinator.Store(&options.TailnetCoordinator) + if api.Experiments.Enabled(codersdk.ExperimentSingleTailnet) { + api.agentProvider, err = NewServerTailnet(api.ctx, + options.Logger, + options.DERPServer, + options.DERPMap, + &api.TailnetCoordinator, + wsconncache.New(api._dialWorkspaceAgentTailnet, 0), + ) + if err != nil { + panic("failed to setup server tailnet: " + err.Error()) + } + } else { + api.agentProvider = &wsconncache.AgentProvider{ + Cache: wsconncache.New(api._dialWorkspaceAgentTailnet, 0), + } + } api.workspaceAppServer = &workspaceapps.Server{ Logger: options.Logger.Named("workspaceapps"), @@ -377,7 +392,7 @@ func New(options *Options) *API { RealIPConfig: options.RealIPConfig, SignedTokenProvider: api.WorkspaceAppsProvider, - WorkspaceConnCache: api.workspaceAgentCache, + AgentProvider: api.agentProvider, AppSecurityKey: options.AppSecurityKey, DisablePathApps: options.DeploymentValues.DisablePathApps.Value(), @@ -921,10 +936,10 @@ type API struct { derpCloseFunc func() metricsCache *metricscache.Cache - workspaceAgentCache *wsconncache.Cache updateChecker *updatecheck.Checker WorkspaceAppsProvider workspaceapps.SignedTokenProvider workspaceAppServer *workspaceapps.Server + agentProvider workspaceapps.AgentProvider // Experiments contains the list of experiments currently enabled. // This is used to gate features that are not yet ready for production. @@ -951,7 +966,8 @@ func (api *API) Close() error { if coordinator != nil { _ = (*coordinator).Close() } - return api.workspaceAgentCache.Close() + _ = api.agentProvider.Close() + return nil } func compressHandler(h http.Handler) http.Handler { diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index f4bf035311e6a..d073b48824bd4 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -109,6 +109,7 @@ type Options struct { GitAuthConfigs []*gitauth.Config TrialGenerator func(context.Context, string) error TemplateScheduleStore schedule.TemplateScheduleStore + Coordinator tailnet.Coordinator HealthcheckFunc func(ctx context.Context, apiKey string) *healthcheck.Report HealthcheckTimeout time.Duration diff --git a/coderd/prometheusmetrics/prometheusmetrics_test.go b/coderd/prometheusmetrics/prometheusmetrics_test.go index 5b53fcaa047e4..2ece768671280 100644 --- a/coderd/prometheusmetrics/prometheusmetrics_test.go +++ b/coderd/prometheusmetrics/prometheusmetrics_test.go @@ -302,7 +302,7 @@ func TestAgents(t *testing.T) { coordinator := tailnet.NewCoordinator(slogtest.Make(t, nil).Leveled(slog.LevelDebug)) coordinatorPtr := atomic.Pointer[tailnet.Coordinator]{} coordinatorPtr.Store(&coordinator) - derpMap := tailnettest.RunDERPAndSTUN(t) + derpMap, _ := tailnettest.RunDERPAndSTUN(t) agentInactiveDisconnectTimeout := 1 * time.Hour // don't need to focus on this value in tests registry := prometheus.NewRegistry() diff --git a/coderd/tailnet.go b/coderd/tailnet.go new file mode 100644 index 0000000000000..a1559e4efcd52 --- /dev/null +++ b/coderd/tailnet.go @@ -0,0 +1,339 @@ +package coderd + +import ( + "bufio" + "context" + "net" + "net/http" + "net/http/httputil" + "net/netip" + "net/url" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + "tailscale.com/derp" + "tailscale.com/tailcfg" + + "cdr.dev/slog" + "github.com/coder/coder/coderd/wsconncache" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/site" + "github.com/coder/coder/tailnet" +) + +var tailnetTransport *http.Transport + +func init() { + var valid bool + tailnetTransport, valid = http.DefaultTransport.(*http.Transport) + if !valid { + panic("dev error: default transport is the wrong type") + } +} + +// NewServerTailnet creates a new tailnet intended for use by coderd. It +// automatically falls back to wsconncache if a legacy agent is encountered. +func NewServerTailnet( + ctx context.Context, + logger slog.Logger, + derpServer *derp.Server, + derpMap *tailcfg.DERPMap, + coord *atomic.Pointer[tailnet.Coordinator], + cache *wsconncache.Cache, +) (*ServerTailnet, error) { + logger = logger.Named("servertailnet") + conn, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + DERPMap: derpMap, + Logger: logger, + }) + if err != nil { + return nil, xerrors.Errorf("create tailnet conn: %w", err) + } + + serverCtx, cancel := context.WithCancel(ctx) + tn := &ServerTailnet{ + ctx: serverCtx, + cancel: cancel, + logger: logger, + conn: conn, + coord: coord, + cache: cache, + agentNodes: map[uuid.UUID]time.Time{}, + agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{}, + transport: tailnetTransport.Clone(), + } + tn.transport.DialContext = tn.dialContext + tn.transport.MaxIdleConnsPerHost = 10 + tn.transport.MaxIdleConns = 0 + agentConn := (*coord.Load()).ServeMultiAgent(uuid.New()) + tn.agentConn.Store(&agentConn) + + err = tn.getAgentConn().UpdateSelf(conn.Node()) + if err != nil { + tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err)) + } + conn.SetNodeCallback(func(node *tailnet.Node) { + err := tn.getAgentConn().UpdateSelf(node) + if err != nil { + tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err)) + } + }) + + // This is set to allow local DERP traffic to be proxied through memory + // instead of needing to hit the external access URL. Don't use the ctx + // given in this callback, it's only valid while connecting. + conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn { + if !region.EmbeddedRelay { + return nil + } + left, right := net.Pipe() + go func() { + defer left.Close() + defer right.Close() + brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right)) + derpServer.Accept(ctx, right, brw, "internal") + }() + return left + }) + + go tn.watchAgentUpdates() + go tn.expireOldAgents() + return tn, nil +} + +func (s *ServerTailnet) expireOldAgents() { + const ( + tick = 5 * time.Minute + cutoff = 30 * time.Minute + ) + + ticker := time.NewTicker(tick) + defer ticker.Stop() + + for { + select { + case <-s.ctx.Done(): + return + case <-ticker.C: + } + + s.nodesMu.Lock() + agentConn := s.getAgentConn() + for agentID, lastConnection := range s.agentNodes { + // If no one has connected since the cutoff and there are no active + // connections, remove the agent. + if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 { + _ = agentConn + // err := agentConn.UnsubscribeAgent(agentID) + // if err != nil { + // s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID)) + // } + // delete(s.agentNodes, agentID) + + // TODO(coadler): actually remove from the netmap, then reenable + // the above + } + } + s.nodesMu.Unlock() + } +} + +func (s *ServerTailnet) watchAgentUpdates() { + for { + conn := s.getAgentConn() + nodes, ok := conn.NextUpdate(s.ctx) + if !ok { + if conn.IsClosed() && s.ctx.Err() == nil { + s.reinitCoordinator() + continue + } + return + } + + err := s.conn.UpdateNodes(nodes, false) + if err != nil { + s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err)) + return + } + } +} + +func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn { + return *s.agentConn.Load() +} + +func (s *ServerTailnet) reinitCoordinator() { + s.nodesMu.Lock() + agentConn := (*s.coord.Load()).ServeMultiAgent(uuid.New()) + s.agentConn.Store(&agentConn) + + // Resubscribe to all of the agents we're tracking. + for agentID := range s.agentNodes { + err := agentConn.SubscribeAgent(agentID) + if err != nil { + s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID)) + } + } + s.nodesMu.Unlock() +} + +type ServerTailnet struct { + ctx context.Context + cancel func() + + logger slog.Logger + conn *tailnet.Conn + coord *atomic.Pointer[tailnet.Coordinator] + agentConn atomic.Pointer[tailnet.MultiAgentConn] + cache *wsconncache.Cache + nodesMu sync.Mutex + // agentNodes is a map of agent tailnetNodes the server wants to keep a + // connection to. It contains the last time the agent was connected to. + agentNodes map[uuid.UUID]time.Time + // agentTockets holds a map of all open connections to an agent. + agentTickets map[uuid.UUID]map[uuid.UUID]struct{} + + transport *http.Transport +} + +func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error) { + proxy := httputil.NewSingleHostReverseProxy(targetURL) + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + site.RenderStaticErrorPage(w, r, site.ErrorPageData{ + Status: http.StatusBadGateway, + Title: "Bad Gateway", + Description: "Failed to proxy request to application: " + err.Error(), + RetryEnabled: true, + DashboardURL: dashboardURL.String(), + }) + } + proxy.Director = s.director(agentID, proxy.Director) + proxy.Transport = s.transport + + return proxy, func() {}, nil +} + +type agentIDKey struct{} + +// director makes sure agentIDKey is set on the context in the reverse proxy. +// This allows the transport to correctly identify which agent to dial to. +func (*ServerTailnet) director(agentID uuid.UUID, prev func(req *http.Request)) func(req *http.Request) { + return func(req *http.Request) { + ctx := context.WithValue(req.Context(), agentIDKey{}, agentID) + *req = *req.WithContext(ctx) + prev(req) + } +} + +func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (net.Conn, error) { + agentID, ok := ctx.Value(agentIDKey{}).(uuid.UUID) + if !ok { + return nil, xerrors.Errorf("no agent id attached") + } + + return s.DialAgentNetConn(ctx, agentID, network, addr) +} + +func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error { + s.nodesMu.Lock() + defer s.nodesMu.Unlock() + + _, ok := s.agentNodes[agentID] + // If we don't have the node, subscribe. + if !ok { + s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID)) + err := s.getAgentConn().SubscribeAgent(agentID) + if err != nil { + return xerrors.Errorf("subscribe agent: %w", err) + } + s.agentTickets[agentID] = map[uuid.UUID]struct{}{} + } + + s.agentNodes[agentID] = time.Now() + return nil +} + +func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) { + var ( + conn *codersdk.WorkspaceAgentConn + ret = func() {} + ) + + if s.getAgentConn().AgentIsLegacy(agentID) { + s.logger.Debug(s.ctx, "acquiring legacy agent", slog.F("agent_id", agentID)) + cconn, release, err := s.cache.Acquire(agentID) + if err != nil { + return nil, nil, xerrors.Errorf("acquire legacy agent conn: %w", err) + } + + conn = cconn.WorkspaceAgentConn + ret = release + } else { + err := s.ensureAgent(agentID) + if err != nil { + return nil, nil, xerrors.Errorf("ensure agent: %w", err) + } + + s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID)) + conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{ + AgentID: agentID, + CloseFunc: func() error { return codersdk.ErrSkipClose }, + }) + } + + // Since we now have an open conn, be careful to close it if we error + // without returning it to the user. + + reachable := conn.AwaitReachable(ctx) + if !reachable { + ret() + conn.Close() + return nil, nil, xerrors.New("agent is unreachable") + } + + return conn, ret, nil +} + +func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, network, addr string) (net.Conn, error) { + conn, release, err := s.AgentConn(ctx, agentID) + if err != nil { + return nil, xerrors.Errorf("acquire agent conn: %w", err) + } + + // Since we now have an open conn, be careful to close it if we error + // without returning it to the user. + + nc, err := conn.DialContext(ctx, network, addr) + if err != nil { + release() + conn.Close() + return nil, xerrors.Errorf("dial context: %w", err) + } + + return &netConnCloser{Conn: nc, close: func() { + release() + conn.Close() + }}, err +} + +type netConnCloser struct { + net.Conn + close func() +} + +func (c *netConnCloser) Close() error { + c.close() + return c.Conn.Close() +} + +func (s *ServerTailnet) Close() error { + s.cancel() + _ = s.cache.Close() + _ = s.conn.Close() + s.transport.CloseIdleConnections() + return nil +} diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go new file mode 100644 index 0000000000000..16d597607312c --- /dev/null +++ b/coderd/tailnet_test.go @@ -0,0 +1,207 @@ +package coderd_test + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "net/url" + "sync/atomic" + "testing" + + "github.com/google/uuid" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/agent" + "github.com/coder/coder/agent/agenttest" + "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/wsconncache" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/codersdk/agentsdk" + "github.com/coder/coder/tailnet" + "github.com/coder/coder/tailnet/tailnettest" + "github.com/coder/coder/testutil" +) + +func TestServerTailnet_AgentConn_OK(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + // Connect through the ServerTailnet + agentID, _, serverTailnet := setupAgent(t, nil) + + conn, release, err := serverTailnet.AgentConn(ctx, agentID) + require.NoError(t, err) + defer release() + + assert.True(t, conn.AwaitReachable(ctx)) +} + +func TestServerTailnet_AgentConn_Legacy(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + // Force a connection through wsconncache using the legacy hardcoded ip. + agentID, _, serverTailnet := setupAgent(t, []netip.Prefix{ + netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128), + }) + + conn, release, err := serverTailnet.AgentConn(ctx, agentID) + require.NoError(t, err) + defer release() + + assert.True(t, conn.AwaitReachable(ctx)) +} + +func TestServerTailnet_ReverseProxy_OK(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + // Force a connection through wsconncache using the legacy hardcoded ip. + agentID, _, serverTailnet := setupAgent(t, nil) + + u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort)) + require.NoError(t, err) + + rp, release, err := serverTailnet.ReverseProxy(u, u, agentID) + require.NoError(t, err) + defer release() + + rw := httptest.NewRecorder() + req := httptest.NewRequest( + http.MethodGet, + u.String(), + nil, + ).WithContext(ctx) + + rp.ServeHTTP(rw, req) + res := rw.Result() + defer res.Body.Close() + + assert.Equal(t, http.StatusOK, res.StatusCode) +} + +func TestServerTailnet_ReverseProxy_Legacy(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + // Force a connection through wsconncache using the legacy hardcoded ip. + agentID, _, serverTailnet := setupAgent(t, []netip.Prefix{ + netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128), + }) + + u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort)) + require.NoError(t, err) + + rp, release, err := serverTailnet.ReverseProxy(u, u, agentID) + require.NoError(t, err) + defer release() + + rw := httptest.NewRecorder() + req := httptest.NewRequest( + http.MethodGet, + u.String(), + nil, + ).WithContext(ctx) + + rp.ServeHTTP(rw, req) + res := rw.Result() + defer res.Body.Close() + + assert.Equal(t, http.StatusOK, res.StatusCode) +} + +func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.Agent, *coderd.ServerTailnet) { + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + derpMap, derpServer := tailnettest.RunDERPAndSTUN(t) + manifest := agentsdk.Manifest{ + AgentID: uuid.New(), + DERPMap: derpMap, + } + + var coordPtr atomic.Pointer[tailnet.Coordinator] + coord := tailnet.NewCoordinator(logger) + coordPtr.Store(&coord) + t.Cleanup(func() { + _ = coord.Close() + }) + + c := agenttest.NewClient(t, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord) + + options := agent.Options{ + Client: c, + Filesystem: afero.NewMemMapFs(), + Logger: logger.Named("agent"), + Addresses: agentAddresses, + } + + ag := agent.New(options) + t.Cleanup(func() { + _ = ag.Close() + }) + + // Wait for the agent to connect. + require.Eventually(t, func() bool { + return coord.Node(manifest.AgentID) != nil + }, testutil.WaitShort, testutil.IntervalFast) + + cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { + conn, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + DERPMap: manifest.DERPMap, + Logger: logger.Named("client"), + }) + require.NoError(t, err) + clientConn, serverConn := net.Pipe() + serveClientDone := make(chan struct{}) + t.Cleanup(func() { + _ = clientConn.Close() + _ = serverConn.Close() + _ = conn.Close() + <-serveClientDone + }) + go func() { + defer close(serveClientDone) + coord.ServeClient(serverConn, uuid.New(), manifest.AgentID) + }() + sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { + return conn.UpdateNodes(node, false) + }) + conn.SetNodeCallback(sendNode) + return codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ + AgentID: manifest.AgentID, + AgentIP: codersdk.WorkspaceAgentIP, + CloseFunc: func() error { return codersdk.ErrSkipClose }, + }), nil + }, 0) + + serverTailnet, err := coderd.NewServerTailnet( + context.Background(), + logger, + derpServer, + manifest.DERPMap, + &coordPtr, + cache, + ) + require.NoError(t, err) + + t.Cleanup(func() { + _ = serverTailnet.Close() + }) + + return manifest.AgentID, ag, serverTailnet +} diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index bfe61b4a180df..c1f2e90c02de9 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -161,6 +161,7 @@ func (api *API) workspaceAgentManifest(rw http.ResponseWriter, r *http.Request) } httpapi.Write(ctx, rw, http.StatusOK, agentsdk.Manifest{ + AgentID: apiAgent.ID, Apps: convertApps(dbApps), DERPMap: api.DERPMap, GitAuthConfigs: len(api.GitAuthConfigs), @@ -654,7 +655,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req return } - agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID) + agentConn, release, err := api.agentProvider.AgentConn(ctx, workspaceAgent.ID) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error dialing workspace agent.", @@ -729,7 +730,9 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req httpapi.Write(ctx, rw, http.StatusOK, portsResponse) } -func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { +// Deprecated: use api.tailnet.AgentConn instead. +// See: https://github.com/coder/coder/issues/8218 +func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { clientConn, serverConn := net.Pipe() conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, @@ -765,14 +768,16 @@ func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspac return nil }) conn.SetNodeCallback(sendNodes) - agentConn := &codersdk.WorkspaceAgentConn{ - Conn: conn, - CloseFunc: func() { + agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ + AgentID: agentID, + AgentIP: codersdk.WorkspaceAgentIP, + CloseFunc: func() error { cancel() _ = clientConn.Close() _ = serverConn.Close() + return nil }, - } + }) go func() { err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID) if err != nil { diff --git a/coderd/workspaceapps/apptest/setup.go b/coderd/workspaceapps/apptest/setup.go index 9432e09c9703d..0f0167f37bf79 100644 --- a/coderd/workspaceapps/apptest/setup.go +++ b/coderd/workspaceapps/apptest/setup.go @@ -399,7 +399,8 @@ func doWithRetries(t require.TestingT, client *codersdk.Client, req *http.Reques return resp, err } -func requestWithRetries(ctx context.Context, t require.TestingT, client *codersdk.Client, method, urlOrPath string, body interface{}, opts ...codersdk.RequestOption) (*http.Response, error) { +func requestWithRetries(ctx context.Context, t testing.TB, client *codersdk.Client, method, urlOrPath string, body interface{}, opts ...codersdk.RequestOption) (*http.Response, error) { + t.Helper() var resp *http.Response var err error require.Eventually(t, func() bool { diff --git a/coderd/workspaceapps/proxy.go b/coderd/workspaceapps/proxy.go index 1d3e8592d7a1c..9b2d9c4bfa297 100644 --- a/coderd/workspaceapps/proxy.go +++ b/coderd/workspaceapps/proxy.go @@ -23,7 +23,6 @@ import ( "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/tracing" "github.com/coder/coder/coderd/util/slice" - "github.com/coder/coder/coderd/wsconncache" "github.com/coder/coder/codersdk" "github.com/coder/coder/site" ) @@ -61,6 +60,22 @@ var nonCanonicalHeaders = map[string]string{ "Sec-Websocket-Version": "Sec-WebSocket-Version", } +type AgentProvider interface { + // ReverseProxy returns an httputil.ReverseProxy for proxying HTTP requests + // to the specified agent. + // + // TODO: after wsconncache is deleted this doesn't need to return an error. + ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error) + + // AgentConn returns a new connection to the specified agent. + // + // TODO: after wsconncache is deleted this doesn't need to return a release + // func. + AgentConn(ctx context.Context, agentID uuid.UUID) (_ *codersdk.WorkspaceAgentConn, release func(), _ error) + + Close() error +} + // Server serves workspace apps endpoints, including: // - Path-based apps // - Subdomain app middleware @@ -83,7 +98,6 @@ type Server struct { RealIPConfig *httpmw.RealIPConfig SignedTokenProvider SignedTokenProvider - WorkspaceConnCache *wsconncache.Cache AppSecurityKey SecurityKey // DisablePathApps disables path-based apps. This is a security feature as path @@ -95,6 +109,8 @@ type Server struct { DisablePathApps bool SecureAuthCookie bool + AgentProvider AgentProvider + websocketWaitMutex sync.Mutex websocketWaitGroup sync.WaitGroup } @@ -106,8 +122,8 @@ func (s *Server) Close() error { s.websocketWaitGroup.Wait() s.websocketWaitMutex.Unlock() - // The caller must close the SignedTokenProvider (if necessary) and the - // wsconncache. + // The caller must close the SignedTokenProvider and the AgentProvider (if + // necessary). return nil } @@ -517,18 +533,7 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT r.URL.Path = path appURL.RawQuery = "" - proxy := httputil.NewSingleHostReverseProxy(appURL) - proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - site.RenderStaticErrorPage(rw, r, site.ErrorPageData{ - Status: http.StatusBadGateway, - Title: "Bad Gateway", - Description: "Failed to proxy request to application: " + err.Error(), - RetryEnabled: true, - DashboardURL: s.DashboardURL.String(), - }) - } - - conn, release, err := s.WorkspaceConnCache.Acquire(appToken.AgentID) + proxy, release, err := s.AgentProvider.ReverseProxy(appURL, s.DashboardURL, appToken.AgentID) if err != nil { site.RenderStaticErrorPage(rw, r, site.ErrorPageData{ Status: http.StatusBadGateway, @@ -540,7 +545,6 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT return } defer release() - proxy.Transport = conn.HTTPTransport() proxy.ModifyResponse = func(r *http.Response) error { r.Header.Del(httpmw.AccessControlAllowOriginHeader) @@ -658,13 +662,14 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { go httpapi.Heartbeat(ctx, conn) - agentConn, release, err := s.WorkspaceConnCache.Acquire(appToken.AgentID) + agentConn, release, err := s.AgentProvider.AgentConn(ctx, appToken.AgentID) if err != nil { log.Debug(ctx, "dial workspace agent", slog.Error(err)) _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err)) return } defer release() + defer agentConn.Close() log.Debug(ctx, "dialed workspace agent") ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command")) if err != nil { diff --git a/coderd/wsconncache/wsconncache.go b/coderd/wsconncache/wsconncache.go index 19c7f65f9fb74..13d1588384954 100644 --- a/coderd/wsconncache/wsconncache.go +++ b/coderd/wsconncache/wsconncache.go @@ -1,9 +1,12 @@ // Package wsconncache caches workspace agent connections by UUID. +// Deprecated: Use ServerTailnet instead. package wsconncache import ( "context" "net/http" + "net/http/httputil" + "net/url" "sync" "time" @@ -13,13 +16,57 @@ import ( "golang.org/x/xerrors" "github.com/coder/coder/codersdk" + "github.com/coder/coder/site" ) -// New creates a new workspace connection cache that closes -// connections after the inactive timeout provided. +type AgentProvider struct { + Cache *Cache +} + +func (a *AgentProvider) AgentConn(_ context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) { + conn, rel, err := a.Cache.Acquire(agentID) + if err != nil { + return nil, nil, xerrors.Errorf("acquire agent connection: %w", err) + } + + return conn.WorkspaceAgentConn, rel, nil +} + +func (a *AgentProvider) ReverseProxy(targetURL *url.URL, dashboardURL *url.URL, agentID uuid.UUID) (*httputil.ReverseProxy, func(), error) { + proxy := httputil.NewSingleHostReverseProxy(targetURL) + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + site.RenderStaticErrorPage(w, r, site.ErrorPageData{ + Status: http.StatusBadGateway, + Title: "Bad Gateway", + Description: "Failed to proxy request to application: " + err.Error(), + RetryEnabled: true, + DashboardURL: dashboardURL.String(), + }) + } + + conn, release, err := a.Cache.Acquire(agentID) + if err != nil { + return nil, nil, xerrors.Errorf("acquire agent connection: %w", err) + } + + proxy.Transport = conn.HTTPTransport() + + return proxy, release, nil +} + +func (a *AgentProvider) Close() error { + return a.Cache.Close() +} + +// New creates a new workspace connection cache that closes connections after +// the inactive timeout provided. +// +// Agent connections are cached due to Wireguard negotiation taking a few +// hundred milliseconds, depending on latency. // -// Agent connections are cached due to WebRTC negotiation -// taking a few hundred milliseconds. +// Deprecated: Use coderd.NewServerTailnet instead. wsconncache is being phased +// out because it creates a unique Tailnet for each agent. +// See: https://github.com/coder/coder/issues/8218 func New(dialer Dialer, inactiveTimeout time.Duration) *Cache { if inactiveTimeout == 0 { inactiveTimeout = 5 * time.Minute diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index 34b92267080e5..276e528313751 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -157,22 +157,23 @@ func TestCache(t *testing.T) { func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn { t.Helper() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - manifest.DERPMap = tailnettest.RunDERPAndSTUN(t) + manifest.DERPMap, _ = tailnettest.RunDERPAndSTUN(t) coordinator := tailnet.NewCoordinator(logger) t.Cleanup(func() { _ = coordinator.Close() }) - agentID := uuid.New() + manifest.AgentID = uuid.New() closer := agent.New(agent.Options{ Client: &client{ t: t, - agentID: agentID, + agentID: manifest.AgentID, manifest: manifest, coordinator: coordinator, }, Logger: logger.Named("agent"), ReconnectingPTYTimeout: ptyTimeout, + Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, }) t.Cleanup(func() { _ = closer.Close() @@ -189,14 +190,15 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati _ = serverConn.Close() _ = conn.Close() }) - go coordinator.ServeClient(serverConn, uuid.New(), agentID) + go coordinator.ServeClient(serverConn, uuid.New(), manifest.AgentID) sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { return conn.UpdateNodes(node, false) }) conn.SetNodeCallback(sendNode) - agentConn := &codersdk.WorkspaceAgentConn{ - Conn: conn, - } + agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ + AgentID: manifest.AgentID, + AgentIP: codersdk.WorkspaceAgentIP, + }) t.Cleanup(func() { _ = agentConn.Close() }) diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index 1e281ef494099..bf150cd84940f 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -84,6 +84,7 @@ func (c *Client) PostMetadata(ctx context.Context, key string, req PostMetadataR } type Manifest struct { + AgentID uuid.UUID `json:"agent_id"` // GitAuthConfigs stores the number of Git configurations // the Coder deployment has. If this number is >0, we // set up special configuration in the workspace. diff --git a/codersdk/deployment.go b/codersdk/deployment.go index 3921963e86f4b..79266441b6dc6 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -1764,6 +1764,12 @@ const ( // oidc. ExperimentConvertToOIDC Experiment = "convert-to-oidc" + // ExperimentSingleTailnet replaces workspace connections inside coderd to + // all use a single tailnet, instead of the previous behavior of creating a + // single tailnet for each agent. + // WARNING: This cannot be enabled when using HA. + ExperimentSingleTailnet Experiment = "single_tailnet" + ExperimentWorkspaceBuildLogsUI Experiment = "workspace_build_logs_ui" // Add new experiments here! // ExperimentExample Experiment = "example" diff --git a/codersdk/workspaceagentconn.go b/codersdk/workspaceagentconn.go index 64bd4fe2f8bfa..6b9b6f0d33f44 100644 --- a/codersdk/workspaceagentconn.go +++ b/codersdk/workspaceagentconn.go @@ -15,6 +15,7 @@ import ( "time" "github.com/google/uuid" + "github.com/hashicorp/go-multierror" "golang.org/x/crypto/ssh" "golang.org/x/xerrors" "tailscale.com/ipn/ipnstate" @@ -27,8 +28,14 @@ import ( // WorkspaceAgentIP is a static IPv6 address with the Tailscale prefix that is used to route // connections from clients to this node. A dynamic address is not required because a Tailnet // client only dials a single agent at a time. +// +// Deprecated: use tailnet.IP() instead. This is kept for backwards +// compatibility with wsconncache. +// See: https://github.com/coder/coder/issues/8218 var WorkspaceAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4") +var ErrSkipClose = xerrors.New("skip tailnet close") + const ( WorkspaceAgentSSHPort = tailnet.WorkspaceAgentSSHPort WorkspaceAgentReconnectingPTYPort = tailnet.WorkspaceAgentReconnectingPTYPort @@ -120,11 +127,38 @@ func init() { } } +// NewWorkspaceAgentConn creates a new WorkspaceAgentConn. `conn` may be unique +// to the WorkspaceAgentConn, or it may be shared in the case of coderd. If the +// conn is shared and closing it is undesirable, you may return ErrNoClose from +// opts.CloseFunc. This will ensure the underlying conn is not closed. +func NewWorkspaceAgentConn(conn *tailnet.Conn, opts WorkspaceAgentConnOptions) *WorkspaceAgentConn { + return &WorkspaceAgentConn{ + Conn: conn, + opts: opts, + } +} + // WorkspaceAgentConn represents a connection to a workspace agent. // @typescript-ignore WorkspaceAgentConn type WorkspaceAgentConn struct { *tailnet.Conn - CloseFunc func() + opts WorkspaceAgentConnOptions +} + +// @typescript-ignore WorkspaceAgentConnOptions +type WorkspaceAgentConnOptions struct { + AgentID uuid.UUID + AgentIP netip.Addr + CloseFunc func() error +} + +func (c *WorkspaceAgentConn) agentAddress() netip.Addr { + var emptyIP netip.Addr + if cmp := c.opts.AgentIP.Compare(emptyIP); cmp != 0 { + return c.opts.AgentIP + } + + return tailnet.IPFromUUID(c.opts.AgentID) } // AwaitReachable waits for the agent to be reachable. @@ -132,7 +166,7 @@ func (c *WorkspaceAgentConn) AwaitReachable(ctx context.Context) bool { ctx, span := tracing.StartSpan(ctx) defer span.End() - return c.Conn.AwaitReachable(ctx, WorkspaceAgentIP) + return c.Conn.AwaitReachable(ctx, c.agentAddress()) } // Ping pings the agent and returns the round-trip time. @@ -141,13 +175,20 @@ func (c *WorkspaceAgentConn) Ping(ctx context.Context) (time.Duration, bool, *ip ctx, span := tracing.StartSpan(ctx) defer span.End() - return c.Conn.Ping(ctx, WorkspaceAgentIP) + return c.Conn.Ping(ctx, c.agentAddress()) } // Close ends the connection to the workspace agent. func (c *WorkspaceAgentConn) Close() error { - if c.CloseFunc != nil { - c.CloseFunc() + var cerr error + if c.opts.CloseFunc != nil { + cerr = c.opts.CloseFunc() + if xerrors.Is(cerr, ErrSkipClose) { + return nil + } + } + if cerr != nil { + return multierror.Append(cerr, c.Conn.Close()) } return c.Conn.Close() } @@ -176,10 +217,12 @@ type ReconnectingPTYRequest struct { func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() + if !c.AwaitReachable(ctx) { return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentReconnectingPTYPort)) + + conn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), WorkspaceAgentReconnectingPTYPort)) if err != nil { return nil, err } @@ -209,10 +252,12 @@ func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() + if !c.AwaitReachable(ctx) { return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - return c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSSHPort)) + + return c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), WorkspaceAgentSSHPort)) } // SSHClient calls SSH to create a client that uses a weak cipher @@ -220,10 +265,12 @@ func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) { func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() + netConn, err := c.SSH(ctx) if err != nil { return nil, xerrors.Errorf("ssh: %w", err) } + sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{ // SSH host validation isn't helpful, because obtaining a peer // connection already signifies user-intent to dial a workspace. @@ -233,6 +280,7 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) if err != nil { return nil, xerrors.Errorf("ssh conn: %w", err) } + return ssh.NewClient(sshConn, channels, requests), nil } @@ -240,17 +288,21 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() + if !c.AwaitReachable(ctx) { return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - speedConn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSpeedtestPort)) + + speedConn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), WorkspaceAgentSpeedtestPort)) if err != nil { return nil, xerrors.Errorf("dial speedtest: %w", err) } + results, err := speedtest.RunClientWithConn(direction, duration, speedConn) if err != nil { return nil, xerrors.Errorf("run speedtest: %w", err) } + return results, err } @@ -259,19 +311,23 @@ func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest. func (c *WorkspaceAgentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() - if network == "unix" { - return nil, xerrors.New("network must be tcp or udp") - } - _, rawPort, _ := net.SplitHostPort(addr) - port, _ := strconv.ParseUint(rawPort, 10, 16) - ipp := netip.AddrPortFrom(WorkspaceAgentIP, uint16(port)) + if !c.AwaitReachable(ctx) { return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - if network == "udp" { + + _, rawPort, _ := net.SplitHostPort(addr) + port, _ := strconv.ParseUint(rawPort, 10, 16) + ipp := netip.AddrPortFrom(c.agentAddress(), uint16(port)) + + switch network { + case "tcp": + return c.Conn.DialContextTCP(ctx, ipp) + case "udp": return c.Conn.DialContextUDP(ctx, ipp) + default: + return nil, xerrors.Errorf("unknown network %q", network) } - return c.Conn.DialContextTCP(ctx, ipp) } type WorkspaceAgentListeningPortsResponse struct { @@ -309,7 +365,8 @@ func (c *WorkspaceAgentConn) ListeningPorts(ctx context.Context) (WorkspaceAgent func (c *WorkspaceAgentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() - host := net.JoinHostPort(WorkspaceAgentIP.String(), strconv.Itoa(WorkspaceAgentHTTPAPIServerPort)) + + host := net.JoinHostPort(c.agentAddress().String(), strconv.Itoa(WorkspaceAgentHTTPAPIServerPort)) url := fmt.Sprintf("http://%s%s", host, path) req, err := http.NewRequestWithContext(ctx, method, url, body) @@ -332,13 +389,14 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client { if network != "tcp" { return nil, xerrors.Errorf("network must be tcp") } + host, port, err := net.SplitHostPort(addr) if err != nil { return nil, xerrors.Errorf("split host port %q: %w", addr, err) } - // Verify that host is TailnetIP and port is - // TailnetStatisticsPort. - if host != WorkspaceAgentIP.String() || port != strconv.Itoa(WorkspaceAgentHTTPAPIServerPort) { + + // Verify that the port is TailnetStatisticsPort. + if port != strconv.Itoa(WorkspaceAgentHTTPAPIServerPort) { return nil, xerrors.Errorf("request %q does not appear to be for http api", addr) } @@ -346,7 +404,12 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client { return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort)) + ipAddr, err := netip.ParseAddr(host) + if err != nil { + return nil, xerrors.Errorf("parse host addr: %w", err) + } + + conn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(ipAddr, WorkspaceAgentHTTPAPIServerPort)) if err != nil { return nil, xerrors.Errorf("dial http api: %w", err) } diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 208c4511d261a..b76ebba9344f5 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -307,8 +307,8 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti options.Logger.Debug(ctx, "failed to dial", slog.Error(err)) continue } - sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error { - return conn.UpdateNodes(node, false) + sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(nodes []*tailnet.Node) error { + return conn.UpdateNodes(nodes, false) }) conn.SetNodeCallback(sendNode) options.Logger.Debug(ctx, "serving coordinator") @@ -330,13 +330,15 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti return nil, err } - agentConn = &WorkspaceAgentConn{ - Conn: conn, - CloseFunc: func() { + agentConn = NewWorkspaceAgentConn(conn, WorkspaceAgentConnOptions{ + AgentID: agentID, + CloseFunc: func() error { cancel() <-closed + return conn.Close() }, - } + }) + if !agentConn.AwaitReachable(ctx) { _ = agentConn.Close() return nil, xerrors.Errorf("timed out waiting for agent to become reachable: %w", ctx.Err()) diff --git a/docs/api/agents.md b/docs/api/agents.md index 69ff2fbe72318..7dcce5d52e847 100644 --- a/docs/api/agents.md +++ b/docs/api/agents.md @@ -292,6 +292,7 @@ curl -X GET http://coder-server:8080/api/v2/workspaceagents/me/manifest \ ```json { + "agent_id": "string", "apps": [ { "command": "string", diff --git a/docs/api/schemas.md b/docs/api/schemas.md index 2a0861d413573..a042b6d1f6e04 100644 --- a/docs/api/schemas.md +++ b/docs/api/schemas.md @@ -161,6 +161,7 @@ ```json { + "agent_id": "string", "apps": [ { "command": "string", @@ -260,6 +261,7 @@ | Name | Type | Required | Restrictions | Description | | ---------------------------- | ------------------------------------------------------------------------------------------------- | -------- | ------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `agent_id` | string | false | | | | `apps` | array of [codersdk.WorkspaceApp](#codersdkworkspaceapp) | false | | | | `derpmap` | [tailcfg.DERPMap](#tailcfgderpmap) | false | | | | `directory` | string | false | | | @@ -2543,6 +2545,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in | `workspace_actions` | | `tailnet_ha_coordinator` | | `convert-to-oidc` | +| `single_tailnet` | | `workspace_build_logs_ui` | ## codersdk.Feature diff --git a/enterprise/coderd/appearance_test.go b/enterprise/coderd/appearance_test.go index dc6ce99052b60..6f564eaa3a680 100644 --- a/enterprise/coderd/appearance_test.go +++ b/enterprise/coderd/appearance_test.go @@ -6,9 +6,8 @@ import ( "net/http" "testing" - "github.com/stretchr/testify/require" - "github.com/google/uuid" + "github.com/stretchr/testify/require" "github.com/coder/coder/cli/clibase" "github.com/coder/coder/coderd/coderdtest" diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index b0d9cfa64032f..889df136710c5 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -17,6 +17,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/coderd/database/pubsub" + "github.com/coder/coder/codersdk" agpl "github.com/coder/coder/tailnet" ) @@ -37,9 +38,12 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err closeFunc: cancelFunc, close: make(chan struct{}), nodes: map[uuid.UUID]*agpl.Node{}, - agentSockets: map[uuid.UUID]*agpl.TrackedConn{}, - agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*agpl.TrackedConn{}, + agentSockets: map[uuid.UUID]agpl.Queue{}, + agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]agpl.Queue{}, agentNameCache: nameCache, + clients: map[uuid.UUID]agpl.Queue{}, + clientsToAgents: map[uuid.UUID]map[uuid.UUID]agpl.Queue{}, + legacyAgents: map[uuid.UUID]struct{}{}, } if err := coord.runPubsub(ctx); err != nil { @@ -49,6 +53,56 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err return coord, nil } +func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { + m := (&agpl.MultiAgent{ + ID: id, + Logger: c.log, + AgentIsLegacyFunc: c.agentIsLegacy, + OnSubscribe: c.clientSubscribeToAgent, + OnNodeUpdate: c.clientNodeUpdate, + OnRemove: c.clientDisconnected, + }).Init() + c.addClient(id, m) + return m +} + +func (c *haCoordinator) addClient(id uuid.UUID, q agpl.Queue) { + c.mutex.Lock() + c.clients[id] = q + c.clientsToAgents[id] = map[uuid.UUID]agpl.Queue{} + c.mutex.Unlock() +} + +func (c *haCoordinator) clientSubscribeToAgent(enq agpl.Queue, agentID uuid.UUID) (*agpl.Node, error) { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.initOrSetAgentConnectionSocketLocked(agentID, enq) + + node := c.nodes[enq.UniqueID()] + if node != nil { + err := c.sendNodeToAgentLocked(agentID, node) + if err != nil { + return nil, xerrors.Errorf("handle client update: %w", err) + } + } + + agentNode, ok := c.nodes[agentID] + // If we have the node locally, give it back to the multiagent. + if ok { + return agentNode, nil + } + + // If we don't have the node locally, notify other coordinators. + err := c.publishClientHello(agentID) + if err != nil { + return nil, xerrors.Errorf("publish client hello: %w", err) + } + + // nolint:nilnil + return nil, nil +} + type haCoordinator struct { id uuid.UUID log slog.Logger @@ -60,14 +114,26 @@ type haCoordinator struct { // nodes maps agent and connection IDs their respective node. nodes map[uuid.UUID]*agpl.Node // agentSockets maps agent IDs to their open websocket. - agentSockets map[uuid.UUID]*agpl.TrackedConn + agentSockets map[uuid.UUID]agpl.Queue // agentToConnectionSockets maps agent IDs to connection IDs of conns that // are subscribed to updates for that agent. - agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]*agpl.TrackedConn + agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]agpl.Queue + + // clients holds a map of all clients connected to the coordinator. This is + // necessary because a client may not be subscribed into any agents. + clients map[uuid.UUID]agpl.Queue + // clientsToAgents is an index of clients to all of their subscribed agents. + clientsToAgents map[uuid.UUID]map[uuid.UUID]agpl.Queue // agentNameCache holds a cache of agent names. If one of them disappears, // it's helpful to have a name cached for debugging. agentNameCache *lru.Cache[uuid.UUID, string] + + // legacyAgents holda a mapping of all agents detected as legacy, meaning + // they only listen on codersdk.WorkspaceAgentIP. They aren't compatible + // with the new ServerTailnet, so they must be connected through + // wsconncache. + legacyAgents map[uuid.UUID]struct{} } // Node returns an in-memory node by ID. @@ -88,47 +154,62 @@ func (c *haCoordinator) agentLogger(agent uuid.UUID) slog.Logger { // ServeClient accepts a WebSocket connection that wants to connect to an agent // with the specified ID. -func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { +func (c *haCoordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - logger := c.clientLogger(id, agent) - - c.mutex.Lock() - connectionSockets, ok := c.agentToConnectionSockets[agent] - if !ok { - connectionSockets = map[uuid.UUID]*agpl.TrackedConn{} - c.agentToConnectionSockets[agent] = connectionSockets - } + logger := c.clientLogger(id, agentID) tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, 0) - // Insert this connection into a map so the agent - // can publish node updates. - connectionSockets[id] = tc + defer tc.Close() - // When a new connection is requested, we update it with the latest - // node of the agent. This allows the connection to establish. - node, ok := c.nodes[agent] - if ok { - err := tc.Enqueue([]*agpl.Node{node}) - c.mutex.Unlock() + c.addClient(id, tc) + defer c.clientDisconnected(id) + + agentNode, err := c.clientSubscribeToAgent(tc, agentID) + if err != nil { + return xerrors.Errorf("subscribe agent: %w", err) + } + + if agentNode != nil { + err := tc.Enqueue([]*agpl.Node{agentNode}) if err != nil { - return xerrors.Errorf("enqueue node: %w", err) + logger.Debug(ctx, "enqueue initial node", slog.Error(err)) } - } else { - c.mutex.Unlock() - err := c.publishClientHello(agent) + } + + go tc.SendUpdates() + + decoder := json.NewDecoder(conn) + // Indefinitely handle messages from the client websocket. + for { + err := c.handleNextClientMessage(id, decoder) if err != nil { - return xerrors.Errorf("publish client hello: %w", err) + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { + return nil + } + return xerrors.Errorf("handle next client message: %w", err) } } - go tc.SendUpdates() +} - defer func() { - c.mutex.Lock() - defer c.mutex.Unlock() +func (c *haCoordinator) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq agpl.Queue) { + connectionSockets, ok := c.agentToConnectionSockets[agentID] + if !ok { + connectionSockets = map[uuid.UUID]agpl.Queue{} + c.agentToConnectionSockets[agentID] = connectionSockets + } + connectionSockets[enq.UniqueID()] = enq + c.clientsToAgents[enq.UniqueID()][agentID] = c.agentSockets[agentID] +} + +func (c *haCoordinator) clientDisconnected(id uuid.UUID) { + c.mutex.Lock() + defer c.mutex.Unlock() + + for agentID := range c.clientsToAgents[id] { // Clean all traces of this connection from the map. delete(c.nodes, id) - connectionSockets, ok := c.agentToConnectionSockets[agent] + connectionSockets, ok := c.agentToConnectionSockets[agentID] if !ok { return } @@ -136,51 +217,65 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID if len(connectionSockets) != 0 { return } - delete(c.agentToConnectionSockets, agent) - }() - - decoder := json.NewDecoder(conn) - // Indefinitely handle messages from the client websocket. - for { - err := c.handleNextClientMessage(id, agent, decoder) - if err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { - return nil - } - return xerrors.Errorf("handle next client message: %w", err) - } + delete(c.agentToConnectionSockets, agentID) } + + delete(c.clients, id) + delete(c.clientsToAgents, id) } -func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error { +func (c *haCoordinator) handleNextClientMessage(id uuid.UUID, decoder *json.Decoder) error { var node agpl.Node err := decoder.Decode(&node) if err != nil { return xerrors.Errorf("read json: %w", err) } + return c.clientNodeUpdate(id, &node) +} + +func (c *haCoordinator) clientNodeUpdate(id uuid.UUID, node *agpl.Node) error { c.mutex.Lock() + defer c.mutex.Unlock() // Update the node of this client in our in-memory map. If an agent entirely // shuts down and reconnects, it needs to be aware of all clients attempting // to establish connections. - c.nodes[id] = &node - // Write the new node from this client to the actively connected agent. - agentSocket, ok := c.agentSockets[agent] + c.nodes[id] = node + for agentID, agentSocket := range c.clientsToAgents[id] { + if agentSocket == nil { + // If we don't own the agent locally, send it over pubsub to a node that + // owns the agent. + err := c.publishNodesToAgent(agentID, []*agpl.Node{node}) + if err != nil { + c.log.Error(context.Background(), "publish node to agent", slog.Error(err), slog.F("agent_id", agentID)) + } + } else { + // Write the new node from this client to the actively connected agent. + err := agentSocket.Enqueue([]*agpl.Node{node}) + if err != nil { + c.log.Error(context.Background(), "enqueue node to agent", slog.Error(err), slog.F("agent_id", agentID)) + } + } + } + + return nil +} + +func (c *haCoordinator) sendNodeToAgentLocked(agentID uuid.UUID, node *agpl.Node) error { + agentSocket, ok := c.agentSockets[agentID] if !ok { - c.mutex.Unlock() // If we don't own the agent locally, send it over pubsub to a node that // owns the agent. - err := c.publishNodesToAgent(agent, []*agpl.Node{&node}) + err := c.publishNodesToAgent(agentID, []*agpl.Node{node}) if err != nil { return xerrors.Errorf("publish node to agent") } return nil } - err = agentSocket.Enqueue([]*agpl.Node{&node}) - c.mutex.Unlock() + err := agentSocket.Enqueue([]*agpl.Node{node}) if err != nil { - return xerrors.Errorf("enqueu nodes: %w", err) + return xerrors.Errorf("enqueue node: %w", err) } return nil } @@ -202,7 +297,7 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err // dead. oldAgentSocket, ok := c.agentSockets[id] if ok { - overwrites = oldAgentSocket.Overwrites + 1 + overwrites = oldAgentSocket.Overwrites() + 1 _ = oldAgentSocket.Close() } // This uniquely identifies a connection that belongs to this goroutine. @@ -219,6 +314,9 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err } } c.agentSockets[id] = tc + for clientID := range c.agentToConnectionSockets[id] { + c.clientsToAgents[clientID][id] = tc + } c.mutex.Unlock() go tc.SendUpdates() @@ -234,10 +332,13 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err // Only delete the connection if it's ours. It could have been // overwritten. - if idConn, ok := c.agentSockets[id]; ok && idConn.ID == unique { + if idConn, ok := c.agentSockets[id]; ok && idConn.UniqueID() == unique { delete(c.agentSockets, id) delete(c.nodes, id) } + for clientID := range c.agentToConnectionSockets[id] { + c.clientsToAgents[clientID][id] = nil + } }() decoder := json.NewDecoder(conn) @@ -285,6 +386,13 @@ func (c *haCoordinator) handleClientHello(id uuid.UUID) error { return c.publishAgentToNodes(id, node) } +func (c *haCoordinator) agentIsLegacy(agentID uuid.UUID) bool { + c.mutex.RLock() + _, ok := c.legacyAgents[agentID] + c.mutex.RUnlock() + return ok +} + func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (*agpl.Node, error) { var node agpl.Node err := decoder.Decode(&node) @@ -293,6 +401,11 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) ( } c.mutex.Lock() + // Keep a cache of all legacy agents. + if len(node.Addresses) > 0 && node.Addresses[0].Addr() == codersdk.WorkspaceAgentIP { + c.legacyAgents[id] = struct{}{} + } + oldNode := c.nodes[id] if oldNode != nil { if oldNode.AsOf.After(node.AsOf) { @@ -311,7 +424,9 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) ( for _, connectionSocket := range connectionSockets { _ = connectionSocket.Enqueue([]*agpl.Node{&node}) } + c.mutex.Unlock() + return &node, nil } @@ -334,20 +449,18 @@ func (c *haCoordinator) Close() error { for _, socket := range c.agentSockets { socket := socket go func() { - _ = socket.Close() + _ = socket.CoordinatorClose() wg.Done() }() } - for _, connMap := range c.agentToConnectionSockets { - wg.Add(len(connMap)) - for _, socket := range connMap { - socket := socket - go func() { - _ = socket.Close() - wg.Done() - }() - } + wg.Add(len(c.clients)) + for _, client := range c.clients { + client := client + go func() { + _ = client.CoordinatorClose() + wg.Done() + }() } wg.Wait() @@ -422,13 +535,12 @@ func (c *haCoordinator) runPubsub(ctx context.Context) error { } go func() { for { - var message []byte select { case <-ctx.Done(): return - case message = <-messageQueue: + case message := <-messageQueue: + c.handlePubsubMessage(ctx, message) } - c.handlePubsubMessage(ctx, message) } }() diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 6c3cc73ac9e9d..37c516f5aa65e 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -125,6 +125,11 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store return c, nil } +func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { + _, _ = c, id + panic("not implemented") // TODO: Implement +} + func (*pgCoord) ServeHTTPDebug(w http.ResponseWriter, _ *http.Request) { // TODO(spikecurtis) I'd like to hold off implementing this until after the rest of this is code reviewed. w.WriteHeader(http.StatusOK) diff --git a/enterprise/wsproxy/wsproxy.go b/enterprise/wsproxy/wsproxy.go index fce5e0cc7a3b1..ae5da832054e2 100644 --- a/enterprise/wsproxy/wsproxy.go +++ b/enterprise/wsproxy/wsproxy.go @@ -183,9 +183,12 @@ func New(ctx context.Context, opts *Options) (*Server, error) { SecurityKey: secKey, Logger: s.Logger.Named("proxy_token_provider"), }, - WorkspaceConnCache: wsconncache.New(s.DialWorkspaceAgent, 0), - AppSecurityKey: secKey, + AppSecurityKey: secKey, + // TODO: Convert wsproxy to use coderd.ServerTailnet. + AgentProvider: &wsconncache.AgentProvider{ + Cache: wsconncache.New(s.DialWorkspaceAgent, 0), + }, DisablePathApps: opts.DisablePathApps, SecureAuthCookie: opts.SecureAuthCookie, } @@ -273,6 +276,7 @@ func (s *Server) Close() error { tmp, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() _ = s.SDKClient.WorkspaceProxyGoingAway(tmp) + _ = s.AppServer.AgentProvider.Close() return s.AppServer.Close() } diff --git a/scaletest/agentconn/run.go b/scaletest/agentconn/run.go index cefbb660485e5..6593e7c79b52c 100644 --- a/scaletest/agentconn/run.go +++ b/scaletest/agentconn/run.go @@ -6,7 +6,6 @@ import ( "io" "net" "net/http" - "net/netip" "net/url" "strconv" "time" @@ -377,7 +376,10 @@ func agentHTTPClient(conn *codersdk.WorkspaceAgentConn) *http.Client { if err != nil { return nil, xerrors.Errorf("parse port %q: %w", port, err) } - return conn.DialContextTCP(ctx, netip.AddrPortFrom(codersdk.WorkspaceAgentIP, uint16(portUint))) + + // Addr doesn't matter here, besides the port. DialContext will + // automatically choose the right IP to dial. + return conn.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", portUint)) }, }, } diff --git a/scaletest/reconnectingpty/run_test.go b/scaletest/reconnectingpty/run_test.go index f6f70bbf574bf..382a3718436f9 100644 --- a/scaletest/reconnectingpty/run_test.go +++ b/scaletest/reconnectingpty/run_test.go @@ -9,6 +9,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" + "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" "github.com/coder/coder/coderd/coderdtest" @@ -243,7 +244,7 @@ func Test_Runner(t *testing.T) { func setupRunnerTest(t *testing.T) (client *codersdk.Client, agentID uuid.UUID) { t.Helper() - client = coderdtest.New(t, &coderdtest.Options{ + client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ IncludeProvisionerDaemon: true, }) user := coderdtest.CreateFirstUser(t, client) @@ -282,12 +283,16 @@ func setupRunnerTest(t *testing.T) (client *codersdk.Client, agentID uuid.UUID) agentClient.SetSessionToken(authToken) agentCloser := agent.New(agent.Options{ Client: agentClient, - Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("agent"), + Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("agent").Leveled(slog.LevelDebug), }) t.Cleanup(func() { _ = agentCloser.Close() }) resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + require.Eventually(t, func() bool { + t.Log("agent id", resources[0].Agents[0].ID) + return (*api.TailnetCoordinator.Load()).Node(resources[0].Agents[0].ID) != nil + }, testutil.WaitLong, testutil.IntervalMedium, "agent never connected") return client, resources[0].Agents[0].ID } diff --git a/scaletest/workspacetraffic/run_test.go b/scaletest/workspacetraffic/run_test.go index e53d408bcd428..c070a906be228 100644 --- a/scaletest/workspacetraffic/run_test.go +++ b/scaletest/workspacetraffic/run_test.go @@ -68,6 +68,7 @@ func TestRun(t *testing.T) { agentCloser := agent.New(agent.Options{ Client: agentClient, }) + ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) t.Cleanup(func() { diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index ee1bcda3736e9..5ba2158c93b36 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1431,12 +1431,14 @@ export const Entitlements: Entitlement[] = [ export type Experiment = | "convert-to-oidc" | "moons" + | "single_tailnet" | "tailnet_ha_coordinator" | "workspace_actions" | "workspace_build_logs_ui" export const Experiments: Experiment[] = [ "convert-to-oidc", "moons", + "single_tailnet", "tailnet_ha_coordinator", "workspace_actions", "workspace_build_logs_ui", diff --git a/tailnet/conn.go b/tailnet/conn.go index 363ccb80ff48c..0534cf961771a 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -139,6 +139,7 @@ func NewConn(options *Options) (conn *Conn, err error) { } }() + IP() dialer := &tsdial.Dialer{ Logf: Logger(options.Logger.Named("tsdial")), } @@ -182,10 +183,17 @@ func NewConn(options *Options) (conn *Conn, err error) { netMap.SelfNode.DiscoKey = magicConn.DiscoPublicKey() netStack, err := netstack.Create( - Logger(options.Logger.Named("netstack")), tunDevice, wireguardEngine, magicConn, dialer, dnsManager) + Logger(options.Logger.Named("netstack")), + tunDevice, + wireguardEngine, + magicConn, + dialer, + dnsManager, + ) if err != nil { return nil, xerrors.Errorf("create netstack: %w", err) } + dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { return netStack.DialContextTCP(ctx, dst) } @@ -203,7 +211,14 @@ func NewConn(options *Options) (conn *Conn, err error) { localIPs, _ := localIPSet.IPSet() logIPSet := netipx.IPSetBuilder{} logIPs, _ := logIPSet.IPSet() - wireguardEngine.SetFilter(filter.New(netMap.PacketFilter, localIPs, logIPs, nil, Logger(options.Logger.Named("packet-filter")))) + wireguardEngine.SetFilter(filter.New( + netMap.PacketFilter, + localIPs, + logIPs, + nil, + Logger(options.Logger.Named("packet-filter")), + )) + dialContext, dialCancel := context.WithCancel(context.Background()) server := &Conn{ blockEndpoints: options.BlockEndpoints, @@ -230,6 +245,7 @@ func NewConn(options *Options) (conn *Conn, err error) { _ = server.Close() } }() + wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) { server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.Error(err)) if err != nil { @@ -251,6 +267,7 @@ func NewConn(options *Options) (conn *Conn, err error) { server.lastMutex.Unlock() server.sendNode() }) + wireguardEngine.SetNetInfoCallback(func(ni *tailcfg.NetInfo) { server.logger.Debug(context.Background(), "netinfo callback", slog.F("netinfo", ni)) server.lastMutex.Lock() @@ -262,6 +279,7 @@ func NewConn(options *Options) (conn *Conn, err error) { server.lastMutex.Unlock() server.sendNode() }) + magicConn.SetDERPForcedWebsocketCallback(func(region int, reason string) { server.logger.Debug(context.Background(), "derp forced websocket", slog.F("region", region), slog.F("reason", reason)) server.lastMutex.Lock() @@ -273,6 +291,7 @@ func NewConn(options *Options) (conn *Conn, err error) { server.lastMutex.Unlock() server.sendNode() }) + netStack.ForwardTCPIn = server.forwardTCP netStack.ForwardTCPSockOpts = server.forwardTCPSockOpts @@ -284,22 +303,30 @@ func NewConn(options *Options) (conn *Conn, err error) { return server, nil } -// IP generates a new IP with a static service prefix. -func IP() netip.Addr { - // This is Tailscale's ephemeral service prefix. - // This can be changed easily later-on, because - // all of our nodes are ephemeral. +func maskUUID(uid uuid.UUID) uuid.UUID { + // This is Tailscale's ephemeral service prefix. This can be changed easily + // later-on, because all of our nodes are ephemeral. // fd7a:115c:a1e0 - uid := uuid.New() uid[0] = 0xfd uid[1] = 0x7a uid[2] = 0x11 uid[3] = 0x5c uid[4] = 0xa1 uid[5] = 0xe0 + return uid +} + +// IP generates a random IP with a static service prefix. +func IP() netip.Addr { + uid := maskUUID(uuid.New()) return netip.AddrFrom16(uid) } +// IP generates a new IP from a UUID. +func IPFromUUID(uid uuid.UUID) netip.Addr { + return netip.AddrFrom16(maskUUID(uid)) +} + // Conn is an actively listening Wireguard connection. type Conn struct { dialContext context.Context @@ -334,6 +361,29 @@ type Conn struct { trafficStats *connstats.Statistics } +func (c *Conn) SetAddresses(ips []netip.Prefix) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.netMap.Addresses = ips + + netMapCopy := *c.netMap + c.logger.Debug(context.Background(), "updating network map") + c.wireguardEngine.SetNetworkMap(&netMapCopy) + err := c.reconfig() + if err != nil { + return xerrors.Errorf("reconfig: %w", err) + } + + return nil +} + +func (c *Conn) Addresses() []netip.Prefix { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.netMap.Addresses +} + func (c *Conn) SetNodeCallback(callback func(node *Node)) { c.lastMutex.Lock() c.nodeCallback = callback @@ -366,32 +416,6 @@ func (c *Conn) SetDERPRegionDialer(dialer func(ctx context.Context, region *tail c.magicConn.SetDERPRegionDialer(dialer) } -func (c *Conn) RemoveAllPeers() error { - c.mutex.Lock() - defer c.mutex.Unlock() - - c.netMap.Peers = []*tailcfg.Node{} - c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{} - netMapCopy := *c.netMap - c.logger.Debug(context.Background(), "updating network map") - c.wireguardEngine.SetNetworkMap(&netMapCopy) - cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "") - if err != nil { - return xerrors.Errorf("update wireguard config: %w", err) - } - err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{}) - if err != nil { - if c.isClosed() { - return nil - } - if errors.Is(err, wgengine.ErrNoChanges) { - return nil - } - return xerrors.Errorf("reconfig: %w", err) - } - return nil -} - // UpdateNodes connects with a set of peers. This can be constantly updated, // and peers will continually be reconnected as necessary. If replacePeers is // true, all peers will be removed before adding the new ones. @@ -423,6 +447,7 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error { } delete(c.peerMap, peer.ID) } + for _, node := range nodes { // If no preferred DERP is provided, we can't reach the node. if node.PreferredDERP == 0 { @@ -452,17 +477,29 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error { } c.peerMap[node.ID] = peerNode } + c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap)) for _, peer := range c.peerMap { c.netMap.Peers = append(c.netMap.Peers, peer.Clone()) } + netMapCopy := *c.netMap c.logger.Debug(context.Background(), "updating network map") c.wireguardEngine.SetNetworkMap(&netMapCopy) + err := c.reconfig() + if err != nil { + return xerrors.Errorf("reconfig: %w", err) + } + + return nil +} + +func (c *Conn) reconfig() error { cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "") if err != nil { return xerrors.Errorf("update wireguard config: %w", err) } + err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{}) if err != nil { if c.isClosed() { @@ -473,6 +510,7 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error { } return xerrors.Errorf("reconfig: %w", err) } + return nil } diff --git a/tailnet/conn_test.go b/tailnet/conn_test.go index 2e19379e6df03..0dd0812b94777 100644 --- a/tailnet/conn_test.go +++ b/tailnet/conn_test.go @@ -23,7 +23,7 @@ func TestMain(m *testing.M) { func TestTailnet(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - derpMap := tailnettest.RunDERPAndSTUN(t) + derpMap, _ := tailnettest.RunDERPAndSTUN(t) t.Run("InstantClose", func(t *testing.T) { t.Parallel() conn, err := tailnet.NewConn(&tailnet.Options{ @@ -172,7 +172,7 @@ func TestConn_PreferredDERP(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - derpMap := tailnettest.RunDERPAndSTUN(t) + derpMap, _ := tailnettest.RunDERPAndSTUN(t) conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, Logger: logger.Named("w1"), diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 7ff279a508b46..93cf8c67af56b 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -1,7 +1,6 @@ package tailnet import ( - "bytes" "context" "encoding/json" "errors" @@ -11,17 +10,16 @@ import ( "net/http" "net/netip" "sync" - "sync/atomic" "time" - "cdr.dev/slog" - "github.com/google/uuid" lru "github.com/hashicorp/golang-lru/v2" "golang.org/x/exp/slices" "golang.org/x/xerrors" "tailscale.com/tailcfg" "tailscale.com/types/key" + + "cdr.dev/slog" ) // Coordinator exchanges nodes with agents to establish connections. @@ -44,6 +42,8 @@ type Coordinator interface { ServeAgent(conn net.Conn, id uuid.UUID, name string) error // Close closes the coordinator. Close() error + + ServeMultiAgent(id uuid.UUID) MultiAgentConn } // Node represents a node in the network. @@ -54,10 +54,11 @@ type Node struct { AsOf time.Time `json:"as_of"` // Key is the Wireguard public key of the node. Key key.NodePublic `json:"key"` - // DiscoKey is used for discovery messages over DERP to establish peer-to-peer connections. + // DiscoKey is used for discovery messages over DERP to establish + // peer-to-peer connections. DiscoKey key.DiscoPublic `json:"disco"` - // PreferredDERP is the DERP server that peered connections - // should meet at to establish. + // PreferredDERP is the DERP server that peered connections should meet at + // to establish. PreferredDERP int `json:"preferred_derp"` // DERPLatency is the latency in seconds to each DERP server. DERPLatency map[string]float64 `json:"derp_latency"` @@ -68,8 +69,8 @@ type Node struct { DERPForcedWebsocket map[int]string `json:"derp_forced_websockets"` // Addresses are the IP address ranges this connection exposes. Addresses []netip.Prefix `json:"addresses"` - // AllowedIPs specify what addresses can dial the connection. - // We allow all by default. + // AllowedIPs specify what addresses can dial the connection. We allow all + // by default. AllowedIPs []netip.Prefix `json:"allowed_ips"` // Endpoints are ip:port combinations that can be used to establish // peer-to-peer connections. @@ -130,12 +131,33 @@ func NewCoordinator(logger slog.Logger) Coordinator { // ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐ // │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│ // └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘ -// This coordinator is incompatible with multiple Coder -// replicas as all node data is in-memory. +// This coordinator is incompatible with multiple Coder replicas as all node +// data is in-memory. type coordinator struct { core *core } +func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn { + m := (&MultiAgent{ + ID: id, + Logger: c.core.logger, + AgentIsLegacyFunc: c.core.agentIsLegacy, + OnSubscribe: c.core.clientSubscribeToAgent, + OnUnsubscribe: c.core.clientUnsubscribeFromAgent, + OnNodeUpdate: c.core.clientNodeUpdate, + OnRemove: c.core.clientDisconnected, + }).Init() + c.core.addClient(id, m) + return m +} + +func (c *core) addClient(id uuid.UUID, ma Queue) { + c.mutex.Lock() + c.clients[id] = ma + c.clientsToAgents[id] = map[uuid.UUID]Queue{} + c.mutex.Unlock() +} + // core is an in-memory structure of Node and TrackedConn mappings. Its methods may be called from multiple goroutines; // it is protected by a mutex to ensure data stay consistent. type core struct { @@ -146,14 +168,38 @@ type core struct { // nodes maps agent and connection IDs their respective node. nodes map[uuid.UUID]*Node // agentSockets maps agent IDs to their open websocket. - agentSockets map[uuid.UUID]*TrackedConn + agentSockets map[uuid.UUID]Queue // agentToConnectionSockets maps agent IDs to connection IDs of conns that // are subscribed to updates for that agent. - agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]*TrackedConn + agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]Queue + + // clients holds a map of all clients connected to the coordinator. This is + // necessary because a client may not be subscribed into any agents. + clients map[uuid.UUID]Queue + // clientsToAgents is an index of clients to all of their subscribed agents. + clientsToAgents map[uuid.UUID]map[uuid.UUID]Queue // agentNameCache holds a cache of agent names. If one of them disappears, // it's helpful to have a name cached for debugging. agentNameCache *lru.Cache[uuid.UUID, string] + + // legacyAgents holda a mapping of all agents detected as legacy, meaning + // they only listen on codersdk.WorkspaceAgentIP. They aren't compatible + // with the new ServerTailnet, so they must be connected through + // wsconncache. + legacyAgents map[uuid.UUID]struct{} +} + +type Queue interface { + UniqueID() uuid.UUID + Enqueue(n []*Node) error + Name() string + Stats() (start, lastWrite int64) + Overwrites() int64 + // CoordinatorClose is used by the coordinator when closing a Queue. It + // should skip removing itself from the coordinator. + CoordinatorClose() error + Close() error } func newCore(logger slog.Logger) *core { @@ -165,128 +211,18 @@ func newCore(logger slog.Logger) *core { return &core{ logger: logger, closed: false, - nodes: make(map[uuid.UUID]*Node), - agentSockets: map[uuid.UUID]*TrackedConn{}, - agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*TrackedConn{}, + nodes: map[uuid.UUID]*Node{}, + agentSockets: map[uuid.UUID]Queue{}, + agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]Queue{}, agentNameCache: nameCache, + legacyAgents: map[uuid.UUID]struct{}{}, + clients: map[uuid.UUID]Queue{}, + clientsToAgents: map[uuid.UUID]map[uuid.UUID]Queue{}, } } var ErrWouldBlock = xerrors.New("would block") -type TrackedConn struct { - ctx context.Context - cancel func() - conn net.Conn - updates chan []*Node - logger slog.Logger - lastData []byte - - // ID is an ephemeral UUID used to uniquely identify the owner of the - // connection. - ID uuid.UUID - - Name string - Start int64 - LastWrite int64 - Overwrites int64 -} - -func (t *TrackedConn) Enqueue(n []*Node) (err error) { - atomic.StoreInt64(&t.LastWrite, time.Now().Unix()) - select { - case t.updates <- n: - return nil - default: - return ErrWouldBlock - } -} - -// Close the connection and cancel the context for reading node updates from the queue -func (t *TrackedConn) Close() error { - t.cancel() - return t.conn.Close() -} - -// WriteTimeout is the amount of time we wait to write a node update to a connection before we declare it hung. -// It is exported so that tests can use it. -const WriteTimeout = time.Second * 5 - -// SendUpdates reads node updates and writes them to the connection. Ends when writes hit an error or context is -// canceled. -func (t *TrackedConn) SendUpdates() { - for { - select { - case <-t.ctx.Done(): - t.logger.Debug(t.ctx, "done sending updates") - return - case nodes := <-t.updates: - data, err := json.Marshal(nodes) - if err != nil { - t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes)) - return - } - if bytes.Equal(t.lastData, data) { - t.logger.Debug(t.ctx, "skipping duplicate update", slog.F("nodes", string(data))) - continue - } - - // Set a deadline so that hung connections don't put back pressure on the system. - // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. - err = t.conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - t.logger.Debug(t.ctx, "unable to set write deadline", slog.Error(err)) - _ = t.Close() - return - } - _, err = t.conn.Write(data) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - t.logger.Debug(t.ctx, "could not write nodes to connection", - slog.Error(err), slog.F("nodes", string(data))) - _ = t.Close() - return - } - t.logger.Debug(t.ctx, "wrote nodes", slog.F("nodes", string(data))) - - // nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are - // *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write() - // fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context. - // If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after - // our successful write, it is important that we reset the deadline before it fires. - err = t.conn.SetWriteDeadline(time.Time{}) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - t.logger.Debug(t.ctx, "unable to extend write deadline", slog.Error(err)) - _ = t.Close() - return - } - t.lastData = data - } - } -} - -func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.UUID, logger slog.Logger, overwrites int64) *TrackedConn { - // buffer updates so they don't block, since we hold the - // coordinator mutex while queuing. Node updates don't - // come quickly, so 512 should be plenty for all but - // the most pathological cases. - updates := make(chan []*Node, 512) - now := time.Now().Unix() - return &TrackedConn{ - ctx: ctx, - conn: conn, - cancel: cancel, - updates: updates, - logger: logger, - ID: id, - Start: now, - LastWrite: now, - Overwrites: overwrites, - } -} - // Node returns an in-memory node by ID. // If the node does not exist, nil is returned. func (c *coordinator) Node(id uuid.UUID) *Node { @@ -321,16 +257,29 @@ func (c *core) agentCount() int { // ServeClient accepts a WebSocket connection that wants to connect to an agent // with the specified ID. -func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { +func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - logger := c.core.clientLogger(id, agent) + logger := c.core.clientLogger(id, agentID) logger.Debug(ctx, "coordinating client") - tc, err := c.core.initAndTrackClient(ctx, cancel, conn, id, agent) + + tc := NewTrackedConn(ctx, cancel, conn, id, logger, 0) + defer tc.Close() + + c.core.addClient(id, tc) + defer c.core.clientDisconnected(id) + + agentNode, err := c.core.clientSubscribeToAgent(tc, agentID) if err != nil { - return err + return xerrors.Errorf("subscribe agent: %w", err) + } + + if agentNode != nil { + err := tc.Enqueue([]*Node{agentNode}) + if err != nil { + logger.Debug(ctx, "enqueue initial node", slog.Error(err)) + } } - defer c.core.clientDisconnected(id, agent) // On this goroutine, we read updates from the client and publish them. We start a second goroutine // to write updates back to the client. @@ -338,7 +287,7 @@ func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) decoder := json.NewDecoder(conn) for { - err := c.handleNextClientMessage(id, agent, decoder) + err := c.handleNextClientMessage(id, decoder) if err != nil { logger.Debug(ctx, "unable to read client update, connection may be closed", slog.Error(err)) if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) { @@ -353,99 +302,133 @@ func (c *core) clientLogger(id, agent uuid.UUID) slog.Logger { return c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent)) } -// initAndTrackClient creates a TrackedConn for the client, and sends any initial Node updates if we have any. It is -// one function that does two things because it is critical that we hold the mutex for both things, lest we miss some -// updates. -func (c *core) initAndTrackClient( - ctx context.Context, cancel func(), conn net.Conn, id, agent uuid.UUID, -) ( - *TrackedConn, error, -) { - logger := c.clientLogger(id, agent) - c.mutex.Lock() - defer c.mutex.Unlock() - if c.closed { - return nil, xerrors.New("coordinator is closed") - } - tc := NewTrackedConn(ctx, cancel, conn, id, logger, 0) - - // When a new connection is requested, we update it with the latest - // node of the agent. This allows the connection to establish. - node, ok := c.nodes[agent] - if ok { - err := tc.Enqueue([]*Node{node}) - // this should never error since we're still the only goroutine that - // knows about the TrackedConn. If we hit an error something really - // wrong is happening - if err != nil { - logger.Critical(ctx, "unable to queue initial node", slog.Error(err)) - return nil, err - } - } - - // Insert this connection into a map so the agent - // can publish node updates. - connectionSockets, ok := c.agentToConnectionSockets[agent] +func (c *core) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq Queue) { + connectionSockets, ok := c.agentToConnectionSockets[agentID] if !ok { - connectionSockets = map[uuid.UUID]*TrackedConn{} - c.agentToConnectionSockets[agent] = connectionSockets + connectionSockets = map[uuid.UUID]Queue{} + c.agentToConnectionSockets[agentID] = connectionSockets } - connectionSockets[id] = tc - logger.Debug(ctx, "added tracked connection") - return tc, nil + connectionSockets[enq.UniqueID()] = enq + + c.clientsToAgents[enq.UniqueID()][agentID] = c.agentSockets[agentID] } -func (c *core) clientDisconnected(id, agent uuid.UUID) { - logger := c.clientLogger(id, agent) +func (c *core) clientDisconnected(id uuid.UUID) { + logger := c.clientLogger(id, uuid.Nil) c.mutex.Lock() defer c.mutex.Unlock() // Clean all traces of this connection from the map. delete(c.nodes, id) logger.Debug(context.Background(), "deleted client node") - connectionSockets, ok := c.agentToConnectionSockets[agent] - if !ok { - return - } - delete(connectionSockets, id) - logger.Debug(context.Background(), "deleted client connectionSocket from map") - if len(connectionSockets) != 0 { - return + + for agentID := range c.clientsToAgents[id] { + connectionSockets, ok := c.agentToConnectionSockets[agentID] + if !ok { + continue + } + delete(connectionSockets, id) + logger.Debug(context.Background(), "deleted client connectionSocket from map", slog.F("agent_id", agentID)) + + if len(connectionSockets) == 0 { + delete(c.agentToConnectionSockets, agentID) + logger.Debug(context.Background(), "deleted last client connectionSocket from map", slog.F("agent_id", agentID)) + } } - delete(c.agentToConnectionSockets, agent) - logger.Debug(context.Background(), "deleted last client connectionSocket from map") + + delete(c.clients, id) + delete(c.clientsToAgents, id) + logger.Debug(context.Background(), "deleted client agents") } -func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error { - logger := c.core.clientLogger(id, agent) +func (c *coordinator) handleNextClientMessage(id uuid.UUID, decoder *json.Decoder) error { + logger := c.core.clientLogger(id, uuid.Nil) + var node Node err := decoder.Decode(&node) if err != nil { return xerrors.Errorf("read json: %w", err) } + logger.Debug(context.Background(), "got client node update", slog.F("node", node)) - return c.core.clientNodeUpdate(id, agent, &node) + return c.core.clientNodeUpdate(id, &node) } -func (c *core) clientNodeUpdate(id, agent uuid.UUID, node *Node) error { - logger := c.clientLogger(id, agent) +func (c *core) clientNodeUpdate(id uuid.UUID, node *Node) error { c.mutex.Lock() defer c.mutex.Unlock() + // Update the node of this client in our in-memory map. If an agent entirely // shuts down and reconnects, it needs to be aware of all clients attempting // to establish connections. c.nodes[id] = node - agentSocket, ok := c.agentSockets[agent] - if !ok { - logger.Debug(context.Background(), "no agent socket, unable to send node") - return nil + return c.clientNodeUpdateLocked(id, node) +} + +func (c *core) clientNodeUpdateLocked(id uuid.UUID, node *Node) error { + logger := c.clientLogger(id, uuid.Nil) + + agents := []uuid.UUID{} + for agentID, agentSocket := range c.clientsToAgents[id] { + if agentSocket == nil { + logger.Debug(context.Background(), "enqueue node to agent; socket is nil", slog.F("agent_id", agentID)) + continue + } + + err := agentSocket.Enqueue([]*Node{node}) + if err != nil { + logger.Debug(context.Background(), "unable to Enqueue node to agent", slog.Error(err), slog.F("agent_id", agentID)) + continue + } + agents = append(agents, agentID) } - err := agentSocket.Enqueue([]*Node{node}) - if err != nil { - return xerrors.Errorf("Enqueue node: %w", err) + logger.Debug(context.Background(), "enqueued node to agents", slog.F("agent_ids", agents)) + return nil +} + +func (c *core) clientSubscribeToAgent(enq Queue, agentID uuid.UUID) (*Node, error) { + c.mutex.Lock() + defer c.mutex.Unlock() + + logger := c.clientLogger(enq.UniqueID(), agentID) + + c.initOrSetAgentConnectionSocketLocked(agentID, enq) + + node, ok := c.nodes[enq.UniqueID()] + if ok { + // If we have the client node, send it to the agent. If not, it will be + // sent async. + agentSocket, ok := c.agentSockets[agentID] + if !ok { + logger.Debug(context.Background(), "subscribe to agent; socket is nil") + } else { + err := agentSocket.Enqueue([]*Node{node}) + if err != nil { + return nil, xerrors.Errorf("enqueue client to agent: %w", err) + } + } + } else { + logger.Debug(context.Background(), "multiagent node doesn't exist") } - logger.Debug(context.Background(), "enqueued node to agent") + + agentNode, ok := c.nodes[agentID] + if !ok { + // This is ok, once the agent connects the node will be sent over. + logger.Debug(context.Background(), "agent node doesn't exist", slog.F("agent_id", agentID)) + } + + // Send the subscribed agent back to the multi agent. + return agentNode, nil +} + +func (c *core) clientUnsubscribeFromAgent(enq Queue, agentID uuid.UUID) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + delete(c.clientsToAgents[enq.UniqueID()], agentID) + delete(c.agentToConnectionSockets[agentID], enq.UniqueID()) + return nil } @@ -493,11 +476,14 @@ func (c *core) agentDisconnected(id, unique uuid.UUID) { // Only delete the connection if it's ours. It could have been // overwritten. - if idConn, ok := c.agentSockets[id]; ok && idConn.ID == unique { + if idConn, ok := c.agentSockets[id]; ok && idConn.UniqueID() == unique { delete(c.agentSockets, id) delete(c.nodes, id) logger.Debug(context.Background(), "deleted agent socket and node") } + for clientID := range c.agentToConnectionSockets[id] { + c.clientsToAgents[clientID][id] = nil + } } // initAndTrackAgent creates a TrackedConn for the agent, and sends any initial nodes updates if we have any. It is @@ -519,7 +505,7 @@ func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Co // dead. oldAgentSocket, ok := c.agentSockets[id] if ok { - overwrites = oldAgentSocket.Overwrites + 1 + overwrites = oldAgentSocket.Overwrites() + 1 _ = oldAgentSocket.Close() } tc := NewTrackedConn(ctx, cancel, conn, unique, logger, overwrites) @@ -549,6 +535,10 @@ func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Co } c.agentSockets[id] = tc + for clientID := range c.agentToConnectionSockets[id] { + c.clientsToAgents[clientID][id] = tc + } + logger.Debug(ctx, "added agent socket") return tc, nil } @@ -564,11 +554,31 @@ func (c *coordinator) handleNextAgentMessage(id uuid.UUID, decoder *json.Decoder return c.core.agentNodeUpdate(id, &node) } +// This is copied from codersdk because importing it here would cause an import +// cycle. This is just temporary until wsconncache is phased out. +var legacyAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4") + +// This is temporary until we no longer need to detect for agent backwards +// compatibility. +// See: https://github.com/coder/coder/issues/8218 +func (c *core) agentIsLegacy(agentID uuid.UUID) bool { + c.mutex.RLock() + _, ok := c.legacyAgents[agentID] + c.mutex.RUnlock() + return ok +} + func (c *core) agentNodeUpdate(id uuid.UUID, node *Node) error { logger := c.agentLogger(id) c.mutex.Lock() defer c.mutex.Unlock() c.nodes[id] = node + + // Keep a cache of all legacy agents. + if len(node.Addresses) > 0 && node.Addresses[0].Addr() == legacyAgentIP { + c.legacyAgents[id] = struct{}{} + } + connectionSockets, ok := c.agentToConnectionSockets[id] if !ok { logger.Debug(context.Background(), "no client sockets; unable to send node") @@ -588,6 +598,7 @@ func (c *core) agentNodeUpdate(id uuid.UUID, node *Node) error { slog.F("client_id", clientID), slog.Error(err)) } } + return nil } @@ -611,20 +622,18 @@ func (c *core) close() error { for _, socket := range c.agentSockets { socket := socket go func() { - _ = socket.Close() + _ = socket.CoordinatorClose() wg.Done() }() } - for _, connMap := range c.agentToConnectionSockets { - wg.Add(len(connMap)) - for _, socket := range connMap { - socket := socket - go func() { - _ = socket.Close() - wg.Done() - }() - } + wg.Add(len(c.clients)) + for _, client := range c.clients { + client := client + go func() { + _ = client.CoordinatorClose() + wg.Done() + }() } c.mutex.Unlock() @@ -649,8 +658,8 @@ func (c *core) serveHTTPDebug(w http.ResponseWriter, r *http.Request) { } func CoordinatorHTTPDebug( - agentSocketsMap map[uuid.UUID]*TrackedConn, - agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]*TrackedConn, + agentSocketsMap map[uuid.UUID]Queue, + agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]Queue, agentNameCache *lru.Cache[uuid.UUID, string], ) func(w http.ResponseWriter, _ *http.Request) { return func(w http.ResponseWriter, _ *http.Request) { @@ -658,7 +667,7 @@ func CoordinatorHTTPDebug( type idConn struct { id uuid.UUID - conn *TrackedConn + conn Queue } { @@ -671,16 +680,17 @@ func CoordinatorHTTPDebug( } slices.SortFunc(agentSockets, func(a, b idConn) bool { - return a.conn.Name < b.conn.Name + return a.conn.Name() < b.conn.Name() }) for _, agent := range agentSockets { + start, lastWrite := agent.conn.Stats() _, _ = fmt.Fprintf(w, "
%s
): created %v ago, write %v ago, overwrites %d %s
): created %v ago, write %v ago %s
): created %v ago, write %v ago Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: