From a86247287cbc7e35a473f39b9d8698b7393a82cc Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Mon, 28 Apr 2025 07:42:46 +0000 Subject: [PATCH 01/10] feat(cli): use coder connect in `coder ssh`, if available --- cli/cliutil/stdioconn.go | 36 +++++ cli/ssh.go | 232 ++++++++++++++++++++++----- cli/ssh_internal_test.go | 118 ++++++++++++++ cli/ssh_test.go | 86 +++++----- cli/testdata/coder_ssh_--help.golden | 4 + codersdk/workspacesdk/agentconn.go | 4 +- docs/reference/cli/ssh.md | 8 + 7 files changed, 411 insertions(+), 77 deletions(-) create mode 100644 cli/cliutil/stdioconn.go diff --git a/cli/cliutil/stdioconn.go b/cli/cliutil/stdioconn.go new file mode 100644 index 0000000000000..7f919dbf9d456 --- /dev/null +++ b/cli/cliutil/stdioconn.go @@ -0,0 +1,36 @@ +package cliutil + +import ( + "io" + "net" + "time" +) + +type StdioConn struct { + io.Reader + io.Writer +} + +func (*StdioConn) Close() (err error) { + return nil +} + +func (*StdioConn) LocalAddr() net.Addr { + return nil +} + +func (*StdioConn) RemoteAddr() net.Addr { + return nil +} + +func (*StdioConn) SetDeadline(_ time.Time) error { + return nil +} + +func (*StdioConn) SetReadDeadline(_ time.Time) error { + return nil +} + +func (*StdioConn) SetWriteDeadline(_ time.Time) error { + return nil +} diff --git a/cli/ssh.go b/cli/ssh.go index e02443e7032c6..82cab0aee1219 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "log" + "net" "net/http" "net/url" "os" @@ -66,6 +67,7 @@ func (r *RootCmd) ssh() *serpent.Command { stdio bool hostPrefix string hostnameSuffix string + forceTunnel bool forwardAgent bool forwardGPG bool identityAgent string @@ -85,6 +87,7 @@ func (r *RootCmd) ssh() *serpent.Command { containerUser string ) client := new(codersdk.Client) + wsClient := workspacesdk.New(client) cmd := &serpent.Command{ Annotations: workspaceCommand, Use: "ssh ", @@ -203,14 +206,14 @@ func (r *RootCmd) ssh() *serpent.Command { parsedEnv = append(parsedEnv, [2]string{k, v}) } - deploymentSSHConfig := codersdk.SSHConfigResponse{ + cliConfig := codersdk.SSHConfigResponse{ HostnamePrefix: hostPrefix, HostnameSuffix: hostnameSuffix, } workspace, workspaceAgent, err := findWorkspaceAndAgentByHostname( ctx, inv, client, - inv.Args[0], deploymentSSHConfig, disableAutostart) + inv.Args[0], cliConfig, disableAutostart) if err != nil { return err } @@ -275,10 +278,34 @@ func (r *RootCmd) ssh() *serpent.Command { return err } + // See if we can use the Coder Connect tunnel + if !forceTunnel { + connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx) + if err != nil { + return xerrors.Errorf("get agent connection info: %w", err) + } + + coderConnectHost := fmt.Sprintf("%s.%s.%s.%s", + workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix) + exists, _ := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost) + if exists { + _, _ = fmt.Fprintln(inv.Stderr, "Connecting to workspace via Coder Connect...") + defer cancel() + addr := fmt.Sprintf("%s:22", coderConnectHost) + if stdio { + if err := writeCoderConnectNetInfo(ctx, networkInfoDir); err != nil { + logger.Error(ctx, "failed to write coder connect net info file", slog.Error(err)) + } + return runCoderConnectStdio(ctx, addr, stdioReader, stdioWriter, stack) + } + return runCoderConnectPTY(ctx, addr, inv.Stdin, inv.Stdout, inv.Stderr, stack) + } + } + if r.disableDirect { _, _ = fmt.Fprintln(inv.Stderr, "Direct connections disabled.") } - conn, err := workspacesdk.New(client). + conn, err := wsClient. DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{ Logger: logger, BlockEndpoints: r.disableDirect, @@ -454,36 +481,11 @@ func (r *RootCmd) ssh() *serpent.Command { stdinFile, validIn := inv.Stdin.(*os.File) stdoutFile, validOut := inv.Stdout.(*os.File) if validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) { - inState, err := pty.MakeInputRaw(stdinFile.Fd()) - if err != nil { - return err - } - defer func() { - _ = pty.RestoreTerminal(stdinFile.Fd(), inState) - }() - outState, err := pty.MakeOutputRaw(stdoutFile.Fd()) + restorePtyFn, err := configurePTY(ctx, stdinFile, stdoutFile, sshSession) + defer restorePtyFn() if err != nil { - return err + return xerrors.Errorf("configure pty: %w", err) } - defer func() { - _ = pty.RestoreTerminal(stdoutFile.Fd(), outState) - }() - - windowChange := listenWindowSize(ctx) - go func() { - for { - select { - case <-ctx.Done(): - return - case <-windowChange: - } - width, height, err := term.GetSize(int(stdoutFile.Fd())) - if err != nil { - continue - } - _ = sshSession.WindowChange(height, width) - } - }() } for _, kv := range parsedEnv { @@ -662,11 +664,51 @@ func (r *RootCmd) ssh() *serpent.Command { Value: serpent.StringOf(&containerUser), Hidden: true, // Hidden until this features is at least in beta. }, + { + Flag: "force-tunnel", + Description: "Force the use of a new tunnel to the workspace, even if the Coder Connect tunnel is available.", + Value: serpent.BoolOf(&forceTunnel), + }, sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)), } return cmd } +func configurePTY(ctx context.Context, stdinFile *os.File, stdoutFile *os.File, sshSession *gossh.Session) (restoreFn func(), err error) { + inState, err := pty.MakeInputRaw(stdinFile.Fd()) + if err != nil { + return restoreFn, err + } + restoreFn = func() { + _ = pty.RestoreTerminal(stdinFile.Fd(), inState) + } + outState, err := pty.MakeOutputRaw(stdoutFile.Fd()) + if err != nil { + return restoreFn, err + } + restoreFn = func() { + _ = pty.RestoreTerminal(stdinFile.Fd(), inState) + _ = pty.RestoreTerminal(stdoutFile.Fd(), outState) + } + + windowChange := listenWindowSize(ctx) + go func() { + for { + select { + case <-ctx.Done(): + return + case <-windowChange: + } + width, height, err := term.GetSize(int(stdoutFile.Fd())) + if err != nil { + continue + } + _ = sshSession.WindowChange(height, width) + } + }() + return restoreFn, nil +} + // findWorkspaceAndAgentByHostname parses the hostname from the commandline and finds the workspace and agent it // corresponds to, taking into account any name prefixes or suffixes configured (e.g. myworkspace.coder, or // vscode-coder--myusername--myworkspace). @@ -1374,12 +1416,13 @@ func setStatsCallback( } type sshNetworkStats struct { - P2P bool `json:"p2p"` - Latency float64 `json:"latency"` - PreferredDERP string `json:"preferred_derp"` - DERPLatency map[string]float64 `json:"derp_latency"` - UploadBytesSec int64 `json:"upload_bytes_sec"` - DownloadBytesSec int64 `json:"download_bytes_sec"` + P2P bool `json:"p2p"` + Latency float64 `json:"latency"` + PreferredDERP string `json:"preferred_derp"` + DERPLatency map[string]float64 `json:"derp_latency"` + UploadBytesSec int64 `json:"upload_bytes_sec"` + DownloadBytesSec int64 `json:"download_bytes_sec"` + UsingCoderConnect bool `json:"using_coder_connect"` } func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) { @@ -1450,6 +1493,121 @@ func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, }, nil } +func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack) error { + conn, err := net.Dial("tcp", addr) + if err != nil { + return xerrors.Errorf("dial coder connect host: %w", err) + } + if err := stack.push("tcp conn", conn); err != nil { + return err + } + + agentssh.Bicopy(ctx, conn, &cliutil.StdioConn{ + Reader: stdin, + Writer: stdout, + }) + + return nil +} + +func runCoderConnectPTY(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stderr io.Writer, stack *closerStack) error { + client, err := gossh.Dial("tcp", addr, &gossh.ClientConfig{ + // We've already checked the agent's address + // is within the Coder service prefix. + // #nosec + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }) + if err != nil { + return xerrors.Errorf("dial coder connect host: %w", err) + } + if err := stack.push("ssh client", client); err != nil { + return err + } + + session, err := client.NewSession() + if err != nil { + return xerrors.Errorf("create ssh session: %w", err) + } + if err := stack.push("ssh session", session); err != nil { + return err + } + + stdinFile, validIn := stdin.(*os.File) + stdoutFile, validOut := stdout.(*os.File) + if validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) { + restorePtyFn, err := configurePTY(ctx, stdinFile, stdoutFile, session) + defer restorePtyFn() + if err != nil { + return xerrors.Errorf("configure pty: %w", err) + } + } + + session.Stdin = stdin + session.Stdout = stdout + session.Stderr = stderr + + err = session.RequestPty("xterm-256color", 80, 24, gossh.TerminalModes{}) + if err != nil { + return xerrors.Errorf("request pty: %w", err) + } + + err = session.Shell() + if err != nil { + return xerrors.Errorf("start shell: %w", err) + } + + if validOut { + // Set initial window size. + width, height, err := term.GetSize(int(stdoutFile.Fd())) + if err == nil { + _ = session.WindowChange(height, width) + } + } + + err = session.Wait() + if err != nil { + if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) { + // Clear the error since it's not useful beyond + // reporting status. + return ExitError(exitErr.ExitStatus(), nil) + } + // If the connection drops unexpectedly, we get an + // ExitMissingError but no other error details, so try to at + // least give the user a better message + if errors.Is(err, &gossh.ExitMissingError{}) { + return ExitError(255, xerrors.New("SSH connection ended unexpectedly")) + } + return xerrors.Errorf("session ended: %w", err) + } + + return nil +} + +func writeCoderConnectNetInfo(ctx context.Context, networkInfoDir string) error { + fs, ok := ctx.Value("fs").(afero.Fs) + if !ok { + fs = afero.NewOsFs() + } + // The VS Code extension obtains the PID of the SSH process to + // find the log file associated with a SSH session. + // + // We get the parent PID because it's assumed `ssh` is calling this + // command via the ProxyCommand SSH option. + networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", os.Getppid())) + stats := &sshNetworkStats{ + UsingCoderConnect: true, + } + rawStats, err := json.Marshal(stats) + if err != nil { + return xerrors.Errorf("marshal network stats: %w", err) + } + err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0o600) + if err != nil { + return xerrors.Errorf("write network stats: %w", err) + } + return nil +} + // Converts workspace name input to owner/workspace.agent format // Possible valid input formats: // workspace diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index d5e4c049347b2..6de6bd9ea24bf 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -3,20 +3,26 @@ package cli import ( "context" "fmt" + "io" + "net" "net/url" "sync" "testing" "time" + gliderssh "github.com/gliderlabs/ssh" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" "golang.org/x/xerrors" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/quartz" + "github.com/coder/coder/v2/cli/cliutil" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" ) @@ -220,6 +226,118 @@ func TestCloserStack_Timeout(t *testing.T) { testutil.TryReceive(ctx, t, closed) } +func TestCoderConnectPTY(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + stack := newCloserStack(ctx, logger, quartz.NewMock(t)) + + server := newSSHServer("127.0.0.1:0") + ln, err := net.Listen("tcp", server.server.Addr) + require.NoError(t, err) + + go func() { + _ = server.Serve(ln) + }() + t.Cleanup(func() { + _ = server.Close() + }) + + ptty := ptytest.New(t) + ptyDone := make(chan struct{}) + go func() { + err := runCoderConnectPTY(ctx, ln.Addr().String(), ptty.Output(), ptty.Input(), ptty.Output(), stack) + assert.NoError(t, err) + close(ptyDone) + }() + ptty.ExpectMatch("Connected!") + // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. + ptty.WriteLine("exit") + <-ptyDone +} + +func TestCoderConnectStdio(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + stack := newCloserStack(ctx, logger, quartz.NewMock(t)) + + clientOutput, clientInput := io.Pipe() + serverOutput, serverInput := io.Pipe() + defer func() { + for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} { + _ = c.Close() + } + }() + + server := newSSHServer("127.0.0.1:0") + ln, err := net.Listen("tcp", server.server.Addr) + require.NoError(t, err) + + go func() { + _ = server.Serve(ln) + }() + t.Cleanup(func() { + _ = server.Close() + }) + + stdioDone := make(chan struct{}) + go func() { + err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack) + assert.NoError(t, err) + close(stdioDone) + }() + + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + Reader: serverOutput, + Writer: clientInput, + }, "", &ssh.ClientConfig{ + // #nosec + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + require.NoError(t, err) + defer conn.Close() + + sshClient := ssh.NewClient(conn, channels, requests) + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() + + // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. + err = session.Run("exit") + require.NoError(t, err) + err = sshClient.Close() + require.NoError(t, err) + _ = clientOutput.Close() + + <-stdioDone +} + +type sshServer struct { + server *gliderssh.Server +} + +func newSSHServer(addr string) *sshServer { + return &sshServer{ + server: &gliderssh.Server{ + Addr: addr, + Handler: func(s gliderssh.Session) { + _, _ = io.WriteString(s, "Connected!") + }, + }, + } +} + +func (s *sshServer) Serve(ln net.Listener) error { + return s.server.Serve(ln) +} + +func (s *sshServer) Close() error { + return s.server.Close() +} + type fakeCloser struct { closes *[]*fakeCloser err error diff --git a/cli/ssh_test.go b/cli/ssh_test.go index c8ad072270169..e9dd7c4bc42b2 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -43,6 +43,7 @@ import ( agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/cliui" + "github.com/coder/coder/v2/cli/cliutil" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbfake" @@ -473,7 +474,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -542,7 +543,7 @@ func TestSSH(t *testing.T) { signer, err := agentssh.CoderSigner(keySeed) assert.NoError(t, err) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -605,7 +606,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -773,7 +774,7 @@ func TestSSH(t *testing.T) { // have access to the shell. _ = agenttest.New(t, client.URL, authToken) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ Reader: proxyCommandStdoutR, Writer: clientStdinW, }, "", &ssh.ClientConfig{ @@ -835,7 +836,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -894,7 +895,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -1082,7 +1083,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -1741,7 +1742,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -2110,6 +2111,46 @@ func TestSSH_Container(t *testing.T) { }) } +func TestSSH_CoderConnect(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + fs := afero.NewMemMapFs() + //nolint:revive,staticcheck + ctx = context.WithValue(ctx, "fs", fs) + + client, workspace, agentToken := setupWorkspaceForAgent(t) + inv, root := clitest.New(t, "ssh", workspace.Name, "--network-info-dir", "/net", "--stdio") + clitest.SetupConfig(t, client, root) + _ = ptytest.New(t).Attach(inv) + + errCh := make(chan error, 1) + tGo(t, func() { + err := inv.WithContext(withCoderConnectRunning(ctx)).Run() + errCh <- err + }) + + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + err := testutil.TryReceive(ctx, t, errCh) + // Making an SSH server available here is difficult, so we'll just check + // the command attempts to dial it. + require.ErrorContains(t, err, "dial coder connect host") + require.ErrorContains(t, err, "dev.myworkspace.myuser.coder") + + // The network info file should be created since we passed `--stdio` + assert.Eventually(t, func() bool { + entries, err := afero.ReadDir(fs, "/net") + if err != nil { + return false + } + return len(entries) > 0 + }, testutil.WaitLong, testutil.IntervalFast) +} + // tGoContext runs fn in a goroutine passing a context that will be // canceled on test completion and wait until fn has finished executing. // Done and cancel are returned for optionally waiting until completion @@ -2153,35 +2194,6 @@ func tGo(t *testing.T, fn func()) (done <-chan struct{}) { return doneC } -type stdioConn struct { - io.Reader - io.Writer -} - -func (*stdioConn) Close() (err error) { - return nil -} - -func (*stdioConn) LocalAddr() net.Addr { - return nil -} - -func (*stdioConn) RemoteAddr() net.Addr { - return nil -} - -func (*stdioConn) SetDeadline(_ time.Time) error { - return nil -} - -func (*stdioConn) SetReadDeadline(_ time.Time) error { - return nil -} - -func (*stdioConn) SetWriteDeadline(_ time.Time) error { - return nil -} - // tempDirUnixSocket returns a temporary directory that can safely hold unix // sockets (probably). // diff --git a/cli/testdata/coder_ssh_--help.golden b/cli/testdata/coder_ssh_--help.golden index 1f7122dd655a2..9aefb24145596 100644 --- a/cli/testdata/coder_ssh_--help.golden +++ b/cli/testdata/coder_ssh_--help.golden @@ -12,6 +12,10 @@ OPTIONS: -e, --env string-array, $CODER_SSH_ENV Set environment variable(s) for session (key1=value1,key2=value2,...). + --force-tunnel bool + Force the use of a new tunnel to the workspace, even if the Coder + Connect tunnel is available. + -A, --forward-agent bool, $CODER_SSH_FORWARD_AGENT Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK. diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index fa569080f7dd2..97b4268c68780 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -185,14 +185,12 @@ func (c *AgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, return c.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), port)) } -// SSHClient calls SSH to create a client that uses a weak cipher -// to improve throughput. +// SSHClient calls SSH to create a client func (c *AgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { return c.SSHClientOnPort(ctx, AgentSSHPort) } // SSHClientOnPort calls SSH to create a client on a specific port -// that uses a weak cipher to improve throughput. func (c *AgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() diff --git a/docs/reference/cli/ssh.md b/docs/reference/cli/ssh.md index c5bae755c8419..b9d76afd1452f 100644 --- a/docs/reference/cli/ssh.md +++ b/docs/reference/cli/ssh.md @@ -138,6 +138,14 @@ Specifies a directory to write network information periodically. Specifies the interval to update network information. +### --force-tunnel + +| | | +|------|-------------------| +| Type | bool | + +Force the use of a new tunnel to the workspace, even if the Coder Connect tunnel is available. + ### --disable-autostart | | | From d8e1c90fa6f33852d687d9741b4f851c676979de Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Mon, 28 Apr 2025 11:06:29 +0000 Subject: [PATCH 02/10] fix windows tests --- cli/ssh_internal_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index 6de6bd9ea24bf..82659e1495afd 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -324,7 +324,7 @@ func newSSHServer(addr string) *sshServer { server: &gliderssh.Server{ Addr: addr, Handler: func(s gliderssh.Session) { - _, _ = io.WriteString(s, "Connected!") + _, _ = io.WriteString(s.Stderr(), "Connected!") }, }, } From 578ebb0b744d5db9cde72101412a4686dfcd931a Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Mon, 28 Apr 2025 12:46:42 +0000 Subject: [PATCH 03/10] rename flag, extra test --- cli/ssh.go | 4 +- cli/ssh_test.go | 88 +++++++++++++++++++--------- cli/testdata/coder_ssh_--help.golden | 4 +- docs/reference/cli/ssh.md | 4 +- 4 files changed, 65 insertions(+), 35 deletions(-) diff --git a/cli/ssh.go b/cli/ssh.go index 82cab0aee1219..82c66eb939964 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -665,8 +665,8 @@ func (r *RootCmd) ssh() *serpent.Command { Hidden: true, // Hidden until this features is at least in beta. }, { - Flag: "force-tunnel", - Description: "Force the use of a new tunnel to the workspace, even if the Coder Connect tunnel is available.", + Flag: "force-new-tunnel", + Description: "Force the creation of a new tunnel to the workspace, even if the Coder Connect tunnel is available.", Value: serpent.BoolOf(&forceTunnel), }, sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)), diff --git a/cli/ssh_test.go b/cli/ssh_test.go index e9dd7c4bc42b2..1a2d2aa37425f 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -2114,41 +2114,71 @@ func TestSSH_Container(t *testing.T) { func TestSSH_CoderConnect(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() + t.Run("Enabled", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() - fs := afero.NewMemMapFs() - //nolint:revive,staticcheck - ctx = context.WithValue(ctx, "fs", fs) + fs := afero.NewMemMapFs() + //nolint:revive,staticcheck + ctx = context.WithValue(ctx, "fs", fs) - client, workspace, agentToken := setupWorkspaceForAgent(t) - inv, root := clitest.New(t, "ssh", workspace.Name, "--network-info-dir", "/net", "--stdio") - clitest.SetupConfig(t, client, root) - _ = ptytest.New(t).Attach(inv) + client, workspace, agentToken := setupWorkspaceForAgent(t) + inv, root := clitest.New(t, "ssh", workspace.Name, "--network-info-dir", "/net", "--stdio") + clitest.SetupConfig(t, client, root) + _ = ptytest.New(t).Attach(inv) - errCh := make(chan error, 1) - tGo(t, func() { - err := inv.WithContext(withCoderConnectRunning(ctx)).Run() - errCh <- err + errCh := make(chan error, 1) + tGo(t, func() { + err := inv.WithContext(withCoderConnectRunning(ctx)).Run() + errCh <- err + }) + + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + err := testutil.TryReceive(ctx, t, errCh) + // Making an SSH server available here is difficult, so we'll just check + // the command attempts to dial it. + require.ErrorContains(t, err, "dial coder connect host") + require.ErrorContains(t, err, "dev.myworkspace.myuser.coder") + + // The network info file should be created since we passed `--stdio` + assert.Eventually(t, func() bool { + entries, err := afero.ReadDir(fs, "/net") + if err != nil { + return false + } + return len(entries) > 0 + }, testutil.WaitLong, testutil.IntervalFast) }) - _ = agenttest.New(t, client.URL, agentToken) - coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + t.Run("Disabled", func(t *testing.T) { + t.Parallel() - err := testutil.TryReceive(ctx, t, errCh) - // Making an SSH server available here is difficult, so we'll just check - // the command attempts to dial it. - require.ErrorContains(t, err, "dial coder connect host") - require.ErrorContains(t, err, "dev.myworkspace.myuser.coder") - - // The network info file should be created since we passed `--stdio` - assert.Eventually(t, func() bool { - entries, err := afero.ReadDir(fs, "/net") - if err != nil { - return false - } - return len(entries) > 0 - }, testutil.WaitLong, testutil.IntervalFast) + client, workspace, agentToken := setupWorkspaceForAgent(t) + inv, root := clitest.New(t, "ssh", workspace.Name, "--force-new-tunnel") + clitest.SetupConfig(t, client, root) + pty := ptytest.New(t).Attach(inv) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + cmdDone := tGo(t, func() { + err := inv.WithContext(withCoderConnectRunning(ctx)).Run() + assert.NoError(t, err) + }) + // Shouldn't fail to dial the coder connect host `--force-new-tunnel` + // is passed. + pty.ExpectMatch("Waiting") + + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. + pty.WriteLine("exit") + <-cmdDone + }) } // tGoContext runs fn in a goroutine passing a context that will be diff --git a/cli/testdata/coder_ssh_--help.golden b/cli/testdata/coder_ssh_--help.golden index 9aefb24145596..12e70e03682de 100644 --- a/cli/testdata/coder_ssh_--help.golden +++ b/cli/testdata/coder_ssh_--help.golden @@ -12,8 +12,8 @@ OPTIONS: -e, --env string-array, $CODER_SSH_ENV Set environment variable(s) for session (key1=value1,key2=value2,...). - --force-tunnel bool - Force the use of a new tunnel to the workspace, even if the Coder + --force-new-tunnel bool + Force the creation of a new tunnel to the workspace, even if the Coder Connect tunnel is available. -A, --forward-agent bool, $CODER_SSH_FORWARD_AGENT diff --git a/docs/reference/cli/ssh.md b/docs/reference/cli/ssh.md index b9d76afd1452f..e7d1b75a616c6 100644 --- a/docs/reference/cli/ssh.md +++ b/docs/reference/cli/ssh.md @@ -138,13 +138,13 @@ Specifies a directory to write network information periodically. Specifies the interval to update network information. -### --force-tunnel +### --force-new-tunnel | | | |------|-------------------| | Type | bool | -Force the use of a new tunnel to the workspace, even if the Coder Connect tunnel is available. +Force the creation of a new tunnel to the workspace, even if the Coder Connect tunnel is available. ### --disable-autostart From be118e600ff4731af7996925557be39d40ecfde8 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Tue, 29 Apr 2025 13:25:26 +0000 Subject: [PATCH 04/10] reduce scope --- cli/cliutil/stdioconn.go | 14 +-- cli/ssh.go | 171 ++++++++------------------- cli/ssh_internal_test.go | 34 +----- cli/ssh_test.go | 16 +-- cli/testdata/coder_ssh_--help.golden | 4 - docs/reference/cli/ssh.md | 8 -- 6 files changed, 66 insertions(+), 181 deletions(-) diff --git a/cli/cliutil/stdioconn.go b/cli/cliutil/stdioconn.go index 7f919dbf9d456..ed87fe552cbd5 100644 --- a/cli/cliutil/stdioconn.go +++ b/cli/cliutil/stdioconn.go @@ -6,31 +6,31 @@ import ( "time" ) -type StdioConn struct { +type ReaderWriterConn struct { io.Reader io.Writer } -func (*StdioConn) Close() (err error) { +func (*ReaderWriterConn) Close() (err error) { return nil } -func (*StdioConn) LocalAddr() net.Addr { +func (*ReaderWriterConn) LocalAddr() net.Addr { return nil } -func (*StdioConn) RemoteAddr() net.Addr { +func (*ReaderWriterConn) RemoteAddr() net.Addr { return nil } -func (*StdioConn) SetDeadline(_ time.Time) error { +func (*ReaderWriterConn) SetDeadline(_ time.Time) error { return nil } -func (*StdioConn) SetReadDeadline(_ time.Time) error { +func (*ReaderWriterConn) SetReadDeadline(_ time.Time) error { return nil } -func (*StdioConn) SetWriteDeadline(_ time.Time) error { +func (*ReaderWriterConn) SetWriteDeadline(_ time.Time) error { return nil } diff --git a/cli/ssh.go b/cli/ssh.go index 82c66eb939964..365f46ab07db9 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -67,7 +67,7 @@ func (r *RootCmd) ssh() *serpent.Command { stdio bool hostPrefix string hostnameSuffix string - forceTunnel bool + forceNewTunnel bool forwardAgent bool forwardGPG bool identityAgent string @@ -278,27 +278,38 @@ func (r *RootCmd) ssh() *serpent.Command { return err } - // See if we can use the Coder Connect tunnel - if !forceTunnel { + // If we're in stdio mode, check to see if we can use Coder Connect. + // We don't support Coder Connect over non-stdio coder ssh yet. + if stdio && !forceNewTunnel { connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx) if err != nil { return xerrors.Errorf("get agent connection info: %w", err) } - coderConnectHost := fmt.Sprintf("%s.%s.%s.%s", workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix) exists, _ := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost) if exists { _, _ = fmt.Fprintln(inv.Stderr, "Connecting to workspace via Coder Connect...") defer cancel() - addr := fmt.Sprintf("%s:22", coderConnectHost) - if stdio { + + if networkInfoDir != "" { if err := writeCoderConnectNetInfo(ctx, networkInfoDir); err != nil { logger.Error(ctx, "failed to write coder connect net info file", slog.Error(err)) } - return runCoderConnectStdio(ctx, addr, stdioReader, stdioWriter, stack) } - return runCoderConnectPTY(ctx, addr, inv.Stdin, inv.Stdout, inv.Stderr, stack) + + stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace) + defer stopPolling() + + usageAppName := getUsageAppName(usageApp) + if usageAppName != "" { + closeUsage := client.UpdateWorkspaceUsageWithBodyContext(ctx, workspace.ID, codersdk.PostWorkspaceUsageRequest{ + AgentID: workspaceAgent.ID, + AppName: usageAppName, + }) + defer closeUsage() + } + return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack) } } @@ -481,11 +492,36 @@ func (r *RootCmd) ssh() *serpent.Command { stdinFile, validIn := inv.Stdin.(*os.File) stdoutFile, validOut := inv.Stdout.(*os.File) if validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) { - restorePtyFn, err := configurePTY(ctx, stdinFile, stdoutFile, sshSession) - defer restorePtyFn() + inState, err := pty.MakeInputRaw(stdinFile.Fd()) + if err != nil { + return err + } + defer func() { + _ = pty.RestoreTerminal(stdinFile.Fd(), inState) + }() + outState, err := pty.MakeOutputRaw(stdoutFile.Fd()) if err != nil { - return xerrors.Errorf("configure pty: %w", err) + return err } + defer func() { + _ = pty.RestoreTerminal(stdoutFile.Fd(), outState) + }() + + windowChange := listenWindowSize(ctx) + go func() { + for { + select { + case <-ctx.Done(): + return + case <-windowChange: + } + width, height, err := term.GetSize(int(stdoutFile.Fd())) + if err != nil { + continue + } + _ = sshSession.WindowChange(height, width) + } + }() } for _, kv := range parsedEnv { @@ -667,48 +703,14 @@ func (r *RootCmd) ssh() *serpent.Command { { Flag: "force-new-tunnel", Description: "Force the creation of a new tunnel to the workspace, even if the Coder Connect tunnel is available.", - Value: serpent.BoolOf(&forceTunnel), + Value: serpent.BoolOf(&forceNewTunnel), + Hidden: true, }, sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)), } return cmd } -func configurePTY(ctx context.Context, stdinFile *os.File, stdoutFile *os.File, sshSession *gossh.Session) (restoreFn func(), err error) { - inState, err := pty.MakeInputRaw(stdinFile.Fd()) - if err != nil { - return restoreFn, err - } - restoreFn = func() { - _ = pty.RestoreTerminal(stdinFile.Fd(), inState) - } - outState, err := pty.MakeOutputRaw(stdoutFile.Fd()) - if err != nil { - return restoreFn, err - } - restoreFn = func() { - _ = pty.RestoreTerminal(stdinFile.Fd(), inState) - _ = pty.RestoreTerminal(stdoutFile.Fd(), outState) - } - - windowChange := listenWindowSize(ctx) - go func() { - for { - select { - case <-ctx.Done(): - return - case <-windowChange: - } - width, height, err := term.GetSize(int(stdoutFile.Fd())) - if err != nil { - continue - } - _ = sshSession.WindowChange(height, width) - } - }() - return restoreFn, nil -} - // findWorkspaceAndAgentByHostname parses the hostname from the commandline and finds the workspace and agent it // corresponds to, taking into account any name prefixes or suffixes configured (e.g. myworkspace.coder, or // vscode-coder--myusername--myworkspace). @@ -1502,7 +1504,7 @@ func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, std return err } - agentssh.Bicopy(ctx, conn, &cliutil.StdioConn{ + agentssh.Bicopy(ctx, conn, &cliutil.ReaderWriterConn{ Reader: stdin, Writer: stdout, }) @@ -1510,79 +1512,6 @@ func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, std return nil } -func runCoderConnectPTY(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stderr io.Writer, stack *closerStack) error { - client, err := gossh.Dial("tcp", addr, &gossh.ClientConfig{ - // We've already checked the agent's address - // is within the Coder service prefix. - // #nosec - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - }) - if err != nil { - return xerrors.Errorf("dial coder connect host: %w", err) - } - if err := stack.push("ssh client", client); err != nil { - return err - } - - session, err := client.NewSession() - if err != nil { - return xerrors.Errorf("create ssh session: %w", err) - } - if err := stack.push("ssh session", session); err != nil { - return err - } - - stdinFile, validIn := stdin.(*os.File) - stdoutFile, validOut := stdout.(*os.File) - if validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) { - restorePtyFn, err := configurePTY(ctx, stdinFile, stdoutFile, session) - defer restorePtyFn() - if err != nil { - return xerrors.Errorf("configure pty: %w", err) - } - } - - session.Stdin = stdin - session.Stdout = stdout - session.Stderr = stderr - - err = session.RequestPty("xterm-256color", 80, 24, gossh.TerminalModes{}) - if err != nil { - return xerrors.Errorf("request pty: %w", err) - } - - err = session.Shell() - if err != nil { - return xerrors.Errorf("start shell: %w", err) - } - - if validOut { - // Set initial window size. - width, height, err := term.GetSize(int(stdoutFile.Fd())) - if err == nil { - _ = session.WindowChange(height, width) - } - } - - err = session.Wait() - if err != nil { - if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) { - // Clear the error since it's not useful beyond - // reporting status. - return ExitError(exitErr.ExitStatus(), nil) - } - // If the connection drops unexpectedly, we get an - // ExitMissingError but no other error details, so try to at - // least give the user a better message - if errors.Is(err, &gossh.ExitMissingError{}) { - return ExitError(255, xerrors.New("SSH connection ended unexpectedly")) - } - return xerrors.Errorf("session ended: %w", err) - } - - return nil -} - func writeCoderConnectNetInfo(ctx context.Context, networkInfoDir string) error { fs, ok := ctx.Value("fs").(afero.Fs) if !ok { diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index 82659e1495afd..ea0e3f1534713 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -22,7 +22,6 @@ import ( "github.com/coder/coder/v2/cli/cliutil" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" ) @@ -226,37 +225,6 @@ func TestCloserStack_Timeout(t *testing.T) { testutil.TryReceive(ctx, t, closed) } -func TestCoderConnectPTY(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - stack := newCloserStack(ctx, logger, quartz.NewMock(t)) - - server := newSSHServer("127.0.0.1:0") - ln, err := net.Listen("tcp", server.server.Addr) - require.NoError(t, err) - - go func() { - _ = server.Serve(ln) - }() - t.Cleanup(func() { - _ = server.Close() - }) - - ptty := ptytest.New(t) - ptyDone := make(chan struct{}) - go func() { - err := runCoderConnectPTY(ctx, ln.Addr().String(), ptty.Output(), ptty.Input(), ptty.Output(), stack) - assert.NoError(t, err) - close(ptyDone) - }() - ptty.ExpectMatch("Connected!") - // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. - ptty.WriteLine("exit") - <-ptyDone -} - func TestCoderConnectStdio(t *testing.T) { t.Parallel() @@ -290,7 +258,7 @@ func TestCoderConnectStdio(t *testing.T) { close(stdioDone) }() - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 1a2d2aa37425f..d76633f27c858 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -474,7 +474,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -543,7 +543,7 @@ func TestSSH(t *testing.T) { signer, err := agentssh.CoderSigner(keySeed) assert.NoError(t, err) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -606,7 +606,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -774,7 +774,7 @@ func TestSSH(t *testing.T) { // have access to the shell. _ = agenttest.New(t, client.URL, authToken) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: proxyCommandStdoutR, Writer: clientStdinW, }, "", &ssh.ClientConfig{ @@ -836,7 +836,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -895,7 +895,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -1083,7 +1083,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -1742,7 +1742,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.StdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ diff --git a/cli/testdata/coder_ssh_--help.golden b/cli/testdata/coder_ssh_--help.golden index 12e70e03682de..1f7122dd655a2 100644 --- a/cli/testdata/coder_ssh_--help.golden +++ b/cli/testdata/coder_ssh_--help.golden @@ -12,10 +12,6 @@ OPTIONS: -e, --env string-array, $CODER_SSH_ENV Set environment variable(s) for session (key1=value1,key2=value2,...). - --force-new-tunnel bool - Force the creation of a new tunnel to the workspace, even if the Coder - Connect tunnel is available. - -A, --forward-agent bool, $CODER_SSH_FORWARD_AGENT Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK. diff --git a/docs/reference/cli/ssh.md b/docs/reference/cli/ssh.md index e7d1b75a616c6..c5bae755c8419 100644 --- a/docs/reference/cli/ssh.md +++ b/docs/reference/cli/ssh.md @@ -138,14 +138,6 @@ Specifies a directory to write network information periodically. Specifies the interval to update network information. -### --force-new-tunnel - -| | | -|------|-------------------| -| Type | bool | - -Force the creation of a new tunnel to the workspace, even if the Coder Connect tunnel is available. - ### --disable-autostart | | | From bb75fa27deedc26d326a9907be347338cfe3659b Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Tue, 29 Apr 2025 13:40:39 +0000 Subject: [PATCH 05/10] review --- cli/ssh.go | 11 ++++++++++- cli/ssh_internal_test.go | 3 +-- cli/ssh_test.go | 17 ++++++++--------- cli/cliutil/stdioconn.go => testutil/rwconn.go | 2 +- 4 files changed, 20 insertions(+), 13 deletions(-) rename cli/cliutil/stdioconn.go => testutil/rwconn.go (96%) diff --git a/cli/ssh.go b/cli/ssh.go index 365f46ab07db9..bb04649535de3 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -1504,7 +1504,7 @@ func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, std return err } - agentssh.Bicopy(ctx, conn, &cliutil.ReaderWriterConn{ + agentssh.Bicopy(ctx, conn, &StdioRwc{ Reader: stdin, Writer: stdout, }) @@ -1512,6 +1512,15 @@ func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, std return nil } +type StdioRwc struct { + io.Reader + io.Writer +} + +func (*StdioRwc) Close() error { + return nil +} + func writeCoderConnectNetInfo(ctx context.Context, networkInfoDir string) error { fs, ok := ctx.Value("fs").(afero.Fs) if !ok { diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index ea0e3f1534713..d76ff1881680c 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -20,7 +20,6 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/quartz" - "github.com/coder/coder/v2/cli/cliutil" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" ) @@ -258,7 +257,7 @@ func TestCoderConnectStdio(t *testing.T) { close(stdioDone) }() - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ diff --git a/cli/ssh_test.go b/cli/ssh_test.go index d76633f27c858..6f1703fe92236 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -43,7 +43,6 @@ import ( agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/cliui" - "github.com/coder/coder/v2/cli/cliutil" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbfake" @@ -474,7 +473,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -543,7 +542,7 @@ func TestSSH(t *testing.T) { signer, err := agentssh.CoderSigner(keySeed) assert.NoError(t, err) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -606,7 +605,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -774,7 +773,7 @@ func TestSSH(t *testing.T) { // have access to the shell. _ = agenttest.New(t, client.URL, authToken) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: proxyCommandStdoutR, Writer: clientStdinW, }, "", &ssh.ClientConfig{ @@ -836,7 +835,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -895,7 +894,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -1083,7 +1082,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -1742,7 +1741,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&cliutil.ReaderWriterConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ diff --git a/cli/cliutil/stdioconn.go b/testutil/rwconn.go similarity index 96% rename from cli/cliutil/stdioconn.go rename to testutil/rwconn.go index ed87fe552cbd5..a731e9c3c0ab0 100644 --- a/cli/cliutil/stdioconn.go +++ b/testutil/rwconn.go @@ -1,4 +1,4 @@ -package cliutil +package testutil import ( "io" From 00f18afef5d5474f33bc3f7507fb4bd2619962f2 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Tue, 29 Apr 2025 13:50:57 +0000 Subject: [PATCH 06/10] typo --- cli/ssh_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 6f1703fe92236..15663dd4bbe9c 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -2167,8 +2167,8 @@ func TestSSH_CoderConnect(t *testing.T) { err := inv.WithContext(withCoderConnectRunning(ctx)).Run() assert.NoError(t, err) }) - // Shouldn't fail to dial the coder connect host `--force-new-tunnel` - // is passed. + // Shouldn't fail to dial the coder connect host since + // `--force-new-tunnel` is passed. pty.ExpectMatch("Waiting") _ = agenttest.New(t, client.URL, agentToken) From 4ce57b72acc732609cfa629701be33757df398d6 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Tue, 29 Apr 2025 14:12:43 +0000 Subject: [PATCH 07/10] fix tests --- cli/ssh_test.go | 51 ++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 15663dd4bbe9c..7c0f299414953 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -2154,28 +2154,57 @@ func TestSSH_CoderConnect(t *testing.T) { t.Run("Disabled", func(t *testing.T) { t.Parallel() - client, workspace, agentToken := setupWorkspaceForAgent(t) - inv, root := clitest.New(t, "ssh", workspace.Name, "--force-new-tunnel") - clitest.SetupConfig(t, client, root) - pty := ptytest.New(t).Attach(inv) + + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + clientOutput, clientInput := io.Pipe() + serverOutput, serverInput := io.Pipe() + defer func() { + for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} { + _ = c.Close() + } + }() ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() + inv, root := clitest.New(t, "ssh", "--force-new-tunnel", "--stdio", workspace.Name) + clitest.SetupConfig(t, client, root) + inv.Stdin = clientOutput + inv.Stdout = serverInput + inv.Stderr = io.Discard + cmdDone := tGo(t, func() { - err := inv.WithContext(withCoderConnectRunning(ctx)).Run() + err := inv.WithContext(ctx).Run() + // Shouldn't fail to dial the Coder Connect host + // since `--force-new-tunnel` was passed assert.NoError(t, err) }) - // Shouldn't fail to dial the coder connect host since - // `--force-new-tunnel` is passed. - pty.ExpectMatch("Waiting") - _ = agenttest.New(t, client.URL, agentToken) - coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ + Reader: serverOutput, + Writer: clientInput, + }, "", &ssh.ClientConfig{ + // #nosec + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + require.NoError(t, err) + defer conn.Close() + + sshClient := ssh.NewClient(conn, channels, requests) + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. - pty.WriteLine("exit") + err = session.Run("exit") + require.NoError(t, err) + err = sshClient.Close() + require.NoError(t, err) + _ = clientOutput.Close() + <-cmdDone }) } From e46a08405a720a219e5845bb2e973cf5bc571820 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Wed, 30 Apr 2025 02:01:35 +0000 Subject: [PATCH 08/10] fixup --- cli/ssh.go | 1 - cli/ssh_test.go | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/cli/ssh.go b/cli/ssh.go index bb04649535de3..ffa85f8690eab 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -289,7 +289,6 @@ func (r *RootCmd) ssh() *serpent.Command { workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix) exists, _ := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost) if exists { - _, _ = fmt.Fprintln(inv.Stderr, "Connecting to workspace via Coder Connect...") defer cancel() if networkInfoDir != "" { diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 7c0f299414953..2ec33f98a8437 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -2177,7 +2177,7 @@ func TestSSH_CoderConnect(t *testing.T) { inv.Stderr = io.Discard cmdDone := tGo(t, func() { - err := inv.WithContext(ctx).Run() + err := inv.WithContext(withCoderConnectRunning(ctx)).Run() // Shouldn't fail to dial the Coder Connect host // since `--force-new-tunnel` was passed assert.NoError(t, err) From 76603ed54c7439d02c684085ecd327b3756a7a00 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Wed, 30 Apr 2025 03:54:04 +0000 Subject: [PATCH 09/10] review --- cli/ssh.go | 21 ++++++++++++++++++++- cli/ssh_internal_test.go | 4 ++-- cli/ssh_test.go | 34 +++++++++++++++++++++------------- 3 files changed, 43 insertions(+), 16 deletions(-) diff --git a/cli/ssh.go b/cli/ssh.go index ffa85f8690eab..f93fa79656858 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -1494,8 +1494,27 @@ func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, }, nil } +type coderConnectDialerContextKey struct{} + +type coderConnectDialer interface { + DialContext(ctx context.Context, network, addr string) (net.Conn, error) +} + +func WithTestOnlyCoderConnectDialer(ctx context.Context, dialer coderConnectDialer) context.Context { + return context.WithValue(ctx, coderConnectDialerContextKey{}, dialer) +} + +func testOrDefaultDialer(ctx context.Context) coderConnectDialer { + dialer, ok := ctx.Value(coderConnectDialerContextKey{}).(coderConnectDialer) + if !ok || dialer == nil { + return &net.Dialer{} + } + return dialer +} + func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack) error { - conn, err := net.Dial("tcp", addr) + dialer := testOrDefaultDialer(ctx) + conn, err := dialer.DialContext(ctx, "tcp", addr) if err != nil { return xerrors.Errorf("dial coder connect host: %w", err) } diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index d76ff1881680c..caee1ec25b710 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -272,8 +272,8 @@ func TestCoderConnectStdio(t *testing.T) { require.NoError(t, err) defer session.Close() - // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. - err = session.Run("exit") + // We're not connected to a real shell + err = session.Run("") require.NoError(t, err) err = sshClient.Close() require.NoError(t, err) diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 2ec33f98a8437..90d36f57dbb81 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -41,6 +41,7 @@ import ( "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/agenttest" agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/cli" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/coderd/coderdtest" @@ -2127,9 +2128,12 @@ func TestSSH_CoderConnect(t *testing.T) { clitest.SetupConfig(t, client, root) _ = ptytest.New(t).Attach(inv) + ctx = cli.WithTestOnlyCoderConnectDialer(ctx, &fakeCoderConnectDialer{}) + ctx = withCoderConnectRunning(ctx) + errCh := make(chan error, 1) tGo(t, func() { - err := inv.WithContext(withCoderConnectRunning(ctx)).Run() + err := inv.WithContext(ctx).Run() errCh <- err }) @@ -2137,19 +2141,14 @@ func TestSSH_CoderConnect(t *testing.T) { coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) err := testutil.TryReceive(ctx, t, errCh) - // Making an SSH server available here is difficult, so we'll just check - // the command attempts to dial it. - require.ErrorContains(t, err, "dial coder connect host") - require.ErrorContains(t, err, "dev.myworkspace.myuser.coder") + // Our mock dialer will always fail with this error, if it was called + require.ErrorContains(t, err, "dial coder connect host \"dev.myworkspace.myuser.coder:22\" over tcp") // The network info file should be created since we passed `--stdio` - assert.Eventually(t, func() bool { - entries, err := afero.ReadDir(fs, "/net") - if err != nil { - return false - } - return len(entries) > 0 - }, testutil.WaitLong, testutil.IntervalFast) + entries, err := afero.ReadDir(fs, "/net") + require.NoError(t, err) + require.True(t, len(entries) > 0) + }) t.Run("Disabled", func(t *testing.T) { @@ -2176,8 +2175,11 @@ func TestSSH_CoderConnect(t *testing.T) { inv.Stdout = serverInput inv.Stderr = io.Discard + ctx = cli.WithTestOnlyCoderConnectDialer(ctx, &fakeCoderConnectDialer{}) + ctx = withCoderConnectRunning(ctx) + cmdDone := tGo(t, func() { - err := inv.WithContext(withCoderConnectRunning(ctx)).Run() + err := inv.WithContext(ctx).Run() // Shouldn't fail to dial the Coder Connect host // since `--force-new-tunnel` was passed assert.NoError(t, err) @@ -2209,6 +2211,12 @@ func TestSSH_CoderConnect(t *testing.T) { }) } +type fakeCoderConnectDialer struct{} + +func (*fakeCoderConnectDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, xerrors.Errorf("dial coder connect host %q over %s", addr, network) +} + // tGoContext runs fn in a goroutine passing a context that will be // canceled on test completion and wait until fn has finished executing. // Done and cancel are returned for optionally waiting until completion From 8c050235ed4c610ed1cf2f2d56529431a3602563 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Wed, 30 Apr 2025 04:29:43 +0000 Subject: [PATCH 10/10] fmt --- cli/ssh_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 90d36f57dbb81..ab754626c54fa 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -2148,7 +2148,6 @@ func TestSSH_CoderConnect(t *testing.T) { entries, err := afero.ReadDir(fs, "/net") require.NoError(t, err) require.True(t, len(entries) > 0) - }) t.Run("Disabled", func(t *testing.T) { pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy