Skip to content

Commit 5ed4b5d

Browse files
committed
WIP
1 parent c89f5f4 commit 5ed4b5d

File tree

2 files changed

+125
-96
lines changed

2 files changed

+125
-96
lines changed

agent/agent.go

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -70,44 +70,13 @@ const (
7070
EnvProcOOMScore = "CODER_PROC_OOM_SCORE"
7171
)
7272

73-
// agentImmortalDialer is a custom dialer for immortal streams that can
74-
// connect to the agent's own services via tailnet addresses.
73+
// agentImmortalDialer wraps the standard dialer for immortal streams.
74+
// Agent services are available on both tailnet and localhost interfaces.
7575
type agentImmortalDialer struct {
76-
agent *agent
7776
standardDialer *net.Dialer
7877
}
7978

8079
func (d *agentImmortalDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
81-
host, portStr, err := net.SplitHostPort(address)
82-
if err != nil {
83-
return nil, xerrors.Errorf("split host port %q: %w", address, err)
84-
}
85-
86-
port, err := strconv.Atoi(portStr)
87-
if err != nil {
88-
return nil, xerrors.Errorf("parse port %q: %w", portStr, err)
89-
}
90-
91-
// Check if this is a connection to one of the agent's own services
92-
isLocalhost := host == "localhost" || host == "127.0.0.1" || host == "::1"
93-
isAgentPort := port == int(workspacesdk.AgentSSHPort) || port == int(workspacesdk.AgentHTTPAPIServerPort) ||
94-
port == int(workspacesdk.AgentReconnectingPTYPort) || port == int(workspacesdk.AgentSpeedtestPort)
95-
96-
if isLocalhost && isAgentPort {
97-
// Get the agent ID from the current manifest
98-
manifest := d.agent.manifest.Load()
99-
if manifest == nil || manifest.AgentID == uuid.Nil {
100-
// Fallback to standard dialing if no manifest available yet
101-
return d.standardDialer.DialContext(ctx, network, address)
102-
}
103-
104-
// Connect to the agent's own tailnet address instead of localhost
105-
agentAddr := tailnet.TailscaleServicePrefix.AddrFromUUID(manifest.AgentID)
106-
agentAddress := net.JoinHostPort(agentAddr.String(), portStr)
107-
return d.standardDialer.DialContext(ctx, network, agentAddress)
108-
}
109-
110-
// For other addresses, use standard dialing
11180
return d.standardDialer.DialContext(ctx, network, address)
11281
}
11382

@@ -392,10 +361,8 @@ func (a *agent) init() {
392361

393362
a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...)
394363

395-
// Initialize immortal streams manager with a custom dialer
396-
// that can connect to the agent's own services
364+
// Initialize immortal streams manager
397365
immortalDialer := &agentImmortalDialer{
398-
agent: a,
399366
standardDialer: &net.Dialer{},
400367
}
401368
a.immortalStreamsManager = immortalstreams.New(a.logger.Named("immortal-streams"), immortalDialer)
@@ -1531,6 +1498,7 @@ func (a *agent) createTailnet(
15311498
}
15321499

15331500
for _, port := range []int{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort} {
1501+
// Listen on tailnet interface for external connections
15341502
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(port))
15351503
if err != nil {
15361504
return nil, xerrors.Errorf("listen on the ssh port (%v): %w", port, err)
@@ -1546,6 +1514,25 @@ func (a *agent) createTailnet(
15461514
}); err != nil {
15471515
return nil, err
15481516
}
1517+
1518+
// Also listen on localhost for immortal streams (only for SSH port 1)
1519+
if port == workspacesdk.AgentSSHPort {
1520+
localhostListener, err := net.Listen("tcp", "127.0.0.1:"+strconv.Itoa(port))
1521+
if err != nil {
1522+
return nil, xerrors.Errorf("listen on localhost ssh port (%v): %w", port, err)
1523+
}
1524+
// nolint:revive // We do want to run the deferred functions when createTailnet returns.
1525+
defer func() {
1526+
if err != nil {
1527+
_ = localhostListener.Close()
1528+
}
1529+
}()
1530+
if err = a.trackGoroutine(func() {
1531+
_ = a.sshServer.Serve(localhostListener)
1532+
}); err != nil {
1533+
return nil, err
1534+
}
1535+
}
15491536
}
15501537

15511538
reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentReconnectingPTYPort))
@@ -1616,6 +1603,7 @@ func (a *agent) createTailnet(
16161603
return nil, err
16171604
}
16181605

1606+
// Listen on tailnet interface for external connections
16191607
apiListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentHTTPAPIServerPort))
16201608
if err != nil {
16211609
return nil, xerrors.Errorf("api listener: %w", err)
@@ -1652,6 +1640,43 @@ func (a *agent) createTailnet(
16521640
return nil, err
16531641
}
16541642

1643+
// Also listen on localhost for immortal streams WebSocket connections
1644+
localhostAPIListener, err := net.Listen("tcp", "127.0.0.1:"+strconv.Itoa(workspacesdk.AgentHTTPAPIServerPort))
1645+
if err != nil {
1646+
return nil, xerrors.Errorf("localhost api listener: %w", err)
1647+
}
1648+
defer func() {
1649+
if err != nil {
1650+
_ = localhostAPIListener.Close()
1651+
}
1652+
}()
1653+
if err = a.trackGoroutine(func() {
1654+
defer localhostAPIListener.Close()
1655+
apiHandler := a.apiHandler()
1656+
server := &http.Server{
1657+
BaseContext: func(net.Listener) context.Context { return ctx },
1658+
Handler: apiHandler,
1659+
ReadTimeout: 20 * time.Second,
1660+
ReadHeaderTimeout: 20 * time.Second,
1661+
WriteTimeout: 20 * time.Second,
1662+
ErrorLog: slog.Stdlib(ctx, a.logger.Named("http_api_server_localhost"), slog.LevelInfo),
1663+
}
1664+
go func() {
1665+
select {
1666+
case <-ctx.Done():
1667+
case <-a.hardCtx.Done():
1668+
}
1669+
_ = server.Close()
1670+
}()
1671+
1672+
apiServErr := server.Serve(localhostAPIListener)
1673+
if apiServErr != nil && !xerrors.Is(apiServErr, http.ErrServerClosed) && !strings.Contains(apiServErr.Error(), "use of closed network connection") {
1674+
a.logger.Critical(ctx, "serve localhost HTTP API server", slog.Error(apiServErr))
1675+
}
1676+
}); err != nil {
1677+
return nil, err
1678+
}
1679+
16551680
return network, nil
16561681
}
16571682

cli/ssh.go

Lines changed: 64 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,10 @@ func (r *RootCmd) ssh() *serpent.Command {
440440
// Connect to the immortal stream via WebSocket
441441
rawSSH, err = connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger)
442442
if err != nil {
443-
// Clean up the stream if connection fails
444-
_ = immortalStreamClient.deleteStream(ctx, stream.ID)
443+
// Only clean up the stream if it's a permanent failure
444+
if !isNetworkError(err) {
445+
_ = immortalStreamClient.deleteStream(ctx, stream.ID)
446+
}
445447
return xerrors.Errorf("connect to immortal stream: %w", err)
446448
}
447449
}
@@ -481,25 +483,25 @@ func (r *RootCmd) ssh() *serpent.Command {
481483
}
482484
}
483485

484-
// Set up cleanup for immortal stream
486+
// Set up signal-based cleanup for immortal stream
487+
// Only delete on explicit user termination (SIGINT, SIGTERM), not network errors
485488
if immortalStreamClient != nil && streamID != nil {
486-
defer func() {
487-
if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil {
488-
logger.Error(context.Background(), "failed to cleanup immortal stream", slog.Error(err))
489-
}
489+
// Create a signal-only context for cleanup
490+
signalCtx, signalStop := inv.SignalNotifyContext(context.Background(), StopSignals...)
491+
defer signalStop()
492+
493+
go func() {
494+
<-signalCtx.Done()
495+
// User sent termination signal - clean up the stream
496+
_ = immortalStreamClient.deleteStream(context.Background(), *streamID)
490497
}()
491498
}
492499

493500
wg.Add(1)
494501
go func() {
495502
defer wg.Done()
496503
watchAndClose(ctx, func() error {
497-
// Clean up immortal stream on termination
498-
if immortalStreamClient != nil && streamID != nil {
499-
if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil {
500-
logger.Error(context.Background(), "failed to cleanup immortal stream on termination", slog.Error(err))
501-
}
502-
}
504+
// Don't delete immortal stream here - let signal handler do it
503505
stack.close(xerrors.New("watchAndClose"))
504506
return nil
505507
}, logger, client, workspace, errCh)
@@ -557,8 +559,10 @@ func (r *RootCmd) ssh() *serpent.Command {
557559
// Connect to the immortal stream and create SSH client
558560
rawConn, err := connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger)
559561
if err != nil {
560-
// Clean up the stream if connection fails
561-
_ = immortalStreamClient.deleteStream(ctx, stream.ID)
562+
// Only clean up the stream if it's a permanent failure
563+
if !isNetworkError(err) {
564+
_ = immortalStreamClient.deleteStream(ctx, stream.ID)
565+
}
562566
return xerrors.Errorf("connect to immortal stream: %w", err)
563567
}
564568

@@ -569,7 +573,10 @@ func (r *RootCmd) ssh() *serpent.Command {
569573
})
570574
if err != nil {
571575
rawConn.Close()
572-
_ = immortalStreamClient.deleteStream(ctx, stream.ID)
576+
// Only clean up the stream if it's a permanent failure
577+
if !isNetworkError(err) {
578+
_ = immortalStreamClient.deleteStream(ctx, stream.ID)
579+
}
573580
return xerrors.Errorf("ssh handshake over immortal stream: %w", err)
574581
}
575582

@@ -603,12 +610,17 @@ func (r *RootCmd) ssh() *serpent.Command {
603610
}
604611
}
605612

606-
// Set up cleanup for immortal stream in regular SSH mode
613+
// Set up signal-based cleanup for immortal stream
614+
// Only delete on explicit user termination (SIGINT, SIGTERM), not network errors
607615
if immortalStreamClient != nil && streamID != nil {
608-
defer func() {
609-
if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil {
610-
logger.Error(context.Background(), "failed to cleanup immortal stream", slog.Error(err))
611-
}
616+
// Create a signal-only context for cleanup
617+
signalCtx, signalStop := inv.SignalNotifyContext(context.Background(), StopSignals...)
618+
defer signalStop()
619+
620+
go func() {
621+
<-signalCtx.Done()
622+
// User sent termination signal - clean up the stream
623+
_ = immortalStreamClient.deleteStream(context.Background(), *streamID)
612624
}()
613625
}
614626

@@ -618,12 +630,7 @@ func (r *RootCmd) ssh() *serpent.Command {
618630
watchAndClose(
619631
ctx,
620632
func() error {
621-
// Clean up immortal stream on termination
622-
if immortalStreamClient != nil && streamID != nil {
623-
if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil {
624-
logger.Error(context.Background(), "failed to cleanup immortal stream on termination", slog.Error(err))
625-
}
626-
}
633+
// Don't delete immortal stream here - let signal handler do it
627634
stack.close(xerrors.New("watchAndClose"))
628635
return nil
629636
},
@@ -923,66 +930,63 @@ func (r *RootCmd) ssh() *serpent.Command {
923930
return cmd
924931
}
925932

926-
// connectToImmortalStreamWebSocket connects to an immortal stream via WebSocket and returns a net.Conn
933+
// connectToImmortalStreamWebSocket connects to an immortal stream via WebSocket
934+
// The immortal stream infrastructure handles reconnection automatically
927935
func connectToImmortalStreamWebSocket(ctx context.Context, agentConn *workspacesdk.AgentConn, streamID uuid.UUID, logger slog.Logger) (net.Conn, error) {
928936
// Build the target address for the agent's HTTP API server
929-
// We'll let the WebSocket dialer handle the actual connection through the agent
930937
apiServerAddr := fmt.Sprintf("127.0.0.1:%d", workspacesdk.AgentHTTPAPIServerPort)
931938
wsURL := fmt.Sprintf("ws://%s/api/v0/immortal-stream/%s", apiServerAddr, streamID)
932939

933940
// Create WebSocket connection using the agent's tailnet connection
934-
// The key is to use a custom dialer that routes through the agent connection
935941
dialOptions := &websocket.DialOptions{
936942
HTTPClient: &http.Client{
937943
Transport: &http.Transport{
938944
DialContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) {
939-
// Route all connections through the agent connection
940-
// The agent connection will handle routing to the correct internal address
941-
942-
conn, err := agentConn.DialContext(dialCtx, network, addr)
943-
if err != nil {
944-
return nil, err
945-
}
946-
947-
return conn, nil
945+
return agentConn.DialContext(dialCtx, network, addr)
948946
},
949947
},
950948
},
951-
// Disable compression for raw TCP data
952949
CompressionMode: websocket.CompressionDisabled,
953950
}
954951

955952
// Connect to the WebSocket endpoint
956-
conn, res, err := websocket.Dial(ctx, wsURL, dialOptions)
953+
conn, _, err := websocket.Dial(ctx, wsURL, dialOptions)
957954
if err != nil {
958-
if res != nil {
959-
logger.Error(ctx, "WebSocket dial failed",
960-
slog.F("stream_id", streamID),
961-
slog.F("websocket_url", wsURL),
962-
slog.F("status", res.StatusCode),
963-
slog.F("status_text", res.Status),
964-
slog.Error(err))
965-
} else {
966-
logger.Error(ctx, "WebSocket dial failed (no response)",
967-
slog.F("stream_id", streamID),
968-
slog.F("websocket_url", wsURL),
969-
slog.Error(err))
970-
}
971955
return nil, xerrors.Errorf("dial immortal stream WebSocket: %w", err)
972956
}
973957

974-
logger.Info(ctx, "successfully connected to immortal stream WebSocket",
975-
slog.F("stream_id", streamID))
976-
977958
// Convert WebSocket to net.Conn for SSH usage
978-
// Use MessageBinary for raw TCP data transport
959+
// The immortal stream's BackedPipe handles reconnection automatically
979960
netConn := websocket.NetConn(ctx, conn, websocket.MessageBinary)
980961

981-
logger.Debug(ctx, "converted WebSocket to net.Conn for SSH usage")
982-
983962
return netConn, nil
984963
}
985964

965+
// isNetworkError checks if an error is a temporary network error
966+
func isNetworkError(err error) bool {
967+
if err == nil {
968+
return false
969+
}
970+
971+
errStr := err.Error()
972+
networkErrors := []string{
973+
"connection refused",
974+
"network is unreachable",
975+
"connection reset",
976+
"broken pipe",
977+
"timeout",
978+
"no route to host",
979+
}
980+
981+
for _, netErr := range networkErrors {
982+
if strings.Contains(errStr, netErr) {
983+
return true
984+
}
985+
}
986+
987+
return false
988+
}
989+
986990
// findWorkspaceAndAgentByHostname parses the hostname from the commandline and finds the workspace and agent it
987991
// corresponds to, taking into account any name prefixes or suffixes configured (e.g. myworkspace.coder, or
988992
// vscode-coder--myusername--myworkspace).

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy