diff --git a/agent/agent.go b/agent/agent.go new file mode 100644 index 0000000000000..285efe3dc9836 --- /dev/null +++ b/agent/agent.go @@ -0,0 +1,329 @@ +package agent + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "fmt" + "io" + "net" + "os/exec" + "os/user" + "sync" + "time" + + "cdr.dev/slog" + "github.com/coder/coder/agent/usershell" + "github.com/coder/coder/peer" + "github.com/coder/coder/peerbroker" + "github.com/coder/coder/pty" + "github.com/coder/retry" + + "github.com/gliderlabs/ssh" + gossh "golang.org/x/crypto/ssh" + "golang.org/x/xerrors" +) + +func DialSSH(conn *peer.Conn) (net.Conn, error) { + channel, err := conn.Dial(context.Background(), "ssh", &peer.ChannelOptions{ + Protocol: "ssh", + }) + if err != nil { + return nil, err + } + return channel.NetConn(), nil +} + +func DialSSHClient(conn *peer.Conn) (*gossh.Client, error) { + netConn, err := DialSSH(conn) + if err != nil { + return nil, err + } + sshConn, channels, requests, err := gossh.NewClientConn(netConn, "localhost:22", &gossh.ClientConfig{ + Config: gossh.Config{ + Ciphers: []string{"arcfour"}, + }, + // SSH host validation isn't helpful, because obtaining a peer + // connection already signifies user-intent to dial a workspace. + // #nosec + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }) + if err != nil { + return nil, err + } + return gossh.NewClient(sshConn, channels, requests), nil +} + +type Options struct { + Logger slog.Logger +} + +type Dialer func(ctx context.Context) (*peerbroker.Listener, error) + +func New(dialer Dialer, options *Options) io.Closer { + ctx, cancelFunc := context.WithCancel(context.Background()) + server := &server{ + clientDialer: dialer, + options: options, + closeCancel: cancelFunc, + closed: make(chan struct{}), + } + server.init(ctx) + return server +} + +type server struct { + clientDialer Dialer + options *Options + + closeCancel context.CancelFunc + closeMutex sync.Mutex + closed chan struct{} + + sshServer *ssh.Server +} + +func (s *server) init(ctx context.Context) { + // Clients' should ignore the host key when connecting. + // The agent needs to authenticate with coderd to SSH, + // so SSH authentication doesn't improve security. + randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + randomSigner, err := gossh.NewSignerFromKey(randomHostKey) + if err != nil { + panic(err) + } + sshLogger := s.options.Logger.Named("ssh-server") + forwardHandler := &ssh.ForwardedTCPHandler{} + s.sshServer = &ssh.Server{ + ChannelHandlers: ssh.DefaultChannelHandlers, + ConnectionFailedCallback: func(conn net.Conn, err error) { + sshLogger.Info(ctx, "ssh connection ended", slog.Error(err)) + }, + Handler: func(session ssh.Session) { + err := s.handleSSHSession(session) + if err != nil { + s.options.Logger.Debug(ctx, "ssh session failed", slog.Error(err)) + _ = session.Exit(1) + return + } + }, + HostSigners: []ssh.Signer{randomSigner}, + LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { + // Allow local port forwarding all! + sshLogger.Debug(ctx, "local port forward", + slog.F("destination-host", destinationHost), + slog.F("destination-port", destinationPort)) + return true + }, + PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { + return true + }, + ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { + // Allow reverse port forwarding all! + sshLogger.Debug(ctx, "local port forward", + slog.F("bind-host", bindHost), + slog.F("bind-port", bindPort)) + return true + }, + RequestHandlers: map[string]ssh.RequestHandler{ + "tcpip-forward": forwardHandler.HandleSSHRequest, + "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, + }, + ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { + return &gossh.ServerConfig{ + Config: gossh.Config{ + // "arcfour" is the fastest SSH cipher. We prioritize throughput + // over encryption here, because the WebRTC connection is already + // encrypted. If possible, we'd disable encryption entirely here. + Ciphers: []string{"arcfour"}, + }, + NoClientAuth: true, + } + }, + } + + go s.run(ctx) +} + +func (*server) handleSSHSession(session ssh.Session) error { + var ( + command string + args = []string{} + err error + ) + + username := session.User() + if username == "" { + currentUser, err := user.Current() + if err != nil { + return xerrors.Errorf("get current user: %w", err) + } + username = currentUser.Username + } + + // gliderlabs/ssh returns a command slice of zero + // when a shell is requested. + if len(session.Command()) == 0 { + command, err = usershell.Get(username) + if err != nil { + return xerrors.Errorf("get user shell: %w", err) + } + } else { + command = session.Command()[0] + if len(session.Command()) > 1 { + args = session.Command()[1:] + } + } + + signals := make(chan ssh.Signal) + breaks := make(chan bool) + defer close(signals) + defer close(breaks) + go func() { + for { + select { + case <-session.Context().Done(): + return + // Ignore signals and breaks for now! + case <-signals: + case <-breaks: + } + } + }() + + cmd := exec.CommandContext(session.Context(), command, args...) + cmd.Env = session.Environ() + + sshPty, windowSize, isPty := session.Pty() + if isPty { + cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) + ptty, process, err := pty.Start(cmd) + if err != nil { + return xerrors.Errorf("start command: %w", err) + } + go func() { + for win := range windowSize { + err := ptty.Resize(uint16(win.Width), uint16(win.Height)) + if err != nil { + panic(err) + } + } + }() + go func() { + _, _ = io.Copy(ptty.Input(), session) + }() + go func() { + _, _ = io.Copy(session, ptty.Output()) + }() + _, _ = process.Wait() + _ = ptty.Close() + return nil + } + + cmd.Stdout = session + cmd.Stderr = session + // This blocks forever until stdin is received if we don't + // use StdinPipe. It's unknown what causes this. + stdinPipe, err := cmd.StdinPipe() + if err != nil { + return xerrors.Errorf("create stdin pipe: %w", err) + } + go func() { + _, _ = io.Copy(stdinPipe, session) + }() + err = cmd.Start() + if err != nil { + return xerrors.Errorf("start: %w", err) + } + _ = cmd.Wait() + return nil +} + +func (s *server) run(ctx context.Context) { + var peerListener *peerbroker.Listener + var err error + // An exponential back-off occurs when the connection is failing to dial. + // This is to prevent server spam in case of a coderd outage. + for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { + peerListener, err = s.clientDialer(ctx) + if err != nil { + if errors.Is(err, context.Canceled) { + return + } + if s.isClosed() { + return + } + s.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err)) + continue + } + s.options.Logger.Debug(context.Background(), "connected") + break + } + select { + case <-ctx.Done(): + return + default: + } + + for { + conn, err := peerListener.Accept() + if err != nil { + if s.isClosed() { + return + } + s.options.Logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err)) + s.run(ctx) + return + } + go s.handlePeerConn(ctx, conn) + } +} + +func (s *server) handlePeerConn(ctx context.Context, conn *peer.Conn) { + for { + channel, err := conn.Accept(ctx) + if err != nil { + if errors.Is(err, peer.ErrClosed) || s.isClosed() { + return + } + s.options.Logger.Debug(ctx, "accept channel from peer connection", slog.Error(err)) + return + } + + switch channel.Protocol() { + case "ssh": + s.sshServer.HandleConn(channel.NetConn()) + default: + s.options.Logger.Warn(ctx, "unhandled protocol from channel", + slog.F("protocol", channel.Protocol()), + slog.F("label", channel.Label()), + ) + } + } +} + +// isClosed returns whether the API is closed or not. +func (s *server) isClosed() bool { + select { + case <-s.closed: + return true + default: + return false + } +} + +func (s *server) Close() error { + s.closeMutex.Lock() + defer s.closeMutex.Unlock() + if s.isClosed() { + return nil + } + close(s.closed) + s.closeCancel() + _ = s.sshServer.Close() + return nil +} diff --git a/agent/agent_test.go b/agent/agent_test.go new file mode 100644 index 0000000000000..662c054eae146 --- /dev/null +++ b/agent/agent_test.go @@ -0,0 +1,110 @@ +package agent_test + +import ( + "context" + "runtime" + "strings" + "testing" + + "github.com/pion/webrtc/v3" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "golang.org/x/crypto/ssh" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/agent" + "github.com/coder/coder/peer" + "github.com/coder/coder/peerbroker" + "github.com/coder/coder/peerbroker/proto" + "github.com/coder/coder/provisionersdk" + "github.com/coder/coder/pty/ptytest" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestAgent(t *testing.T) { + t.Parallel() + t.Run("SessionExec", func(t *testing.T) { + t.Parallel() + api := setup(t) + stream, err := api.NegotiateConnection(context.Background()) + require.NoError(t, err) + conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{ + Logger: slogtest.Make(t, nil), + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.Close() + }) + sshClient, err := agent.DialSSHClient(conn) + require.NoError(t, err) + session, err := sshClient.NewSession() + require.NoError(t, err) + command := "echo test" + if runtime.GOOS == "windows" { + command = "cmd.exe /c echo test" + } + output, err := session.Output(command) + require.NoError(t, err) + require.Equal(t, "test", strings.TrimSpace(string(output))) + }) + + t.Run("SessionTTY", func(t *testing.T) { + t.Parallel() + api := setup(t) + stream, err := api.NegotiateConnection(context.Background()) + require.NoError(t, err) + conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{ + Logger: slogtest.Make(t, nil), + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.Close() + }) + sshClient, err := agent.DialSSHClient(conn) + require.NoError(t, err) + session, err := sshClient.NewSession() + require.NoError(t, err) + prompt := "$" + command := "bash" + if runtime.GOOS == "windows" { + command = "cmd.exe" + prompt = ">" + } + err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{}) + require.NoError(t, err) + ptty := ptytest.New(t) + require.NoError(t, err) + session.Stdout = ptty.Output() + session.Stderr = ptty.Output() + session.Stdin = ptty.Input() + err = session.Start(command) + require.NoError(t, err) + ptty.ExpectMatch(prompt) + ptty.WriteLine("echo test") + ptty.ExpectMatch("test") + ptty.WriteLine("exit") + err = session.Wait() + require.NoError(t, err) + }) +} + +func setup(t *testing.T) proto.DRPCPeerBrokerClient { + client, server := provisionersdk.TransportPipe() + closer := agent.New(func(ctx context.Context) (*peerbroker.Listener, error) { + return peerbroker.Listen(server, &peer.ConnOptions{ + Logger: slogtest.Make(t, nil), + }) + }, &agent.Options{ + Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), + }) + t.Cleanup(func() { + _ = client.Close() + _ = server.Close() + _ = closer.Close() + }) + return proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) +} diff --git a/agent/usershell/usershell_darwin.go b/agent/usershell/usershell_darwin.go new file mode 100644 index 0000000000000..d2b9a454e0470 --- /dev/null +++ b/agent/usershell/usershell_darwin.go @@ -0,0 +1,10 @@ +package usershell + +import "os" + +// Get returns the $SHELL environment variable. +// TODO: This should use "dscl" to fetch the proper value. See: +// https://stackoverflow.com/questions/16375519/how-to-get-the-default-shell +func Get(username string) (string, error) { + return os.Getenv("SHELL"), nil +} diff --git a/agent/usershell/usershell_other.go b/agent/usershell/usershell_other.go new file mode 100644 index 0000000000000..6f69a1e270ac3 --- /dev/null +++ b/agent/usershell/usershell_other.go @@ -0,0 +1,31 @@ +//go:build !windows && !darwin +// +build !windows,!darwin + +package usershell + +import ( + "os" + "strings" + + "golang.org/x/xerrors" +) + +// Get returns the /etc/passwd entry for the username provided. +func Get(username string) (string, error) { + contents, err := os.ReadFile("/etc/passwd") + if err != nil { + return "", xerrors.Errorf("read /etc/passwd: %w", err) + } + lines := strings.Split(string(contents), "\n") + for _, line := range lines { + if !strings.HasPrefix(line, username+":") { + continue + } + parts := strings.Split(line, ":") + if len(parts) < 7 { + return "", xerrors.Errorf("malformed user entry: %q", line) + } + return parts[6], nil + } + return "", xerrors.New("user not found in /etc/passwd and $SHELL not set") +} diff --git a/agent/usershell/usershell_other_test.go b/agent/usershell/usershell_other_test.go new file mode 100644 index 0000000000000..9469f31c70e70 --- /dev/null +++ b/agent/usershell/usershell_other_test.go @@ -0,0 +1,27 @@ +//go:build !windows && !darwin +// +build !windows,!darwin + +package usershell_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/agent/usershell" +) + +func TestGet(t *testing.T) { + t.Parallel() + t.Run("Has", func(t *testing.T) { + t.Parallel() + shell, err := usershell.Get("root") + require.NoError(t, err) + require.NotEmpty(t, shell) + }) + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + _, err := usershell.Get("notauser") + require.Error(t, err) + }) +} diff --git a/agent/usershell/usershell_windows.go b/agent/usershell/usershell_windows.go new file mode 100644 index 0000000000000..91bff1d8297cd --- /dev/null +++ b/agent/usershell/usershell_windows.go @@ -0,0 +1,6 @@ +package usershell + +// Get returns the command prompt binary name. +func Get(username string) (string, error) { + return "cmd.exe", nil +} diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 889f6241a442a..40eba2fb53942 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -164,7 +164,7 @@ func AwaitProjectImportJob(t *testing.T, client *codersdk.Client, organization s provisionerJob, err = client.ProjectImportJob(context.Background(), organization, job) require.NoError(t, err) return provisionerJob.Status.Completed() - }, 3*time.Second, 25*time.Millisecond) + }, 5*time.Second, 25*time.Millisecond) return provisionerJob } @@ -176,7 +176,7 @@ func AwaitWorkspaceProvisionJob(t *testing.T, client *codersdk.Client, organizat provisionerJob, err = client.WorkspaceProvisionJob(context.Background(), organization, job) require.NoError(t, err) return provisionerJob.Status.Completed() - }, 3*time.Second, 25*time.Millisecond) + }, 5*time.Second, 25*time.Millisecond) return provisionerJob } diff --git a/go.mod b/go.mod index 61fb3b0e83cee..c018ade5929d3 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/coder/retry v1.3.0 github.com/creack/pty v1.1.17 github.com/fatih/color v1.13.0 + github.com/gliderlabs/ssh v0.3.3 github.com/go-chi/chi/v5 v5.0.7 github.com/go-chi/render v1.0.1 github.com/go-playground/validator/v10 v10.10.0 @@ -64,6 +65,7 @@ require ( github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect github.com/agext/levenshtein v1.2.3 // indirect github.com/alecthomas/chroma v0.10.0 // indirect + github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect github.com/cenkalti/backoff/v4 v4.1.2 // indirect github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect diff --git a/go.sum b/go.sum index 5db435337eecd..cb9a8a7679c7a 100644 --- a/go.sum +++ b/go.sum @@ -132,6 +132,8 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRF github.com/alexflint/go-filemutex v0.0.0-20171022225611-72bdc8eae2ae/go.mod h1:CgnQgUtFrFz9mxFNtED3jI5tLDjKlOM+oUF/sTk6ps0= github.com/andybalholm/crlf v0.0.0-20171020200849-670099aa064f/go.mod h1:k8feO4+kXDxro6ErPXBRTJ/ro2mf0SsFG8s7doP9kJE= github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/arrow/go/arrow v0.0.0-20210818145353-234c94e4ce64/go.mod h1:2qMFB56yOP3KzkB3PbYZ4AlUFg3a88F67TIx5lB/WwY= github.com/apache/arrow/go/arrow v0.0.0-20211013220434-5962184e7a30/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs= @@ -442,6 +444,8 @@ github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14= github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= github.com/gliderlabs/ssh v0.2.2/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= +github.com/gliderlabs/ssh v0.3.3 h1:mBQ8NiOgDkINJrZtoizkC3nDNYgSaWtxyem6S2XHBtA= +github.com/gliderlabs/ssh v0.3.3/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914= github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8= github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8= diff --git a/pty/pty_other.go b/pty/pty_other.go index e2520a2387116..c3933878456cb 100644 --- a/pty/pty_other.go +++ b/pty/pty_other.go @@ -45,7 +45,7 @@ func (p *otherPty) Output() io.ReadWriter { func (p *otherPty) Resize(cols uint16, rows uint16) error { p.mutex.Lock() defer p.mutex.Unlock() - return pty.Setsize(p.tty, &pty.Winsize{ + return pty.Setsize(p.pty, &pty.Winsize{ Rows: rows, Cols: cols, }) diff --git a/pty/pty_windows.go b/pty/pty_windows.go index b6a9f8ae2e5dd..fa6f1932a48c3 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -96,12 +96,15 @@ func (p *ptyWindows) Close() error { return nil } p.closed = true + _ = p.outputWrite.Close() + _ = p.outputRead.Close() + _ = p.inputWrite.Close() + _ = p.inputRead.Close() ret, _, err := procClosePseudoConsole.Call(uintptr(p.console)) - if ret != 0 { + if ret < 0 { return xerrors.Errorf("close pseudo console: %w", err) } - _ = p.outputRead.Close() - _ = p.inputWrite.Close() + return nil } diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 7ea5b7a119f0d..60cd88ce606a2 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -5,8 +5,10 @@ import ( "bytes" "fmt" "io" + "os" "os/exec" "regexp" + "runtime" "strings" "testing" "unicode/utf8" @@ -28,10 +30,10 @@ func New(t *testing.T) *PTY { return create(t, ptty) } -func Start(t *testing.T, cmd *exec.Cmd) *PTY { - ptty, err := pty.Start(cmd) +func Start(t *testing.T, cmd *exec.Cmd) (*PTY, *os.Process) { + ptty, ps, err := pty.Start(cmd) require.NoError(t, err) - return create(t, ptty) + return create(t, ptty), ps } func create(t *testing.T, ptty pty.PTY) *PTY { @@ -86,10 +88,15 @@ func (p *PTY) ExpectMatch(str string) string { break } } + p.t.Logf("matched %q = %q", str, stripAnsi.ReplaceAllString(buffer.String(), "")) return buffer.String() } func (p *PTY) WriteLine(str string) { - _, err := fmt.Fprintf(p.PTY.Input(), "%s\n", str) + newline := "\n" + if runtime.GOOS == "windows" { + newline = "\r\n" + } + _, err := fmt.Fprintf(p.PTY.Input(), "%s%s", str, newline) require.NoError(p.t, err) } diff --git a/pty/start.go b/pty/start.go index 2b75843ee16c2..d0cbcd667d7b7 100644 --- a/pty/start.go +++ b/pty/start.go @@ -1,7 +1,10 @@ package pty -import "os/exec" +import ( + "os" + "os/exec" +) -func Start(cmd *exec.Cmd) (PTY, error) { +func Start(cmd *exec.Cmd) (PTY, *os.Process, error) { return startPty(cmd) } diff --git a/pty/start_other.go b/pty/start_other.go index 2f1a74633130e..6709cb271b1e4 100644 --- a/pty/start_other.go +++ b/pty/start_other.go @@ -4,6 +4,7 @@ package pty import ( + "os" "os/exec" "syscall" @@ -11,10 +12,10 @@ import ( "golang.org/x/xerrors" ) -func startPty(cmd *exec.Cmd) (PTY, error) { +func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) { ptty, tty, err := pty.Open() if err != nil { - return nil, xerrors.Errorf("open: %w", err) + return nil, nil, xerrors.Errorf("open: %w", err) } defer func() { _ = tty.Close() @@ -29,10 +30,11 @@ func startPty(cmd *exec.Cmd) (PTY, error) { err = cmd.Start() if err != nil { _ = ptty.Close() - return nil, xerrors.Errorf("start: %w", err) + return nil, nil, xerrors.Errorf("start: %w", err) } - return &otherPty{ + oPty := &otherPty{ pty: ptty, tty: tty, - }, nil + } + return oPty, cmd.Process, nil } diff --git a/pty/start_other_test.go b/pty/start_other_test.go index a5e7d94b36af1..30c87935bcd69 100644 --- a/pty/start_other_test.go +++ b/pty/start_other_test.go @@ -7,8 +7,9 @@ import ( "os/exec" "testing" - "github.com/coder/coder/pty/ptytest" "go.uber.org/goleak" + + "github.com/coder/coder/pty/ptytest" ) func TestMain(m *testing.M) { @@ -19,7 +20,7 @@ func TestStart(t *testing.T) { t.Parallel() t.Run("Echo", func(t *testing.T) { t.Parallel() - pty := ptytest.Start(t, exec.Command("echo", "test")) + pty, _ := ptytest.Start(t, exec.Command("echo", "test")) pty.ExpectMatch("test") }) } diff --git a/pty/start_windows.go b/pty/start_windows.go index 136ba245736ab..1019a969aef2c 100644 --- a/pty/start_windows.go +++ b/pty/start_windows.go @@ -11,47 +11,48 @@ import ( "unsafe" "golang.org/x/sys/windows" + "golang.org/x/xerrors" ) // Allocates a PTY and starts the specified command attached to it. // See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process -func startPty(cmd *exec.Cmd) (PTY, error) { +func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) { fullPath, err := exec.LookPath(cmd.Path) if err != nil { - return nil, err + return nil, nil, err } pathPtr, err := windows.UTF16PtrFromString(fullPath) if err != nil { - return nil, err + return nil, nil, err } argsPtr, err := windows.UTF16PtrFromString(windows.ComposeCommandLine(cmd.Args)) if err != nil { - return nil, err + return nil, nil, err } if cmd.Dir == "" { cmd.Dir, err = os.Getwd() if err != nil { - return nil, err + return nil, nil, err } } dirPtr, err := windows.UTF16PtrFromString(cmd.Dir) if err != nil { - return nil, err + return nil, nil, err } pty, err := newPty() if err != nil { - return nil, err + return nil, nil, err } winPty := pty.(*ptyWindows) attrs, err := windows.NewProcThreadAttributeList(1) if err != nil { - return nil, err + return nil, nil, err } // Taken from: https://github.com/microsoft/hcsshim/blob/2314362e977aa03b3ed245a4beb12d00422af0e2/internal/winapi/process.go#L6 err = attrs.Update(0x20016, unsafe.Pointer(winPty.console), unsafe.Sizeof(winPty.console)) if err != nil { - return nil, err + return nil, nil, err } startupInfo := &windows.StartupInfoEx{} @@ -73,12 +74,16 @@ func startPty(cmd *exec.Cmd) (PTY, error) { &processInfo, ) if err != nil { - return nil, err + return nil, nil, err } defer windows.CloseHandle(processInfo.Thread) defer windows.CloseHandle(processInfo.Process) - return pty, nil + process, err := os.FindProcess(int(processInfo.ProcessId)) + if err != nil { + return nil, nil, xerrors.Errorf("find process %d: %w", processInfo.ProcessId, err) + } + return pty, process, nil } // Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476 diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index faee269776830..d0398d0dec019 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -20,12 +20,12 @@ func TestStart(t *testing.T) { t.Parallel() t.Run("Echo", func(t *testing.T) { t.Parallel() - pty := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test")) + pty, _ := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test")) pty.ExpectMatch("test") }) t.Run("Resize", func(t *testing.T) { t.Parallel() - pty := ptytest.Start(t, exec.Command("cmd.exe")) + pty, _ := ptytest.Start(t, exec.Command("cmd.exe")) err := pty.Resize(100, 50) require.NoError(t, err) }) diff --git a/templates/null/main.tf b/templates/null/main.tf deleted file mode 100644 index 9bb3f2042e2a4..0000000000000 --- a/templates/null/main.tf +++ /dev/null @@ -1,5 +0,0 @@ -variable "bananas" { - description = "hello!" -} - -resource "null_resource" "example" {}
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: