diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 18e647ef15c0b..f53fe207c72cf 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -609,7 +609,9 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag // and SSH server close may be delayed. cmd.SysProcAttr = cmdSysProcAttr() - // to match OpenSSH, we don't actually tear a non-TTY command down, even if the session ends. + // to match OpenSSH, we don't actually tear a non-TTY command down, even if the session ends. OpenSSH closes the + // pipes to the process when the session ends; which is what happens here since we wire the command up to the + // session for I/O. // c.f. https://github.com/coder/coder/issues/18519#issuecomment-3019118271 cmd.Cancel = nil diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index 23d9dcc7da3b7..08fa02ddb4565 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -8,7 +8,9 @@ import ( "context" "fmt" "net" + "os" "os/user" + "path/filepath" "runtime" "strings" "sync" @@ -403,6 +405,81 @@ func TestNewServer_Signal(t *testing.T) { }) } +func TestSSHServer_ClosesStdin(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("bash doesn't exist on Windows") + } + + ctx := testutil.Context(t, testutil.WaitMedium) + logger := testutil.Logger(t) + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil) + require.NoError(t, err) + defer s.Close() + err = s.UpdateHostSigner(42) + assert.NoError(t, err) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + err := s.Serve(ln) + assert.Error(t, err) // Server is closed. + }() + defer func() { + err := s.Close() + require.NoError(t, err) + <-done + }() + + c := sshClient(t, ln.Addr().String()) + + sess, err := c.NewSession() + require.NoError(t, err) + stdout, err := sess.StdoutPipe() + require.NoError(t, err) + stdin, err := sess.StdinPipe() + require.NoError(t, err) + defer stdin.Close() + + dir := t.TempDir() + err = os.MkdirAll(dir, 0o755) + require.NoError(t, err) + filePath := filepath.Join(dir, "result.txt") + + // the shell command `read` will block until data is written to stdin, or closed. It will return + // exit code 1 if it hits EOF, which is what we want to test. + cmdErrCh := make(chan error, 1) + go func() { + cmdErrCh <- sess.Start(fmt.Sprintf("echo started; read; echo \"read exit code: $?\" > %s", filePath)) + }() + + cmdErr := testutil.RequireReceive(ctx, t, cmdErrCh) + require.NoError(t, cmdErr) + + readCh := make(chan error, 1) + go func() { + buf := make([]byte, 8) + _, err := stdout.Read(buf) + assert.Equal(t, "started\n", string(buf)) + readCh <- err + }() + err = testutil.RequireReceive(ctx, t, readCh) + require.NoError(t, err) + + sess.Close() + + var content []byte + testutil.Eventually(ctx, t, func(_ context.Context) bool { + content, err = os.ReadFile(filePath) + return err == nil + }, testutil.IntervalFast) + require.NoError(t, err) + require.Equal(t, "read exit code: 1\n", string(content)) +} + func sshClient(t *testing.T, addr string) *ssh.Client { conn, err := net.Dial("tcp", addr) require.NoError(t, err)
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: