Skip to content

Commit d0e2060

Browse files
authored
feat(agent): add second SSH listener on port 22 (#16627)
Fixes: coder/internal#377 Added an additional SSH listener on port 22, so the agent now listens on both, port one and port 22. --- Change-Id: Ifd986b260f8ac317e37d65111cd4e0bd1dc38af8 Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent ca23abe commit d0e2060

File tree

6 files changed

+153
-95
lines changed

6 files changed

+153
-95
lines changed

agent/agent.go

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,19 +1362,22 @@ func (a *agent) createTailnet(
13621362
return nil, xerrors.Errorf("update host signer: %w", err)
13631363
}
13641364

1365-
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentSSHPort))
1366-
if err != nil {
1367-
return nil, xerrors.Errorf("listen on the ssh port: %w", err)
1368-
}
1369-
defer func() {
1365+
for _, port := range []int{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort} {
1366+
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(port))
13701367
if err != nil {
1371-
_ = sshListener.Close()
1368+
return nil, xerrors.Errorf("listen on the ssh port (%v): %w", port, err)
1369+
}
1370+
// nolint:revive // We do want to run the deferred functions when createTailnet returns.
1371+
defer func() {
1372+
if err != nil {
1373+
_ = sshListener.Close()
1374+
}
1375+
}()
1376+
if err = a.trackGoroutine(func() {
1377+
_ = a.sshServer.Serve(sshListener)
1378+
}); err != nil {
1379+
return nil, err
13721380
}
1373-
}()
1374-
if err = a.trackGoroutine(func() {
1375-
_ = a.sshServer.Serve(sshListener)
1376-
}); err != nil {
1377-
return nil, err
13781381
}
13791382

13801383
reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentReconnectingPTYPort))

agent/agent_test.go

Lines changed: 120 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -65,38 +65,48 @@ func TestMain(m *testing.M) {
6565
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
6666
}
6767

68+
var sshPorts = []uint16{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort}
69+
6870
// NOTE: These tests only work when your default shell is bash for some reason.
6971

7072
func TestAgent_Stats_SSH(t *testing.T) {
7173
t.Parallel()
72-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
73-
defer cancel()
7474

75-
//nolint:dogsled
76-
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
75+
for _, port := range sshPorts {
76+
port := port
77+
t.Run(fmt.Sprintf("(:%d)", port), func(t *testing.T) {
78+
t.Parallel()
7779

78-
sshClient, err := conn.SSHClient(ctx)
79-
require.NoError(t, err)
80-
defer sshClient.Close()
81-
session, err := sshClient.NewSession()
82-
require.NoError(t, err)
83-
defer session.Close()
84-
stdin, err := session.StdinPipe()
85-
require.NoError(t, err)
86-
err = session.Shell()
87-
require.NoError(t, err)
80+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
81+
defer cancel()
8882

89-
var s *proto.Stats
90-
require.Eventuallyf(t, func() bool {
91-
var ok bool
92-
s, ok = <-stats
93-
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSsh == 1
94-
}, testutil.WaitLong, testutil.IntervalFast,
95-
"never saw stats: %+v", s,
96-
)
97-
_ = stdin.Close()
98-
err = session.Wait()
99-
require.NoError(t, err)
83+
//nolint:dogsled
84+
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
85+
86+
sshClient, err := conn.SSHClientOnPort(ctx, port)
87+
require.NoError(t, err)
88+
defer sshClient.Close()
89+
session, err := sshClient.NewSession()
90+
require.NoError(t, err)
91+
defer session.Close()
92+
stdin, err := session.StdinPipe()
93+
require.NoError(t, err)
94+
err = session.Shell()
95+
require.NoError(t, err)
96+
97+
var s *proto.Stats
98+
require.Eventuallyf(t, func() bool {
99+
var ok bool
100+
s, ok = <-stats
101+
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSsh == 1
102+
}, testutil.WaitLong, testutil.IntervalFast,
103+
"never saw stats: %+v", s,
104+
)
105+
_ = stdin.Close()
106+
err = session.Wait()
107+
require.NoError(t, err)
108+
})
109+
}
100110
}
101111

102112
func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
@@ -278,15 +288,23 @@ func TestAgent_Stats_Magic(t *testing.T) {
278288

279289
func TestAgent_SessionExec(t *testing.T) {
280290
t.Parallel()
281-
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
282291

283-
command := "echo test"
284-
if runtime.GOOS == "windows" {
285-
command = "cmd.exe /c echo test"
292+
for _, port := range sshPorts {
293+
port := port
294+
t.Run(fmt.Sprintf("(:%d)", port), func(t *testing.T) {
295+
t.Parallel()
296+
297+
session := setupSSHSessionOnPort(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil, port)
298+
299+
command := "echo test"
300+
if runtime.GOOS == "windows" {
301+
command = "cmd.exe /c echo test"
302+
}
303+
output, err := session.Output(command)
304+
require.NoError(t, err)
305+
require.Equal(t, "test", strings.TrimSpace(string(output)))
306+
})
286307
}
287-
output, err := session.Output(command)
288-
require.NoError(t, err)
289-
require.Equal(t, "test", strings.TrimSpace(string(output)))
290308
}
291309

292310
//nolint:tparallel // Sub tests need to run sequentially.
@@ -396,25 +414,33 @@ func TestAgent_SessionTTYShell(t *testing.T) {
396414
// it seems like it could be either.
397415
t.Skip("ConPTY appears to be inconsistent on Windows.")
398416
}
399-
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
400-
command := "sh"
401-
if runtime.GOOS == "windows" {
402-
command = "cmd.exe"
417+
418+
for _, port := range sshPorts {
419+
port := port
420+
t.Run(fmt.Sprintf("(%d)", port), func(t *testing.T) {
421+
t.Parallel()
422+
423+
session := setupSSHSessionOnPort(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil, port)
424+
command := "sh"
425+
if runtime.GOOS == "windows" {
426+
command = "cmd.exe"
427+
}
428+
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
429+
require.NoError(t, err)
430+
ptty := ptytest.New(t)
431+
session.Stdout = ptty.Output()
432+
session.Stderr = ptty.Output()
433+
session.Stdin = ptty.Input()
434+
err = session.Start(command)
435+
require.NoError(t, err)
436+
_ = ptty.Peek(ctx, 1) // wait for the prompt
437+
ptty.WriteLine("echo test")
438+
ptty.ExpectMatch("test")
439+
ptty.WriteLine("exit")
440+
err = session.Wait()
441+
require.NoError(t, err)
442+
})
403443
}
404-
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
405-
require.NoError(t, err)
406-
ptty := ptytest.New(t)
407-
session.Stdout = ptty.Output()
408-
session.Stderr = ptty.Output()
409-
session.Stdin = ptty.Input()
410-
err = session.Start(command)
411-
require.NoError(t, err)
412-
_ = ptty.Peek(ctx, 1) // wait for the prompt
413-
ptty.WriteLine("echo test")
414-
ptty.ExpectMatch("test")
415-
ptty.WriteLine("exit")
416-
err = session.Wait()
417-
require.NoError(t, err)
418444
}
419445

420446
func TestAgent_SessionTTYExitCode(t *testing.T) {
@@ -608,37 +634,41 @@ func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
608634
//nolint:dogsled // Allow the blank identifiers.
609635
conn, client, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, setSBInterval)
610636

611-
sshClient, err := conn.SSHClient(ctx)
612-
require.NoError(t, err)
613-
t.Cleanup(func() {
614-
_ = sshClient.Close()
615-
})
616-
617637
//nolint:paralleltest // These tests need to swap the banner func.
618-
for i, test := range tests {
619-
test := test
620-
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
621-
// Set new banner func and wait for the agent to call it to update the
622-
// banner.
623-
ready := make(chan struct{}, 2)
624-
client.SetAnnouncementBannersFunc(func() ([]codersdk.BannerConfig, error) {
625-
select {
626-
case ready <- struct{}{}:
627-
default:
628-
}
629-
return []codersdk.BannerConfig{test.banner}, nil
630-
})
631-
<-ready
632-
<-ready // Wait for two updates to ensure the value has propagated.
633-
634-
session, err := sshClient.NewSession()
635-
require.NoError(t, err)
636-
t.Cleanup(func() {
637-
_ = session.Close()
638-
})
638+
for _, port := range sshPorts {
639+
port := port
639640

640-
testSessionOutput(t, session, test.expected, test.unexpected, nil)
641+
sshClient, err := conn.SSHClientOnPort(ctx, port)
642+
require.NoError(t, err)
643+
t.Cleanup(func() {
644+
_ = sshClient.Close()
641645
})
646+
647+
for i, test := range tests {
648+
test := test
649+
t.Run(fmt.Sprintf("(:%d)/%d", port, i), func(t *testing.T) {
650+
// Set new banner func and wait for the agent to call it to update the
651+
// banner.
652+
ready := make(chan struct{}, 2)
653+
client.SetAnnouncementBannersFunc(func() ([]codersdk.BannerConfig, error) {
654+
select {
655+
case ready <- struct{}{}:
656+
default:
657+
}
658+
return []codersdk.BannerConfig{test.banner}, nil
659+
})
660+
<-ready
661+
<-ready // Wait for two updates to ensure the value has propagated.
662+
663+
session, err := sshClient.NewSession()
664+
require.NoError(t, err)
665+
t.Cleanup(func() {
666+
_ = session.Close()
667+
})
668+
669+
testSessionOutput(t, session, test.expected, test.unexpected, nil)
670+
})
671+
}
642672
}
643673
}
644674

@@ -2424,6 +2454,17 @@ func setupSSHSession(
24242454
banner codersdk.BannerConfig,
24252455
prepareFS func(fs afero.Fs),
24262456
opts ...func(*agenttest.Client, *agent.Options),
2457+
) *ssh.Session {
2458+
return setupSSHSessionOnPort(t, manifest, banner, prepareFS, workspacesdk.AgentSSHPort, opts...)
2459+
}
2460+
2461+
func setupSSHSessionOnPort(
2462+
t *testing.T,
2463+
manifest agentsdk.Manifest,
2464+
banner codersdk.BannerConfig,
2465+
prepareFS func(fs afero.Fs),
2466+
port uint16,
2467+
opts ...func(*agenttest.Client, *agent.Options),
24272468
) *ssh.Session {
24282469
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
24292470
defer cancel()
@@ -2437,7 +2478,7 @@ func setupSSHSession(
24372478
if prepareFS != nil {
24382479
prepareFS(fs)
24392480
}
2440-
sshClient, err := conn.SSHClient(ctx)
2481+
sshClient, err := conn.SSHClientOnPort(ctx, port)
24412482
require.NoError(t, err)
24422483
t.Cleanup(func() {
24432484
_ = sshClient.Close()

agent/usershell/usershell_darwin.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func Get(username string) (string, error) {
1818
return "", xerrors.Errorf("username is nonlocal path: %s", username)
1919
}
2020
//nolint: gosec // input checked above
21-
out, _ := exec.Command("dscl", ".", "-read", filepath.Join("/Users", username), "UserShell").Output()
21+
out, _ := exec.Command("dscl", ".", "-read", filepath.Join("/Users", username), "UserShell").Output() //nolint:gocritic
2222
s, ok := strings.CutPrefix(string(out), "UserShell: ")
2323
if ok {
2424
return strings.TrimSpace(s), nil

codersdk/workspacesdk/agentconn.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,24 +165,36 @@ func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, w
165165
// SSH pipes the SSH protocol over the returned net.Conn.
166166
// This connects to the built-in SSH server in the workspace agent.
167167
func (c *AgentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) {
168+
return c.SSHOnPort(ctx, AgentSSHPort)
169+
}
170+
171+
// SSHOnPort pipes the SSH protocol over the returned net.Conn.
172+
// This connects to the built-in SSH server in the workspace agent on the specified port.
173+
func (c *AgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) {
168174
ctx, span := tracing.StartSpan(ctx)
169175
defer span.End()
170176

171177
if !c.AwaitReachable(ctx) {
172178
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
173179
}
174180

175-
c.Conn.SendConnectedTelemetry(c.agentAddress(), tailnet.TelemetryApplicationSSH)
176-
return c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), AgentSSHPort))
181+
c.SendConnectedTelemetry(c.agentAddress(), tailnet.TelemetryApplicationSSH)
182+
return c.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), port))
177183
}
178184

179185
// SSHClient calls SSH to create a client that uses a weak cipher
180186
// to improve throughput.
181187
func (c *AgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) {
188+
return c.SSHClientOnPort(ctx, AgentSSHPort)
189+
}
190+
191+
// SSHClientOnPort calls SSH to create a client on a specific port
192+
// that uses a weak cipher to improve throughput.
193+
func (c *AgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) {
182194
ctx, span := tracing.StartSpan(ctx)
183195
defer span.End()
184196

185-
netConn, err := c.SSH(ctx)
197+
netConn, err := c.SSHOnPort(ctx, port)
186198
if err != nil {
187199
return nil, xerrors.Errorf("ssh: %w", err)
188200
}

codersdk/workspacesdk/workspacesdk.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ var ErrSkipClose = xerrors.New("skip tailnet close")
3131

3232
const (
3333
AgentSSHPort = tailnet.WorkspaceAgentSSHPort
34+
AgentStandardSSHPort = tailnet.WorkspaceAgentStandardSSHPort
3435
AgentReconnectingPTYPort = tailnet.WorkspaceAgentReconnectingPTYPort
3536
AgentSpeedtestPort = tailnet.WorkspaceAgentSpeedtestPort
3637
// AgentHTTPAPIServerPort serves a HTTP server with endpoints for e.g.

tailnet/conn.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ const (
5252
WorkspaceAgentSSHPort = 1
5353
WorkspaceAgentReconnectingPTYPort = 2
5454
WorkspaceAgentSpeedtestPort = 3
55+
WorkspaceAgentStandardSSHPort = 22
5556
)
5657

5758
// EnvMagicsockDebugLogging enables super-verbose logging for the magicsock
@@ -745,7 +746,7 @@ func (c *Conn) forwardTCP(src, dst netip.AddrPort) (handler func(net.Conn), opts
745746
return nil, nil, false
746747
}
747748
// See: https://github.com/tailscale/tailscale/blob/c7cea825aea39a00aca71ea02bab7266afc03e7c/wgengine/netstack/netstack.go#L888
748-
if dst.Port() == WorkspaceAgentSSHPort || dst.Port() == 22 {
749+
if dst.Port() == WorkspaceAgentSSHPort || dst.Port() == WorkspaceAgentStandardSSHPort {
749750
opt := tcpip.KeepaliveIdleOption(72 * time.Hour)
750751
opts = append(opts, &opt)
751752
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy