From e5be50695a5e57d33b286f7e78edaa5bfc0d8101 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Thu, 7 Aug 2025 11:16:53 +0000 Subject: [PATCH] "chore: add backed reader, writer and pipe" --- coderd/agentapi/backedpipe/backed_pipe.go | 332 ++++++++ .../agentapi/backedpipe/backed_pipe_test.go | 727 ++++++++++++++++++ coderd/agentapi/backedpipe/backed_reader.go | 150 ++++ .../agentapi/backedpipe/backed_reader_test.go | 471 ++++++++++++ coderd/agentapi/backedpipe/backed_writer.go | 244 ++++++ .../agentapi/backedpipe/backed_writer_test.go | 411 ++++++++++ coderd/agentapi/backedpipe/ring_buffer.go | 140 ++++ .../backedpipe/ring_buffer_internal_test.go | 162 ++++ .../agentapi/backedpipe/ring_buffer_test.go | 326 ++++++++ testutil/duration.go | 9 +- 10 files changed, 2968 insertions(+), 4 deletions(-) create mode 100644 coderd/agentapi/backedpipe/backed_pipe.go create mode 100644 coderd/agentapi/backedpipe/backed_pipe_test.go create mode 100644 coderd/agentapi/backedpipe/backed_reader.go create mode 100644 coderd/agentapi/backedpipe/backed_reader_test.go create mode 100644 coderd/agentapi/backedpipe/backed_writer.go create mode 100644 coderd/agentapi/backedpipe/backed_writer_test.go create mode 100644 coderd/agentapi/backedpipe/ring_buffer.go create mode 100644 coderd/agentapi/backedpipe/ring_buffer_internal_test.go create mode 100644 coderd/agentapi/backedpipe/ring_buffer_test.go diff --git a/coderd/agentapi/backedpipe/backed_pipe.go b/coderd/agentapi/backedpipe/backed_pipe.go new file mode 100644 index 0000000000000..784b8d55353b1 --- /dev/null +++ b/coderd/agentapi/backedpipe/backed_pipe.go @@ -0,0 +1,332 @@ +package backedpipe + +import ( + "context" + "io" + "sync" + + "golang.org/x/xerrors" +) + +const ( + // DefaultBufferSize is the default buffer size for the BackedWriter (64MB) + DefaultBufferSize = 64 * 1024 * 1024 +) + +// ReconnectFunc is called when the BackedPipe needs to establish a new connection. +// It should: +// 1. Establish a new connection to the remote side +// 2. Exchange sequence numbers with the remote side +// 3. Return the new connection and the remote's current sequence number +// +// The writerSeqNum parameter is the local writer's current sequence number, +// which should be sent to the remote side so it knows where to resume reading from. +// +// The returned readerSeqNum should be the remote side's current sequence number, +// which indicates where the local reader should resume from. +type ReconnectFunc func(ctx context.Context, writerSeqNum uint64) (conn io.ReadWriteCloser, readerSeqNum uint64, err error) + +// BackedPipe provides a reliable bidirectional byte stream over unreliable network connections. +// It orchestrates a BackedReader and BackedWriter to provide transparent reconnection +// and data replay capabilities. +type BackedPipe struct { + ctx context.Context + cancel context.CancelFunc + mu sync.RWMutex + reader *BackedReader + writer *BackedWriter + reconnectFn ReconnectFunc + conn io.ReadWriteCloser + connected bool + closed bool + + // Reconnection state + reconnecting bool + + // Error channel for receiving connection errors from reader/writer + errorChan chan error + + // Connection state notification + connectionChanged chan struct{} +} + +// NewBackedPipe creates a new BackedPipe with default options and the specified reconnect function. +// The pipe starts disconnected and must be connected using Connect(). +func NewBackedPipe(ctx context.Context, reconnectFn ReconnectFunc) *BackedPipe { + pipeCtx, cancel := context.WithCancel(ctx) + + bp := &BackedPipe{ + ctx: pipeCtx, + cancel: cancel, + reader: NewBackedReader(), + writer: NewBackedWriterWithCapacity(DefaultBufferSize), // 64MB default buffer + reconnectFn: reconnectFn, + errorChan: make(chan error, 2), // Buffer for reader and writer errors + connectionChanged: make(chan struct{}, 1), + } + + // Set up error callbacks + bp.reader.SetErrorCallback(func(err error) { + select { + case bp.errorChan <- err: + case <-bp.ctx.Done(): + } + }) + + bp.writer.SetErrorCallback(func(err error) { + select { + case bp.errorChan <- err: + case <-bp.ctx.Done(): + } + }) + + // Start error handler goroutine + go bp.handleErrors() + + return bp +} + +// Connect establishes the initial connection using the reconnect function. +func (bp *BackedPipe) Connect(ctx context.Context) error { + bp.mu.Lock() + defer bp.mu.Unlock() + + if bp.closed { + return xerrors.New("pipe is closed") + } + + if bp.connected { + return xerrors.New("pipe is already connected") + } + + return bp.reconnectLocked(ctx) +} + +// Read implements io.Reader by delegating to the BackedReader. +func (bp *BackedPipe) Read(p []byte) (int, error) { + bp.mu.RLock() + reader := bp.reader + closed := bp.closed + bp.mu.RUnlock() + + if closed { + return 0, io.ErrClosedPipe + } + + return reader.Read(p) +} + +// Write implements io.Writer by delegating to the BackedWriter. +func (bp *BackedPipe) Write(p []byte) (int, error) { + bp.mu.RLock() + writer := bp.writer + closed := bp.closed + bp.mu.RUnlock() + + if closed { + return 0, io.ErrClosedPipe + } + + return writer.Write(p) +} + +// Close closes the pipe and all underlying connections. +func (bp *BackedPipe) Close() error { + bp.mu.Lock() + defer bp.mu.Unlock() + + if bp.closed { + return nil + } + + bp.closed = true + bp.cancel() // Cancel main context + + // Close underlying components + var readerErr, writerErr, connErr error + + if bp.reader != nil { + readerErr = bp.reader.Close() + } + + if bp.writer != nil { + writerErr = bp.writer.Close() + } + + if bp.conn != nil { + connErr = bp.conn.Close() + bp.conn = nil + } + + bp.connected = false + bp.signalConnectionChange() + + // Return first error encountered + if readerErr != nil { + return readerErr + } + if writerErr != nil { + return writerErr + } + return connErr +} + +// Connected returns whether the pipe is currently connected. +func (bp *BackedPipe) Connected() bool { + bp.mu.RLock() + defer bp.mu.RUnlock() + return bp.connected +} + +// signalConnectionChange signals that the connection state has changed. +func (bp *BackedPipe) signalConnectionChange() { + select { + case bp.connectionChanged <- struct{}{}: + default: + // Channel is full, which is fine - we just want to signal that something changed + } +} + +// reconnectLocked handles the reconnection logic. Must be called with write lock held. +func (bp *BackedPipe) reconnectLocked(ctx context.Context) error { + if bp.reconnecting { + return xerrors.New("reconnection already in progress") + } + + bp.reconnecting = true + defer func() { + bp.reconnecting = false + }() + + // Close existing connection if any + if bp.conn != nil { + _ = bp.conn.Close() + bp.conn = nil + } + + bp.connected = false + bp.signalConnectionChange() + + // Get current writer sequence number to send to remote + writerSeqNum := bp.writer.SequenceNum() + + // Unlock during reconnect attempt to avoid blocking reads/writes + bp.mu.Unlock() + conn, readerSeqNum, err := bp.reconnectFn(ctx, writerSeqNum) + bp.mu.Lock() + + if err != nil { + return xerrors.Errorf("reconnect failed: %w", err) + } + + // Validate sequence numbers + if readerSeqNum > writerSeqNum { + _ = conn.Close() + return xerrors.Errorf("remote sequence number %d exceeds local sequence %d, cannot replay", + readerSeqNum, writerSeqNum) + } + + // Validate writer can replay from the requested sequence + if !bp.writer.CanReplayFrom(readerSeqNum) { + _ = conn.Close() + // Calculate data loss + currentSeq := bp.writer.SequenceNum() + dataLoss := currentSeq - DefaultBufferSize - readerSeqNum + return xerrors.Errorf("cannot replay from sequence %d (current: %d, data loss: ~%d bytes)", + readerSeqNum, currentSeq, dataLoss) + } + + // Reconnect reader and writer + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go bp.reader.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + <-seqNum + newR <- conn + + err = bp.writer.Reconnect(readerSeqNum, conn) + if err != nil { + _ = conn.Close() + return xerrors.Errorf("reconnect writer: %w", err) + } + + // Success - update state + bp.conn = conn + bp.connected = true + bp.signalConnectionChange() + + return nil +} + +// handleErrors listens for connection errors from reader/writer and triggers reconnection. +func (bp *BackedPipe) handleErrors() { + for { + select { + case <-bp.ctx.Done(): + return + case err := <-bp.errorChan: + // Connection error occurred + bp.mu.Lock() + + // Skip if already closed or not connected + if bp.closed || !bp.connected { + bp.mu.Unlock() + continue + } + + // Mark as disconnected + bp.connected = false + bp.signalConnectionChange() + + // Try to reconnect + reconnectErr := bp.reconnectLocked(bp.ctx) + bp.mu.Unlock() + + if reconnectErr != nil { + // Reconnection failed - log or handle as needed + // For now, we'll just continue and wait for manual reconnection + _ = err // Use the original error + } + } + } +} + +// WaitForConnection blocks until the pipe is connected or the context is canceled. +func (bp *BackedPipe) WaitForConnection(ctx context.Context) error { + for { + bp.mu.RLock() + connected := bp.connected + closed := bp.closed + bp.mu.RUnlock() + + if closed { + return io.ErrClosedPipe + } + + if connected { + return nil + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-bp.connectionChanged: + // Connection state changed, check again + } + } +} + +// ForceReconnect forces a reconnection attempt immediately. +// This can be used to force a reconnection if a new connection is established. +func (bp *BackedPipe) ForceReconnect(ctx context.Context) error { + bp.mu.Lock() + defer bp.mu.Unlock() + + if bp.closed { + return io.ErrClosedPipe + } + + return bp.reconnectLocked(ctx) +} diff --git a/coderd/agentapi/backedpipe/backed_pipe_test.go b/coderd/agentapi/backedpipe/backed_pipe_test.go new file mode 100644 index 0000000000000..c841112ed07e1 --- /dev/null +++ b/coderd/agentapi/backedpipe/backed_pipe_test.go @@ -0,0 +1,727 @@ +package backedpipe_test + +import ( + "bytes" + "context" + "io" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/agentapi/backedpipe" + "github.com/coder/coder/v2/testutil" +) + +// mockConnection implements io.ReadWriteCloser for testing +type mockConnection struct { + mu sync.Mutex + readBuffer bytes.Buffer + writeBuffer bytes.Buffer + closed bool + readError error + writeError error + closeError error + readFunc func([]byte) (int, error) + writeFunc func([]byte) (int, error) + seqNum uint64 +} + +func newMockConnection() *mockConnection { + return &mockConnection{} +} + +func (mc *mockConnection) Read(p []byte) (int, error) { + mc.mu.Lock() + defer mc.mu.Unlock() + + if mc.readFunc != nil { + return mc.readFunc(p) + } + + if mc.readError != nil { + return 0, mc.readError + } + + return mc.readBuffer.Read(p) +} + +func (mc *mockConnection) Write(p []byte) (int, error) { + mc.mu.Lock() + defer mc.mu.Unlock() + + if mc.writeFunc != nil { + return mc.writeFunc(p) + } + + if mc.writeError != nil { + return 0, mc.writeError + } + + return mc.writeBuffer.Write(p) +} + +func (mc *mockConnection) Close() error { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.closed = true + return mc.closeError +} + +func (mc *mockConnection) WriteString(s string) { + mc.mu.Lock() + defer mc.mu.Unlock() + _, _ = mc.readBuffer.WriteString(s) +} + +func (mc *mockConnection) ReadString() string { + mc.mu.Lock() + defer mc.mu.Unlock() + return mc.writeBuffer.String() +} + +func (mc *mockConnection) SetReadError(err error) { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.readError = err +} + +func (mc *mockConnection) SetWriteError(err error) { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.writeError = err +} + +func (mc *mockConnection) Reset() { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.readBuffer.Reset() + mc.writeBuffer.Reset() + mc.readError = nil + mc.writeError = nil + mc.closed = false +} + +// mockReconnectFunc creates a unified reconnect function with all behaviors enabled +func mockReconnectFunc(connections ...*mockConnection) (backedpipe.ReconnectFunc, *int, chan struct{}) { + connectionIndex := 0 + callCount := 0 + signalChan := make(chan struct{}, 1) + + reconnectFn := func(ctx context.Context, writerSeqNum uint64) (io.ReadWriteCloser, uint64, error) { + callCount++ + + if connectionIndex >= len(connections) { + return nil, 0, xerrors.New("no more connections available") + } + + conn := connections[connectionIndex] + connectionIndex++ + + // Signal when reconnection happens + if connectionIndex > 1 { + select { + case signalChan <- struct{}{}: + default: + } + } + + // Determine readerSeqNum based on call count + var readerSeqNum uint64 + switch { + case callCount == 1: + readerSeqNum = 0 + case conn.seqNum != 0: + readerSeqNum = conn.seqNum + default: + readerSeqNum = writerSeqNum + } + + return conn, readerSeqNum, nil + } + + return reconnectFn, &callCount, signalChan +} + +func TestBackedPipe_NewBackedPipe(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reconnectFn, _, _ := mockReconnectFunc(newMockConnection()) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + require.NotNil(t, bp) + require.False(t, bp.Connected()) +} + +func TestBackedPipe_Connect(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, callCount, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Connect(ctx) + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 1, *callCount) +} + +func TestBackedPipe_ConnectAlreadyConnected(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Connect(ctx) + require.NoError(t, err) + + // Second connect should fail + err = bp.Connect(ctx) + require.Error(t, err) + require.Contains(t, err.Error(), "already connected") +} + +func TestBackedPipe_ConnectAfterClose(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Close() + require.NoError(t, err) + + err = bp.Connect(ctx) + require.Error(t, err) + require.Contains(t, err.Error(), "closed") +} + +func TestBackedPipe_BasicReadWrite(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Connect(ctx) + require.NoError(t, err) + + // Write data + n, err := bp.Write([]byte("hello")) + require.NoError(t, err) + require.Equal(t, 5, n) + + // Simulate data coming back + conn.WriteString("world") + + // Read data + buf := make([]byte, 10) + n, err = bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, "world", string(buf[:n])) +} + +func TestBackedPipe_WriteBeforeConnect(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Write before connecting should succeed (buffered) + n, err := bp.Write([]byte("hello")) + require.NoError(t, err) + require.Equal(t, 5, n) + + // Connect should replay the buffered data + err = bp.Connect(ctx) + require.NoError(t, err) + + // Check that data was replayed to connection + require.Equal(t, "hello", conn.ReadString()) +} + +func TestBackedPipe_ReadBlocksWhenDisconnected(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reconnectFn, _, _ := mockReconnectFunc(newMockConnection()) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Start a read that should block + readDone := make(chan struct{}) + readStarted := make(chan struct{}) + var readErr error + + go func() { + defer close(readDone) + close(readStarted) // Signal that we're about to start the read + buf := make([]byte, 10) + _, readErr = bp.Read(buf) + }() + + // Wait for the goroutine to start + <-readStarted + + // Give a brief moment for the read to actually block + time.Sleep(time.Millisecond) + + // Read should still be blocked + select { + case <-readDone: + t.Fatal("Read should be blocked when disconnected") + default: + // Good, still blocked + } + + // Close should unblock the read + bp.Close() + + select { + case <-readDone: + require.Equal(t, io.ErrClosedPipe, readErr) + case <-time.After(time.Second): + t.Fatal("Read did not unblock after close") + } +} + +func TestBackedPipe_Reconnection(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn1 := newMockConnection() + conn2 := newMockConnection() + conn2.seqNum = 17 // Remote has received 17 bytes, so replay from sequence 17 + reconnectFn, _, signalChan := mockReconnectFunc(conn1, conn2) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Initial connect + err := bp.Connect(ctx) + require.NoError(t, err) + + // Write some data before failure + bp.Write([]byte("before disconnect***")) + + // Simulate connection failure + conn1.SetReadError(xerrors.New("connection lost")) + conn1.SetWriteError(xerrors.New("connection lost")) + + // Trigger a write to cause the pipe to notice the failure + _, _ = bp.Write([]byte("trigger failure ")) + + <-signalChan + + err = bp.WaitForConnection(ctx) + require.NoError(t, err) + + replayedData := conn2.ReadString() + require.Equal(t, "***trigger failure ", replayedData, "Should replay exactly the data written after sequence 17") + + // Verify that new writes work with the reconnected pipe + _, err = bp.Write([]byte("new data after reconnect")) + require.NoError(t, err) + + // Read all data from the connection (replayed + new data) + allData := conn2.ReadString() + require.Equal(t, "***trigger failure new data after reconnect", allData, "Should have replayed data plus new data") +} + +func TestBackedPipe_Close(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Connect(ctx) + require.NoError(t, err) + + err = bp.Close() + require.NoError(t, err) + require.True(t, conn.closed) + + // Operations after close should fail + _, err = bp.Read(make([]byte, 10)) + require.Equal(t, io.ErrClosedPipe, err) + + _, err = bp.Write([]byte("test")) + require.Equal(t, io.ErrClosedPipe, err) +} + +func TestBackedPipe_CloseIdempotent(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Close() + require.NoError(t, err) + + // Second close should be no-op + err = bp.Close() + require.NoError(t, err) +} + +func TestBackedPipe_WaitForConnection(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Should timeout when not connected + // Use a shorter timeout for this test to speed up test runs + timeoutCtx, cancel := context.WithTimeout(ctx, testutil.WaitSuperShort) + defer cancel() + + err := bp.WaitForConnection(timeoutCtx) + require.Equal(t, context.DeadlineExceeded, err) + + // Connect in background after a brief delay + connectionStarted := make(chan struct{}) + go func() { + close(connectionStarted) + // Small delay to ensure WaitForConnection is called first + time.Sleep(time.Millisecond) + bp.Connect(context.Background()) + }() + + // Wait for connection goroutine to start + <-connectionStarted + + // Should succeed once connected + err = bp.WaitForConnection(context.Background()) + require.NoError(t, err) +} + +func TestBackedPipe_ConcurrentReadWrite(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Connect(ctx) + require.NoError(t, err) + + var wg sync.WaitGroup + numWriters := 3 + writesPerWriter := 10 + + // Fill read buffer with test data first + testData := make([]byte, 1000) + for i := range testData { + testData[i] = 'A' + } + conn.WriteString(string(testData)) + + // Channel to collect all written data + writtenData := make(chan byte, numWriters*writesPerWriter) + + // Start a few readers + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 10) + for j := 0; j < 10; j++ { + bp.Read(buf) + time.Sleep(time.Millisecond) // Small delay to avoid busy waiting + } + }() + } + + // Start writers + for i := 0; i < numWriters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < writesPerWriter; j++ { + data := []byte{byte(id + '0')} + bp.Write(data) + writtenData <- byte(id + '0') + time.Sleep(time.Millisecond) // Small delay + } + }(i) + } + + // Wait with timeout + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() + }() + + select { + case <-done: + // Success + case <-time.After(5 * time.Second): + t.Fatal("Test timed out") + } + + // Close the channel and collect all written data + close(writtenData) + var allWritten []byte + for b := range writtenData { + allWritten = append(allWritten, b) + } + + // Verify that all written data was received by the connection + // Note: Since this test uses the old mock that returns readerSeqNum = 0, + // all data will be replayed, so we expect to receive all written data + receivedData := conn.ReadString() + require.GreaterOrEqual(t, len(receivedData), len(allWritten), "Connection should have received at least all written data") + + // Check that all written bytes appear in the received data + for _, writtenByte := range allWritten { + require.Contains(t, receivedData, string(writtenByte), "Written byte %c should be present in received data", writtenByte) + } +} + +func TestBackedPipe_ReconnectFunctionFailure(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + failingReconnectFn := func(ctx context.Context, writerSeqNum uint64) (io.ReadWriteCloser, uint64, error) { + return nil, 0, xerrors.New("reconnect failed") + } + + bp := backedpipe.NewBackedPipe(ctx, failingReconnectFn) + + err := bp.Connect(ctx) + require.Error(t, err) + require.Contains(t, err.Error(), "reconnect failed") + require.False(t, bp.Connected()) +} + +func TestBackedPipe_ForceReconnect(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn1 := newMockConnection() + conn2 := newMockConnection() + reconnectFn, callCount, _ := mockReconnectFunc(conn1, conn2) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Initial connect + err := bp.Connect(ctx) + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 1, *callCount) + + // Write some data to the first connection + _, err = bp.Write([]byte("test data")) + require.NoError(t, err) + require.Equal(t, "test data", conn1.ReadString()) + + // Force a reconnection + err = bp.ForceReconnect(ctx) + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 2, *callCount) + + // Since the mock now returns the proper sequence number, no data should be replayed + // The new connection should be empty + require.Equal(t, "", conn2.ReadString()) + + // Verify that data can still be written and read after forced reconnection + _, err = bp.Write([]byte("new data")) + require.NoError(t, err) + require.Equal(t, "new data", conn2.ReadString()) + + // Verify that reads work with the new connection + conn2.WriteString("response data") + buf := make([]byte, 20) + n, err := bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 13, n) + require.Equal(t, "response data", string(buf[:n])) +} + +func TestBackedPipe_ForceReconnectWhenClosed(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Close the pipe first + err := bp.Close() + require.NoError(t, err) + + // Try to force reconnect when closed + err = bp.ForceReconnect(ctx) + require.Error(t, err) + require.Equal(t, io.ErrClosedPipe, err) +} + +func TestBackedPipe_ForceReconnectWhenDisconnected(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, callCount, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Don't connect initially, just force reconnect + err := bp.ForceReconnect(ctx) + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 1, *callCount) + + // Verify we can write and read + _, err = bp.Write([]byte("test")) + require.NoError(t, err) + require.Equal(t, "test", conn.ReadString()) + + conn.WriteString("response") + buf := make([]byte, 10) + n, err := bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 8, n) + require.Equal(t, "response", string(buf[:n])) +} + +func TestBackedPipe_EOFTriggersReconnection(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + // Create connections where we can control when EOF occurs + conn1 := newMockConnection() + conn2 := newMockConnection() + conn2.WriteString("newdata") // Pre-populate conn2 with data + + // Make conn1 return EOF after reading "world" + hasReadData := false + conn1.readFunc = func(p []byte) (int, error) { + // Don't lock here - the Read method already holds the lock + + // First time: return "world" + if !hasReadData && conn1.readBuffer.Len() > 0 { + n, _ := conn1.readBuffer.Read(p) + hasReadData = true + return n, nil + } + // After that: return EOF + return 0, io.EOF + } + conn1.WriteString("world") + + callCount := 0 + reconnectFn := func(ctx context.Context, writerSeqNum uint64) (io.ReadWriteCloser, uint64, error) { + callCount++ + + if callCount == 1 { + return conn1, 0, nil + } + if callCount == 2 { + // Second call is the reconnection after EOF + return conn2, writerSeqNum, nil // conn2 already has the reader sequence at writerSeqNum + } + + return nil, 0, xerrors.New("no more connections") + } + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Initial connect + err := bp.Connect(ctx) + require.NoError(t, err) + require.Equal(t, 1, callCount) + + // Write some data + _, err = bp.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 10) + + // First read should succeed + n, err := bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, "world", string(buf[:n])) + + // Next read will encounter EOF and should trigger reconnection + // After reconnection, it should read from conn2 + n, err = bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 7, n) + require.Equal(t, "newdata", string(buf[:n])) + + // Verify reconnection happened + require.Equal(t, 2, callCount) + + // Verify the pipe is still connected and functional + require.True(t, bp.Connected()) + + // Further writes should go to the new connection + _, err = bp.Write([]byte("aftereof")) + require.NoError(t, err) + require.Equal(t, "aftereof", conn2.ReadString()) +} + +func BenchmarkBackedPipe_Write(b *testing.B) { + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + bp.Connect(ctx) + + data := make([]byte, 1024) // 1KB writes + + b.ResetTimer() + for i := 0; i < b.N; i++ { + bp.Write(data) + } +} + +func BenchmarkBackedPipe_Read(b *testing.B) { + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + bp.Connect(ctx) + + buf := make([]byte, 1024) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Fill connection with fresh data for each iteration + conn.WriteString(string(buf)) + bp.Read(buf) + } +} diff --git a/coderd/agentapi/backedpipe/backed_reader.go b/coderd/agentapi/backedpipe/backed_reader.go new file mode 100644 index 0000000000000..7b57986638065 --- /dev/null +++ b/coderd/agentapi/backedpipe/backed_reader.go @@ -0,0 +1,150 @@ +package backedpipe + +import ( + "io" + "sync" +) + +// BackedReader wraps an unreliable io.Reader and makes it resilient to disconnections. +// It tracks sequence numbers for all bytes read and can handle reconnection, +// blocking reads when disconnected instead of erroring. +type BackedReader struct { + mu sync.Mutex + cond *sync.Cond + reader io.Reader + sequenceNum uint64 + closed bool + + // Error callback to notify parent when connection fails + onError func(error) +} + +// NewBackedReader creates a new BackedReader. The reader is initially disconnected +// and must be connected using Reconnect before reads will succeed. +func NewBackedReader() *BackedReader { + br := &BackedReader{} + br.cond = sync.NewCond(&br.mu) + return br +} + +// Read implements io.Reader. It blocks when disconnected until either: +// 1. A reconnection is established +// 2. The reader is closed +// +// When connected, it reads from the underlying reader and updates sequence numbers. +// Connection failures are automatically detected and reported to the higher layer via callback. +func (br *BackedReader) Read(p []byte) (int, error) { + br.mu.Lock() + defer br.mu.Unlock() + + for { + for br.reader == nil && !br.closed { + br.cond.Wait() + } + + // Check if closed + if br.closed { + return 0, io.ErrClosedPipe + } + + br.mu.Unlock() + n, err := br.reader.Read(p) + br.mu.Lock() + + if err == nil { + br.sequenceNum += uint64(n) // #nosec G115 -- n is always >= 0 per io.Reader contract + return n, nil + } + + br.reader = nil + + if br.onError != nil { + br.onError(err) + } + + // If we got some data before the error, return it + if n > 0 { + br.sequenceNum += uint64(n) + return n, nil + } + + // Return to Step 2 (continue the loop) + } +} + +// Reconnect coordinates reconnection using channels for better synchronization. +// The seqNum channel is used to send the current sequence number to the caller. +// The newR channel is used to receive the new reader from the caller. +// This allows for better coordination during the reconnection process. +func (br *BackedReader) Reconnect(seqNum chan<- uint64, newR <-chan io.Reader) { + // Grab the lock + br.mu.Lock() + defer br.mu.Unlock() + + if br.closed { + // Send 0 sequence number and close the channel to indicate closed state + seqNum <- 0 + close(seqNum) + return + } + + // Get the sequence number to send to the other side via seqNum channel + seqNum <- br.sequenceNum + close(seqNum) + + // Wait for the reconnect to complete, via newR channel, and give us a new io.Reader + newReader := <-newR + + // If reconnection fails while we are starting it, the caller sends nil on newR + if newReader == nil { + // Reconnection failed, keep current state + return + } + + // Reconnection successful + br.reader = newReader + + // Notify any waiting reads via the cond + br.cond.Broadcast() +} + +// Closes the reader and wakes up any blocked reads. +// After closing, all Read calls will return io.ErrClosedPipe. +func (br *BackedReader) Close() error { + br.mu.Lock() + defer br.mu.Unlock() + + if br.closed { + return nil + } + + br.closed = true + br.reader = nil + + // Wake up any blocked reads + br.cond.Broadcast() + + return nil +} + +// SetErrorCallback sets the callback function that will be called when +// a connection error occurs (excluding EOF). +func (br *BackedReader) SetErrorCallback(fn func(error)) { + br.mu.Lock() + defer br.mu.Unlock() + br.onError = fn +} + +// SequenceNum returns the current sequence number (total bytes read). +func (br *BackedReader) SequenceNum() uint64 { + br.mu.Lock() + defer br.mu.Unlock() + return br.sequenceNum +} + +// Connected returns whether the reader is currently connected. +func (br *BackedReader) Connected() bool { + br.mu.Lock() + defer br.mu.Unlock() + return br.reader != nil +} diff --git a/coderd/agentapi/backedpipe/backed_reader_test.go b/coderd/agentapi/backedpipe/backed_reader_test.go new file mode 100644 index 0000000000000..810abb7c64bd6 --- /dev/null +++ b/coderd/agentapi/backedpipe/backed_reader_test.go @@ -0,0 +1,471 @@ +package backedpipe_test + +import ( + "io" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/agentapi/backedpipe" +) + +// mockReader implements io.Reader with controllable behavior for testing +type mockReader struct { + mu sync.Mutex + data []byte + pos int + err error + readFunc func([]byte) (int, error) +} + +func newMockReader(data string) *mockReader { + return &mockReader{data: []byte(data)} +} + +func (mr *mockReader) Read(p []byte) (int, error) { + mr.mu.Lock() + defer mr.mu.Unlock() + + if mr.readFunc != nil { + return mr.readFunc(p) + } + + if mr.err != nil { + return 0, mr.err + } + + if mr.pos >= len(mr.data) { + return 0, io.EOF + } + + n := copy(p, mr.data[mr.pos:]) + mr.pos += n + return n, nil +} + +func (mr *mockReader) setError(err error) { + mr.mu.Lock() + defer mr.mu.Unlock() + mr.err = err +} + +func TestBackedReader_NewBackedReader(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + assert.NotNil(t, br) + assert.Equal(t, uint64(0), br.SequenceNum()) + assert.False(t, br.Connected()) +} + +func TestBackedReader_BasicReadOperation(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + reader := newMockReader("hello world") + + // Connect the reader + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number from reader + seq := <-seqNum + assert.Equal(t, uint64(0), seq) + + // Send new reader + newR <- reader + + // Read data + buf := make([]byte, 5) + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, "hello", string(buf)) + assert.Equal(t, uint64(5), br.SequenceNum()) + + // Read more data + n, err = br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, " worl", string(buf)) + assert.Equal(t, uint64(10), br.SequenceNum()) +} + +func TestBackedReader_ReadBlocksWhenDisconnected(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + + // Start a read operation that should block + readDone := make(chan struct{}) + readStarted := make(chan struct{}) + var readErr error + + go func() { + defer close(readDone) + close(readStarted) // Signal that we're about to start the read + buf := make([]byte, 10) + _, readErr = br.Read(buf) + }() + + // Wait for the goroutine to start + <-readStarted + + // Give a brief moment for the read to actually block on the condition variable + // This is much shorter and more deterministic than the previous approach + time.Sleep(time.Millisecond) + + // Read should still be blocked + select { + case <-readDone: + t.Fatal("Read should be blocked when disconnected") + default: + // Good, still blocked + } + + // Connect and the read should unblock + reader := newMockReader("test") + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + <-seqNum + newR <- reader + + // Wait for read to complete + select { + case <-readDone: + assert.NoError(t, readErr) + case <-time.After(time.Second): + t.Fatal("Read did not unblock after reconnection") + } +} + +func TestBackedReader_ReconnectionAfterFailure(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + reader1 := newMockReader("first") + + // Initial connection + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + <-seqNum + newR <- reader1 + + // Read some data + buf := make([]byte, 5) + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, "first", string(buf[:n])) + assert.Equal(t, uint64(5), br.SequenceNum()) + + // Set up error callback to verify error notification + errorReceived := make(chan error, 1) + br.SetErrorCallback(func(err error) { + errorReceived <- err + }) + + // Simulate connection failure + reader1.setError(xerrors.New("connection lost")) + + // Start a read that will block due to connection failure + readDone := make(chan error, 1) + go func() { + _, err := br.Read(buf) + readDone <- err + }() + + // Wait for the error to be reported via callback + select { + case receivedErr := <-errorReceived: + assert.Error(t, receivedErr) + assert.Contains(t, receivedErr.Error(), "connection lost") + case <-time.After(time.Second): + t.Fatal("Error callback was not invoked within timeout") + } + + // Verify disconnection + assert.False(t, br.Connected()) + + // Reconnect with new reader + reader2 := newMockReader("second") + seqNum2 := make(chan uint64, 1) + newR2 := make(chan io.Reader, 1) + + go br.Reconnect(seqNum2, newR2) + + // Get sequence number and send new reader + seq := <-seqNum2 + assert.Equal(t, uint64(5), seq) // Should return current sequence number + newR2 <- reader2 + + // Wait for read to unblock and succeed with new data + select { + case readErr := <-readDone: + assert.NoError(t, readErr) // Should succeed with new reader + case <-time.After(time.Second): + t.Fatal("Read did not unblock after reconnection") + } +} + +func TestBackedReader_Close(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + reader := newMockReader("test") + + // Connect + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + <-seqNum + newR <- reader + + // First, read all available data + buf := make([]byte, 10) + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 4, n) // "test" is 4 bytes + + // Close the reader before EOF triggers reconnection + err = br.Close() + require.NoError(t, err) + + // After close, reads should return ErrClosedPipe + n, err = br.Read(buf) + assert.Equal(t, 0, n) + assert.Equal(t, io.ErrClosedPipe, err) + + // Subsequent reads should return ErrClosedPipe + _, err = br.Read(buf) + assert.Equal(t, io.ErrClosedPipe, err) +} + +func TestBackedReader_CloseIdempotent(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + + err := br.Close() + assert.NoError(t, err) + + // Second close should be no-op + err = br.Close() + assert.NoError(t, err) +} + +func TestBackedReader_ReconnectAfterClose(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + + err := br.Close() + require.NoError(t, err) + + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Should get 0 sequence number for closed reader + seq := <-seqNum + assert.Equal(t, uint64(0), seq) +} + +// Helper function to reconnect a reader using channels +func reconnectReader(br *backedpipe.BackedReader, reader io.Reader) { + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + <-seqNum + newR <- reader +} + +func TestBackedReader_SequenceNumberTracking(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + reader := newMockReader("0123456789") + + reconnectReader(br, reader) + + // Read in chunks and verify sequence number + buf := make([]byte, 3) + + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 3, n) + assert.Equal(t, uint64(3), br.SequenceNum()) + + n, err = br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 3, n) + assert.Equal(t, uint64(6), br.SequenceNum()) + + n, err = br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 3, n) + assert.Equal(t, uint64(9), br.SequenceNum()) +} + +func TestBackedReader_ConcurrentReads(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + reader := newMockReader(strings.Repeat("a", 1000)) + + reconnectReader(br, reader) + + var wg sync.WaitGroup + numReaders := 5 + readsPerReader := 10 + + for i := 0; i < numReaders; i++ { + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 10) + for j := 0; j < readsPerReader; j++ { + br.Read(buf) + } + }() + } + + wg.Wait() + + // Should have read some data (exact amount depends on scheduling) + assert.True(t, br.SequenceNum() > 0) + assert.True(t, br.SequenceNum() <= 1000) +} + +func TestBackedReader_EOFHandling(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + reader := newMockReader("test") + + // Set up error callback to track when EOF triggers disconnection + errorReceived := make(chan error, 1) + br.SetErrorCallback(func(err error) { + errorReceived <- err + }) + + reconnectReader(br, reader) + + // Read all data + buf := make([]byte, 10) + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 4, n) + assert.Equal(t, "test", string(buf[:n])) + + // Next read should encounter EOF, which triggers disconnection + // The read should block waiting for reconnection + readDone := make(chan struct{}) + var readErr error + var readN int + + go func() { + defer close(readDone) + readN, readErr = br.Read(buf) + }() + + // Wait for EOF to be reported via error callback + select { + case receivedErr := <-errorReceived: + assert.Equal(t, io.EOF, receivedErr) + case <-time.After(time.Second): + t.Fatal("EOF was not reported via error callback within timeout") + } + + // Reader should be disconnected after EOF + assert.False(t, br.Connected()) + + // Read should still be blocked + select { + case <-readDone: + t.Fatal("Read should be blocked waiting for reconnection after EOF") + default: + // Good, still blocked + } + + // Reconnect with new data + reader2 := newMockReader("more") + reconnectReader(br, reader2) + + // Wait for the blocked read to complete with new data + select { + case <-readDone: + require.NoError(t, readErr) + assert.Equal(t, 4, readN) + assert.Equal(t, "more", string(buf[:readN])) + case <-time.After(time.Second): + t.Fatal("Read did not unblock after reconnection") + } +} + +func BenchmarkBackedReader_Read(b *testing.B) { + br := backedpipe.NewBackedReader() + buf := make([]byte, 1024) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Create fresh reader with data for each iteration + data := strings.Repeat("x", 1024) // 1KB of data per iteration + reader := newMockReader(data) + reconnectReader(br, reader) + + br.Read(buf) + } +} + +func TestBackedReader_PartialReads(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + + // Create a reader that returns partial reads + reader := &mockReader{ + readFunc: func(p []byte) (int, error) { + // Always return just 1 byte at a time + if len(p) == 0 { + return 0, nil + } + p[0] = 'A' + return 1, nil + }, + } + + reconnectReader(br, reader) + + // Read multiple times + buf := make([]byte, 10) + for i := 0; i < 5; i++ { + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 1, n) + assert.Equal(t, byte('A'), buf[0]) + } + + assert.Equal(t, uint64(5), br.SequenceNum()) +} diff --git a/coderd/agentapi/backedpipe/backed_writer.go b/coderd/agentapi/backedpipe/backed_writer.go new file mode 100644 index 0000000000000..bc72d8bfc7385 --- /dev/null +++ b/coderd/agentapi/backedpipe/backed_writer.go @@ -0,0 +1,244 @@ +package backedpipe + +import ( + "context" + "io" + "sync" + + "golang.org/x/xerrors" +) + +// BackedWriter wraps an unreliable io.Writer and makes it resilient to disconnections. +// It maintains a ring buffer of recent writes for replay during reconnection and +// always writes to the buffer even when disconnected. +type BackedWriter struct { + mu sync.Mutex + cond *sync.Cond + writer io.Writer + buffer *RingBuffer + sequenceNum uint64 // total bytes written + closed bool + + // Error callback to notify parent when connection fails + onError func(error) +} + +// NewBackedWriter creates a new BackedWriter with a 64MB ring buffer. +// The writer is initially disconnected and will buffer writes until connected. +func NewBackedWriter() *BackedWriter { + return NewBackedWriterWithCapacity(64 * 1024 * 1024) +} + +// NewBackedWriterWithCapacity creates a new BackedWriter with the specified buffer capacity. +// The writer is initially disconnected and will buffer writes until connected. +func NewBackedWriterWithCapacity(capacity int) *BackedWriter { + bw := &BackedWriter{ + buffer: NewRingBufferWithCapacity(capacity), + } + bw.cond = sync.NewCond(&bw.mu) + return bw +} + +// Write implements io.Writer. It always writes to the ring buffer, even when disconnected. +// When connected, it also writes to the underlying writer. If the underlying write fails, +// the writer is marked as disconnected but the buffer write still succeeds. +func (bw *BackedWriter) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + + bw.mu.Lock() + defer bw.mu.Unlock() + + if bw.closed { + return 0, io.ErrClosedPipe + } + + // Always write to buffer first + written, _ := bw.buffer.Write(p) + //nolint:gosec // Safe conversion: written is always non-negative from buffer.Write + bw.sequenceNum += uint64(written) + + // If connected, also write to underlying writer + if bw.writer != nil { + // Unlock during actual write to avoid blocking other operations + bw.mu.Unlock() + n, err := bw.writer.Write(p) + bw.mu.Lock() + + if n != len(p) { + err = xerrors.Errorf("partial write: wrote %d of %d bytes", n, len(p)) + } + + if err != nil { + // Connection failed, mark as disconnected + bw.writer = nil + + // Notify parent of error if callback is set + if bw.onError != nil { + bw.onError(err) + } + } + } + + return written, nil +} + +// Reconnect replaces the current writer with a new one and replays data from the specified +// sequence number. If the requested sequence number is no longer in the buffer, +// returns an error indicating data loss. +func (bw *BackedWriter) Reconnect(replayFromSeq uint64, newWriter io.Writer) error { + bw.mu.Lock() + defer bw.mu.Unlock() + + if bw.closed { + return xerrors.New("cannot reconnect closed writer") + } + + if newWriter == nil { + return xerrors.New("new writer cannot be nil") + } + + // Check if we can replay from the requested sequence number + if replayFromSeq > bw.sequenceNum { + return xerrors.Errorf("cannot replay from future sequence %d: current sequence is %d", replayFromSeq, bw.sequenceNum) + } + + // Calculate how many bytes we need to replay + replayBytes := bw.sequenceNum - replayFromSeq + + var replayData []byte + if replayBytes > 0 { + // Get the last replayBytes from buffer + // If the buffer doesn't have enough data (some was evicted), + // ReadLast will return an error + var err error + // Safe conversion: replayBytes is always non-negative due to the check above + // No overflow possible since replayBytes is calculated as sequenceNum - replayFromSeq + // and uint64->int conversion is safe for reasonable buffer sizes + //nolint:gosec // Safe conversion: replayBytes is calculated from uint64 subtraction + replayData, err = bw.buffer.ReadLast(int(replayBytes)) + if err != nil { + return xerrors.Errorf("failed to read replay data: %w", err) + } + } + + // Set new writer + bw.writer = newWriter + + // Replay data if needed + if len(replayData) > 0 { + bw.mu.Unlock() + n, err := newWriter.Write(replayData) + bw.mu.Lock() + + if err != nil { + bw.writer = nil + return xerrors.Errorf("replay failed: %w", err) + } + + if n != len(replayData) { + bw.writer = nil + return xerrors.Errorf("partial replay: wrote %d of %d bytes", n, len(replayData)) + } + } + + // Wake up any operations waiting for connection + bw.cond.Broadcast() + + return nil +} + +// Close closes the writer and prevents further writes. +// After closing, all Write calls will return io.ErrClosedPipe. +// This code keeps the Close() signature consistent with io.Closer, +// but it never actually returns an error. +func (bw *BackedWriter) Close() error { + bw.mu.Lock() + defer bw.mu.Unlock() + + if bw.closed { + return nil + } + + bw.closed = true + bw.writer = nil + + // Wake up any blocked operations + bw.cond.Broadcast() + + return nil +} + +// SetErrorCallback sets the callback function that will be called when +// a connection error occurs. +func (bw *BackedWriter) SetErrorCallback(fn func(error)) { + bw.mu.Lock() + defer bw.mu.Unlock() + bw.onError = fn +} + +// SequenceNum returns the current sequence number (total bytes written). +func (bw *BackedWriter) SequenceNum() uint64 { + bw.mu.Lock() + defer bw.mu.Unlock() + return bw.sequenceNum +} + +// Connected returns whether the writer is currently connected. +func (bw *BackedWriter) Connected() bool { + bw.mu.Lock() + defer bw.mu.Unlock() + return bw.writer != nil +} + +// CanReplayFrom returns true if the writer can replay data from the given sequence number. +func (bw *BackedWriter) CanReplayFrom(seqNum uint64) bool { + bw.mu.Lock() + defer bw.mu.Unlock() + return seqNum <= bw.sequenceNum && bw.sequenceNum-seqNum <= DefaultBufferSize +} + +// WaitForConnection blocks until the writer is connected or the context is canceled. +func (bw *BackedWriter) WaitForConnection(ctx context.Context) error { + bw.mu.Lock() + defer bw.mu.Unlock() + + return bw.waitForConnectionLocked(ctx) +} + +// waitForConnectionLocked waits for connection with lock held. +func (bw *BackedWriter) waitForConnectionLocked(ctx context.Context) error { + for bw.writer == nil && !bw.closed { + select { + case <-ctx.Done(): + return ctx.Err() + default: + // Use a timeout to avoid infinite waiting + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + bw.cond.Broadcast() + case <-done: + } + }() + + bw.cond.Wait() + close(done) + + // Check context again after waking up + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + } + } + + if bw.closed { + return io.ErrClosedPipe + } + + return nil +} diff --git a/coderd/agentapi/backedpipe/backed_writer_test.go b/coderd/agentapi/backedpipe/backed_writer_test.go new file mode 100644 index 0000000000000..f92a79c6f366b --- /dev/null +++ b/coderd/agentapi/backedpipe/backed_writer_test.go @@ -0,0 +1,411 @@ +package backedpipe_test + +import ( + "bytes" + "context" + "io" + "sync" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/agentapi/backedpipe" + "github.com/coder/coder/v2/testutil" +) + +// mockWriter implements io.Writer with controllable behavior for testing +type mockWriter struct { + mu sync.Mutex + buffer bytes.Buffer + err error + writeFunc func([]byte) (int, error) + writeCalls int +} + +func newMockWriter() *mockWriter { + return &mockWriter{} +} + +// newBackedWriterForTest creates a BackedWriter with a small buffer for testing eviction behavior +func newBackedWriterForTest(bufferSize int) *backedpipe.BackedWriter { + return backedpipe.NewBackedWriterWithCapacity(bufferSize) +} + +func (mw *mockWriter) Write(p []byte) (int, error) { + mw.mu.Lock() + defer mw.mu.Unlock() + + mw.writeCalls++ + + if mw.writeFunc != nil { + return mw.writeFunc(p) + } + + if mw.err != nil { + return 0, mw.err + } + + return mw.buffer.Write(p) +} + +func (mw *mockWriter) Len() int { + mw.mu.Lock() + defer mw.mu.Unlock() + return mw.buffer.Len() +} + +func (mw *mockWriter) Reset() { + mw.mu.Lock() + defer mw.mu.Unlock() + mw.buffer.Reset() + mw.writeCalls = 0 + mw.err = nil + mw.writeFunc = nil +} + +func (mw *mockWriter) setError(err error) { + mw.mu.Lock() + defer mw.mu.Unlock() + mw.err = err +} + +func TestBackedWriter_NewBackedWriter(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + require.NotNil(t, bw) + require.Equal(t, uint64(0), bw.SequenceNum()) + require.False(t, bw.Connected()) +} + +func TestBackedWriter_WriteToBufferWhenDisconnected(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + + // Write should succeed even when disconnected + n, err := bw.Write([]byte("hello")) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, uint64(5), bw.SequenceNum()) + + // Data should be in buffer +} + +func TestBackedWriter_WriteToUnderlyingWhenConnected(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + writer := newMockWriter() + + // Connect + err := bw.Reconnect(0, writer) + require.NoError(t, err) + require.True(t, bw.Connected()) + + // Write should go to both buffer and underlying writer + n, err := bw.Write([]byte("hello")) + require.NoError(t, err) + require.Equal(t, 5, n) + + // Data should be buffered + + // Check underlying writer + require.Equal(t, []byte("hello"), writer.buffer.Bytes()) +} + +func TestBackedWriter_DisconnectOnWriteFailure(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + writer := newMockWriter() + + // Connect + err := bw.Reconnect(0, writer) + require.NoError(t, err) + + // Cause write to fail + writer.setError(xerrors.New("write failed")) + + // Write should still succeed to buffer but disconnect + n, err := bw.Write([]byte("hello")) + require.NoError(t, err) // Buffer write succeeds + require.Equal(t, 5, n) + require.False(t, bw.Connected()) // Should be disconnected + + // Data should still be in buffer +} + +func TestBackedWriter_ReplayOnReconnect(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + + // Write some data while disconnected + bw.Write([]byte("hello")) + bw.Write([]byte(" world")) + + require.Equal(t, uint64(11), bw.SequenceNum()) + + // Reconnect and request replay from beginning + writer := newMockWriter() + err := bw.Reconnect(0, writer) + require.NoError(t, err) + + // Should have replayed all data + require.Equal(t, []byte("hello world"), writer.buffer.Bytes()) + + // Write new data should go to both + bw.Write([]byte("!")) + require.Equal(t, []byte("hello world!"), writer.buffer.Bytes()) +} + +func TestBackedWriter_PartialReplay(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + + // Write some data + bw.Write([]byte("hello")) + bw.Write([]byte(" world")) + bw.Write([]byte("!")) + + // Reconnect and request replay from middle + writer := newMockWriter() + err := bw.Reconnect(5, writer) // From " world!" + require.NoError(t, err) + + // Should have replayed only the requested portion + require.Equal(t, []byte(" world!"), writer.buffer.Bytes()) +} + +func TestBackedWriter_ReplayFromFutureSequence(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + bw.Write([]byte("hello")) + + writer := newMockWriter() + err := bw.Reconnect(10, writer) // Future sequence + require.Error(t, err) + require.Contains(t, err.Error(), "future sequence") +} + +func TestBackedWriter_ReplayDataLoss(t *testing.T) { + t.Parallel() + + bw := newBackedWriterForTest(10) // Small buffer for testing + + // Fill buffer beyond capacity to cause eviction + bw.Write([]byte("0123456789")) // Fills buffer exactly + bw.Write([]byte("abcdef")) // Should evict "012345" + + writer := newMockWriter() + err := bw.Reconnect(0, writer) // Try to replay from evicted data + // With the new error handling, this should fail because we can't read all the data + require.Error(t, err) + require.Contains(t, err.Error(), "failed to read replay data") +} + +func TestBackedWriter_BufferEviction(t *testing.T) { + t.Parallel() + + bw := newBackedWriterForTest(5) // Very small buffer for testing + + // Write data that will cause eviction + n, err := bw.Write([]byte("abcde")) + require.NoError(t, err) + require.Equal(t, 5, n) + + // Write more to cause eviction + n, err = bw.Write([]byte("fg")) + require.NoError(t, err) + require.Equal(t, 2, n) + + // Buffer should contain "cdefg" (latest data) +} + +func TestBackedWriter_Close(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + writer := newMockWriter() + + bw.Reconnect(0, writer) + + err := bw.Close() + require.NoError(t, err) + + // Writes after close should fail + _, err = bw.Write([]byte("test")) + require.Equal(t, io.ErrClosedPipe, err) + + // Reconnect after close should fail + err = bw.Reconnect(0, newMockWriter()) + require.Error(t, err) + require.Contains(t, err.Error(), "closed") +} + +func TestBackedWriter_CloseIdempotent(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + + err := bw.Close() + require.NoError(t, err) + + // Second close should be no-op + err = bw.Close() + require.NoError(t, err) +} + +func TestBackedWriter_CanReplayFrom(t *testing.T) { + t.Parallel() + + bw := newBackedWriterForTest(10) // Small buffer for testing eviction + + // Empty buffer + require.True(t, bw.CanReplayFrom(0)) + require.False(t, bw.CanReplayFrom(1)) + + // Write some data + bw.Write([]byte("hello")) + require.True(t, bw.CanReplayFrom(0)) + require.True(t, bw.CanReplayFrom(3)) + require.True(t, bw.CanReplayFrom(5)) + require.False(t, bw.CanReplayFrom(6)) + + // Fill buffer and cause eviction + bw.Write([]byte("world!")) + require.True(t, bw.CanReplayFrom(0)) // Can replay from any sequence up to current + require.True(t, bw.CanReplayFrom(bw.SequenceNum())) +} + +func TestBackedWriter_WaitForConnection(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + + // Should timeout when not connected + // Use a shorter timeout for this test to speed up test runs + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperShort) + defer cancel() + + err := bw.WaitForConnection(ctx) + require.Equal(t, context.DeadlineExceeded, err) + + // Should succeed immediately when connected + writer := newMockWriter() + bw.Reconnect(0, writer) + + ctx = context.Background() + err = bw.WaitForConnection(ctx) + require.NoError(t, err) +} + +func TestBackedWriter_ConcurrentWrites(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + writer := newMockWriter() + bw.Reconnect(0, writer) + + var wg sync.WaitGroup + numWriters := 10 + writesPerWriter := 50 + + for i := 0; i < numWriters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < writesPerWriter; j++ { + data := []byte{byte(id + '0')} + bw.Write(data) + } + }(i) + } + + wg.Wait() + + // Should have written expected amount to buffer + expectedBytes := uint64(numWriters * writesPerWriter) //nolint:gosec // Safe conversion: test constants with small values + require.Equal(t, expectedBytes, bw.SequenceNum()) + // Note: underlying writer may not receive all bytes due to potential disconnections + // during concurrent operations, but the buffer should track all writes + require.True(t, writer.Len() <= int(expectedBytes)) //nolint:gosec // Safe conversion: expectedBytes is calculated from small test values +} + +func TestBackedWriter_ReconnectDuringReplay(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + bw.Write([]byte("hello world")) + + // Create a writer that fails during replay + writer := &mockWriter{ + writeFunc: func(p []byte) (int, error) { + return 0, xerrors.New("replay failed") + }, + } + + err := bw.Reconnect(0, writer) + require.Error(t, err) + require.Contains(t, err.Error(), "replay failed") + require.False(t, bw.Connected()) +} + +func TestBackedWriter_PartialWriteToUnderlying(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + + // Create writer that does partial writes + writer := &mockWriter{ + writeFunc: func(p []byte) (int, error) { + if len(p) > 3 { + return 3, nil // Only write first 3 bytes + } + return len(p), nil + }, + } + + bw.Reconnect(0, writer) + + // Write should succeed to buffer but disconnect due to partial write + n, err := bw.Write([]byte("hello")) + require.NoError(t, err) + require.Equal(t, 5, n) + require.False(t, bw.Connected()) + + // Buffer should have all data +} + +func BenchmarkBackedWriter_Write(b *testing.B) { + bw := backedpipe.NewBackedWriter() // 64KB buffer + writer := newMockWriter() + bw.Reconnect(0, writer) + + data := bytes.Repeat([]byte("x"), 1024) // 1KB writes + + b.ResetTimer() + for i := 0; i < b.N; i++ { + bw.Write(data) + } +} + +func BenchmarkBackedWriter_Reconnect(b *testing.B) { + bw := backedpipe.NewBackedWriter() + + // Fill buffer with data + data := bytes.Repeat([]byte("x"), 1024) + for i := 0; i < 32; i++ { + bw.Write(data) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + writer := newMockWriter() + bw.Reconnect(0, writer) + } +} diff --git a/coderd/agentapi/backedpipe/ring_buffer.go b/coderd/agentapi/backedpipe/ring_buffer.go new file mode 100644 index 0000000000000..f092385741e0c --- /dev/null +++ b/coderd/agentapi/backedpipe/ring_buffer.go @@ -0,0 +1,140 @@ +package backedpipe + +import ( + "sync" + + "golang.org/x/xerrors" +) + +// RingBuffer implements an efficient circular buffer with a fixed-size allocation. +// It supports concurrent access and handles wrap-around seamlessly. +// The buffer is designed for high-performance scenarios where avoiding +// dynamic memory allocation during operation is critical. +type RingBuffer struct { + mu sync.RWMutex + buffer []byte + start int // index of first valid byte + end int // index after last valid byte + size int // current number of bytes in buffer + cap int // maximum capacity +} + +// NewRingBuffer creates a new ring buffer with 64MB capacity. +func NewRingBuffer() *RingBuffer { + const capacity = 64 * 1024 * 1024 // 64MB + return NewRingBufferWithCapacity(capacity) +} + +// NewRingBufferWithCapacity creates a new ring buffer with the specified capacity. +// If capacity is <= 0, it defaults to 64MB. +func NewRingBufferWithCapacity(capacity int) *RingBuffer { + if capacity <= 0 { + capacity = 64 * 1024 * 1024 // Default to 64MB + } + return &RingBuffer{ + buffer: make([]byte, capacity), + cap: capacity, + } +} + +// Write writes data to the ring buffer. If the buffer would overflow, +// it evicts the oldest data to make room for new data. +// Returns the number of bytes written and the number of bytes evicted. +func (rb *RingBuffer) Write(data []byte) (written int, evicted int) { + if len(data) == 0 { + return 0, 0 + } + + rb.mu.Lock() + defer rb.mu.Unlock() + + written = len(data) + + // If data is larger than capacity, only keep the last capacity bytes + if len(data) > rb.cap { + evicted = len(data) - rb.cap + data = data[evicted:] + written = rb.cap + // Clear buffer and write new data + rb.start = 0 + rb.end = 0 + rb.size = 0 + } + + // Calculate how much we need to evict to fit new data + spaceNeeded := len(data) + availableSpace := rb.cap - rb.size + + if spaceNeeded > availableSpace { + bytesToEvict := spaceNeeded - availableSpace + evicted += bytesToEvict + rb.evict(bytesToEvict) + } + + // Write the data + for _, b := range data { + rb.buffer[rb.end] = b + rb.end = (rb.end + 1) % rb.cap + rb.size++ + } + + return written, evicted +} + +// evict removes the specified number of bytes from the beginning of the buffer. +// Must be called with lock held. +func (rb *RingBuffer) evict(count int) { + if count >= rb.size { + // Evict everything + rb.start = 0 + rb.end = 0 + rb.size = 0 + return + } + + rb.start = (rb.start + count) % rb.cap + rb.size -= count +} + +// ReadLast returns the last n bytes from the buffer. +// If n is greater than the available data, returns all available data. +// If n is 0 or negative, returns nil. +func (rb *RingBuffer) ReadLast(n int) ([]byte, error) { + rb.mu.RLock() + defer rb.mu.RUnlock() + + if n <= 0 { + return nil, nil + } + + if rb.size == 0 { + return nil, xerrors.New("buffer is empty") + } + + // If requested more than available, return error + if n > rb.size { + return nil, xerrors.Errorf("requested %d bytes but only %d available", n, rb.size) + } + + result := make([]byte, n) + + // Calculate where to start reading from (n bytes before the end) + startOffset := rb.size - n + actualStart := rb.start + startOffset + if rb.cap > 0 { + actualStart %= rb.cap + } + + // Copy the last n bytes + if actualStart+n <= rb.cap { + // No wrap needed + copy(result, rb.buffer[actualStart:actualStart+n]) + } else { + // Need to wrap around + firstChunk := rb.cap - actualStart + copy(result[0:firstChunk], rb.buffer[actualStart:rb.cap]) + copy(result[firstChunk:], rb.buffer[0:n-firstChunk]) + } + + return result, nil +} diff --git a/coderd/agentapi/backedpipe/ring_buffer_internal_test.go b/coderd/agentapi/backedpipe/ring_buffer_internal_test.go new file mode 100644 index 0000000000000..5a23880774057 --- /dev/null +++ b/coderd/agentapi/backedpipe/ring_buffer_internal_test.go @@ -0,0 +1,162 @@ +package backedpipe + +import ( + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRingBuffer_ClearInternal(t *testing.T) { + t.Parallel() + + rb := NewRingBufferWithCapacity(10) + rb.Write([]byte("hello")) + require.Equal(t, 5, rb.size) + + rb.Clear() + require.Equal(t, 0, rb.size) + require.Equal(t, "", rb.String()) +} + +func TestRingBuffer_Available(t *testing.T) { + t.Parallel() + + rb := NewRingBufferWithCapacity(10) + require.Equal(t, 10, rb.Available()) + + rb.Write([]byte("hello")) + require.Equal(t, 5, rb.Available()) + + rb.Write([]byte("world")) + require.Equal(t, 0, rb.Available()) +} + +func TestRingBuffer_StringInternal(t *testing.T) { + t.Parallel() + + rb := NewRingBufferWithCapacity(10) + require.Equal(t, "", rb.String()) + + rb.Write([]byte("hello")) + require.Equal(t, "hello", rb.String()) + + rb.Write([]byte("world")) + require.Equal(t, "helloworld", rb.String()) +} + +func TestRingBuffer_StringWithWrapAround(t *testing.T) { + t.Parallel() + + rb := NewRingBufferWithCapacity(5) + rb.Write([]byte("hello")) + require.Equal(t, "hello", rb.String()) + + rb.Write([]byte("world")) + require.Equal(t, "world", rb.String()) +} + +func TestRingBuffer_ConcurrentAccessWithString(t *testing.T) { + t.Parallel() + + rb := NewRingBufferWithCapacity(1000) + var wg sync.WaitGroup + + // Start multiple goroutines writing + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + data := fmt.Sprintf("data-%d", id) + for j := 0; j < 100; j++ { + rb.Write([]byte(data)) + } + }(i) + } + + wg.Wait() + + // Verify buffer is still in valid state + require.NotEmpty(t, rb.String()) +} + +func TestRingBuffer_EdgeCaseEvictionWithString(t *testing.T) { + t.Parallel() + + rb := NewRingBufferWithCapacity(3) + rb.Write([]byte("hello")) + rb.Write([]byte("world")) + + // Should evict "he" and keep "llo world" + require.Equal(t, "rld", rb.String()) + + // Write more data to cause more eviction + rb.Write([]byte("test")) + require.Equal(t, "est", rb.String()) +} + +// TestRingBuffer_ComplexWrapAroundScenarioWithString tests complex wrap-around with String +func TestRingBuffer_ComplexWrapAroundScenarioWithString(t *testing.T) { + t.Parallel() + + rb := NewRingBufferWithCapacity(5) + + // Fill buffer + rb.Write([]byte("abcde")) + require.Equal(t, "abcde", rb.String()) + + // Write more to cause wrap-around + rb.Write([]byte("fgh")) + require.Equal(t, "defgh", rb.String()) + + // Write even more + rb.Write([]byte("ijklmn")) + require.Equal(t, "jklmn", rb.String()) +} + +// Helper function to get available space (for internal tests only) +func (rb *RingBuffer) Available() int { + rb.mu.RLock() + defer rb.mu.RUnlock() + return rb.cap - rb.size +} + +// Helper function to clear buffer (for internal tests only) +func (rb *RingBuffer) Clear() { + rb.mu.Lock() + defer rb.mu.Unlock() + + rb.start = 0 + rb.end = 0 + rb.size = 0 +} + +// Helper function to get string representation (for internal tests only) +func (rb *RingBuffer) String() string { + rb.mu.RLock() + defer rb.mu.RUnlock() + + if rb.size == 0 { + return "" + } + + // readAllInternal equivalent for internal tests + if rb.size == 0 { + return "" + } + + result := make([]byte, rb.size) + + if rb.start+rb.size <= rb.cap { + // No wrap needed + copy(result, rb.buffer[rb.start:rb.start+rb.size]) + } else { + // Need to wrap around + firstChunk := rb.cap - rb.start + copy(result[0:firstChunk], rb.buffer[rb.start:rb.cap]) + copy(result[firstChunk:], rb.buffer[0:rb.size-firstChunk]) + } + + return string(result) +} diff --git a/coderd/agentapi/backedpipe/ring_buffer_test.go b/coderd/agentapi/backedpipe/ring_buffer_test.go new file mode 100644 index 0000000000000..3febbcb433e5a --- /dev/null +++ b/coderd/agentapi/backedpipe/ring_buffer_test.go @@ -0,0 +1,326 @@ +package backedpipe_test + +import ( + "bytes" + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/agentapi/backedpipe" +) + +func TestRingBuffer_NewRingBuffer(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(100) + // Test that we can write and read from the buffer + written, evicted := rb.Write([]byte("test")) + require.Equal(t, 4, written) + require.Equal(t, 0, evicted) + + data, err := rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, []byte("test"), data) +} + +func TestRingBuffer_WriteAndRead(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(10) + + // Write some data + rb.Write([]byte("hello")) + + // Read last 4 bytes + data, err := rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, "ello", string(data)) + + // Write more data + rb.Write([]byte("world")) + + // Read last 5 bytes + data, err = rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, "world", string(data)) + + // Read last 3 bytes + data, err = rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, "rld", string(data)) + + // Read more than available (should be 10 bytes total) + _, err = rb.ReadLast(15) + require.Error(t, err) + require.Contains(t, err.Error(), "requested 15 bytes but only") +} + +func TestRingBuffer_OverflowEviction(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(5) + + // Fill buffer + written, evicted := rb.Write([]byte("abcde")) + require.Equal(t, 5, written) + require.Equal(t, 0, evicted) + + // Overflow should evict oldest data + written, evicted = rb.Write([]byte("fg")) + require.Equal(t, 2, written) + require.Equal(t, 2, evicted) + + // Should now contain "cdefg" + data, err := rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, []byte("cdefg"), data) +} + +func TestRingBuffer_LargeWrite(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(5) + + // Write data larger than capacity + written, evicted := rb.Write([]byte("abcdefghij")) + require.Equal(t, 5, written) + require.Equal(t, 5, evicted) + + // Should contain last 5 bytes + data, err := rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, []byte("fghij"), data) +} + +func TestRingBuffer_WrapAround(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(5) + + // Fill buffer + rb.Write([]byte("abcde")) + + // Write more to cause wrap-around + rb.Write([]byte("fgh")) + + // Should contain "defgh" + data, err := rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, []byte("defgh"), data) + + // Test reading last 3 bytes after wrap + data, err = rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, []byte("fgh"), data) +} + +func TestRingBuffer_ReadLastEdgeCases(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(3) + + // Write some data (5 bytes to a 3-byte buffer, so only last 3 bytes remain) + rb.Write([]byte("hello")) + + // Test reading negative count + data, err := rb.ReadLast(-1) + require.NoError(t, err) + require.Nil(t, data) + + // Test reading zero bytes + data, err = rb.ReadLast(0) + require.NoError(t, err) + require.Nil(t, data) + + // Test reading more than available (buffer has 3 bytes, try to read 10) + _, err = rb.ReadLast(10) + require.Error(t, err) + require.Contains(t, err.Error(), "requested 10 bytes but only 3 available") + + // Test reading exact amount available + data, err = rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, []byte("llo"), data) +} + +func TestRingBuffer_EmptyWrite(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(10) + + // Write empty data + written, evicted := rb.Write([]byte{}) + require.Equal(t, 0, written) + require.Equal(t, 0, evicted) + + // Buffer should still be empty + _, err := rb.ReadLast(5) + require.Error(t, err) + require.Contains(t, err.Error(), "buffer is empty") +} + +func TestRingBuffer_ConcurrentAccess(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(1000) + var wg sync.WaitGroup + + // Start multiple goroutines writing + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + data := []byte(fmt.Sprintf("data-%d", id)) + for j := 0; j < 100; j++ { + rb.Write(data) + } + }(i) + } + + // Start multiple goroutines reading + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + _, err := rb.ReadLast(100) + if err != nil { + // Error is expected if buffer doesn't have enough data + continue + } + } + }() + } + + wg.Wait() + + // Verify buffer is still in valid state + data, err := rb.ReadLast(1000) + require.NoError(t, err) + require.NotNil(t, data) +} + +func TestRingBuffer_MultipleWrites(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(10) + + // Write data in chunks + rb.Write([]byte("ab")) + rb.Write([]byte("cd")) + rb.Write([]byte("ef")) + + data, err := rb.ReadLast(6) + require.NoError(t, err) + require.Equal(t, []byte("abcdef"), data) + + // Test partial reads + data, err = rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, []byte("cdef"), data) + + data, err = rb.ReadLast(2) + require.NoError(t, err) + require.Equal(t, []byte("ef"), data) +} + +func TestRingBuffer_EdgeCaseEviction(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(3) + + // Write data that will cause eviction + written, evicted := rb.Write([]byte("abc")) + require.Equal(t, 3, written) + require.Equal(t, 0, evicted) + + // Write more to cause eviction + written, evicted = rb.Write([]byte("d")) + require.Equal(t, 1, written) + require.Equal(t, 1, evicted) + + // Should now contain "bcd" + data, err := rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, []byte("bcd"), data) +} + +func TestRingBuffer_ComplexWrapAroundScenario(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(8) + + // Fill buffer + rb.Write([]byte("12345678")) + + // Evict some and add more to create complex wrap scenario + rb.Write([]byte("abcd")) + data, err := rb.ReadLast(8) + require.NoError(t, err) + require.Equal(t, []byte("5678abcd"), data) + + // Add more + rb.Write([]byte("xyz")) + data, err = rb.ReadLast(8) + require.NoError(t, err) + require.Equal(t, []byte("8abcdxyz"), data) + + // Test reading various amounts from the end + data, err = rb.ReadLast(7) + require.NoError(t, err) + require.Equal(t, []byte("abcdxyz"), data) + + data, err = rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, []byte("dxyz"), data) +} + +// Benchmark tests for performance validation +func BenchmarkRingBuffer_Write(b *testing.B) { + rb := backedpipe.NewRingBuffer() // Use full 64MB for benchmarks + data := bytes.Repeat([]byte("x"), 1024) // 1KB writes + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rb.Write(data) + } +} + +func BenchmarkRingBuffer_ReadLast(b *testing.B) { + rb := backedpipe.NewRingBuffer() // Use full 64MB for benchmarks + // Fill buffer with test data + for i := 0; i < 64; i++ { + rb.Write(bytes.Repeat([]byte("x"), 1024)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := rb.ReadLast((i % 100) + 1) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkRingBuffer_ConcurrentAccess(b *testing.B) { + rb := backedpipe.NewRingBuffer() // Use full 64MB for benchmarks + data := bytes.Repeat([]byte("x"), 100) + + // Pre-fill buffer with enough data + for i := 0; i < 100; i++ { + rb.Write(data) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + rb.Write(data) + _, err := rb.ReadLast(100) // Read only what we know is available + if err != nil { + b.Fatal(err) + } + } + }) +} diff --git a/testutil/duration.go b/testutil/duration.go index a8c35030cdea2..821684f6b0f98 100644 --- a/testutil/duration.go +++ b/testutil/duration.go @@ -7,10 +7,11 @@ import ( // Constants for timing out operations, usable for creating contexts // that timeout or in require.Eventually. const ( - WaitShort = 10 * time.Second - WaitMedium = 15 * time.Second - WaitLong = 25 * time.Second - WaitSuperLong = 60 * time.Second + WaitSuperShort = 100 * time.Millisecond + WaitShort = 10 * time.Second + WaitMedium = 15 * time.Second + WaitLong = 25 * time.Second + WaitSuperLong = 60 * time.Second ) // Constants for delaying repeated operations, e.g. in
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: