Skip to content

Commit 527f1f3

Browse files
mafredrijohnstcn
andauthored
feat: Add SSH agent forwarding support to coder agent (#1548)
* feat: Add SSH agent forwarding support to coder agent * feat: Add forward agent flag to `coder ssh` * refactor: Share setup between SSH tests, sync goroutines * feat: Add test for `coder ssh --forward-agent` * fix: Fix test flakes and implement Deans suggestion for helpers * fix: Add example to config-ssh * fix: Allow forwarding agent via -A Co-authored-by: Cian Johnston <cian@coder.com>
1 parent 22ef456 commit 527f1f3

File tree

4 files changed

+211
-69
lines changed

4 files changed

+211
-69
lines changed

agent/agent.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,16 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
391391
return err
392392
}
393393

394+
if ssh.AgentRequested(session) {
395+
l, err := ssh.NewAgentListener()
396+
if err != nil {
397+
return xerrors.Errorf("new agent listener: %w", err)
398+
}
399+
defer l.Close()
400+
go ssh.ForwardAgentConnections(l, session)
401+
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String()))
402+
}
403+
394404
sshPty, windowSize, isPty := session.Pty()
395405
if isPty {
396406
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))

cli/configssh.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ func configSSH() *cobra.Command {
3838
Annotations: workspaceCommand,
3939
Use: "config-ssh",
4040
Short: "Populate your SSH config with Host entries for all of your workspaces",
41+
Example: `
42+
- You can use -o (or --ssh-option) so set SSH options to be used for all your
43+
workspaces.
44+
45+
` + cliui.Styles.Code.Render("$ coder config-ssh -o ForwardAgent=yes"),
4146
RunE: func(cmd *cobra.Command, args []string) error {
4247
client, err := createClient(cmd)
4348
if err != nil {

cli/ssh.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/mattn/go-isatty"
1616
"github.com/spf13/cobra"
1717
gossh "golang.org/x/crypto/ssh"
18+
gosshagent "golang.org/x/crypto/ssh/agent"
1819
"golang.org/x/term"
1920
"golang.org/x/xerrors"
2021

@@ -32,6 +33,7 @@ func ssh() *cobra.Command {
3233
var (
3334
stdio bool
3435
shuffle bool
36+
forwardAgent bool
3537
wsPollInterval time.Duration
3638
)
3739
cmd := &cobra.Command{
@@ -108,6 +110,17 @@ func ssh() *cobra.Command {
108110
return err
109111
}
110112

113+
if forwardAgent && os.Getenv("SSH_AUTH_SOCK") != "" {
114+
err = gosshagent.ForwardToRemote(sshClient, os.Getenv("SSH_AUTH_SOCK"))
115+
if err != nil {
116+
return xerrors.Errorf("forward agent failed: %w", err)
117+
}
118+
err = gosshagent.RequestAgentForwarding(sshSession)
119+
if err != nil {
120+
return xerrors.Errorf("request agent forwarding failed: %w", err)
121+
}
122+
}
123+
111124
stdoutFile, valid := cmd.OutOrStdout().(*os.File)
112125
if valid && isatty.IsTerminal(stdoutFile.Fd()) {
113126
state, err := term.MakeRaw(int(os.Stdin.Fd()))
@@ -156,8 +169,9 @@ func ssh() *cobra.Command {
156169
}
157170
cliflag.BoolVarP(cmd.Flags(), &stdio, "stdio", "", "CODER_SSH_STDIO", false, "Specifies whether to emit SSH output over stdin/stdout.")
158171
cliflag.BoolVarP(cmd.Flags(), &shuffle, "shuffle", "", "CODER_SSH_SHUFFLE", false, "Specifies whether to choose a random workspace")
159-
cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.")
160172
_ = cmd.Flags().MarkHidden("shuffle")
173+
cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK")
174+
cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.")
161175

162176
return cmd
163177
}

cli/ssh_test.go

Lines changed: 181 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
package cli_test
22

33
import (
4+
"context"
5+
"crypto/ecdsa"
6+
"crypto/elliptic"
7+
"crypto/rand"
8+
"errors"
49
"io"
510
"net"
11+
"path/filepath"
612
"runtime"
713
"testing"
814
"time"
@@ -11,9 +17,11 @@ import (
1117
"github.com/stretchr/testify/assert"
1218
"github.com/stretchr/testify/require"
1319
"golang.org/x/crypto/ssh"
20+
gosshagent "golang.org/x/crypto/ssh/agent"
1421

1522
"cdr.dev/slog"
1623
"cdr.dev/slog/sloggers/slogtest"
24+
1725
"github.com/coder/coder/agent"
1826
"github.com/coder/coder/cli/clitest"
1927
"github.com/coder/coder/coderd/coderdtest"
@@ -23,49 +31,53 @@ import (
2331
"github.com/coder/coder/pty/ptytest"
2432
)
2533

34+
func setupWorkspaceForSSH(t *testing.T) (*codersdk.Client, codersdk.Workspace, string) {
35+
t.Helper()
36+
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})
37+
user := coderdtest.CreateFirstUser(t, client)
38+
agentToken := uuid.NewString()
39+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
40+
Parse: echo.ParseComplete,
41+
ProvisionDryRun: echo.ProvisionComplete,
42+
Provision: []*proto.Provision_Response{{
43+
Type: &proto.Provision_Response_Complete{
44+
Complete: &proto.Provision_Complete{
45+
Resources: []*proto.Resource{{
46+
Name: "dev",
47+
Type: "google_compute_instance",
48+
Agents: []*proto.Agent{{
49+
Id: uuid.NewString(),
50+
Auth: &proto.Agent_Token{
51+
Token: agentToken,
52+
},
53+
}},
54+
}},
55+
},
56+
},
57+
}},
58+
})
59+
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
60+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
61+
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
62+
63+
return client, workspace, agentToken
64+
}
65+
2666
func TestSSH(t *testing.T) {
27-
t.Skip("This is causing test flakes. TODO @cian fix this")
2867
t.Parallel()
2968
t.Run("ImmediateExit", func(t *testing.T) {
3069
t.Parallel()
31-
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})
32-
user := coderdtest.CreateFirstUser(t, client)
33-
agentToken := uuid.NewString()
34-
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
35-
Parse: echo.ParseComplete,
36-
ProvisionDryRun: echo.ProvisionComplete,
37-
Provision: []*proto.Provision_Response{{
38-
Type: &proto.Provision_Response_Complete{
39-
Complete: &proto.Provision_Complete{
40-
Resources: []*proto.Resource{{
41-
Name: "dev",
42-
Type: "google_compute_instance",
43-
Agents: []*proto.Agent{{
44-
Id: uuid.NewString(),
45-
Auth: &proto.Agent_Token{
46-
Token: agentToken,
47-
},
48-
}},
49-
}},
50-
},
51-
},
52-
}},
53-
})
54-
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
55-
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
56-
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
70+
client, workspace, agentToken := setupWorkspaceForSSH(t)
5771
cmd, root := clitest.New(t, "ssh", workspace.Name)
5872
clitest.SetupConfig(t, client, root)
59-
doneChan := make(chan struct{})
6073
pty := ptytest.New(t)
6174
cmd.SetIn(pty.Input())
6275
cmd.SetErr(pty.Output())
6376
cmd.SetOut(pty.Output())
64-
go func() {
65-
defer close(doneChan)
77+
cmdDone := tGo(t, func() {
6678
err := cmd.Execute()
6779
assert.NoError(t, err)
68-
}()
80+
})
6981
pty.ExpectMatch("Waiting")
7082
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
7183
agentClient := codersdk.New(client.URL)
@@ -76,39 +88,16 @@ func TestSSH(t *testing.T) {
7688
t.Cleanup(func() {
7789
_ = agentCloser.Close()
7890
})
91+
7992
// Shells on Mac, Windows, and Linux all exit shells with the "exit" command.
8093
pty.WriteLine("exit")
81-
<-doneChan
94+
<-cmdDone
8295
})
8396
t.Run("Stdio", func(t *testing.T) {
8497
t.Parallel()
85-
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})
86-
user := coderdtest.CreateFirstUser(t, client)
87-
agentToken := uuid.NewString()
88-
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
89-
Parse: echo.ParseComplete,
90-
ProvisionDryRun: echo.ProvisionComplete,
91-
Provision: []*proto.Provision_Response{{
92-
Type: &proto.Provision_Response_Complete{
93-
Complete: &proto.Provision_Complete{
94-
Resources: []*proto.Resource{{
95-
Name: "dev",
96-
Type: "google_compute_instance",
97-
Agents: []*proto.Agent{{
98-
Id: uuid.NewString(),
99-
Auth: &proto.Agent_Token{
100-
Token: agentToken,
101-
},
102-
}},
103-
}},
104-
},
105-
},
106-
}},
107-
})
108-
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
109-
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
110-
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
111-
go func() {
98+
client, workspace, agentToken := setupWorkspaceForSSH(t)
99+
100+
_, _ = tGoContext(t, func(ctx context.Context) {
112101
// Run this async so the SSH command has to wait for
113102
// the build and agent to connect!
114103
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
@@ -117,25 +106,22 @@ func TestSSH(t *testing.T) {
117106
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
118107
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
119108
})
120-
t.Cleanup(func() {
121-
_ = agentCloser.Close()
122-
})
123-
}()
109+
<-ctx.Done()
110+
_ = agentCloser.Close()
111+
})
124112

125113
clientOutput, clientInput := io.Pipe()
126114
serverOutput, serverInput := io.Pipe()
127115

128116
cmd, root := clitest.New(t, "ssh", "--stdio", workspace.Name)
129117
clitest.SetupConfig(t, client, root)
130-
doneChan := make(chan struct{})
131118
cmd.SetIn(clientOutput)
132119
cmd.SetOut(serverInput)
133120
cmd.SetErr(io.Discard)
134-
go func() {
135-
defer close(doneChan)
121+
cmdDone := tGo(t, func() {
136122
err := cmd.Execute()
137123
assert.NoError(t, err)
138-
}()
124+
})
139125

140126
conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
141127
Reader: serverOutput,
@@ -157,8 +143,135 @@ func TestSSH(t *testing.T) {
157143
err = sshClient.Close()
158144
require.NoError(t, err)
159145
_ = clientOutput.Close()
160-
<-doneChan
146+
147+
<-cmdDone
148+
})
149+
//nolint:paralleltest // Disabled due to use of t.Setenv.
150+
t.Run("ForwardAgent", func(t *testing.T) {
151+
if runtime.GOOS == "windows" {
152+
t.Skip("Test not supported on windows")
153+
}
154+
155+
client, workspace, agentToken := setupWorkspaceForSSH(t)
156+
157+
_, _ = tGoContext(t, func(ctx context.Context) {
158+
// Run this async so the SSH command has to wait for
159+
// the build and agent to connect!
160+
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
161+
agentClient := codersdk.New(client.URL)
162+
agentClient.SessionToken = agentToken
163+
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
164+
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
165+
})
166+
<-ctx.Done()
167+
_ = agentCloser.Close()
168+
})
169+
170+
// Generate private key.
171+
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
172+
require.NoError(t, err)
173+
kr := gosshagent.NewKeyring()
174+
kr.Add(gosshagent.AddedKey{
175+
PrivateKey: privateKey,
176+
})
177+
178+
// Start up ssh agent listening on unix socket.
179+
tmpdir := t.TempDir()
180+
agentSock := filepath.Join(tmpdir, "agent.sock")
181+
l, err := net.Listen("unix", agentSock)
182+
require.NoError(t, err)
183+
defer l.Close()
184+
_ = tGo(t, func() {
185+
for {
186+
fd, err := l.Accept()
187+
if err != nil {
188+
if !errors.Is(err, net.ErrClosed) {
189+
t.Logf("accept error: %v", err)
190+
}
191+
return
192+
}
193+
194+
err = gosshagent.ServeAgent(kr, fd)
195+
if !errors.Is(err, io.EOF) {
196+
assert.NoError(t, err)
197+
}
198+
}
199+
})
200+
201+
t.Setenv("SSH_AUTH_SOCK", agentSock)
202+
cmd, root := clitest.New(t,
203+
"ssh",
204+
workspace.Name,
205+
"--forward-agent",
206+
)
207+
clitest.SetupConfig(t, client, root)
208+
pty := ptytest.New(t)
209+
cmd.SetIn(pty.Input())
210+
cmd.SetOut(pty.Output())
211+
cmd.SetErr(io.Discard)
212+
cmdDone := tGo(t, func() {
213+
err := cmd.Execute()
214+
assert.NoError(t, err)
215+
})
216+
217+
// Ensure that SSH_AUTH_SOCK is set.
218+
// Linux: /tmp/auth-agent3167016167/listener.sock
219+
// macOS: /var/folders/ng/m1q0wft14hj0t3rtjxrdnzsr0000gn/T/auth-agent3245553419/listener.sock
220+
pty.WriteLine("env")
221+
pty.ExpectMatch("SSH_AUTH_SOCK=")
222+
// Ensure that ssh-add lists our key.
223+
pty.WriteLine("ssh-add -L")
224+
keys, err := kr.List()
225+
require.NoError(t, err)
226+
pty.ExpectMatch(keys[0].String())
227+
228+
// And we're done.
229+
pty.WriteLine("exit")
230+
<-cmdDone
231+
})
232+
}
233+
234+
// tGoContext runs fn in a goroutine passing a context that will be
235+
// canceled on test completion and wait until fn has finished executing.
236+
// Done and cancel are returned for optionally waiting until completion
237+
// or early cancellation.
238+
//
239+
// NOTE(mafredri): This could be moved to a helper library.
240+
func tGoContext(t *testing.T, fn func(context.Context)) (done <-chan struct{}, cancel context.CancelFunc) {
241+
t.Helper()
242+
243+
ctx, cancel := context.WithCancel(context.Background())
244+
doneC := make(chan struct{})
245+
t.Cleanup(func() {
246+
cancel()
247+
<-done
248+
})
249+
go func() {
250+
fn(ctx)
251+
close(doneC)
252+
}()
253+
254+
return doneC, cancel
255+
}
256+
257+
// tGo runs fn in a goroutine and waits until fn has completed before
258+
// test completion. Done is returned for optionally waiting for fn to
259+
// exit.
260+
//
261+
// NOTE(mafredri): This could be moved to a helper library.
262+
func tGo(t *testing.T, fn func()) (done <-chan struct{}) {
263+
t.Helper()
264+
265+
doneC := make(chan struct{})
266+
t.Cleanup(func() {
267+
<-doneC
161268
})
269+
go func() {
270+
fn()
271+
close(doneC)
272+
}()
273+
274+
return doneC
162275
}
163276

164277
type stdioConn struct {

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy