Skip to content

Commit e73338d

Browse files
committed
WIP
1 parent dccaf9b commit e73338d

File tree

1 file changed

+92
-8
lines changed

1 file changed

+92
-8
lines changed

cli/portforward.go

Lines changed: 92 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"sync"
1414
"syscall"
1515

16+
"github.com/google/uuid"
1617
"golang.org/x/xerrors"
1718

1819
"cdr.dev/slog"
@@ -152,15 +153,15 @@ func (r *RootCmd) portForward() *serpent.Command {
152153
// first, opportunistically try to listen on IPv6
153154
spec6 := spec
154155
spec6.listenHost = ipv6Loopback
155-
l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger)
156+
l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger, immortal, immortalFallback, client, workspaceAgent.ID)
156157
if err6 != nil {
157158
logger.Info(ctx, "failed to opportunistically listen on IPv6", slog.F("spec", spec), slog.Error(err6))
158159
} else {
159160
listeners = append(listeners, l6)
160161
}
161162
spec.listenHost = ipv4Loopback
162163
}
163-
l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger)
164+
l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger, immortal, immortalFallback, client, workspaceAgent.ID)
164165
if err != nil {
165166
logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err))
166167
return err
@@ -242,6 +243,10 @@ func listenAndPortForward(
242243
wg *sync.WaitGroup,
243244
spec portForwardSpec,
244245
logger slog.Logger,
246+
immortal bool,
247+
immortalFallback bool,
248+
client *codersdk.Client,
249+
agentID uuid.UUID,
245250
) (net.Listener, error) {
246251
logger = logger.With(
247252
slog.F("network", spec.network),
@@ -281,17 +286,96 @@ func listenAndPortForward(
281286

282287
go func(netConn net.Conn) {
283288
defer netConn.Close()
284-
remoteConn, err := conn.DialContext(ctx, spec.network, dialAddress)
285-
if err != nil {
286-
_, _ = fmt.Fprintf(inv.Stderr,
287-
"Failed to dial '%s://%s' in workspace: %s\n",
288-
spec.network, dialAddress, err)
289-
return
289+
290+
var remoteConn net.Conn
291+
var immortalStreamClient *immortalStreamClient
292+
var streamID *uuid.UUID
293+
294+
// Only use immortal streams for TCP connections
295+
if immortal && spec.network == "tcp" {
296+
// Create immortal stream client
297+
immortalStreamClient = newImmortalStreamClient(client, agentID, logger)
298+
299+
// Create immortal stream to the target port
300+
stream, err := immortalStreamClient.createStream(ctx, int(spec.dialPort))
301+
if err != nil {
302+
logger.Error(ctx, "failed to create immortal stream for port forward",
303+
slog.Error(err),
304+
slog.F("agent_id", agentID),
305+
slog.F("target_port", spec.dialPort),
306+
slog.F("immortal_fallback_enabled", immortalFallback))
307+
308+
shouldFallback := immortalFallback && (strings.Contains(err.Error(), "Too many immortal streams") ||
309+
strings.Contains(err.Error(), "The connection was refused"))
310+
311+
if shouldFallback {
312+
if strings.Contains(err.Error(), "Too many immortal streams") {
313+
logger.Warn(ctx, "too many immortal streams, falling back to regular port forward",
314+
slog.F("max_streams", "32"),
315+
slog.F("target_port", spec.dialPort))
316+
} else {
317+
logger.Warn(ctx, "service not available, falling back to regular port forward",
318+
slog.F("reason", "connection_refused"),
319+
slog.F("target_port", spec.dialPort))
320+
}
321+
logger.Debug(ctx, "attempting fallback to regular port forward")
322+
remoteConn, err = conn.DialContext(ctx, spec.network, dialAddress)
323+
if err != nil {
324+
logger.Error(ctx, "fallback port forward also failed", slog.Error(err))
325+
_, _ = fmt.Fprintf(inv.Stderr,
326+
"Failed to dial '%s://%s' in workspace: %s\n",
327+
spec.network, dialAddress, err)
328+
return
329+
}
330+
logger.Debug(ctx, "successfully connected via regular port forward fallback")
331+
} else {
332+
_, _ = fmt.Fprintf(inv.Stderr,
333+
"Failed to create immortal stream for '%s://%s' in workspace: %s\n",
334+
spec.network, dialAddress, err)
335+
return
336+
}
337+
} else {
338+
streamID = &stream.ID
339+
logger.Debug(ctx, "created immortal stream for port forward",
340+
slog.F("stream_name", stream.Name),
341+
slog.F("stream_id", stream.ID),
342+
slog.F("target_port", spec.dialPort))
343+
344+
// Connect to the immortal stream via WebSocket
345+
remoteConn, err = connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger)
346+
if err != nil {
347+
// Clean up the stream if connection fails
348+
_ = immortalStreamClient.deleteStream(ctx, stream.ID)
349+
_, _ = fmt.Fprintf(inv.Stderr,
350+
"Failed to connect to immortal stream for '%s://%s' in workspace: %s\n",
351+
spec.network, dialAddress, err)
352+
return
353+
}
354+
}
355+
} else {
356+
// Use regular connection for UDP or when immortal is disabled
357+
remoteConn, err = conn.DialContext(ctx, spec.network, dialAddress)
358+
if err != nil {
359+
_, _ = fmt.Fprintf(inv.Stderr,
360+
"Failed to dial '%s://%s' in workspace: %s\n",
361+
spec.network, dialAddress, err)
362+
return
363+
}
290364
}
365+
291366
defer remoteConn.Close()
292367
logger.Debug(ctx,
293368
"dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))
294369

370+
// Set up cleanup for immortal stream
371+
if immortalStreamClient != nil && streamID != nil {
372+
defer func() {
373+
if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil {
374+
logger.Error(context.Background(), "failed to cleanup immortal stream", slog.Error(err))
375+
}
376+
}()
377+
}
378+
295379
agentssh.Bicopy(ctx, netConn, remoteConn)
296380
logger.Debug(ctx,
297381
"connection closing", slog.F("remote_addr", netConn.RemoteAddr()))

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy