Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions cli/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ func portForward() *cobra.Command {
},
),
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()

specs, err := parsePortForwards(tcpForwards, udpForwards, unixForwards)
if err != nil {
return xerrors.Errorf("parse port-forward specs: %w", err)
Expand All @@ -72,21 +75,21 @@ func portForward() *cobra.Command {
return err
}

workspace, agent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], false)
workspace, agent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false)
if err != nil {
return err
}
if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart {
return xerrors.New("workspace must be in start transition to port-forward")
}
if workspace.LatestBuild.Job.CompletedAt == nil {
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
if err != nil {
return err
}
}

err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{
WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, agent.ID)
Expand All @@ -96,15 +99,14 @@ func portForward() *cobra.Command {
return xerrors.Errorf("await agent: %w", err)
}

conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil)
conn, err := client.DialWorkspaceAgent(ctx, agent.ID, nil)
if err != nil {
return xerrors.Errorf("dial workspace agent: %w", err)
}
defer conn.Close()

// Start all listeners.
var (
ctx, cancel = context.WithCancel(cmd.Context())
wg = new(sync.WaitGroup)
listeners = make([]net.Listener, len(specs))
closeAllListeners = func() {
Expand All @@ -116,11 +118,11 @@ func portForward() *cobra.Command {
}
}
)
defer cancel()
defer closeAllListeners()

for i, spec := range specs {
l, err := listenAndPortForward(ctx, cmd, conn, wg, spec)
if err != nil {
closeAllListeners()
return err
}
listeners[i] = l
Expand All @@ -129,7 +131,10 @@ func portForward() *cobra.Command {
// Wait for the context to be canceled or for a signal and close
// all listeners.
var closeErr error
wg.Add(1)
go func() {
defer wg.Done()

sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)

Expand Down
82 changes: 48 additions & 34 deletions cli/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ func ssh() *cobra.Command {
Short: "SSH into a workspace",
Args: cobra.ArbitraryArgs,
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()

client, err := createClient(cmd)
if err != nil {
return err
Expand All @@ -68,14 +71,14 @@ func ssh() *cobra.Command {
}
}

workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle)
workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], shuffle)
if err != nil {
return err
}

// OpenSSH passes stderr directly to the calling TTY.
// This is required in "stdio" mode so a connecting indicator can be displayed.
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{
WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
Expand All @@ -85,42 +88,33 @@ func ssh() *cobra.Command {
return xerrors.Errorf("await agent: %w", err)
}

var (
sshClient *gossh.Client
sshSession *gossh.Session
)
var newSSHClient func() (*gossh.Client, error)

if !wireguard {
conn, err := client.DialWorkspaceAgent(cmd.Context(), workspaceAgent.ID, nil)
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
if err != nil {
return err
}
defer conn.Close()

stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace)
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
defer stopPolling()

if stdio {
rawSSH, err := conn.SSH()
if err != nil {
return err
}
defer rawSSH.Close()

go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
}()
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
return nil
}

sshClient, err = conn.SSHClient()
if err != nil {
return err
}

sshSession, err = sshClient.NewSession()
if err != nil {
return err
}
newSSHClient = conn.SSHClient
} else {
// TODO: more granual control of Tailscale logging.
peerwg.Logf = tslogger.Discard
Expand All @@ -133,8 +127,9 @@ func ssh() *cobra.Command {
if err != nil {
return xerrors.Errorf("create wireguard network: %w", err)
}
defer wgn.Close()

err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{
err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{
Recipient: workspaceAgent.ID,
NodePublicKey: wgn.NodePrivateKey.Public(),
DiscoPublicKey: wgn.DiscoPublicKey,
Expand All @@ -155,10 +150,11 @@ func ssh() *cobra.Command {
}

if stdio {
rawSSH, err := wgn.SSH(cmd.Context(), workspaceAgent.IPv6.IP())
rawSSH, err := wgn.SSH(ctx, workspaceAgent.IPv6.IP())
if err != nil {
return err
}
defer rawSSH.Close()

go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
Expand All @@ -167,16 +163,29 @@ func ssh() *cobra.Command {
return nil
}

sshClient, err = wgn.SSHClient(cmd.Context(), workspaceAgent.IPv6.IP())
if err != nil {
return err
newSSHClient = func() (*gossh.Client, error) {
return wgn.SSHClient(ctx, workspaceAgent.IPv6.IP())
}
}

sshSession, err = sshClient.NewSession()
if err != nil {
return err
}
sshClient, err := newSSHClient()
if err != nil {
return err
}
defer sshClient.Close()

sshSession, err := sshClient.NewSession()
if err != nil {
return err
}
defer sshSession.Close()

// Ensure context cancellation is propagated to the
// SSH session, e.g. to cancel `Wait()` at the end.
go func() {
<-ctx.Done()
_ = sshSession.Close()
}()

if identityAgent == "" {
identityAgent = os.Getenv("SSH_AUTH_SOCK")
Expand All @@ -203,15 +212,18 @@ func ssh() *cobra.Command {
_ = term.Restore(int(stdinFile.Fd()), state)
}()

windowChange := listenWindowSize(cmd.Context())
windowChange := listenWindowSize(ctx)
go func() {
for {
select {
case <-cmd.Context().Done():
case <-ctx.Done():
return
case <-windowChange:
}
width, height, _ := term.GetSize(int(stdoutFile.Fd()))
width, height, err := term.GetSize(int(stdoutFile.Fd()))
if err != nil {
continue
}
_ = sshSession.WindowChange(height, width)
}
}()
Expand All @@ -224,13 +236,17 @@ func ssh() *cobra.Command {

sshSession.Stdin = cmd.InOrStdin()
sshSession.Stdout = cmd.OutOrStdout()
sshSession.Stderr = cmd.OutOrStdout()
sshSession.Stderr = cmd.ErrOrStderr()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a drive-by change I did. It seemed wrong but perhaps I didn't understand the purpose?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to me!


err = sshSession.Shell()
if err != nil {
return err
}

// Put cancel at the top of the defer stack to initiate
// shutdown of services.
defer cancel()

err = sshSession.Wait()
if err != nil {
// If the connection drops unexpectedly, we get an ExitMissingError but no other
Expand Down Expand Up @@ -259,16 +275,14 @@ func ssh() *cobra.Command {
// getWorkspaceAgent returns the workspace and agent selected using either the
// `<workspace>[.<agent>]` syntax via `in` or picks a random workspace and agent
// if `shuffle` is true.
func getWorkspaceAndAgent(cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive
ctx := cmd.Context()

func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive
var (
workspace codersdk.Workspace
workspaceParts = strings.Split(in, ".")
err error
)
if shuffle {
workspaces, err := client.Workspaces(cmd.Context(), codersdk.WorkspaceFilter{
workspaces, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{
Owner: codersdk.Me,
})
if err != nil {
Expand Down
5 changes: 1 addition & 4 deletions cli/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func TestSSH(t *testing.T) {
pty := ptytest.New(t)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
cmd.SetErr(io.Discard)
cmd.SetErr(pty.Output())
cmdDone := tGo(t, func() {
err := cmd.ExecuteContext(ctx)
assert.NoError(t, err)
Expand All @@ -248,9 +248,6 @@ func TestSSH(t *testing.T) {

// And we're done.
pty.WriteLine("exit")
// Read output to prevent hang on macOS, see:
// https://github.com/coder/coder/issues/2122
pty.ExpectMatch("exit")
<-cmdDone
})
}
Expand Down
20 changes: 13 additions & 7 deletions cli/wireguardtunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ func wireguardPortForward() *cobra.Command {
},
),
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()

specs, err := parsePortForwards(tcpForwards, nil, nil)
if err != nil {
return xerrors.Errorf("parse port-forward specs: %w", err)
Expand All @@ -69,21 +72,21 @@ func wireguardPortForward() *cobra.Command {
return err
}

workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], false)
workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false)
if err != nil {
return err
}
if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart {
return xerrors.New("workspace must be in start transition to port-forward")
}
if workspace.LatestBuild.Job.CompletedAt == nil {
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
if err != nil {
return err
}
}

err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{
WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
Expand All @@ -101,8 +104,9 @@ func wireguardPortForward() *cobra.Command {
if err != nil {
return xerrors.Errorf("create wireguard network: %w", err)
}
defer wgn.Close()

err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{
err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{
Recipient: workspaceAgent.ID,
NodePublicKey: wgn.NodePrivateKey.Public(),
DiscoPublicKey: wgn.DiscoPublicKey,
Expand All @@ -124,7 +128,6 @@ func wireguardPortForward() *cobra.Command {

// Start all listeners.
var (
ctx, cancel = context.WithCancel(cmd.Context())
wg = new(sync.WaitGroup)
listeners = make([]net.Listener, len(specs))
closeAllListeners = func() {
Expand All @@ -136,11 +139,11 @@ func wireguardPortForward() *cobra.Command {
}
}
)
defer cancel()
defer closeAllListeners()

for i, spec := range specs {
l, err := listenAndPortForwardWireguard(ctx, cmd, wgn, wg, spec, workspaceAgent.IPv6.IP())
if err != nil {
closeAllListeners()
return err
}
listeners[i] = l
Expand All @@ -149,7 +152,10 @@ func wireguardPortForward() *cobra.Command {
// Wait for the context to be canceled or for a signal and close
// all listeners.
var closeErr error
wg.Add(1)
go func() {
defer wg.Done()

sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)

Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy