Skip to content

Commit 92a95fb

Browse files
authored
fix: Rewrite ptytest to buffer stdout (#3170)
Fixes #2122
1 parent d7dee2c commit 92a95fb

File tree

3 files changed

+213
-56
lines changed

3 files changed

+213
-56
lines changed

pty/ptytest/ptytest.go

Lines changed: 171 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,116 +7,239 @@ import (
77
"io"
88
"os"
99
"os/exec"
10-
"regexp"
1110
"runtime"
1211
"strings"
12+
"sync"
1313
"testing"
1414
"time"
1515
"unicode/utf8"
1616

1717
"github.com/stretchr/testify/require"
18+
"golang.org/x/xerrors"
1819

1920
"github.com/coder/coder/pty"
2021
)
2122

22-
var (
23-
// Used to ensure terminal output doesn't have anything crazy!
24-
// See: https://stackoverflow.com/a/29497680
25-
stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))")
26-
)
27-
2823
func New(t *testing.T) *PTY {
2924
ptty, err := pty.New()
3025
require.NoError(t, err)
3126

32-
return create(t, ptty)
27+
return create(t, ptty, "cmd")
3328
}
3429

3530
func Start(t *testing.T, cmd *exec.Cmd) (*PTY, *os.Process) {
3631
ptty, ps, err := pty.Start(cmd)
3732
require.NoError(t, err)
38-
return create(t, ptty), ps
33+
return create(t, ptty, cmd.Args[0]), ps
3934
}
4035

41-
func create(t *testing.T, ptty pty.PTY) *PTY {
42-
reader, writer := io.Pipe()
43-
scanner := bufio.NewScanner(reader)
36+
func create(t *testing.T, ptty pty.PTY, name string) *PTY {
37+
// Use pipe for logging.
38+
logDone := make(chan struct{})
39+
logr, logw := io.Pipe()
4440
t.Cleanup(func() {
45-
_ = reader.Close()
46-
_ = writer.Close()
41+
_ = logw.Close()
42+
_ = logr.Close()
43+
<-logDone // Guard against logging after test.
4744
})
4845
go func() {
49-
for scanner.Scan() {
50-
if scanner.Err() != nil {
51-
return
52-
}
53-
t.Log(stripAnsi.ReplaceAllString(scanner.Text(), ""))
46+
defer close(logDone)
47+
s := bufio.NewScanner(logr)
48+
for s.Scan() {
49+
// Quote output to avoid terminal escape codes, e.g. bell.
50+
t.Logf("%s: stdout: %q", name, s.Text())
5451
}
5552
}()
5653

54+
// Write to log and output buffer.
55+
copyDone := make(chan struct{})
56+
out := newStdbuf()
57+
w := io.MultiWriter(logw, out)
58+
go func() {
59+
defer close(copyDone)
60+
_, err := io.Copy(w, ptty.Output())
61+
_ = out.closeErr(err)
62+
}()
5763
t.Cleanup(func() {
64+
_ = out.Close
5865
_ = ptty.Close()
66+
<-copyDone
5967
})
68+
6069
return &PTY{
6170
t: t,
6271
PTY: ptty,
72+
out: out,
6373

64-
outputWriter: writer,
65-
runeReader: bufio.NewReaderSize(ptty.Output(), utf8.UTFMax),
74+
runeReader: bufio.NewReaderSize(out, utf8.UTFMax),
6675
}
6776
}
6877

6978
type PTY struct {
7079
t *testing.T
7180
pty.PTY
81+
out *stdbuf
7282

73-
outputWriter io.Writer
74-
runeReader *bufio.Reader
83+
runeReader *bufio.Reader
7584
}
7685

7786
func (p *PTY) ExpectMatch(str string) string {
87+
p.t.Helper()
88+
89+
timeout, cancel := context.WithTimeout(context.Background(), 10*time.Second)
90+
defer cancel()
91+
7892
var buffer bytes.Buffer
79-
multiWriter := io.MultiWriter(&buffer, p.outputWriter)
80-
runeWriter := bufio.NewWriterSize(multiWriter, utf8.UTFMax)
81-
complete, cancelFunc := context.WithCancel(context.Background())
82-
defer cancelFunc()
93+
match := make(chan error, 1)
8394
go func() {
84-
timer := time.NewTimer(10 * time.Second)
85-
defer timer.Stop()
86-
select {
87-
case <-complete.Done():
88-
return
89-
case <-timer.C:
90-
}
91-
_ = p.Close()
92-
p.t.Errorf("%s match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String())
95+
defer close(match)
96+
match <- func() error {
97+
for {
98+
r, _, err := p.runeReader.ReadRune()
99+
if err != nil {
100+
return err
101+
}
102+
_, err = buffer.WriteRune(r)
103+
if err != nil {
104+
return err
105+
}
106+
if strings.Contains(buffer.String(), str) {
107+
return nil
108+
}
109+
}
110+
}()
93111
}()
94-
for {
95-
var r rune
96-
r, _, err := p.runeReader.ReadRune()
97-
require.NoError(p.t, err)
98-
_, err = runeWriter.WriteRune(r)
99-
require.NoError(p.t, err)
100-
err = runeWriter.Flush()
101-
require.NoError(p.t, err)
102-
if strings.Contains(buffer.String(), str) {
103-
break
112+
113+
select {
114+
case err := <-match:
115+
if err != nil {
116+
p.t.Fatalf("%s: read error: %v (wanted %q; got %q)", time.Now(), err, str, buffer.String())
117+
return ""
104118
}
119+
p.t.Logf("%s: matched %q = %q", time.Now(), str, buffer.String())
120+
return buffer.String()
121+
case <-timeout.Done():
122+
// Ensure goroutine is cleaned up before test exit.
123+
_ = p.out.closeErr(p.Close())
124+
<-match
125+
126+
p.t.Fatalf("%s: match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String())
127+
return ""
105128
}
106-
p.t.Logf("matched %q = %q", str, stripAnsi.ReplaceAllString(buffer.String(), ""))
107-
return buffer.String()
108129
}
109130

110131
func (p *PTY) Write(r rune) {
132+
p.t.Helper()
133+
111134
_, err := p.Input().Write([]byte{byte(r)})
112135
require.NoError(p.t, err)
113136
}
114137

115138
func (p *PTY) WriteLine(str string) {
139+
p.t.Helper()
140+
116141
newline := []byte{'\r'}
117142
if runtime.GOOS == "windows" {
118143
newline = append(newline, '\n')
119144
}
120145
_, err := p.Input().Write(append([]byte(str), newline...))
121146
require.NoError(p.t, err)
122147
}
148+
149+
// stdbuf is like a buffered stdout, it buffers writes until read.
150+
type stdbuf struct {
151+
r io.Reader
152+
153+
mu sync.Mutex // Protects following.
154+
b []byte
155+
more chan struct{}
156+
err error
157+
}
158+
159+
func newStdbuf() *stdbuf {
160+
return &stdbuf{more: make(chan struct{}, 1)}
161+
}
162+
163+
func (b *stdbuf) Read(p []byte) (int, error) {
164+
if b.r == nil {
165+
return b.readOrWaitForMore(p)
166+
}
167+
168+
n, err := b.r.Read(p)
169+
if xerrors.Is(err, io.EOF) {
170+
b.r = nil
171+
err = nil
172+
if n == 0 {
173+
return b.readOrWaitForMore(p)
174+
}
175+
}
176+
return n, err
177+
}
178+
179+
func (b *stdbuf) readOrWaitForMore(p []byte) (int, error) {
180+
b.mu.Lock()
181+
defer b.mu.Unlock()
182+
183+
// Deplete channel so that more check
184+
// is for future input into buffer.
185+
select {
186+
case <-b.more:
187+
default:
188+
}
189+
190+
if len(b.b) == 0 {
191+
if b.err != nil {
192+
return 0, b.err
193+
}
194+
195+
b.mu.Unlock()
196+
<-b.more
197+
b.mu.Lock()
198+
}
199+
200+
b.r = bytes.NewReader(b.b)
201+
b.b = b.b[len(b.b):]
202+
203+
return b.r.Read(p)
204+
}
205+
206+
func (b *stdbuf) Write(p []byte) (int, error) {
207+
if len(p) == 0 {
208+
return 0, nil
209+
}
210+
211+
b.mu.Lock()
212+
defer b.mu.Unlock()
213+
214+
if b.err != nil {
215+
return 0, b.err
216+
}
217+
218+
b.b = append(b.b, p...)
219+
220+
select {
221+
case b.more <- struct{}{}:
222+
default:
223+
}
224+
225+
return len(p), nil
226+
}
227+
228+
func (b *stdbuf) Close() error {
229+
return b.closeErr(nil)
230+
}
231+
232+
func (b *stdbuf) closeErr(err error) error {
233+
b.mu.Lock()
234+
defer b.mu.Unlock()
235+
if b.err != nil {
236+
return err
237+
}
238+
if err == nil {
239+
b.err = io.EOF
240+
} else {
241+
b.err = err
242+
}
243+
close(b.more)
244+
return err
245+
}

pty/ptytest/ptytest_internal_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package ptytest
2+
3+
import (
4+
"bytes"
5+
"io"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestStdbuf(t *testing.T) {
13+
t.Parallel()
14+
15+
var got bytes.Buffer
16+
17+
b := newStdbuf()
18+
done := make(chan struct{})
19+
go func() {
20+
defer close(done)
21+
_, err := io.Copy(&got, b)
22+
assert.NoError(t, err)
23+
}()
24+
25+
_, err := b.Write([]byte("hello "))
26+
require.NoError(t, err)
27+
_, err = b.Write([]byte("world\n"))
28+
require.NoError(t, err)
29+
_, err = b.Write([]byte("bye\n"))
30+
require.NoError(t, err)
31+
32+
err = b.Close()
33+
require.NoError(t, err)
34+
<-done
35+
36+
assert.Equal(t, "hello world\nbye\n", got.String())
37+
}

pty/ptytest/ptytest_test.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package ptytest_test
22

33
import (
44
"fmt"
5-
"runtime"
65
"strings"
76
"testing"
87

@@ -22,26 +21,24 @@ func TestPtytest(t *testing.T) {
2221
pty.WriteLine("read")
2322
})
2423

24+
// See https://github.com/coder/coder/issues/2122 for the motivation
25+
// behind this test.
2526
t.Run("Cobra ptytest should not hang when output is not consumed", func(t *testing.T) {
2627
t.Parallel()
2728

2829
tests := []struct {
2930
name string
3031
output string
31-
isPlatformBug bool // See https://github.com/coder/coder/issues/2122 for more info.
32+
isPlatformBug bool
3233
}{
3334
{name: "1024 is safe (does not exceed macOS buffer)", output: strings.Repeat(".", 1024)},
34-
{name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025), isPlatformBug: true},
35-
{name: "10241 large output", output: strings.Repeat(".", 10241), isPlatformBug: true}, // 1024 * 10 + 1
35+
{name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025)},
36+
{name: "10241 large output", output: strings.Repeat(".", 10241)}, // 1024 * 10 + 1
3637
}
3738
for _, tt := range tests {
3839
tt := tt
3940
// nolint:paralleltest // Avoid parallel test to more easily identify the issue.
4041
t.Run(tt.name, func(t *testing.T) {
41-
if tt.isPlatformBug && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") {
42-
t.Skip("This test hangs on macOS and Windows, see https://github.com/coder/coder/issues/2122")
43-
}
44-
4542
cmd := cobra.Command{
4643
Use: "test",
4744
RunE: func(cmd *cobra.Command, args []string) error {

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy