Skip to content

Commit c916a9e

Browse files
authored
fix(agent): guard against multiple rpty race for same id (#7998)
* fix(agent): guard against multiple rpty race for same id * fix(agent): ensure pty is closed on error
1 parent 9440b3d commit c916a9e

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

agent/agent.go

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,16 +1025,32 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10251025
}()
10261026

10271027
var rpty *reconnectingPTY
1028-
rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID)
1028+
sendConnected := make(chan *reconnectingPTY, 1)
1029+
// On store, reserve this ID to prevent multiple concurrent new connections.
1030+
waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
10291031
if ok {
1032+
close(sendConnected) // Unused.
10301033
logger.Debug(ctx, "connecting to existing session")
1031-
rpty, ok = rawRPTY.(*reconnectingPTY)
1034+
c, ok := waitReady.(chan *reconnectingPTY)
10321035
if !ok {
1033-
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", rawRPTY)
1036+
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
10341037
}
1038+
rpty, ok = <-c
1039+
if !ok || rpty == nil {
1040+
return xerrors.Errorf("reconnecting pty closed before connection")
1041+
}
1042+
c <- rpty // Put it back for the next reconnect.
10351043
} else {
10361044
logger.Debug(ctx, "creating new session")
10371045

1046+
connected := false
1047+
defer func() {
1048+
if !connected && retErr != nil {
1049+
a.reconnectingPTYs.Delete(msg.ID)
1050+
close(sendConnected)
1051+
}
1052+
}()
1053+
10381054
// Empty command will default to the users shell!
10391055
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
10401056
if err != nil {
@@ -1055,7 +1071,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10551071
return xerrors.Errorf("start command: %w", err)
10561072
}
10571073

1058-
ctx, cancelFunc := context.WithCancel(ctx)
1074+
ctx, cancel := context.WithCancel(ctx)
10591075
rpty = &reconnectingPTY{
10601076
activeConns: map[string]net.Conn{
10611077
// We have to put the connection in the map instantly otherwise
@@ -1064,10 +1080,9 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10641080
},
10651081
ptty: ptty,
10661082
// Timeouts created with an after func can be reset!
1067-
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc),
1083+
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancel),
10681084
circularBuffer: circularBuffer,
10691085
}
1070-
a.reconnectingPTYs.Store(msg.ID, rpty)
10711086
// We don't need to separately monitor for the process exiting.
10721087
// When it exits, our ptty.OutputReader() will return EOF after
10731088
// reading all process output.
@@ -1115,8 +1130,12 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
11151130
rpty.Close()
11161131
a.reconnectingPTYs.Delete(msg.ID)
11171132
}); err != nil {
1133+
_ = process.Kill()
1134+
_ = ptty.Close()
11181135
return xerrors.Errorf("start routine: %w", err)
11191136
}
1137+
connected = true
1138+
sendConnected <- rpty
11201139
}
11211140
// Resize the PTY to initial height + width.
11221141
err := rpty.ptty.Resize(msg.Height, msg.Width)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy