diff --git a/agent/agent.go b/agent/agent.go index e22d5c3576123..f538ef93b4af8 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -4,8 +4,6 @@ import ( "bufio" "bytes" "context" - "crypto/rand" - "crypto/rsa" "encoding/binary" "encoding/json" "errors" @@ -16,11 +14,9 @@ import ( "net/http" "net/netip" "os" - "os/exec" "os/user" "path/filepath" "reflect" - "runtime" "sort" "strconv" "strings" @@ -28,12 +24,9 @@ import ( "time" "github.com/armon/circbuf" - "github.com/gliderlabs/ssh" "github.com/google/uuid" - "github.com/pkg/sftp" "github.com/spf13/afero" "go.uber.org/atomic" - gossh "golang.org/x/crypto/ssh" "golang.org/x/exp/slices" "golang.org/x/xerrors" "tailscale.com/net/speedtest" @@ -41,7 +34,7 @@ import ( "tailscale.com/types/netlogtype" "cdr.dev/slog" - "github.com/coder/coder/agent/usershell" + "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/buildinfo" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/gitauth" @@ -56,19 +49,6 @@ const ( ProtocolReconnectingPTY = "reconnecting-pty" ProtocolSSH = "ssh" ProtocolDial = "dial" - - // MagicSessionErrorCode indicates that something went wrong with the session, rather than the - // command just returning a nonzero exit code, and is chosen as an arbitrary, high number - // unlikely to shadow other exit codes, which are typically 1, 2, 3, etc. - MagicSessionErrorCode = 229 - - // MagicSSHSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection. - // This is stripped from any commands being executed, and is counted towards connection stats. - MagicSSHSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE" - // MagicSSHSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself. - MagicSSHSessionTypeVSCode = "vscode" - // MagicSSHSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself. - MagicSSHSessionTypeJetBrains = "jetbrains" ) type Options struct { @@ -165,7 +145,7 @@ type agent struct { // manifest is atomic because values can change after reconnection. manifest atomic.Pointer[agentsdk.Manifest] sessionToken atomic.Pointer[string] - sshServer *ssh.Server + sshServer *agentssh.Server sshMaxTimeout time.Duration lifecycleUpdate chan struct{} @@ -177,10 +157,20 @@ type agent struct { connStatsChan chan *agentsdk.Stats latestStat atomic.Pointer[agentsdk.Stats] - connCountVSCode atomic.Int64 - connCountJetBrains atomic.Int64 connCountReconnectingPTY atomic.Int64 - connCountSSHSession atomic.Int64 +} + +func (a *agent) init(ctx context.Context) { + sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.sshMaxTimeout) + if err != nil { + panic(err) + } + sshSrv.Env = a.envVars + sshSrv.AgentToken = func() string { return *a.sessionToken.Load() } + sshSrv.Manifest = &a.manifest + a.sshServer = sshSrv + + go a.runLoop(ctx) } // runLoop attempts to start the agent in a retry loop. @@ -223,7 +213,7 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM // if it is certain the clocks are in sync. CollectedAt: time.Now(), } - cmd, err := a.createCommand(ctx, md.Script, nil) + cmd, err := a.sshServer.CreateCommand(ctx, md.Script, nil) if err != nil { result.Error = err.Error() return result @@ -633,28 +623,7 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ } }() if err = a.trackConnGoroutine(func() { - var wg sync.WaitGroup - for { - conn, err := sshListener.Accept() - if err != nil { - break - } - wg.Add(1) - closed := make(chan struct{}) - go func() { - select { - case <-closed: - case <-a.closed: - _ = conn.Close() - } - wg.Done() - }() - go func() { - defer close(closed) - a.sshServer.HandleConn(conn) - }() - } - wg.Wait() + _ = a.sshServer.Serve(sshListener) }); err != nil { return nil, err } @@ -857,7 +826,7 @@ func (a *agent) runScript(ctx context.Context, lifecycle, script string) error { }() } - cmd, err := a.createCommand(ctx, script, nil) + cmd, err := a.sshServer.CreateCommand(ctx, script, nil) if err != nil { return xerrors.Errorf("create command: %w", err) } @@ -990,394 +959,6 @@ func (a *agent) trackScriptLogs(ctx context.Context, reader io.Reader) (chan str return logsFinished, nil } -func (a *agent) init(ctx context.Context) { - // Clients' should ignore the host key when connecting. - // The agent needs to authenticate with coderd to SSH, - // so SSH authentication doesn't improve security. - randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - panic(err) - } - randomSigner, err := gossh.NewSignerFromKey(randomHostKey) - if err != nil { - panic(err) - } - - sshLogger := a.logger.Named("ssh-server") - forwardHandler := &ssh.ForwardedTCPHandler{} - unixForwardHandler := &forwardedUnixHandler{log: a.logger} - - a.sshServer = &ssh.Server{ - ChannelHandlers: map[string]ssh.ChannelHandler{ - "direct-tcpip": ssh.DirectTCPIPHandler, - "direct-streamlocal@openssh.com": directStreamLocalHandler, - "session": ssh.DefaultSessionHandler, - }, - ConnectionFailedCallback: func(conn net.Conn, err error) { - sshLogger.Info(ctx, "ssh connection ended", slog.Error(err)) - }, - Handler: func(session ssh.Session) { - err := a.handleSSHSession(session) - var exitError *exec.ExitError - if xerrors.As(err, &exitError) { - a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError)) - _ = session.Exit(exitError.ExitCode()) - return - } - if err != nil { - a.logger.Warn(ctx, "ssh session failed", slog.Error(err)) - // This exit code is designed to be unlikely to be confused for a legit exit code - // from the process. - _ = session.Exit(MagicSessionErrorCode) - return - } - _ = session.Exit(0) - }, - HostSigners: []ssh.Signer{randomSigner}, - LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { - // Allow local port forwarding all! - sshLogger.Debug(ctx, "local port forward", - slog.F("destination-host", destinationHost), - slog.F("destination-port", destinationPort)) - return true - }, - PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { - return true - }, - ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { - // Allow reverse port forwarding all! - sshLogger.Debug(ctx, "local port forward", - slog.F("bind-host", bindHost), - slog.F("bind-port", bindPort)) - return true - }, - RequestHandlers: map[string]ssh.RequestHandler{ - "tcpip-forward": forwardHandler.HandleSSHRequest, - "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, - "streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, - "cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, - }, - ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { - return &gossh.ServerConfig{ - NoClientAuth: true, - } - }, - SubsystemHandlers: map[string]ssh.SubsystemHandler{ - "sftp": func(session ssh.Session) { - ctx := session.Context() - - // Typically sftp sessions don't request a TTY, but if they do, - // we must ensure the gliderlabs/ssh CRLF emulation is disabled. - // Otherwise sftp will be broken. This can happen if a user sets - // `RequestTTY force` in their SSH config. - session.DisablePTYEmulation() - - var opts []sftp.ServerOption - // Change current working directory to the users home - // directory so that SFTP connections land there. - homedir, err := userHomeDir() - if err != nil { - sshLogger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err)) - } else { - opts = append(opts, sftp.WithServerWorkingDirectory(homedir)) - } - - server, err := sftp.NewServer(session, opts...) - if err != nil { - sshLogger.Debug(ctx, "initialize sftp server", slog.Error(err)) - return - } - defer server.Close() - - err = server.Serve() - if errors.Is(err, io.EOF) { - // Unless we call `session.Exit(0)` here, the client won't - // receive `exit-status` because `(*sftp.Server).Close()` - // calls `Close()` on the underlying connection (session), - // which actually calls `channel.Close()` because it isn't - // wrapped. This causes sftp clients to receive a non-zero - // exit code. Typically sftp clients don't echo this exit - // code but `scp` on macOS does (when using the default - // SFTP backend). - _ = session.Exit(0) - return - } - sshLogger.Warn(ctx, "sftp server closed with error", slog.Error(err)) - _ = session.Exit(1) - }, - }, - MaxTimeout: a.sshMaxTimeout, - } - - go a.runLoop(ctx) -} - -// createCommand processes raw command input with OpenSSH-like behavior. -// If the script provided is empty, it will default to the users shell. -// This injects environment variables specified by the user at launch too. -func (a *agent) createCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) { - currentUser, err := user.Current() - if err != nil { - return nil, xerrors.Errorf("get current user: %w", err) - } - username := currentUser.Username - - shell, err := usershell.Get(username) - if err != nil { - return nil, xerrors.Errorf("get user shell: %w", err) - } - - manifest := a.manifest.Load() - if manifest == nil { - return nil, xerrors.Errorf("no metadata was provided") - } - - // OpenSSH executes all commands with the users current shell. - // We replicate that behavior for IDE support. - caller := "-c" - if runtime.GOOS == "windows" { - caller = "/c" - } - args := []string{caller, script} - - // gliderlabs/ssh returns a command slice of zero - // when a shell is requested. - if len(script) == 0 { - args = []string{} - if runtime.GOOS != "windows" { - // On Linux and macOS, we should start a login - // shell to consume juicy environment variables! - args = append(args, "-l") - } - } - - cmd := exec.CommandContext(ctx, shell, args...) - cmd.Dir = manifest.Directory - - // If the metadata directory doesn't exist, we run the command - // in the users home directory. - _, err = os.Stat(cmd.Dir) - if cmd.Dir == "" || err != nil { - // Default to user home if a directory is not set. - homedir, err := userHomeDir() - if err != nil { - return nil, xerrors.Errorf("get home dir: %w", err) - } - cmd.Dir = homedir - } - cmd.Env = append(os.Environ(), env...) - executablePath, err := os.Executable() - if err != nil { - return nil, xerrors.Errorf("getting os executable: %w", err) - } - // Set environment variables reliable detection of being inside a - // Coder workspace. - cmd.Env = append(cmd.Env, "CODER=true") - cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username)) - // Git on Windows resolves with UNIX-style paths. - // If using backslashes, it's unable to find the executable. - unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/") - cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath)) - - // Specific Coder subcommands require the agent token exposed! - cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", *a.sessionToken.Load())) - - // Set SSH connection environment variables (these are also set by OpenSSH - // and thus expected to be present by SSH clients). Since the agent does - // networking in-memory, trying to provide accurate values here would be - // nonsensical. For now, we hard code these values so that they're present. - srcAddr, srcPort := "0.0.0.0", "0" - dstAddr, dstPort := "0.0.0.0", "0" - cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CLIENT=%s %s %s", srcAddr, srcPort, dstPort)) - cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", srcAddr, srcPort, dstAddr, dstPort)) - - // This adds the ports dialog to code-server that enables - // proxying a port dynamically. - cmd.Env = append(cmd.Env, fmt.Sprintf("VSCODE_PROXY_URI=%s", manifest.VSCodePortProxyURI)) - - // Hide Coder message on code-server's "Getting Started" page - cmd.Env = append(cmd.Env, "CS_DISABLE_GETTING_STARTED_OVERRIDE=true") - - // Load environment variables passed via the agent. - // These should override all variables we manually specify. - for envKey, value := range manifest.EnvironmentVariables { - // Expanding environment variables allows for customization - // of the $PATH, among other variables. Customers can prepend - // or append to the $PATH, so allowing expand is required! - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, os.ExpandEnv(value))) - } - - // Agent-level environment variables should take over all! - // This is used for setting agent-specific variables like "CODER_AGENT_TOKEN". - for envKey, value := range a.envVars { - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value)) - } - - return cmd, nil -} - -func (a *agent) handleSSHSession(session ssh.Session) (retErr error) { - ctx := session.Context() - env := session.Environ() - var magicType string - for index, kv := range env { - if !strings.HasPrefix(kv, MagicSSHSessionTypeEnvironmentVariable) { - continue - } - magicType = strings.TrimPrefix(kv, MagicSSHSessionTypeEnvironmentVariable+"=") - env = append(env[:index], env[index+1:]...) - } - switch magicType { - case MagicSSHSessionTypeVSCode: - a.connCountVSCode.Add(1) - defer a.connCountVSCode.Add(-1) - case MagicSSHSessionTypeJetBrains: - a.connCountJetBrains.Add(1) - defer a.connCountJetBrains.Add(-1) - case "": - a.connCountSSHSession.Add(1) - defer a.connCountSSHSession.Add(-1) - default: - a.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType)) - } - - cmd, err := a.createCommand(ctx, session.RawCommand(), env) - if err != nil { - return err - } - - if ssh.AgentRequested(session) { - l, err := ssh.NewAgentListener() - if err != nil { - return xerrors.Errorf("new agent listener: %w", err) - } - defer l.Close() - go ssh.ForwardAgentConnections(l, session) - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String())) - } - - sshPty, windowSize, isPty := session.Pty() - if isPty { - // Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL). - // See https://github.com/coder/coder/issues/3371. - session.DisablePTYEmulation() - - if !isQuietLogin(session.RawCommand()) { - manifest := a.manifest.Load() - if manifest != nil { - err = showMOTD(session, manifest.MOTDFile) - if err != nil { - a.logger.Error(ctx, "show MOTD", slog.Error(err)) - } - } else { - a.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") - } - } - - cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) - - // The pty package sets `SSH_TTY` on supported platforms. - ptty, process, err := pty.Start(cmd, pty.WithPTYOption( - pty.WithSSHRequest(sshPty), - pty.WithLogger(slog.Stdlib(ctx, a.logger, slog.LevelInfo)), - )) - if err != nil { - return xerrors.Errorf("start command: %w", err) - } - var wg sync.WaitGroup - defer func() { - defer wg.Wait() - closeErr := ptty.Close() - if closeErr != nil { - a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) - if retErr == nil { - retErr = closeErr - } - } - }() - go func() { - for win := range windowSize { - resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) - // If the pty is closed, then command has exited, no need to log. - if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { - a.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) - } - } - }() - // We don't add input copy to wait group because - // it won't return until the session is closed. - go func() { - _, _ = io.Copy(ptty.Input(), session) - }() - - // In low parallelism scenarios, the command may exit and we may close - // the pty before the output copy has started. This can result in the - // output being lost. To avoid this, we wait for the output copy to - // start before waiting for the command to exit. This ensures that the - // output copy goroutine will be scheduled before calling close on the - // pty. This shouldn't be needed because of `pty.Dup()` below, but it - // may not be supported on all platforms. - outputCopyStarted := make(chan struct{}) - ptyOutput := func() io.ReadCloser { - defer close(outputCopyStarted) - // Try to dup so we can separate stdin and stdout closure. - // Once the original pty is closed, the dup will return - // input/output error once the buffered data has been read. - stdout, err := ptty.Dup() - if err == nil { - return stdout - } - // If we can't dup, we shouldn't close - // the fd since it's tied to stdin. - return readNopCloser{ptty.Output()} - } - wg.Add(1) - go func() { - // Ensure data is flushed to session on command exit, if we - // close the session too soon, we might lose data. - defer wg.Done() - - stdout := ptyOutput() - defer stdout.Close() - - _, _ = io.Copy(session, stdout) - }() - <-outputCopyStarted - - err = process.Wait() - var exitErr *exec.ExitError - // ExitErrors just mean the command we run returned a non-zero exit code, which is normal - // and not something to be concerned about. But, if it's something else, we should log it. - if err != nil && !xerrors.As(err, &exitErr) { - a.logger.Warn(ctx, "wait error", slog.Error(err)) - } - return err - } - - cmd.Stdout = session - cmd.Stderr = session.Stderr() - // This blocks forever until stdin is received if we don't - // use StdinPipe. It's unknown what causes this. - stdinPipe, err := cmd.StdinPipe() - if err != nil { - return xerrors.Errorf("create stdin pipe: %w", err) - } - go func() { - _, _ = io.Copy(stdinPipe, session) - _ = stdinPipe.Close() - }() - err = cmd.Start() - if err != nil { - return xerrors.Errorf("start: %w", err) - } - return cmd.Wait() -} - -type readNopCloser struct{ io.Reader } - -// Close implements io.Closer. -func (readNopCloser) Close() error { return nil } - func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg codersdk.WorkspaceAgentReconnectingPTYInit, conn net.Conn) (retErr error) { defer conn.Close() @@ -1416,7 +997,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m logger.Debug(ctx, "creating new session") // Empty command will default to the users shell! - cmd, err := a.createCommand(ctx, msg.Command, nil) + cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil) if err != nil { return xerrors.Errorf("create command: %w", err) } @@ -1590,9 +1171,11 @@ func (a *agent) startReportingConnectionStats(ctx context.Context) { } // The count of active sessions. - stats.SessionCountSSH = a.connCountSSHSession.Load() - stats.SessionCountVSCode = a.connCountVSCode.Load() - stats.SessionCountJetBrains = a.connCountJetBrains.Load() + sshStats := a.sshServer.ConnStats() + stats.SessionCountSSH = sshStats.Sessions + stats.SessionCountVSCode = sshStats.VSCode + stats.SessionCountJetBrains = sshStats.JetBrains + stats.SessionCountReconnectingPTY = a.connCountReconnectingPTY.Load() // Compute the median connection latency! @@ -1692,8 +1275,16 @@ func (a *agent) Close() error { } ctx := context.Background() + a.logger.Info(ctx, "shutting down agent") a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShuttingDown) + // Attempt to gracefully shut down all active SSH connections and + // stop accepting new ones. + err := a.sshServer.Shutdown(ctx) + if err != nil { + a.logger.Error(ctx, "ssh server shutdown", slog.Error(err)) + } + lifecycleState := codersdk.WorkspaceAgentLifecycleOff if manifest := a.manifest.Load(); manifest != nil && manifest.ShutdownScript != "" { scriptDone := make(chan error, 1) @@ -1785,101 +1376,6 @@ func (r *reconnectingPTY) Close() { r.timeout.Stop() } -// Bicopy copies all of the data between the two connections and will close them -// after one or both of them are done writing. If the context is canceled, both -// of the connections will be closed. -func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - defer func() { - _ = c1.Close() - _ = c2.Close() - }() - - var wg sync.WaitGroup - copyFunc := func(dst io.WriteCloser, src io.Reader) { - defer func() { - wg.Done() - // If one side of the copy fails, ensure the other one exits as - // well. - cancel() - }() - _, _ = io.Copy(dst, src) - } - - wg.Add(2) - go copyFunc(c1, c2) - go copyFunc(c2, c1) - - // Convert waitgroup to a channel so we can also wait on the context. - done := make(chan struct{}) - go func() { - defer close(done) - wg.Wait() - }() - - select { - case <-ctx.Done(): - case <-done: - } -} - -// isQuietLogin checks if the SSH server should perform a quiet login or not. -// -// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L816 -func isQuietLogin(rawCommand string) bool { - // We are always quiet unless this is a login shell. - if len(rawCommand) != 0 { - return true - } - - // Best effort, if we can't get the home directory, - // we can't lookup .hushlogin. - homedir, err := userHomeDir() - if err != nil { - return false - } - - _, err = os.Stat(filepath.Join(homedir, ".hushlogin")) - return err == nil -} - -// showMOTD will output the message of the day from -// the given filename to dest, if the file exists. -// -// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L784 -func showMOTD(dest io.Writer, filename string) error { - if filename == "" { - return nil - } - - f, err := os.Open(filename) - if err != nil { - if xerrors.Is(err, os.ErrNotExist) { - // This is not an error, there simply isn't a MOTD to show. - return nil - } - return xerrors.Errorf("open MOTD: %w", err) - } - defer f.Close() - - s := bufio.NewScanner(f) - for s.Scan() { - // Carriage return ensures each line starts - // at the beginning of the terminal. - _, err = fmt.Fprint(dest, s.Text()+"\r\n") - if err != nil { - return xerrors.Errorf("write MOTD: %w", err) - } - } - if err := s.Err(); err != nil { - return xerrors.Errorf("read MOTD: %w", err) - } - - return nil -} - // userHomeDir returns the home directory of the current user, giving // priority to the $HOME environment variable. func userHomeDir() (string, error) { diff --git a/agent/agent_test.go b/agent/agent_test.go index ec76aa1b0b6b9..8d7d641e1f73d 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -41,6 +41,7 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" + "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk/agentsdk" @@ -131,13 +132,13 @@ func TestAgent_Stats_Magic(t *testing.T) { defer sshClient.Close() session, err := sshClient.NewSession() require.NoError(t, err) - session.Setenv(agent.MagicSSHSessionTypeEnvironmentVariable, agent.MagicSSHSessionTypeVSCode) + session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, agentssh.MagicSessionTypeVSCode) defer session.Close() - command := "sh -c 'echo $" + agent.MagicSSHSessionTypeEnvironmentVariable + "'" + command := "sh -c 'echo $" + agentssh.MagicSessionTypeEnvironmentVariable + "'" expected := "" if runtime.GOOS == "windows" { - expected = "%" + agent.MagicSSHSessionTypeEnvironmentVariable + "%" + expected = "%" + agentssh.MagicSessionTypeEnvironmentVariable + "%" command = "cmd.exe /c echo " + expected } output, err := session.Output(command) @@ -158,7 +159,7 @@ func TestAgent_Stats_Magic(t *testing.T) { defer sshClient.Close() session, err := sshClient.NewSession() require.NoError(t, err) - session.Setenv(agent.MagicSSHSessionTypeEnvironmentVariable, agent.MagicSSHSessionTypeVSCode) + session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, agentssh.MagicSessionTypeVSCode) defer session.Close() stdin, err := session.StdinPipe() require.NoError(t, err) @@ -1595,7 +1596,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe } waitGroup.Add(1) go func() { - agent.Bicopy(context.Background(), conn, ssh) + agentssh.Bicopy(context.Background(), conn, ssh) waitGroup.Done() }() } diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go new file mode 100644 index 0000000000000..c882380bacf48 --- /dev/null +++ b/agent/agentssh/agentssh.go @@ -0,0 +1,677 @@ +package agentssh + +import ( + "bufio" + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "fmt" + "io" + "net" + "os" + "os/exec" + "os/user" + "path/filepath" + "runtime" + "strings" + "sync" + "time" + + "github.com/gliderlabs/ssh" + "github.com/pkg/sftp" + "go.uber.org/atomic" + gossh "golang.org/x/crypto/ssh" + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/agent/usershell" + "github.com/coder/coder/codersdk/agentsdk" + "github.com/coder/coder/pty" +) + +const ( + // MagicSessionErrorCode indicates that something went wrong with the session, rather than the + // command just returning a nonzero exit code, and is chosen as an arbitrary, high number + // unlikely to shadow other exit codes, which are typically 1, 2, 3, etc. + MagicSessionErrorCode = 229 + + // MagicSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection. + // This is stripped from any commands being executed, and is counted towards connection stats. + MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE" + // MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself. + MagicSessionTypeVSCode = "vscode" + // MagicSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself. + MagicSessionTypeJetBrains = "jetbrains" +) + +type Server struct { + mu sync.RWMutex // Protects following. + listeners map[net.Listener]struct{} + conns map[net.Conn]struct{} + closing chan struct{} + // Wait for goroutines to exit, waited without + // a lock on mu but protected by closing. + wg sync.WaitGroup + + logger slog.Logger + srv *ssh.Server + + Env map[string]string + AgentToken func() string + Manifest *atomic.Pointer[agentsdk.Manifest] + + connCountVSCode atomic.Int64 + connCountJetBrains atomic.Int64 + connCountSSHSession atomic.Int64 +} + +func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration) (*Server, error) { + // Clients' should ignore the host key when connecting. + // The agent needs to authenticate with coderd to SSH, + // so SSH authentication doesn't improve security. + randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + randomSigner, err := gossh.NewSignerFromKey(randomHostKey) + if err != nil { + return nil, err + } + + forwardHandler := &ssh.ForwardedTCPHandler{} + unixForwardHandler := &forwardedUnixHandler{log: logger} + + s := &Server{ + listeners: make(map[net.Listener]struct{}), + conns: make(map[net.Conn]struct{}), + logger: logger, + } + + s.srv = &ssh.Server{ + ChannelHandlers: map[string]ssh.ChannelHandler{ + "direct-tcpip": ssh.DirectTCPIPHandler, + "direct-streamlocal@openssh.com": directStreamLocalHandler, + "session": ssh.DefaultSessionHandler, + }, + ConnectionFailedCallback: func(_ net.Conn, err error) { + s.logger.Info(ctx, "ssh connection ended", slog.Error(err)) + }, + Handler: s.sessionHandler, + HostSigners: []ssh.Signer{randomSigner}, + LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { + // Allow local port forwarding all! + s.logger.Debug(ctx, "local port forward", + slog.F("destination-host", destinationHost), + slog.F("destination-port", destinationPort)) + return true + }, + PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { + return true + }, + ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { + // Allow reverse port forwarding all! + s.logger.Debug(ctx, "local port forward", + slog.F("bind-host", bindHost), + slog.F("bind-port", bindPort)) + return true + }, + RequestHandlers: map[string]ssh.RequestHandler{ + "tcpip-forward": forwardHandler.HandleSSHRequest, + "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, + "streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, + "cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, + }, + ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { + return &gossh.ServerConfig{ + NoClientAuth: true, + } + }, + SubsystemHandlers: map[string]ssh.SubsystemHandler{ + "sftp": s.sftpHandler, + }, + MaxTimeout: maxTimeout, + } + + return s, nil +} + +type ConnStats struct { + Sessions int64 + VSCode int64 + JetBrains int64 +} + +func (s *Server) ConnStats() ConnStats { + return ConnStats{ + Sessions: s.connCountSSHSession.Load(), + VSCode: s.connCountVSCode.Load(), + JetBrains: s.connCountJetBrains.Load(), + } +} + +func (s *Server) sessionHandler(session ssh.Session) { + ctx := session.Context() + err := s.sessionStart(session) + var exitError *exec.ExitError + if xerrors.As(err, &exitError) { + s.logger.Debug(ctx, "ssh session returned", slog.Error(exitError)) + _ = session.Exit(exitError.ExitCode()) + return + } + if err != nil { + s.logger.Warn(ctx, "ssh session failed", slog.Error(err)) + // This exit code is designed to be unlikely to be confused for a legit exit code + // from the process. + _ = session.Exit(MagicSessionErrorCode) + return + } + _ = session.Exit(0) +} + +func (s *Server) sessionStart(session ssh.Session) (retErr error) { + ctx := session.Context() + env := session.Environ() + var magicType string + for index, kv := range env { + if !strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable) { + continue + } + magicType = strings.TrimPrefix(kv, MagicSessionTypeEnvironmentVariable+"=") + env = append(env[:index], env[index+1:]...) + } + switch magicType { + case MagicSessionTypeVSCode: + s.connCountVSCode.Add(1) + defer s.connCountVSCode.Add(-1) + case MagicSessionTypeJetBrains: + s.connCountJetBrains.Add(1) + defer s.connCountJetBrains.Add(-1) + case "": + s.connCountSSHSession.Add(1) + defer s.connCountSSHSession.Add(-1) + default: + s.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType)) + } + + cmd, err := s.CreateCommand(ctx, session.RawCommand(), env) + if err != nil { + return err + } + + if ssh.AgentRequested(session) { + l, err := ssh.NewAgentListener() + if err != nil { + return xerrors.Errorf("new agent listener: %w", err) + } + defer l.Close() + go ssh.ForwardAgentConnections(l, session) + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String())) + } + + sshPty, windowSize, isPty := session.Pty() + if isPty { + // Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL). + // See https://github.com/coder/coder/issues/3371. + session.DisablePTYEmulation() + + if !isQuietLogin(session.RawCommand()) { + manifest := s.Manifest.Load() + if manifest != nil { + err = showMOTD(session, manifest.MOTDFile) + if err != nil { + s.logger.Error(ctx, "show MOTD", slog.Error(err)) + } + } else { + s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") + } + } + + cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) + + // The pty package sets `SSH_TTY` on supported platforms. + ptty, process, err := pty.Start(cmd, pty.WithPTYOption( + pty.WithSSHRequest(sshPty), + pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)), + )) + if err != nil { + return xerrors.Errorf("start command: %w", err) + } + var wg sync.WaitGroup + defer func() { + defer wg.Wait() + closeErr := ptty.Close() + if closeErr != nil { + s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) + if retErr == nil { + retErr = closeErr + } + } + }() + go func() { + for win := range windowSize { + resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) + // If the pty is closed, then command has exited, no need to log. + if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { + s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) + } + } + }() + // We don't add input copy to wait group because + // it won't return until the session is closed. + go func() { + _, _ = io.Copy(ptty.Input(), session) + }() + + // In low parallelism scenarios, the command may exit and we may close + // the pty before the output copy has started. This can result in the + // output being lost. To avoid this, we wait for the output copy to + // start before waiting for the command to exit. This ensures that the + // output copy goroutine will be scheduled before calling close on the + // pty. This shouldn't be needed because of `pty.Dup()` below, but it + // may not be supported on all platforms. + outputCopyStarted := make(chan struct{}) + ptyOutput := func() io.ReadCloser { + defer close(outputCopyStarted) + // Try to dup so we can separate stdin and stdout closure. + // Once the original pty is closed, the dup will return + // input/output error once the buffered data has been read. + stdout, err := ptty.Dup() + if err == nil { + return stdout + } + // If we can't dup, we shouldn't close + // the fd since it's tied to stdin. + return readNopCloser{ptty.Output()} + } + wg.Add(1) + go func() { + // Ensure data is flushed to session on command exit, if we + // close the session too soon, we might lose data. + defer wg.Done() + + stdout := ptyOutput() + defer stdout.Close() + + _, _ = io.Copy(session, stdout) + }() + <-outputCopyStarted + + err = process.Wait() + var exitErr *exec.ExitError + // ExitErrors just mean the command we run returned a non-zero exit code, which is normal + // and not something to be concerned about. But, if it's something else, we should log it. + if err != nil && !xerrors.As(err, &exitErr) { + s.logger.Warn(ctx, "wait error", slog.Error(err)) + } + return err + } + + cmd.Stdout = session + cmd.Stderr = session.Stderr() + // This blocks forever until stdin is received if we don't + // use StdinPipe. It's unknown what causes this. + stdinPipe, err := cmd.StdinPipe() + if err != nil { + return xerrors.Errorf("create stdin pipe: %w", err) + } + go func() { + _, _ = io.Copy(stdinPipe, session) + _ = stdinPipe.Close() + }() + err = cmd.Start() + if err != nil { + return xerrors.Errorf("start: %w", err) + } + return cmd.Wait() +} + +type readNopCloser struct{ io.Reader } + +// Close implements io.Closer. +func (readNopCloser) Close() error { return nil } + +func (s *Server) sftpHandler(session ssh.Session) { + ctx := session.Context() + + // Typically sftp sessions don't request a TTY, but if they do, + // we must ensure the gliderlabs/ssh CRLF emulation is disabled. + // Otherwise sftp will be broken. This can happen if a user sets + // `RequestTTY force` in their SSH config. + session.DisablePTYEmulation() + + var opts []sftp.ServerOption + // Change current working directory to the users home + // directory so that SFTP connections land there. + homedir, err := userHomeDir() + if err != nil { + s.logger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err)) + } else { + opts = append(opts, sftp.WithServerWorkingDirectory(homedir)) + } + + server, err := sftp.NewServer(session, opts...) + if err != nil { + s.logger.Debug(ctx, "initialize sftp server", slog.Error(err)) + return + } + defer server.Close() + + err = server.Serve() + if errors.Is(err, io.EOF) { + // Unless we call `session.Exit(0)` here, the client won't + // receive `exit-status` because `(*sftp.Server).Close()` + // calls `Close()` on the underlying connection (session), + // which actually calls `channel.Close()` because it isn't + // wrapped. This causes sftp clients to receive a non-zero + // exit code. Typically sftp clients don't echo this exit + // code but `scp` on macOS does (when using the default + // SFTP backend). + _ = session.Exit(0) + return + } + s.logger.Warn(ctx, "sftp server closed with error", slog.Error(err)) + _ = session.Exit(1) +} + +// CreateCommand processes raw command input with OpenSSH-like behavior. +// If the script provided is empty, it will default to the users shell. +// This injects environment variables specified by the user at launch too. +func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) { + currentUser, err := user.Current() + if err != nil { + return nil, xerrors.Errorf("get current user: %w", err) + } + username := currentUser.Username + + shell, err := usershell.Get(username) + if err != nil { + return nil, xerrors.Errorf("get user shell: %w", err) + } + + manifest := s.Manifest.Load() + if manifest == nil { + return nil, xerrors.Errorf("no metadata was provided") + } + + // OpenSSH executes all commands with the users current shell. + // We replicate that behavior for IDE support. + caller := "-c" + if runtime.GOOS == "windows" { + caller = "/c" + } + args := []string{caller, script} + + // gliderlabs/ssh returns a command slice of zero + // when a shell is requested. + if len(script) == 0 { + args = []string{} + if runtime.GOOS != "windows" { + // On Linux and macOS, we should start a login + // shell to consume juicy environment variables! + args = append(args, "-l") + } + } + + cmd := exec.CommandContext(ctx, shell, args...) + cmd.Dir = manifest.Directory + + // If the metadata directory doesn't exist, we run the command + // in the users home directory. + _, err = os.Stat(cmd.Dir) + if cmd.Dir == "" || err != nil { + // Default to user home if a directory is not set. + homedir, err := userHomeDir() + if err != nil { + return nil, xerrors.Errorf("get home dir: %w", err) + } + cmd.Dir = homedir + } + cmd.Env = append(os.Environ(), env...) + executablePath, err := os.Executable() + if err != nil { + return nil, xerrors.Errorf("getting os executable: %w", err) + } + // Set environment variables reliable detection of being inside a + // Coder workspace. + cmd.Env = append(cmd.Env, "CODER=true") + cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username)) + // Git on Windows resolves with UNIX-style paths. + // If using backslashes, it's unable to find the executable. + unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/") + cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath)) + + // Specific Coder subcommands require the agent token exposed! + cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", s.AgentToken())) + + // Set SSH connection environment variables (these are also set by OpenSSH + // and thus expected to be present by SSH clients). Since the agent does + // networking in-memory, trying to provide accurate values here would be + // nonsensical. For now, we hard code these values so that they're present. + srcAddr, srcPort := "0.0.0.0", "0" + dstAddr, dstPort := "0.0.0.0", "0" + cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CLIENT=%s %s %s", srcAddr, srcPort, dstPort)) + cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", srcAddr, srcPort, dstAddr, dstPort)) + + // This adds the ports dialog to code-server that enables + // proxying a port dynamically. + cmd.Env = append(cmd.Env, fmt.Sprintf("VSCODE_PROXY_URI=%s", manifest.VSCodePortProxyURI)) + + // Hide Coder message on code-server's "Getting Started" page + cmd.Env = append(cmd.Env, "CS_DISABLE_GETTING_STARTED_OVERRIDE=true") + + // Load environment variables passed via the agent. + // These should override all variables we manually specify. + for envKey, value := range manifest.EnvironmentVariables { + // Expanding environment variables allows for customization + // of the $PATH, among other variables. Customers can prepend + // or append to the $PATH, so allowing expand is required! + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, os.ExpandEnv(value))) + } + + // Agent-level environment variables should take over all! + // This is used for setting agent-specific variables like "CODER_AGENT_TOKEN". + for envKey, value := range s.Env { + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value)) + } + + return cmd, nil +} + +func (s *Server) Serve(l net.Listener) error { + defer l.Close() + + s.trackListener(l, true) + defer s.trackListener(l, false) + for { + conn, err := l.Accept() + if err != nil { + return err + } + go s.handleConn(l, conn) + } +} + +func (s *Server) handleConn(l net.Listener, c net.Conn) { + defer c.Close() + + if !s.trackConn(l, c, true) { + // Server is closed or we no longer want + // connections from this listener. + s.logger.Debug(context.Background(), "received connection after server closed") + return + } + defer s.trackConn(l, c, false) + + s.srv.HandleConn(c) +} + +// trackListener registers the listener with the server. If the server is +// closing, the function will block until the server is closed. +// +//nolint:revive +func (s *Server) trackListener(l net.Listener, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if add { + for s.closing != nil { + closing := s.closing + // Wait until close is complete before + // serving a new listener. + s.mu.Unlock() + <-closing + s.mu.Lock() + } + s.wg.Add(1) + s.listeners[l] = struct{}{} + return + } + s.wg.Done() + delete(s.listeners, l) +} + +// trackConn registers the connection with the server. If the server is +// closed or the listener is closed, the connection is not registered +// and should be closed. +// +//nolint:revive +func (s *Server) trackConn(l net.Listener, c net.Conn, add bool) (ok bool) { + s.mu.Lock() + defer s.mu.Unlock() + if add { + found := false + for ll := range s.listeners { + if l == ll { + found = true + break + } + } + if s.closing != nil || !found { + // Server or listener closed. + return false + } + s.wg.Add(1) + s.conns[c] = struct{}{} + return true + } + s.wg.Done() + delete(s.conns, c) + return true +} + +// Close the server and all active connections. Server can be re-used +// after Close is done. +func (s *Server) Close() error { + s.mu.Lock() + + // Guard against multiple calls to Close and + // accepting new connections during close. + if s.closing != nil { + s.mu.Unlock() + return xerrors.New("server is closing") + } + s.closing = make(chan struct{}) + + // Close all active listeners and connections. + for l := range s.listeners { + _ = l.Close() + } + for c := range s.conns { + _ = c.Close() + } + + // Close the underlying SSH server. + err := s.srv.Close() + + s.mu.Unlock() + s.wg.Wait() // Wait for all goroutines to exit. + + s.mu.Lock() + close(s.closing) + s.closing = nil + s.mu.Unlock() + + return err +} + +// Shutdown gracefully closes all active SSH connections and stops +// accepting new connections. +// +// Shutdown is not implemented. +func (*Server) Shutdown(_ context.Context) error { + // TODO(mafredri): Implement shutdown, SIGHUP running commands, etc. + return nil +} + +// isQuietLogin checks if the SSH server should perform a quiet login or not. +// +// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L816 +func isQuietLogin(rawCommand string) bool { + // We are always quiet unless this is a login shell. + if len(rawCommand) != 0 { + return true + } + + // Best effort, if we can't get the home directory, + // we can't lookup .hushlogin. + homedir, err := userHomeDir() + if err != nil { + return false + } + + _, err = os.Stat(filepath.Join(homedir, ".hushlogin")) + return err == nil +} + +// showMOTD will output the message of the day from +// the given filename to dest, if the file exists. +// +// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L784 +func showMOTD(dest io.Writer, filename string) error { + if filename == "" { + return nil + } + + f, err := os.Open(filename) + if err != nil { + if xerrors.Is(err, os.ErrNotExist) { + // This is not an error, there simply isn't a MOTD to show. + return nil + } + return xerrors.Errorf("open MOTD: %w", err) + } + defer f.Close() + + s := bufio.NewScanner(f) + for s.Scan() { + // Carriage return ensures each line starts + // at the beginning of the terminal. + _, err = fmt.Fprint(dest, s.Text()+"\r\n") + if err != nil { + return xerrors.Errorf("write MOTD: %w", err) + } + } + if err := s.Err(); err != nil { + return xerrors.Errorf("read MOTD: %w", err) + } + + return nil +} + +// userHomeDir returns the home directory of the current user, giving +// priority to the $HOME environment variable. +func userHomeDir() (string, error) { + // First we check the environment. + homedir, err := os.UserHomeDir() + if err == nil { + return homedir, nil + } + + // As a fallback, we try the user information. + u, err := user.Current() + if err != nil { + return "", xerrors.Errorf("current user: %w", err) + } + return u.HomeDir, nil +} diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go new file mode 100644 index 0000000000000..684c0e36bbb18 --- /dev/null +++ b/agent/agentssh/agentssh_test.go @@ -0,0 +1,139 @@ +// Package agentssh_test provides tests for basic functinoality of the agentssh +// package, more test coverage can be found in the `agent` and `cli` package(s). +package agentssh_test + +import ( + "bytes" + "context" + "net" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + "go.uber.org/goleak" + "golang.org/x/crypto/ssh" + + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/agent/agentssh" + "github.com/coder/coder/codersdk/agentsdk" + "github.com/coder/coder/pty/ptytest" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestNewServer_ServeClient(t *testing.T) { + t.Parallel() + + ctx := context.Background() + logger := slogtest.Make(t, nil) + s, err := agentssh.NewServer(ctx, logger, 0) + require.NoError(t, err) + + // The assumption is that these are set before serving SSH connections. + s.AgentToken = func() string { return "" } + s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + err := s.Serve(ln) + assert.Error(t, err) // Server is closed. + }() + + c := sshClient(t, ln.Addr().String()) + var b bytes.Buffer + sess, err := c.NewSession() + sess.Stdout = &b + require.NoError(t, err) + err = sess.Start("echo hello") + require.NoError(t, err) + + err = sess.Wait() + require.NoError(t, err) + + require.Equal(t, "hello", strings.TrimSpace(b.String())) + + err = s.Close() + require.NoError(t, err) + <-done +} + +func TestNewServer_CloseActiveConnections(t *testing.T) { + t.Parallel() + + ctx := context.Background() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + s, err := agentssh.NewServer(ctx, logger, 0) + require.NoError(t, err) + + // The assumption is that these are set before serving SSH connections. + s.AgentToken = func() string { return "" } + s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + err := s.Serve(ln) + assert.Error(t, err) // Server is closed. + }() + + pty := ptytest.New(t) + + doClose := make(chan struct{}) + go func() { + defer wg.Done() + c := sshClient(t, ln.Addr().String()) + sess, err := c.NewSession() + sess.Stdin = pty.Input() + sess.Stdout = pty.Output() + sess.Stderr = pty.Output() + + assert.NoError(t, err) + err = sess.Start("") + assert.NoError(t, err) + + close(doClose) + err = sess.Wait() + assert.Error(t, err) + }() + + <-doClose + err = s.Close() + require.NoError(t, err) + + wg.Wait() +} + +func sshClient(t *testing.T, addr string) *ssh.Client { + conn, err := net.Dial("tcp", addr) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.Close() + }) + + sshConn, channels, requests, err := ssh.NewClientConn(conn, "localhost:22", &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // This is a test. + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = sshConn.Close() + }) + c := ssh.NewClient(sshConn, channels, requests) + t.Cleanup(func() { + _ = c.Close() + }) + return c +} diff --git a/agent/agentssh/bicopy.go b/agent/agentssh/bicopy.go new file mode 100644 index 0000000000000..64cd2a716058c --- /dev/null +++ b/agent/agentssh/bicopy.go @@ -0,0 +1,47 @@ +package agentssh + +import ( + "context" + "io" + "sync" +) + +// Bicopy copies all of the data between the two connections and will close them +// after one or both of them are done writing. If the context is canceled, both +// of the connections will be closed. +func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + defer func() { + _ = c1.Close() + _ = c2.Close() + }() + + var wg sync.WaitGroup + copyFunc := func(dst io.WriteCloser, src io.Reader) { + defer func() { + wg.Done() + // If one side of the copy fails, ensure the other one exits as + // well. + cancel() + }() + _, _ = io.Copy(dst, src) + } + + wg.Add(2) + go copyFunc(c1, c2) + go copyFunc(c2, c1) + + // Convert waitgroup to a channel so we can also wait on the context. + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() + }() + + select { + case <-ctx.Done(): + case <-done: + } +} diff --git a/agent/ssh.go b/agent/agentssh/forward.go similarity index 99% rename from agent/ssh.go rename to agent/agentssh/forward.go index 8aa41a1d268ed..1e3635fd8ff91 100644 --- a/agent/ssh.go +++ b/agent/agentssh/forward.go @@ -1,4 +1,4 @@ -package agent +package agentssh import ( "context" diff --git a/cli/portforward.go b/cli/portforward.go index c746216889a55..dad82381bfb5b 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -14,7 +14,7 @@ import ( "github.com/pion/udp" "golang.org/x/xerrors" - "github.com/coder/coder/agent" + "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" @@ -226,7 +226,7 @@ func listenAndPortForward(ctx context.Context, inv *clibase.Invocation, conn *co } defer remoteConn.Close() - agent.Bicopy(ctx, netConn, remoteConn) + agentssh.Bicopy(ctx, netConn, remoteConn) }(netConn) } }(spec) diff --git a/cli/ssh.go b/cli/ssh.go index e9168f6999f6b..e1ebbcd04cfd2 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -23,7 +23,7 @@ import ( "golang.org/x/term" "golang.org/x/xerrors" - "github.com/coder/coder/agent" + "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/autobuild/notify" @@ -574,7 +574,7 @@ func sshForwardRemote(ctx context.Context, stderr io.Writer, sshClient *gossh.Cl } } - agent.Bicopy(ctx, localConn, remoteConn) + agentssh.Bicopy(ctx, localConn, remoteConn) }() } }() diff --git a/coderd/workspaceapps/proxy.go b/coderd/workspaceapps/proxy.go index 7189d19ff9a3b..82d112d7273ac 100644 --- a/coderd/workspaceapps/proxy.go +++ b/coderd/workspaceapps/proxy.go @@ -18,7 +18,7 @@ import ( "nhooyr.io/websocket" "cdr.dev/slog" - "github.com/coder/coder/agent" + "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/tracing" @@ -575,7 +575,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { return } defer ptNetConn.Close() - agent.Bicopy(ctx, wsNetConn, ptNetConn) + agentssh.Bicopy(ctx, wsNetConn, ptNetConn) } // wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: