Skip to content

Commit 9bcaa2f

Browse files
committed
WIP
1 parent 57b68a4 commit 9bcaa2f

File tree

4 files changed

+156
-124
lines changed

4 files changed

+156
-124
lines changed

agent/immortalstreams/manager.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ func (m *Manager) CreateStream(ctx context.Context, port int) (*codersdk.Immorta
8080
name,
8181
port,
8282
m.logger.With(slog.F("stream_id", id), slog.F("stream_name", name)),
83+
m.dialer,
8384
)
8485

8586
// Start the stream

agent/immortalstreams/manager_test.go

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,24 @@ func TestManager_CreateStream(t *testing.T) {
4040
if err != nil {
4141
return
4242
}
43-
// Just echo for testing
44-
go func() {
45-
defer conn.Close()
46-
_, _ = io.Copy(conn, conn)
47-
}()
43+
// Just echo for testing with proper cleanup
44+
go func(c net.Conn) {
45+
defer c.Close()
46+
// Use a buffer to avoid blocking indefinitely
47+
buf := make([]byte, 1024)
48+
for {
49+
n, err := c.Read(buf)
50+
if err != nil {
51+
return
52+
}
53+
if n > 0 {
54+
_, err = c.Write(buf[:n])
55+
if err != nil {
56+
return
57+
}
58+
}
59+
}
60+
}(conn)
4861
}
4962
}()
5063

agent/immortalstreams/stream.go

Lines changed: 125 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package immortalstreams
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"io"
78
"sync"
89
"time"
@@ -22,6 +23,7 @@ type Stream struct {
2223
port int
2324
createdAt time.Time
2425
logger slog.Logger
26+
dialer Dialer
2527

2628
mu sync.RWMutex
2729
localConn io.ReadWriteCloser
@@ -56,8 +58,7 @@ type Stream struct {
5658

5759
// reconnectRequest represents a pending reconnection request
5860
type reconnectRequest struct {
59-
writerSeqNum uint64
60-
response chan reconnectResponse
61+
response chan reconnectResponse
6162
}
6263

6364
// reconnectResponse represents a reconnection response
@@ -67,83 +68,75 @@ type reconnectResponse struct {
6768
err error
6869
}
6970

71+
// streamReconnector implements the backedpipe.Reconnector interface for Stream
72+
type streamReconnector struct {
73+
stream *Stream
74+
// Track the current client connection so we can close it during reconnect
75+
mu sync.Mutex
76+
currentConn io.ReadWriteCloser
77+
}
78+
79+
// Reconnect implements the backedpipe.Reconnector interface
80+
func (sr *streamReconnector) Reconnect(ctx context.Context, writerSeqNum uint64) (io.ReadWriteCloser, uint64, error) {
81+
sr.stream.logger.Info(context.Background(), "reconnector attempting to dial local connection",
82+
slog.F("port", sr.stream.port),
83+
slog.F("writer_seq", writerSeqNum))
84+
85+
// Dial the local TCP port directly
86+
conn, err := sr.stream.dialer.DialContext(ctx, "tcp", fmt.Sprintf("localhost:%d", sr.stream.port))
87+
if err != nil {
88+
sr.stream.logger.Warn(context.Background(), "failed to dial local connection",
89+
slog.Error(err),
90+
slog.F("port", sr.stream.port))
91+
return nil, 0, err
92+
}
93+
94+
sr.stream.logger.Info(context.Background(), "successfully dialed local connection",
95+
slog.F("port", sr.stream.port))
96+
97+
// Store the new connection for tracking
98+
sr.mu.Lock()
99+
sr.currentConn = conn
100+
sr.mu.Unlock()
101+
102+
// Update stream state
103+
sr.stream.mu.Lock()
104+
sr.stream.connected = true
105+
sr.stream.lastConnectionAt = time.Now()
106+
if sr.stream.reconnectCond != nil {
107+
sr.stream.reconnectCond.Broadcast()
108+
}
109+
sr.stream.mu.Unlock()
110+
111+
return conn, 0, nil // Start from sequence 0 for new connections
112+
}
113+
70114
// NewStream creates a new immortal stream
71-
func NewStream(id uuid.UUID, name string, port int, logger slog.Logger) *Stream {
115+
func NewStream(id uuid.UUID, name string, port int, logger slog.Logger, dialer Dialer) *Stream {
72116
stream := &Stream{
73117
id: id,
74118
name: name,
75119
port: port,
76120
createdAt: time.Now(),
77121
logger: logger,
122+
dialer: dialer,
78123
disconnectChan: make(chan struct{}, 1),
79124
shutdownChan: make(chan struct{}),
80125
reconnectReq: make(chan struct{}, 1),
81126
}
82127
stream.reconnectCond = sync.NewCond(&stream.mu)
83128

84-
// Create a reconnect function that waits for a client connection
85-
reconnectFn := func(ctx context.Context, writerSeqNum uint64) (io.ReadWriteCloser, uint64, error) {
86-
// Wait for HandleReconnect to be called with a new connection
87-
responseChan := make(chan reconnectResponse, 1)
88-
89-
stream.mu.Lock()
90-
stream.pendingReconnect = &reconnectRequest{
91-
writerSeqNum: writerSeqNum,
92-
response: responseChan,
93-
}
94-
stream.handshakePending = true
95-
// Mark disconnected if we previously had a client connection
96-
if stream.connected {
97-
stream.connected = false
98-
stream.lastDisconnectionAt = time.Now()
99-
}
100-
stream.logger.Info(context.Background(), "pending reconnect set",
101-
slog.F("writer_seq", writerSeqNum))
102-
// Signal waiters a reconnect request is pending
103-
stream.reconnectCond.Broadcast()
104-
stream.mu.Unlock()
105-
106-
// Fast path: if the stream is already shutting down, abort immediately
107-
select {
108-
case <-stream.shutdownChan:
109-
stream.mu.Lock()
110-
// Clear the pending request since we're aborting
111-
if stream.pendingReconnect != nil {
112-
stream.pendingReconnect = nil
113-
}
114-
stream.mu.Unlock()
115-
return nil, 0, xerrors.New("stream is shutting down")
116-
default:
117-
}
118-
119-
// Wait for response from HandleReconnect or context cancellation
120-
stream.logger.Info(context.Background(), "reconnect function waiting for response")
121-
select {
122-
case resp := <-responseChan:
123-
stream.logger.Info(context.Background(), "reconnect function got response",
124-
slog.F("has_conn", resp.conn != nil),
125-
slog.F("read_seq", resp.readSeq),
126-
slog.Error(resp.err))
127-
return resp.conn, resp.readSeq, resp.err
128-
case <-ctx.Done():
129-
// Context was canceled, clear pending request and return error
130-
stream.mu.Lock()
131-
stream.pendingReconnect = nil
132-
stream.handshakePending = false
133-
stream.mu.Unlock()
134-
return nil, 0, ctx.Err()
135-
case <-stream.shutdownChan:
136-
// Stream is being shut down, clear pending request and return error
137-
stream.mu.Lock()
138-
stream.pendingReconnect = nil
139-
stream.handshakePending = false
140-
stream.mu.Unlock()
141-
return nil, 0, xerrors.New("stream is shutting down")
142-
}
129+
// Create BackedPipe with background context and reconnector
130+
reconnector := &streamReconnector{
131+
stream: stream,
143132
}
133+
stream.pipe = backedpipe.NewBackedPipe(context.Background(), reconnector)
144134

145-
// Create BackedPipe with background context
146-
stream.pipe = backedpipe.NewBackedPipe(context.Background(), reconnectFn)
135+
// Initiate the first connection
136+
if err := stream.pipe.Connect(); err != nil {
137+
stream.logger.Warn(context.Background(), "failed to connect pipe initially", slog.Error(err))
138+
// Continue anyway - the pipe will retry connections as needed
139+
}
147140

148141
// Start reconnect worker: dedupe pokes and call ForceReconnect when safe.
149142
go func() {
@@ -240,22 +233,9 @@ func (s *Stream) HandleReconnect(clientConn io.ReadWriteCloser, readSeqNum uint6
240233
s.mu.Unlock()
241234
respCh <- reconnectResponse{conn: clientConn, readSeq: readSeqNum, err: nil}
242235

243-
// Wait until the pipe reports a connected state so the handshake fully completes.
244-
// Use a bounded timeout to avoid hanging forever in pathological cases.
245-
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
246-
err := s.pipe.WaitForConnection(ctx)
247-
cancel()
248-
if err != nil {
249-
s.mu.Lock()
250-
s.connected = false
251-
if s.reconnectCond != nil {
252-
s.reconnectCond.Broadcast()
253-
}
254-
s.mu.Unlock()
255-
s.logger.Warn(context.Background(), "failed to connect backed pipe", slog.Error(err))
256-
return xerrors.Errorf("failed to establish connection: %w", err)
257-
}
258-
236+
// The reconnector interface handles the connection establishment.
237+
// By the time we respond to the reconnect request, the connection should be established.
238+
// We just need to update our state to reflect the successful connection.
259239
s.mu.Lock()
260240
s.lastConnectionAt = time.Now()
261241
s.connected = true
@@ -333,23 +313,35 @@ func (s *Stream) Close() error {
333313
s.handshakePending = false
334314
}
335315

336-
// Close the backed pipe
337-
if s.pipe != nil {
338-
if err := s.pipe.Close(); err != nil {
339-
s.logger.Warn(context.Background(), "failed to close backed pipe", slog.Error(err))
340-
}
341-
}
342-
343-
// Close connections
316+
// Close connections first to unblock io.Copy operations
344317
if s.localConn != nil {
345318
if err := s.localConn.Close(); err != nil {
346319
s.logger.Warn(context.Background(), "failed to close local connection", slog.Error(err))
347320
}
348321
}
349322

350-
// Wait for goroutines to finish
323+
// Then close the backed pipe
324+
if s.pipe != nil {
325+
if err := s.pipe.Close(); err != nil {
326+
s.logger.Warn(context.Background(), "failed to close backed pipe", slog.Error(err))
327+
}
328+
}
329+
330+
// Wait for goroutines to finish with a timeout
351331
s.mu.Unlock()
352-
s.goroutines.Wait()
332+
done := make(chan struct{})
333+
go func() {
334+
defer close(done)
335+
s.goroutines.Wait()
336+
}()
337+
338+
select {
339+
case <-done:
340+
// Goroutines finished normally
341+
case <-time.After(5 * time.Second):
342+
// Timeout - log warning but continue
343+
s.logger.Warn(context.Background(), "timeout waiting for stream goroutines to finish during close")
344+
}
353345
s.mu.Lock()
354346

355347
return nil
@@ -403,8 +395,17 @@ func (s *Stream) startCopyingLocked() {
403395
defer s.goroutines.Done()
404396

405397
_, err := io.Copy(s.pipe, s.localConn)
406-
if err != nil && !xerrors.Is(err, io.EOF) && !xerrors.Is(err, io.ErrClosedPipe) {
407-
s.logger.Debug(context.Background(), "error copying from local to pipe", slog.Error(err))
398+
if err != nil {
399+
// Handle different error types appropriately
400+
if xerrors.Is(err, io.EOF) {
401+
s.logger.Debug(context.Background(), "local connection closed (EOF)")
402+
} else if xerrors.Is(err, io.ErrClosedPipe) || xerrors.Is(err, backedpipe.ErrPipeClosed) {
403+
s.logger.Debug(context.Background(), "pipe closed during copy", slog.Error(err))
404+
} else if xerrors.Is(err, backedpipe.ErrWriterClosed) {
405+
s.logger.Debug(context.Background(), "writer closed during copy", slog.Error(err))
406+
} else {
407+
s.logger.Debug(context.Background(), "error copying from local to pipe", slog.Error(err))
408+
}
408409
}
409410

410411
// Local connection closed, signal disconnection
@@ -426,13 +427,35 @@ func (s *Stream) startCopyingLocked() {
426427
for {
427428
// Use a buffer for copying
428429
n, err := s.pipe.Read(buf)
429-
// Log significant events
430-
if errors.Is(err, io.EOF) {
431-
s.logger.Debug(context.Background(), "got EOF from pipe")
432-
s.SignalDisconnect()
433-
} else if err != nil && !errors.Is(err, io.ErrClosedPipe) {
434-
s.logger.Debug(context.Background(), "error reading from pipe", slog.Error(err))
435-
s.SignalDisconnect()
430+
431+
// Handle different error types appropriately
432+
if err != nil {
433+
// Check for fatal errors that should terminate the goroutine
434+
if xerrors.Is(err, io.ErrClosedPipe) || xerrors.Is(err, backedpipe.ErrPipeClosed) {
435+
// The pipe itself is closed, we're done
436+
s.logger.Debug(context.Background(), "pipe closed, exiting copy goroutine", slog.Error(err))
437+
s.SignalDisconnect()
438+
return
439+
}
440+
441+
// Log various error types with appropriate context
442+
switch {
443+
case errors.Is(err, io.EOF):
444+
s.logger.Debug(context.Background(), "got EOF from pipe")
445+
s.SignalDisconnect()
446+
case xerrors.Is(err, backedpipe.ErrReconnectFailed):
447+
s.logger.Debug(context.Background(), "reconnect failed, pipe will retry", slog.Error(err))
448+
s.SignalDisconnect()
449+
case xerrors.Is(err, backedpipe.ErrReconnectionInProgress):
450+
s.logger.Debug(context.Background(), "reconnection in progress", slog.Error(err))
451+
// Don't signal disconnect - reconnection is already happening
452+
case xerrors.Is(err, backedpipe.ErrInvalidSequenceNumber):
453+
s.logger.Warn(context.Background(), "sequence number mismatch during reconnect", slog.Error(err))
454+
s.SignalDisconnect()
455+
default:
456+
s.logger.Debug(context.Background(), "error reading from pipe", slog.Error(err))
457+
s.SignalDisconnect()
458+
}
436459
}
437460

438461
if n > 0 {
@@ -447,14 +470,9 @@ func (s *Stream) startCopyingLocked() {
447470
}
448471

449472
if err != nil {
450-
// Check if this is a fatal error
451-
if xerrors.Is(err, io.ErrClosedPipe) {
452-
// The pipe itself is closed, we're done
453-
s.logger.Debug(context.Background(), "pipe closed, exiting copy goroutine")
454-
s.SignalDisconnect()
455-
return
456-
}
457-
// Any other error (including EOF) is handled by BackedPipe; continue
473+
// For non-fatal errors, BackedPipe handles reconnection internally
474+
// We continue the loop to keep reading after reconnection
475+
continue
458476
}
459477
}
460478
}()

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy