Skip to content

Commit 83c63d4

Browse files
authored
fix: Improve shutdown procedure of ssh, portforward, wgtunnel cmds (#3354)
* fix: Improve shutdown procedure of ssh, portforward, wgtunnel cmds We could turn it into a practice to wrap `cmd.Context()` so that we have more fine-grained control of cancellation. Sometimes in tests we may be running commands with a context that is never canceled. Related to #3221 * fix: Set ssh session stderr to stderr
1 parent 5ae19f0 commit 83c63d4

File tree

4 files changed

+74
-52
lines changed

4 files changed

+74
-52
lines changed

cli/portforward.go

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ func portForward() *cobra.Command {
5555
},
5656
),
5757
RunE: func(cmd *cobra.Command, args []string) error {
58+
ctx, cancel := context.WithCancel(cmd.Context())
59+
defer cancel()
60+
5861
specs, err := parsePortForwards(tcpForwards, udpForwards, unixForwards)
5962
if err != nil {
6063
return xerrors.Errorf("parse port-forward specs: %w", err)
@@ -72,21 +75,21 @@ func portForward() *cobra.Command {
7275
return err
7376
}
7477

75-
workspace, agent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], false)
78+
workspace, agent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false)
7679
if err != nil {
7780
return err
7881
}
7982
if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart {
8083
return xerrors.New("workspace must be in start transition to port-forward")
8184
}
8285
if workspace.LatestBuild.Job.CompletedAt == nil {
83-
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
86+
err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
8487
if err != nil {
8588
return err
8689
}
8790
}
8891

89-
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
92+
err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{
9093
WorkspaceName: workspace.Name,
9194
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
9295
return client.WorkspaceAgent(ctx, agent.ID)
@@ -96,15 +99,14 @@ func portForward() *cobra.Command {
9699
return xerrors.Errorf("await agent: %w", err)
97100
}
98101

99-
conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil)
102+
conn, err := client.DialWorkspaceAgent(ctx, agent.ID, nil)
100103
if err != nil {
101104
return xerrors.Errorf("dial workspace agent: %w", err)
102105
}
103106
defer conn.Close()
104107

105108
// Start all listeners.
106109
var (
107-
ctx, cancel = context.WithCancel(cmd.Context())
108110
wg = new(sync.WaitGroup)
109111
listeners = make([]net.Listener, len(specs))
110112
closeAllListeners = func() {
@@ -116,11 +118,11 @@ func portForward() *cobra.Command {
116118
}
117119
}
118120
)
119-
defer cancel()
121+
defer closeAllListeners()
122+
120123
for i, spec := range specs {
121124
l, err := listenAndPortForward(ctx, cmd, conn, wg, spec)
122125
if err != nil {
123-
closeAllListeners()
124126
return err
125127
}
126128
listeners[i] = l
@@ -129,7 +131,10 @@ func portForward() *cobra.Command {
129131
// Wait for the context to be canceled or for a signal and close
130132
// all listeners.
131133
var closeErr error
134+
wg.Add(1)
132135
go func() {
136+
defer wg.Done()
137+
133138
sigs := make(chan os.Signal, 1)
134139
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
135140

cli/ssh.go

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ func ssh() *cobra.Command {
5151
Short: "SSH into a workspace",
5252
Args: cobra.ArbitraryArgs,
5353
RunE: func(cmd *cobra.Command, args []string) error {
54+
ctx, cancel := context.WithCancel(cmd.Context())
55+
defer cancel()
56+
5457
client, err := createClient(cmd)
5558
if err != nil {
5659
return err
@@ -68,14 +71,14 @@ func ssh() *cobra.Command {
6871
}
6972
}
7073

71-
workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle)
74+
workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], shuffle)
7275
if err != nil {
7376
return err
7477
}
7578

7679
// OpenSSH passes stderr directly to the calling TTY.
7780
// This is required in "stdio" mode so a connecting indicator can be displayed.
78-
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
81+
err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{
7982
WorkspaceName: workspace.Name,
8083
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
8184
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
@@ -85,42 +88,33 @@ func ssh() *cobra.Command {
8588
return xerrors.Errorf("await agent: %w", err)
8689
}
8790

88-
var (
89-
sshClient *gossh.Client
90-
sshSession *gossh.Session
91-
)
91+
var newSSHClient func() (*gossh.Client, error)
9292

9393
if !wireguard {
94-
conn, err := client.DialWorkspaceAgent(cmd.Context(), workspaceAgent.ID, nil)
94+
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
9595
if err != nil {
9696
return err
9797
}
9898
defer conn.Close()
9999

100-
stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace)
100+
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
101101
defer stopPolling()
102102

103103
if stdio {
104104
rawSSH, err := conn.SSH()
105105
if err != nil {
106106
return err
107107
}
108+
defer rawSSH.Close()
109+
108110
go func() {
109111
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
110112
}()
111113
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
112114
return nil
113115
}
114116

115-
sshClient, err = conn.SSHClient()
116-
if err != nil {
117-
return err
118-
}
119-
120-
sshSession, err = sshClient.NewSession()
121-
if err != nil {
122-
return err
123-
}
117+
newSSHClient = conn.SSHClient
124118
} else {
125119
// TODO: more granual control of Tailscale logging.
126120
peerwg.Logf = tslogger.Discard
@@ -133,8 +127,9 @@ func ssh() *cobra.Command {
133127
if err != nil {
134128
return xerrors.Errorf("create wireguard network: %w", err)
135129
}
130+
defer wgn.Close()
136131

137-
err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{
132+
err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{
138133
Recipient: workspaceAgent.ID,
139134
NodePublicKey: wgn.NodePrivateKey.Public(),
140135
DiscoPublicKey: wgn.DiscoPublicKey,
@@ -155,10 +150,11 @@ func ssh() *cobra.Command {
155150
}
156151

157152
if stdio {
158-
rawSSH, err := wgn.SSH(cmd.Context(), workspaceAgent.IPv6.IP())
153+
rawSSH, err := wgn.SSH(ctx, workspaceAgent.IPv6.IP())
159154
if err != nil {
160155
return err
161156
}
157+
defer rawSSH.Close()
162158

163159
go func() {
164160
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
@@ -167,16 +163,29 @@ func ssh() *cobra.Command {
167163
return nil
168164
}
169165

170-
sshClient, err = wgn.SSHClient(cmd.Context(), workspaceAgent.IPv6.IP())
171-
if err != nil {
172-
return err
166+
newSSHClient = func() (*gossh.Client, error) {
167+
return wgn.SSHClient(ctx, workspaceAgent.IPv6.IP())
173168
}
169+
}
174170

175-
sshSession, err = sshClient.NewSession()
176-
if err != nil {
177-
return err
178-
}
171+
sshClient, err := newSSHClient()
172+
if err != nil {
173+
return err
174+
}
175+
defer sshClient.Close()
176+
177+
sshSession, err := sshClient.NewSession()
178+
if err != nil {
179+
return err
179180
}
181+
defer sshSession.Close()
182+
183+
// Ensure context cancellation is propagated to the
184+
// SSH session, e.g. to cancel `Wait()` at the end.
185+
go func() {
186+
<-ctx.Done()
187+
_ = sshSession.Close()
188+
}()
180189

181190
if identityAgent == "" {
182191
identityAgent = os.Getenv("SSH_AUTH_SOCK")
@@ -203,15 +212,18 @@ func ssh() *cobra.Command {
203212
_ = term.Restore(int(stdinFile.Fd()), state)
204213
}()
205214

206-
windowChange := listenWindowSize(cmd.Context())
215+
windowChange := listenWindowSize(ctx)
207216
go func() {
208217
for {
209218
select {
210-
case <-cmd.Context().Done():
219+
case <-ctx.Done():
211220
return
212221
case <-windowChange:
213222
}
214-
width, height, _ := term.GetSize(int(stdoutFile.Fd()))
223+
width, height, err := term.GetSize(int(stdoutFile.Fd()))
224+
if err != nil {
225+
continue
226+
}
215227
_ = sshSession.WindowChange(height, width)
216228
}
217229
}()
@@ -224,13 +236,17 @@ func ssh() *cobra.Command {
224236

225237
sshSession.Stdin = cmd.InOrStdin()
226238
sshSession.Stdout = cmd.OutOrStdout()
227-
sshSession.Stderr = cmd.OutOrStdout()
239+
sshSession.Stderr = cmd.ErrOrStderr()
228240

229241
err = sshSession.Shell()
230242
if err != nil {
231243
return err
232244
}
233245

246+
// Put cancel at the top of the defer stack to initiate
247+
// shutdown of services.
248+
defer cancel()
249+
234250
err = sshSession.Wait()
235251
if err != nil {
236252
// If the connection drops unexpectedly, we get an ExitMissingError but no other
@@ -259,16 +275,14 @@ func ssh() *cobra.Command {
259275
// getWorkspaceAgent returns the workspace and agent selected using either the
260276
// `<workspace>[.<agent>]` syntax via `in` or picks a random workspace and agent
261277
// if `shuffle` is true.
262-
func getWorkspaceAndAgent(cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive
263-
ctx := cmd.Context()
264-
278+
func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive
265279
var (
266280
workspace codersdk.Workspace
267281
workspaceParts = strings.Split(in, ".")
268282
err error
269283
)
270284
if shuffle {
271-
workspaces, err := client.Workspaces(cmd.Context(), codersdk.WorkspaceFilter{
285+
workspaces, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{
272286
Owner: codersdk.Me,
273287
})
274288
if err != nil {

cli/ssh_test.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ func TestSSH(t *testing.T) {
229229
pty := ptytest.New(t)
230230
cmd.SetIn(pty.Input())
231231
cmd.SetOut(pty.Output())
232-
cmd.SetErr(io.Discard)
232+
cmd.SetErr(pty.Output())
233233
cmdDone := tGo(t, func() {
234234
err := cmd.ExecuteContext(ctx)
235235
assert.NoError(t, err)
@@ -248,9 +248,6 @@ func TestSSH(t *testing.T) {
248248

249249
// And we're done.
250250
pty.WriteLine("exit")
251-
// Read output to prevent hang on macOS, see:
252-
// https://github.com/coder/coder/issues/2122
253-
pty.ExpectMatch("exit")
254251
<-cmdDone
255252
})
256253
}

cli/wireguardtunnel.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ func wireguardPortForward() *cobra.Command {
5252
},
5353
),
5454
RunE: func(cmd *cobra.Command, args []string) error {
55+
ctx, cancel := context.WithCancel(cmd.Context())
56+
defer cancel()
57+
5558
specs, err := parsePortForwards(tcpForwards, nil, nil)
5659
if err != nil {
5760
return xerrors.Errorf("parse port-forward specs: %w", err)
@@ -69,21 +72,21 @@ func wireguardPortForward() *cobra.Command {
6972
return err
7073
}
7174

72-
workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], false)
75+
workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false)
7376
if err != nil {
7477
return err
7578
}
7679
if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart {
7780
return xerrors.New("workspace must be in start transition to port-forward")
7881
}
7982
if workspace.LatestBuild.Job.CompletedAt == nil {
80-
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
83+
err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
8184
if err != nil {
8285
return err
8386
}
8487
}
8588

86-
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
89+
err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{
8790
WorkspaceName: workspace.Name,
8891
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
8992
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
@@ -101,8 +104,9 @@ func wireguardPortForward() *cobra.Command {
101104
if err != nil {
102105
return xerrors.Errorf("create wireguard network: %w", err)
103106
}
107+
defer wgn.Close()
104108

105-
err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{
109+
err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{
106110
Recipient: workspaceAgent.ID,
107111
NodePublicKey: wgn.NodePrivateKey.Public(),
108112
DiscoPublicKey: wgn.DiscoPublicKey,
@@ -124,7 +128,6 @@ func wireguardPortForward() *cobra.Command {
124128

125129
// Start all listeners.
126130
var (
127-
ctx, cancel = context.WithCancel(cmd.Context())
128131
wg = new(sync.WaitGroup)
129132
listeners = make([]net.Listener, len(specs))
130133
closeAllListeners = func() {
@@ -136,11 +139,11 @@ func wireguardPortForward() *cobra.Command {
136139
}
137140
}
138141
)
139-
defer cancel()
142+
defer closeAllListeners()
143+
140144
for i, spec := range specs {
141145
l, err := listenAndPortForwardWireguard(ctx, cmd, wgn, wg, spec, workspaceAgent.IPv6.IP())
142146
if err != nil {
143-
closeAllListeners()
144147
return err
145148
}
146149
listeners[i] = l
@@ -149,7 +152,10 @@ func wireguardPortForward() *cobra.Command {
149152
// Wait for the context to be canceled or for a signal and close
150153
// all listeners.
151154
var closeErr error
155+
wg.Add(1)
152156
go func() {
157+
defer wg.Done()
158+
153159
sigs := make(chan os.Signal, 1)
154160
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
155161

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy