From 42ce2463c159898351e22b5f568d3f436cd43d35 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Tue, 12 Aug 2025 21:20:59 +0000 Subject: [PATCH 1/4] chore: added support for immortal streams to cli and agent --- agent/agent.go | 50 ++++- cli/exp.go | 1 + cli/immortalstreams.go | 188 ++++++++++++++++++ cli/portforward.go | 17 ++ cli/ssh.go | 304 ++++++++++++++++++++++++++++- coderd/coderd.go | 5 + coderd/workspaceagents.go | 206 +++++++++++++++++++ codersdk/workspaceagents.go | 41 ++++ codersdk/workspacesdk/agentconn.go | 69 +++++++ 9 files changed, 871 insertions(+), 10 deletions(-) create mode 100644 cli/immortalstreams.go diff --git a/agent/agent.go b/agent/agent.go index 31b48edd4dc83..4cefcfa9f8616 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -70,6 +70,47 @@ const ( EnvProcOOMScore = "CODER_PROC_OOM_SCORE" ) +// agentImmortalDialer is a custom dialer for immortal streams that can +// connect to the agent's own services via tailnet addresses. +type agentImmortalDialer struct { + agent *agent + standardDialer *net.Dialer +} + +func (d *agentImmortalDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + host, portStr, err := net.SplitHostPort(address) + if err != nil { + return nil, xerrors.Errorf("split host port %q: %w", address, err) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, xerrors.Errorf("parse port %q: %w", portStr, err) + } + + // Check if this is a connection to one of the agent's own services + isLocalhost := host == "localhost" || host == "127.0.0.1" || host == "::1" + isAgentPort := port == int(workspacesdk.AgentSSHPort) || port == int(workspacesdk.AgentHTTPAPIServerPort) || + port == int(workspacesdk.AgentReconnectingPTYPort) || port == int(workspacesdk.AgentSpeedtestPort) + + if isLocalhost && isAgentPort { + // Get the agent ID from the current manifest + manifest := d.agent.manifest.Load() + if manifest == nil || manifest.AgentID == uuid.Nil { + // Fallback to standard dialing if no manifest available yet + return d.standardDialer.DialContext(ctx, network, address) + } + + // Connect to the agent's own tailnet address instead of localhost + agentAddr := tailnet.TailscaleServicePrefix.AddrFromUUID(manifest.AgentID) + agentAddress := net.JoinHostPort(agentAddr.String(), portStr) + return d.standardDialer.DialContext(ctx, network, agentAddress) + } + + // For other addresses, use standard dialing + return d.standardDialer.DialContext(ctx, network, address) +} + type Options struct { Filesystem afero.Fs LogDir string @@ -351,8 +392,13 @@ func (a *agent) init() { a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...) - // Initialize immortal streams manager - a.immortalStreamsManager = immortalstreams.New(a.logger.Named("immortal-streams"), &net.Dialer{}) + // Initialize immortal streams manager with a custom dialer + // that can connect to the agent's own services + immortalDialer := &agentImmortalDialer{ + agent: a, + standardDialer: &net.Dialer{}, + } + a.immortalStreamsManager = immortalstreams.New(a.logger.Named("immortal-streams"), immortalDialer) a.reconnectingPTYServer = reconnectingpty.NewServer( a.logger.Named("reconnecting-pty"), diff --git a/cli/exp.go b/cli/exp.go index dafd85402663e..a4d4640bf6057 100644 --- a/cli/exp.go +++ b/cli/exp.go @@ -16,6 +16,7 @@ func (r *RootCmd) expCmd() *serpent.Command { r.mcpCommand(), r.promptExample(), r.rptyCommand(), + r.immortalStreamCmd(), }, } return cmd diff --git a/cli/immortalstreams.go b/cli/immortalstreams.go new file mode 100644 index 0000000000000..7dc3e0300d7ab --- /dev/null +++ b/cli/immortalstreams.go @@ -0,0 +1,188 @@ +package cli + +import ( + "context" + "fmt" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/v2/cli/cliui" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/serpent" +) + +// immortalStreamClient provides methods to interact with immortal streams API +// This uses the main codersdk.Client to make server-proxied requests to agents +type immortalStreamClient struct { + client *codersdk.Client + agentID uuid.UUID + logger slog.Logger +} + +// newImmortalStreamClient creates a new client for immortal streams +func newImmortalStreamClient(client *codersdk.Client, agentID uuid.UUID, logger slog.Logger) *immortalStreamClient { + return &immortalStreamClient{ + client: client, + agentID: agentID, + logger: logger, + } +} + +// createStream creates a new immortal stream +func (c *immortalStreamClient) createStream(ctx context.Context, port int) (*codersdk.ImmortalStream, error) { + stream, err := c.client.WorkspaceAgentCreateImmortalStream(ctx, c.agentID, codersdk.CreateImmortalStreamRequest{ + TCPPort: port, + }) + if err != nil { + return nil, err + } + return &stream, nil +} + +// listStreams lists all immortal streams +func (c *immortalStreamClient) listStreams(ctx context.Context) ([]codersdk.ImmortalStream, error) { + return c.client.WorkspaceAgentImmortalStreams(ctx, c.agentID) +} + +// deleteStream deletes an immortal stream +func (c *immortalStreamClient) deleteStream(ctx context.Context, streamID uuid.UUID) error { + return c.client.WorkspaceAgentDeleteImmortalStream(ctx, c.agentID, streamID) +} + +// CLI Commands + +func (r *RootCmd) immortalStreamCmd() *serpent.Command { + client := new(codersdk.Client) + cmd := &serpent.Command{ + Use: "immortal-stream", + Short: "Manage immortal streams in workspaces", + Long: "Immortal streams provide persistent TCP connections to workspace services that automatically reconnect when interrupted.", + Middleware: serpent.Chain( + r.InitClient(client), + ), + Handler: func(inv *serpent.Invocation) error { + return inv.Command.HelpHandler(inv) + }, + Children: []*serpent.Command{ + r.immortalStreamListCmd(), + r.immortalStreamDeleteCmd(), + }, + } + return cmd +} + +func (r *RootCmd) immortalStreamListCmd() *serpent.Command { + client := new(codersdk.Client) + cmd := &serpent.Command{ + Use: "list ", + Short: "List active immortal streams in a workspace", + Middleware: serpent.Chain( + serpent.RequireNArgs(1), + r.InitClient(client), + ), + Handler: func(inv *serpent.Invocation) error { + ctx := inv.Context() + workspaceName := inv.Args[0] + + workspace, workspaceAgent, _, err := getWorkspaceAndAgent(ctx, inv, client, false, workspaceName) + if err != nil { + return err + } + + if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart { + return xerrors.New("workspace must be running to list immortal streams") + } + + // Create immortal stream client + // Note: We don't need to dial the agent for management operations + // as these go through the server's proxy endpoints + streamClient := newImmortalStreamClient(client, workspaceAgent.ID, inv.Logger) + streams, err := streamClient.listStreams(ctx) + if err != nil { + return xerrors.Errorf("list immortal streams: %w", err) + } + + if len(streams) == 0 { + cliui.Info(inv.Stderr, "No active immortal streams found.") + return nil + } + + // Display the streams in a table + displayImmortalStreams(inv, streams) + return nil + }, + } + return cmd +} + +func (r *RootCmd) immortalStreamDeleteCmd() *serpent.Command { + client := new(codersdk.Client) + cmd := &serpent.Command{ + Use: "delete ", + Short: "Delete an active immortal stream", + Middleware: serpent.Chain( + serpent.RequireNArgs(2), + r.InitClient(client), + ), + Handler: func(inv *serpent.Invocation) error { + ctx := inv.Context() + workspaceName := inv.Args[0] + streamName := inv.Args[1] + + workspace, workspaceAgent, _, err := getWorkspaceAndAgent(ctx, inv, client, false, workspaceName) + if err != nil { + return err + } + + if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart { + return xerrors.New("workspace must be running to delete immortal streams") + } + + // Create immortal stream client + streamClient := newImmortalStreamClient(client, workspaceAgent.ID, inv.Logger) + streams, err := streamClient.listStreams(ctx) + if err != nil { + return xerrors.Errorf("list immortal streams: %w", err) + } + + var targetStream *codersdk.ImmortalStream + for _, stream := range streams { + if stream.Name == streamName { + targetStream = &stream + break + } + } + + if targetStream == nil { + return xerrors.Errorf("immortal stream %q not found", streamName) + } + + // Delete the stream + err = streamClient.deleteStream(ctx, targetStream.ID) + if err != nil { + return xerrors.Errorf("delete immortal stream: %w", err) + } + + cliui.Info(inv.Stderr, fmt.Sprintf("Deleted immortal stream %q (ID: %s)", streamName, targetStream.ID)) + return nil + }, + } + return cmd +} + +func displayImmortalStreams(inv *serpent.Invocation, streams []codersdk.ImmortalStream) { + _, _ = fmt.Fprintf(inv.Stderr, "Active Immortal Streams:\n\n") + _, _ = fmt.Fprintf(inv.Stderr, "%-20s %-6s %-20s %-20s\n", "NAME", "PORT", "CREATED", "LAST CONNECTED") + _, _ = fmt.Fprintf(inv.Stderr, "%-20s %-6s %-20s %-20s\n", "----", "----", "-------", "--------------") + + for _, stream := range streams { + createdTime := stream.CreatedAt.Format("2006-01-02 15:04:05") + lastConnTime := stream.LastConnectionAt.Format("2006-01-02 15:04:05") + + _, _ = fmt.Fprintf(inv.Stderr, "%-20s %-6d %-20s %-20s\n", + stream.Name, stream.TCPPort, createdTime, lastConnTime) + } + _, _ = fmt.Fprintf(inv.Stderr, "\n") +} diff --git a/cli/portforward.go b/cli/portforward.go index 1b055d9e4362e..d96a0d697d289 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -39,6 +39,10 @@ func (r *RootCmd) portForward() *serpent.Command { udpForwards []string // : disableAutostart bool appearanceConfig codersdk.AppearanceConfig + + // Immortal streams flags + immortal bool + immortalFallback bool = true // Default to true for port-forward ) client := new(codersdk.Client) cmd := &serpent.Command{ @@ -212,6 +216,19 @@ func (r *RootCmd) portForward() *serpent.Command { Description: "Forward UDP port(s) from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols.", Value: serpent.StringArrayOf(&udpForwards), }, + { + Flag: "immortal", + Description: "Use immortal streams for port forwarding connections, providing automatic reconnection when interrupted.", + Value: serpent.BoolOf(&immortal), + Hidden: true, + }, + { + Flag: "immortal-fallback", + Description: "If immortal streams are unavailable due to connection limits, fall back to regular TCP connection.", + Default: "true", + Value: serpent.BoolOf(&immortalFallback), + Hidden: true, + }, sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)), } diff --git a/cli/ssh.go b/cli/ssh.go index a2f0db7327bef..8ce2a0420f172 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -48,6 +48,7 @@ import ( "github.com/coder/quartz" "github.com/coder/retry" "github.com/coder/serpent" + "github.com/coder/websocket" ) const ( @@ -85,6 +86,10 @@ func (r *RootCmd) ssh() *serpent.Command { containerName string containerUser string + + // Immortal streams flags + immortal bool + immortalFallback bool // Default to false for SSH ) client := new(codersdk.Client) wsClient := workspacesdk.New(client) @@ -387,11 +392,83 @@ func (r *RootCmd) ssh() *serpent.Command { } if stdio { - rawSSH, err := conn.SSH(ctx) - if err != nil { - return xerrors.Errorf("connect SSH: %w", err) + var rawSSH net.Conn + var immortalStreamClient *immortalStreamClient + var streamID *uuid.UUID + + if immortal { + // Use immortal stream for SSH connection + immortalStreamClient = newImmortalStreamClient(client, workspaceAgent.ID, logger) + + // Create immortal stream to agent SSH port (1) + stream, err := immortalStreamClient.createStream(ctx, 1) + if err != nil { + logger.Error(ctx, "failed to create immortal stream for SSH", + slog.Error(err), + slog.F("agent_id", workspaceAgent.ID), + slog.F("target_port", 1), + slog.F("workspace", workspace.Name), + slog.F("agent_status", workspaceAgent.Status), + slog.F("immortal_fallback_enabled", immortalFallback)) + + shouldFallback := immortalFallback && (strings.Contains(err.Error(), "too many immortal streams") || + strings.Contains(err.Error(), "The connection was refused")) + + if shouldFallback { + if strings.Contains(err.Error(), "too many immortal streams") { + logger.Warn(ctx, "too many immortal streams, falling back to regular SSH connection", + slog.F("max_streams", "32")) + } else { + logger.Warn(ctx, "Agent SSH service not available on port 1, falling back to regular SSH connection", + slog.F("reason", "connection_refused"), + slog.F("suggestion", "agent SSH server may not be running")) + } + logger.Info(ctx, "attempting fallback to regular SSH connection") + rawSSH, err = conn.SSH(ctx) + if err != nil { + logger.Error(ctx, "fallback SSH connection also failed", slog.Error(err)) + return xerrors.Errorf("connect SSH (fallback): %w", err) + } + logger.Info(ctx, "successfully connected via regular SSH fallback") + } else { + return xerrors.Errorf("create immortal stream for SSH: %w", err) + } + } else { + streamID = &stream.ID + logger.Info(ctx, "created immortal stream for SSH", slog.F("stream_name", stream.Name), slog.F("stream_id", stream.ID)) + + // Connect to the immortal stream via WebSocket + rawSSH, err = connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger) + if err != nil { + // Clean up the stream if connection fails + _ = immortalStreamClient.deleteStream(ctx, stream.ID) + return xerrors.Errorf("connect to immortal stream: %w", err) + } + } + } else { + // Use regular SSH connection + rawSSH, err = conn.SSH(ctx) + if err != nil { + return xerrors.Errorf("connect SSH: %w", err) + } } - copier := newRawSSHCopier(logger, rawSSH, stdioReader, stdioWriter) + + var copier io.Closer + + if tcpConn, ok := rawSSH.(*gonet.TCPConn); ok { + // Use specialized raw SSH copier for regular TCP connections + rawCopier := newRawSSHCopier(logger, tcpConn, stdioReader, stdioWriter) + copier = rawCopier + // Start copying in the background for rawSSHCopier + go rawCopier.copy(&wg) + } else { + // Use generic copier for immortal stream connections + genericCopier := newGenericSSHCopier(logger, rawSSH, stdioReader, stdioWriter) + copier = genericCopier + // Start copying in the background for genericSSHCopier + go genericCopier.copy(&wg) + } + if err = stack.push("rawSSHCopier", copier); err != nil { return err } @@ -404,22 +481,108 @@ func (r *RootCmd) ssh() *serpent.Command { } } + // Set up cleanup for immortal stream + if immortalStreamClient != nil && streamID != nil { + defer func() { + if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { + logger.Error(context.Background(), "failed to cleanup immortal stream", slog.Error(err)) + } + }() + } + wg.Add(1) go func() { defer wg.Done() watchAndClose(ctx, func() error { + // Clean up immortal stream on termination + if immortalStreamClient != nil && streamID != nil { + if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { + logger.Error(context.Background(), "failed to cleanup immortal stream on termination", slog.Error(err)) + } + } stack.close(xerrors.New("watchAndClose")) return nil }, logger, client, workspace, errCh) }() - copier.copy(&wg) + // The copying is already started in the background above + wg.Wait() return nil } - sshClient, err := conn.SSHClient(ctx) - if err != nil { - return xerrors.Errorf("ssh client: %w", err) + var sshClient *gossh.Client + var immortalStreamClient *immortalStreamClient + var streamID *uuid.UUID + + if immortal { + // Use immortal stream for SSH connection + immortalStreamClient = newImmortalStreamClient(client, workspaceAgent.ID, logger) + + // Create immortal stream to agent SSH port (1) + stream, err := immortalStreamClient.createStream(ctx, 1) + if err != nil { + logger.Error(ctx, "failed to create immortal stream for SSH (regular mode)", + slog.Error(err), + slog.F("agent_id", workspaceAgent.ID), + slog.F("target_port", 1), + slog.F("workspace", workspace.Name), + slog.F("agent_status", workspaceAgent.Status), + slog.F("immortal_fallback_enabled", immortalFallback)) + + shouldFallback := immortalFallback && (strings.Contains(err.Error(), "too many immortal streams") || + strings.Contains(err.Error(), "The connection was refused")) + + if shouldFallback { + if strings.Contains(err.Error(), "too many immortal streams") { + logger.Warn(ctx, "too many immortal streams, falling back to regular SSH connection", + slog.F("max_streams", "32")) + } else { + logger.Warn(ctx, "Agent SSH service not available on port 1, falling back to regular SSH connection", + slog.F("reason", "connection_refused"), + slog.F("suggestion", "agent SSH server may not be running")) + } + logger.Info(ctx, "attempting fallback to regular SSH client") + sshClient, err = conn.SSHClient(ctx) + if err != nil { + logger.Error(ctx, "fallback SSH client creation also failed", slog.Error(err)) + return xerrors.Errorf("ssh client (fallback): %w", err) + } + logger.Info(ctx, "successfully created SSH client via regular fallback") + } else { + return xerrors.Errorf("create immortal stream for SSH: %w", err) + } + } else { + streamID = &stream.ID + logger.Info(ctx, "created immortal stream for SSH", slog.F("stream_name", stream.Name), slog.F("stream_id", stream.ID)) + + // Connect to the immortal stream and create SSH client + rawConn, err := connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger) + if err != nil { + // Clean up the stream if connection fails + _ = immortalStreamClient.deleteStream(ctx, stream.ID) + return xerrors.Errorf("connect to immortal stream: %w", err) + } + + // Create SSH client over the immortal stream connection + sshConn, chans, reqs, err := gossh.NewClientConn(rawConn, "localhost:22", &gossh.ClientConfig{ + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + Timeout: 30 * time.Second, + }) + if err != nil { + rawConn.Close() + _ = immortalStreamClient.deleteStream(ctx, stream.ID) + return xerrors.Errorf("ssh handshake over immortal stream: %w", err) + } + + sshClient = gossh.NewClient(sshConn, chans, reqs) + } + } else { + // Use regular SSH connection + sshClient, err = conn.SSHClient(ctx) + if err != nil { + return xerrors.Errorf("ssh client: %w", err) + } } + if err = stack.push("ssh client", sshClient); err != nil { return err } @@ -440,12 +603,27 @@ func (r *RootCmd) ssh() *serpent.Command { } } + // Set up cleanup for immortal stream in regular SSH mode + if immortalStreamClient != nil && streamID != nil { + defer func() { + if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { + logger.Error(context.Background(), "failed to cleanup immortal stream", slog.Error(err)) + } + }() + } + wg.Add(1) go func() { defer wg.Done() watchAndClose( ctx, func() error { + // Clean up immortal stream on termination + if immortalStreamClient != nil && streamID != nil { + if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { + logger.Error(context.Background(), "failed to cleanup immortal stream on termination", slog.Error(err)) + } + } stack.close(xerrors.New("watchAndClose")) return nil }, @@ -728,11 +906,83 @@ func (r *RootCmd) ssh() *serpent.Command { Value: serpent.BoolOf(&forceNewTunnel), Hidden: true, }, + { + Flag: "immortal", + Description: "Use immortal streams for SSH connection, providing automatic reconnection when interrupted.", + Value: serpent.BoolOf(&immortal), + Hidden: true, + }, + { + Flag: "immortal-fallback", + Description: "If immortal streams are unavailable due to connection limits, fall back to regular TCP connection.", + Value: serpent.BoolOf(&immortalFallback), + Hidden: true, + }, sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)), } return cmd } +// connectToImmortalStreamWebSocket connects to an immortal stream via WebSocket and returns a net.Conn +func connectToImmortalStreamWebSocket(ctx context.Context, agentConn *workspacesdk.AgentConn, streamID uuid.UUID, logger slog.Logger) (net.Conn, error) { + // Build the target address for the agent's HTTP API server + // We'll let the WebSocket dialer handle the actual connection through the agent + apiServerAddr := fmt.Sprintf("127.0.0.1:%d", workspacesdk.AgentHTTPAPIServerPort) + wsURL := fmt.Sprintf("ws://%s/api/v0/immortal-stream/%s", apiServerAddr, streamID) + + // Create WebSocket connection using the agent's tailnet connection + // The key is to use a custom dialer that routes through the agent connection + dialOptions := &websocket.DialOptions{ + HTTPClient: &http.Client{ + Transport: &http.Transport{ + DialContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) { + // Route all connections through the agent connection + // The agent connection will handle routing to the correct internal address + + conn, err := agentConn.DialContext(dialCtx, network, addr) + if err != nil { + return nil, err + } + + return conn, nil + }, + }, + }, + // Disable compression for raw TCP data + CompressionMode: websocket.CompressionDisabled, + } + + // Connect to the WebSocket endpoint + conn, res, err := websocket.Dial(ctx, wsURL, dialOptions) + if err != nil { + if res != nil { + logger.Error(ctx, "WebSocket dial failed", + slog.F("stream_id", streamID), + slog.F("websocket_url", wsURL), + slog.F("status", res.StatusCode), + slog.F("status_text", res.Status), + slog.Error(err)) + } else { + logger.Error(ctx, "WebSocket dial failed (no response)", + slog.F("stream_id", streamID), + slog.F("websocket_url", wsURL), + slog.Error(err)) + } + return nil, xerrors.Errorf("dial immortal stream WebSocket: %w", err) + } + + logger.Info(ctx, "successfully connected to immortal stream WebSocket", + slog.F("stream_id", streamID)) + + // Convert WebSocket to net.Conn for SSH usage + // Use MessageBinary for raw TCP data transport + netConn := websocket.NetConn(ctx, conn, websocket.MessageBinary) + + logger.Debug(ctx, "converted WebSocket to net.Conn for SSH usage") + + return netConn, nil +} + // findWorkspaceAndAgentByHostname parses the hostname from the commandline and finds the workspace and agent it // corresponds to, taking into account any name prefixes or suffixes configured (e.g. myworkspace.coder, or // vscode-coder--myusername--myworkspace). @@ -1276,6 +1526,44 @@ func newRawSSHCopier(logger slog.Logger, conn *gonet.TCPConn, r io.Reader, w io. return &rawSSHCopier{conn: conn, logger: logger, r: r, w: w, done: make(chan struct{})} } +// genericSSHCopier is similar to rawSSHCopier but works with any net.Conn (e.g., immortal streams) +type genericSSHCopier struct { + conn net.Conn + logger slog.Logger + r io.Reader + w io.Writer + done chan struct{} +} + +func newGenericSSHCopier(logger slog.Logger, conn net.Conn, r io.Reader, w io.Writer) *genericSSHCopier { + return &genericSSHCopier{conn: conn, logger: logger, r: r, w: w, done: make(chan struct{})} +} + +func (c *genericSSHCopier) copy(wg *sync.WaitGroup) { + defer close(c.done) + + // Copy stdin to connection + go func() { + defer c.conn.Close() + _, err := io.Copy(c.conn, c.r) + if err != nil { + c.logger.Debug(context.Background(), "error copying stdin to connection", slog.Error(err)) + } + }() + + // Copy connection to stdout + _, err := io.Copy(c.w, c.conn) + if err != nil { + c.logger.Debug(context.Background(), "error copying connection to stdout", slog.Error(err)) + } +} + +func (c *genericSSHCopier) Close() error { + c.conn.Close() + <-c.done + return nil +} + func (c *rawSSHCopier) copy(wg *sync.WaitGroup) { defer close(c.done) logCtx := context.Background() diff --git a/coderd/coderd.go b/coderd/coderd.go index 8ab204f8a31ef..986ea8daa079f 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1391,6 +1391,11 @@ func New(options *Options) *API { r.Get("/containers/watch", api.watchWorkspaceAgentContainers) r.Post("/containers/devcontainers/{devcontainer}/recreate", api.workspaceAgentRecreateDevcontainer) r.Get("/coordinate", api.workspaceAgentClientCoordinate) + r.Route("/immortal-streams", func(r chi.Router) { + r.Get("/", api.workspaceAgentImmortalStreams) + r.Post("/", api.workspaceAgentCreateImmortalStream) + r.Delete("/{immortalstream}", api.workspaceAgentDeleteImmortalStream) + }) // PTY is part of workspaceAppServer. }) diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index f2ee1ac18e823..9da3c2fb85127 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -805,6 +805,212 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req httpapi.Write(ctx, rw, http.StatusOK, portsResponse) } +// @Summary Get workspace agent immortal streams +// @ID get-workspace-agent-immortal-streams +// @Security CoderSessionToken +// @Produce json +// @Tags Agents +// @Param workspaceagent path string true "Workspace agent ID" format(uuid) +// @Success 200 {array} codersdk.ImmortalStream +// @Router /workspaceagents/{workspaceagent}/immortal-streams [get] +func (api *API) workspaceAgentImmortalStreams(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + workspaceAgent := httpmw.WorkspaceAgentParam(r) + + // Check agent connectivity with timeout + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + apiAgent, err := db2sdk.WorkspaceAgent( + api.DERPMap(), *api.TailnetCoordinator.Load(), workspaceAgent, nil, nil, nil, api.AgentInactiveDisconnectTimeout, + api.DeploymentValues.AgentFallbackTroubleshootingURL.String(), + ) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error reading workspace agent.", + Detail: err.Error(), + }) + return + } + if apiAgent.Status != codersdk.WorkspaceAgentConnected { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Agent state is %q, it must be in the %q state.", apiAgent.Status, codersdk.WorkspaceAgentConnected), + }) + return + } + + agentConn, release, err := api.agentProvider.AgentConn(ctx, workspaceAgent.ID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error dialing workspace agent.", + Detail: err.Error(), + }) + return + } + defer release() + + streams, err := agentConn.ImmortalStreams(ctx) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching immortal streams.", + Detail: err.Error(), + }) + return + } + + httpapi.Write(ctx, rw, http.StatusOK, streams) +} + +// @Summary Create workspace agent immortal stream +// @ID create-workspace-agent-immortal-stream +// @Security CoderSessionToken +// @Accept json +// @Produce json +// @Tags Agents +// @Param workspaceagent path string true "Workspace agent ID" format(uuid) +// @Param request body codersdk.CreateImmortalStreamRequest true "Create immortal stream request" +// @Success 201 {object} codersdk.ImmortalStream +// @Router /workspaceagents/{workspaceagent}/immortal-streams [post] +func (api *API) workspaceAgentCreateImmortalStream(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + workspaceAgent := httpmw.WorkspaceAgentParam(r) + + var req codersdk.CreateImmortalStreamRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + + // Check agent connectivity with timeout + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + apiAgent, err := db2sdk.WorkspaceAgent( + api.DERPMap(), *api.TailnetCoordinator.Load(), workspaceAgent, nil, nil, nil, api.AgentInactiveDisconnectTimeout, + api.DeploymentValues.AgentFallbackTroubleshootingURL.String(), + ) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error reading workspace agent.", + Detail: err.Error(), + }) + return + } + if apiAgent.Status != codersdk.WorkspaceAgentConnected { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Agent state is %q, it must be in the %q state.", apiAgent.Status, codersdk.WorkspaceAgentConnected), + }) + return + } + + agentConn, release, err := api.agentProvider.AgentConn(ctx, workspaceAgent.ID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error dialing workspace agent.", + Detail: err.Error(), + }) + return + } + defer release() + + stream, err := agentConn.CreateImmortalStream(ctx, req) + if err != nil { + // Check for specific error types from the agent + if strings.Contains(err.Error(), "too many immortal streams") { + httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{ + Message: "Too many immortal streams.", + }) + return + } + if strings.Contains(err.Error(), "connection was refused") { + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ + Message: "The connection was refused.", + }) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error creating immortal stream.", + Detail: err.Error(), + }) + return + } + + httpapi.Write(ctx, rw, http.StatusCreated, stream) +} + +// @Summary Delete workspace agent immortal stream +// @ID delete-workspace-agent-immortal-stream +// @Security CoderSessionToken +// @Tags Agents +// @Param workspaceagent path string true "Workspace agent ID" format(uuid) +// @Param immortalstream path string true "Immortal stream ID" format(uuid) +// @Success 200 {object} codersdk.Response +// @Router /workspaceagents/{workspaceagent}/immortal-streams/{immortalstream} [delete] +func (api *API) workspaceAgentDeleteImmortalStream(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + workspaceAgent := httpmw.WorkspaceAgentParam(r) + + streamIDStr := chi.URLParam(r, "immortalstream") + streamID, err := uuid.Parse(streamIDStr) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid immortal stream ID format.", + Detail: err.Error(), + }) + return + } + + // Check agent connectivity with timeout + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + apiAgent, err := db2sdk.WorkspaceAgent( + api.DERPMap(), *api.TailnetCoordinator.Load(), workspaceAgent, nil, nil, nil, api.AgentInactiveDisconnectTimeout, + api.DeploymentValues.AgentFallbackTroubleshootingURL.String(), + ) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error reading workspace agent.", + Detail: err.Error(), + }) + return + } + if apiAgent.Status != codersdk.WorkspaceAgentConnected { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Agent state is %q, it must be in the %q state.", apiAgent.Status, codersdk.WorkspaceAgentConnected), + }) + return + } + + agentConn, release, err := api.agentProvider.AgentConn(ctx, workspaceAgent.ID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error dialing workspace agent.", + Detail: err.Error(), + }) + return + } + defer release() + + err = agentConn.DeleteImmortalStream(ctx, streamID) + if err != nil { + if strings.Contains(err.Error(), "stream not found") { + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ + Message: "Immortal stream not found.", + }) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error deleting immortal stream.", + Detail: err.Error(), + }) + return + } + + httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{ + Message: "Immortal stream deleted successfully.", + }) +} + // @Summary Watch workspace agent for container updates. // @ID watch-workspace-agent-for-container-updates // @Security CoderSessionToken diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 4f3faedb534fc..a0fdad857b412 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -393,6 +393,47 @@ func (c *Client) WorkspaceAgentListeningPorts(ctx context.Context, agentID uuid. return listeningPorts, json.NewDecoder(res.Body).Decode(&listeningPorts) } +// WorkspaceAgentImmortalStreams returns a list of immortal streams for the given agent. +func (c *Client) WorkspaceAgentImmortalStreams(ctx context.Context, agentID uuid.UUID) ([]ImmortalStream, error) { + res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/immortal-streams", agentID), nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, ReadBodyAsError(res) + } + var streams []ImmortalStream + return streams, json.NewDecoder(res.Body).Decode(&streams) +} + +// WorkspaceAgentCreateImmortalStream creates a new immortal stream for the given agent. +func (c *Client) WorkspaceAgentCreateImmortalStream(ctx context.Context, agentID uuid.UUID, req CreateImmortalStreamRequest) (ImmortalStream, error) { + res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/workspaceagents/%s/immortal-streams", agentID), req) + if err != nil { + return ImmortalStream{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusCreated { + return ImmortalStream{}, ReadBodyAsError(res) + } + var stream ImmortalStream + return stream, json.NewDecoder(res.Body).Decode(&stream) +} + +// WorkspaceAgentDeleteImmortalStream deletes an immortal stream for the given agent. +func (c *Client) WorkspaceAgentDeleteImmortalStream(ctx context.Context, agentID uuid.UUID, streamID uuid.UUID) error { + res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/v2/workspaceagents/%s/immortal-streams/%s", agentID, streamID), nil) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return ReadBodyAsError(res) + } + return nil +} + // WorkspaceAgentDevcontainerStatus is the status of a devcontainer. type WorkspaceAgentDevcontainerStatus string diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index bb929c9ba2a04..36dd471712a3c 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -1,6 +1,7 @@ package workspacesdk import ( + "bytes" "context" "encoding/binary" "encoding/json" @@ -312,6 +313,74 @@ func (c *agentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgent return resp, json.NewDecoder(res.Body).Decode(&resp) } +// ImmortalStreams lists the immortal streams that are currently active in the workspace. +func (c *AgentConn) ImmortalStreams(ctx context.Context) ([]codersdk.ImmortalStream, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/immortal-stream", nil) + if err != nil { + return nil, xerrors.Errorf("do request: %w", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, codersdk.ReadBodyAsError(res) + } + + var streams []codersdk.ImmortalStream + return streams, json.NewDecoder(res.Body).Decode(&streams) +} + +// CreateImmortalStream creates a new immortal stream to the specified port. +func (c *AgentConn) CreateImmortalStream(ctx context.Context, req codersdk.CreateImmortalStreamRequest) (codersdk.ImmortalStream, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + // Note: We can't easily add logging here since AgentConn doesn't have a logger + // But we can add some debug info to the error messages + + reqBody, err := json.Marshal(req) + if err != nil { + return codersdk.ImmortalStream{}, xerrors.Errorf("marshal request: %w", err) + } + + res, err := c.apiRequest(ctx, http.MethodPost, "/api/v0/immortal-stream", bytes.NewReader(reqBody)) + if err != nil { + return codersdk.ImmortalStream{}, xerrors.Errorf("do request to agent /api/v0/immortal-stream: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusCreated { + bodyErr := codersdk.ReadBodyAsError(res) + return codersdk.ImmortalStream{}, xerrors.Errorf("agent responded with status %d: %w", res.StatusCode, bodyErr) + } + + var stream codersdk.ImmortalStream + err = json.NewDecoder(res.Body).Decode(&stream) + if err != nil { + return codersdk.ImmortalStream{}, xerrors.Errorf("decode response: %w", err) + } + return stream, nil +} + +// DeleteImmortalStream deletes an immortal stream by ID. +func (c *AgentConn) DeleteImmortalStream(ctx context.Context, streamID uuid.UUID) error { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + path := fmt.Sprintf("/api/v0/immortal-stream/%s", streamID) + res, err := c.apiRequest(ctx, http.MethodDelete, path, nil) + if err != nil { + return xerrors.Errorf("do request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusNoContent { + return codersdk.ReadBodyAsError(res) + } + + return nil +} + // Netcheck returns a network check report from the workspace agent. func (c *agentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) { ctx, span := tracing.StartSpan(ctx) From 4b56fdee1e75cbb0a5026873578e30298592664d Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Tue, 19 Aug 2025 18:16:42 +0000 Subject: [PATCH 2/4] WIP --- cli/ssh.go | 8 ++++---- coderd/workspaceagents.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cli/ssh.go b/cli/ssh.go index 8ce2a0420f172..e299363481711 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -411,11 +411,11 @@ func (r *RootCmd) ssh() *serpent.Command { slog.F("agent_status", workspaceAgent.Status), slog.F("immortal_fallback_enabled", immortalFallback)) - shouldFallback := immortalFallback && (strings.Contains(err.Error(), "too many immortal streams") || + shouldFallback := immortalFallback && (strings.Contains(err.Error(), "Too many immortal streams") || strings.Contains(err.Error(), "The connection was refused")) if shouldFallback { - if strings.Contains(err.Error(), "too many immortal streams") { + if strings.Contains(err.Error(), "Too many immortal streams") { logger.Warn(ctx, "too many immortal streams, falling back to regular SSH connection", slog.F("max_streams", "32")) } else { @@ -528,11 +528,11 @@ func (r *RootCmd) ssh() *serpent.Command { slog.F("agent_status", workspaceAgent.Status), slog.F("immortal_fallback_enabled", immortalFallback)) - shouldFallback := immortalFallback && (strings.Contains(err.Error(), "too many immortal streams") || + shouldFallback := immortalFallback && (strings.Contains(err.Error(), "Too many immortal streams") || strings.Contains(err.Error(), "The connection was refused")) if shouldFallback { - if strings.Contains(err.Error(), "too many immortal streams") { + if strings.Contains(err.Error(), "Too many immortal streams") { logger.Warn(ctx, "too many immortal streams, falling back to regular SSH connection", slog.F("max_streams", "32")) } else { diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 9da3c2fb85127..09d3cb421438f 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -915,7 +915,7 @@ func (api *API) workspaceAgentCreateImmortalStream(rw http.ResponseWriter, r *ht stream, err := agentConn.CreateImmortalStream(ctx, req) if err != nil { // Check for specific error types from the agent - if strings.Contains(err.Error(), "too many immortal streams") { + if strings.Contains(err.Error(), "Too many Immortal Streams") { httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{ Message: "Too many immortal streams.", }) From c89f5f40211e0163e9fcc39571cd6255e1a1c972 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Tue, 19 Aug 2025 18:33:32 +0000 Subject: [PATCH 3/4] WIP --- cli/portforward.go | 100 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 92 insertions(+), 8 deletions(-) diff --git a/cli/portforward.go b/cli/portforward.go index d96a0d697d289..d29a17e8e5b6c 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -13,6 +13,7 @@ import ( "sync" "syscall" + "github.com/google/uuid" "golang.org/x/xerrors" "cdr.dev/slog" @@ -152,7 +153,7 @@ func (r *RootCmd) portForward() *serpent.Command { // first, opportunistically try to listen on IPv6 spec6 := spec spec6.listenHost = ipv6Loopback - l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger) + l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger, immortal, immortalFallback, client, workspaceAgent.ID) if err6 != nil { logger.Info(ctx, "failed to opportunistically listen on IPv6", slog.F("spec", spec), slog.Error(err6)) } else { @@ -160,7 +161,7 @@ func (r *RootCmd) portForward() *serpent.Command { } spec.listenHost = ipv4Loopback } - l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger) + l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger, immortal, immortalFallback, client, workspaceAgent.ID) if err != nil { logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err)) return err @@ -242,6 +243,10 @@ func listenAndPortForward( wg *sync.WaitGroup, spec portForwardSpec, logger slog.Logger, + immortal bool, + immortalFallback bool, + client *codersdk.Client, + agentID uuid.UUID, ) (net.Listener, error) { logger = logger.With( slog.F("network", spec.network), @@ -281,17 +286,96 @@ func listenAndPortForward( go func(netConn net.Conn) { defer netConn.Close() - remoteConn, err := conn.DialContext(ctx, spec.network, dialAddress) - if err != nil { - _, _ = fmt.Fprintf(inv.Stderr, - "Failed to dial '%s://%s' in workspace: %s\n", - spec.network, dialAddress, err) - return + + var remoteConn net.Conn + var immortalStreamClient *immortalStreamClient + var streamID *uuid.UUID + + // Only use immortal streams for TCP connections + if immortal && spec.network == "tcp" { + // Create immortal stream client + immortalStreamClient = newImmortalStreamClient(client, agentID, logger) + + // Create immortal stream to the target port + stream, err := immortalStreamClient.createStream(ctx, int(spec.dialPort)) + if err != nil { + logger.Error(ctx, "failed to create immortal stream for port forward", + slog.Error(err), + slog.F("agent_id", agentID), + slog.F("target_port", spec.dialPort), + slog.F("immortal_fallback_enabled", immortalFallback)) + + shouldFallback := immortalFallback && (strings.Contains(err.Error(), "Too many immortal streams") || + strings.Contains(err.Error(), "The connection was refused")) + + if shouldFallback { + if strings.Contains(err.Error(), "Too many immortal streams") { + logger.Warn(ctx, "too many immortal streams, falling back to regular port forward", + slog.F("max_streams", "32"), + slog.F("target_port", spec.dialPort)) + } else { + logger.Warn(ctx, "service not available, falling back to regular port forward", + slog.F("reason", "connection_refused"), + slog.F("target_port", spec.dialPort)) + } + logger.Debug(ctx, "attempting fallback to regular port forward") + remoteConn, err = conn.DialContext(ctx, spec.network, dialAddress) + if err != nil { + logger.Error(ctx, "fallback port forward also failed", slog.Error(err)) + _, _ = fmt.Fprintf(inv.Stderr, + "Failed to dial '%s://%s' in workspace: %s\n", + spec.network, dialAddress, err) + return + } + logger.Debug(ctx, "successfully connected via regular port forward fallback") + } else { + _, _ = fmt.Fprintf(inv.Stderr, + "Failed to create immortal stream for '%s://%s' in workspace: %s\n", + spec.network, dialAddress, err) + return + } + } else { + streamID = &stream.ID + logger.Debug(ctx, "created immortal stream for port forward", + slog.F("stream_name", stream.Name), + slog.F("stream_id", stream.ID), + slog.F("target_port", spec.dialPort)) + + // Connect to the immortal stream via WebSocket + remoteConn, err = connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger) + if err != nil { + // Clean up the stream if connection fails + _ = immortalStreamClient.deleteStream(ctx, stream.ID) + _, _ = fmt.Fprintf(inv.Stderr, + "Failed to connect to immortal stream for '%s://%s' in workspace: %s\n", + spec.network, dialAddress, err) + return + } + } + } else { + // Use regular connection for UDP or when immortal is disabled + remoteConn, err = conn.DialContext(ctx, spec.network, dialAddress) + if err != nil { + _, _ = fmt.Fprintf(inv.Stderr, + "Failed to dial '%s://%s' in workspace: %s\n", + spec.network, dialAddress, err) + return + } } + defer remoteConn.Close() logger.Debug(ctx, "dialed remote", slog.F("remote_addr", netConn.RemoteAddr())) + // Set up cleanup for immortal stream + if immortalStreamClient != nil && streamID != nil { + defer func() { + if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { + logger.Error(context.Background(), "failed to cleanup immortal stream", slog.Error(err)) + } + }() + } + agentssh.Bicopy(ctx, netConn, remoteConn) logger.Debug(ctx, "connection closing", slog.F("remote_addr", netConn.RemoteAddr())) From 5ed4b5df0399dfd0525c7b8e630d54f9803885df Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Wed, 20 Aug 2025 11:50:52 +0000 Subject: [PATCH 4/4] WIP --- agent/agent.go | 97 ++++++++++++++++++++++++-------------- cli/ssh.go | 124 +++++++++++++++++++++++++------------------------ 2 files changed, 125 insertions(+), 96 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 4cefcfa9f8616..f8ad5eb73f1a9 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -70,44 +70,13 @@ const ( EnvProcOOMScore = "CODER_PROC_OOM_SCORE" ) -// agentImmortalDialer is a custom dialer for immortal streams that can -// connect to the agent's own services via tailnet addresses. +// agentImmortalDialer wraps the standard dialer for immortal streams. +// Agent services are available on both tailnet and localhost interfaces. type agentImmortalDialer struct { - agent *agent standardDialer *net.Dialer } func (d *agentImmortalDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - host, portStr, err := net.SplitHostPort(address) - if err != nil { - return nil, xerrors.Errorf("split host port %q: %w", address, err) - } - - port, err := strconv.Atoi(portStr) - if err != nil { - return nil, xerrors.Errorf("parse port %q: %w", portStr, err) - } - - // Check if this is a connection to one of the agent's own services - isLocalhost := host == "localhost" || host == "127.0.0.1" || host == "::1" - isAgentPort := port == int(workspacesdk.AgentSSHPort) || port == int(workspacesdk.AgentHTTPAPIServerPort) || - port == int(workspacesdk.AgentReconnectingPTYPort) || port == int(workspacesdk.AgentSpeedtestPort) - - if isLocalhost && isAgentPort { - // Get the agent ID from the current manifest - manifest := d.agent.manifest.Load() - if manifest == nil || manifest.AgentID == uuid.Nil { - // Fallback to standard dialing if no manifest available yet - return d.standardDialer.DialContext(ctx, network, address) - } - - // Connect to the agent's own tailnet address instead of localhost - agentAddr := tailnet.TailscaleServicePrefix.AddrFromUUID(manifest.AgentID) - agentAddress := net.JoinHostPort(agentAddr.String(), portStr) - return d.standardDialer.DialContext(ctx, network, agentAddress) - } - - // For other addresses, use standard dialing return d.standardDialer.DialContext(ctx, network, address) } @@ -392,10 +361,8 @@ func (a *agent) init() { a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...) - // Initialize immortal streams manager with a custom dialer - // that can connect to the agent's own services + // Initialize immortal streams manager immortalDialer := &agentImmortalDialer{ - agent: a, standardDialer: &net.Dialer{}, } a.immortalStreamsManager = immortalstreams.New(a.logger.Named("immortal-streams"), immortalDialer) @@ -1531,6 +1498,7 @@ func (a *agent) createTailnet( } for _, port := range []int{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort} { + // Listen on tailnet interface for external connections sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(port)) if err != nil { return nil, xerrors.Errorf("listen on the ssh port (%v): %w", port, err) @@ -1546,6 +1514,25 @@ func (a *agent) createTailnet( }); err != nil { return nil, err } + + // Also listen on localhost for immortal streams (only for SSH port 1) + if port == workspacesdk.AgentSSHPort { + localhostListener, err := net.Listen("tcp", "127.0.0.1:"+strconv.Itoa(port)) + if err != nil { + return nil, xerrors.Errorf("listen on localhost ssh port (%v): %w", port, err) + } + // nolint:revive // We do want to run the deferred functions when createTailnet returns. + defer func() { + if err != nil { + _ = localhostListener.Close() + } + }() + if err = a.trackGoroutine(func() { + _ = a.sshServer.Serve(localhostListener) + }); err != nil { + return nil, err + } + } } reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentReconnectingPTYPort)) @@ -1616,6 +1603,7 @@ func (a *agent) createTailnet( return nil, err } + // Listen on tailnet interface for external connections apiListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentHTTPAPIServerPort)) if err != nil { return nil, xerrors.Errorf("api listener: %w", err) @@ -1652,6 +1640,43 @@ func (a *agent) createTailnet( return nil, err } + // Also listen on localhost for immortal streams WebSocket connections + localhostAPIListener, err := net.Listen("tcp", "127.0.0.1:"+strconv.Itoa(workspacesdk.AgentHTTPAPIServerPort)) + if err != nil { + return nil, xerrors.Errorf("localhost api listener: %w", err) + } + defer func() { + if err != nil { + _ = localhostAPIListener.Close() + } + }() + if err = a.trackGoroutine(func() { + defer localhostAPIListener.Close() + apiHandler := a.apiHandler() + server := &http.Server{ + BaseContext: func(net.Listener) context.Context { return ctx }, + Handler: apiHandler, + ReadTimeout: 20 * time.Second, + ReadHeaderTimeout: 20 * time.Second, + WriteTimeout: 20 * time.Second, + ErrorLog: slog.Stdlib(ctx, a.logger.Named("http_api_server_localhost"), slog.LevelInfo), + } + go func() { + select { + case <-ctx.Done(): + case <-a.hardCtx.Done(): + } + _ = server.Close() + }() + + apiServErr := server.Serve(localhostAPIListener) + if apiServErr != nil && !xerrors.Is(apiServErr, http.ErrServerClosed) && !strings.Contains(apiServErr.Error(), "use of closed network connection") { + a.logger.Critical(ctx, "serve localhost HTTP API server", slog.Error(apiServErr)) + } + }); err != nil { + return nil, err + } + return network, nil } diff --git a/cli/ssh.go b/cli/ssh.go index e299363481711..478473d294ee3 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -440,8 +440,10 @@ func (r *RootCmd) ssh() *serpent.Command { // Connect to the immortal stream via WebSocket rawSSH, err = connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger) if err != nil { - // Clean up the stream if connection fails - _ = immortalStreamClient.deleteStream(ctx, stream.ID) + // Only clean up the stream if it's a permanent failure + if !isNetworkError(err) { + _ = immortalStreamClient.deleteStream(ctx, stream.ID) + } return xerrors.Errorf("connect to immortal stream: %w", err) } } @@ -481,12 +483,17 @@ func (r *RootCmd) ssh() *serpent.Command { } } - // Set up cleanup for immortal stream + // Set up signal-based cleanup for immortal stream + // Only delete on explicit user termination (SIGINT, SIGTERM), not network errors if immortalStreamClient != nil && streamID != nil { - defer func() { - if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { - logger.Error(context.Background(), "failed to cleanup immortal stream", slog.Error(err)) - } + // Create a signal-only context for cleanup + signalCtx, signalStop := inv.SignalNotifyContext(context.Background(), StopSignals...) + defer signalStop() + + go func() { + <-signalCtx.Done() + // User sent termination signal - clean up the stream + _ = immortalStreamClient.deleteStream(context.Background(), *streamID) }() } @@ -494,12 +501,7 @@ func (r *RootCmd) ssh() *serpent.Command { go func() { defer wg.Done() watchAndClose(ctx, func() error { - // Clean up immortal stream on termination - if immortalStreamClient != nil && streamID != nil { - if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { - logger.Error(context.Background(), "failed to cleanup immortal stream on termination", slog.Error(err)) - } - } + // Don't delete immortal stream here - let signal handler do it stack.close(xerrors.New("watchAndClose")) return nil }, logger, client, workspace, errCh) @@ -557,8 +559,10 @@ func (r *RootCmd) ssh() *serpent.Command { // Connect to the immortal stream and create SSH client rawConn, err := connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger) if err != nil { - // Clean up the stream if connection fails - _ = immortalStreamClient.deleteStream(ctx, stream.ID) + // Only clean up the stream if it's a permanent failure + if !isNetworkError(err) { + _ = immortalStreamClient.deleteStream(ctx, stream.ID) + } return xerrors.Errorf("connect to immortal stream: %w", err) } @@ -569,7 +573,10 @@ func (r *RootCmd) ssh() *serpent.Command { }) if err != nil { rawConn.Close() - _ = immortalStreamClient.deleteStream(ctx, stream.ID) + // Only clean up the stream if it's a permanent failure + if !isNetworkError(err) { + _ = immortalStreamClient.deleteStream(ctx, stream.ID) + } return xerrors.Errorf("ssh handshake over immortal stream: %w", err) } @@ -603,12 +610,17 @@ func (r *RootCmd) ssh() *serpent.Command { } } - // Set up cleanup for immortal stream in regular SSH mode + // Set up signal-based cleanup for immortal stream + // Only delete on explicit user termination (SIGINT, SIGTERM), not network errors if immortalStreamClient != nil && streamID != nil { - defer func() { - if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { - logger.Error(context.Background(), "failed to cleanup immortal stream", slog.Error(err)) - } + // Create a signal-only context for cleanup + signalCtx, signalStop := inv.SignalNotifyContext(context.Background(), StopSignals...) + defer signalStop() + + go func() { + <-signalCtx.Done() + // User sent termination signal - clean up the stream + _ = immortalStreamClient.deleteStream(context.Background(), *streamID) }() } @@ -618,12 +630,7 @@ func (r *RootCmd) ssh() *serpent.Command { watchAndClose( ctx, func() error { - // Clean up immortal stream on termination - if immortalStreamClient != nil && streamID != nil { - if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { - logger.Error(context.Background(), "failed to cleanup immortal stream on termination", slog.Error(err)) - } - } + // Don't delete immortal stream here - let signal handler do it stack.close(xerrors.New("watchAndClose")) return nil }, @@ -923,66 +930,63 @@ func (r *RootCmd) ssh() *serpent.Command { return cmd } -// connectToImmortalStreamWebSocket connects to an immortal stream via WebSocket and returns a net.Conn +// connectToImmortalStreamWebSocket connects to an immortal stream via WebSocket +// The immortal stream infrastructure handles reconnection automatically func connectToImmortalStreamWebSocket(ctx context.Context, agentConn *workspacesdk.AgentConn, streamID uuid.UUID, logger slog.Logger) (net.Conn, error) { // Build the target address for the agent's HTTP API server - // We'll let the WebSocket dialer handle the actual connection through the agent apiServerAddr := fmt.Sprintf("127.0.0.1:%d", workspacesdk.AgentHTTPAPIServerPort) wsURL := fmt.Sprintf("ws://%s/api/v0/immortal-stream/%s", apiServerAddr, streamID) // Create WebSocket connection using the agent's tailnet connection - // The key is to use a custom dialer that routes through the agent connection dialOptions := &websocket.DialOptions{ HTTPClient: &http.Client{ Transport: &http.Transport{ DialContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) { - // Route all connections through the agent connection - // The agent connection will handle routing to the correct internal address - - conn, err := agentConn.DialContext(dialCtx, network, addr) - if err != nil { - return nil, err - } - - return conn, nil + return agentConn.DialContext(dialCtx, network, addr) }, }, }, - // Disable compression for raw TCP data CompressionMode: websocket.CompressionDisabled, } // Connect to the WebSocket endpoint - conn, res, err := websocket.Dial(ctx, wsURL, dialOptions) + conn, _, err := websocket.Dial(ctx, wsURL, dialOptions) if err != nil { - if res != nil { - logger.Error(ctx, "WebSocket dial failed", - slog.F("stream_id", streamID), - slog.F("websocket_url", wsURL), - slog.F("status", res.StatusCode), - slog.F("status_text", res.Status), - slog.Error(err)) - } else { - logger.Error(ctx, "WebSocket dial failed (no response)", - slog.F("stream_id", streamID), - slog.F("websocket_url", wsURL), - slog.Error(err)) - } return nil, xerrors.Errorf("dial immortal stream WebSocket: %w", err) } - logger.Info(ctx, "successfully connected to immortal stream WebSocket", - slog.F("stream_id", streamID)) - // Convert WebSocket to net.Conn for SSH usage - // Use MessageBinary for raw TCP data transport + // The immortal stream's BackedPipe handles reconnection automatically netConn := websocket.NetConn(ctx, conn, websocket.MessageBinary) - logger.Debug(ctx, "converted WebSocket to net.Conn for SSH usage") - return netConn, nil } +// isNetworkError checks if an error is a temporary network error +func isNetworkError(err error) bool { + if err == nil { + return false + } + + errStr := err.Error() + networkErrors := []string{ + "connection refused", + "network is unreachable", + "connection reset", + "broken pipe", + "timeout", + "no route to host", + } + + for _, netErr := range networkErrors { + if strings.Contains(errStr, netErr) { + return true + } + } + + return false +} + // findWorkspaceAndAgentByHostname parses the hostname from the commandline and finds the workspace and agent it // corresponds to, taking into account any name prefixes or suffixes configured (e.g. myworkspace.coder, or // vscode-coder--myusername--myworkspace). pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy