diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 178998b5a21f9..43fbaec1109e2 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -7,112 +7,137 @@ import ( "io" "os" "os/exec" - "regexp" "runtime" "strings" + "sync" "testing" "time" "unicode/utf8" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "github.com/coder/coder/pty" ) -var ( - // Used to ensure terminal output doesn't have anything crazy! - // See: https://stackoverflow.com/a/29497680 - stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") -) - func New(t *testing.T) *PTY { ptty, err := pty.New() require.NoError(t, err) - return create(t, ptty) + return create(t, ptty, "cmd") } func Start(t *testing.T, cmd *exec.Cmd) (*PTY, *os.Process) { ptty, ps, err := pty.Start(cmd) require.NoError(t, err) - return create(t, ptty), ps + return create(t, ptty, cmd.Args[0]), ps } -func create(t *testing.T, ptty pty.PTY) *PTY { - reader, writer := io.Pipe() - scanner := bufio.NewScanner(reader) +func create(t *testing.T, ptty pty.PTY, name string) *PTY { + // Use pipe for logging. + logDone := make(chan struct{}) + logr, logw := io.Pipe() t.Cleanup(func() { - _ = reader.Close() - _ = writer.Close() + _ = logw.Close() + _ = logr.Close() + <-logDone // Guard against logging after test. }) go func() { - for scanner.Scan() { - if scanner.Err() != nil { - return - } - t.Log(stripAnsi.ReplaceAllString(scanner.Text(), "")) + defer close(logDone) + s := bufio.NewScanner(logr) + for s.Scan() { + // Quote output to avoid terminal escape codes, e.g. bell. + t.Logf("%s: stdout: %q", name, s.Text()) } }() + // Write to log and output buffer. + copyDone := make(chan struct{}) + out := newStdbuf() + w := io.MultiWriter(logw, out) + go func() { + defer close(copyDone) + _, err := io.Copy(w, ptty.Output()) + _ = out.closeErr(err) + }() t.Cleanup(func() { + _ = out.Close _ = ptty.Close() + <-copyDone }) + return &PTY{ t: t, PTY: ptty, + out: out, - outputWriter: writer, - runeReader: bufio.NewReaderSize(ptty.Output(), utf8.UTFMax), + runeReader: bufio.NewReaderSize(out, utf8.UTFMax), } } type PTY struct { t *testing.T pty.PTY + out *stdbuf - outputWriter io.Writer - runeReader *bufio.Reader + runeReader *bufio.Reader } func (p *PTY) ExpectMatch(str string) string { + p.t.Helper() + + timeout, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + var buffer bytes.Buffer - multiWriter := io.MultiWriter(&buffer, p.outputWriter) - runeWriter := bufio.NewWriterSize(multiWriter, utf8.UTFMax) - complete, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() + match := make(chan error, 1) go func() { - timer := time.NewTimer(10 * time.Second) - defer timer.Stop() - select { - case <-complete.Done(): - return - case <-timer.C: - } - _ = p.Close() - p.t.Errorf("%s match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String()) + defer close(match) + match <- func() error { + for { + r, _, err := p.runeReader.ReadRune() + if err != nil { + return err + } + _, err = buffer.WriteRune(r) + if err != nil { + return err + } + if strings.Contains(buffer.String(), str) { + return nil + } + } + }() }() - for { - var r rune - r, _, err := p.runeReader.ReadRune() - require.NoError(p.t, err) - _, err = runeWriter.WriteRune(r) - require.NoError(p.t, err) - err = runeWriter.Flush() - require.NoError(p.t, err) - if strings.Contains(buffer.String(), str) { - break + + select { + case err := <-match: + if err != nil { + p.t.Fatalf("%s: read error: %v (wanted %q; got %q)", time.Now(), err, str, buffer.String()) + return "" } + p.t.Logf("%s: matched %q = %q", time.Now(), str, buffer.String()) + return buffer.String() + case <-timeout.Done(): + // Ensure goroutine is cleaned up before test exit. + _ = p.out.closeErr(p.Close()) + <-match + + p.t.Fatalf("%s: match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String()) + return "" } - p.t.Logf("matched %q = %q", str, stripAnsi.ReplaceAllString(buffer.String(), "")) - return buffer.String() } func (p *PTY) Write(r rune) { + p.t.Helper() + _, err := p.Input().Write([]byte{byte(r)}) require.NoError(p.t, err) } func (p *PTY) WriteLine(str string) { + p.t.Helper() + newline := []byte{'\r'} if runtime.GOOS == "windows" { newline = append(newline, '\n') @@ -120,3 +145,101 @@ func (p *PTY) WriteLine(str string) { _, err := p.Input().Write(append([]byte(str), newline...)) require.NoError(p.t, err) } + +// stdbuf is like a buffered stdout, it buffers writes until read. +type stdbuf struct { + r io.Reader + + mu sync.Mutex // Protects following. + b []byte + more chan struct{} + err error +} + +func newStdbuf() *stdbuf { + return &stdbuf{more: make(chan struct{}, 1)} +} + +func (b *stdbuf) Read(p []byte) (int, error) { + if b.r == nil { + return b.readOrWaitForMore(p) + } + + n, err := b.r.Read(p) + if xerrors.Is(err, io.EOF) { + b.r = nil + err = nil + if n == 0 { + return b.readOrWaitForMore(p) + } + } + return n, err +} + +func (b *stdbuf) readOrWaitForMore(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + + // Deplete channel so that more check + // is for future input into buffer. + select { + case <-b.more: + default: + } + + if len(b.b) == 0 { + if b.err != nil { + return 0, b.err + } + + b.mu.Unlock() + <-b.more + b.mu.Lock() + } + + b.r = bytes.NewReader(b.b) + b.b = b.b[len(b.b):] + + return b.r.Read(p) +} + +func (b *stdbuf) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + + b.mu.Lock() + defer b.mu.Unlock() + + if b.err != nil { + return 0, b.err + } + + b.b = append(b.b, p...) + + select { + case b.more <- struct{}{}: + default: + } + + return len(p), nil +} + +func (b *stdbuf) Close() error { + return b.closeErr(nil) +} + +func (b *stdbuf) closeErr(err error) error { + b.mu.Lock() + defer b.mu.Unlock() + if b.err != nil { + return err + } + if err == nil { + b.err = io.EOF + } else { + b.err = err + } + close(b.more) + return err +} diff --git a/pty/ptytest/ptytest_internal_test.go b/pty/ptytest/ptytest_internal_test.go new file mode 100644 index 0000000000000..29154178636f6 --- /dev/null +++ b/pty/ptytest/ptytest_internal_test.go @@ -0,0 +1,37 @@ +package ptytest + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStdbuf(t *testing.T) { + t.Parallel() + + var got bytes.Buffer + + b := newStdbuf() + done := make(chan struct{}) + go func() { + defer close(done) + _, err := io.Copy(&got, b) + assert.NoError(t, err) + }() + + _, err := b.Write([]byte("hello ")) + require.NoError(t, err) + _, err = b.Write([]byte("world\n")) + require.NoError(t, err) + _, err = b.Write([]byte("bye\n")) + require.NoError(t, err) + + err = b.Close() + require.NoError(t, err) + <-done + + assert.Equal(t, "hello world\nbye\n", got.String()) +} diff --git a/pty/ptytest/ptytest_test.go b/pty/ptytest/ptytest_test.go index 7dfba01f04478..764ede12aec2c 100644 --- a/pty/ptytest/ptytest_test.go +++ b/pty/ptytest/ptytest_test.go @@ -2,7 +2,6 @@ package ptytest_test import ( "fmt" - "runtime" "strings" "testing" @@ -22,26 +21,24 @@ func TestPtytest(t *testing.T) { pty.WriteLine("read") }) + // See https://github.com/coder/coder/issues/2122 for the motivation + // behind this test. t.Run("Cobra ptytest should not hang when output is not consumed", func(t *testing.T) { t.Parallel() tests := []struct { name string output string - isPlatformBug bool // See https://github.com/coder/coder/issues/2122 for more info. + isPlatformBug bool }{ {name: "1024 is safe (does not exceed macOS buffer)", output: strings.Repeat(".", 1024)}, - {name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025), isPlatformBug: true}, - {name: "10241 large output", output: strings.Repeat(".", 10241), isPlatformBug: true}, // 1024 * 10 + 1 + {name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025)}, + {name: "10241 large output", output: strings.Repeat(".", 10241)}, // 1024 * 10 + 1 } for _, tt := range tests { tt := tt // nolint:paralleltest // Avoid parallel test to more easily identify the issue. t.Run(tt.name, func(t *testing.T) { - if tt.isPlatformBug && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") { - t.Skip("This test hangs on macOS and Windows, see https://github.com/coder/coder/issues/2122") - } - cmd := cobra.Command{ Use: "test", RunE: func(cmd *cobra.Command, args []string) error {
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: