Skip to content

Commit 6a5d5c6

Browse files
committed
feat(agentssh): Gracefully close SSH sessions on Close
By tracking and closing sessions manually before closing the underlying connections, we ensure that the termination is propagated to SSH/SFTP clients and they're not left waiting for a connection timeout. Refs: #6177
1 parent 0224426 commit 6a5d5c6

File tree

1 file changed

+46
-1
lines changed

1 file changed

+46
-1
lines changed

agent/agentssh/agentssh.go

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ type Server struct {
5050
mu sync.RWMutex // Protects following.
5151
listeners map[net.Listener]struct{}
5252
conns map[net.Conn]struct{}
53+
sessions map[ssh.Session]struct{}
5354
closing chan struct{}
5455
// Wait for goroutines to exit, waited without
5556
// a lock on mu but protected by closing.
@@ -86,6 +87,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
8687
s := &Server{
8788
listeners: make(map[net.Listener]struct{}),
8889
conns: make(map[net.Conn]struct{}),
90+
sessions: make(map[ssh.Session]struct{}),
8991
logger: logger,
9092
}
9193

@@ -129,7 +131,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
129131
}
130132
},
131133
SubsystemHandlers: map[string]ssh.SubsystemHandler{
132-
"sftp": s.sftpHandler,
134+
"sftp": s.sessionHandler,
133135
},
134136
MaxTimeout: maxTimeout,
135137
}
@@ -152,7 +154,25 @@ func (s *Server) ConnStats() ConnStats {
152154
}
153155

154156
func (s *Server) sessionHandler(session ssh.Session) {
157+
if !s.trackSession(session, true) {
158+
session.Exit(MagicSessionErrorCode)
159+
return
160+
}
161+
defer s.trackSession(session, false)
162+
155163
ctx := session.Context()
164+
165+
switch ss := session.Subsystem(); ss {
166+
case "":
167+
case "sftp":
168+
s.sftpHandler(session)
169+
return
170+
default:
171+
s.logger.Debug(ctx, "unsupported subsystem", slog.F("subsystem", ss))
172+
_ = session.Exit(1)
173+
return
174+
}
175+
156176
err := s.sessionStart(session)
157177
var exitError *exec.ExitError
158178
if xerrors.As(err, &exitError) {
@@ -560,6 +580,25 @@ func (s *Server) trackConn(l net.Listener, c net.Conn, add bool) (ok bool) {
560580
return true
561581
}
562582

583+
// trackSession registers the session with the server. If the server is
584+
// closing, the session is not registered and should be closed.
585+
//
586+
//nolint:revive
587+
func (s *Server) trackSession(ss ssh.Session, add bool) (ok bool) {
588+
s.mu.Lock()
589+
defer s.mu.Unlock()
590+
if add {
591+
if s.closing != nil {
592+
// Server closed.
593+
return false
594+
}
595+
s.sessions[ss] = struct{}{}
596+
return true
597+
}
598+
delete(s.sessions, ss)
599+
return true
600+
}
601+
563602
// Close the server and all active connections. Server can be re-used
564603
// after Close is done.
565604
func (s *Server) Close() error {
@@ -573,6 +612,12 @@ func (s *Server) Close() error {
573612
}
574613
s.closing = make(chan struct{})
575614

615+
// Close all active sessions to gracefully
616+
// terminate client connections.
617+
for ss := range s.sessions {
618+
_ = ss.Close()
619+
}
620+
576621
// Close all active listeners and connections.
577622
for l := range s.listeners {
578623
_ = l.Close()

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy