From 37adb622eb7519b4cd275c0f2658bf2df96423d6 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 4 Apr 2023 14:34:04 +0000 Subject: [PATCH 1/7] refactor(agent): Move SSH server into agentssh package Refs: #6177 --- agent/agent.go | 568 ++----------------------- agent/agent_test.go | 11 +- agent/agentssh/agentssh.go | 576 ++++++++++++++++++++++++++ agent/agentssh/agentssh_test.go | 136 ++++++ agent/agentssh/bicopy.go | 47 +++ agent/{ssh.go => agentssh/forward.go} | 2 +- cli/portforward.go | 4 +- cli/ssh.go | 4 +- coderd/workspaceagents.go | 4 +- 9 files changed, 804 insertions(+), 548 deletions(-) create mode 100644 agent/agentssh/agentssh.go create mode 100644 agent/agentssh/agentssh_test.go create mode 100644 agent/agentssh/bicopy.go rename agent/{ssh.go => agentssh/forward.go} (99%) diff --git a/agent/agent.go b/agent/agent.go index e22d5c3576123..3906a182139ad 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -4,8 +4,6 @@ import ( "bufio" "bytes" "context" - "crypto/rand" - "crypto/rsa" "encoding/binary" "encoding/json" "errors" @@ -16,11 +14,9 @@ import ( "net/http" "net/netip" "os" - "os/exec" "os/user" "path/filepath" "reflect" - "runtime" "sort" "strconv" "strings" @@ -28,12 +24,9 @@ import ( "time" "github.com/armon/circbuf" - "github.com/gliderlabs/ssh" "github.com/google/uuid" - "github.com/pkg/sftp" "github.com/spf13/afero" "go.uber.org/atomic" - gossh "golang.org/x/crypto/ssh" "golang.org/x/exp/slices" "golang.org/x/xerrors" "tailscale.com/net/speedtest" @@ -41,7 +34,7 @@ import ( "tailscale.com/types/netlogtype" "cdr.dev/slog" - "github.com/coder/coder/agent/usershell" + "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/buildinfo" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/gitauth" @@ -56,19 +49,6 @@ const ( ProtocolReconnectingPTY = "reconnecting-pty" ProtocolSSH = "ssh" ProtocolDial = "dial" - - // MagicSessionErrorCode indicates that something went wrong with the session, rather than the - // command just returning a nonzero exit code, and is chosen as an arbitrary, high number - // unlikely to shadow other exit codes, which are typically 1, 2, 3, etc. - MagicSessionErrorCode = 229 - - // MagicSSHSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection. - // This is stripped from any commands being executed, and is counted towards connection stats. - MagicSSHSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE" - // MagicSSHSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself. - MagicSSHSessionTypeVSCode = "vscode" - // MagicSSHSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself. - MagicSSHSessionTypeJetBrains = "jetbrains" ) type Options struct { @@ -165,7 +145,7 @@ type agent struct { // manifest is atomic because values can change after reconnection. manifest atomic.Pointer[agentsdk.Manifest] sessionToken atomic.Pointer[string] - sshServer *ssh.Server + sshServer *agentssh.Server sshMaxTimeout time.Duration lifecycleUpdate chan struct{} @@ -177,10 +157,19 @@ type agent struct { connStatsChan chan *agentsdk.Stats latestStat atomic.Pointer[agentsdk.Stats] - connCountVSCode atomic.Int64 - connCountJetBrains atomic.Int64 connCountReconnectingPTY atomic.Int64 - connCountSSHSession atomic.Int64 +} + +func (a *agent) init(ctx context.Context) { + sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.sshMaxTimeout) + if err != nil { + panic(err) + } + sshSrv.Env = a.envVars + sshSrv.AgentToken = func() string { return *a.sessionToken.Load() } + a.sshServer = sshSrv + + go a.runLoop(ctx) } // runLoop attempts to start the agent in a retry loop. @@ -223,7 +212,7 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM // if it is certain the clocks are in sync. CollectedAt: time.Now(), } - cmd, err := a.createCommand(ctx, md.Script, nil) + cmd, err := a.sshServer.CreateCommand(ctx, md.Script, nil) if err != nil { result.Error = err.Error() return result @@ -489,6 +478,7 @@ func (a *agent) run(ctx context.Context) error { } oldManifest := a.manifest.Swap(&manifest) + a.sshServer.SetManifest(&manifest) // The startup script should only execute on the first run! if oldManifest == nil { @@ -633,28 +623,7 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ } }() if err = a.trackConnGoroutine(func() { - var wg sync.WaitGroup - for { - conn, err := sshListener.Accept() - if err != nil { - break - } - wg.Add(1) - closed := make(chan struct{}) - go func() { - select { - case <-closed: - case <-a.closed: - _ = conn.Close() - } - wg.Done() - }() - go func() { - defer close(closed) - a.sshServer.HandleConn(conn) - }() - } - wg.Wait() + _ = a.sshServer.Serve(sshListener) }); err != nil { return nil, err } @@ -857,7 +826,7 @@ func (a *agent) runScript(ctx context.Context, lifecycle, script string) error { }() } - cmd, err := a.createCommand(ctx, script, nil) + cmd, err := a.sshServer.CreateCommand(ctx, script, nil) if err != nil { return xerrors.Errorf("create command: %w", err) } @@ -990,394 +959,6 @@ func (a *agent) trackScriptLogs(ctx context.Context, reader io.Reader) (chan str return logsFinished, nil } -func (a *agent) init(ctx context.Context) { - // Clients' should ignore the host key when connecting. - // The agent needs to authenticate with coderd to SSH, - // so SSH authentication doesn't improve security. - randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - panic(err) - } - randomSigner, err := gossh.NewSignerFromKey(randomHostKey) - if err != nil { - panic(err) - } - - sshLogger := a.logger.Named("ssh-server") - forwardHandler := &ssh.ForwardedTCPHandler{} - unixForwardHandler := &forwardedUnixHandler{log: a.logger} - - a.sshServer = &ssh.Server{ - ChannelHandlers: map[string]ssh.ChannelHandler{ - "direct-tcpip": ssh.DirectTCPIPHandler, - "direct-streamlocal@openssh.com": directStreamLocalHandler, - "session": ssh.DefaultSessionHandler, - }, - ConnectionFailedCallback: func(conn net.Conn, err error) { - sshLogger.Info(ctx, "ssh connection ended", slog.Error(err)) - }, - Handler: func(session ssh.Session) { - err := a.handleSSHSession(session) - var exitError *exec.ExitError - if xerrors.As(err, &exitError) { - a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError)) - _ = session.Exit(exitError.ExitCode()) - return - } - if err != nil { - a.logger.Warn(ctx, "ssh session failed", slog.Error(err)) - // This exit code is designed to be unlikely to be confused for a legit exit code - // from the process. - _ = session.Exit(MagicSessionErrorCode) - return - } - _ = session.Exit(0) - }, - HostSigners: []ssh.Signer{randomSigner}, - LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { - // Allow local port forwarding all! - sshLogger.Debug(ctx, "local port forward", - slog.F("destination-host", destinationHost), - slog.F("destination-port", destinationPort)) - return true - }, - PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { - return true - }, - ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { - // Allow reverse port forwarding all! - sshLogger.Debug(ctx, "local port forward", - slog.F("bind-host", bindHost), - slog.F("bind-port", bindPort)) - return true - }, - RequestHandlers: map[string]ssh.RequestHandler{ - "tcpip-forward": forwardHandler.HandleSSHRequest, - "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, - "streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, - "cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, - }, - ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { - return &gossh.ServerConfig{ - NoClientAuth: true, - } - }, - SubsystemHandlers: map[string]ssh.SubsystemHandler{ - "sftp": func(session ssh.Session) { - ctx := session.Context() - - // Typically sftp sessions don't request a TTY, but if they do, - // we must ensure the gliderlabs/ssh CRLF emulation is disabled. - // Otherwise sftp will be broken. This can happen if a user sets - // `RequestTTY force` in their SSH config. - session.DisablePTYEmulation() - - var opts []sftp.ServerOption - // Change current working directory to the users home - // directory so that SFTP connections land there. - homedir, err := userHomeDir() - if err != nil { - sshLogger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err)) - } else { - opts = append(opts, sftp.WithServerWorkingDirectory(homedir)) - } - - server, err := sftp.NewServer(session, opts...) - if err != nil { - sshLogger.Debug(ctx, "initialize sftp server", slog.Error(err)) - return - } - defer server.Close() - - err = server.Serve() - if errors.Is(err, io.EOF) { - // Unless we call `session.Exit(0)` here, the client won't - // receive `exit-status` because `(*sftp.Server).Close()` - // calls `Close()` on the underlying connection (session), - // which actually calls `channel.Close()` because it isn't - // wrapped. This causes sftp clients to receive a non-zero - // exit code. Typically sftp clients don't echo this exit - // code but `scp` on macOS does (when using the default - // SFTP backend). - _ = session.Exit(0) - return - } - sshLogger.Warn(ctx, "sftp server closed with error", slog.Error(err)) - _ = session.Exit(1) - }, - }, - MaxTimeout: a.sshMaxTimeout, - } - - go a.runLoop(ctx) -} - -// createCommand processes raw command input with OpenSSH-like behavior. -// If the script provided is empty, it will default to the users shell. -// This injects environment variables specified by the user at launch too. -func (a *agent) createCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) { - currentUser, err := user.Current() - if err != nil { - return nil, xerrors.Errorf("get current user: %w", err) - } - username := currentUser.Username - - shell, err := usershell.Get(username) - if err != nil { - return nil, xerrors.Errorf("get user shell: %w", err) - } - - manifest := a.manifest.Load() - if manifest == nil { - return nil, xerrors.Errorf("no metadata was provided") - } - - // OpenSSH executes all commands with the users current shell. - // We replicate that behavior for IDE support. - caller := "-c" - if runtime.GOOS == "windows" { - caller = "/c" - } - args := []string{caller, script} - - // gliderlabs/ssh returns a command slice of zero - // when a shell is requested. - if len(script) == 0 { - args = []string{} - if runtime.GOOS != "windows" { - // On Linux and macOS, we should start a login - // shell to consume juicy environment variables! - args = append(args, "-l") - } - } - - cmd := exec.CommandContext(ctx, shell, args...) - cmd.Dir = manifest.Directory - - // If the metadata directory doesn't exist, we run the command - // in the users home directory. - _, err = os.Stat(cmd.Dir) - if cmd.Dir == "" || err != nil { - // Default to user home if a directory is not set. - homedir, err := userHomeDir() - if err != nil { - return nil, xerrors.Errorf("get home dir: %w", err) - } - cmd.Dir = homedir - } - cmd.Env = append(os.Environ(), env...) - executablePath, err := os.Executable() - if err != nil { - return nil, xerrors.Errorf("getting os executable: %w", err) - } - // Set environment variables reliable detection of being inside a - // Coder workspace. - cmd.Env = append(cmd.Env, "CODER=true") - cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username)) - // Git on Windows resolves with UNIX-style paths. - // If using backslashes, it's unable to find the executable. - unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/") - cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath)) - - // Specific Coder subcommands require the agent token exposed! - cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", *a.sessionToken.Load())) - - // Set SSH connection environment variables (these are also set by OpenSSH - // and thus expected to be present by SSH clients). Since the agent does - // networking in-memory, trying to provide accurate values here would be - // nonsensical. For now, we hard code these values so that they're present. - srcAddr, srcPort := "0.0.0.0", "0" - dstAddr, dstPort := "0.0.0.0", "0" - cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CLIENT=%s %s %s", srcAddr, srcPort, dstPort)) - cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", srcAddr, srcPort, dstAddr, dstPort)) - - // This adds the ports dialog to code-server that enables - // proxying a port dynamically. - cmd.Env = append(cmd.Env, fmt.Sprintf("VSCODE_PROXY_URI=%s", manifest.VSCodePortProxyURI)) - - // Hide Coder message on code-server's "Getting Started" page - cmd.Env = append(cmd.Env, "CS_DISABLE_GETTING_STARTED_OVERRIDE=true") - - // Load environment variables passed via the agent. - // These should override all variables we manually specify. - for envKey, value := range manifest.EnvironmentVariables { - // Expanding environment variables allows for customization - // of the $PATH, among other variables. Customers can prepend - // or append to the $PATH, so allowing expand is required! - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, os.ExpandEnv(value))) - } - - // Agent-level environment variables should take over all! - // This is used for setting agent-specific variables like "CODER_AGENT_TOKEN". - for envKey, value := range a.envVars { - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value)) - } - - return cmd, nil -} - -func (a *agent) handleSSHSession(session ssh.Session) (retErr error) { - ctx := session.Context() - env := session.Environ() - var magicType string - for index, kv := range env { - if !strings.HasPrefix(kv, MagicSSHSessionTypeEnvironmentVariable) { - continue - } - magicType = strings.TrimPrefix(kv, MagicSSHSessionTypeEnvironmentVariable+"=") - env = append(env[:index], env[index+1:]...) - } - switch magicType { - case MagicSSHSessionTypeVSCode: - a.connCountVSCode.Add(1) - defer a.connCountVSCode.Add(-1) - case MagicSSHSessionTypeJetBrains: - a.connCountJetBrains.Add(1) - defer a.connCountJetBrains.Add(-1) - case "": - a.connCountSSHSession.Add(1) - defer a.connCountSSHSession.Add(-1) - default: - a.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType)) - } - - cmd, err := a.createCommand(ctx, session.RawCommand(), env) - if err != nil { - return err - } - - if ssh.AgentRequested(session) { - l, err := ssh.NewAgentListener() - if err != nil { - return xerrors.Errorf("new agent listener: %w", err) - } - defer l.Close() - go ssh.ForwardAgentConnections(l, session) - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String())) - } - - sshPty, windowSize, isPty := session.Pty() - if isPty { - // Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL). - // See https://github.com/coder/coder/issues/3371. - session.DisablePTYEmulation() - - if !isQuietLogin(session.RawCommand()) { - manifest := a.manifest.Load() - if manifest != nil { - err = showMOTD(session, manifest.MOTDFile) - if err != nil { - a.logger.Error(ctx, "show MOTD", slog.Error(err)) - } - } else { - a.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") - } - } - - cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) - - // The pty package sets `SSH_TTY` on supported platforms. - ptty, process, err := pty.Start(cmd, pty.WithPTYOption( - pty.WithSSHRequest(sshPty), - pty.WithLogger(slog.Stdlib(ctx, a.logger, slog.LevelInfo)), - )) - if err != nil { - return xerrors.Errorf("start command: %w", err) - } - var wg sync.WaitGroup - defer func() { - defer wg.Wait() - closeErr := ptty.Close() - if closeErr != nil { - a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) - if retErr == nil { - retErr = closeErr - } - } - }() - go func() { - for win := range windowSize { - resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) - // If the pty is closed, then command has exited, no need to log. - if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { - a.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) - } - } - }() - // We don't add input copy to wait group because - // it won't return until the session is closed. - go func() { - _, _ = io.Copy(ptty.Input(), session) - }() - - // In low parallelism scenarios, the command may exit and we may close - // the pty before the output copy has started. This can result in the - // output being lost. To avoid this, we wait for the output copy to - // start before waiting for the command to exit. This ensures that the - // output copy goroutine will be scheduled before calling close on the - // pty. This shouldn't be needed because of `pty.Dup()` below, but it - // may not be supported on all platforms. - outputCopyStarted := make(chan struct{}) - ptyOutput := func() io.ReadCloser { - defer close(outputCopyStarted) - // Try to dup so we can separate stdin and stdout closure. - // Once the original pty is closed, the dup will return - // input/output error once the buffered data has been read. - stdout, err := ptty.Dup() - if err == nil { - return stdout - } - // If we can't dup, we shouldn't close - // the fd since it's tied to stdin. - return readNopCloser{ptty.Output()} - } - wg.Add(1) - go func() { - // Ensure data is flushed to session on command exit, if we - // close the session too soon, we might lose data. - defer wg.Done() - - stdout := ptyOutput() - defer stdout.Close() - - _, _ = io.Copy(session, stdout) - }() - <-outputCopyStarted - - err = process.Wait() - var exitErr *exec.ExitError - // ExitErrors just mean the command we run returned a non-zero exit code, which is normal - // and not something to be concerned about. But, if it's something else, we should log it. - if err != nil && !xerrors.As(err, &exitErr) { - a.logger.Warn(ctx, "wait error", slog.Error(err)) - } - return err - } - - cmd.Stdout = session - cmd.Stderr = session.Stderr() - // This blocks forever until stdin is received if we don't - // use StdinPipe. It's unknown what causes this. - stdinPipe, err := cmd.StdinPipe() - if err != nil { - return xerrors.Errorf("create stdin pipe: %w", err) - } - go func() { - _, _ = io.Copy(stdinPipe, session) - _ = stdinPipe.Close() - }() - err = cmd.Start() - if err != nil { - return xerrors.Errorf("start: %w", err) - } - return cmd.Wait() -} - -type readNopCloser struct{ io.Reader } - -// Close implements io.Closer. -func (readNopCloser) Close() error { return nil } - func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg codersdk.WorkspaceAgentReconnectingPTYInit, conn net.Conn) (retErr error) { defer conn.Close() @@ -1416,7 +997,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m logger.Debug(ctx, "creating new session") // Empty command will default to the users shell! - cmd, err := a.createCommand(ctx, msg.Command, nil) + cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil) if err != nil { return xerrors.Errorf("create command: %w", err) } @@ -1590,9 +1171,11 @@ func (a *agent) startReportingConnectionStats(ctx context.Context) { } // The count of active sessions. - stats.SessionCountSSH = a.connCountSSHSession.Load() - stats.SessionCountVSCode = a.connCountVSCode.Load() - stats.SessionCountJetBrains = a.connCountJetBrains.Load() + sshStats := a.sshServer.ConnStats() + stats.SessionCountSSH = sshStats.Sessions + stats.SessionCountVSCode = sshStats.VSCode + stats.SessionCountJetBrains = sshStats.JetBrains + stats.SessionCountReconnectingPTY = a.connCountReconnectingPTY.Load() // Compute the median connection latency! @@ -1692,8 +1275,16 @@ func (a *agent) Close() error { } ctx := context.Background() + a.logger.Info(ctx, "shutting down agent") a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShuttingDown) + // Attempt to gracefully shut down all active SSH connections and + // stop accepting new ones. + err := a.sshServer.Shutdown(ctx) + if err != nil { + a.logger.Error(ctx, "ssh server shutdown", slog.Error(err)) + } + lifecycleState := codersdk.WorkspaceAgentLifecycleOff if manifest := a.manifest.Load(); manifest != nil && manifest.ShutdownScript != "" { scriptDone := make(chan error, 1) @@ -1785,101 +1376,6 @@ func (r *reconnectingPTY) Close() { r.timeout.Stop() } -// Bicopy copies all of the data between the two connections and will close them -// after one or both of them are done writing. If the context is canceled, both -// of the connections will be closed. -func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - defer func() { - _ = c1.Close() - _ = c2.Close() - }() - - var wg sync.WaitGroup - copyFunc := func(dst io.WriteCloser, src io.Reader) { - defer func() { - wg.Done() - // If one side of the copy fails, ensure the other one exits as - // well. - cancel() - }() - _, _ = io.Copy(dst, src) - } - - wg.Add(2) - go copyFunc(c1, c2) - go copyFunc(c2, c1) - - // Convert waitgroup to a channel so we can also wait on the context. - done := make(chan struct{}) - go func() { - defer close(done) - wg.Wait() - }() - - select { - case <-ctx.Done(): - case <-done: - } -} - -// isQuietLogin checks if the SSH server should perform a quiet login or not. -// -// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L816 -func isQuietLogin(rawCommand string) bool { - // We are always quiet unless this is a login shell. - if len(rawCommand) != 0 { - return true - } - - // Best effort, if we can't get the home directory, - // we can't lookup .hushlogin. - homedir, err := userHomeDir() - if err != nil { - return false - } - - _, err = os.Stat(filepath.Join(homedir, ".hushlogin")) - return err == nil -} - -// showMOTD will output the message of the day from -// the given filename to dest, if the file exists. -// -// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L784 -func showMOTD(dest io.Writer, filename string) error { - if filename == "" { - return nil - } - - f, err := os.Open(filename) - if err != nil { - if xerrors.Is(err, os.ErrNotExist) { - // This is not an error, there simply isn't a MOTD to show. - return nil - } - return xerrors.Errorf("open MOTD: %w", err) - } - defer f.Close() - - s := bufio.NewScanner(f) - for s.Scan() { - // Carriage return ensures each line starts - // at the beginning of the terminal. - _, err = fmt.Fprint(dest, s.Text()+"\r\n") - if err != nil { - return xerrors.Errorf("write MOTD: %w", err) - } - } - if err := s.Err(); err != nil { - return xerrors.Errorf("read MOTD: %w", err) - } - - return nil -} - // userHomeDir returns the home directory of the current user, giving // priority to the $HOME environment variable. func userHomeDir() (string, error) { diff --git a/agent/agent_test.go b/agent/agent_test.go index ec76aa1b0b6b9..8d7d641e1f73d 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -41,6 +41,7 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" + "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk/agentsdk" @@ -131,13 +132,13 @@ func TestAgent_Stats_Magic(t *testing.T) { defer sshClient.Close() session, err := sshClient.NewSession() require.NoError(t, err) - session.Setenv(agent.MagicSSHSessionTypeEnvironmentVariable, agent.MagicSSHSessionTypeVSCode) + session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, agentssh.MagicSessionTypeVSCode) defer session.Close() - command := "sh -c 'echo $" + agent.MagicSSHSessionTypeEnvironmentVariable + "'" + command := "sh -c 'echo $" + agentssh.MagicSessionTypeEnvironmentVariable + "'" expected := "" if runtime.GOOS == "windows" { - expected = "%" + agent.MagicSSHSessionTypeEnvironmentVariable + "%" + expected = "%" + agentssh.MagicSessionTypeEnvironmentVariable + "%" command = "cmd.exe /c echo " + expected } output, err := session.Output(command) @@ -158,7 +159,7 @@ func TestAgent_Stats_Magic(t *testing.T) { defer sshClient.Close() session, err := sshClient.NewSession() require.NoError(t, err) - session.Setenv(agent.MagicSSHSessionTypeEnvironmentVariable, agent.MagicSSHSessionTypeVSCode) + session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, agentssh.MagicSessionTypeVSCode) defer session.Close() stdin, err := session.StdinPipe() require.NoError(t, err) @@ -1595,7 +1596,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe } waitGroup.Add(1) go func() { - agent.Bicopy(context.Background(), conn, ssh) + agentssh.Bicopy(context.Background(), conn, ssh) waitGroup.Done() }() } diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go new file mode 100644 index 0000000000000..94e716b260fbe --- /dev/null +++ b/agent/agentssh/agentssh.go @@ -0,0 +1,576 @@ +package agentssh + +import ( + "bufio" + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "fmt" + "io" + "net" + "os" + "os/exec" + "os/user" + "path/filepath" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/gliderlabs/ssh" + "github.com/pkg/sftp" + gossh "golang.org/x/crypto/ssh" + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/agent/usershell" + "github.com/coder/coder/codersdk/agentsdk" + "github.com/coder/coder/pty" +) + +const ( + // MagicSessionErrorCode indicates that something went wrong with the session, rather than the + // command just returning a nonzero exit code, and is chosen as an arbitrary, high number + // unlikely to shadow other exit codes, which are typically 1, 2, 3, etc. + MagicSessionErrorCode = 229 + + // MagicSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection. + // This is stripped from any commands being executed, and is counted towards connection stats. + MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE" + // MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself. + MagicSessionTypeVSCode = "vscode" + // MagicSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself. + MagicSessionTypeJetBrains = "jetbrains" +) + +type Server struct { + ctx context.Context + cancel context.CancelFunc + serveWg sync.WaitGroup + logger slog.Logger + + srv *ssh.Server + + Env map[string]string + AgentToken func() string + + manifest atomic.Pointer[agentsdk.Manifest] + + connCountVSCode atomic.Int64 + connCountJetBrains atomic.Int64 + connCountSSHSession atomic.Int64 +} + +func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration) (*Server, error) { + // Clients' should ignore the host key when connecting. + // The agent needs to authenticate with coderd to SSH, + // so SSH authentication doesn't improve security. + randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + randomSigner, err := gossh.NewSignerFromKey(randomHostKey) + if err != nil { + return nil, err + } + + forwardHandler := &ssh.ForwardedTCPHandler{} + unixForwardHandler := &forwardedUnixHandler{log: logger} + + sCtx, sCancel := context.WithCancel(context.Background()) + s := &Server{ + ctx: sCtx, + cancel: sCancel, + logger: logger, + } + + s.srv = &ssh.Server{ + ChannelHandlers: map[string]ssh.ChannelHandler{ + "direct-tcpip": ssh.DirectTCPIPHandler, + "direct-streamlocal@openssh.com": directStreamLocalHandler, + "session": ssh.DefaultSessionHandler, + }, + ConnectionFailedCallback: func(_ net.Conn, err error) { + logger.Info(ctx, "ssh connection ended", slog.Error(err)) + }, + Handler: s.sessionHandler, + HostSigners: []ssh.Signer{randomSigner}, + LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { + // Allow local port forwarding all! + logger.Debug(ctx, "local port forward", + slog.F("destination-host", destinationHost), + slog.F("destination-port", destinationPort)) + return true + }, + PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { + return true + }, + ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { + // Allow reverse port forwarding all! + logger.Debug(ctx, "local port forward", + slog.F("bind-host", bindHost), + slog.F("bind-port", bindPort)) + return true + }, + RequestHandlers: map[string]ssh.RequestHandler{ + "tcpip-forward": forwardHandler.HandleSSHRequest, + "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, + "streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, + "cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, + }, + ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { + return &gossh.ServerConfig{ + NoClientAuth: true, + } + }, + SubsystemHandlers: map[string]ssh.SubsystemHandler{ + "sftp": s.sftpHandler, + }, + MaxTimeout: maxTimeout, + } + + return s, nil +} + +// SetManifest sets the manifest used for starting commands. +func (a *Server) SetManifest(m *agentsdk.Manifest) { + a.manifest.Store(m) +} + +type ConnStats struct { + Sessions int64 + VSCode int64 + JetBrains int64 +} + +func (a *Server) ConnStats() ConnStats { + return ConnStats{ + Sessions: a.connCountSSHSession.Load(), + VSCode: a.connCountVSCode.Load(), + JetBrains: a.connCountJetBrains.Load(), + } +} + +func (a *Server) sessionHandler(session ssh.Session) { + ctx := session.Context() + err := a.sessionStart(session) + var exitError *exec.ExitError + if xerrors.As(err, &exitError) { + a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError)) + _ = session.Exit(exitError.ExitCode()) + return + } + if err != nil { + a.logger.Warn(ctx, "ssh session failed", slog.Error(err)) + // This exit code is designed to be unlikely to be confused for a legit exit code + // from the process. + _ = session.Exit(MagicSessionErrorCode) + return + } + _ = session.Exit(0) +} + +func (a *Server) sessionStart(session ssh.Session) (retErr error) { + ctx := session.Context() + env := session.Environ() + var magicType string + for index, kv := range env { + if !strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable) { + continue + } + magicType = strings.TrimPrefix(kv, MagicSessionTypeEnvironmentVariable+"=") + env = append(env[:index], env[index+1:]...) + } + switch magicType { + case MagicSessionTypeVSCode: + a.connCountVSCode.Add(1) + defer a.connCountVSCode.Add(-1) + case MagicSessionTypeJetBrains: + a.connCountJetBrains.Add(1) + defer a.connCountJetBrains.Add(-1) + case "": + a.connCountSSHSession.Add(1) + defer a.connCountSSHSession.Add(-1) + default: + a.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType)) + } + + cmd, err := a.CreateCommand(ctx, session.RawCommand(), env) + if err != nil { + return err + } + + if ssh.AgentRequested(session) { + l, err := ssh.NewAgentListener() + if err != nil { + return xerrors.Errorf("new agent listener: %w", err) + } + defer l.Close() + go ssh.ForwardAgentConnections(l, session) + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String())) + } + + sshPty, windowSize, isPty := session.Pty() + if isPty { + // Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL). + // See https://github.com/coder/coder/issues/3371. + session.DisablePTYEmulation() + + if !isQuietLogin(session.RawCommand()) { + manifest := a.manifest.Load() + if manifest != nil { + err = showMOTD(session, manifest.MOTDFile) + if err != nil { + a.logger.Error(ctx, "show MOTD", slog.Error(err)) + } + } else { + a.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") + } + } + + cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) + + // The pty package sets `SSH_TTY` on supported platforms. + ptty, process, err := pty.Start(cmd, pty.WithPTYOption( + pty.WithSSHRequest(sshPty), + pty.WithLogger(slog.Stdlib(ctx, a.logger, slog.LevelInfo)), + )) + if err != nil { + return xerrors.Errorf("start command: %w", err) + } + var wg sync.WaitGroup + defer func() { + defer wg.Wait() + closeErr := ptty.Close() + if closeErr != nil { + a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) + if retErr == nil { + retErr = closeErr + } + } + }() + go func() { + for win := range windowSize { + resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) + // If the pty is closed, then command has exited, no need to log. + if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { + a.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) + } + } + }() + // We don't add input copy to wait group because + // it won't return until the session is closed. + go func() { + _, _ = io.Copy(ptty.Input(), session) + }() + + // In low parallelism scenarios, the command may exit and we may close + // the pty before the output copy has started. This can result in the + // output being lost. To avoid this, we wait for the output copy to + // start before waiting for the command to exit. This ensures that the + // output copy goroutine will be scheduled before calling close on the + // pty. This shouldn't be needed because of `pty.Dup()` below, but it + // may not be supported on all platforms. + outputCopyStarted := make(chan struct{}) + ptyOutput := func() io.ReadCloser { + defer close(outputCopyStarted) + // Try to dup so we can separate stdin and stdout closure. + // Once the original pty is closed, the dup will return + // input/output error once the buffered data has been read. + stdout, err := ptty.Dup() + if err == nil { + return stdout + } + // If we can't dup, we shouldn't close + // the fd since it's tied to stdin. + return readNopCloser{ptty.Output()} + } + wg.Add(1) + go func() { + // Ensure data is flushed to session on command exit, if we + // close the session too soon, we might lose data. + defer wg.Done() + + stdout := ptyOutput() + defer stdout.Close() + + _, _ = io.Copy(session, stdout) + }() + <-outputCopyStarted + + err = process.Wait() + var exitErr *exec.ExitError + // ExitErrors just mean the command we run returned a non-zero exit code, which is normal + // and not something to be concerned about. But, if it's something else, we should log it. + if err != nil && !xerrors.As(err, &exitErr) { + a.logger.Warn(ctx, "wait error", slog.Error(err)) + } + return err + } + + cmd.Stdout = session + cmd.Stderr = session.Stderr() + // This blocks forever until stdin is received if we don't + // use StdinPipe. It's unknown what causes this. + stdinPipe, err := cmd.StdinPipe() + if err != nil { + return xerrors.Errorf("create stdin pipe: %w", err) + } + go func() { + _, _ = io.Copy(stdinPipe, session) + _ = stdinPipe.Close() + }() + err = cmd.Start() + if err != nil { + return xerrors.Errorf("start: %w", err) + } + return cmd.Wait() +} + +type readNopCloser struct{ io.Reader } + +// Close implements io.Closer. +func (readNopCloser) Close() error { return nil } + +func (a *Server) sftpHandler(session ssh.Session) { + ctx := session.Context() + + // Typically sftp sessions don't request a TTY, but if they do, + // we must ensure the gliderlabs/ssh CRLF emulation is disabled. + // Otherwise sftp will be broken. This can happen if a user sets + // `RequestTTY force` in their SSH config. + session.DisablePTYEmulation() + + var opts []sftp.ServerOption + // Change current working directory to the users home + // directory so that SFTP connections land there. + homedir, err := userHomeDir() + if err != nil { + a.logger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err)) + } else { + opts = append(opts, sftp.WithServerWorkingDirectory(homedir)) + } + + server, err := sftp.NewServer(session, opts...) + if err != nil { + a.logger.Debug(ctx, "initialize sftp server", slog.Error(err)) + return + } + defer server.Close() + + err = server.Serve() + if errors.Is(err, io.EOF) { + // Unless we call `session.Exit(0)` here, the client won't + // receive `exit-status` because `(*sftp.Server).Close()` + // calls `Close()` on the underlying connection (session), + // which actually calls `channel.Close()` because it isn't + // wrapped. This causes sftp clients to receive a non-zero + // exit code. Typically sftp clients don't echo this exit + // code but `scp` on macOS does (when using the default + // SFTP backend). + _ = session.Exit(0) + return + } + a.logger.Warn(ctx, "sftp server closed with error", slog.Error(err)) + _ = session.Exit(1) +} + +// CreateCommand processes raw command input with OpenSSH-like behavior. +// If the script provided is empty, it will default to the users shell. +// This injects environment variables specified by the user at launch too. +func (a *Server) CreateCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) { + currentUser, err := user.Current() + if err != nil { + return nil, xerrors.Errorf("get current user: %w", err) + } + username := currentUser.Username + + shell, err := usershell.Get(username) + if err != nil { + return nil, xerrors.Errorf("get user shell: %w", err) + } + + manifest := a.manifest.Load() + if manifest == nil { + return nil, xerrors.Errorf("no metadata was provided") + } + + // OpenSSH executes all commands with the users current shell. + // We replicate that behavior for IDE support. + caller := "-c" + if runtime.GOOS == "windows" { + caller = "/c" + } + args := []string{caller, script} + + // gliderlabs/ssh returns a command slice of zero + // when a shell is requested. + if len(script) == 0 { + args = []string{} + if runtime.GOOS != "windows" { + // On Linux and macOS, we should start a login + // shell to consume juicy environment variables! + args = append(args, "-l") + } + } + + cmd := exec.CommandContext(ctx, shell, args...) + cmd.Dir = manifest.Directory + + // If the metadata directory doesn't exist, we run the command + // in the users home directory. + _, err = os.Stat(cmd.Dir) + if cmd.Dir == "" || err != nil { + // Default to user home if a directory is not set. + homedir, err := userHomeDir() + if err != nil { + return nil, xerrors.Errorf("get home dir: %w", err) + } + cmd.Dir = homedir + } + cmd.Env = append(os.Environ(), env...) + executablePath, err := os.Executable() + if err != nil { + return nil, xerrors.Errorf("getting os executable: %w", err) + } + // Set environment variables reliable detection of being inside a + // Coder workspace. + cmd.Env = append(cmd.Env, "CODER=true") + cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username)) + // Git on Windows resolves with UNIX-style paths. + // If using backslashes, it's unable to find the executable. + unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/") + cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath)) + + // Specific Coder subcommands require the agent token exposed! + cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", a.AgentToken())) + + // Set SSH connection environment variables (these are also set by OpenSSH + // and thus expected to be present by SSH clients). Since the agent does + // networking in-memory, trying to provide accurate values here would be + // nonsensical. For now, we hard code these values so that they're present. + srcAddr, srcPort := "0.0.0.0", "0" + dstAddr, dstPort := "0.0.0.0", "0" + cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CLIENT=%s %s %s", srcAddr, srcPort, dstPort)) + cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", srcAddr, srcPort, dstAddr, dstPort)) + + // This adds the ports dialog to code-server that enables + // proxying a port dynamically. + cmd.Env = append(cmd.Env, fmt.Sprintf("VSCODE_PROXY_URI=%s", manifest.VSCodePortProxyURI)) + + // Hide Coder message on code-server's "Getting Started" page + cmd.Env = append(cmd.Env, "CS_DISABLE_GETTING_STARTED_OVERRIDE=true") + + // Load environment variables passed via the agent. + // These should override all variables we manually specify. + for envKey, value := range manifest.EnvironmentVariables { + // Expanding environment variables allows for customization + // of the $PATH, among other variables. Customers can prepend + // or append to the $PATH, so allowing expand is required! + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, os.ExpandEnv(value))) + } + + // Agent-level environment variables should take over all! + // This is used for setting agent-specific variables like "CODER_AGENT_TOKEN". + for envKey, value := range a.Env { + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value)) + } + + return cmd, nil +} + +func (a *Server) Serve(l net.Listener) error { + a.serveWg.Add(1) + defer a.serveWg.Done() + return a.srv.Serve(l) +} + +func (a *Server) Close() error { + err := a.srv.Close() + a.serveWg.Wait() + return err +} + +// Shutdown gracefully closes all active SSH connections and stops +// accepting new connections. +// +// Shutdown is not implemented. +func (a *Server) Shutdown(ctx context.Context) error { + // TODO(mafredri): Implement shutdown, SIGHUP running commands, etc. + return nil +} + +// isQuietLogin checks if the SSH server should perform a quiet login or not. +// +// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L816 +func isQuietLogin(rawCommand string) bool { + // We are always quiet unless this is a login shell. + if len(rawCommand) != 0 { + return true + } + + // Best effort, if we can't get the home directory, + // we can't lookup .hushlogin. + homedir, err := userHomeDir() + if err != nil { + return false + } + + _, err = os.Stat(filepath.Join(homedir, ".hushlogin")) + return err == nil +} + +// showMOTD will output the message of the day from +// the given filename to dest, if the file exists. +// +// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L784 +func showMOTD(dest io.Writer, filename string) error { + if filename == "" { + return nil + } + + f, err := os.Open(filename) + if err != nil { + if xerrors.Is(err, os.ErrNotExist) { + // This is not an error, there simply isn't a MOTD to show. + return nil + } + return xerrors.Errorf("open MOTD: %w", err) + } + defer f.Close() + + s := bufio.NewScanner(f) + for s.Scan() { + // Carriage return ensures each line starts + // at the beginning of the terminal. + _, err = fmt.Fprint(dest, s.Text()+"\r\n") + if err != nil { + return xerrors.Errorf("write MOTD: %w", err) + } + } + if err := s.Err(); err != nil { + return xerrors.Errorf("read MOTD: %w", err) + } + + return nil +} + +// userHomeDir returns the home directory of the current user, giving +// priority to the $HOME environment variable. +func userHomeDir() (string, error) { + // First we check the environment. + homedir, err := os.UserHomeDir() + if err == nil { + return homedir, nil + } + + // As a fallback, we try the user information. + u, err := user.Current() + if err != nil { + return "", xerrors.Errorf("current user: %w", err) + } + return u.HomeDir, nil +} diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go new file mode 100644 index 0000000000000..ecdb0a19eb5d8 --- /dev/null +++ b/agent/agentssh/agentssh_test.go @@ -0,0 +1,136 @@ +package agentssh_test + +import ( + "bytes" + "context" + "net" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "golang.org/x/crypto/ssh" + + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/agent/agentssh" + "github.com/coder/coder/codersdk/agentsdk" + "github.com/coder/coder/pty/ptytest" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestNewServer_ServeClient(t *testing.T) { + t.Parallel() + + ctx := context.Background() + logger := slogtest.Make(t, nil) + s, err := agentssh.NewServer(ctx, logger, 0) + require.NoError(t, err) + + // The assumption is that these are set before serving SSH connections. + s.AgentToken = func() string { return "" } + s.SetManifest(&agentsdk.Manifest{}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + err := s.Serve(ln) + assert.Error(t, err) // Server is closed. + }() + + c := sshClient(t, ln.Addr().String()) + var b bytes.Buffer + sess, err := c.NewSession() + sess.Stdout = &b + require.NoError(t, err) + err = sess.Start("echo hello") + require.NoError(t, err) + + err = sess.Wait() + require.NoError(t, err) + + require.Equal(t, "hello", strings.TrimSpace(b.String())) + + err = s.Close() + require.NoError(t, err) + <-done +} + +func TestNewServer_CloseActiveConnections(t *testing.T) { + t.Parallel() + + ctx := context.Background() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + s, err := agentssh.NewServer(ctx, logger, 0) + require.NoError(t, err) + + // The assumption is that these are set before serving SSH connections. + s.AgentToken = func() string { return "" } + s.SetManifest(&agentsdk.Manifest{}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + err := s.Serve(ln) + assert.Error(t, err) // Server is closed. + }() + + pty := ptytest.New(t) + + doClose := make(chan struct{}) + go func() { + defer wg.Done() + c := sshClient(t, ln.Addr().String()) + sess, err := c.NewSession() + sess.Stdin = pty.Input() + sess.Stdout = pty.Output() + sess.Stderr = pty.Output() + + assert.NoError(t, err) + err = sess.Start("") + assert.NoError(t, err) + + close(doClose) + err = sess.Wait() + assert.Error(t, err) + }() + + <-doClose + err = s.Close() + require.NoError(t, err) + + wg.Wait() +} + +func sshClient(t *testing.T, addr string) *ssh.Client { + conn, err := net.Dial("tcp", addr) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.Close() + }) + + sshConn, channels, requests, err := ssh.NewClientConn(conn, "localhost:22", &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // This is a test. + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = sshConn.Close() + }) + c := ssh.NewClient(sshConn, channels, requests) + t.Cleanup(func() { + _ = c.Close() + }) + return c +} diff --git a/agent/agentssh/bicopy.go b/agent/agentssh/bicopy.go new file mode 100644 index 0000000000000..64cd2a716058c --- /dev/null +++ b/agent/agentssh/bicopy.go @@ -0,0 +1,47 @@ +package agentssh + +import ( + "context" + "io" + "sync" +) + +// Bicopy copies all of the data between the two connections and will close them +// after one or both of them are done writing. If the context is canceled, both +// of the connections will be closed. +func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + defer func() { + _ = c1.Close() + _ = c2.Close() + }() + + var wg sync.WaitGroup + copyFunc := func(dst io.WriteCloser, src io.Reader) { + defer func() { + wg.Done() + // If one side of the copy fails, ensure the other one exits as + // well. + cancel() + }() + _, _ = io.Copy(dst, src) + } + + wg.Add(2) + go copyFunc(c1, c2) + go copyFunc(c2, c1) + + // Convert waitgroup to a channel so we can also wait on the context. + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() + }() + + select { + case <-ctx.Done(): + case <-done: + } +} diff --git a/agent/ssh.go b/agent/agentssh/forward.go similarity index 99% rename from agent/ssh.go rename to agent/agentssh/forward.go index 8aa41a1d268ed..1e3635fd8ff91 100644 --- a/agent/ssh.go +++ b/agent/agentssh/forward.go @@ -1,4 +1,4 @@ -package agent +package agentssh import ( "context" diff --git a/cli/portforward.go b/cli/portforward.go index c746216889a55..dad82381bfb5b 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -14,7 +14,7 @@ import ( "github.com/pion/udp" "golang.org/x/xerrors" - "github.com/coder/coder/agent" + "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" @@ -226,7 +226,7 @@ func listenAndPortForward(ctx context.Context, inv *clibase.Invocation, conn *co } defer remoteConn.Close() - agent.Bicopy(ctx, netConn, remoteConn) + agentssh.Bicopy(ctx, netConn, remoteConn) }(netConn) } }(spec) diff --git a/cli/ssh.go b/cli/ssh.go index e9168f6999f6b..e1ebbcd04cfd2 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -23,7 +23,7 @@ import ( "golang.org/x/term" "golang.org/x/xerrors" - "github.com/coder/coder/agent" + "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/autobuild/notify" @@ -574,7 +574,7 @@ func sshForwardRemote(ctx context.Context, stderr io.Writer, sshClient *gossh.Cl } } - agent.Bicopy(ctx, localConn, remoteConn) + agentssh.Bicopy(ctx, localConn, remoteConn) }() } }() diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 6ce14dad7689e..293ab3f0a06d8 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -30,7 +30,7 @@ import ( "tailscale.com/tailcfg" "cdr.dev/slog" - "github.com/coder/coder/agent" + "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/gitauth" @@ -620,7 +620,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { return } defer ptNetConn.Close() - agent.Bicopy(ctx, wsNetConn, ptNetConn) + agentssh.Bicopy(ctx, wsNetConn, ptNetConn) } // @Summary Get listening ports for workspace agent From d5f7a4e4f8705bfc67d02f9b137b1f0622ec6e34 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 4 Apr 2023 14:46:58 +0000 Subject: [PATCH 2/7] Rename receivers --- agent/agentssh/agentssh.go | 84 +++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 94e716b260fbe..b2c421c54f1aa 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -136,8 +136,8 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration } // SetManifest sets the manifest used for starting commands. -func (a *Server) SetManifest(m *agentsdk.Manifest) { - a.manifest.Store(m) +func (s *Server) SetManifest(m *agentsdk.Manifest) { + s.manifest.Store(m) } type ConnStats struct { @@ -146,25 +146,25 @@ type ConnStats struct { JetBrains int64 } -func (a *Server) ConnStats() ConnStats { +func (s *Server) ConnStats() ConnStats { return ConnStats{ - Sessions: a.connCountSSHSession.Load(), - VSCode: a.connCountVSCode.Load(), - JetBrains: a.connCountJetBrains.Load(), + Sessions: s.connCountSSHSession.Load(), + VSCode: s.connCountVSCode.Load(), + JetBrains: s.connCountJetBrains.Load(), } } -func (a *Server) sessionHandler(session ssh.Session) { +func (s *Server) sessionHandler(session ssh.Session) { ctx := session.Context() - err := a.sessionStart(session) + err := s.sessionStart(session) var exitError *exec.ExitError if xerrors.As(err, &exitError) { - a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError)) + s.logger.Debug(ctx, "ssh session returned", slog.Error(exitError)) _ = session.Exit(exitError.ExitCode()) return } if err != nil { - a.logger.Warn(ctx, "ssh session failed", slog.Error(err)) + s.logger.Warn(ctx, "ssh session failed", slog.Error(err)) // This exit code is designed to be unlikely to be confused for a legit exit code // from the process. _ = session.Exit(MagicSessionErrorCode) @@ -173,7 +173,7 @@ func (a *Server) sessionHandler(session ssh.Session) { _ = session.Exit(0) } -func (a *Server) sessionStart(session ssh.Session) (retErr error) { +func (s *Server) sessionStart(session ssh.Session) (retErr error) { ctx := session.Context() env := session.Environ() var magicType string @@ -186,19 +186,19 @@ func (a *Server) sessionStart(session ssh.Session) (retErr error) { } switch magicType { case MagicSessionTypeVSCode: - a.connCountVSCode.Add(1) - defer a.connCountVSCode.Add(-1) + s.connCountVSCode.Add(1) + defer s.connCountVSCode.Add(-1) case MagicSessionTypeJetBrains: - a.connCountJetBrains.Add(1) - defer a.connCountJetBrains.Add(-1) + s.connCountJetBrains.Add(1) + defer s.connCountJetBrains.Add(-1) case "": - a.connCountSSHSession.Add(1) - defer a.connCountSSHSession.Add(-1) + s.connCountSSHSession.Add(1) + defer s.connCountSSHSession.Add(-1) default: - a.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType)) + s.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType)) } - cmd, err := a.CreateCommand(ctx, session.RawCommand(), env) + cmd, err := s.CreateCommand(ctx, session.RawCommand(), env) if err != nil { return err } @@ -220,14 +220,14 @@ func (a *Server) sessionStart(session ssh.Session) (retErr error) { session.DisablePTYEmulation() if !isQuietLogin(session.RawCommand()) { - manifest := a.manifest.Load() + manifest := s.manifest.Load() if manifest != nil { err = showMOTD(session, manifest.MOTDFile) if err != nil { - a.logger.Error(ctx, "show MOTD", slog.Error(err)) + s.logger.Error(ctx, "show MOTD", slog.Error(err)) } } else { - a.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") + s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") } } @@ -236,7 +236,7 @@ func (a *Server) sessionStart(session ssh.Session) (retErr error) { // The pty package sets `SSH_TTY` on supported platforms. ptty, process, err := pty.Start(cmd, pty.WithPTYOption( pty.WithSSHRequest(sshPty), - pty.WithLogger(slog.Stdlib(ctx, a.logger, slog.LevelInfo)), + pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)), )) if err != nil { return xerrors.Errorf("start command: %w", err) @@ -246,7 +246,7 @@ func (a *Server) sessionStart(session ssh.Session) (retErr error) { defer wg.Wait() closeErr := ptty.Close() if closeErr != nil { - a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) + s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) if retErr == nil { retErr = closeErr } @@ -257,7 +257,7 @@ func (a *Server) sessionStart(session ssh.Session) (retErr error) { resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) // If the pty is closed, then command has exited, no need to log. if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { - a.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) + s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) } } }() @@ -306,7 +306,7 @@ func (a *Server) sessionStart(session ssh.Session) (retErr error) { // ExitErrors just mean the command we run returned a non-zero exit code, which is normal // and not something to be concerned about. But, if it's something else, we should log it. if err != nil && !xerrors.As(err, &exitErr) { - a.logger.Warn(ctx, "wait error", slog.Error(err)) + s.logger.Warn(ctx, "wait error", slog.Error(err)) } return err } @@ -335,7 +335,7 @@ type readNopCloser struct{ io.Reader } // Close implements io.Closer. func (readNopCloser) Close() error { return nil } -func (a *Server) sftpHandler(session ssh.Session) { +func (s *Server) sftpHandler(session ssh.Session) { ctx := session.Context() // Typically sftp sessions don't request a TTY, but if they do, @@ -349,14 +349,14 @@ func (a *Server) sftpHandler(session ssh.Session) { // directory so that SFTP connections land there. homedir, err := userHomeDir() if err != nil { - a.logger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err)) + s.logger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err)) } else { opts = append(opts, sftp.WithServerWorkingDirectory(homedir)) } server, err := sftp.NewServer(session, opts...) if err != nil { - a.logger.Debug(ctx, "initialize sftp server", slog.Error(err)) + s.logger.Debug(ctx, "initialize sftp server", slog.Error(err)) return } defer server.Close() @@ -374,14 +374,14 @@ func (a *Server) sftpHandler(session ssh.Session) { _ = session.Exit(0) return } - a.logger.Warn(ctx, "sftp server closed with error", slog.Error(err)) + s.logger.Warn(ctx, "sftp server closed with error", slog.Error(err)) _ = session.Exit(1) } // CreateCommand processes raw command input with OpenSSH-like behavior. // If the script provided is empty, it will default to the users shell. // This injects environment variables specified by the user at launch too. -func (a *Server) CreateCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) { +func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) { currentUser, err := user.Current() if err != nil { return nil, xerrors.Errorf("get current user: %w", err) @@ -393,7 +393,7 @@ func (a *Server) CreateCommand(ctx context.Context, script string, env []string) return nil, xerrors.Errorf("get user shell: %w", err) } - manifest := a.manifest.Load() + manifest := s.manifest.Load() if manifest == nil { return nil, xerrors.Errorf("no metadata was provided") } @@ -446,7 +446,7 @@ func (a *Server) CreateCommand(ctx context.Context, script string, env []string) cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath)) // Specific Coder subcommands require the agent token exposed! - cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", a.AgentToken())) + cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", s.AgentToken())) // Set SSH connection environment variables (these are also set by OpenSSH // and thus expected to be present by SSH clients). Since the agent does @@ -475,22 +475,22 @@ func (a *Server) CreateCommand(ctx context.Context, script string, env []string) // Agent-level environment variables should take over all! // This is used for setting agent-specific variables like "CODER_AGENT_TOKEN". - for envKey, value := range a.Env { + for envKey, value := range s.Env { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value)) } return cmd, nil } -func (a *Server) Serve(l net.Listener) error { - a.serveWg.Add(1) - defer a.serveWg.Done() - return a.srv.Serve(l) +func (s *Server) Serve(l net.Listener) error { + s.serveWg.Add(1) + defer s.serveWg.Done() + return s.srv.Serve(l) } -func (a *Server) Close() error { - err := a.srv.Close() - a.serveWg.Wait() +func (s *Server) Close() error { + err := s.srv.Close() + s.serveWg.Wait() return err } @@ -498,7 +498,7 @@ func (a *Server) Close() error { // accepting new connections. // // Shutdown is not implemented. -func (a *Server) Shutdown(ctx context.Context) error { +func (*Server) Shutdown(ctx context.Context) error { // TODO(mafredri): Implement shutdown, SIGHUP running commands, etc. return nil } From 3ea3e707ce40f3d9a4c4be0481e748484607f55d Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 4 Apr 2023 14:52:33 +0000 Subject: [PATCH 3/7] Remove unused context --- agent/agentssh/agentssh.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index b2c421c54f1aa..ff273790dffec 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -47,8 +47,6 @@ const ( ) type Server struct { - ctx context.Context - cancel context.CancelFunc serveWg sync.WaitGroup logger slog.Logger @@ -80,10 +78,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration forwardHandler := &ssh.ForwardedTCPHandler{} unixForwardHandler := &forwardedUnixHandler{log: logger} - sCtx, sCancel := context.WithCancel(context.Background()) s := &Server{ - ctx: sCtx, - cancel: sCancel, logger: logger, } From 126813f029765d9bf6606d7cbcaee40a40ad845f Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 4 Apr 2023 14:54:22 +0000 Subject: [PATCH 4/7] Use s logger --- agent/agentssh/agentssh.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index ff273790dffec..d929b1e0293e5 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -89,13 +89,13 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration "session": ssh.DefaultSessionHandler, }, ConnectionFailedCallback: func(_ net.Conn, err error) { - logger.Info(ctx, "ssh connection ended", slog.Error(err)) + s.logger.Info(ctx, "ssh connection ended", slog.Error(err)) }, Handler: s.sessionHandler, HostSigners: []ssh.Signer{randomSigner}, LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { // Allow local port forwarding all! - logger.Debug(ctx, "local port forward", + s.logger.Debug(ctx, "local port forward", slog.F("destination-host", destinationHost), slog.F("destination-port", destinationPort)) return true @@ -105,7 +105,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration }, ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { // Allow reverse port forwarding all! - logger.Debug(ctx, "local port forward", + s.logger.Debug(ctx, "local port forward", slog.F("bind-host", bindHost), slog.F("bind-port", bindPort)) return true From 667d038fe262d3bd21e827569f3534f242cfe9b0 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 4 Apr 2023 15:04:02 +0000 Subject: [PATCH 5/7] Rename unused arg _ --- agent/agentssh/agentssh.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index d929b1e0293e5..a1cad342d4fbe 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -493,7 +493,7 @@ func (s *Server) Close() error { // accepting new connections. // // Shutdown is not implemented. -func (*Server) Shutdown(ctx context.Context) error { +func (*Server) Shutdown(_ context.Context) error { // TODO(mafredri): Implement shutdown, SIGHUP running commands, etc. return nil } From 94d759396ad87cb386f54917ea51cc56e4a6fe0f Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Thu, 6 Apr 2023 09:50:51 +0000 Subject: [PATCH 6/7] Address PR feedback --- agent/agent.go | 2 +- agent/agentssh/agentssh.go | 14 ++++---------- agent/agentssh/agentssh_test.go | 7 +++++-- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 3906a182139ad..f538ef93b4af8 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -167,6 +167,7 @@ func (a *agent) init(ctx context.Context) { } sshSrv.Env = a.envVars sshSrv.AgentToken = func() string { return *a.sessionToken.Load() } + sshSrv.Manifest = &a.manifest a.sshServer = sshSrv go a.runLoop(ctx) @@ -478,7 +479,6 @@ func (a *agent) run(ctx context.Context) error { } oldManifest := a.manifest.Swap(&manifest) - a.sshServer.SetManifest(&manifest) // The startup script should only execute on the first run! if oldManifest == nil { diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index a1cad342d4fbe..c511d0e5f9eab 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -16,11 +16,11 @@ import ( "runtime" "strings" "sync" - "sync/atomic" "time" "github.com/gliderlabs/ssh" "github.com/pkg/sftp" + "go.uber.org/atomic" gossh "golang.org/x/crypto/ssh" "golang.org/x/xerrors" @@ -54,8 +54,7 @@ type Server struct { Env map[string]string AgentToken func() string - - manifest atomic.Pointer[agentsdk.Manifest] + Manifest *atomic.Pointer[agentsdk.Manifest] connCountVSCode atomic.Int64 connCountJetBrains atomic.Int64 @@ -130,11 +129,6 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration return s, nil } -// SetManifest sets the manifest used for starting commands. -func (s *Server) SetManifest(m *agentsdk.Manifest) { - s.manifest.Store(m) -} - type ConnStats struct { Sessions int64 VSCode int64 @@ -215,7 +209,7 @@ func (s *Server) sessionStart(session ssh.Session) (retErr error) { session.DisablePTYEmulation() if !isQuietLogin(session.RawCommand()) { - manifest := s.manifest.Load() + manifest := s.Manifest.Load() if manifest != nil { err = showMOTD(session, manifest.MOTDFile) if err != nil { @@ -388,7 +382,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string) return nil, xerrors.Errorf("get user shell: %w", err) } - manifest := s.manifest.Load() + manifest := s.Manifest.Load() if manifest == nil { return nil, xerrors.Errorf("no metadata was provided") } diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index ecdb0a19eb5d8..684c0e36bbb18 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -1,3 +1,5 @@ +// Package agentssh_test provides tests for basic functinoality of the agentssh +// package, more test coverage can be found in the `agent` and `cli` package(s). package agentssh_test import ( @@ -10,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/atomic" "go.uber.org/goleak" "golang.org/x/crypto/ssh" @@ -34,7 +37,7 @@ func TestNewServer_ServeClient(t *testing.T) { // The assumption is that these are set before serving SSH connections. s.AgentToken = func() string { return "" } - s.SetManifest(&agentsdk.Manifest{}) + s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -74,7 +77,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) { // The assumption is that these are set before serving SSH connections. s.AgentToken = func() string { return "" } - s.SetManifest(&agentsdk.Manifest{}) + s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) From ed63a2bcf048d0fc178908ff06743c4abd823fd0 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Thu, 6 Apr 2023 11:38:04 +0000 Subject: [PATCH 7/7] Improve handling of serve/close --- agent/agentssh/agentssh.go | 128 ++++++++++++++++++++++++++++++++++--- 1 file changed, 120 insertions(+), 8 deletions(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index c511d0e5f9eab..c882380bacf48 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -47,10 +47,16 @@ const ( ) type Server struct { - serveWg sync.WaitGroup - logger slog.Logger + mu sync.RWMutex // Protects following. + listeners map[net.Listener]struct{} + conns map[net.Conn]struct{} + closing chan struct{} + // Wait for goroutines to exit, waited without + // a lock on mu but protected by closing. + wg sync.WaitGroup - srv *ssh.Server + logger slog.Logger + srv *ssh.Server Env map[string]string AgentToken func() string @@ -78,7 +84,9 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration unixForwardHandler := &forwardedUnixHandler{log: logger} s := &Server{ - logger: logger, + listeners: make(map[net.Listener]struct{}), + conns: make(map[net.Conn]struct{}), + logger: logger, } s.srv = &ssh.Server{ @@ -472,14 +480,118 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string) } func (s *Server) Serve(l net.Listener) error { - s.serveWg.Add(1) - defer s.serveWg.Done() - return s.srv.Serve(l) + defer l.Close() + + s.trackListener(l, true) + defer s.trackListener(l, false) + for { + conn, err := l.Accept() + if err != nil { + return err + } + go s.handleConn(l, conn) + } } +func (s *Server) handleConn(l net.Listener, c net.Conn) { + defer c.Close() + + if !s.trackConn(l, c, true) { + // Server is closed or we no longer want + // connections from this listener. + s.logger.Debug(context.Background(), "received connection after server closed") + return + } + defer s.trackConn(l, c, false) + + s.srv.HandleConn(c) +} + +// trackListener registers the listener with the server. If the server is +// closing, the function will block until the server is closed. +// +//nolint:revive +func (s *Server) trackListener(l net.Listener, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if add { + for s.closing != nil { + closing := s.closing + // Wait until close is complete before + // serving a new listener. + s.mu.Unlock() + <-closing + s.mu.Lock() + } + s.wg.Add(1) + s.listeners[l] = struct{}{} + return + } + s.wg.Done() + delete(s.listeners, l) +} + +// trackConn registers the connection with the server. If the server is +// closed or the listener is closed, the connection is not registered +// and should be closed. +// +//nolint:revive +func (s *Server) trackConn(l net.Listener, c net.Conn, add bool) (ok bool) { + s.mu.Lock() + defer s.mu.Unlock() + if add { + found := false + for ll := range s.listeners { + if l == ll { + found = true + break + } + } + if s.closing != nil || !found { + // Server or listener closed. + return false + } + s.wg.Add(1) + s.conns[c] = struct{}{} + return true + } + s.wg.Done() + delete(s.conns, c) + return true +} + +// Close the server and all active connections. Server can be re-used +// after Close is done. func (s *Server) Close() error { + s.mu.Lock() + + // Guard against multiple calls to Close and + // accepting new connections during close. + if s.closing != nil { + s.mu.Unlock() + return xerrors.New("server is closing") + } + s.closing = make(chan struct{}) + + // Close all active listeners and connections. + for l := range s.listeners { + _ = l.Close() + } + for c := range s.conns { + _ = c.Close() + } + + // Close the underlying SSH server. err := s.srv.Close() - s.serveWg.Wait() + + s.mu.Unlock() + s.wg.Wait() // Wait for all goroutines to exit. + + s.mu.Lock() + close(s.closing) + s.closing = nil + s.mu.Unlock() + return err } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy