Skip to content

Commit 9b88a68

Browse files
committed
Protect map and state with the same mutex
I moved the conn closes back to the lifecycle, too.
1 parent 56ca7ac commit 9b88a68

File tree

3 files changed

+58
-58
lines changed

3 files changed

+58
-58
lines changed

agent/reconnectingpty/buffered.go

Lines changed: 47 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"io"
77
"net"
8-
"sync"
98
"time"
109

1110
"github.com/armon/circbuf"
@@ -23,9 +22,6 @@ import (
2322
type bufferedReconnectingPTY struct {
2423
command *pty.Cmd
2524

26-
// mutex protects writing to the circular buffer and connections.
27-
mutex sync.RWMutex
28-
2925
activeConns map[string]net.Conn
3026
circularBuffer *circbuf.Buffer
3127

@@ -100,7 +96,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo
10096
break
10197
}
10298
part := buffer[:read]
103-
rpty.mutex.Lock()
99+
rpty.state.cond.L.Lock()
104100
_, err = rpty.circularBuffer.Write(part)
105101
if err != nil {
106102
logger.Error(ctx, "write to circular buffer", slog.Error(err))
@@ -119,7 +115,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo
119115
rpty.metrics.WithLabelValues("write").Add(1)
120116
}
121117
}
122-
rpty.mutex.Unlock()
118+
rpty.state.cond.L.Unlock()
123119
}
124120
}()
125121

@@ -136,14 +132,29 @@ func (rpty *bufferedReconnectingPTY) lifecycle(ctx context.Context, logger slog.
136132
logger.Debug(ctx, "reconnecting pty ready")
137133
rpty.state.setState(StateReady, nil)
138134

139-
state, reasonErr := rpty.state.waitForStateOrContext(ctx, StateClosing)
135+
state, reasonErr := rpty.state.waitForStateOrContext(ctx, StateClosing, nil)
140136
if state < StateClosing {
141137
// If we have not closed yet then the context is what unblocked us (which
142138
// means the agent is shutting down) so move into the closing phase.
143139
rpty.Close(reasonErr.Error())
144140
}
145141
rpty.timer.Stop()
146142

143+
rpty.state.cond.L.Lock()
144+
// Log these closes only for debugging since the connections or processes
145+
// might have already closed on their own.
146+
for _, conn := range rpty.activeConns {
147+
err := conn.Close()
148+
if err != nil {
149+
logger.Debug(ctx, "closed conn with error", slog.Error(err))
150+
}
151+
}
152+
// Connections get removed once they close but it is possible there is still
153+
// some data that will be written before that happens so clear the map now to
154+
// avoid writing to closed connections.
155+
rpty.activeConns = map[string]net.Conn{}
156+
rpty.state.cond.L.Unlock()
157+
147158
// Log close/kill only for debugging since the process might have already
148159
// closed on its own.
149160
err := rpty.ptty.Close()
@@ -167,65 +178,49 @@ func (rpty *bufferedReconnectingPTY) Attach(ctx context.Context, connID string,
167178
ctx, cancel := context.WithCancel(ctx)
168179
defer cancel()
169180

170-
state, err := rpty.state.waitForStateOrContext(ctx, StateReady)
171-
if state != StateReady {
172-
return xerrors.Errorf("reconnecting pty ready wait: %w", err)
173-
}
181+
// Once we are ready, attach the active connection while we hold the mutex.
182+
_, err := rpty.state.waitForStateOrContext(ctx, StateReady, func(state State, err error) error {
183+
if state != StateReady {
184+
return xerrors.Errorf("reconnecting pty ready wait: %w", err)
185+
}
186+
187+
go heartbeat(ctx, rpty.timer, rpty.timeout)
188+
189+
// Resize the PTY to initial height + width.
190+
err = rpty.ptty.Resize(height, width)
191+
if err != nil {
192+
// We can continue after this, it's not fatal!
193+
logger.Warn(ctx, "reconnecting PTY initial resize failed, but will continue", slog.Error(err))
194+
rpty.metrics.WithLabelValues("resize").Add(1)
195+
}
174196

175-
go heartbeat(ctx, rpty.timer, rpty.timeout)
197+
// Write any previously stored data for the TTY and store the connection for
198+
// future writes.
199+
prevBuf := slices.Clone(rpty.circularBuffer.Bytes())
200+
_, err = conn.Write(prevBuf)
201+
if err != nil {
202+
rpty.metrics.WithLabelValues("write").Add(1)
203+
return xerrors.Errorf("write buffer to conn: %w", err)
204+
}
205+
rpty.activeConns[connID] = conn
176206

177-
err = rpty.doAttach(ctx, connID, conn, height, width, logger)
207+
return nil
208+
})
178209
if err != nil {
179210
return err
180211
}
181212

182-
go func() {
183-
_, _ = rpty.state.waitForStateOrContext(ctx, StateClosing)
184-
rpty.mutex.Lock()
185-
defer rpty.mutex.Unlock()
213+
defer func() {
214+
rpty.state.cond.L.Lock()
215+
defer rpty.state.cond.L.Unlock()
186216
delete(rpty.activeConns, connID)
187-
// Log closes only for debugging since the connection might have already
188-
// closed on its own.
189-
err := conn.Close()
190-
if err != nil {
191-
logger.Debug(ctx, "closed conn with error", slog.Error(err))
192-
}
193217
}()
194218

195219
// Pipe conn -> pty and block. pty -> conn is handled in newBuffered().
196220
readConnLoop(ctx, conn, rpty.ptty, rpty.metrics, logger)
197221
return nil
198222
}
199223

200-
// doAttach adds the connection to the map, replays the buffer, and starts the
201-
// heartbeat. It exists separately only so we can defer the mutex unlock which
202-
// is not possible in Attach since it blocks.
203-
func (rpty *bufferedReconnectingPTY) doAttach(ctx context.Context, connID string, conn net.Conn, height, width uint16, logger slog.Logger) error {
204-
// Ensure we do not write to or close connections while we attach.
205-
rpty.mutex.Lock()
206-
defer rpty.mutex.Unlock()
207-
208-
// Resize the PTY to initial height + width.
209-
err := rpty.ptty.Resize(height, width)
210-
if err != nil {
211-
// We can continue after this, it's not fatal!
212-
logger.Warn(ctx, "reconnecting PTY initial resize failed, but will continue", slog.Error(err))
213-
rpty.metrics.WithLabelValues("resize").Add(1)
214-
}
215-
216-
// Write any previously stored data for the TTY and store the connection for
217-
// future writes.
218-
prevBuf := slices.Clone(rpty.circularBuffer.Bytes())
219-
_, err = conn.Write(prevBuf)
220-
if err != nil {
221-
rpty.metrics.WithLabelValues("write").Add(1)
222-
return xerrors.Errorf("write buffer to conn: %w", err)
223-
}
224-
rpty.activeConns[connID] = conn
225-
226-
return nil
227-
}
228-
229224
func (rpty *bufferedReconnectingPTY) Wait() {
230225
_, _ = rpty.state.waitForState(StateClosing)
231226
}

agent/reconnectingpty/reconnectingpty.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,9 @@ func (s *ptyState) waitForState(state State) (State, error) {
167167
}
168168

169169
// waitForStateOrContext blocks until the state or a greater one is reached or
170-
// the provided context ends.
171-
func (s *ptyState) waitForStateOrContext(ctx context.Context, state State) (State, error) {
170+
// the provided context ends. If fn is non-nil it will be ran while the lock is
171+
// held and fn's error will replace waitForStateOrContext's error.
172+
func (s *ptyState) waitForStateOrContext(ctx context.Context, state State, fn func(state State, err error) error) (State, error) {
172173
nevermind := make(chan struct{})
173174
defer close(nevermind)
174175
go func() {
@@ -185,10 +186,14 @@ func (s *ptyState) waitForStateOrContext(ctx context.Context, state State) (Stat
185186
for ctx.Err() == nil && state > s.state {
186187
s.cond.Wait()
187188
}
189+
err := s.error
188190
if ctx.Err() != nil {
189-
return s.state, ctx.Err()
191+
err = ctx.Err()
190192
}
191-
return s.state, s.error
193+
if fn != nil {
194+
return s.state, fn(s.state, err)
195+
}
196+
return s.state, err
192197
}
193198

194199
// readConnLoop reads messages from conn and writes to ptty as needed. Blocks

agent/reconnectingpty/screen.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ func (rpty *screenReconnectingPTY) lifecycle(ctx context.Context, logger slog.Lo
130130
logger.Debug(ctx, "reconnecting pty ready")
131131
rpty.state.setState(StateReady, nil)
132132

133-
state, reasonErr := rpty.state.waitForStateOrContext(ctx, StateClosing)
133+
state, reasonErr := rpty.state.waitForStateOrContext(ctx, StateClosing, nil)
134134
if state < StateClosing {
135135
// If we have not closed yet then the context is what unblocked us (which
136136
// means the agent is shutting down) so move into the closing phase.
@@ -155,7 +155,7 @@ func (rpty *screenReconnectingPTY) Attach(ctx context.Context, _ string, conn ne
155155
ctx, cancel := context.WithCancel(ctx)
156156
defer cancel()
157157

158-
state, err := rpty.state.waitForStateOrContext(ctx, StateReady)
158+
state, err := rpty.state.waitForStateOrContext(ctx, StateReady, nil)
159159
if state != StateReady {
160160
return xerrors.Errorf("reconnecting pty ready wait: %w", err)
161161
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy