Skip to content

Commit 71dc91e

Browse files
authored
fix: fix loss of buffered input on cliui.Prompt (#15421)
fixes coder/internal#203
1 parent 0987281 commit 71dc91e

File tree

2 files changed

+84
-21
lines changed

2 files changed

+84
-21
lines changed

cli/cliui/prompt.go

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package cliui
22

33
import (
4-
"bufio"
54
"bytes"
65
"encoding/json"
76
"fmt"
7+
"io"
88
"os"
99
"os/signal"
1010
"strings"
@@ -96,14 +96,13 @@ func Prompt(inv *serpent.Invocation, opts PromptOptions) (string, error) {
9696
signal.Notify(interrupt, os.Interrupt)
9797
defer signal.Stop(interrupt)
9898

99-
reader := bufio.NewReader(inv.Stdin)
100-
line, err = reader.ReadString('\n')
99+
line, err = readUntil(inv.Stdin, '\n')
101100

102101
// Check if the first line beings with JSON object or array chars.
103102
// This enables multiline JSON to be pasted into an input, and have
104103
// it parse properly.
105104
if err == nil && (strings.HasPrefix(line, "{") || strings.HasPrefix(line, "[")) {
106-
line, err = promptJSON(reader, line)
105+
line, err = promptJSON(inv.Stdin, line)
107106
}
108107
}
109108
if err != nil {
@@ -144,7 +143,7 @@ func Prompt(inv *serpent.Invocation, opts PromptOptions) (string, error) {
144143
}
145144
}
146145

147-
func promptJSON(reader *bufio.Reader, line string) (string, error) {
146+
func promptJSON(reader io.Reader, line string) (string, error) {
148147
var data bytes.Buffer
149148
for {
150149
_, _ = data.WriteString(line)
@@ -162,7 +161,7 @@ func promptJSON(reader *bufio.Reader, line string) (string, error) {
162161
// Read line-by-line. We can't use a JSON decoder
163162
// here because it doesn't work by newline, so
164163
// reads will block.
165-
line, err = reader.ReadString('\n')
164+
line, err = readUntil(reader, '\n')
166165
if err != nil {
167166
break
168167
}
@@ -179,3 +178,29 @@ func promptJSON(reader *bufio.Reader, line string) (string, error) {
179178
}
180179
return line, nil
181180
}
181+
182+
// readUntil the first occurrence of delim in the input, returning a string containing the data up
183+
// to and including the delimiter. Unlike `bufio`, it only reads until the delimiter and no further
184+
// bytes. If readUntil encounters an error before finding a delimiter, it returns the data read
185+
// before the error and the error itself (often io.EOF). readUntil returns err != nil if and only if
186+
// the returned data does not end in delim.
187+
func readUntil(r io.Reader, delim byte) (string, error) {
188+
var (
189+
have []byte
190+
b = make([]byte, 1)
191+
)
192+
for {
193+
n, err := r.Read(b)
194+
if n > 0 {
195+
have = append(have, b[0])
196+
if b[0] == delim {
197+
// match `bufio` in that we only return non-nil if we didn't find the delimiter,
198+
// regardless of whether we also erred.
199+
return string(have), nil
200+
}
201+
}
202+
if err != nil {
203+
return string(have), err
204+
}
205+
}
206+
}

cli/cliui/prompt_test.go

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/stretchr/testify/assert"
1212
"github.com/stretchr/testify/require"
13+
"golang.org/x/xerrors"
1314

1415
"github.com/coder/coder/v2/cli/cliui"
1516
"github.com/coder/coder/v2/pty"
@@ -22,26 +23,29 @@ func TestPrompt(t *testing.T) {
2223
t.Parallel()
2324
t.Run("Success", func(t *testing.T) {
2425
t.Parallel()
26+
ctx := testutil.Context(t, testutil.WaitShort)
2527
ptty := ptytest.New(t)
2628
msgChan := make(chan string)
2729
go func() {
28-
resp, err := newPrompt(ptty, cliui.PromptOptions{
30+
resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{
2931
Text: "Example",
3032
}, nil)
3133
assert.NoError(t, err)
3234
msgChan <- resp
3335
}()
3436
ptty.ExpectMatch("Example")
3537
ptty.WriteLine("hello")
36-
require.Equal(t, "hello", <-msgChan)
38+
resp := testutil.RequireRecvCtx(ctx, t, msgChan)
39+
require.Equal(t, "hello", resp)
3740
})
3841

3942
t.Run("Confirm", func(t *testing.T) {
4043
t.Parallel()
44+
ctx := testutil.Context(t, testutil.WaitShort)
4145
ptty := ptytest.New(t)
4246
doneChan := make(chan string)
4347
go func() {
44-
resp, err := newPrompt(ptty, cliui.PromptOptions{
48+
resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{
4549
Text: "Example",
4650
IsConfirm: true,
4751
}, nil)
@@ -50,18 +54,20 @@ func TestPrompt(t *testing.T) {
5054
}()
5155
ptty.ExpectMatch("Example")
5256
ptty.WriteLine("yes")
53-
require.Equal(t, "yes", <-doneChan)
57+
resp := testutil.RequireRecvCtx(ctx, t, doneChan)
58+
require.Equal(t, "yes", resp)
5459
})
5560

5661
t.Run("Skip", func(t *testing.T) {
5762
t.Parallel()
63+
ctx := testutil.Context(t, testutil.WaitShort)
5864
ptty := ptytest.New(t)
5965
var buf bytes.Buffer
6066

6167
// Copy all data written out to a buffer. When we close the ptty, we can
6268
// no longer read from the ptty.Output(), but we can read what was
6369
// written to the buffer.
64-
dataRead, doneReading := context.WithTimeout(context.Background(), testutil.WaitShort)
70+
dataRead, doneReading := context.WithCancel(ctx)
6571
go func() {
6672
// This will throw an error sometimes. The underlying ptty
6773
// has its own cleanup routines in t.Cleanup. Instead of
@@ -74,7 +80,7 @@ func TestPrompt(t *testing.T) {
7480

7581
doneChan := make(chan string)
7682
go func() {
77-
resp, err := newPrompt(ptty, cliui.PromptOptions{
83+
resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{
7884
Text: "ShouldNotSeeThis",
7985
IsConfirm: true,
8086
}, func(inv *serpent.Invocation) {
@@ -85,7 +91,8 @@ func TestPrompt(t *testing.T) {
8591
doneChan <- resp
8692
}()
8793

88-
require.Equal(t, "yes", <-doneChan)
94+
resp := testutil.RequireRecvCtx(ctx, t, doneChan)
95+
require.Equal(t, "yes", resp)
8996
// Close the reader to end the io.Copy
9097
require.NoError(t, ptty.Close(), "close eof reader")
9198
// Wait for the IO copy to finish
@@ -96,42 +103,47 @@ func TestPrompt(t *testing.T) {
96103
})
97104
t.Run("JSON", func(t *testing.T) {
98105
t.Parallel()
106+
ctx := testutil.Context(t, testutil.WaitShort)
99107
ptty := ptytest.New(t)
100108
doneChan := make(chan string)
101109
go func() {
102-
resp, err := newPrompt(ptty, cliui.PromptOptions{
110+
resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{
103111
Text: "Example",
104112
}, nil)
105113
assert.NoError(t, err)
106114
doneChan <- resp
107115
}()
108116
ptty.ExpectMatch("Example")
109117
ptty.WriteLine("{}")
110-
require.Equal(t, "{}", <-doneChan)
118+
resp := testutil.RequireRecvCtx(ctx, t, doneChan)
119+
require.Equal(t, "{}", resp)
111120
})
112121

113122
t.Run("BadJSON", func(t *testing.T) {
114123
t.Parallel()
124+
ctx := testutil.Context(t, testutil.WaitShort)
115125
ptty := ptytest.New(t)
116126
doneChan := make(chan string)
117127
go func() {
118-
resp, err := newPrompt(ptty, cliui.PromptOptions{
128+
resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{
119129
Text: "Example",
120130
}, nil)
121131
assert.NoError(t, err)
122132
doneChan <- resp
123133
}()
124134
ptty.ExpectMatch("Example")
125135
ptty.WriteLine("{a")
126-
require.Equal(t, "{a", <-doneChan)
136+
resp := testutil.RequireRecvCtx(ctx, t, doneChan)
137+
require.Equal(t, "{a", resp)
127138
})
128139

129140
t.Run("MultilineJSON", func(t *testing.T) {
130141
t.Parallel()
142+
ctx := testutil.Context(t, testutil.WaitShort)
131143
ptty := ptytest.New(t)
132144
doneChan := make(chan string)
133145
go func() {
134-
resp, err := newPrompt(ptty, cliui.PromptOptions{
146+
resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{
135147
Text: "Example",
136148
}, nil)
137149
assert.NoError(t, err)
@@ -141,11 +153,37 @@ func TestPrompt(t *testing.T) {
141153
ptty.WriteLine(`{
142154
"test": "wow"
143155
}`)
144-
require.Equal(t, `{"test":"wow"}`, <-doneChan)
156+
resp := testutil.RequireRecvCtx(ctx, t, doneChan)
157+
require.Equal(t, `{"test":"wow"}`, resp)
158+
})
159+
160+
t.Run("InvalidValid", func(t *testing.T) {
161+
t.Parallel()
162+
ctx := testutil.Context(t, testutil.WaitShort)
163+
ptty := ptytest.New(t)
164+
doneChan := make(chan string)
165+
go func() {
166+
resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{
167+
Text: "Example",
168+
Validate: func(s string) error {
169+
t.Logf("validate: %q", s)
170+
if s != "valid" {
171+
return xerrors.New("invalid")
172+
}
173+
return nil
174+
},
175+
}, nil)
176+
assert.NoError(t, err)
177+
doneChan <- resp
178+
}()
179+
ptty.ExpectMatch("Example")
180+
ptty.WriteLine("foo\nbar\nbaz\n\n\nvalid\n")
181+
resp := testutil.RequireRecvCtx(ctx, t, doneChan)
182+
require.Equal(t, "valid", resp)
145183
})
146184
}
147185

148-
func newPrompt(ptty *ptytest.PTY, opts cliui.PromptOptions, invOpt func(inv *serpent.Invocation)) (string, error) {
186+
func newPrompt(ctx context.Context, ptty *ptytest.PTY, opts cliui.PromptOptions, invOpt func(inv *serpent.Invocation)) (string, error) {
149187
value := ""
150188
cmd := &serpent.Command{
151189
Handler: func(inv *serpent.Invocation) error {
@@ -163,7 +201,7 @@ func newPrompt(ptty *ptytest.PTY, opts cliui.PromptOptions, invOpt func(inv *ser
163201
inv.Stdout = ptty.Output()
164202
inv.Stderr = ptty.Output()
165203
inv.Stdin = ptty.Input()
166-
return value, inv.WithContext(context.Background()).Run()
204+
return value, inv.WithContext(ctx).Run()
167205
}
168206

169207
func TestPasswordTerminalState(t *testing.T) {

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy