diff --git a/cli/ssh.go b/cli/ssh.go index e02443e7032c6..f93fa79656858 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "log" + "net" "net/http" "net/url" "os" @@ -66,6 +67,7 @@ func (r *RootCmd) ssh() *serpent.Command { stdio bool hostPrefix string hostnameSuffix string + forceNewTunnel bool forwardAgent bool forwardGPG bool identityAgent string @@ -85,6 +87,7 @@ func (r *RootCmd) ssh() *serpent.Command { containerUser string ) client := new(codersdk.Client) + wsClient := workspacesdk.New(client) cmd := &serpent.Command{ Annotations: workspaceCommand, Use: "ssh ", @@ -203,14 +206,14 @@ func (r *RootCmd) ssh() *serpent.Command { parsedEnv = append(parsedEnv, [2]string{k, v}) } - deploymentSSHConfig := codersdk.SSHConfigResponse{ + cliConfig := codersdk.SSHConfigResponse{ HostnamePrefix: hostPrefix, HostnameSuffix: hostnameSuffix, } workspace, workspaceAgent, err := findWorkspaceAndAgentByHostname( ctx, inv, client, - inv.Args[0], deploymentSSHConfig, disableAutostart) + inv.Args[0], cliConfig, disableAutostart) if err != nil { return err } @@ -275,10 +278,44 @@ func (r *RootCmd) ssh() *serpent.Command { return err } + // If we're in stdio mode, check to see if we can use Coder Connect. + // We don't support Coder Connect over non-stdio coder ssh yet. + if stdio && !forceNewTunnel { + connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx) + if err != nil { + return xerrors.Errorf("get agent connection info: %w", err) + } + coderConnectHost := fmt.Sprintf("%s.%s.%s.%s", + workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix) + exists, _ := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost) + if exists { + defer cancel() + + if networkInfoDir != "" { + if err := writeCoderConnectNetInfo(ctx, networkInfoDir); err != nil { + logger.Error(ctx, "failed to write coder connect net info file", slog.Error(err)) + } + } + + stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace) + defer stopPolling() + + usageAppName := getUsageAppName(usageApp) + if usageAppName != "" { + closeUsage := client.UpdateWorkspaceUsageWithBodyContext(ctx, workspace.ID, codersdk.PostWorkspaceUsageRequest{ + AgentID: workspaceAgent.ID, + AppName: usageAppName, + }) + defer closeUsage() + } + return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack) + } + } + if r.disableDirect { _, _ = fmt.Fprintln(inv.Stderr, "Direct connections disabled.") } - conn, err := workspacesdk.New(client). + conn, err := wsClient. DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{ Logger: logger, BlockEndpoints: r.disableDirect, @@ -662,6 +699,12 @@ func (r *RootCmd) ssh() *serpent.Command { Value: serpent.StringOf(&containerUser), Hidden: true, // Hidden until this features is at least in beta. }, + { + Flag: "force-new-tunnel", + Description: "Force the creation of a new tunnel to the workspace, even if the Coder Connect tunnel is available.", + Value: serpent.BoolOf(&forceNewTunnel), + Hidden: true, + }, sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)), } return cmd @@ -1374,12 +1417,13 @@ func setStatsCallback( } type sshNetworkStats struct { - P2P bool `json:"p2p"` - Latency float64 `json:"latency"` - PreferredDERP string `json:"preferred_derp"` - DERPLatency map[string]float64 `json:"derp_latency"` - UploadBytesSec int64 `json:"upload_bytes_sec"` - DownloadBytesSec int64 `json:"download_bytes_sec"` + P2P bool `json:"p2p"` + Latency float64 `json:"latency"` + PreferredDERP string `json:"preferred_derp"` + DERPLatency map[string]float64 `json:"derp_latency"` + UploadBytesSec int64 `json:"upload_bytes_sec"` + DownloadBytesSec int64 `json:"download_bytes_sec"` + UsingCoderConnect bool `json:"using_coder_connect"` } func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) { @@ -1450,6 +1494,76 @@ func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, }, nil } +type coderConnectDialerContextKey struct{} + +type coderConnectDialer interface { + DialContext(ctx context.Context, network, addr string) (net.Conn, error) +} + +func WithTestOnlyCoderConnectDialer(ctx context.Context, dialer coderConnectDialer) context.Context { + return context.WithValue(ctx, coderConnectDialerContextKey{}, dialer) +} + +func testOrDefaultDialer(ctx context.Context) coderConnectDialer { + dialer, ok := ctx.Value(coderConnectDialerContextKey{}).(coderConnectDialer) + if !ok || dialer == nil { + return &net.Dialer{} + } + return dialer +} + +func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack) error { + dialer := testOrDefaultDialer(ctx) + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return xerrors.Errorf("dial coder connect host: %w", err) + } + if err := stack.push("tcp conn", conn); err != nil { + return err + } + + agentssh.Bicopy(ctx, conn, &StdioRwc{ + Reader: stdin, + Writer: stdout, + }) + + return nil +} + +type StdioRwc struct { + io.Reader + io.Writer +} + +func (*StdioRwc) Close() error { + return nil +} + +func writeCoderConnectNetInfo(ctx context.Context, networkInfoDir string) error { + fs, ok := ctx.Value("fs").(afero.Fs) + if !ok { + fs = afero.NewOsFs() + } + // The VS Code extension obtains the PID of the SSH process to + // find the log file associated with a SSH session. + // + // We get the parent PID because it's assumed `ssh` is calling this + // command via the ProxyCommand SSH option. + networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", os.Getppid())) + stats := &sshNetworkStats{ + UsingCoderConnect: true, + } + rawStats, err := json.Marshal(stats) + if err != nil { + return xerrors.Errorf("marshal network stats: %w", err) + } + err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0o600) + if err != nil { + return xerrors.Errorf("write network stats: %w", err) + } + return nil +} + // Converts workspace name input to owner/workspace.agent format // Possible valid input formats: // workspace diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index d5e4c049347b2..caee1ec25b710 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -3,13 +3,17 @@ package cli import ( "context" "fmt" + "io" + "net" "net/url" "sync" "testing" "time" + gliderssh "github.com/gliderlabs/ssh" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" "golang.org/x/xerrors" "cdr.dev/slog" @@ -220,6 +224,87 @@ func TestCloserStack_Timeout(t *testing.T) { testutil.TryReceive(ctx, t, closed) } +func TestCoderConnectStdio(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + stack := newCloserStack(ctx, logger, quartz.NewMock(t)) + + clientOutput, clientInput := io.Pipe() + serverOutput, serverInput := io.Pipe() + defer func() { + for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} { + _ = c.Close() + } + }() + + server := newSSHServer("127.0.0.1:0") + ln, err := net.Listen("tcp", server.server.Addr) + require.NoError(t, err) + + go func() { + _ = server.Serve(ln) + }() + t.Cleanup(func() { + _ = server.Close() + }) + + stdioDone := make(chan struct{}) + go func() { + err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack) + assert.NoError(t, err) + close(stdioDone) + }() + + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ + Reader: serverOutput, + Writer: clientInput, + }, "", &ssh.ClientConfig{ + // #nosec + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + require.NoError(t, err) + defer conn.Close() + + sshClient := ssh.NewClient(conn, channels, requests) + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() + + // We're not connected to a real shell + err = session.Run("") + require.NoError(t, err) + err = sshClient.Close() + require.NoError(t, err) + _ = clientOutput.Close() + + <-stdioDone +} + +type sshServer struct { + server *gliderssh.Server +} + +func newSSHServer(addr string) *sshServer { + return &sshServer{ + server: &gliderssh.Server{ + Addr: addr, + Handler: func(s gliderssh.Session) { + _, _ = io.WriteString(s.Stderr(), "Connected!") + }, + }, + } +} + +func (s *sshServer) Serve(ln net.Listener) error { + return s.server.Serve(ln) +} + +func (s *sshServer) Close() error { + return s.server.Close() +} + type fakeCloser struct { closes *[]*fakeCloser err error diff --git a/cli/ssh_test.go b/cli/ssh_test.go index c8ad072270169..ab754626c54fa 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -41,6 +41,7 @@ import ( "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/agenttest" agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/cli" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/coderd/coderdtest" @@ -473,7 +474,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -542,7 +543,7 @@ func TestSSH(t *testing.T) { signer, err := agentssh.CoderSigner(keySeed) assert.NoError(t, err) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -605,7 +606,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -773,7 +774,7 @@ func TestSSH(t *testing.T) { // have access to the shell. _ = agenttest.New(t, client.URL, authToken) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: proxyCommandStdoutR, Writer: clientStdinW, }, "", &ssh.ClientConfig{ @@ -835,7 +836,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -894,7 +895,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -1082,7 +1083,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -1741,7 +1742,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -2110,6 +2111,111 @@ func TestSSH_Container(t *testing.T) { }) } +func TestSSH_CoderConnect(t *testing.T) { + t.Parallel() + + t.Run("Enabled", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + fs := afero.NewMemMapFs() + //nolint:revive,staticcheck + ctx = context.WithValue(ctx, "fs", fs) + + client, workspace, agentToken := setupWorkspaceForAgent(t) + inv, root := clitest.New(t, "ssh", workspace.Name, "--network-info-dir", "/net", "--stdio") + clitest.SetupConfig(t, client, root) + _ = ptytest.New(t).Attach(inv) + + ctx = cli.WithTestOnlyCoderConnectDialer(ctx, &fakeCoderConnectDialer{}) + ctx = withCoderConnectRunning(ctx) + + errCh := make(chan error, 1) + tGo(t, func() { + err := inv.WithContext(ctx).Run() + errCh <- err + }) + + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + err := testutil.TryReceive(ctx, t, errCh) + // Our mock dialer will always fail with this error, if it was called + require.ErrorContains(t, err, "dial coder connect host \"dev.myworkspace.myuser.coder:22\" over tcp") + + // The network info file should be created since we passed `--stdio` + entries, err := afero.ReadDir(fs, "/net") + require.NoError(t, err) + require.True(t, len(entries) > 0) + }) + + t.Run("Disabled", func(t *testing.T) { + t.Parallel() + client, workspace, agentToken := setupWorkspaceForAgent(t) + + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + clientOutput, clientInput := io.Pipe() + serverOutput, serverInput := io.Pipe() + defer func() { + for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} { + _ = c.Close() + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + inv, root := clitest.New(t, "ssh", "--force-new-tunnel", "--stdio", workspace.Name) + clitest.SetupConfig(t, client, root) + inv.Stdin = clientOutput + inv.Stdout = serverInput + inv.Stderr = io.Discard + + ctx = cli.WithTestOnlyCoderConnectDialer(ctx, &fakeCoderConnectDialer{}) + ctx = withCoderConnectRunning(ctx) + + cmdDone := tGo(t, func() { + err := inv.WithContext(ctx).Run() + // Shouldn't fail to dial the Coder Connect host + // since `--force-new-tunnel` was passed + assert.NoError(t, err) + }) + + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ + Reader: serverOutput, + Writer: clientInput, + }, "", &ssh.ClientConfig{ + // #nosec + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + require.NoError(t, err) + defer conn.Close() + + sshClient := ssh.NewClient(conn, channels, requests) + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() + + // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. + err = session.Run("exit") + require.NoError(t, err) + err = sshClient.Close() + require.NoError(t, err) + _ = clientOutput.Close() + + <-cmdDone + }) +} + +type fakeCoderConnectDialer struct{} + +func (*fakeCoderConnectDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, xerrors.Errorf("dial coder connect host %q over %s", addr, network) +} + // tGoContext runs fn in a goroutine passing a context that will be // canceled on test completion and wait until fn has finished executing. // Done and cancel are returned for optionally waiting until completion @@ -2153,35 +2259,6 @@ func tGo(t *testing.T, fn func()) (done <-chan struct{}) { return doneC } -type stdioConn struct { - io.Reader - io.Writer -} - -func (*stdioConn) Close() (err error) { - return nil -} - -func (*stdioConn) LocalAddr() net.Addr { - return nil -} - -func (*stdioConn) RemoteAddr() net.Addr { - return nil -} - -func (*stdioConn) SetDeadline(_ time.Time) error { - return nil -} - -func (*stdioConn) SetReadDeadline(_ time.Time) error { - return nil -} - -func (*stdioConn) SetWriteDeadline(_ time.Time) error { - return nil -} - // tempDirUnixSocket returns a temporary directory that can safely hold unix // sockets (probably). // diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index fa569080f7dd2..97b4268c68780 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -185,14 +185,12 @@ func (c *AgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, return c.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), port)) } -// SSHClient calls SSH to create a client that uses a weak cipher -// to improve throughput. +// SSHClient calls SSH to create a client func (c *AgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { return c.SSHClientOnPort(ctx, AgentSSHPort) } // SSHClientOnPort calls SSH to create a client on a specific port -// that uses a weak cipher to improve throughput. func (c *AgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() diff --git a/testutil/rwconn.go b/testutil/rwconn.go new file mode 100644 index 0000000000000..a731e9c3c0ab0 --- /dev/null +++ b/testutil/rwconn.go @@ -0,0 +1,36 @@ +package testutil + +import ( + "io" + "net" + "time" +) + +type ReaderWriterConn struct { + io.Reader + io.Writer +} + +func (*ReaderWriterConn) Close() (err error) { + return nil +} + +func (*ReaderWriterConn) LocalAddr() net.Addr { + return nil +} + +func (*ReaderWriterConn) RemoteAddr() net.Addr { + return nil +} + +func (*ReaderWriterConn) SetDeadline(_ time.Time) error { + return nil +} + +func (*ReaderWriterConn) SetReadDeadline(_ time.Time) error { + return nil +} + +func (*ReaderWriterConn) SetWriteDeadline(_ time.Time) error { + return nil +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy