Skip to content

Commit aa660e0

Browse files
authored
feat(agentssh): Gracefully close SSH sessions on Close (#7027)
By tracking and closing sessions manually before closing the underlying connections, we ensure that the termination is propagated to SSH/SFTP clients and they're not left waiting for a connection timeout. Refs: #6177
1 parent f4f40d0 commit aa660e0

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

agent/agentssh/agentssh.go

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ type Server struct {
5050
mu sync.RWMutex // Protects following.
5151
listeners map[net.Listener]struct{}
5252
conns map[net.Conn]struct{}
53+
sessions map[ssh.Session]struct{}
5354
closing chan struct{}
5455
// Wait for goroutines to exit, waited without
5556
// a lock on mu but protected by closing.
@@ -86,6 +87,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
8687
s := &Server{
8788
listeners: make(map[net.Listener]struct{}),
8889
conns: make(map[net.Conn]struct{}),
90+
sessions: make(map[ssh.Session]struct{}),
8991
logger: logger,
9092
}
9193

@@ -129,7 +131,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
129131
}
130132
},
131133
SubsystemHandlers: map[string]ssh.SubsystemHandler{
132-
"sftp": s.sftpHandler,
134+
"sftp": s.sessionHandler,
133135
},
134136
MaxTimeout: maxTimeout,
135137
}
@@ -152,7 +154,26 @@ func (s *Server) ConnStats() ConnStats {
152154
}
153155

154156
func (s *Server) sessionHandler(session ssh.Session) {
157+
if !s.trackSession(session, true) {
158+
// See (*Server).Close() for why we call Close instead of Exit.
159+
_ = session.Close()
160+
return
161+
}
162+
defer s.trackSession(session, false)
163+
155164
ctx := session.Context()
165+
166+
switch ss := session.Subsystem(); ss {
167+
case "":
168+
case "sftp":
169+
s.sftpHandler(session)
170+
return
171+
default:
172+
s.logger.Debug(ctx, "unsupported subsystem", slog.F("subsystem", ss))
173+
_ = session.Exit(1)
174+
return
175+
}
176+
156177
err := s.sessionStart(session)
157178
var exitError *exec.ExitError
158179
if xerrors.As(err, &exitError) {
@@ -560,6 +581,25 @@ func (s *Server) trackConn(l net.Listener, c net.Conn, add bool) (ok bool) {
560581
return true
561582
}
562583

584+
// trackSession registers the session with the server. If the server is
585+
// closing, the session is not registered and should be closed.
586+
//
587+
//nolint:revive
588+
func (s *Server) trackSession(ss ssh.Session, add bool) (ok bool) {
589+
s.mu.Lock()
590+
defer s.mu.Unlock()
591+
if add {
592+
if s.closing != nil {
593+
// Server closed.
594+
return false
595+
}
596+
s.sessions[ss] = struct{}{}
597+
return true
598+
}
599+
delete(s.sessions, ss)
600+
return true
601+
}
602+
563603
// Close the server and all active connections. Server can be re-used
564604
// after Close is done.
565605
func (s *Server) Close() error {
@@ -573,6 +613,15 @@ func (s *Server) Close() error {
573613
}
574614
s.closing = make(chan struct{})
575615

616+
// Close all active sessions to gracefully
617+
// terminate client connections.
618+
for ss := range s.sessions {
619+
// We call Close on the underlying channel here because we don't
620+
// want to send an exit status to the client (via Exit()).
621+
// Typically OpenSSH clients will return 255 as the exit status.
622+
_ = ss.Close()
623+
}
624+
576625
// Close all active listeners and connections.
577626
for l := range s.listeners {
578627
_ = l.Close()

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy