diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index f49a64924bd36..ec682a735c248 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -130,9 +130,10 @@ type Server struct { // a lock on mu but protected by closing. wg sync.WaitGroup - Execer agentexec.Execer - logger slog.Logger - srv *ssh.Server + Execer agentexec.Execer + logger slog.Logger + srv *ssh.Server + x11Forwarder *x11Forwarder config *Config @@ -188,6 +189,14 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom config: config, metrics: metrics, + x11Forwarder: &x11Forwarder{ + logger: logger, + x11HandlerErrors: metrics.x11HandlerErrors, + fs: fs, + displayOffset: *config.X11DisplayOffset, + sessions: make(map[*x11Session]struct{}), + connections: make(map[net.Conn]struct{}), + }, } srv := &ssh.Server{ @@ -455,7 +464,7 @@ func (s *Server) sessionHandler(session ssh.Session) { x11, hasX11 := session.X11() if hasX11 { - display, handled := s.x11Handler(ctx, x11) + display, handled := s.x11Forwarder.x11Handler(ctx, session) if !handled { logger.Error(ctx, "x11 handler failed") closeCause("x11 handler failed") @@ -1114,6 +1123,9 @@ func (s *Server) Close() error { s.mu.Unlock() + s.logger.Debug(ctx, "closing X11 forwarding") + _ = s.x11Forwarder.Close() + s.logger.Debug(ctx, "waiting for all goroutines to exit") s.wg.Wait() // Wait for all goroutines to exit. diff --git a/agent/agentssh/x11.go b/agent/agentssh/x11.go index 439f2c3021791..8c23d32bfa5d1 100644 --- a/agent/agentssh/x11.go +++ b/agent/agentssh/x11.go @@ -7,15 +7,16 @@ import ( "errors" "fmt" "io" - "math" "net" "os" "path/filepath" "strconv" + "sync" "time" "github.com/gliderlabs/ssh" "github.com/gofrs/flock" + "github.com/prometheus/client_golang/prometheus" "github.com/spf13/afero" gossh "golang.org/x/crypto/ssh" "golang.org/x/xerrors" @@ -29,8 +30,33 @@ const ( X11StartPort = 6000 // X11DefaultDisplayOffset is the default offset for X11 forwarding. X11DefaultDisplayOffset = 10 + X11MaxDisplays = 200 + // X11MaxPort is the highest port we will ever use for X11 forwarding. This limits the total number of TCP sockets + // we will create. It seems more useful to have a maximum port number than a direct limit on sockets with no max + // port because we'd like to be able to tell users the exact range of ports the Agent might use. + X11MaxPort = X11StartPort + X11MaxDisplays ) +type x11Forwarder struct { + logger slog.Logger + x11HandlerErrors *prometheus.CounterVec + fs afero.Fs + displayOffset int + + mu sync.Mutex + sessions map[*x11Session]struct{} + connections map[net.Conn]struct{} + closing bool + wg sync.WaitGroup +} + +type x11Session struct { + session ssh.Session + display int + listener net.Listener + usedAt time.Time +} + // x11Callback is called when the client requests X11 forwarding. func (*Server) x11Callback(_ ssh.Context, _ ssh.X11) bool { // Always allow. @@ -39,114 +65,234 @@ func (*Server) x11Callback(_ ssh.Context, _ ssh.X11) bool { // x11Handler is called when a session has requested X11 forwarding. // It listens for X11 connections and forwards them to the client. -func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) (displayNumber int, handled bool) { +func (x *x11Forwarder) x11Handler(ctx ssh.Context, sshSession ssh.Session) (displayNumber int, handled bool) { + x11, hasX11 := sshSession.X11() + if !hasX11 { + return -1, false + } serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn) if !valid { - s.logger.Warn(ctx, "failed to get server connection") + x.logger.Warn(ctx, "failed to get server connection") return -1, false } hostname, err := os.Hostname() if err != nil { - s.logger.Warn(ctx, "failed to get hostname", slog.Error(err)) - s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1) + x.logger.Warn(ctx, "failed to get hostname", slog.Error(err)) + x.x11HandlerErrors.WithLabelValues("hostname").Add(1) return -1, false } - ln, display, err := createX11Listener(ctx, *s.config.X11DisplayOffset) + x11session, err := x.createX11Session(ctx, sshSession) if err != nil { - s.logger.Warn(ctx, "failed to create X11 listener", slog.Error(err)) - s.metrics.x11HandlerErrors.WithLabelValues("listen").Add(1) + x.logger.Warn(ctx, "failed to create X11 listener", slog.Error(err)) + x.x11HandlerErrors.WithLabelValues("listen").Add(1) return -1, false } - s.trackListener(ln, true) defer func() { if !handled { - s.trackListener(ln, false) - _ = ln.Close() + x.closeAndRemoveSession(x11session) } }() - err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(display), x11.AuthProtocol, x11.AuthCookie) + err = addXauthEntry(ctx, x.fs, hostname, strconv.Itoa(x11session.display), x11.AuthProtocol, x11.AuthCookie) if err != nil { - s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err)) - s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1) + x.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err)) + x.x11HandlerErrors.WithLabelValues("xauthority").Add(1) return -1, false } + // clean up the X11 session if the SSH session completes. go func() { - // Don't leave the listener open after the session is gone. <-ctx.Done() - _ = ln.Close() + x.closeAndRemoveSession(x11session) }() - go func() { - defer ln.Close() - defer s.trackListener(ln, false) - - for { - conn, err := ln.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return - } - s.logger.Warn(ctx, "failed to accept X11 connection", slog.Error(err)) + go x.listenForConnections(ctx, x11session, serverConn, x11) + + return x11session.display, true +} + +func (x *x11Forwarder) trackGoroutine() (closing bool, done func()) { + x.mu.Lock() + defer x.mu.Unlock() + if !x.closing { + x.wg.Add(1) + return false, func() { x.wg.Done() } + } + return true, func() {} +} + +func (x *x11Forwarder) listenForConnections( + ctx context.Context, session *x11Session, serverConn *gossh.ServerConn, x11 ssh.X11, +) { + defer x.closeAndRemoveSession(session) + if closing, done := x.trackGoroutine(); closing { + return + } else { // nolint: revive + defer done() + } + + for { + conn, err := session.listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { return } - if x11.SingleConnection { - s.logger.Debug(ctx, "single connection requested, closing X11 listener") - _ = ln.Close() - } + x.logger.Warn(ctx, "failed to accept X11 connection", slog.Error(err)) + return + } + if x11.SingleConnection { + x.logger.Debug(ctx, "single connection requested, closing X11 listener") + x.closeAndRemoveSession(session) + } - tcpConn, ok := conn.(*net.TCPConn) - if !ok { - s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn)) - _ = conn.Close() - continue - } - tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr) - if !ok { - s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr())) - _ = conn.Close() - continue - } + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + x.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn)) + _ = conn.Close() + continue + } + tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr) + if !ok { + x.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr())) + _ = conn.Close() + continue + } - channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct { - OriginatorAddress string - OriginatorPort uint32 - }{ - OriginatorAddress: tcpAddr.IP.String(), - // #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535) - OriginatorPort: uint32(tcpAddr.Port), - })) - if err != nil { - s.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err)) - _ = conn.Close() - continue - } - go gossh.DiscardRequests(reqs) + channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct { + OriginatorAddress string + OriginatorPort uint32 + }{ + OriginatorAddress: tcpAddr.IP.String(), + // #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535) + OriginatorPort: uint32(tcpAddr.Port), + })) + if err != nil { + x.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err)) + _ = conn.Close() + continue + } + go gossh.DiscardRequests(reqs) - if !s.trackConn(ln, conn, true) { - s.logger.Warn(ctx, "failed to track X11 connection") - _ = conn.Close() - continue - } - go func() { - defer s.trackConn(ln, conn, false) - Bicopy(ctx, conn, channel) - }() + if !x.trackConn(conn, true) { + x.logger.Warn(ctx, "failed to track X11 connection") + _ = conn.Close() + continue } - }() + go func() { + defer x.trackConn(conn, false) + Bicopy(ctx, conn, channel) + }() + } +} - return display, true +// closeAndRemoveSession closes and removes the session. +func (x *x11Forwarder) closeAndRemoveSession(x11session *x11Session) { + _ = x11session.listener.Close() + x.mu.Lock() + delete(x.sessions, x11session) + x.mu.Unlock() +} + +// createX11Session creates an X11 forwarding session. +func (x *x11Forwarder) createX11Session(ctx context.Context, sshSession ssh.Session) (*x11Session, error) { + var ( + ln net.Listener + display int + err error + ) + // retry listener creation after evictions. Limit to 10 retries to prevent pathological cases looping forever. + const maxRetries = 10 + for try := range maxRetries { + ln, display, err = x.createX11Listener(ctx) + if err == nil { + break + } + if try == maxRetries-1 { + return nil, xerrors.New("max retries exceeded while creating X11 session") + } + x.logger.Warn(ctx, "failed to create X11 listener; will evict an X11 forwarding session", + slog.F("num_current_sessions", x.numSessions()), + slog.Error(err)) + x.evictLeastRecentlyUsedSession() + } + x.mu.Lock() + defer x.mu.Unlock() + if x.closing { + closeErr := ln.Close() + if closeErr != nil { + x.logger.Error(ctx, "error closing X11 listener", slog.Error(closeErr)) + } + return nil, xerrors.New("server is closing") + } + x11Sess := &x11Session{ + session: sshSession, + display: display, + listener: ln, + usedAt: time.Now(), + } + x.sessions[x11Sess] = struct{}{} + return x11Sess, nil +} + +func (x *x11Forwarder) numSessions() int { + x.mu.Lock() + defer x.mu.Unlock() + return len(x.sessions) +} + +func (x *x11Forwarder) popLeastRecentlyUsedSession() *x11Session { + x.mu.Lock() + defer x.mu.Unlock() + var lru *x11Session + for s := range x.sessions { + if lru == nil { + lru = s + continue + } + if s.usedAt.Before(lru.usedAt) { + lru = s + continue + } + } + if lru == nil { + x.logger.Debug(context.Background(), "tried to pop from empty set of X11 sessions") + return nil + } + delete(x.sessions, lru) + return lru +} + +func (x *x11Forwarder) evictLeastRecentlyUsedSession() { + lru := x.popLeastRecentlyUsedSession() + if lru == nil { + return + } + err := lru.listener.Close() + if err != nil { + x.logger.Error(context.Background(), "failed to close evicted X11 session listener", slog.Error(err)) + } + // when we evict, we also want to force the SSH session to be closed as well. This is because we intend to reuse + // the X11 TCP listener port for a new X11 forwarding session. If we left the SSH session up, then graphical apps + // started in that session could potentially connect to an unintended X11 Server (i.e. the display on a different + // computer than the one that started the SSH session). Most likely, this session is a zombie anyway if we've + // reached the maximum number of X11 forwarding sessions. + err = lru.session.Close() + if err != nil { + x.logger.Error(context.Background(), "failed to close evicted X11 SSH session", slog.Error(err)) + } } // createX11Listener creates a listener for X11 forwarding, it will use // the next available port starting from X11StartPort and displayOffset. -func createX11Listener(ctx context.Context, displayOffset int) (ln net.Listener, display int, err error) { +func (x *x11Forwarder) createX11Listener(ctx context.Context) (ln net.Listener, display int, err error) { var lc net.ListenConfig // Look for an open port to listen on. - for port := X11StartPort + displayOffset; port < math.MaxUint16; port++ { + for port := X11StartPort + x.displayOffset; port <= X11MaxPort; port++ { + if ctx.Err() != nil { + return nil, -1, ctx.Err() + } ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port)) if err == nil { display = port - X11StartPort @@ -156,6 +302,49 @@ func createX11Listener(ctx context.Context, displayOffset int) (ln net.Listener, return nil, -1, xerrors.Errorf("failed to find open port for X11 listener: %w", err) } +// trackConn registers the connection with the x11Forwarder. If the server is +// closed, the connection is not registered and should be closed. +// +//nolint:revive +func (x *x11Forwarder) trackConn(c net.Conn, add bool) (ok bool) { + x.mu.Lock() + defer x.mu.Unlock() + if add { + if x.closing { + // Server or listener closed. + return false + } + x.wg.Add(1) + x.connections[c] = struct{}{} + return true + } + x.wg.Done() + delete(x.connections, c) + return true +} + +func (x *x11Forwarder) Close() error { + x.mu.Lock() + x.closing = true + + for s := range x.sessions { + sErr := s.listener.Close() + if sErr != nil { + x.logger.Debug(context.Background(), "failed to close X11 listener", slog.Error(sErr)) + } + } + for c := range x.connections { + cErr := c.Close() + if cErr != nil { + x.logger.Debug(context.Background(), "failed to close X11 connection", slog.Error(cErr)) + } + } + + x.mu.Unlock() + x.wg.Wait() + return nil +} + // addXauthEntry adds an Xauthority entry to the Xauthority file. // The Xauthority file is located at ~/.Xauthority. func addXauthEntry(ctx context.Context, fs afero.Fs, host string, display string, authProtocol string, authCookie string) error {
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: