@@ -13,6 +13,7 @@ import (
13
13
"sync"
14
14
"syscall"
15
15
16
+ "github.com/google/uuid"
16
17
"golang.org/x/xerrors"
17
18
18
19
"cdr.dev/slog"
@@ -152,15 +153,15 @@ func (r *RootCmd) portForward() *serpent.Command {
152
153
// first, opportunistically try to listen on IPv6
153
154
spec6 := spec
154
155
spec6 .listenHost = ipv6Loopback
155
- l6 , err6 := listenAndPortForward (ctx , inv , conn , wg , spec6 , logger )
156
+ l6 , err6 := listenAndPortForward (ctx , inv , conn , wg , spec6 , logger , immortal , immortalFallback , client , workspaceAgent . ID )
156
157
if err6 != nil {
157
158
logger .Info (ctx , "failed to opportunistically listen on IPv6" , slog .F ("spec" , spec ), slog .Error (err6 ))
158
159
} else {
159
160
listeners = append (listeners , l6 )
160
161
}
161
162
spec .listenHost = ipv4Loopback
162
163
}
163
- l , err := listenAndPortForward (ctx , inv , conn , wg , spec , logger )
164
+ l , err := listenAndPortForward (ctx , inv , conn , wg , spec , logger , immortal , immortalFallback , client , workspaceAgent . ID )
164
165
if err != nil {
165
166
logger .Error (ctx , "failed to listen" , slog .F ("spec" , spec ), slog .Error (err ))
166
167
return err
@@ -242,6 +243,10 @@ func listenAndPortForward(
242
243
wg * sync.WaitGroup ,
243
244
spec portForwardSpec ,
244
245
logger slog.Logger ,
246
+ immortal bool ,
247
+ immortalFallback bool ,
248
+ client * codersdk.Client ,
249
+ agentID uuid.UUID ,
245
250
) (net.Listener , error ) {
246
251
logger = logger .With (
247
252
slog .F ("network" , spec .network ),
@@ -281,17 +286,96 @@ func listenAndPortForward(
281
286
282
287
go func (netConn net.Conn ) {
283
288
defer netConn .Close ()
284
- remoteConn , err := conn .DialContext (ctx , spec .network , dialAddress )
285
- if err != nil {
286
- _ , _ = fmt .Fprintf (inv .Stderr ,
287
- "Failed to dial '%s://%s' in workspace: %s\n " ,
288
- spec .network , dialAddress , err )
289
- return
289
+
290
+ var remoteConn net.Conn
291
+ var immortalStreamClient * immortalStreamClient
292
+ var streamID * uuid.UUID
293
+
294
+ // Only use immortal streams for TCP connections
295
+ if immortal && spec .network == "tcp" {
296
+ // Create immortal stream client
297
+ immortalStreamClient = newImmortalStreamClient (client , agentID , logger )
298
+
299
+ // Create immortal stream to the target port
300
+ stream , err := immortalStreamClient .createStream (ctx , int (spec .dialPort ))
301
+ if err != nil {
302
+ logger .Error (ctx , "failed to create immortal stream for port forward" ,
303
+ slog .Error (err ),
304
+ slog .F ("agent_id" , agentID ),
305
+ slog .F ("target_port" , spec .dialPort ),
306
+ slog .F ("immortal_fallback_enabled" , immortalFallback ))
307
+
308
+ shouldFallback := immortalFallback && (strings .Contains (err .Error (), "Too many immortal streams" ) ||
309
+ strings .Contains (err .Error (), "The connection was refused" ))
310
+
311
+ if shouldFallback {
312
+ if strings .Contains (err .Error (), "Too many immortal streams" ) {
313
+ logger .Warn (ctx , "too many immortal streams, falling back to regular port forward" ,
314
+ slog .F ("max_streams" , "32" ),
315
+ slog .F ("target_port" , spec .dialPort ))
316
+ } else {
317
+ logger .Warn (ctx , "service not available, falling back to regular port forward" ,
318
+ slog .F ("reason" , "connection_refused" ),
319
+ slog .F ("target_port" , spec .dialPort ))
320
+ }
321
+ logger .Debug (ctx , "attempting fallback to regular port forward" )
322
+ remoteConn , err = conn .DialContext (ctx , spec .network , dialAddress )
323
+ if err != nil {
324
+ logger .Error (ctx , "fallback port forward also failed" , slog .Error (err ))
325
+ _ , _ = fmt .Fprintf (inv .Stderr ,
326
+ "Failed to dial '%s://%s' in workspace: %s\n " ,
327
+ spec .network , dialAddress , err )
328
+ return
329
+ }
330
+ logger .Debug (ctx , "successfully connected via regular port forward fallback" )
331
+ } else {
332
+ _ , _ = fmt .Fprintf (inv .Stderr ,
333
+ "Failed to create immortal stream for '%s://%s' in workspace: %s\n " ,
334
+ spec .network , dialAddress , err )
335
+ return
336
+ }
337
+ } else {
338
+ streamID = & stream .ID
339
+ logger .Debug (ctx , "created immortal stream for port forward" ,
340
+ slog .F ("stream_name" , stream .Name ),
341
+ slog .F ("stream_id" , stream .ID ),
342
+ slog .F ("target_port" , spec .dialPort ))
343
+
344
+ // Connect to the immortal stream via WebSocket
345
+ remoteConn , err = connectToImmortalStreamWebSocket (ctx , conn , stream .ID , logger )
346
+ if err != nil {
347
+ // Clean up the stream if connection fails
348
+ _ = immortalStreamClient .deleteStream (ctx , stream .ID )
349
+ _ , _ = fmt .Fprintf (inv .Stderr ,
350
+ "Failed to connect to immortal stream for '%s://%s' in workspace: %s\n " ,
351
+ spec .network , dialAddress , err )
352
+ return
353
+ }
354
+ }
355
+ } else {
356
+ // Use regular connection for UDP or when immortal is disabled
357
+ remoteConn , err = conn .DialContext (ctx , spec .network , dialAddress )
358
+ if err != nil {
359
+ _ , _ = fmt .Fprintf (inv .Stderr ,
360
+ "Failed to dial '%s://%s' in workspace: %s\n " ,
361
+ spec .network , dialAddress , err )
362
+ return
363
+ }
290
364
}
365
+
291
366
defer remoteConn .Close ()
292
367
logger .Debug (ctx ,
293
368
"dialed remote" , slog .F ("remote_addr" , netConn .RemoteAddr ()))
294
369
370
+ // Set up cleanup for immortal stream
371
+ if immortalStreamClient != nil && streamID != nil {
372
+ defer func () {
373
+ if err := immortalStreamClient .deleteStream (context .Background (), * streamID ); err != nil {
374
+ logger .Error (context .Background (), "failed to cleanup immortal stream" , slog .Error (err ))
375
+ }
376
+ }()
377
+ }
378
+
295
379
agentssh .Bicopy (ctx , netConn , remoteConn )
296
380
logger .Debug (ctx ,
297
381
"connection closing" , slog .F ("remote_addr" , netConn .RemoteAddr ()))
0 commit comments