diff --git a/dns/dnsmessage/message.go b/dns/dnsmessage/message.go index 42987ab7c..a656efc12 100644 --- a/dns/dnsmessage/message.go +++ b/dns/dnsmessage/message.go @@ -273,7 +273,6 @@ var ( errTooManyAdditionals = errors.New("too many Additionals to pack (>65535)") errNonCanonicalName = errors.New("name is not in canonical format (it must end with a .)") errStringTooLong = errors.New("character string exceeds maximum length (255)") - errCompressedSRV = errors.New("compressed name in SRV resource data") ) // Internal constants. @@ -2028,10 +2027,6 @@ func (n *Name) pack(msg []byte, compression map[string]uint16, compressionOff in // unpack unpacks a domain name. func (n *Name) unpack(msg []byte, off int) (int, error) { - return n.unpackCompressed(msg, off, true /* allowCompression */) -} - -func (n *Name) unpackCompressed(msg []byte, off int, allowCompression bool) (int, error) { // currOff is the current working offset. currOff := off @@ -2076,9 +2071,6 @@ Loop: name = append(name, '.') currOff = endOff case 0xC0: // Pointer - if !allowCompression { - return off, errCompressedSRV - } if currOff >= len(msg) { return off, errInvalidPtr } @@ -2549,7 +2541,7 @@ func unpackSRVResource(msg []byte, off int) (SRVResource, error) { return SRVResource{}, &nestedError{"Port", err} } var target Name - if _, err := target.unpackCompressed(msg, off, false /* allowCompression */); err != nil { + if _, err := target.unpack(msg, off); err != nil { return SRVResource{}, &nestedError{"Target", err} } return SRVResource{priority, weight, port, target}, nil diff --git a/dns/dnsmessage/message_test.go b/dns/dnsmessage/message_test.go index c84d5a3aa..255530598 100644 --- a/dns/dnsmessage/message_test.go +++ b/dns/dnsmessage/message_test.go @@ -303,28 +303,6 @@ func TestNameUnpackTooLongName(t *testing.T) { } } -func TestIncompressibleName(t *testing.T) { - name := MustNewName("example.com.") - compression := map[string]uint16{} - buf, err := name.pack(make([]byte, 0, 100), compression, 0) - if err != nil { - t.Fatal("first Name.pack() =", err) - } - buf, err = name.pack(buf, compression, 0) - if err != nil { - t.Fatal("second Name.pack() =", err) - } - var n1 Name - off, err := n1.unpackCompressed(buf, 0, false /* allowCompression */) - if err != nil { - t.Fatal("unpacking incompressible name without pointers failed:", err) - } - var n2 Name - if _, err := n2.unpackCompressed(buf, off, false /* allowCompression */); err != errCompressedSRV { - t.Errorf("unpacking compressed incompressible name with pointers: got %v, want = %v", err, errCompressedSRV) - } -} - func checkErrorPrefix(err error, prefix string) bool { e, ok := err.(*nestedError) return ok && e.s == prefix @@ -1657,7 +1635,7 @@ func FuzzUnpackPack(f *testing.F) { msgPacked, err := m.Pack() if err != nil { - t.Fatalf("failed to pack message that was succesfully unpacked: %v", err) + t.Fatalf("failed to pack message that was successfully unpacked: %v", err) } var m2 Message diff --git a/go.mod b/go.mod index 3bd487f5a..36207106d 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module golang.org/x/net go 1.18 require ( - golang.org/x/crypto v0.18.0 - golang.org/x/sys v0.16.0 - golang.org/x/term v0.16.0 + golang.org/x/crypto v0.21.0 + golang.org/x/sys v0.18.0 + golang.org/x/term v0.18.0 golang.org/x/text v0.14.0 ) diff --git a/go.sum b/go.sum index 8eeaf16c6..69fb10498 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ -golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= -golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE= -golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= +golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= diff --git a/html/token.go b/html/token.go index de67f938a..3c57880d6 100644 --- a/html/token.go +++ b/html/token.go @@ -910,9 +910,6 @@ func (z *Tokenizer) readTagAttrKey() { return } switch c { - case ' ', '\n', '\r', '\t', '\f', '/': - z.pendingAttr[0].end = z.raw.end - 1 - return case '=': if z.pendingAttr[0].start+1 == z.raw.end { // WHATWG 13.2.5.32, if we see an equals sign before the attribute name @@ -920,7 +917,9 @@ func (z *Tokenizer) readTagAttrKey() { continue } fallthrough - case '>': + case ' ', '\n', '\r', '\t', '\f', '/', '>': + // WHATWG 13.2.5.33 Attribute name state + // We need to reconsume the char in the after attribute name state to support the / character z.raw.end-- z.pendingAttr[0].end = z.raw.end return @@ -939,6 +938,11 @@ func (z *Tokenizer) readTagAttrVal() { if z.err != nil { return } + if c == '/' { + // WHATWG 13.2.5.34 After attribute name state + // U+002F SOLIDUS (/) - Switch to the self-closing start tag state. + return + } if c != '=' { z.raw.end-- return diff --git a/html/token_test.go b/html/token_test.go index b2383a951..8b0d5aab6 100644 --- a/html/token_test.go +++ b/html/token_test.go @@ -601,6 +601,21 @@ var tokenTests = []tokenTest{ `

`, `

`, }, + { + "forward slash before attribute name", + `

`, + `

`, + }, + { + "forward slash before attribute name with spaces around", + `

`, + `

`, + }, + { + "forward slash after attribute name followed by a character", + `

`, + `

`, + }, } func TestTokenizer(t *testing.T) { diff --git a/http/httpproxy/proxy.go b/http/httpproxy/proxy.go index c3bd9a1ee..6404aaf15 100644 --- a/http/httpproxy/proxy.go +++ b/http/httpproxy/proxy.go @@ -149,10 +149,7 @@ func parseProxy(proxy string) (*url.URL, error) { } proxyURL, err := url.Parse(proxy) - if err != nil || - (proxyURL.Scheme != "http" && - proxyURL.Scheme != "https" && - proxyURL.Scheme != "socks5") { + if err != nil || proxyURL.Scheme == "" || proxyURL.Host == "" { // proxy was bogus. Try prepending "http://" to it and // see if that parses correctly. If not, we fall // through and complain about the original one. diff --git a/http/httpproxy/proxy_test.go b/http/httpproxy/proxy_test.go index d76373295..790afdab7 100644 --- a/http/httpproxy/proxy_test.go +++ b/http/httpproxy/proxy_test.go @@ -68,6 +68,12 @@ var proxyForURLTests = []proxyForURLTest{{ HTTPProxy: "cache.corp.example.com", }, want: "http://cache.corp.example.com", +}, { + // single label domain is recognized as scheme by url.Parse + cfg: httpproxy.Config{ + HTTPProxy: "localhost", + }, + want: "http://localhost", }, { cfg: httpproxy.Config{ HTTPProxy: "https://cache.corp.example.com", @@ -88,6 +94,12 @@ var proxyForURLTests = []proxyForURLTest{{ HTTPProxy: "socks5://127.0.0.1", }, want: "socks5://127.0.0.1", +}, { + // Preserve unknown schemes. + cfg: httpproxy.Config{ + HTTPProxy: "foo://host", + }, + want: "foo://host", }, { // Don't use secure for http cfg: httpproxy.Config{ diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go new file mode 100644 index 000000000..4237b1436 --- /dev/null +++ b/http2/clientconn_test.go @@ -0,0 +1,829 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Infrastructure for testing ClientConn.RoundTrip. +// Put actual tests in transport_test.go. + +package http2 + +import ( + "bytes" + "fmt" + "io" + "net" + "net/http" + "reflect" + "slices" + "testing" + "time" + + "golang.org/x/net/http2/hpack" +) + +// TestTestClientConn demonstrates usage of testClientConn. +func TestTestClientConn(t *testing.T) { + // newTestClientConn creates a *ClientConn and surrounding test infrastructure. + tc := newTestClientConn(t) + + // tc.greet reads the client's initial SETTINGS and WINDOW_UPDATE frames, + // and sends a SETTINGS frame to the client. + // + // Additional settings may be provided as optional parameters to greet. + tc.greet() + + // Request bodies must either be constant (bytes.Buffer, strings.Reader) + // or created with newRequestBody. + body := tc.newRequestBody() + body.writeBytes(10) // 10 arbitrary bytes... + body.closeWithError(io.EOF) // ...followed by EOF. + + // tc.roundTrip calls RoundTrip, but does not wait for it to return. + // It returns a testRoundTrip. + req, _ := http.NewRequest("PUT", "https://dummy.tld/", body) + rt := tc.roundTrip(req) + + // tc has a number of methods to check for expected frames sent. + // Here, we look for headers and the request body. + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: false, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"PUT"}, + ":path": []string{"/"}, + }, + }) + // Expect 10 bytes of request body in DATA frames. + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: true, + size: 10, + }) + + // tc.writeHeaders sends a HEADERS frame back to the client. + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + + // Now that we've received headers, RoundTrip has finished. + // testRoundTrip has various methods to examine the response, + // or to fetch the response and/or error returned by RoundTrip + rt.wantStatus(200) + rt.wantBody(nil) +} + +// A testClientConn allows testing ClientConn.RoundTrip against a fake server. +// +// A test using testClientConn consists of: +// - actions on the client (calling RoundTrip, making data available to Request.Body); +// - validation of frames sent by the client to the server; and +// - providing frames from the server to the client. +// +// testClientConn manages synchronization, so tests can generally be written as +// a linear sequence of actions and validations without additional synchronization. +type testClientConn struct { + t *testing.T + + tr *Transport + fr *Framer + cc *ClientConn + hooks *testSyncHooks + + encbuf bytes.Buffer + enc *hpack.Encoder + + roundtrips []*testRoundTrip + + rerr error // returned by Read + netConnClosed bool // set when the ClientConn closes the net.Conn + rbuf bytes.Buffer // sent to the test conn + wbuf bytes.Buffer // sent by the test conn +} + +func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn { + tc := &testClientConn{ + t: t, + tr: cc.t, + cc: cc, + hooks: cc.t.syncHooks, + } + cc.tconn = (*testClientConnNetConn)(tc) + tc.enc = hpack.NewEncoder(&tc.encbuf) + tc.fr = NewFramer(&tc.rbuf, &tc.wbuf) + tc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) + tc.fr.SetMaxReadFrameSize(10 << 20) + t.Cleanup(func() { + tc.sync() + if tc.rerr == nil { + tc.rerr = io.EOF + } + tc.sync() + }) + return tc +} + +func (tc *testClientConn) readClientPreface() { + tc.t.Helper() + // Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames. + buf := make([]byte, len(clientPreface)) + if _, err := io.ReadFull(&tc.wbuf, buf); err != nil { + tc.t.Fatalf("reading preface: %v", err) + } + if !bytes.Equal(buf, clientPreface) { + tc.t.Fatalf("client preface: %q, want %q", buf, clientPreface) + } +} + +func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn { + t.Helper() + + tt := newTestTransport(t, opts...) + const singleUse = false + _, err := tt.tr.newClientConn(nil, singleUse, tt.tr.syncHooks) + if err != nil { + t.Fatalf("newClientConn: %v", err) + } + + return tt.getConn() +} + +// sync waits for the ClientConn under test to reach a stable state, +// with all goroutines blocked on some input. +func (tc *testClientConn) sync() { + tc.hooks.waitInactive() +} + +// advance advances synthetic time by a duration. +func (tc *testClientConn) advance(d time.Duration) { + tc.hooks.advance(d) + tc.sync() +} + +// hasFrame reports whether a frame is available to be read. +func (tc *testClientConn) hasFrame() bool { + return tc.wbuf.Len() > 0 +} + +// readFrame reads the next frame from the conn. +func (tc *testClientConn) readFrame() Frame { + if tc.wbuf.Len() == 0 { + return nil + } + fr, err := tc.fr.ReadFrame() + if err != nil { + return nil + } + return fr +} + +// testClientConnReadFrame reads a frame of a specific type from the conn. +func testClientConnReadFrame[T any](tc *testClientConn) T { + tc.t.Helper() + var v T + fr := tc.readFrame() + if fr == nil { + tc.t.Fatalf("got no frame, want frame %T", v) + } + v, ok := fr.(T) + if !ok { + tc.t.Fatalf("got frame %T, want %T", fr, v) + } + return v +} + +// wantFrameType reads the next frame from the conn. +// It produces an error if the frame type is not the expected value. +func (tc *testClientConn) wantFrameType(want FrameType) { + tc.t.Helper() + fr := tc.readFrame() + if fr == nil { + tc.t.Fatalf("got no frame, want frame %v", want) + } + if got := fr.Header().Type; got != want { + tc.t.Fatalf("got frame %v, want %v", got, want) + } +} + +// wantUnorderedFrames reads frames from the conn until every condition in want has been satisfied. +// +// want is a list of func(*SomeFrame) bool. +// wantUnorderedFrames will call each func with frames of the appropriate type +// until the func returns true. +// It calls t.Fatal if an unexpected frame is received (no func has that frame type, +// or all funcs with that type have returned true), or if the conn runs out of frames +// with unsatisfied funcs. +// +// Example: +// +// // Read a SETTINGS frame, and any number of DATA frames for a stream. +// // The SETTINGS frame may appear anywhere in the sequence. +// // The last DATA frame must indicate the end of the stream. +// tc.wantUnorderedFrames( +// func(f *SettingsFrame) bool { +// return true +// }, +// func(f *DataFrame) bool { +// return f.StreamEnded() +// }, +// ) +func (tc *testClientConn) wantUnorderedFrames(want ...any) { + tc.t.Helper() + want = slices.Clone(want) + seen := 0 +frame: + for seen < len(want) && !tc.t.Failed() { + fr := tc.readFrame() + if fr == nil { + break + } + for i, f := range want { + if f == nil { + continue + } + typ := reflect.TypeOf(f) + if typ.Kind() != reflect.Func || + typ.NumIn() != 1 || + typ.NumOut() != 1 || + typ.Out(0) != reflect.TypeOf(true) { + tc.t.Fatalf("expected func(*SomeFrame) bool, got %T", f) + } + if typ.In(0) == reflect.TypeOf(fr) { + out := reflect.ValueOf(f).Call([]reflect.Value{reflect.ValueOf(fr)}) + if out[0].Bool() { + want[i] = nil + seen++ + } + continue frame + } + } + tc.t.Errorf("got unexpected frame type %T", fr) + } + if seen < len(want) { + for _, f := range want { + if f == nil { + continue + } + tc.t.Errorf("did not see expected frame: %v", reflect.TypeOf(f).In(0)) + } + tc.t.Fatalf("did not see %v expected frame types", len(want)-seen) + } +} + +type wantHeader struct { + streamID uint32 + endStream bool + header http.Header +} + +// wantHeaders reads a HEADERS frame and potential CONTINUATION frames, +// and asserts that they contain the expected headers. +func (tc *testClientConn) wantHeaders(want wantHeader) { + tc.t.Helper() + got := testClientConnReadFrame[*MetaHeadersFrame](tc) + if got, want := got.StreamID, want.streamID; got != want { + tc.t.Fatalf("got stream ID %v, want %v", got, want) + } + if got, want := got.StreamEnded(), want.endStream; got != want { + tc.t.Fatalf("got stream ended %v, want %v", got, want) + } + gotHeader := make(http.Header) + for _, f := range got.Fields { + gotHeader[f.Name] = append(gotHeader[f.Name], f.Value) + } + for k, v := range want.header { + if !reflect.DeepEqual(v, gotHeader[k]) { + tc.t.Fatalf("got header %q = %q; want %q", k, v, gotHeader[k]) + } + } +} + +type wantData struct { + streamID uint32 + endStream bool + size int +} + +// wantData reads zero or more DATA frames, and asserts that they match the expectation. +func (tc *testClientConn) wantData(want wantData) { + tc.t.Helper() + gotSize := 0 + gotEndStream := false + for tc.hasFrame() && !gotEndStream { + data := testClientConnReadFrame[*DataFrame](tc) + gotSize += len(data.Data()) + if data.StreamEnded() { + gotEndStream = true + } + } + if gotSize != want.size { + tc.t.Fatalf("got %v bytes of DATA frames, want %v", gotSize, want.size) + } + if gotEndStream != want.endStream { + tc.t.Fatalf("after %v bytes of DATA frames, got END_STREAM=%v; want %v", gotSize, gotEndStream, want.endStream) + } +} + +// testRequestBody is a Request.Body for use in tests. +type testRequestBody struct { + tc *testClientConn + + // At most one of buf or bytes can be set at any given time: + buf bytes.Buffer // specific bytes to read from the body + bytes int // body contains this many arbitrary bytes + + err error // read error (comes after any available bytes) +} + +func (tc *testClientConn) newRequestBody() *testRequestBody { + b := &testRequestBody{ + tc: tc, + } + return b +} + +// Read is called by the ClientConn to read from a request body. +func (b *testRequestBody) Read(p []byte) (n int, _ error) { + b.tc.cc.syncHooks.blockUntil(func() bool { + return b.buf.Len() > 0 || b.bytes > 0 || b.err != nil + }) + switch { + case b.buf.Len() > 0: + return b.buf.Read(p) + case b.bytes > 0: + if len(p) > b.bytes { + p = p[:b.bytes] + } + b.bytes -= len(p) + for i := range p { + p[i] = 'A' + } + return len(p), nil + default: + return 0, b.err + } +} + +// Close is called by the ClientConn when it is done reading from a request body. +func (b *testRequestBody) Close() error { + return nil +} + +// writeBytes adds n arbitrary bytes to the body. +func (b *testRequestBody) writeBytes(n int) { + b.bytes += n + b.checkWrite() + b.tc.sync() +} + +// Write adds bytes to the body. +func (b *testRequestBody) Write(p []byte) (int, error) { + n, err := b.buf.Write(p) + b.checkWrite() + b.tc.sync() + return n, err +} + +func (b *testRequestBody) checkWrite() { + if b.bytes > 0 && b.buf.Len() > 0 { + b.tc.t.Fatalf("can't interleave Write and writeBytes on request body") + } + if b.err != nil { + b.tc.t.Fatalf("can't write to request body after closeWithError") + } +} + +// closeWithError sets an error which will be returned by Read. +func (b *testRequestBody) closeWithError(err error) { + b.err = err + b.tc.sync() +} + +// roundTrip starts a RoundTrip call. +// +// (Note that the RoundTrip won't complete until response headers are received, +// the request times out, or some other terminal condition is reached.) +func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { + rt := &testRoundTrip{ + t: tc.t, + donec: make(chan struct{}), + } + tc.roundtrips = append(tc.roundtrips, rt) + tc.hooks.newstream = func(cs *clientStream) { rt.cs = cs } + tc.cc.goRun(func() { + defer close(rt.donec) + rt.resp, rt.respErr = tc.cc.RoundTrip(req) + }) + tc.sync() + tc.hooks.newstream = nil + + tc.t.Cleanup(func() { + if !rt.done() { + return + } + res, _ := rt.result() + if res != nil { + res.Body.Close() + } + }) + + return rt +} + +func (tc *testClientConn) greet(settings ...Setting) { + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.writeSettings(settings...) + tc.writeSettingsAck() + tc.wantFrameType(FrameSettings) // acknowledgement +} + +func (tc *testClientConn) writeSettings(settings ...Setting) { + tc.t.Helper() + if err := tc.fr.WriteSettings(settings...); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeSettingsAck() { + tc.t.Helper() + if err := tc.fr.WriteSettingsAck(); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeData(streamID uint32, endStream bool, data []byte) { + tc.t.Helper() + if err := tc.fr.WriteData(streamID, endStream, data); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) { + tc.t.Helper() + if err := tc.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +// makeHeaderBlockFragment encodes headers in a form suitable for inclusion +// in a HEADERS or CONTINUATION frame. +// +// It takes a list of alernating names and values. +func (tc *testClientConn) makeHeaderBlockFragment(s ...string) []byte { + if len(s)%2 != 0 { + tc.t.Fatalf("uneven list of header name/value pairs") + } + tc.encbuf.Reset() + for i := 0; i < len(s); i += 2 { + tc.enc.WriteField(hpack.HeaderField{Name: s[i], Value: s[i+1]}) + } + return tc.encbuf.Bytes() +} + +func (tc *testClientConn) writeHeaders(p HeadersFrameParam) { + tc.t.Helper() + if err := tc.fr.WriteHeaders(p); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +// writeHeadersMode writes header frames, as modified by mode: +// +// - noHeader: Don't write the header. +// - oneHeader: Write a single HEADERS frame. +// - splitHeader: Write a HEADERS frame and CONTINUATION frame. +func (tc *testClientConn) writeHeadersMode(mode headerType, p HeadersFrameParam) { + tc.t.Helper() + switch mode { + case noHeader: + case oneHeader: + tc.writeHeaders(p) + case splitHeader: + if len(p.BlockFragment) < 2 { + panic("too small") + } + contData := p.BlockFragment[1:] + contEnd := p.EndHeaders + p.BlockFragment = p.BlockFragment[:1] + p.EndHeaders = false + tc.writeHeaders(p) + tc.writeContinuation(p.StreamID, contEnd, contData) + default: + panic("bogus mode") + } +} + +func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) { + tc.t.Helper() + if err := tc.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeRSTStream(streamID uint32, code ErrCode) { + tc.t.Helper() + if err := tc.fr.WriteRSTStream(streamID, code); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writePing(ack bool, data [8]byte) { + tc.t.Helper() + if err := tc.fr.WritePing(ack, data); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) { + tc.t.Helper() + if err := tc.fr.WriteGoAway(maxStreamID, code, debugData); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +func (tc *testClientConn) writeWindowUpdate(streamID, incr uint32) { + tc.t.Helper() + if err := tc.fr.WriteWindowUpdate(streamID, incr); err != nil { + tc.t.Fatal(err) + } + tc.sync() +} + +// closeWrite causes the net.Conn used by the ClientConn to return a error +// from Read calls. +func (tc *testClientConn) closeWrite(err error) { + tc.rerr = err + tc.sync() +} + +// inflowWindow returns the amount of inbound flow control available for a stream, +// or for the connection if streamID is 0. +func (tc *testClientConn) inflowWindow(streamID uint32) int32 { + tc.cc.mu.Lock() + defer tc.cc.mu.Unlock() + if streamID == 0 { + return tc.cc.inflow.avail + tc.cc.inflow.unsent + } + cs := tc.cc.streams[streamID] + if cs == nil { + tc.t.Errorf("no stream with id %v", streamID) + return -1 + } + return cs.inflow.avail + cs.inflow.unsent +} + +// testRoundTrip manages a RoundTrip in progress. +type testRoundTrip struct { + t *testing.T + resp *http.Response + respErr error + donec chan struct{} + cs *clientStream +} + +// streamID returns the HTTP/2 stream ID of the request. +func (rt *testRoundTrip) streamID() uint32 { + if rt.cs == nil { + panic("stream ID unknown") + } + return rt.cs.ID +} + +// done reports whether RoundTrip has returned. +func (rt *testRoundTrip) done() bool { + select { + case <-rt.donec: + return true + default: + return false + } +} + +// result returns the result of the RoundTrip. +func (rt *testRoundTrip) result() (*http.Response, error) { + t := rt.t + t.Helper() + select { + case <-rt.donec: + default: + t.Fatalf("RoundTrip is not done; want it to be") + } + return rt.resp, rt.respErr +} + +// response returns the response of a successful RoundTrip. +// If the RoundTrip unexpectedly failed, it calls t.Fatal. +func (rt *testRoundTrip) response() *http.Response { + t := rt.t + t.Helper() + resp, err := rt.result() + if err != nil { + t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr) + } + if resp == nil { + t.Fatalf("RoundTrip returned nil *Response and nil error") + } + return resp +} + +// err returns the (possibly nil) error result of RoundTrip. +func (rt *testRoundTrip) err() error { + t := rt.t + t.Helper() + _, err := rt.result() + return err +} + +// wantStatus indicates the expected response StatusCode. +func (rt *testRoundTrip) wantStatus(want int) { + t := rt.t + t.Helper() + if got := rt.response().StatusCode; got != want { + t.Fatalf("got response status %v, want %v", got, want) + } +} + +// body reads the contents of the response body. +func (rt *testRoundTrip) readBody() ([]byte, error) { + t := rt.t + t.Helper() + return io.ReadAll(rt.response().Body) +} + +// wantBody indicates the expected response body. +// (Note that this consumes the body.) +func (rt *testRoundTrip) wantBody(want []byte) { + t := rt.t + t.Helper() + got, err := rt.readBody() + if err != nil { + t.Fatalf("unexpected error reading response body: %v", err) + } + if !bytes.Equal(got, want) { + t.Fatalf("unexpected response body:\ngot: %q\nwant: %q", got, want) + } +} + +// wantHeaders indicates the expected response headers. +func (rt *testRoundTrip) wantHeaders(want http.Header) { + t := rt.t + t.Helper() + res := rt.response() + if diff := diffHeaders(res.Header, want); diff != "" { + t.Fatalf("unexpected response headers:\n%v", diff) + } +} + +// wantTrailers indicates the expected response trailers. +func (rt *testRoundTrip) wantTrailers(want http.Header) { + t := rt.t + t.Helper() + res := rt.response() + if diff := diffHeaders(res.Trailer, want); diff != "" { + t.Fatalf("unexpected response trailers:\n%v", diff) + } +} + +func diffHeaders(got, want http.Header) string { + // nil and 0-length non-nil are equal. + if len(got) == 0 && len(want) == 0 { + return "" + } + // We could do a more sophisticated diff here. + // DeepEqual is good enough for now. + if reflect.DeepEqual(got, want) { + return "" + } + return fmt.Sprintf("got: %v\nwant: %v", got, want) +} + +// testClientConnNetConn implements net.Conn. +type testClientConnNetConn testClientConn + +func (nc *testClientConnNetConn) Read(b []byte) (n int, err error) { + nc.cc.syncHooks.blockUntil(func() bool { + return nc.rerr != nil || nc.rbuf.Len() > 0 + }) + if nc.rbuf.Len() > 0 { + return nc.rbuf.Read(b) + } + return 0, nc.rerr +} + +func (nc *testClientConnNetConn) Write(b []byte) (n int, err error) { + return nc.wbuf.Write(b) +} + +func (nc *testClientConnNetConn) Close() error { + nc.netConnClosed = true + return nil +} + +func (*testClientConnNetConn) LocalAddr() (_ net.Addr) { return } +func (*testClientConnNetConn) RemoteAddr() (_ net.Addr) { return } +func (*testClientConnNetConn) SetDeadline(t time.Time) error { return nil } +func (*testClientConnNetConn) SetReadDeadline(t time.Time) error { return nil } +func (*testClientConnNetConn) SetWriteDeadline(t time.Time) error { return nil } + +// A testTransport allows testing Transport.RoundTrip against fake servers. +// Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling +// should use testClientConn instead. +type testTransport struct { + t *testing.T + tr *Transport + + ccs []*testClientConn +} + +func newTestTransport(t *testing.T, opts ...func(*Transport)) *testTransport { + tr := &Transport{ + syncHooks: newTestSyncHooks(), + } + for _, o := range opts { + o(tr) + } + + tt := &testTransport{ + t: t, + tr: tr, + } + tr.syncHooks.newclientconn = func(cc *ClientConn) { + tt.ccs = append(tt.ccs, newTestClientConnFromClientConn(t, cc)) + } + + t.Cleanup(func() { + tt.sync() + if len(tt.ccs) > 0 { + t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs)) + } + if tt.tr.syncHooks.total != 0 { + t.Errorf("%v goroutines still running after test completed", tt.tr.syncHooks.total) + } + }) + + return tt +} + +func (tt *testTransport) sync() { + tt.tr.syncHooks.waitInactive() +} + +func (tt *testTransport) advance(d time.Duration) { + tt.tr.syncHooks.advance(d) + tt.sync() +} + +func (tt *testTransport) hasConn() bool { + return len(tt.ccs) > 0 +} + +func (tt *testTransport) getConn() *testClientConn { + tt.t.Helper() + if len(tt.ccs) == 0 { + tt.t.Fatalf("no new ClientConns created; wanted one") + } + tc := tt.ccs[0] + tt.ccs = tt.ccs[1:] + tc.sync() + tc.readClientPreface() + return tc +} + +func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip { + rt := &testRoundTrip{ + t: tt.t, + donec: make(chan struct{}), + } + tt.tr.syncHooks.goRun(func() { + defer close(rt.donec) + rt.resp, rt.respErr = tt.tr.RoundTrip(req) + }) + tt.sync() + + tt.t.Cleanup(func() { + if !rt.done() { + return + } + res, _ := rt.result() + if res != nil { + res.Body.Close() + } + }) + + return rt +} diff --git a/http2/frame.go b/http2/frame.go index c1f6b90dc..43557ab7e 100644 --- a/http2/frame.go +++ b/http2/frame.go @@ -1510,13 +1510,12 @@ func (mh *MetaHeadersFrame) checkPseudos() error { } func (fr *Framer) maxHeaderStringLen() int { - v := fr.maxHeaderListSize() - if uint32(int(v)) == v { - return int(v) + v := int(fr.maxHeaderListSize()) + if v < 0 { + // If maxHeaderListSize overflows an int, use no limit (0). + return 0 } - // They had a crazy big number for MaxHeaderBytes anyway, - // so give them unlimited header lengths: - return 0 + return v } // readMetaFrame returns 0 or more CONTINUATION frames from fr and @@ -1565,6 +1564,7 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) { if size > remainSize { hdec.SetEmitEnabled(false) mh.Truncated = true + remainSize = 0 return } remainSize -= size @@ -1577,6 +1577,36 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) { var hc headersOrContinuation = hf for { frag := hc.HeaderBlockFragment() + + // Avoid parsing large amounts of headers that we will then discard. + // If the sender exceeds the max header list size by too much, + // skip parsing the fragment and close the connection. + // + // "Too much" is either any CONTINUATION frame after we've already + // exceeded the max header list size (in which case remainSize is 0), + // or a frame whose encoded size is more than twice the remaining + // header list bytes we're willing to accept. + if int64(len(frag)) > int64(2*remainSize) { + if VerboseLogs { + log.Printf("http2: header list too large") + } + // It would be nice to send a RST_STREAM before sending the GOAWAY, + // but the structure of the server's frame writer makes this difficult. + return nil, ConnectionError(ErrCodeProtocol) + } + + // Also close the connection after any CONTINUATION frame following an + // invalid header, since we stop tracking the size of the headers after + // an invalid one. + if invalid != nil { + if VerboseLogs { + log.Printf("http2: invalid header: %v", invalid) + } + // It would be nice to send a RST_STREAM before sending the GOAWAY, + // but the structure of the server's frame writer makes this difficult. + return nil, ConnectionError(ErrCodeProtocol) + } + if _, err := hdec.Write(frag); err != nil { return nil, ConnectionError(ErrCodeCompression) } diff --git a/http2/pipe.go b/http2/pipe.go index 684d984fd..3b9f06b96 100644 --- a/http2/pipe.go +++ b/http2/pipe.go @@ -77,7 +77,10 @@ func (p *pipe) Read(d []byte) (n int, err error) { } } -var errClosedPipeWrite = errors.New("write on closed buffer") +var ( + errClosedPipeWrite = errors.New("write on closed buffer") + errUninitializedPipeWrite = errors.New("write on uninitialized buffer") +) // Write copies bytes from p into the buffer and wakes a reader. // It is an error to write more data than the buffer can hold. @@ -91,6 +94,12 @@ func (p *pipe) Write(d []byte) (n int, err error) { if p.err != nil || p.breakErr != nil { return 0, errClosedPipeWrite } + // pipe.setBuffer is never invoked, leaving the buffer uninitialized. + // We shouldn't try to write to an uninitialized pipe, + // but returning an error is better than panicking. + if p.b == nil { + return 0, errUninitializedPipeWrite + } return p.b.Write(d) } diff --git a/http2/server.go b/http2/server.go index ae94c6408..ce2e8b40e 100644 --- a/http2/server.go +++ b/http2/server.go @@ -124,6 +124,7 @@ type Server struct { // IdleTimeout specifies how long until idle clients should be // closed with a GOAWAY frame. PING frames are not considered // activity for the purposes of IdleTimeout. + // If zero or negative, there is no timeout. IdleTimeout time.Duration // MaxUploadBufferPerConnection is the size of the initial flow @@ -434,7 +435,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { // passes the connection off to us with the deadline already set. // Write deadlines are set per stream in serverConn.newStream. // Disarm the net.Conn write deadline here. - if sc.hs.WriteTimeout != 0 { + if sc.hs.WriteTimeout > 0 { sc.conn.SetWriteDeadline(time.Time{}) } @@ -924,7 +925,7 @@ func (sc *serverConn) serve() { sc.setConnState(http.StateActive) sc.setConnState(http.StateIdle) - if sc.srv.IdleTimeout != 0 { + if sc.srv.IdleTimeout > 0 { sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) defer sc.idleTimer.Stop() } @@ -1637,7 +1638,7 @@ func (sc *serverConn) closeStream(st *stream, err error) { delete(sc.streams, st.id) if len(sc.streams) == 0 { sc.setConnState(http.StateIdle) - if sc.srv.IdleTimeout != 0 { + if sc.srv.IdleTimeout > 0 { sc.idleTimer.Reset(sc.srv.IdleTimeout) } if h1ServerKeepAlivesDisabled(sc.hs) { @@ -2017,7 +2018,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { // similar to how the http1 server works. Here it's // technically more like the http1 Server's ReadHeaderTimeout // (in Go 1.8), though. That's a more sane option anyway. - if sc.hs.ReadTimeout != 0 { + if sc.hs.ReadTimeout > 0 { sc.conn.SetReadDeadline(time.Time{}) st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) } @@ -2038,7 +2039,7 @@ func (sc *serverConn) upgradeRequest(req *http.Request) { // Disable any read deadline set by the net/http package // prior to the upgrade. - if sc.hs.ReadTimeout != 0 { + if sc.hs.ReadTimeout > 0 { sc.conn.SetReadDeadline(time.Time{}) } @@ -2116,7 +2117,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream st.flow.conn = &sc.flow // link to conn-level counter st.flow.add(sc.initialStreamSendWindowSize) st.inflow.init(sc.srv.initialStreamRecvWindowSize()) - if sc.hs.WriteTimeout != 0 { + if sc.hs.WriteTimeout > 0 { st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) } diff --git a/http2/server_push_test.go b/http2/server_push_test.go index 9882d9ef7..cda8f4336 100644 --- a/http2/server_push_test.go +++ b/http2/server_push_test.go @@ -11,6 +11,7 @@ import ( "io/ioutil" "net/http" "reflect" + "runtime" "strconv" "sync" "testing" @@ -483,11 +484,7 @@ func TestServer_Push_RejectAfterGoAway(t *testing.T) { ready := make(chan struct{}) errc := make(chan error, 2) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - select { - case <-ready: - case <-time.After(5 * time.Second): - errc <- fmt.Errorf("timeout waiting for GOAWAY to be processed") - } + <-ready if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want { errc <- fmt.Errorf("Push()=%v, want %v", got, want) } @@ -505,6 +502,10 @@ func TestServer_Push_RejectAfterGoAway(t *testing.T) { case <-ready: return default: + if runtime.GOARCH == "wasm" { + // Work around https://go.dev/issue/65178 to avoid goroutine starvation. + runtime.Gosched() + } } st.sc.serveMsgCh <- func(loopNum int) { if !st.sc.pushEnabled { diff --git a/http2/server_test.go b/http2/server_test.go index 1fdd191ef..a931a06e5 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -4578,13 +4578,16 @@ func TestCanonicalHeaderCacheGrowth(t *testing.T) { sc := &serverConn{ serveG: newGoroutineLock(), } - const count = 1000 - for i := 0; i < count; i++ { - h := fmt.Sprintf("%v-%v", base, i) + count := 0 + added := 0 + for added < 10*maxCachedCanonicalHeadersKeysSize { + h := fmt.Sprintf("%v-%v", base, count) c := sc.canonicalHeader(h) if len(h) != len(c) { t.Errorf("sc.canonicalHeader(%q) = %q, want same length", h, c) } + count++ + added += len(h) } total := 0 for k, v := range sc.canonHeader { @@ -4783,3 +4786,89 @@ Frames: close(s) } } + +func TestServerContinuationFlood(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + fmt.Println(r.Header) + }, func(ts *httptest.Server) { + ts.Config.MaxHeaderBytes = 4096 + }) + defer st.Close() + + st.writePreface() + st.writeInitialSettings() + st.writeSettingsAck() + + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: true, + }) + for i := 0; i < 1000; i++ { + st.fr.WriteContinuation(1, false, st.encodeHeaderRaw( + fmt.Sprintf("x-%v", i), "1234567890", + )) + } + st.fr.WriteContinuation(1, true, st.encodeHeaderRaw( + "x-last-header", "1", + )) + + for { + f, err := st.readFrame() + if err != nil { + break + } + switch f.(type) { + case *HeadersFrame: + t.Fatalf("received HEADERS frame; want GOAWAY and a closed connection") + } + } + // We expect to have seen a GOAWAY before the connection closes, + // but the server will close the connection after one second + // whether or not it has finished sending the GOAWAY. On windows-amd64-race + // builders, this fairly consistently results in the connection closing without + // the GOAWAY being sent. + // + // Since the server's behavior is inherently racy here and the important thing + // is that the connection is closed, don't check for the GOAWAY having been sent. +} + +func TestServerContinuationAfterInvalidHeader(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + fmt.Println(r.Header) + }) + defer st.Close() + + st.writePreface() + st.writeInitialSettings() + st.writeSettingsAck() + + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: true, + }) + st.fr.WriteContinuation(1, false, st.encodeHeaderRaw( + "x-invalid-header", "\x00", + )) + st.fr.WriteContinuation(1, true, st.encodeHeaderRaw( + "x-valid-header", "1", + )) + + var sawGoAway bool + for { + f, err := st.readFrame() + if err != nil { + break + } + switch f.(type) { + case *GoAwayFrame: + sawGoAway = true + case *HeadersFrame: + t.Fatalf("received HEADERS frame; want GOAWAY") + } + } + if !sawGoAway { + t.Errorf("connection closed with no GOAWAY frame; want one") + } +} diff --git a/http2/testsync.go b/http2/testsync.go new file mode 100644 index 000000000..61075bd16 --- /dev/null +++ b/http2/testsync.go @@ -0,0 +1,331 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package http2 + +import ( + "context" + "sync" + "time" +) + +// testSyncHooks coordinates goroutines in tests. +// +// For example, a call to ClientConn.RoundTrip involves several goroutines, including: +// - the goroutine running RoundTrip; +// - the clientStream.doRequest goroutine, which writes the request; and +// - the clientStream.readLoop goroutine, which reads the response. +// +// Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines +// are blocked waiting for some condition such as reading the Request.Body or waiting for +// flow control to become available. +// +// The testSyncHooks also manage timers and synthetic time in tests. +// This permits us to, for example, start a request and cause it to time out waiting for +// response headers without resorting to time.Sleep calls. +type testSyncHooks struct { + // active/inactive act as a mutex and condition variable. + // + // - neither chan contains a value: testSyncHooks is locked. + // - active contains a value: unlocked, and at least one goroutine is not blocked + // - inactive contains a value: unlocked, and all goroutines are blocked + active chan struct{} + inactive chan struct{} + + // goroutine counts + total int // total goroutines + condwait map[*sync.Cond]int // blocked in sync.Cond.Wait + blocked []*testBlockedGoroutine // otherwise blocked + + // fake time + now time.Time + timers []*fakeTimer + + // Transport testing: Report various events. + newclientconn func(*ClientConn) + newstream func(*clientStream) +} + +// testBlockedGoroutine is a blocked goroutine. +type testBlockedGoroutine struct { + f func() bool // blocked until f returns true + ch chan struct{} // closed when unblocked +} + +func newTestSyncHooks() *testSyncHooks { + h := &testSyncHooks{ + active: make(chan struct{}, 1), + inactive: make(chan struct{}, 1), + condwait: map[*sync.Cond]int{}, + } + h.inactive <- struct{}{} + h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + return h +} + +// lock acquires the testSyncHooks mutex. +func (h *testSyncHooks) lock() { + select { + case <-h.active: + case <-h.inactive: + } +} + +// waitInactive waits for all goroutines to become inactive. +func (h *testSyncHooks) waitInactive() { + for { + <-h.inactive + if !h.unlock() { + break + } + } +} + +// unlock releases the testSyncHooks mutex. +// It reports whether any goroutines are active. +func (h *testSyncHooks) unlock() (active bool) { + // Look for a blocked goroutine which can be unblocked. + blocked := h.blocked[:0] + unblocked := false + for _, b := range h.blocked { + if !unblocked && b.f() { + unblocked = true + close(b.ch) + } else { + blocked = append(blocked, b) + } + } + h.blocked = blocked + + // Count goroutines blocked on condition variables. + condwait := 0 + for _, count := range h.condwait { + condwait += count + } + + if h.total > condwait+len(blocked) { + h.active <- struct{}{} + return true + } else { + h.inactive <- struct{}{} + return false + } +} + +// goRun starts a new goroutine. +func (h *testSyncHooks) goRun(f func()) { + h.lock() + h.total++ + h.unlock() + go func() { + defer func() { + h.lock() + h.total-- + h.unlock() + }() + f() + }() +} + +// blockUntil indicates that a goroutine is blocked waiting for some condition to become true. +// It waits until f returns true before proceeding. +// +// Example usage: +// +// h.blockUntil(func() bool { +// // Is the context done yet? +// select { +// case <-ctx.Done(): +// default: +// return false +// } +// return true +// }) +// // Wait for the context to become done. +// <-ctx.Done() +// +// The function f passed to blockUntil must be non-blocking and idempotent. +func (h *testSyncHooks) blockUntil(f func() bool) { + if f() { + return + } + ch := make(chan struct{}) + h.lock() + h.blocked = append(h.blocked, &testBlockedGoroutine{ + f: f, + ch: ch, + }) + h.unlock() + <-ch +} + +// broadcast is sync.Cond.Broadcast. +func (h *testSyncHooks) condBroadcast(cond *sync.Cond) { + h.lock() + delete(h.condwait, cond) + h.unlock() + cond.Broadcast() +} + +// broadcast is sync.Cond.Wait. +func (h *testSyncHooks) condWait(cond *sync.Cond) { + h.lock() + h.condwait[cond]++ + h.unlock() +} + +// newTimer creates a new fake timer. +func (h *testSyncHooks) newTimer(d time.Duration) timer { + h.lock() + defer h.unlock() + t := &fakeTimer{ + hooks: h, + when: h.now.Add(d), + c: make(chan time.Time), + } + h.timers = append(h.timers, t) + return t +} + +// afterFunc creates a new fake AfterFunc timer. +func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer { + h.lock() + defer h.unlock() + t := &fakeTimer{ + hooks: h, + when: h.now.Add(d), + f: f, + } + h.timers = append(h.timers, t) + return t +} + +func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(ctx) + t := h.afterFunc(d, cancel) + return ctx, func() { + t.Stop() + cancel() + } +} + +func (h *testSyncHooks) timeUntilEvent() time.Duration { + h.lock() + defer h.unlock() + var next time.Time + for _, t := range h.timers { + if next.IsZero() || t.when.Before(next) { + next = t.when + } + } + if d := next.Sub(h.now); d > 0 { + return d + } + return 0 +} + +// advance advances time and causes synthetic timers to fire. +func (h *testSyncHooks) advance(d time.Duration) { + h.lock() + defer h.unlock() + h.now = h.now.Add(d) + timers := h.timers[:0] + for _, t := range h.timers { + t := t // remove after go.mod depends on go1.22 + t.mu.Lock() + switch { + case t.when.After(h.now): + timers = append(timers, t) + case t.when.IsZero(): + // stopped timer + default: + t.when = time.Time{} + if t.c != nil { + close(t.c) + } + if t.f != nil { + h.total++ + go func() { + defer func() { + h.lock() + h.total-- + h.unlock() + }() + t.f() + }() + } + } + t.mu.Unlock() + } + h.timers = timers +} + +// A timer wraps a time.Timer, or a synthetic equivalent in tests. +// Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires. +type timer interface { + C() <-chan time.Time + Stop() bool + Reset(d time.Duration) bool +} + +// timeTimer implements timer using real time. +type timeTimer struct { + t *time.Timer + c chan time.Time +} + +// newTimeTimer creates a new timer using real time. +func newTimeTimer(d time.Duration) timer { + ch := make(chan time.Time) + t := time.AfterFunc(d, func() { + close(ch) + }) + return &timeTimer{t, ch} +} + +// newTimeAfterFunc creates an AfterFunc timer using real time. +func newTimeAfterFunc(d time.Duration, f func()) timer { + return &timeTimer{ + t: time.AfterFunc(d, f), + } +} + +func (t timeTimer) C() <-chan time.Time { return t.c } +func (t timeTimer) Stop() bool { return t.t.Stop() } +func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) } + +// fakeTimer implements timer using fake time. +type fakeTimer struct { + hooks *testSyncHooks + + mu sync.Mutex + when time.Time // when the timer will fire + c chan time.Time // closed when the timer fires; mutually exclusive with f + f func() // called when the timer fires; mutually exclusive with c +} + +func (t *fakeTimer) C() <-chan time.Time { return t.c } + +func (t *fakeTimer) Stop() bool { + t.mu.Lock() + defer t.mu.Unlock() + stopped := t.when.IsZero() + t.when = time.Time{} + return stopped +} + +func (t *fakeTimer) Reset(d time.Duration) bool { + if t.c != nil || t.f == nil { + panic("fakeTimer only supports Reset on AfterFunc timers") + } + t.mu.Lock() + defer t.mu.Unlock() + t.hooks.lock() + defer t.hooks.unlock() + active := !t.when.IsZero() + t.when = t.hooks.now.Add(d) + if !active { + t.hooks.timers = append(t.hooks.timers, t) + } + return active +} diff --git a/http2/transport.go b/http2/transport.go index df578b86c..ce375c8c7 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -147,6 +147,12 @@ type Transport struct { // waiting for their turn. StrictMaxConcurrentStreams bool + // IdleConnTimeout is the maximum amount of time an idle + // (keep-alive) connection will remain idle before closing + // itself. + // Zero means no limit. + IdleConnTimeout time.Duration + // ReadIdleTimeout is the timeout after which a health check using ping // frame will be carried out if no frame is received on the connection. // Note that a ping response will is considered a received frame, so if @@ -178,6 +184,8 @@ type Transport struct { connPoolOnce sync.Once connPoolOrDef ClientConnPool // non-nil version of ConnPool + + syncHooks *testSyncHooks } func (t *Transport) maxHeaderListSize() uint32 { @@ -302,7 +310,7 @@ type ClientConn struct { readerErr error // set before readerDone is closed idleTimeout time.Duration // or 0 for never - idleTimer *time.Timer + idleTimer timer mu sync.Mutex // guards following cond *sync.Cond // hold mu; broadcast on flow/closed changes @@ -344,6 +352,60 @@ type ClientConn struct { werr error // first write error that has occurred hbuf bytes.Buffer // HPACK encoder writes into this henc *hpack.Encoder + + syncHooks *testSyncHooks // can be nil +} + +// Hook points used for testing. +// Outside of tests, cc.syncHooks is nil and these all have minimal implementations. +// Inside tests, see the testSyncHooks function docs. + +// goRun starts a new goroutine. +func (cc *ClientConn) goRun(f func()) { + if cc.syncHooks != nil { + cc.syncHooks.goRun(f) + return + } + go f() +} + +// condBroadcast is cc.cond.Broadcast. +func (cc *ClientConn) condBroadcast() { + if cc.syncHooks != nil { + cc.syncHooks.condBroadcast(cc.cond) + } + cc.cond.Broadcast() +} + +// condWait is cc.cond.Wait. +func (cc *ClientConn) condWait() { + if cc.syncHooks != nil { + cc.syncHooks.condWait(cc.cond) + } + cc.cond.Wait() +} + +// newTimer creates a new time.Timer, or a synthetic timer in tests. +func (cc *ClientConn) newTimer(d time.Duration) timer { + if cc.syncHooks != nil { + return cc.syncHooks.newTimer(d) + } + return newTimeTimer(d) +} + +// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests. +func (cc *ClientConn) afterFunc(d time.Duration, f func()) timer { + if cc.syncHooks != nil { + return cc.syncHooks.afterFunc(d, f) + } + return newTimeAfterFunc(d, f) +} + +func (cc *ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + if cc.syncHooks != nil { + return cc.syncHooks.contextWithTimeout(ctx, d) + } + return context.WithTimeout(ctx, d) } // clientStream is the state for a single HTTP/2 stream. One of these @@ -425,7 +487,7 @@ func (cs *clientStream) abortStreamLocked(err error) { // TODO(dneil): Clean up tests where cs.cc.cond is nil. if cs.cc.cond != nil { // Wake up writeRequestBody if it is waiting on flow control. - cs.cc.cond.Broadcast() + cs.cc.condBroadcast() } } @@ -435,7 +497,7 @@ func (cs *clientStream) abortRequestBodyWrite() { defer cc.mu.Unlock() if cs.reqBody != nil && cs.reqBodyClosed == nil { cs.closeReqBodyLocked() - cc.cond.Broadcast() + cc.condBroadcast() } } @@ -445,10 +507,10 @@ func (cs *clientStream) closeReqBodyLocked() { } cs.reqBodyClosed = make(chan struct{}) reqBodyClosed := cs.reqBodyClosed - go func() { + cs.cc.goRun(func() { cs.reqBody.Close() close(reqBodyClosed) - }() + }) } type stickyErrWriter struct { @@ -537,15 +599,6 @@ func authorityAddr(scheme string, authority string) (addr string) { return net.JoinHostPort(host, port) } -var retryBackoffHook func(time.Duration) *time.Timer - -func backoffNewTimer(d time.Duration) *time.Timer { - if retryBackoffHook != nil { - return retryBackoffHook(d) - } - return time.NewTimer(d) -} - // RoundTripOpt is like RoundTrip, but takes options. func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { @@ -573,13 +626,27 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res backoff := float64(uint(1) << (uint(retry) - 1)) backoff += backoff * (0.1 * mathrand.Float64()) d := time.Second * time.Duration(backoff) - timer := backoffNewTimer(d) + var tm timer + if t.syncHooks != nil { + tm = t.syncHooks.newTimer(d) + t.syncHooks.blockUntil(func() bool { + select { + case <-tm.C(): + case <-req.Context().Done(): + default: + return false + } + return true + }) + } else { + tm = newTimeTimer(d) + } select { - case <-timer.C: + case <-tm.C(): t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) continue case <-req.Context().Done(): - timer.Stop() + tm.Stop() err = req.Context().Err() } } @@ -658,6 +725,9 @@ func canRetryError(err error) bool { } func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) { + if t.syncHooks != nil { + return t.newClientConn(nil, singleUse, t.syncHooks) + } host, _, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -666,7 +736,7 @@ func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse b if err != nil { return nil, err } - return t.newClientConn(tconn, singleUse) + return t.newClientConn(tconn, singleUse, nil) } func (t *Transport) newTLSConfig(host string) *tls.Config { @@ -732,10 +802,10 @@ func (t *Transport) maxEncoderHeaderTableSize() uint32 { } func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) { - return t.newClientConn(c, t.disableKeepAlives()) + return t.newClientConn(c, t.disableKeepAlives(), nil) } -func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) { +func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHooks) (*ClientConn, error) { cc := &ClientConn{ t: t, tconn: c, @@ -750,10 +820,15 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro wantSettingsAck: true, pings: make(map[[8]byte]chan struct{}), reqHeaderMu: make(chan struct{}, 1), + syncHooks: hooks, + } + if hooks != nil { + hooks.newclientconn(cc) + c = cc.tconn } if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d - cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) + cc.idleTimer = cc.afterFunc(d, cc.onIdleTimeout) } if VerboseLogs { t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) @@ -818,7 +893,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro return nil, cc.werr } - go cc.readLoop() + cc.goRun(cc.readLoop) return cc, nil } @@ -826,7 +901,7 @@ func (cc *ClientConn) healthCheck() { pingTimeout := cc.t.pingTimeout() // We don't need to periodically ping in the health check, because the readLoop of ClientConn will // trigger the healthCheck again if there is no frame received. - ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout) defer cancel() cc.vlogf("http2: Transport sending health check") err := cc.Ping(ctx) @@ -1056,7 +1131,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { // Wait for all in-flight streams to complete or connection to close done := make(chan struct{}) cancelled := false // guarded by cc.mu - go func() { + cc.goRun(func() { cc.mu.Lock() defer cc.mu.Unlock() for { @@ -1068,9 +1143,9 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { if cancelled { break } - cc.cond.Wait() + cc.condWait() } - }() + }) shutdownEnterWaitStateHook() select { case <-done: @@ -1080,7 +1155,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { cc.mu.Lock() // Free the goroutine above cancelled = true - cc.cond.Broadcast() + cc.condBroadcast() cc.mu.Unlock() return ctx.Err() } @@ -1118,7 +1193,7 @@ func (cc *ClientConn) closeForError(err error) { for _, cs := range cc.streams { cs.abortStreamLocked(err) } - cc.cond.Broadcast() + cc.condBroadcast() cc.mu.Unlock() cc.closeConn() } @@ -1215,6 +1290,10 @@ func (cc *ClientConn) decrStreamReservationsLocked() { } func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { + return cc.roundTrip(req, nil) +} + +func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) (*http.Response, error) { ctx := req.Context() cs := &clientStream{ cc: cc, @@ -1229,9 +1308,23 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { respHeaderRecv: make(chan struct{}), donec: make(chan struct{}), } - go cs.doRequest(req) + cc.goRun(func() { + cs.doRequest(req) + }) waitDone := func() error { + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-cs.donec: + case <-ctx.Done(): + case <-cs.reqCancel: + default: + return false + } + return true + }) + } select { case <-cs.donec: return nil @@ -1292,7 +1385,24 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { return err } + if streamf != nil { + streamf(cs) + } + for { + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-cs.respHeaderRecv: + case <-cs.abort: + case <-ctx.Done(): + case <-cs.reqCancel: + default: + return false + } + return true + }) + } select { case <-cs.respHeaderRecv: return handleResponseHeaders() @@ -1348,6 +1458,21 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { if cc.reqHeaderMu == nil { panic("RoundTrip on uninitialized ClientConn") // for tests } + var newStreamHook func(*clientStream) + if cc.syncHooks != nil { + newStreamHook = cc.syncHooks.newstream + cc.syncHooks.blockUntil(func() bool { + select { + case cc.reqHeaderMu <- struct{}{}: + <-cc.reqHeaderMu + case <-cs.reqCancel: + case <-ctx.Done(): + default: + return false + } + return true + }) + } select { case cc.reqHeaderMu <- struct{}{}: case <-cs.reqCancel: @@ -1372,6 +1497,10 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { } cc.mu.Unlock() + if newStreamHook != nil { + newStreamHook(cs) + } + // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? if !cc.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && @@ -1452,15 +1581,30 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { var respHeaderTimer <-chan time.Time var respHeaderRecv chan struct{} if d := cc.responseHeaderTimeout(); d != 0 { - timer := time.NewTimer(d) + timer := cc.newTimer(d) defer timer.Stop() - respHeaderTimer = timer.C + respHeaderTimer = timer.C() respHeaderRecv = cs.respHeaderRecv } // Wait until the peer half-closes its end of the stream, // or until the request is aborted (via context, error, or otherwise), // whichever comes first. for { + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-cs.peerClosed: + case <-respHeaderTimer: + case <-respHeaderRecv: + case <-cs.abort: + case <-ctx.Done(): + case <-cs.reqCancel: + default: + return false + } + return true + }) + } select { case <-cs.peerClosed: return nil @@ -1609,7 +1753,7 @@ func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error { return nil } cc.pendingRequests++ - cc.cond.Wait() + cc.condWait() cc.pendingRequests-- select { case <-cs.abort: @@ -1871,8 +2015,24 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) cs.flow.take(take) return take, nil } - cc.cond.Wait() + cc.condWait() + } +} + +func validateHeaders(hdrs http.Header) string { + for k, vv := range hdrs { + if !httpguts.ValidHeaderFieldName(k) { + return fmt.Sprintf("name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + // Don't include the value in the error, + // because it may be sensitive. + return fmt.Sprintf("value for header %q", k) + } + } } + return "" } var errNilRequestURL = errors.New("http2: Request.URI is nil") @@ -1912,19 +2072,14 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail } } - // Check for any invalid headers and return an error before we + // Check for any invalid headers+trailers and return an error before we // potentially pollute our hpack state. (We want to be able to // continue to reuse the hpack encoder for future requests) - for k, vv := range req.Header { - if !httpguts.ValidHeaderFieldName(k) { - return nil, fmt.Errorf("invalid HTTP header name %q", k) - } - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - // Don't include the value in the error, because it may be sensitive. - return nil, fmt.Errorf("invalid HTTP header value for header %q", k) - } - } + if err := validateHeaders(req.Header); err != "" { + return nil, fmt.Errorf("invalid HTTP header %s", err) + } + if err := validateHeaders(req.Trailer); err != "" { + return nil, fmt.Errorf("invalid HTTP trailer %s", err) } enumerateHeaders := func(f func(name, value string)) { @@ -2143,7 +2298,7 @@ func (cc *ClientConn) forgetStreamID(id uint32) { } // Wake up writeRequestBody via clientStream.awaitFlowControl and // wake up RoundTrip if there is a pending request. - cc.cond.Broadcast() + cc.condBroadcast() closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { @@ -2231,7 +2386,7 @@ func (rl *clientConnReadLoop) cleanup() { cs.abortStreamLocked(err) } } - cc.cond.Broadcast() + cc.condBroadcast() cc.mu.Unlock() } @@ -2266,10 +2421,9 @@ func (rl *clientConnReadLoop) run() error { cc := rl.cc gotSettings := false readIdleTimeout := cc.t.ReadIdleTimeout - var t *time.Timer + var t timer if readIdleTimeout != 0 { - t = time.AfterFunc(readIdleTimeout, cc.healthCheck) - defer t.Stop() + t = cc.afterFunc(readIdleTimeout, cc.healthCheck) } for { f, err := cc.fr.ReadFrame() @@ -2684,7 +2838,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { }) return nil } - if !cs.firstByte { + if !cs.pastHeaders { cc.logf("protocol error: received DATA before a HEADERS frame") rl.endStreamError(cs, StreamError{ StreamID: f.StreamID, @@ -2867,7 +3021,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error { for _, cs := range cc.streams { cs.flow.add(delta) } - cc.cond.Broadcast() + cc.condBroadcast() cc.initialWindowSize = s.Val case SettingHeaderTableSize: @@ -2911,9 +3065,18 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error { fl = &cs.flow } if !fl.add(int32(f.Increment)) { + // For stream, the sender sends RST_STREAM with an error code of FLOW_CONTROL_ERROR + if cs != nil { + rl.endStreamError(cs, StreamError{ + StreamID: f.StreamID, + Code: ErrCodeFlowControl, + }) + return nil + } + return ConnectionError(ErrCodeFlowControl) } - cc.cond.Broadcast() + cc.condBroadcast() return nil } @@ -2955,24 +3118,38 @@ func (cc *ClientConn) Ping(ctx context.Context) error { } cc.mu.Unlock() } - errc := make(chan error, 1) - go func() { + var pingError error + errc := make(chan struct{}) + cc.goRun(func() { cc.wmu.Lock() defer cc.wmu.Unlock() - if err := cc.fr.WritePing(false, p); err != nil { - errc <- err + if pingError = cc.fr.WritePing(false, p); pingError != nil { + close(errc) return } - if err := cc.bw.Flush(); err != nil { - errc <- err + if pingError = cc.bw.Flush(); pingError != nil { + close(errc) return } - }() + }) + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-c: + case <-errc: + case <-ctx.Done(): + case <-cc.readerDone: + default: + return false + } + return true + }) + } select { case <-c: return nil - case err := <-errc: - return err + case <-errc: + return pingError case <-ctx.Done(): return ctx.Err() case <-cc.readerDone: @@ -3141,9 +3318,17 @@ func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, err } func (t *Transport) idleConnTimeout() time.Duration { + // to keep things backwards compatible, we use non-zero values of + // IdleConnTimeout, followed by using the IdleConnTimeout on the underlying + // http1 transport, followed by 0 + if t.IdleConnTimeout != 0 { + return t.IdleConnTimeout + } + if t.t1 != nil { return t.t1.IdleConnTimeout } + return 0 } diff --git a/http2/transport_test.go b/http2/transport_test.go index a81131f29..11ff67b4c 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -95,6 +95,88 @@ func startH2cServer(t *testing.T) net.Listener { return l } +func TestIdleConnTimeout(t *testing.T) { + for _, test := range []struct { + name string + idleConnTimeout time.Duration + wait time.Duration + baseTransport *http.Transport + wantNewConn bool + }{{ + name: "NoExpiry", + idleConnTimeout: 2 * time.Second, + wait: 1 * time.Second, + baseTransport: nil, + wantNewConn: false, + }, { + name: "H2TransportTimeoutExpires", + idleConnTimeout: 1 * time.Second, + wait: 2 * time.Second, + baseTransport: nil, + wantNewConn: true, + }, { + name: "H1TransportTimeoutExpires", + idleConnTimeout: 0 * time.Second, + wait: 1 * time.Second, + baseTransport: &http.Transport{ + IdleConnTimeout: 2 * time.Second, + }, + wantNewConn: false, + }} { + t.Run(test.name, func(t *testing.T) { + tt := newTestTransport(t, func(tr *Transport) { + tr.IdleConnTimeout = test.idleConnTimeout + }) + var tc *testClientConn + for i := 0; i < 3; i++ { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) + + // This request happens on a new conn if it's the first request + // (and there is no cached conn), or if the test timeout is long + // enough that old conns are being closed. + wantConn := i == 0 || test.wantNewConn + if has := tt.hasConn(); has != wantConn { + t.Fatalf("request %v: hasConn=%v, want %v", i, has, wantConn) + } + if wantConn { + tc = tt.getConn() + // Read client's SETTINGS and first WINDOW_UPDATE, + // send our SETTINGS. + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.writeSettings() + } + if tt.hasConn() { + t.Fatalf("request %v: Transport has more than one conn", i) + } + + // Respond to the client's request. + hf := testClientConnReadFrame[*MetaHeadersFrame](tc) + tc.writeHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt.wantStatus(200) + + // If this was a newly-accepted conn, read the SETTINGS ACK. + if wantConn { + tc.wantFrameType(FrameSettings) // ACK to our settings + } + + tt.advance(test.wait) + if got, want := tc.netConnClosed, test.wantNewConn; got != want { + t.Fatalf("after waiting %v, conn closed=%v; want %v", test.wait, got, want) + } + } + }) + } +} + func TestTransportH2c(t *testing.T) { l := startH2cServer(t) defer l.Close() @@ -740,53 +822,6 @@ func (fw flushWriter) Write(p []byte) (n int, err error) { return } -type clientTester struct { - t *testing.T - tr *Transport - sc, cc net.Conn // server and client conn - fr *Framer // server's framer - settings *SettingsFrame - client func() error - server func() error -} - -func newClientTester(t *testing.T) *clientTester { - var dialOnce struct { - sync.Mutex - dialed bool - } - ct := &clientTester{ - t: t, - } - ct.tr = &Transport{ - TLSClientConfig: tlsConfigInsecure, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - dialOnce.Lock() - defer dialOnce.Unlock() - if dialOnce.dialed { - return nil, errors.New("only one dial allowed in test mode") - } - dialOnce.dialed = true - return ct.cc, nil - }, - } - - ln := newLocalListener(t) - cc, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - sc, err := ln.Accept() - if err != nil { - t.Fatal(err) - } - ln.Close() - ct.cc = cc - ct.sc = sc - ct.fr = NewFramer(sc, sc) - return ct -} - func newLocalListener(t *testing.T) net.Listener { ln, err := net.Listen("tcp4", "127.0.0.1:0") if err == nil { @@ -799,284 +834,70 @@ func newLocalListener(t *testing.T) net.Listener { return ln } -func (ct *clientTester) greet(settings ...Setting) { - buf := make([]byte, len(ClientPreface)) - _, err := io.ReadFull(ct.sc, buf) - if err != nil { - ct.t.Fatalf("reading client preface: %v", err) - } - f, err := ct.fr.ReadFrame() - if err != nil { - ct.t.Fatalf("Reading client settings frame: %v", err) - } - var ok bool - if ct.settings, ok = f.(*SettingsFrame); !ok { - ct.t.Fatalf("Wanted client settings frame; got %v", f) - } - if err := ct.fr.WriteSettings(settings...); err != nil { - ct.t.Fatal(err) - } - if err := ct.fr.WriteSettingsAck(); err != nil { - ct.t.Fatal(err) - } -} - -func (ct *clientTester) readNonSettingsFrame() (Frame, error) { - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return nil, err - } - if _, ok := f.(*SettingsFrame); ok { - continue - } - return f, nil - } -} - -// writeReadPing sends a PING and immediately reads the PING ACK. -// It will fail if any other unread data was pending on the connection, -// aside from SETTINGS frames. -func (ct *clientTester) writeReadPing() error { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - if err := ct.fr.WritePing(false, data); err != nil { - return fmt.Errorf("Error writing PING: %v", err) - } - f, err := ct.readNonSettingsFrame() - if err != nil { - return err - } - p, ok := f.(*PingFrame) - if !ok { - return fmt.Errorf("got a %v, want a PING ACK", f) - } - if p.Flags&FlagPingAck == 0 { - return fmt.Errorf("got a PING, want a PING ACK") - } - if p.Data != data { - return fmt.Errorf("got PING data = %x, want %x", p.Data, data) - } - return nil -} - -func (ct *clientTester) inflowWindow(streamID uint32) int32 { - pool := ct.tr.connPoolOrDef.(*clientConnPool) - pool.mu.Lock() - defer pool.mu.Unlock() - if n := len(pool.keys); n != 1 { - ct.t.Errorf("clientConnPool contains %v keys, expected 1", n) - return -1 - } - for cc := range pool.keys { - cc.mu.Lock() - defer cc.mu.Unlock() - if streamID == 0 { - return cc.inflow.avail + cc.inflow.unsent - } - cs := cc.streams[streamID] - if cs == nil { - ct.t.Errorf("no stream with id %v", streamID) - return -1 - } - return cs.inflow.avail + cs.inflow.unsent - } - return -1 -} - -func (ct *clientTester) cleanup() { - ct.tr.CloseIdleConnections() - - // close both connections, ignore the error if its already closed - ct.sc.Close() - ct.cc.Close() -} - -func (ct *clientTester) run() { - var errOnce sync.Once - var wg sync.WaitGroup - - run := func(which string, fn func() error) { - defer wg.Done() - if err := fn(); err != nil { - errOnce.Do(func() { - ct.t.Errorf("%s: %v", which, err) - ct.cleanup() - }) - } - } +func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) } +func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) } - wg.Add(2) - go run("client", ct.client) - go run("server", ct.server) - wg.Wait() +func testTransportReqBodyAfterResponse(t *testing.T, status int) { + const bodySize = 10 << 20 - errOnce.Do(ct.cleanup) // clean up if no error -} + tc := newTestClientConn(t) + tc.greet() + + body := tc.newRequestBody() + body.writeBytes(bodySize / 2) + req, _ := http.NewRequest("PUT", "https://dummy.tld/", body) + rt := tc.roundTrip(req) + + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: false, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"PUT"}, + ":path": []string{"/"}, + }, + }) -func (ct *clientTester) readFrame() (Frame, error) { - return ct.fr.ReadFrame() -} + // Provide enough congestion window for the full request body. + tc.writeWindowUpdate(0, bodySize) + tc.writeWindowUpdate(rt.streamID(), bodySize) -func (ct *clientTester) firstHeaders() (*HeadersFrame, error) { - for { - f, err := ct.readFrame() - if err != nil { - return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - continue - } - hf, ok := f.(*HeadersFrame) - if !ok { - return nil, fmt.Errorf("Got %T; want HeadersFrame", f) - } - return hf, nil - } -} + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: false, + size: bodySize / 2, + }) -type countingReader struct { - n *int64 -} + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", strconv.Itoa(status), + ), + }) -func (r countingReader) Read(p []byte) (n int, err error) { - for i := range p { - p[i] = byte(i) + res := rt.response() + if res.StatusCode != status { + t.Fatalf("status code = %v; want %v", res.StatusCode, status) } - atomic.AddInt64(r.n, int64(len(p))) - return len(p), err -} -func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) } -func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) } + body.writeBytes(bodySize / 2) + body.closeWithError(io.EOF) -func testTransportReqBodyAfterResponse(t *testing.T, status int) { - const bodySize = 10 << 20 - clientDone := make(chan struct{}) - ct := newClientTester(t) - recvLen := make(chan int64, 1) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - - body := &pipe{b: new(bytes.Buffer)} - io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2)) - req, err := http.NewRequest("PUT", "https://dummy.tld/", body) - if err != nil { - return err - } - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - if res.StatusCode != status { - return fmt.Errorf("status code = %v; want %v", res.StatusCode, status) - } - io.Copy(body, io.LimitReader(neverEnding('A'), bodySize/2)) - body.CloseWithError(io.EOF) - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("Slurp: %v", err) - } - if len(slurp) > 0 { - return fmt.Errorf("unexpected body: %q", slurp) - } - res.Body.Close() - if status == 200 { - if got := <-recvLen; got != bodySize { - return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize) - } - } else { - if got := <-recvLen; got == 0 || got >= bodySize { - return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize) - } - } - return nil + if status == 200 { + // After a 200 response, client sends the remaining request body. + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: true, + size: bodySize / 2, + }) + } else { + // After a 403 response, client gives up and resets the stream. + tc.wantFrameType(FrameRSTStream) } - ct.server = func() error { - ct.greet() - defer close(recvLen) - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - var dataRecv int64 - var closed bool - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - return nil - default: - return err - } - } - //println(fmt.Sprintf("server got frame: %v", f)) - ended := false - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - if f.StreamEnded() { - return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f) - } - case *DataFrame: - dataLen := len(f.Data()) - if dataLen > 0 { - if dataRecv == 0 { - enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - } - if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { - return err - } - if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { - return err - } - } - dataRecv += int64(dataLen) - - if !closed && ((status != 200 && dataRecv > 0) || - (status == 200 && f.StreamEnded())) { - closed = true - if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil { - return err - } - } - if f.StreamEnded() { - ended = true - } - case *RSTStreamFrame: - if status == 200 { - return fmt.Errorf("Unexpected client frame %v", f) - } - ended = true - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - if ended { - select { - case recvLen <- dataRecv: - default: - } - } - } - } - ct.run() + rt.wantBody(nil) } // See golang.org/issue/13444 @@ -1257,121 +1078,74 @@ func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerTy panic("invalid combination") } - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody)) - if expect100Continue != noHeader { - req.Header.Set("Expect", "100-continue") - } - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want 200", res.StatusCode) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("Slurp: %v", err) - } - wantBody := resBody - if !withData { - wantBody = "" - } - if string(slurp) != wantBody { - return fmt.Errorf("body = %q; want %q", slurp, wantBody) - } - if trailers == noHeader { - if len(res.Trailer) > 0 { - t.Errorf("Trailer = %v; want none", res.Trailer) - } - } else { - want := http.Header{"Some-Trailer": {"some-value"}} - if !reflect.DeepEqual(res.Trailer, want) { - t.Errorf("Trailer = %v; want %v", res.Trailer, want) - } - } - return nil + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody)) + if expect100Continue != noHeader { + req.Header.Set("Expect", "100-continue") } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) + rt := tc.roundTrip(req) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - endStream := false - send := func(mode headerType) { - hbf := buf.Bytes() - switch mode { - case oneHeader: - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.Header().StreamID, - EndHeaders: true, - EndStream: endStream, - BlockFragment: hbf, - }) - case splitHeader: - if len(hbf) < 2 { - panic("too small") - } - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.Header().StreamID, - EndHeaders: false, - EndStream: endStream, - BlockFragment: hbf[:1], - }) - ct.fr.WriteContinuation(f.Header().StreamID, true, hbf[1:]) - default: - panic("bogus mode") - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *DataFrame: - if !f.StreamEnded() { - // No need to send flow control tokens. The test request body is tiny. - continue - } - // Response headers (1+ frames; 1 or 2 in this test, but never 0) - { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"}) - enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"}) - if trailers != noHeader { - enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"}) - } - endStream = withData == false && trailers == noHeader - send(resHeader) - } - if withData { - endStream = trailers == noHeader - ct.fr.WriteData(f.StreamID, endStream, []byte(resBody)) - } - if trailers != noHeader { - endStream = true - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"}) - send(trailers) - } - if endStream { - return nil - } - case *HeadersFrame: - if expect100Continue != noHeader { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"}) - send(expect100Continue) - } - } - } + tc.wantFrameType(FrameHeaders) + + // Possibly 100-continue, or skip when noHeader. + tc.writeHeadersMode(expect100Continue, HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "100", + ), + }) + + // Client sends request body. + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: true, + size: len(reqBody), + }) + + hdr := []string{ + ":status", "200", + "x-foo", "blah", + "x-bar", "more", + } + if trailers != noHeader { + hdr = append(hdr, "trailer", "some-trailer") + } + tc.writeHeadersMode(resHeader, HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: withData == false && trailers == noHeader, + BlockFragment: tc.makeHeaderBlockFragment(hdr...), + }) + if withData { + endStream := trailers == noHeader + tc.writeData(rt.streamID(), endStream, []byte(resBody)) + } + tc.writeHeadersMode(trailers, HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + "some-trailer", "some-value", + ), + }) + + rt.wantStatus(200) + if !withData { + rt.wantBody(nil) + } else { + rt.wantBody([]byte(resBody)) + } + if trailers == noHeader { + rt.wantTrailers(nil) + } else { + rt.wantTrailers(http.Header{ + "Some-Trailer": {"some-value"}, + }) } - ct.run() } // Issue 26189, Issue 17739: ignore unknown 1xx responses @@ -1383,130 +1157,76 @@ func TestTransportUnknown1xx(t *testing.T) { return nil } - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 204 { - return fmt.Errorf("status code = %v; want 204", res.StatusCode) - } - want := `code=110 header=map[Foo-Bar:[110]] + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + for i := 110; i <= 114; i++ { + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", fmt.Sprint(i), + "foo-bar", fmt.Sprint(i), + ), + }) + } + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "204", + ), + }) + + res := rt.response() + if res.StatusCode != 204 { + t.Fatalf("status code = %v; want 204", res.StatusCode) + } + want := `code=110 header=map[Foo-Bar:[110]] code=111 header=map[Foo-Bar:[111]] code=112 header=map[Foo-Bar:[112]] code=113 header=map[Foo-Bar:[113]] code=114 header=map[Foo-Bar:[114]] ` - if got := buf.String(); got != want { - t.Errorf("Got trace:\n%s\nWant:\n%s", got, want) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - for i := 110; i <= 114; i++ { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(i)}) - enc.WriteField(hpack.HeaderField{Name: "foo-bar", Value: fmt.Sprint(i)}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - } - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - return nil - } - } + if got := buf.String(); got != want { + t.Errorf("Got trace:\n%s\nWant:\n%s", got, want) } - ct.run() - } func TestTransportReceiveUndeclaredTrailer(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want 200", res.StatusCode) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil) - } - if len(slurp) > 0 { - return fmt.Errorf("body = %q; want nothing", slurp) - } - if _, ok := res.Trailer["Some-Trailer"]; !ok { - return fmt.Errorf("expected Some-Trailer") - } - return nil - } - ct.server = func() error { - ct.greet() - - var n int - var hf *HeadersFrame - for hf == nil && n < 10 { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - hf, _ = f.(*HeadersFrame) - n++ - } - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - - // send headers without Trailer header - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + "some-trailer", "I'm an undeclared Trailer!", + ), + }) - // send trailers - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - ct.run() + rt.wantStatus(200) + rt.wantBody(nil) + rt.wantTrailers(http.Header{ + "Some-Trailer": []string{"I'm an undeclared Trailer!"}, + }) } func TestTransportInvalidTrailer_Pseudo1(t *testing.T) { @@ -1516,10 +1236,10 @@ func TestTransportInvalidTrailer_Pseudo2(t *testing.T) { testTransportInvalidTrailer_Pseudo(t, splitHeader) } func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) { - testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"}) - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - }) + testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), + ":colon", "foo", + "foo", "bar", + ) } func TestTransportInvalidTrailer_Capital1(t *testing.T) { @@ -1529,102 +1249,54 @@ func TestTransportInvalidTrailer_Capital2(t *testing.T) { testTransportInvalidTrailer_Capital(t, splitHeader) } func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) { - testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"}) - }) + testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), + "foo", "bar", + "Capital", "bad", + ) } func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) { - testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"}) - }) + testInvalidTrailer(t, oneHeader, headerFieldNameError(""), + "", "bad", + ) } func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) { - testInvalidTrailer(t, oneHeader, headerFieldValueError("x"), func(enc *hpack.Encoder) { - enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"}) - }) + testInvalidTrailer(t, oneHeader, headerFieldValueError("x"), + "x", "has\nnewline", + ) } -func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeTrailer func(*hpack.Encoder)) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want 200", res.StatusCode) - } - slurp, err := ioutil.ReadAll(res.Body) - se, ok := err.(StreamError) - if !ok || se.Cause != wantErr { - return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr) - } - if len(slurp) > 0 { - return fmt.Errorf("body = %q; want nothing", slurp) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) +func testInvalidTrailer(t *testing.T, mode headerType, wantErr error, trailers ...string) { + tc := newTestClientConn(t) + tc.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - var endStream bool - send := func(mode headerType) { - hbf := buf.Bytes() - switch mode { - case oneHeader: - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: endStream, - BlockFragment: hbf, - }) - case splitHeader: - if len(hbf) < 2 { - panic("too small") - } - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: false, - EndStream: endStream, - BlockFragment: hbf[:1], - }) - ct.fr.WriteContinuation(f.StreamID, true, hbf[1:]) - default: - panic("bogus mode") - } - } - // Response headers (1+ frames; 1 or 2 in this test, but never 0) - { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "declared"}) - endStream = false - send(oneHeader) - } - // Trailers: - { - endStream = true - buf.Reset() - writeTrailer(enc) - send(trailers) - } - return nil - } - } + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "trailer", "declared", + ), + }) + tc.writeHeadersMode(mode, HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment(trailers...), + }) + + rt.wantStatus(200) + body, err := rt.readBody() + se, ok := err.(StreamError) + if !ok || se.Cause != wantErr { + t.Fatalf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", body, err, wantErr, wantErr) + } + if len(body) > 0 { + t.Fatalf("body = %q; want nothing", body) } - ct.run() } // headerListSize returns the HTTP2 header list size of h. @@ -1900,115 +1572,80 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { } func TestTransportChecksResponseHeaderListSize(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if e, ok := err.(StreamError); ok { - err = e.Cause - } - if err != errResponseHeaderListSize { - size := int64(0) - if res != nil { - res.Body.Close() - for k, vv := range res.Header { - for _, v := range vv { - size += int64(len(k)) + int64(len(v)) + 32 - } - } - } - return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + + hdr := []string{":status", "200"} + large := strings.Repeat("a", 1<<10) + for i := 0; i < 5042; i++ { + hdr = append(hdr, large, large) + } + hbf := tc.makeHeaderBlockFragment(hdr...) + // Note: this number might change if our hpack implementation changes. + // That's fine. This is just a sanity check that our response can fit in a single + // header block fragment frame. + if size, want := len(hbf), 6329; size != want { + t.Fatalf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want) + } + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: hbf, + }) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - large := strings.Repeat("a", 1<<10) - for i := 0; i < 5042; i++ { - enc.WriteField(hpack.HeaderField{Name: large, Value: large}) - } - if size, want := buf.Len(), 6329; size != want { - // Note: this number might change if - // our hpack implementation - // changes. That's fine. This is - // just a sanity check that our - // response can fit in a single - // header block fragment frame. - return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want) + res, err := rt.result() + if e, ok := err.(StreamError); ok { + err = e.Cause + } + if err != errResponseHeaderListSize { + size := int64(0) + if res != nil { + res.Body.Close() + for k, vv := range res.Header { + for _, v := range vv { + size += int64(len(k)) + int64(len(v)) + 32 } - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil } } + t.Fatalf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size) } - ct.run() } func TestTransportCookieHeaderSplit(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - req.Header.Add("Cookie", "a=b;c=d; e=f;") - req.Header.Add("Cookie", "e=f;g=h; ") - req.Header.Add("Cookie", "i=j") - _, err := ct.tr.RoundTrip(req) - return err - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - dec := hpack.NewDecoder(initialHeaderTableSize, nil) - hfs, err := dec.DecodeFull(f.HeaderBlockFragment()) - if err != nil { - return err - } - got := []string{} - want := []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"} - for _, hf := range hfs { - if hf.Name == "cookie" { - got = append(got, hf.Value) - } - } - if !reflect.DeepEqual(got, want) { - t.Errorf("Cookies = %#v, want %#v", got, want) - } + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + req.Header.Add("Cookie", "a=b;c=d; e=f;") + req.Header.Add("Cookie", "e=f;g=h; ") + req.Header.Add("Cookie", "i=j") + rt := tc.roundTrip(req) + + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: true, + header: http.Header{ + "cookie": []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"}, + }, + }) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "204", + ), + }) - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - } + if err := rt.err(); err != nil { + t.Fatalf("RoundTrip = %v, want success", err) } - ct.run() } // Test that the Transport returns a typed error from Response.Body.Read calls @@ -2224,55 +1861,49 @@ func TestTransportResponseHeaderTimeout_Body(t *testing.T) { } func testTransportResponseHeaderTimeout(t *testing.T, body bool) { - ct := newClientTester(t) - ct.tr.t1 = &http.Transport{ - ResponseHeaderTimeout: 5 * time.Millisecond, - } - ct.client = func() error { - c := &http.Client{Transport: ct.tr} - var err error - var n int64 - const bodySize = 4 << 20 - if body { - _, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize)) - } else { - _, err = c.Get("https://dummy.tld/") - } - if !isTimeout(err) { - t.Errorf("client expected timeout error; got %#v", err) - } - if body && n != bodySize { - t.Errorf("only read %d bytes of body; want %d", n, bodySize) + const bodySize = 4 << 20 + tc := newTestClientConn(t, func(tr *Transport) { + tr.t1 = &http.Transport{ + ResponseHeaderTimeout: 5 * time.Millisecond, } - return nil + }) + tc.greet() + + var req *http.Request + var reqBody *testRequestBody + if body { + reqBody = tc.newRequestBody() + reqBody.writeBytes(bodySize) + reqBody.closeWithError(io.EOF) + req, _ = http.NewRequest("POST", "https://dummy.tld/", reqBody) + req.Header.Set("Content-Type", "text/foo") + } else { + req, _ = http.NewRequest("GET", "https://dummy.tld/", nil) } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - switch f := f.(type) { - case *DataFrame: - dataLen := len(f.Data()) - if dataLen > 0 { - if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { - return err - } - if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { - return err - } - } - case *RSTStreamFrame: - if f.StreamID == 1 && f.ErrCode == ErrCodeCancel { - return nil - } - } - } + + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + + tc.writeWindowUpdate(0, bodySize) + tc.writeWindowUpdate(rt.streamID(), bodySize) + + if body { + tc.wantData(wantData{ + endStream: true, + size: bodySize, + }) + } + + tc.advance(4 * time.Millisecond) + if rt.done() { + t.Fatalf("RoundTrip is done after 4ms; want still waiting") + } + tc.advance(1 * time.Millisecond) + + if err := rt.err(); !isTimeout(err) { + t.Fatalf("RoundTrip error: %v; want timeout error", err) } - ct.run() } func TestTransportDisableCompression(t *testing.T) { @@ -2484,7 +2115,8 @@ func TestTransportRejectsContentLengthWithSign(t *testing.T) { } // golang.org/issue/14048 -func TestTransportFailsOnInvalidHeaders(t *testing.T) { +// golang.org/issue/64766 +func TestTransportFailsOnInvalidHeadersAndTrailers(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { var got []string for k := range r.Header { @@ -2497,6 +2129,7 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { tests := [...]struct { h http.Header + t http.Header wantErr string }{ 0: { @@ -2515,6 +2148,14 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { h: http.Header{"foo": {"foo\x01bar"}}, wantErr: `invalid HTTP header value for header "foo"`, }, + 4: { + t: http.Header{"foo": {"foo\x01bar"}}, + wantErr: `invalid HTTP trailer value for header "foo"`, + }, + 5: { + t: http.Header{"x-\r\nda": {"foo\x01bar"}}, + wantErr: `invalid HTTP trailer name "x-\r\nda"`, + }, } tr := &Transport{TLSClientConfig: tlsConfigInsecure} @@ -2523,6 +2164,7 @@ func TestTransportFailsOnInvalidHeaders(t *testing.T) { for i, tt := range tests { req, _ := http.NewRequest("GET", st.ts.URL, nil) req.Header = tt.h + req.Trailer = tt.t res, err := tr.RoundTrip(req) var bad bool if tt.wantErr == "" { @@ -2658,115 +2300,61 @@ func TestTransportNewTLSConfig(t *testing.T) { // without END_STREAM, followed by a 0-length DATA frame with // END_STREAM. Make sure we don't get confused by that. (We did.) func TestTransportReadHeadResponse(t *testing.T) { - ct := newClientTester(t) - clientDone := make(chan struct{}) - ct.client = func() error { - defer close(clientDone) - req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - if res.ContentLength != 123 { - return fmt.Errorf("Content-Length = %d; want 123", res.ContentLength) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("ReadAll: %v", err) - } - if len(slurp) > 0 { - return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - hf, ok := f.(*HeadersFrame) - if !ok { - continue - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, // as the GFE does - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(hf.StreamID, true, nil) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, // as the GFE does + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", "123", + ), + }) + tc.writeData(rt.streamID(), true, nil) - <-clientDone - return nil - } + res := rt.response() + if res.ContentLength != 123 { + t.Fatalf("Content-Length = %d; want 123", res.ContentLength) } - ct.run() + rt.wantBody(nil) } func TestTransportReadHeadResponseWithBody(t *testing.T) { - // This test use not valid response format. - // Discarding logger output to not spam tests output. - log.SetOutput(ioutil.Discard) + // This test uses an invalid response format. + // Discard logger output to not spam tests output. + log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) response := "redirecting to /elsewhere" - ct := newClientTester(t) - clientDone := make(chan struct{}) - ct.client = func() error { - defer close(clientDone) - req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - if res.ContentLength != int64(len(response)) { - return fmt.Errorf("Content-Length = %d; want %d", res.ContentLength, len(response)) - } - slurp, err := ioutil.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("ReadAll: %v", err) - } - if len(slurp) > 0 { - return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - hf, ok := f.(*HeadersFrame) - if !ok { - continue - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(response))}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(hf.StreamID, true, []byte(response)) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", strconv.Itoa(len(response)), + ), + }) + tc.writeData(rt.streamID(), true, []byte(response)) - <-clientDone - return nil - } + res := rt.response() + if res.ContentLength != int64(len(response)) { + t.Fatalf("Content-Length = %d; want %d", res.ContentLength, len(response)) } - ct.run() + rt.wantBody(nil) } type neverEnding byte @@ -2891,190 +2479,125 @@ func TestTransportUsesGoAwayDebugError_Body(t *testing.T) { } func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { - ct := newClientTester(t) - clientDone := make(chan struct{}) + tc := newTestClientConn(t) + tc.greet() const goAwayErrCode = ErrCodeHTTP11Required // arbitrary const goAwayDebugData = "some debug data" - ct.client = func() error { - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if failMidBody { - if err != nil { - return fmt.Errorf("unexpected client RoundTrip error: %v", err) - } - _, err = io.Copy(ioutil.Discard, res.Body) - res.Body.Close() - } - want := GoAwayError{ - LastStreamID: 5, - ErrCode: goAwayErrCode, - DebugData: goAwayDebugData, - } - if !reflect.DeepEqual(err, want) { - t.Errorf("RoundTrip error = %T: %#v, want %T (%#v)", err, err, want, want) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - t.Logf("ReadFrame: %v", err) - return nil - } - hf, ok := f.(*HeadersFrame) - if !ok { - continue - } - if failMidBody { - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - } - // Write two GOAWAY frames, to test that the Transport takes - // the interesting parts of both. - ct.fr.WriteGoAway(5, ErrCodeNo, []byte(goAwayDebugData)) - ct.fr.WriteGoAway(5, goAwayErrCode, nil) - ct.sc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - ct.sc.(*net.TCPConn).Close() - } - <-clientDone - return nil - } + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + + if failMidBody { + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", "123", + ), + }) } - ct.run() -} -func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { - ct := newClientTester(t) + // Write two GOAWAY frames, to test that the Transport takes + // the interesting parts of both. + tc.writeGoAway(5, ErrCodeNo, []byte(goAwayDebugData)) + tc.writeGoAway(5, goAwayErrCode, nil) + tc.closeWrite(io.EOF) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) + res, err := rt.result() + whence := "RoundTrip" + if failMidBody { + whence = "Body.Read" if err != nil { - return err + t.Fatalf("RoundTrip error = %v, want success", err) } + _, err = res.Body.Read(make([]byte, 1)) + } - if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 { - return fmt.Errorf("body read = %v, %v; want 1, nil", n, err) - } - res.Body.Close() // leaving 4999 bytes unread - - return nil + want := GoAwayError{ + LastStreamID: 5, + ErrCode: goAwayErrCode, + DebugData: goAwayDebugData, + } + if !reflect.DeepEqual(err, want) { + t.Errorf("%v error = %T: %#v, want %T (%#v)", whence, err, err, want, want) } - ct.server = func() error { - ct.greet() +} - var hf *HeadersFrame - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - continue - } - var ok bool - hf, ok = f.(*HeadersFrame) - if !ok { - return fmt.Errorf("Got %T; want HeadersFrame", f) - } - break - } +func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", "5000", + ), + }) + initialInflow := tc.inflowWindow(0) + + // Two cases: + // - Send one DATA frame with 5000 bytes. + // - Send two DATA frames with 1 and 4999 bytes each. + // + // In both cases, the client should consume one byte of data, + // refund that byte, then refund the following 4999 bytes. + // + // In the second case, the server waits for the client to reset the + // stream before sending the second DATA frame. This tests the case + // where the client receives a DATA frame after it has reset the stream. + const streamNotEnded = false + if oneDataFrame { + tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 5000)) + } else { + tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 1)) + } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - initialInflow := ct.inflowWindow(0) - - // Two cases: - // - Send one DATA frame with 5000 bytes. - // - Send two DATA frames with 1 and 4999 bytes each. - // - // In both cases, the client should consume one byte of data, - // refund that byte, then refund the following 4999 bytes. - // - // In the second case, the server waits for the client to reset the - // stream before sending the second DATA frame. This tests the case - // where the client receives a DATA frame after it has reset the stream. - if oneDataFrame { - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000)) - } else { - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1)) - } + res := rt.response() + if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 { + t.Fatalf("body read = %v, %v; want 1, nil", n, err) + } + res.Body.Close() // leaving 4999 bytes unread + tc.sync() - wantRST := true - wantWUF := true - if !oneDataFrame { - wantWUF = false // flow control update is small, and will not be sent - } - for wantRST || wantWUF { - f, err := ct.readNonSettingsFrame() - if err != nil { - return err + sentAdditionalData := false + tc.wantUnorderedFrames( + func(f *RSTStreamFrame) bool { + if f.ErrCode != ErrCodeCancel { + t.Fatalf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f)) } - switch f := f.(type) { - case *RSTStreamFrame: - if !wantRST { - return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) - } - if f.ErrCode != ErrCodeCancel { - return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f)) - } - wantRST = false - case *WindowUpdateFrame: - if !wantWUF { - return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) - } - if f.Increment != 5000 { - return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f)) - } - wantWUF = false - default: - return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f)) + if !oneDataFrame { + // Send the remaining data now. + tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 4999)) + sentAdditionalData = true } - } - if !oneDataFrame { - ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999)) - f, err := ct.readNonSettingsFrame() - if err != nil { - return err + return true + }, + func(f *WindowUpdateFrame) bool { + if !oneDataFrame && !sentAdditionalData { + t.Fatalf("Got WindowUpdateFrame, don't expect one yet") } - wuf, ok := f.(*WindowUpdateFrame) - if !ok || wuf.Increment != 5000 { - return fmt.Errorf("want WindowUpdateFrame for 5000 bytes; got %v", summarizeFrame(f)) + if f.Increment != 5000 { + t.Fatalf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f)) } - } - if err := ct.writeReadPing(); err != nil { - return err - } - if got, want := ct.inflowWindow(0), initialInflow; got != want { - return fmt.Errorf("connection flow tokens = %v, want %v", got, want) - } - return nil + return true + }, + ) + + if got, want := tc.inflowWindow(0), initialInflow; got != want { + t.Fatalf("connection flow tokens = %v, want %v", got, want) } - ct.run() } // See golang.org/issue/16481 @@ -3090,199 +2613,124 @@ func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) { // Issue 16612: adjust flow control on open streams when transport // receives SETTINGS with INITIAL_WINDOW_SIZE from server. func TestTransportAdjustsFlowControl(t *testing.T) { - ct := newClientTester(t) - clientDone := make(chan struct{}) - const bodySize = 1 << 20 - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) + tc := newTestClientConn(t) + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + // Don't write our SETTINGS yet. - req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)}) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err + body := tc.newRequestBody() + body.writeBytes(bodySize) + body.closeWithError(io.EOF) + + req, _ := http.NewRequest("POST", "https://dummy.tld/", body) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + + gotBytes := int64(0) + for { + f := testClientConnReadFrame[*DataFrame](tc) + gotBytes += int64(len(f.Data())) + // After we've got half the client's initial flow control window's worth + // of request body data, give it just enough flow control to finish. + if gotBytes >= initialWindowSize/2 { + break } - res.Body.Close() - return nil } - ct.server = func() error { - _, err := io.ReadFull(ct.sc, make([]byte, len(ClientPreface))) - if err != nil { - return fmt.Errorf("reading client preface: %v", err) - } - var gotBytes int64 - var sentSettings bool - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - return nil - default: - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - } - switch f := f.(type) { - case *DataFrame: - gotBytes += int64(len(f.Data())) - // After we've got half the client's - // initial flow control window's worth - // of request body data, give it just - // enough flow control to finish. - if gotBytes >= initialWindowSize/2 && !sentSettings { - sentSettings = true - - ct.fr.WriteSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize}) - ct.fr.WriteWindowUpdate(0, bodySize) - ct.fr.WriteSettingsAck() - } + tc.writeSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize}) + tc.writeWindowUpdate(0, bodySize) + tc.writeSettingsAck() - if f.StreamEnded() { - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - } - } + tc.wantUnorderedFrames( + func(f *SettingsFrame) bool { return true }, + func(f *DataFrame) bool { + gotBytes += int64(len(f.Data())) + return f.StreamEnded() + }, + ) + + if gotBytes != bodySize { + t.Fatalf("server received %v bytes of body, want %v", gotBytes, bodySize) } - ct.run() + + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt.wantStatus(200) } // See golang.org/issue/16556 func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { - ct := newClientTester(t) - - unblockClient := make(chan bool, 1) - - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return err - } - defer res.Body.Close() - <-unblockClient - return nil - } - ct.server = func() error { - ct.greet() + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + "content-length", "5000", + ), + }) - var hf *HeadersFrame - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - continue - } - var ok bool - hf, ok = f.(*HeadersFrame) - if !ok { - return fmt.Errorf("Got %T; want HeadersFrame", f) - } - break - } + initialConnWindow := tc.inflowWindow(0) + initialStreamWindow := tc.inflowWindow(rt.streamID()) - initialConnWindow := ct.inflowWindow(0) + pad := make([]byte, 5) + tc.writeDataPadded(rt.streamID(), false, make([]byte, 5000), pad) - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - initialStreamWindow := ct.inflowWindow(hf.StreamID) - pad := make([]byte, 5) - ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream - if err := ct.writeReadPing(); err != nil { - return err - } - // Padding flow control should have been returned. - if got, want := ct.inflowWindow(0), initialConnWindow-5000; got != want { - t.Errorf("conn inflow window = %v, want %v", got, want) - } - if got, want := ct.inflowWindow(hf.StreamID), initialStreamWindow-5000; got != want { - t.Errorf("stream inflow window = %v, want %v", got, want) - } - unblockClient <- true - return nil + // Padding flow control should have been returned. + if got, want := tc.inflowWindow(0), initialConnWindow-5000; got != want { + t.Errorf("conn inflow window = %v, want %v", got, want) + } + if got, want := tc.inflowWindow(rt.streamID()), initialStreamWindow-5000; got != want { + t.Errorf("stream inflow window = %v, want %v", got, want) } - ct.run() } // golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a // StreamError as a result of the response HEADERS func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) { - ct := newClientTester(t) + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + " content-type", "bogus", + ), + }) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err == nil { - res.Body.Close() - return errors.New("unexpected successful GET") - } - want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")} - if !reflect.DeepEqual(want, err) { - t.Errorf("RoundTrip error = %#v; want %#v", err, want) - } - return nil + err := rt.err() + want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")} + if !reflect.DeepEqual(err, want) { + t.Fatalf("RoundTrip error = %#v; want %#v", err, want) } - ct.server = func() error { - ct.greet() - - hf, err := ct.firstHeaders() - if err != nil { - return err - } - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: " content-type", Value: "bogus"}) // bogus spaces - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - for { - fr, err := ct.readFrame() - if err != nil { - return fmt.Errorf("error waiting for RST_STREAM from client: %v", err) - } - if _, ok := fr.(*SettingsFrame); ok { - continue - } - if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol { - t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr)) - } - break - } - - return nil + fr := testClientConnReadFrame[*RSTStreamFrame](tc) + if fr.StreamID != 1 || fr.ErrCode != ErrCodeProtocol { + t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr)) } - ct.run() } // byteAndEOFReader returns is in an io.Reader which reads one byte @@ -3576,26 +3024,24 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) { } func TestTransportCloseAfterLostPing(t *testing.T) { - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.tr.PingTimeout = 1 * time.Second - ct.tr.ReadIdleTimeout = 1 * time.Second - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - _, err := ct.tr.RoundTrip(req) - if err == nil || !strings.Contains(err.Error(), "client connection lost") { - return fmt.Errorf("expected to get error about \"connection lost\", got %v", err) - } - return nil - } - ct.server = func() error { - ct.greet() - <-clientDone - return nil + tc := newTestClientConn(t, func(tr *Transport) { + tr.PingTimeout = 1 * time.Second + tr.ReadIdleTimeout = 1 * time.Second + }) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + + tc.advance(1 * time.Second) + tc.wantFrameType(FramePing) + + tc.advance(1 * time.Second) + err := rt.err() + if err == nil || !strings.Contains(err.Error(), "client connection lost") { + t.Fatalf("expected to get error about \"connection lost\", got %v", err) } - ct.run() } func TestTransportPingWriteBlocks(t *testing.T) { @@ -3628,418 +3074,231 @@ func TestTransportPingWriteBlocks(t *testing.T) { } } -func TestTransportPingWhenReading(t *testing.T) { - testCases := []struct { - name string - readIdleTimeout time.Duration - deadline time.Duration - expectedPingCount int - }{ - { - name: "two pings", - readIdleTimeout: 100 * time.Millisecond, - deadline: time.Second, - expectedPingCount: 2, - }, - { - name: "zero ping", - readIdleTimeout: time.Second, - deadline: 200 * time.Millisecond, - expectedPingCount: 0, - }, - { - name: "0 readIdleTimeout means no ping", - readIdleTimeout: 0 * time.Millisecond, - deadline: 500 * time.Millisecond, - expectedPingCount: 0, - }, - } - - for _, tc := range testCases { - tc := tc // capture range variable - t.Run(tc.name, func(t *testing.T) { - testTransportPingWhenReading(t, tc.readIdleTimeout, tc.deadline, tc.expectedPingCount) - }) - } -} +func TestTransportPingWhenReadingMultiplePings(t *testing.T) { + tc := newTestClientConn(t, func(tr *Transport) { + tr.ReadIdleTimeout = 1000 * time.Millisecond + }) + tc.greet() -func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.Duration, expectedPingCount int) { - var pingCount int - ct := newClientTester(t) - ct.tr.ReadIdleTimeout = readIdleTimeout + ctx, cancel := context.WithCancel(context.Background()) + req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) - ctx, cancel := context.WithTimeout(context.Background(), deadline) - defer cancel() - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip: %v", err) - } - defer res.Body.Close() - if res.StatusCode != 200 { - return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200) - } - _, err = ioutil.ReadAll(res.Body) - if expectedPingCount == 0 && errors.Is(ctx.Err(), context.DeadlineExceeded) { - return nil + for i := 0; i < 5; i++ { + // No ping yet... + tc.advance(999 * time.Millisecond) + if f := tc.readFrame(); f != nil { + t.Fatalf("unexpected frame: %v", f) } - cancel() - return err + // ...ping now. + tc.advance(1 * time.Millisecond) + f := testClientConnReadFrame[*PingFrame](tc) + tc.writePing(true, f.Data) } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - var streamID uint32 - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-ctx.Done(): - // If the client's done, it - // will have reported any - // errors on its side. - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - streamID = f.StreamID - case *PingFrame: - pingCount++ - if pingCount == expectedPingCount { - if err := ct.fr.WriteData(streamID, true, []byte("hello, this is last server data frame")); err != nil { - return err - } - } - if err := ct.fr.WritePing(true, f.Data); err != nil { - return err - } - case *RSTStreamFrame: - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } + // Cancel the request, Transport resets it and returns an error from body reads. + cancel() + tc.sync() + + tc.wantFrameType(FrameRSTStream) + _, err := rt.readBody() + if err == nil { + t.Fatalf("Response.Body.Read() = %v, want error", err) } - ct.run() } -func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) { - ln := newLocalListener(t) - defer ln.Close() - - var ( - mu sync.Mutex - count int - conns []net.Conn - ) - var wg sync.WaitGroup - tr := &Transport{ - TLSClientConfig: tlsConfigInsecure, - } - tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) { - mu.Lock() - defer mu.Unlock() - count++ - cc, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - return nil, fmt.Errorf("dial error: %v", err) - } - conns = append(conns, cc) - sc, err := ln.Accept() - if err != nil { - return nil, fmt.Errorf("accept error: %v", err) - } - conns = append(conns, sc) - ct := &clientTester{ - t: t, - tr: tr, - cc: cc, - sc: sc, - fr: NewFramer(sc, sc), - } - wg.Add(1) - go func(count int) { - defer wg.Done() - server(count, ct) - }(count) - return cc, nil - } +func TestTransportPingWhenReadingPingDisabled(t *testing.T) { + tc := newTestClientConn(t, func(tr *Transport) { + tr.ReadIdleTimeout = 0 // PINGs disabled + }) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) - client(tr) - tr.CloseIdleConnections() - ln.Close() - for _, c := range conns { - c.Close() + // No PING is sent, even after a long delay. + tc.advance(1 * time.Minute) + if f := tc.readFrame(); f != nil { + t.Fatalf("unexpected frame: %v", f) } - wg.Wait() } func TestTransportRetryAfterGOAWAY(t *testing.T) { - client := func(tr *Transport) { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := tr.RoundTrip(req) - if res != nil { - res.Body.Close() - if got := res.Header.Get("Foo"); got != "bar" { - err = fmt.Errorf("foo header = %q; want bar", got) - } - } - if err != nil { - t.Errorf("RoundTrip: %v", err) - } - } - - server := func(count int, ct *clientTester) { - switch count { - case 1: - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - t.Errorf("server1 failed reading HEADERS: %v", err) - return - } - t.Logf("server1 got %v", hf) - if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil { - t.Errorf("server1 failed writing GOAWAY: %v", err) - return - } - case 2: - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - t.Errorf("server2 failed reading HEADERS: %v", err) - return - } - t.Logf("server2 got %v", hf) - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - err = ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - if err != nil { - t.Errorf("server2 failed writing response HEADERS: %v", err) - } - default: - t.Errorf("unexpected number of dials") - return - } - } + tt := newTestTransport(t) + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) + + // First attempt: Server sends a GOAWAY. + tc := tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc.writeSettings() + tc.writeGoAway(0 /*max id*/, ErrCodeNo, nil) + if rt.done() { + t.Fatalf("after GOAWAY, RoundTrip is done; want it to be retrying") + } + + // Second attempt succeeds on a new connection. + tc = tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc.writeSettings() + tc.writeHeaders(HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) - testClientMultipleDials(t, client, server) + rt.wantStatus(200) } func TestTransportRetryAfterRefusedStream(t *testing.T) { - clientDone := make(chan struct{}) - client := func(tr *Transport) { - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - resp, err := tr.RoundTrip(req) - if err != nil { - t.Errorf("RoundTrip: %v", err) - return - } - resp.Body.Close() - if resp.StatusCode != 204 { - t.Errorf("Status = %v; want 204", resp.StatusCode) - return - } + tt := newTestTransport(t) + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) + + // First attempt: Server sends a RST_STREAM. + tc := tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + tc.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc.writeSettings() + tc.wantFrameType(FrameSettings) // settings ACK + tc.writeRSTStream(1, ErrCodeRefusedStream) + if rt.done() { + t.Fatalf("after RST_STREAM, RoundTrip is done; want it to be retrying") } - server := func(_ int, ct *clientTester) { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - var count int - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - default: - t.Error(err) - } - return - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - t.Errorf("headers should have END_HEADERS be ended: %v", f) - return - } - count++ - if count == 1 { - ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) - } else { - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - default: - t.Errorf("Unexpected client frame %v", f) - return - } - } - } + // Second attempt succeeds on the same connection. + tc.wantHeaders(wantHeader{ + streamID: 3, + endStream: true, + }) + tc.writeSettings() + tc.writeHeaders(HeadersFrameParam{ + StreamID: 3, + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "204", + ), + }) - testClientMultipleDials(t, client, server) + rt.wantStatus(204) } func TestTransportRetryHasLimit(t *testing.T) { - // Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s. - if testing.Short() { - t.Skip("skipping long test in short mode") - } - retryBackoffHook = func(d time.Duration) *time.Timer { - return time.NewTimer(0) // fires immediately - } - defer func() { - retryBackoffHook = nil - }() - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - resp, err := ct.tr.RoundTrip(req) - if err == nil { - return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) + tt := newTestTransport(t) + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tt.roundTrip(req) + + // First attempt: Server sends a GOAWAY. + tc := tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + + var totalDelay time.Duration + count := 0 + for streamID := uint32(1); ; streamID += 2 { + count++ + tc.wantHeaders(wantHeader{ + streamID: streamID, + endStream: true, + }) + if streamID == 1 { + tc.writeSettings() + tc.wantFrameType(FrameSettings) // settings ACK } - t.Logf("expected error, got: %v", err) - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) - default: - return fmt.Errorf("Unexpected client frame %v", f) + tc.writeRSTStream(streamID, ErrCodeRefusedStream) + + d := tt.tr.syncHooks.timeUntilEvent() + if d == 0 { + if streamID == 1 { + continue } + break + } + totalDelay += d + if totalDelay > 5*time.Minute { + t.Fatalf("RoundTrip still retrying after %v, should have given up", totalDelay) } + tt.advance(d) + } + if got, want := count, 5; got < count { + t.Errorf("RoundTrip made %v attempts, want at least %v", got, want) + } + if rt.err() == nil { + t.Errorf("RoundTrip succeeded, want error") } - ct.run() } func TestTransportResponseDataBeforeHeaders(t *testing.T) { - // This test use not valid response format. - // Discarding logger output to not spam tests output. - log.SetOutput(ioutil.Discard) - defer log.SetOutput(os.Stderr) + // Discard log output complaining about protocol error. + log.SetOutput(io.Discard) + t.Cleanup(func() { log.SetOutput(os.Stderr) }) // after other cleanup is done + + tc := newTestClientConn(t) + tc.greet() + + // First request is normal to ensure the check is per stream and not per connection. + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt1 := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt1.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt1.wantStatus(200) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - req := httptest.NewRequest("GET", "https://dummy.tld/", nil) - // First request is normal to ensure the check is per stream and not per connection. - _, err := ct.tr.RoundTrip(req) - if err != nil { - return fmt.Errorf("RoundTrip expected no error, got: %v", err) - } - // Second request returns a DATA frame with no HEADERS. - resp, err := ct.tr.RoundTrip(req) - if err == nil { - return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) - } - if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol { - return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err) - } - return nil - } - ct.server = func() error { - ct.greet() - for { - f, err := ct.fr.ReadFrame() - if err == io.EOF { - return nil - } else if err != nil { - return err - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame, *RSTStreamFrame: - case *HeadersFrame: - switch f.StreamID { - case 1: - // Send a valid response to first request. - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - case 3: - ct.fr.WriteData(f.StreamID, true, []byte("payload")) - } - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } + // Second request returns a DATA frame with no HEADERS. + rt2 := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + tc.writeData(rt2.streamID(), true, []byte("payload")) + if err, ok := rt2.err().(StreamError); !ok || err.Code != ErrCodeProtocol { + t.Fatalf("expected stream PROTOCOL_ERROR, got: %v", err) } - ct.run() } func TestTransportMaxFrameReadSize(t *testing.T) { @@ -4053,30 +3312,17 @@ func TestTransportMaxFrameReadSize(t *testing.T) { maxReadFrameSize: 1024, want: minMaxFrameSize, }} { - ct := newClientTester(t) - ct.tr.MaxReadFrameSize = test.maxReadFrameSize - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody) - ct.tr.RoundTrip(req) - return nil - } - ct.server = func() error { - defer ct.cc.(*net.TCPConn).Close() - ct.greet() - var got uint32 - ct.settings.ForeachSetting(func(s Setting) error { - switch s.ID { - case SettingMaxFrameSize: - got = s.Val - } - return nil - }) - if got != test.want { - t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want) - } - return nil + tc := newTestClientConn(t, func(tr *Transport) { + tr.MaxReadFrameSize = test.maxReadFrameSize + }) + + fr := testClientConnReadFrame[*SettingsFrame](tc) + got, ok := fr.Value(SettingMaxFrameSize) + if !ok { + t.Errorf("Transport.MaxReadFrameSize = %v; server got no setting, want %v", test.maxReadFrameSize, test.want) + } else if got != test.want { + t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want) } - ct.run() } } @@ -4108,345 +3354,134 @@ func TestTransportRequestsLowServerLimit(t *testing.T) { if err != nil { t.Fatal(err) } - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - if got, want := res.StatusCode, 200; got != want { - t.Errorf("StatusCode = %v; want %v", got, want) - } - if res != nil && res.Body != nil { - res.Body.Close() - } - } - - if connCount != 1 { - t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount) - } -} - -// tests Transport.StrictMaxConcurrentStreams -func TestTransportRequestsStallAtServerLimit(t *testing.T) { - const maxConcurrent = 2 - - greet := make(chan struct{}) // server sends initial SETTINGS frame - gotRequest := make(chan struct{}) // server received a request - clientDone := make(chan struct{}) - cancelClientRequest := make(chan struct{}) - - // Collect errors from goroutines. - var wg sync.WaitGroup - errs := make(chan error, 100) - defer func() { - wg.Wait() - close(errs) - for err := range errs { - t.Error(err) - } - }() - - // We will send maxConcurrent+2 requests. This checker goroutine waits for the - // following stages: - // 1. The first maxConcurrent requests are received by the server. - // 2. The client will cancel the next request - // 3. The server is unblocked so it can service the first maxConcurrent requests - // 4. The client will send the final request - wg.Add(1) - unblockClient := make(chan struct{}) - clientRequestCancelled := make(chan struct{}) - unblockServer := make(chan struct{}) - go func() { - defer wg.Done() - // Stage 1. - for k := 0; k < maxConcurrent; k++ { - <-gotRequest - } - // Stage 2. - close(unblockClient) - <-clientRequestCancelled - // Stage 3: give some time for the final RoundTrip call to be scheduled and - // verify that the final request is not sent. - time.Sleep(50 * time.Millisecond) - select { - case <-gotRequest: - errs <- errors.New("last request did not stall") - close(unblockServer) - return - default: - } - close(unblockServer) - // Stage 4. - <-gotRequest - }() - - ct := newClientTester(t) - ct.tr.StrictMaxConcurrentStreams = true - ct.client = func() error { - var wg sync.WaitGroup - defer func() { - wg.Wait() - close(clientDone) - ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - ct.cc.(*net.TCPConn).Close() - } - }() - for k := 0; k < maxConcurrent+2; k++ { - wg.Add(1) - go func(k int) { - defer wg.Done() - // Don't send the second request until after receiving SETTINGS from the server - // to avoid a race where we use the default SettingMaxConcurrentStreams, which - // is much larger than maxConcurrent. We have to send the first request before - // waiting because the first request triggers the dial and greet. - if k > 0 { - <-greet - } - // Block until maxConcurrent requests are sent before sending any more. - if k >= maxConcurrent { - <-unblockClient - } - body := newStaticCloseChecker("") - req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body) - if k == maxConcurrent { - // This request will be canceled. - req.Cancel = cancelClientRequest - close(cancelClientRequest) - _, err := ct.tr.RoundTrip(req) - close(clientRequestCancelled) - if err == nil { - errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k) - return - } - } else { - resp, err := ct.tr.RoundTrip(req) - if err != nil { - errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) - return - } - ioutil.ReadAll(resp.Body) - resp.Body.Close() - if resp.StatusCode != 204 { - errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode) - return - } - } - if err := body.isClosed(); err != nil { - errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) - } - }(k) + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if got, want := res.StatusCode, 200; got != want { + t.Errorf("StatusCode = %v; want %v", got, want) + } + if res != nil && res.Body != nil { + res.Body.Close() } - return nil } - ct.server = func() error { - var wg sync.WaitGroup - defer wg.Wait() + if connCount != 1 { + t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount) + } +} - ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent}) +// tests Transport.StrictMaxConcurrentStreams +func TestTransportRequestsStallAtServerLimit(t *testing.T) { + const maxConcurrent = 2 - // Server write loop. - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - writeResp := make(chan uint32, maxConcurrent+1) + tc := newTestClientConn(t, func(tr *Transport) { + tr.StrictMaxConcurrentStreams = true + }) + tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent}) - wg.Add(1) - go func() { - defer wg.Done() - <-unblockServer - for id := range writeResp { - buf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: id, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - }() + cancelClientRequest := make(chan struct{}) - // Server read loop. - var nreq int - for { - f, err := ct.fr.ReadFrame() - if err != nil { - select { - case <-clientDone: - // If the client's done, it will have reported any errors on its side. - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame: - case *SettingsFrame: - // Wait for the client SETTINGS ack until ending the greet. - close(greet) - case *HeadersFrame: - if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) - } - gotRequest <- struct{}{} - nreq++ - writeResp <- f.StreamID - if nreq == maxConcurrent+1 { - close(writeResp) - } - case *DataFrame: - default: - return fmt.Errorf("Unexpected client frame %v", f) + // Start maxConcurrent+2 requests. + // The server does not respond to any of them yet. + var rts []*testRoundTrip + for k := 0; k < maxConcurrent+2; k++ { + req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil) + if k == maxConcurrent { + req.Cancel = cancelClientRequest + } + rt := tc.roundTrip(req) + rts = append(rts, rt) + + if k < maxConcurrent { + // We are under the stream limit, so the client sends the request. + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: true, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{fmt.Sprintf("/%d", k)}, + }, + }) + } else { + // We have reached the stream limit, + // so the client cannot send the request. + if fr := tc.readFrame(); fr != nil { + t.Fatalf("after making new request while at stream limit, got unexpected frame: %v", fr) } } + + if rt.done() { + t.Fatalf("rt %v done", k) + } + } + + // Cancel the maxConcurrent'th request. + // The request should fail. + close(cancelClientRequest) + tc.sync() + if err := rts[maxConcurrent].err(); err == nil { + t.Fatalf("RoundTrip(%d) should have failed due to cancel, did not", maxConcurrent) + } + + // No requests should be complete, except for the canceled one. + for i, rt := range rts { + if i != maxConcurrent && rt.done() { + t.Fatalf("RoundTrip(%d) is done, but should not be", i) + } } - ct.run() + // Server responds to a request, unblocking the last one. + tc.writeHeaders(HeadersFrameParam{ + StreamID: rts[0].streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + tc.wantHeaders(wantHeader{ + streamID: rts[maxConcurrent+1].streamID(), + endStream: true, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{fmt.Sprintf("/%d", maxConcurrent+1)}, + }, + }) + rts[0].wantStatus(200) } func TestTransportMaxDecoderHeaderTableSize(t *testing.T) { - ct := newClientTester(t) var reqSize, resSize uint32 = 8192, 16384 - ct.tr.MaxDecoderHeaderTableSize = reqSize - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - cc, err := ct.tr.NewClientConn(ct.cc) - if err != nil { - return err - } - _, err = cc.RoundTrip(req) - if err != nil { - return err - } - if got, want := cc.peerMaxHeaderTableSize, resSize; got != want { - return fmt.Errorf("peerHeaderTableSize = %d, want %d", got, want) - } - return nil + tc := newTestClientConn(t, func(tr *Transport) { + tr.MaxDecoderHeaderTableSize = reqSize + }) + + fr := testClientConnReadFrame[*SettingsFrame](tc) + if v, ok := fr.Value(SettingHeaderTableSize); !ok { + t.Fatalf("missing SETTINGS_HEADER_TABLE_SIZE setting") + } else if v != reqSize { + t.Fatalf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", v, reqSize) } - ct.server = func() error { - buf := make([]byte, len(ClientPreface)) - _, err := io.ReadFull(ct.sc, buf) - if err != nil { - return fmt.Errorf("reading client preface: %v", err) - } - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - sf, ok := f.(*SettingsFrame) - if !ok { - ct.t.Fatalf("wanted client settings frame; got %v", f) - _ = sf // stash it away? - } - var found bool - err = sf.ForeachSetting(func(s Setting) error { - if s.ID == SettingHeaderTableSize { - found = true - if got, want := s.Val, reqSize; got != want { - return fmt.Errorf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", got, want) - } - } - return nil - }) - if err != nil { - return err - } - if !found { - return fmt.Errorf("missing SETTINGS_HEADER_TABLE_SIZE setting") - } - if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, resSize}); err != nil { - ct.t.Fatal(err) - } - if err := ct.fr.WriteSettingsAck(); err != nil { - ct.t.Fatal(err) - } - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - } + tc.writeSettings(Setting{SettingHeaderTableSize, resSize}) + if got, want := tc.cc.peerMaxHeaderTableSize, resSize; got != want { + t.Fatalf("peerHeaderTableSize = %d, want %d", got, want) } - ct.run() } func TestTransportMaxEncoderHeaderTableSize(t *testing.T) { - ct := newClientTester(t) var peerAdvertisedMaxHeaderTableSize uint32 = 16384 - ct.tr.MaxEncoderHeaderTableSize = 8192 - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - cc, err := ct.tr.NewClientConn(ct.cc) - if err != nil { - return err - } - _, err = cc.RoundTrip(req) - if err != nil { - return err - } - if got, want := cc.henc.MaxDynamicTableSize(), ct.tr.MaxEncoderHeaderTableSize; got != want { - return fmt.Errorf("henc.MaxDynamicTableSize() = %d, want %d", got, want) - } - return nil - } - ct.server = func() error { - buf := make([]byte, len(ClientPreface)) - _, err := io.ReadFull(ct.sc, buf) - if err != nil { - return fmt.Errorf("reading client preface: %v", err) - } - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - sf, ok := f.(*SettingsFrame) - if !ok { - ct.t.Fatalf("wanted client settings frame; got %v", f) - _ = sf // stash it away? - } - if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize}); err != nil { - ct.t.Fatal(err) - } - if err := ct.fr.WriteSettingsAck(); err != nil { - ct.t.Fatal(err) - } + tc := newTestClientConn(t, func(tr *Transport) { + tr.MaxEncoderHeaderTableSize = 8192 + }) + tc.greet(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize}) - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - return nil - } - } + if got, want := tc.cc.henc.MaxDynamicTableSize(), tc.tr.MaxEncoderHeaderTableSize; got != want { + t.Fatalf("henc.MaxDynamicTableSize() = %d, want %d", got, want) } - ct.run() } func TestAuthorityAddr(t *testing.T) { @@ -4530,40 +3565,24 @@ func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) { // Issue 18891: make sure Request.Body == NoBody means no DATA frame // is ever sent, even if empty. func TestTransportNoBodyMeansNoDATA(t *testing.T) { - ct := newClientTester(t) - - unblockClient := make(chan bool) - - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody) - ct.tr.RoundTrip(req) - <-unblockClient - return nil - } - ct.server = func() error { - defer close(unblockClient) - defer ct.cc.(*net.TCPConn).Close() - ct.greet() - - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) - } - switch f := f.(type) { - default: - return fmt.Errorf("Got %T; want HeadersFrame", f) - case *WindowUpdateFrame, *SettingsFrame: - continue - case *HeadersFrame: - if !f.StreamEnded() { - return fmt.Errorf("got headers frame without END_STREAM") - } - return nil - } - } + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody) + rt := tc.roundTrip(req) + + tc.wantHeaders(wantHeader{ + streamID: rt.streamID(), + endStream: true, // END_STREAM should be set when body is http.NoBody + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{"/"}, + }, + }) + if fr := tc.readFrame(); fr != nil { + t.Fatalf("unexpected frame after headers: %v", fr) } - ct.run() } func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) { @@ -4642,41 +3661,22 @@ func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) { // Verify transport doesn't crash when receiving bogus response lacking a :status header. // Issue 22880. func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) { - ct := newClientTester(t) - ct.client = func() error { - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - _, err := ct.tr.RoundTrip(req) - const substr = "malformed response from server: missing status pseudo header" - if !strings.Contains(fmt.Sprint(err), substr) { - return fmt.Errorf("RoundTrip error = %v; want substring %q", err, substr) - } - return nil - } - ct.server = func() error { - ct.greet() - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } - switch f := f.(type) { - case *HeadersFrame: - enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"}) // no :status header - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, // we'll send some DATA to try to crash the transport - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(f.StreamID, true, []byte("payload")) - return nil - } - } - } - ct.run() + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, // we'll send some DATA to try to crash the transport + BlockFragment: tc.makeHeaderBlockFragment( + "content-type", "text/html", // no :status header + ), + }) + tc.writeData(rt.streamID(), true, []byte("payload")) } func BenchmarkClientRequestHeaders(b *testing.B) { @@ -5024,95 +4024,42 @@ func (r *errReader) Read(p []byte) (int, error) { } func testTransportBodyReadError(t *testing.T, body []byte) { - if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { - // So far we've only seen this be flaky on Windows and Plan 9, - // perhaps due to TCP behavior on shutdowns while - // unread data is in flight. This test should be - // fixed, but a skip is better than annoying people - // for now. - t.Skipf("skipping flaky test on %s; https://golang.org/issue/31260", runtime.GOOS) - } - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - defer close(clientDone) - - checkNoStreams := func() error { - cp, ok := ct.tr.connPool().(*clientConnPool) - if !ok { - return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool()) - } - cp.mu.Lock() - defer cp.mu.Unlock() - conns, ok := cp.conns["dummy.tld:443"] - if !ok { - return fmt.Errorf("missing connection") - } - if len(conns) != 1 { - return fmt.Errorf("conn pool size: %v; expect 1", len(conns)) - } - if activeStreams(conns[0]) != 0 { - return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0])) - } - return nil - } - bodyReadError := errors.New("body read error") - body := &errReader{body, bodyReadError} - req, err := http.NewRequest("PUT", "https://dummy.tld/", body) - if err != nil { - return err - } - _, err = ct.tr.RoundTrip(req) - if err != bodyReadError { - return fmt.Errorf("err = %v; want %v", err, bodyReadError) - } - if err = checkNoStreams(); err != nil { - return err + tc := newTestClientConn(t) + tc.greet() + + bodyReadError := errors.New("body read error") + b := tc.newRequestBody() + b.Write(body) + b.closeWithError(bodyReadError) + req, _ := http.NewRequest("PUT", "https://dummy.tld/", b) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + var receivedBody []byte +readFrames: + for { + switch f := tc.readFrame().(type) { + case *DataFrame: + receivedBody = append(receivedBody, f.Data()...) + case *RSTStreamFrame: + break readFrames + default: + t.Fatalf("unexpected frame: %v", f) + case nil: + t.Fatalf("transport is idle, want RST_STREAM") } - return nil } - ct.server = func() error { - ct.greet() - var receivedBody []byte - var resetCount int - for { - f, err := ct.fr.ReadFrame() - t.Logf("server: ReadFrame = %v, %v", f, err) - if err != nil { - select { - case <-clientDone: - // If the client's done, it - // will have reported any - // errors on its side. - if bytes.Compare(receivedBody, body) != 0 { - return fmt.Errorf("body: %q; expected %q", receivedBody, body) - } - if resetCount != 1 { - return fmt.Errorf("stream reset count: %v; expected: 1", resetCount) - } - return nil - default: - return err - } - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - case *DataFrame: - receivedBody = append(receivedBody, f.Data()...) - case *RSTStreamFrame: - resetCount++ - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } + if !bytes.Equal(receivedBody, body) { + t.Fatalf("body: %q; expected %q", receivedBody, body) + } + + if err := rt.err(); err != bodyReadError { + t.Fatalf("err = %v; want %v", err, bodyReadError) + } + + if got := activeStreams(tc.cc); got != 0 { + t.Fatalf("active streams count: %v; want 0", got) } - ct.run() } func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) } @@ -5125,59 +4072,18 @@ func TestTransportBodyEagerEndStream(t *testing.T) { const reqBody = "some request body" const resBody = "some response body" - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } - body := strings.NewReader(reqBody) - req, err := http.NewRequest("PUT", "https://dummy.tld/", body) - if err != nil { - return err - } - _, err = ct.tr.RoundTrip(req) - if err != nil { - return err - } - return nil - } - ct.server = func() error { - ct.greet() + tc := newTestClientConn(t) + tc.greet() - for { - f, err := ct.fr.ReadFrame() - if err != nil { - return err - } + body := strings.NewReader(reqBody) + req, _ := http.NewRequest("PUT", "https://dummy.tld/", body) + tc.roundTrip(req) - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - case *DataFrame: - if !f.StreamEnded() { - ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) - return fmt.Errorf("data frame without END_STREAM %v", f) - } - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.Header().StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - ct.fr.WriteData(f.StreamID, true, []byte(resBody)) - return nil - case *RSTStreamFrame: - default: - return fmt.Errorf("Unexpected client frame %v", f) - } - } + tc.wantFrameType(FrameHeaders) + f := testClientConnReadFrame[*DataFrame](tc) + if !f.StreamEnded() { + t.Fatalf("data frame without END_STREAM %v", f) } - ct.run() } type chunkReader struct { @@ -5826,155 +4732,80 @@ func TestTransportCloseRequestBody(t *testing.T) { } } -// collectClientsConnPool is a ClientConnPool that wraps lower and -// collects what calls were made on it. -type collectClientsConnPool struct { - lower ClientConnPool - - mu sync.Mutex - getErrs int - got []*ClientConn -} - -func (p *collectClientsConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { - cc, err := p.lower.GetClientConn(req, addr) - p.mu.Lock() - defer p.mu.Unlock() - if err != nil { - p.getErrs++ - return nil, err - } - p.got = append(p.got, cc) - return cc, nil -} - -func (p *collectClientsConnPool) MarkDead(cc *ClientConn) { - p.lower.MarkDead(cc) -} - func TestTransportRetriesOnStreamProtocolError(t *testing.T) { - ct := newClientTester(t) - pool := &collectClientsConnPool{ - lower: &clientConnPool{t: ct.tr}, - } - ct.tr.ConnPool = pool + // This test verifies that + // - receiving a protocol error on a connection does not interfere with + // other requests in flight on that connection; + // - the connection is not reused for further requests; and + // - the failed request is retried on a new connecection. + tt := newTestTransport(t) + + // Start two requests. The first is a long request + // that will finish after the second. The second one + // will result in the protocol error. + + // Request #1: The long request. + req1, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt1 := tt.roundTrip(req1) + tc1 := tt.getConn() + tc1.wantFrameType(FrameSettings) + tc1.wantFrameType(FrameWindowUpdate) + tc1.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc1.writeSettings() + tc1.wantFrameType(FrameSettings) // settings ACK + + // Request #2(a): The short request. + req2, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt2 := tt.roundTrip(req2) + tc1.wantHeaders(wantHeader{ + streamID: 3, + endStream: true, + }) - gotProtoError := make(chan bool, 1) - ct.tr.CountError = func(errType string) { - if errType == "recv_rststream_PROTOCOL_ERROR" { - select { - case gotProtoError <- true: - default: - } - } + // Request #2(a) fails with ErrCodeProtocol. + tc1.writeRSTStream(3, ErrCodeProtocol) + if rt1.done() { + t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #1 is done; want still in progress") } - ct.client = func() error { - // Start two requests. The first is a long request - // that will finish after the second. The second one - // will result in the protocol error. We check that - // after the first one closes, the connection then - // shuts down. - - // The long, outer request. - req1, _ := http.NewRequest("GET", "https://dummy.tld/long", nil) - res1, err := ct.tr.RoundTrip(req1) - if err != nil { - return err - } - if got, want := res1.Header.Get("Is-Long"), "1"; got != want { - return fmt.Errorf("First response's Is-Long header = %q; want %q", got, want) - } - - req, _ := http.NewRequest("POST", "https://dummy.tld/fails", nil) - res, err := ct.tr.RoundTrip(req) - const want = "only one dial allowed in test mode" - if got := fmt.Sprint(err); got != want { - t.Errorf("didn't dial again: got %#q; want %#q", got, want) - } - if res != nil { - res.Body.Close() - } - select { - case <-gotProtoError: - default: - t.Errorf("didn't get stream protocol error") - } - - if n, err := res1.Body.Read(make([]byte, 10)); err != io.EOF || n != 0 { - t.Errorf("unexpected body read %v, %v", n, err) - } - - pool.mu.Lock() - defer pool.mu.Unlock() - if pool.getErrs != 1 { - t.Errorf("pool get errors = %v; want 1", pool.getErrs) - } - if len(pool.got) == 2 { - if pool.got[0] != pool.got[1] { - t.Errorf("requests went on different connections") - } - cc := pool.got[0] - cc.mu.Lock() - if !cc.doNotReuse { - t.Error("ClientConn not marked doNotReuse") - } - cc.mu.Unlock() - - select { - case <-cc.readerDone: - case <-time.After(5 * time.Second): - t.Errorf("timeout waiting for reader to be done") - } - } else { - t.Errorf("pool get success = %v; want 2", len(pool.got)) - } - return nil + if rt2.done() { + t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #2 is done; want still in progress") } - ct.server = func() error { - ct.greet() - var sentErr bool - var numHeaders int - var firstStreamID uint32 - - var hbuf bytes.Buffer - enc := hpack.NewEncoder(&hbuf) - for { - f, err := ct.fr.ReadFrame() - if err == io.EOF { - // Client hung up on us, as it should at the end. - return nil - } - if err != nil { - return nil - } - switch f := f.(type) { - case *WindowUpdateFrame, *SettingsFrame: - case *HeadersFrame: - numHeaders++ - if numHeaders == 1 { - firstStreamID = f.StreamID - hbuf.Reset() - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "is-long", Value: "1"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: f.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: hbuf.Bytes(), - }) - continue - } - if !sentErr { - sentErr = true - ct.fr.WriteRSTStream(f.StreamID, ErrCodeProtocol) - ct.fr.WriteData(firstStreamID, true, nil) - continue - } - } - } - } - ct.run() + // Request #2(b): The short request is retried on a new connection. + tc2 := tt.getConn() + tc2.wantFrameType(FrameSettings) + tc2.wantFrameType(FrameWindowUpdate) + tc2.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + }) + tc2.writeSettings() + tc2.wantFrameType(FrameSettings) // settings ACK + + // Request #2(b) succeeds. + tc2.writeHeaders(HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc1.makeHeaderBlockFragment( + ":status", "201", + ), + }) + rt2.wantStatus(201) + + // Request #1 succeeds. + tc1.writeHeaders(HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc1.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt1.wantStatus(200) } func TestClientConnReservations(t *testing.T) { @@ -5987,7 +4818,7 @@ func TestClientConnReservations(t *testing.T) { tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() - cc, err := tr.newClientConn(st.cc, false) + cc, err := tr.newClientConn(st.cc, false, nil) if err != nil { t.Fatal(err) } @@ -6026,39 +4857,27 @@ func TestClientConnReservations(t *testing.T) { } func TestTransportTimeoutServerHangs(t *testing.T) { - clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - defer close(clientDone) + tc := newTestClientConn(t) + tc.greet() - req, err := http.NewRequest("PUT", "https://dummy.tld/", nil) - if err != nil { - return err - } + ctx, cancel := context.WithCancel(context.Background()) + req, _ := http.NewRequestWithContext(ctx, "PUT", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - req = req.WithContext(ctx) - req.Header.Add("Big", strings.Repeat("a", 1<<20)) - _, err = ct.tr.RoundTrip(req) - if err == nil { - return errors.New("error should not be nil") - } - if ne, ok := err.(net.Error); !ok || !ne.Timeout() { - return fmt.Errorf("error should be a net error timeout: %v", err) - } - return nil + tc.wantFrameType(FrameHeaders) + tc.advance(5 * time.Second) + if f := tc.readFrame(); f != nil { + t.Fatalf("unexpected frame: %v", f) } - ct.server = func() error { - ct.greet() - select { - case <-time.After(5 * time.Second): - case <-clientDone: - } - return nil + if rt.done() { + t.Fatalf("after 5 seconds with no response, RoundTrip unexpectedly returned") + } + + cancel() + tc.sync() + if rt.err() != context.Canceled { + t.Fatalf("RoundTrip error: %v; want context.Canceled", rt.err()) } - ct.run() } func TestTransportContentLengthWithoutBody(t *testing.T) { @@ -6251,20 +5070,6 @@ func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) { testTransportClosesConnAfterGoAway(t, 1) } -type closeOnceConn struct { - net.Conn - closed uint32 -} - -var errClosed = errors.New("Close of closed connection") - -func (c *closeOnceConn) Close() error { - if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { - return c.Conn.Close() - } - return errClosed -} - // testTransportClosesConnAfterGoAway verifies that the transport // closes a connection after reading a GOAWAY from it. // @@ -6272,53 +5077,35 @@ func (c *closeOnceConn) Close() error { // When 0, the transport (unsuccessfully) retries the request (stream 1); // when 1, the transport reads the response after receiving the GOAWAY. func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) { - ct := newClientTester(t) - ct.cc = &closeOnceConn{Conn: ct.cc} - - var wg sync.WaitGroup - wg.Add(1) - ct.client = func() error { - defer wg.Done() - req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - res, err := ct.tr.RoundTrip(req) - if err == nil { - res.Body.Close() - } - if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr { - t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr) - } - if err = ct.cc.Close(); err != errClosed { - return fmt.Errorf("ct.cc.Close() = %v, want errClosed", err) - } - return nil + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeGoAway(lastStream, ErrCodeNo, nil) + + if lastStream > 0 { + // Send a valid response to first request. + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) } - ct.server = func() error { - defer wg.Wait() - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - return fmt.Errorf("server failed reading HEADERS: %v", err) - } - if err := ct.fr.WriteGoAway(lastStream, ErrCodeNo, nil); err != nil { - return fmt.Errorf("server failed writing GOAWAY: %v", err) - } - if lastStream > 0 { - // Send a valid response to first request. - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: true, - BlockFragment: buf.Bytes(), - }) - } - return nil + tc.closeWrite(io.EOF) + err := rt.err() + if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr { + t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr) + } + if !tc.netConnClosed { + t.Errorf("ClientConn did not close its net.Conn, expected it to") } - - ct.run() } type slowCloser struct { @@ -6520,3 +5307,32 @@ func TestDialRaceResumesDial(t *testing.T) { case <-successCh: } } + +func TestTransportDataAfter1xxHeader(t *testing.T) { + // Discard logger output to avoid spamming stderr. + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + // https://go.dev/issue/65927 - server sends a 1xx response, followed by a DATA frame. + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "100", + ), + }) + tc.writeData(rt.streamID(), true, []byte{0}) + err := rt.err() + if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol { + t.Errorf("RoundTrip error: %v; want ErrCodeProtocol", err) + } + tc.wantFrameType(FrameRSTStream) +} diff --git a/internal/quic/cmd/interop/main.go b/internal/quic/cmd/interop/main.go index 20f737b52..5b652a2b1 100644 --- a/internal/quic/cmd/interop/main.go +++ b/internal/quic/cmd/interop/main.go @@ -25,8 +25,8 @@ import ( "path/filepath" "sync" - "golang.org/x/net/internal/quic" - "golang.org/x/net/internal/quic/qlog" + "golang.org/x/net/quic" + "golang.org/x/net/quic/qlog" ) var ( @@ -148,7 +148,7 @@ func basicTest(ctx context.Context, config *quic.Config, urls []string) { g.Add(1) go func() { defer g.Done() - fetchFrom(ctx, l, addr, u) + fetchFrom(ctx, config, l, addr, u) }() } @@ -221,8 +221,8 @@ func parseURL(s string) (u *url.URL, authority string, err error) { return u, authority, nil } -func fetchFrom(ctx context.Context, l *quic.Endpoint, addr string, urls []*url.URL) { - conn, err := l.Dial(ctx, "udp", addr) +func fetchFrom(ctx context.Context, config *quic.Config, l *quic.Endpoint, addr string, urls []*url.URL) { + conn, err := l.Dial(ctx, "udp", addr, config) if err != nil { log.Printf("%v: %v", addr, err) return diff --git a/internal/quic/doc.go b/internal/quic/doc.go deleted file mode 100644 index 2fe17fe22..000000000 --- a/internal/quic/doc.go +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright 2023 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package quic is an experimental, incomplete implementation of the QUIC protocol. -// This package is a work in progress, and is not ready for use at this time. -// -// This package implements (or will implement) RFC 9000, RFC 9001, and RFC 9002. -package quic diff --git a/internal/quic/ack_delay.go b/quic/ack_delay.go similarity index 100% rename from internal/quic/ack_delay.go rename to quic/ack_delay.go diff --git a/internal/quic/ack_delay_test.go b/quic/ack_delay_test.go similarity index 100% rename from internal/quic/ack_delay_test.go rename to quic/ack_delay_test.go diff --git a/internal/quic/acks.go b/quic/acks.go similarity index 91% rename from internal/quic/acks.go rename to quic/acks.go index ba860efb2..039b7b46e 100644 --- a/internal/quic/acks.go +++ b/quic/acks.go @@ -130,12 +130,19 @@ func (acks *ackState) mustAckImmediately(space numberSpace, num packetNumber) bo // there are no gaps. If it does not, there must be a gap. return true } - if acks.unackedAckEliciting >= 2 { - // "[...] after receiving at least two ack-eliciting packets." - // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.2 - return true + // "[...] SHOULD send an ACK frame after receiving at least two ack-eliciting packets." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.2 + // + // This ack frequency takes a substantial toll on performance, however. + // Follow the behavior of Google QUICHE: + // Ack every other packet for the first 100 packets, and then ack every 10th packet. + // This keeps ack frequency high during the beginning of slow start when CWND is + // increasing rapidly. + packetsBeforeAck := 2 + if acks.seen.max() > 100 { + packetsBeforeAck = 10 } - return false + return acks.unackedAckEliciting >= packetsBeforeAck } // shouldSendAck reports whether the connection should send an ACK frame at this time, diff --git a/internal/quic/acks_test.go b/quic/acks_test.go similarity index 94% rename from internal/quic/acks_test.go rename to quic/acks_test.go index 4f1032910..d10f917ad 100644 --- a/internal/quic/acks_test.go +++ b/quic/acks_test.go @@ -7,6 +7,7 @@ package quic import ( + "slices" "testing" "time" ) @@ -198,7 +199,7 @@ func TestAcksSent(t *testing.T) { if len(gotNums) == 0 { wantDelay = 0 } - if !slicesEqual(gotNums, test.wantAcks) || gotDelay != wantDelay { + if !slices.Equal(gotNums, test.wantAcks) || gotDelay != wantDelay { t.Errorf("acks.acksToSend(T+%v) = %v, %v; want %v, %v", delay, gotNums, gotDelay, test.wantAcks, wantDelay) } } @@ -206,20 +207,6 @@ func TestAcksSent(t *testing.T) { } } -// slicesEqual reports whether two slices are equal. -// Replace this with slices.Equal once the module go.mod is go1.17 or newer. -func slicesEqual[E comparable](s1, s2 []E) bool { - if len(s1) != len(s2) { - return false - } - for i := range s1 { - if s1[i] != s2[i] { - return false - } - } - return true -} - func TestAcksDiscardAfterAck(t *testing.T) { acks := ackState{} now := time.Now() diff --git a/internal/quic/atomic_bits.go b/quic/atomic_bits.go similarity index 100% rename from internal/quic/atomic_bits.go rename to quic/atomic_bits.go diff --git a/quic/bench_test.go b/quic/bench_test.go new file mode 100644 index 000000000..636b71327 --- /dev/null +++ b/quic/bench_test.go @@ -0,0 +1,170 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "context" + "fmt" + "io" + "math" + "sync" + "testing" +) + +// BenchmarkThroughput is based on the crypto/tls benchmark of the same name. +func BenchmarkThroughput(b *testing.B) { + for size := 1; size <= 64; size <<= 1 { + name := fmt.Sprintf("%dMiB", size) + b.Run(name, func(b *testing.B) { + throughput(b, int64(size<<20)) + }) + } +} + +func throughput(b *testing.B, totalBytes int64) { + // Same buffer size as crypto/tls's BenchmarkThroughput, for consistency. + const bufsize = 32 << 10 + + cli, srv := newLocalConnPair(b, &Config{}, &Config{}) + + go func() { + buf := make([]byte, bufsize) + for i := 0; i < b.N; i++ { + sconn, err := srv.AcceptStream(context.Background()) + if err != nil { + panic(fmt.Errorf("AcceptStream: %v", err)) + } + if _, err := io.CopyBuffer(sconn, sconn, buf); err != nil { + panic(fmt.Errorf("CopyBuffer: %v", err)) + } + sconn.Close() + } + }() + + b.SetBytes(totalBytes) + buf := make([]byte, bufsize) + chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf)))) + for i := 0; i < b.N; i++ { + cconn, err := cli.NewStream(context.Background()) + if err != nil { + b.Fatalf("NewStream: %v", err) + } + closec := make(chan struct{}) + go func() { + defer close(closec) + buf := make([]byte, bufsize) + if _, err := io.CopyBuffer(io.Discard, cconn, buf); err != nil { + panic(fmt.Errorf("Discard: %v", err)) + } + }() + for j := 0; j < chunks; j++ { + _, err := cconn.Write(buf) + if err != nil { + b.Fatalf("Write: %v", err) + } + } + cconn.CloseWrite() + <-closec + cconn.Close() + } +} + +func BenchmarkReadByte(b *testing.B) { + cli, srv := newLocalConnPair(b, &Config{}, &Config{}) + + var wg sync.WaitGroup + defer wg.Wait() + + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 1<<20) + sconn, err := srv.AcceptStream(context.Background()) + if err != nil { + panic(fmt.Errorf("AcceptStream: %v", err)) + } + for { + if _, err := sconn.Write(buf); err != nil { + break + } + sconn.Flush() + } + }() + + b.SetBytes(1) + cconn, err := cli.NewStream(context.Background()) + if err != nil { + b.Fatalf("NewStream: %v", err) + } + cconn.Flush() + for i := 0; i < b.N; i++ { + _, err := cconn.ReadByte() + if err != nil { + b.Fatalf("ReadByte: %v", err) + } + } + cconn.Close() +} + +func BenchmarkWriteByte(b *testing.B) { + cli, srv := newLocalConnPair(b, &Config{}, &Config{}) + + var wg sync.WaitGroup + defer wg.Wait() + + wg.Add(1) + go func() { + defer wg.Done() + sconn, err := srv.AcceptStream(context.Background()) + if err != nil { + panic(fmt.Errorf("AcceptStream: %v", err)) + } + n, err := io.Copy(io.Discard, sconn) + if n != int64(b.N) || err != nil { + b.Errorf("server io.Copy() = %v, %v; want %v, nil", n, err, b.N) + } + }() + + b.SetBytes(1) + cconn, err := cli.NewStream(context.Background()) + if err != nil { + b.Fatalf("NewStream: %v", err) + } + cconn.Flush() + for i := 0; i < b.N; i++ { + if err := cconn.WriteByte(0); err != nil { + b.Fatalf("WriteByte: %v", err) + } + } + cconn.Close() +} + +func BenchmarkStreamCreation(b *testing.B) { + cli, srv := newLocalConnPair(b, &Config{}, &Config{}) + + go func() { + for i := 0; i < b.N; i++ { + sconn, err := srv.AcceptStream(context.Background()) + if err != nil { + panic(fmt.Errorf("AcceptStream: %v", err)) + } + sconn.Close() + } + }() + + buf := make([]byte, 1) + for i := 0; i < b.N; i++ { + cconn, err := cli.NewStream(context.Background()) + if err != nil { + b.Fatalf("NewStream: %v", err) + } + cconn.Write(buf) + cconn.Flush() + cconn.Read(buf) + cconn.Close() + } +} diff --git a/internal/quic/config.go b/quic/config.go similarity index 96% rename from internal/quic/config.go rename to quic/config.go index b045b7b92..5d420312b 100644 --- a/internal/quic/config.go +++ b/quic/config.go @@ -107,6 +107,13 @@ type Config struct { QLogLogger *slog.Logger } +// Clone returns a shallow clone of c, or nil if c is nil. +// It is safe to clone a [Config] that is being used concurrently by a QUIC endpoint. +func (c *Config) Clone() *Config { + n := *c + return &n +} + func configDefault[T ~int64](v, def, limit T) T { switch { case v == 0: diff --git a/internal/quic/config_test.go b/quic/config_test.go similarity index 100% rename from internal/quic/config_test.go rename to quic/config_test.go diff --git a/internal/quic/congestion_reno.go b/quic/congestion_reno.go similarity index 83% rename from internal/quic/congestion_reno.go rename to quic/congestion_reno.go index 982cbf4bb..a53983524 100644 --- a/internal/quic/congestion_reno.go +++ b/quic/congestion_reno.go @@ -7,6 +7,8 @@ package quic import ( + "context" + "log/slog" "math" "time" ) @@ -40,6 +42,9 @@ type ccReno struct { // true if we haven't sent that packet yet. sendOnePacketInRecovery bool + // inRecovery is set when we are in the recovery state. + inRecovery bool + // underutilized is set if the congestion window is underutilized // due to insufficient application data, flow control limits, or // anti-amplification limits. @@ -100,12 +105,19 @@ func (c *ccReno) canSend() bool { // congestion controller permits sending data, but no data is sent. // // https://www.rfc-editor.org/rfc/rfc9002#section-7.8 -func (c *ccReno) setUnderutilized(v bool) { +func (c *ccReno) setUnderutilized(log *slog.Logger, v bool) { + if c.underutilized == v { + return + } + oldState := c.state() c.underutilized = v + if logEnabled(log, QLogLevelPacket) { + logCongestionStateUpdated(log, oldState, c.state()) + } } // packetSent indicates that a packet has been sent. -func (c *ccReno) packetSent(now time.Time, space numberSpace, sent *sentPacket) { +func (c *ccReno) packetSent(now time.Time, log *slog.Logger, space numberSpace, sent *sentPacket) { if !sent.inFlight { return } @@ -185,7 +197,11 @@ func (c *ccReno) packetLost(now time.Time, space numberSpace, sent *sentPacket, } // packetBatchEnd is called at the end of processing a batch of acked or lost packets. -func (c *ccReno) packetBatchEnd(now time.Time, space numberSpace, rtt *rttState, maxAckDelay time.Duration) { +func (c *ccReno) packetBatchEnd(now time.Time, log *slog.Logger, space numberSpace, rtt *rttState, maxAckDelay time.Duration) { + if logEnabled(log, QLogLevelPacket) { + oldState := c.state() + defer func() { logCongestionStateUpdated(log, oldState, c.state()) }() + } if !c.ackLastLoss.IsZero() && !c.ackLastLoss.Before(c.recoveryStartTime) { // Enter the recovery state. // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.3.2 @@ -196,8 +212,10 @@ func (c *ccReno) packetBatchEnd(now time.Time, space numberSpace, rtt *rttState, // Clear congestionPendingAcks to avoid increasing the congestion // window based on acks in a frame that sends us into recovery. c.congestionPendingAcks = 0 + c.inRecovery = true } else if c.congestionPendingAcks > 0 { // We are in slow start or congestion avoidance. + c.inRecovery = false if c.congestionWindow < c.slowStartThreshold { // When the congestion window is less than the slow start threshold, // we are in slow start and increase the window by the number of @@ -253,3 +271,38 @@ func (c *ccReno) minimumCongestionWindow() int { // https://www.rfc-editor.org/rfc/rfc9002.html#section-7.2-4 return 2 * c.maxDatagramSize } + +func logCongestionStateUpdated(log *slog.Logger, oldState, newState congestionState) { + if oldState == newState { + return + } + log.LogAttrs(context.Background(), QLogLevelPacket, + "recovery:congestion_state_updated", + slog.String("old", oldState.String()), + slog.String("new", newState.String()), + ) +} + +type congestionState string + +func (s congestionState) String() string { return string(s) } + +const ( + congestionSlowStart = congestionState("slow_start") + congestionCongestionAvoidance = congestionState("congestion_avoidance") + congestionApplicationLimited = congestionState("application_limited") + congestionRecovery = congestionState("recovery") +) + +func (c *ccReno) state() congestionState { + switch { + case c.inRecovery: + return congestionRecovery + case c.underutilized: + return congestionApplicationLimited + case c.congestionWindow < c.slowStartThreshold: + return congestionSlowStart + default: + return congestionCongestionAvoidance + } +} diff --git a/internal/quic/congestion_reno_test.go b/quic/congestion_reno_test.go similarity index 99% rename from internal/quic/congestion_reno_test.go rename to quic/congestion_reno_test.go index e9af6452c..cda7a90a8 100644 --- a/internal/quic/congestion_reno_test.go +++ b/quic/congestion_reno_test.go @@ -470,7 +470,7 @@ func (c *ccTest) setRTT(smoothedRTT, rttvar time.Duration) { func (c *ccTest) setUnderutilized(v bool) { c.t.Helper() c.t.Logf("set underutilized = %v", v) - c.cc.setUnderutilized(v) + c.cc.setUnderutilized(nil, v) } func (c *ccTest) packetSent(space numberSpace, size int, fns ...func(*sentPacket)) *sentPacket { @@ -488,7 +488,7 @@ func (c *ccTest) packetSent(space numberSpace, size int, fns ...func(*sentPacket f(sent) } c.t.Logf("packet sent: num=%v.%v, size=%v", space, sent.num, sent.size) - c.cc.packetSent(c.now, space, sent) + c.cc.packetSent(c.now, nil, space, sent) return sent } @@ -519,7 +519,7 @@ func (c *ccTest) packetDiscarded(space numberSpace, sent *sentPacket) { func (c *ccTest) packetBatchEnd(space numberSpace) { c.t.Helper() c.t.Logf("(end of batch)") - c.cc.packetBatchEnd(c.now, space, &c.rtt, c.maxAckDelay) + c.cc.packetBatchEnd(c.now, nil, space, &c.rtt, c.maxAckDelay) } func (c *ccTest) wantCanSend(want bool) { diff --git a/internal/quic/conn.go b/quic/conn.go similarity index 96% rename from internal/quic/conn.go rename to quic/conn.go index 6d79013eb..38e8fe8f4 100644 --- a/internal/quic/conn.go +++ b/quic/conn.go @@ -25,6 +25,7 @@ type Conn struct { config *Config testHooks connTestHooks peerAddr netip.AddrPort + localAddr netip.AddrPort msgc chan any donec chan struct{} // closed when conn loop exits @@ -36,6 +37,7 @@ type Conn struct { connIDState connIDState loss lossState streams streamsState + path pathState // Packet protection keys, CRYPTO streams, and TLS state. keysInitial fixedKeyPair @@ -92,12 +94,12 @@ type newServerConnIDs struct { retrySrcConnID []byte // source from server's Retry } -func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort, config *Config, e *Endpoint) (conn *Conn, _ error) { +func newConn(now time.Time, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort, config *Config, e *Endpoint) (conn *Conn, _ error) { c := &Conn{ side: side, endpoint: e, config: config, - peerAddr: peerAddr, + peerAddr: unmapAddrPort(peerAddr), msgc: make(chan any, 1), donec: make(chan struct{}), peerAckDelayExponent: -1, @@ -144,7 +146,7 @@ func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip c.lifetimeInit() c.restartIdleTimer(now) - if err := c.startTLS(now, initialConnID, transportParameters{ + if err := c.startTLS(now, initialConnID, peerHostname, transportParameters{ initialSrcConnID: c.connIDState.srcConnID(), originalDstConnID: cids.originalDstConnID, retrySrcConnID: cids.retrySrcConnID, @@ -210,7 +212,7 @@ func (c *Conn) discardKeys(now time.Time, space numberSpace) { case handshakeSpace: c.keysHandshake.discard() } - c.loss.discardKeys(now, space) + c.loss.discardKeys(now, c.log, space) } // receiveTransportParameters applies transport parameters sent by the peer. @@ -317,7 +319,11 @@ func (c *Conn) loop(now time.Time) { } switch m := m.(type) { case *datagram: - c.handleDatagram(now, m) + if !c.handleDatagram(now, m) { + if c.logEnabled(QLogLevelPacket) { + c.logPacketDropped(m) + } + } m.recycle() case timerEvent: // A connection timer has expired. diff --git a/internal/quic/conn_async_test.go b/quic/conn_async_test.go similarity index 100% rename from internal/quic/conn_async_test.go rename to quic/conn_async_test.go diff --git a/internal/quic/conn_close.go b/quic/conn_close.go similarity index 100% rename from internal/quic/conn_close.go rename to quic/conn_close.go diff --git a/internal/quic/conn_close_test.go b/quic/conn_close_test.go similarity index 98% rename from internal/quic/conn_close_test.go rename to quic/conn_close_test.go index 63d4911e8..213975011 100644 --- a/internal/quic/conn_close_test.go +++ b/quic/conn_close_test.go @@ -249,8 +249,9 @@ func TestConnCloseUnblocksNewStream(t *testing.T) { func TestConnCloseUnblocksStreamRead(t *testing.T) { testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error { s := newLocalStream(t, tc, bidiStream) + s.SetReadContext(ctx) buf := make([]byte, 16) - _, err := s.ReadContext(ctx, buf) + _, err := s.Read(buf) return err }, permissiveTransportParameters) } @@ -258,8 +259,9 @@ func TestConnCloseUnblocksStreamRead(t *testing.T) { func TestConnCloseUnblocksStreamWrite(t *testing.T) { testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error { s := newLocalStream(t, tc, bidiStream) + s.SetWriteContext(ctx) buf := make([]byte, 32) - _, err := s.WriteContext(ctx, buf) + _, err := s.Write(buf) return err }, permissiveTransportParameters, func(c *Config) { c.MaxStreamWriteBufferSize = 16 @@ -269,11 +271,12 @@ func TestConnCloseUnblocksStreamWrite(t *testing.T) { func TestConnCloseUnblocksStreamClose(t *testing.T) { testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error { s := newLocalStream(t, tc, bidiStream) + s.SetWriteContext(ctx) buf := make([]byte, 16) - _, err := s.WriteContext(ctx, buf) + _, err := s.Write(buf) if err != nil { return err } - return s.CloseContext(ctx) + return s.Close() }, permissiveTransportParameters) } diff --git a/internal/quic/conn_flow.go b/quic/conn_flow.go similarity index 100% rename from internal/quic/conn_flow.go rename to quic/conn_flow.go diff --git a/internal/quic/conn_flow_test.go b/quic/conn_flow_test.go similarity index 90% rename from internal/quic/conn_flow_test.go rename to quic/conn_flow_test.go index 39c879346..260684bdb 100644 --- a/internal/quic/conn_flow_test.go +++ b/quic/conn_flow_test.go @@ -12,39 +12,34 @@ import ( ) func TestConnInflowReturnOnRead(t *testing.T) { - ctx := canceledContext() tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, func(c *Config) { c.MaxConnReadBufferSize = 64 }) tc.writeFrames(packetType1RTT, debugFrameStream{ id: s.id, - data: make([]byte, 64), + data: make([]byte, 8), }) - const readSize = 8 - if n, err := s.ReadContext(ctx, make([]byte, readSize)); n != readSize || err != nil { - t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, readSize) - } - tc.wantFrame("available window increases, send a MAX_DATA", - packetType1RTT, debugFrameMaxData{ - max: 64 + readSize, - }) - if n, err := s.ReadContext(ctx, make([]byte, 64)); n != 64-readSize || err != nil { - t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, 64-readSize) + if n, err := s.Read(make([]byte, 8)); n != 8 || err != nil { + t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, 8) } tc.wantFrame("available window increases, send a MAX_DATA", packetType1RTT, debugFrameMaxData{ - max: 128, + max: 64 + 8, }) // Peer can write up to the new limit. tc.writeFrames(packetType1RTT, debugFrameStream{ id: s.id, - off: 64, + off: 8, data: make([]byte, 64), }) - tc.wantIdle("connection is idle") - if n, err := s.ReadContext(ctx, make([]byte, 64)); n != 64 || err != nil { - t.Fatalf("offset 64: s.Read() = %v, %v; want %v, nil", n, err, 64) + if n, err := s.Read(make([]byte, 64+1)); n != 64 { + t.Fatalf("s.Read() = %v, %v; want %v, anything", n, err, 64) } + tc.wantFrame("available window increases, send a MAX_DATA", + packetType1RTT, debugFrameMaxData{ + max: 64 + 8 + 64, + }) + tc.wantIdle("connection is idle") } func TestConnInflowReturnOnRacingReads(t *testing.T) { @@ -64,11 +59,11 @@ func TestConnInflowReturnOnRacingReads(t *testing.T) { tc.ignoreFrame(frameTypeAck) tc.writeFrames(packetType1RTT, debugFrameStream{ id: newStreamID(clientSide, uniStream, 0), - data: make([]byte, 32), + data: make([]byte, 16), }) tc.writeFrames(packetType1RTT, debugFrameStream{ id: newStreamID(clientSide, uniStream, 1), - data: make([]byte, 32), + data: make([]byte, 1), }) s1, err := tc.conn.AcceptStream(ctx) if err != nil { @@ -79,10 +74,10 @@ func TestConnInflowReturnOnRacingReads(t *testing.T) { t.Fatalf("conn.AcceptStream() = %v", err) } read1 := runAsync(tc, func(ctx context.Context) (int, error) { - return s1.ReadContext(ctx, make([]byte, 16)) + return s1.Read(make([]byte, 16)) }) read2 := runAsync(tc, func(ctx context.Context) (int, error) { - return s2.ReadContext(ctx, make([]byte, 1)) + return s2.Read(make([]byte, 1)) }) // This MAX_DATA might extend the window by 16 or 17, depending on // whether the second write occurs before the update happens. @@ -90,10 +85,10 @@ func TestConnInflowReturnOnRacingReads(t *testing.T) { packetType1RTT, debugFrameMaxData{}) tc.wantIdle("redundant MAX_DATA is not sent") if _, err := read1.result(); err != nil { - t.Errorf("ReadContext #1 = %v", err) + t.Errorf("Read #1 = %v", err) } if _, err := read2.result(); err != nil { - t.Errorf("ReadContext #2 = %v", err) + t.Errorf("Read #2 = %v", err) } } @@ -204,7 +199,6 @@ func TestConnInflowResetViolation(t *testing.T) { } func TestConnInflowMultipleStreams(t *testing.T) { - ctx := canceledContext() tc := newTestConn(t, serverSide, func(c *Config) { c.MaxConnReadBufferSize = 128 }) @@ -220,21 +214,26 @@ func TestConnInflowMultipleStreams(t *testing.T) { } { tc.writeFrames(packetType1RTT, debugFrameStream{ id: id, - data: make([]byte, 32), + data: make([]byte, 1), }) - s, err := tc.conn.AcceptStream(ctx) - if err != nil { - t.Fatalf("AcceptStream() = %v", err) - } + s := tc.acceptStream() streams = append(streams, s) - if n, err := s.ReadContext(ctx, make([]byte, 1)); err != nil || n != 1 { + if n, err := s.Read(make([]byte, 1)); err != nil || n != 1 { t.Fatalf("s.Read() = %v, %v; want 1, nil", n, err) } } tc.wantIdle("streams have read data, but not enough to update MAX_DATA") - if n, err := streams[0].ReadContext(ctx, make([]byte, 32)); err != nil || n != 31 { - t.Fatalf("s.Read() = %v, %v; want 31, nil", n, err) + for _, s := range streams { + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + off: 1, + data: make([]byte, 31), + }) + } + + if n, err := streams[0].Read(make([]byte, 32)); n != 31 { + t.Fatalf("s.Read() = %v, %v; want 31, anything", n, err) } tc.wantFrame("read enough data to trigger a MAX_DATA update", packetType1RTT, debugFrameMaxData{ diff --git a/internal/quic/conn_id.go b/quic/conn_id.go similarity index 100% rename from internal/quic/conn_id.go rename to quic/conn_id.go diff --git a/internal/quic/conn_id_test.go b/quic/conn_id_test.go similarity index 100% rename from internal/quic/conn_id_test.go rename to quic/conn_id_test.go diff --git a/internal/quic/conn_loss.go b/quic/conn_loss.go similarity index 96% rename from internal/quic/conn_loss.go rename to quic/conn_loss.go index 85bda314e..623ebdd7c 100644 --- a/internal/quic/conn_loss.go +++ b/quic/conn_loss.go @@ -20,6 +20,10 @@ import "fmt" // See RFC 9000, Section 13.3 for a complete list of information which is retransmitted on loss. // https://www.rfc-editor.org/rfc/rfc9000#section-13.3 func (c *Conn) handleAckOrLoss(space numberSpace, sent *sentPacket, fate packetFate) { + if fate == packetLost && c.logEnabled(QLogLevelPacket) { + c.logPacketLost(space, sent) + } + // The list of frames in a sent packet is marshaled into a buffer in the sentPacket // by the packetWriter. Unmarshal that buffer here. This code must be kept in sync with // packetWriter.append*. diff --git a/internal/quic/conn_loss_test.go b/quic/conn_loss_test.go similarity index 93% rename from internal/quic/conn_loss_test.go rename to quic/conn_loss_test.go index 818816335..81d537803 100644 --- a/internal/quic/conn_loss_test.go +++ b/quic/conn_loss_test.go @@ -308,9 +308,9 @@ func TestLostMaxDataFrame(t *testing.T) { tc.writeFrames(packetType1RTT, debugFrameStream{ id: s.id, off: 0, - data: make([]byte, maxWindowSize), + data: make([]byte, maxWindowSize-1), }) - if n, err := s.Read(buf[:maxWindowSize-1]); err != nil || n != maxWindowSize-1 { + if n, err := s.Read(buf[:maxWindowSize]); err != nil || n != maxWindowSize-1 { t.Fatalf("Read() = %v, %v; want %v, nil", n, err, maxWindowSize-1) } tc.wantFrame("conn window is extended after reading data", @@ -319,7 +319,12 @@ func TestLostMaxDataFrame(t *testing.T) { }) // MAX_DATA = 64, which is only one more byte, so we don't send the frame. - if n, err := s.Read(buf); err != nil || n != 1 { + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + off: maxWindowSize - 1, + data: make([]byte, 1), + }) + if n, err := s.Read(buf[:1]); err != nil || n != 1 { t.Fatalf("Read() = %v, %v; want %v, nil", n, err, 1) } tc.wantIdle("read doesn't extend window enough to send another MAX_DATA") @@ -348,9 +353,9 @@ func TestLostMaxStreamDataFrame(t *testing.T) { tc.writeFrames(packetType1RTT, debugFrameStream{ id: s.id, off: 0, - data: make([]byte, maxWindowSize), + data: make([]byte, maxWindowSize-1), }) - if n, err := s.Read(buf[:maxWindowSize-1]); err != nil || n != maxWindowSize-1 { + if n, err := s.Read(buf[:maxWindowSize]); err != nil || n != maxWindowSize-1 { t.Fatalf("Read() = %v, %v; want %v, nil", n, err, maxWindowSize-1) } tc.wantFrame("stream window is extended after reading data", @@ -360,6 +365,11 @@ func TestLostMaxStreamDataFrame(t *testing.T) { }) // MAX_STREAM_DATA = 64, which is only one more byte, so we don't send the frame. + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + off: maxWindowSize - 1, + data: make([]byte, 1), + }) if n, err := s.Read(buf); err != nil || n != 1 { t.Fatalf("Read() = %v, %v; want %v, nil", n, err, 1) } @@ -433,7 +443,8 @@ func TestLostMaxStreamsFrameMostRecent(t *testing.T) { if err != nil { t.Fatalf("AcceptStream() = %v", err) } - s.CloseContext(ctx) + s.SetWriteContext(ctx) + s.Close() if styp == bidiStream { tc.wantFrame("stream is closed", packetType1RTT, debugFrameStream{ @@ -480,7 +491,7 @@ func TestLostMaxStreamsFrameNotMostRecent(t *testing.T) { if err != nil { t.Fatalf("AcceptStream() = %v", err) } - if err := s.CloseContext(ctx); err != nil { + if err := s.Close(); err != nil { t.Fatalf("stream.Close() = %v", err) } tc.wantFrame("closing stream updates peer's MAX_STREAMS", @@ -512,7 +523,7 @@ func TestLostStreamDataBlockedFrame(t *testing.T) { }) w := runAsync(tc, func(ctx context.Context) (int, error) { - return s.WriteContext(ctx, []byte{0, 1, 2, 3}) + return s.Write([]byte{0, 1, 2, 3}) }) defer w.cancel() tc.wantFrame("write is blocked by flow control", @@ -564,7 +575,7 @@ func TestLostStreamDataBlockedFrameAfterStreamUnblocked(t *testing.T) { data := []byte{0, 1, 2, 3} w := runAsync(tc, func(ctx context.Context) (int, error) { - return s.WriteContext(ctx, data) + return s.Write(data) }) defer w.cancel() tc.wantFrame("write is blocked by flow control", @@ -652,6 +663,29 @@ func TestLostRetireConnectionIDFrame(t *testing.T) { }) } +func TestLostPathResponseFrame(t *testing.T) { + // "Responses to path validation using PATH_RESPONSE frames are sent just once." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.12 + lostFrameTest(t, func(t *testing.T, pto bool) { + tc := newTestConn(t, clientSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + tc.ignoreFrame(frameTypePing) + + data := pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef} + tc.writeFrames(packetType1RTT, debugFramePathChallenge{ + data: data, + }) + tc.wantFrame("response to PATH_CHALLENGE", + packetType1RTT, debugFramePathResponse{ + data: data, + }) + + tc.triggerLossOrPTO(packetType1RTT, pto) + tc.wantIdle("lost PATH_RESPONSE frame is not retransmitted") + }) +} + func TestLostHandshakeDoneFrame(t *testing.T) { // "The HANDSHAKE_DONE frame MUST be retransmitted until it is acknowledged." // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.16 diff --git a/internal/quic/conn_recv.go b/quic/conn_recv.go similarity index 86% rename from internal/quic/conn_recv.go rename to quic/conn_recv.go index 045bf861c..b1354cd3a 100644 --- a/internal/quic/conn_recv.go +++ b/quic/conn_recv.go @@ -13,11 +13,28 @@ import ( "time" ) -func (c *Conn) handleDatagram(now time.Time, dgram *datagram) { +func (c *Conn) handleDatagram(now time.Time, dgram *datagram) (handled bool) { + if !c.localAddr.IsValid() { + // We don't have any way to tell in the general case what address we're + // sending packets from. Set our address from the destination address of + // the first packet received from the peer. + c.localAddr = dgram.localAddr + } + if dgram.peerAddr.IsValid() && dgram.peerAddr != c.peerAddr { + if c.side == clientSide { + // "If a client receives packets from an unknown server address, + // the client MUST discard these packets." + // https://www.rfc-editor.org/rfc/rfc9000#section-9-6 + return false + } + // We currently don't support connection migration, + // so for now the server also drops packets from an unknown address. + return false + } buf := dgram.b c.loss.datagramReceived(now, len(buf)) if c.isDraining() { - return + return false } for len(buf) > 0 { var n int @@ -27,19 +44,19 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) { if c.side == serverSide && len(dgram.b) < paddedInitialDatagramSize { // Discard client-sent Initial packets in too-short datagrams. // https://www.rfc-editor.org/rfc/rfc9000#section-14.1-4 - return + return false } - n = c.handleLongHeader(now, ptype, initialSpace, c.keysInitial.r, buf) + n = c.handleLongHeader(now, dgram, ptype, initialSpace, c.keysInitial.r, buf) case packetTypeHandshake: - n = c.handleLongHeader(now, ptype, handshakeSpace, c.keysHandshake.r, buf) + n = c.handleLongHeader(now, dgram, ptype, handshakeSpace, c.keysHandshake.r, buf) case packetType1RTT: - n = c.handle1RTT(now, buf) + n = c.handle1RTT(now, dgram, buf) case packetTypeRetry: c.handleRetry(now, buf) - return + return true case packetTypeVersionNegotiation: c.handleVersionNegotiation(now, buf) - return + return true default: n = -1 } @@ -56,17 +73,20 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) { if len(buf) == len(dgram.b) && len(buf) > statelessResetTokenLen { var token statelessResetToken copy(token[:], buf[len(buf)-len(token):]) - c.handleStatelessReset(now, token) + if c.handleStatelessReset(now, token) { + return true + } } // Invalid data at the end of a datagram is ignored. - break + return false } c.idleHandlePacketReceived(now) buf = buf[n:] } + return true } -func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpace, k fixedKeys, buf []byte) int { +func (c *Conn) handleLongHeader(now time.Time, dgram *datagram, ptype packetType, space numberSpace, k fixedKeys, buf []byte) int { if !k.isSet() { return skipLongHeaderPacket(buf) } @@ -105,7 +125,7 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa c.logLongPacketReceived(p, buf[:n]) } c.connIDState.handlePacket(c, p.ptype, p.srcConnID) - ackEliciting := c.handleFrames(now, ptype, space, p.payload) + ackEliciting := c.handleFrames(now, dgram, ptype, space, p.payload) c.acks[space].receive(now, space, p.num, ackEliciting) if p.ptype == packetTypeHandshake && c.side == serverSide { c.loss.validateClientAddress() @@ -118,7 +138,7 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa return n } -func (c *Conn) handle1RTT(now time.Time, buf []byte) int { +func (c *Conn) handle1RTT(now time.Time, dgram *datagram, buf []byte) int { if !c.keysAppData.canRead() { // 1-RTT packets extend to the end of the datagram, // so skip the remainder of the datagram if we can't parse this. @@ -155,7 +175,7 @@ func (c *Conn) handle1RTT(now time.Time, buf []byte) int { if c.logEnabled(QLogLevelPacket) { c.log1RTTPacketReceived(p, buf) } - ackEliciting := c.handleFrames(now, packetType1RTT, appDataSpace, p.payload) + ackEliciting := c.handleFrames(now, dgram, packetType1RTT, appDataSpace, p.payload) c.acks[appDataSpace].receive(now, appDataSpace, p.num, ackEliciting) return len(buf) } @@ -192,7 +212,7 @@ func (c *Conn) handleRetry(now time.Time, pkt []byte) { c.connIDState.handleRetryPacket(p.srcConnID) // We need to resend any data we've already sent in Initial packets. // We must not reuse already sent packet numbers. - c.loss.discardPackets(initialSpace, c.handleAckOrLoss) + c.loss.discardPackets(initialSpace, c.log, c.handleAckOrLoss) // TODO: Discard 0-RTT packets as well, once we support 0-RTT. } @@ -232,7 +252,7 @@ func (c *Conn) handleVersionNegotiation(now time.Time, pkt []byte) { c.abortImmediately(now, errVersionNegotiation) } -func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, payload []byte) (ackEliciting bool) { +func (c *Conn) handleFrames(now time.Time, dgram *datagram, ptype packetType, space numberSpace, payload []byte) (ackEliciting bool) { if len(payload) == 0 { // "An endpoint MUST treat receipt of a packet containing no frames // as a connection error of type PROTOCOL_VIOLATION." @@ -353,6 +373,16 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, return } n = c.handleRetireConnectionIDFrame(now, space, payload) + case frameTypePathChallenge: + if !frameOK(c, ptype, __01) { + return + } + n = c.handlePathChallengeFrame(now, dgram, space, payload) + case frameTypePathResponse: + if !frameOK(c, ptype, ___1) { + return + } + n = c.handlePathResponseFrame(now, space, payload) case frameTypeConnectionCloseTransport: // Transport CONNECTION_CLOSE is OK in all spaces. n = c.handleConnectionCloseTransportFrame(now, payload) @@ -416,7 +446,7 @@ func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte) if c.peerAckDelayExponent >= 0 { delay = ackDelay.Duration(uint8(c.peerAckDelayExponent)) } - c.loss.receiveAckEnd(now, space, delay, c.handleAckOrLoss) + c.loss.receiveAckEnd(now, c.log, space, delay, c.handleAckOrLoss) if space == appDataSpace { c.keysAppData.handleAckFor(largest) } @@ -526,6 +556,24 @@ func (c *Conn) handleRetireConnectionIDFrame(now time.Time, space numberSpace, p return n } +func (c *Conn) handlePathChallengeFrame(now time.Time, dgram *datagram, space numberSpace, payload []byte) int { + data, n := consumePathChallengeFrame(payload) + if n < 0 { + return -1 + } + c.handlePathChallenge(now, dgram, data) + return n +} + +func (c *Conn) handlePathResponseFrame(now time.Time, space numberSpace, payload []byte) int { + data, n := consumePathResponseFrame(payload) + if n < 0 { + return -1 + } + c.handlePathResponse(now, data) + return n +} + func (c *Conn) handleConnectionCloseTransportFrame(now time.Time, payload []byte) int { code, _, reason, n := consumeConnectionCloseTransportFrame(payload) if n < 0 { @@ -562,10 +610,11 @@ func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payloa var errStatelessReset = errors.New("received stateless reset") -func (c *Conn) handleStatelessReset(now time.Time, resetToken statelessResetToken) { +func (c *Conn) handleStatelessReset(now time.Time, resetToken statelessResetToken) (valid bool) { if !c.connIDState.isValidStatelessResetToken(resetToken) { - return + return false } c.setFinalError(errStatelessReset) c.enterDraining(now) + return true } diff --git a/internal/quic/conn_send.go b/quic/conn_send.go similarity index 95% rename from internal/quic/conn_send.go rename to quic/conn_send.go index ccb467591..a87cac232 100644 --- a/internal/quic/conn_send.go +++ b/quic/conn_send.go @@ -22,7 +22,10 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { // Assumption: The congestion window is not underutilized. // If congestion control, pacing, and anti-amplification all permit sending, // but we have no packet to send, then we will declare the window underutilized. - c.loss.cc.setUnderutilized(false) + underutilized := false + defer func() { + c.loss.cc.setUnderutilized(c.log, underutilized) + }() // Send one datagram on each iteration of this loop, // until we hit a limit or run out of data to send. @@ -80,7 +83,6 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { } sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysInitial.w, p) if sentInitial != nil { - c.idleHandlePacketSent(now, sentInitial) // Client initial packets and ack-eliciting server initial packaets // need to be sent in a datagram padded to at least 1200 bytes. // We can't add the padding yet, however, since we may want to @@ -111,8 +113,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { c.logPacketSent(packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.packetLen(), c.w.payload()) } if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysHandshake.w, p); sent != nil { - c.idleHandlePacketSent(now, sent) - c.loss.packetSent(now, handshakeSpace, sent) + c.packetSent(now, handshakeSpace, sent) if c.side == clientSide { // "[...] a client MUST discard Initial keys when it first // sends a Handshake packet [...]" @@ -142,8 +143,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { c.logPacketSent(packetType1RTT, pnum, nil, dstConnID, c.w.packetLen(), c.w.payload()) } if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, &c.keysAppData); sent != nil { - c.idleHandlePacketSent(now, sent) - c.loss.packetSent(now, appDataSpace, sent) + c.packetSent(now, appDataSpace, sent) } } @@ -152,7 +152,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { if limit == ccOK { // We have nothing to send, and congestion control does not // block sending. The congestion window is underutilized. - c.loss.cc.setUnderutilized(true) + underutilized = true } return next } @@ -175,14 +175,22 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { // with a Handshake packet, then we've discarded Initial keys // since constructing the packet and shouldn't record it as in-flight. if c.keysInitial.canWrite() { - c.loss.packetSent(now, initialSpace, sentInitial) + c.packetSent(now, initialSpace, sentInitial) } } - c.endpoint.sendDatagram(buf, c.peerAddr) + c.endpoint.sendDatagram(datagram{ + b: buf, + peerAddr: c.peerAddr, + }) } } +func (c *Conn) packetSent(now time.Time, space numberSpace, sent *sentPacket) { + c.idleHandlePacketSent(now, sent) + c.loss.packetSent(now, c.log, space, sent) +} + func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, limit ccLimit) { if c.lifetime.localErr != nil { c.appendConnectionCloseFrame(now, space, c.lifetime.localErr) @@ -263,6 +271,13 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, return } + // PATH_RESPONSE + if pad, ok := c.appendPathFrames(); !ok { + return + } else if pad { + defer c.w.appendPaddingTo(smallestMaxDatagramSize) + } + // All stream-related frames. This should come last in the packet, // so large amounts of STREAM data don't crowd out other frames // we may need to send. diff --git a/internal/quic/conn_send_test.go b/quic/conn_send_test.go similarity index 100% rename from internal/quic/conn_send_test.go rename to quic/conn_send_test.go diff --git a/internal/quic/conn_streams.go b/quic/conn_streams.go similarity index 100% rename from internal/quic/conn_streams.go rename to quic/conn_streams.go diff --git a/internal/quic/conn_streams_test.go b/quic/conn_streams_test.go similarity index 95% rename from internal/quic/conn_streams_test.go rename to quic/conn_streams_test.go index 6815e403e..dc81ad991 100644 --- a/internal/quic/conn_streams_test.go +++ b/quic/conn_streams_test.go @@ -230,8 +230,8 @@ func TestStreamsWriteQueueFairness(t *testing.T) { t.Fatal(err) } streams = append(streams, s) - if n, err := s.WriteContext(ctx, data); n != len(data) || err != nil { - t.Fatalf("s.WriteContext() = %v, %v; want %v, nil", n, err, len(data)) + if n, err := s.Write(data); n != len(data) || err != nil { + t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(data)) } // Wait for the stream to finish writing whatever frames it can before // congestion control blocks it. @@ -298,7 +298,7 @@ func TestStreamsShutdown(t *testing.T) { side: localStream, styp: uniStream, setup: func(t *testing.T, tc *testConn, s *Stream) { - s.CloseContext(canceledContext()) + s.Close() }, shutdown: func(t *testing.T, tc *testConn, s *Stream) { tc.writeAckForAll() @@ -311,7 +311,7 @@ func TestStreamsShutdown(t *testing.T) { tc.writeFrames(packetType1RTT, debugFrameResetStream{ id: s.id, }) - s.CloseContext(canceledContext()) + s.Close() }, shutdown: func(t *testing.T, tc *testConn, s *Stream) { tc.writeAckForAll() @@ -321,8 +321,8 @@ func TestStreamsShutdown(t *testing.T) { side: localStream, styp: bidiStream, setup: func(t *testing.T, tc *testConn, s *Stream) { - s.CloseContext(canceledContext()) - tc.wantIdle("all frames after CloseContext are ignored") + s.Close() + tc.wantIdle("all frames after Close are ignored") tc.writeAckForAll() }, shutdown: func(t *testing.T, tc *testConn, s *Stream) { @@ -335,13 +335,12 @@ func TestStreamsShutdown(t *testing.T) { side: remoteStream, styp: uniStream, setup: func(t *testing.T, tc *testConn, s *Stream) { - ctx := canceledContext() tc.writeFrames(packetType1RTT, debugFrameStream{ id: s.id, fin: true, }) - if n, err := s.ReadContext(ctx, make([]byte, 16)); n != 0 || err != io.EOF { - t.Errorf("ReadContext() = %v, %v; want 0, io.EOF", n, err) + if n, err := s.Read(make([]byte, 16)); n != 0 || err != io.EOF { + t.Errorf("Read() = %v, %v; want 0, io.EOF", n, err) } }, shutdown: func(t *testing.T, tc *testConn, s *Stream) { @@ -451,17 +450,14 @@ func TestStreamsCreateAndCloseRemote(t *testing.T) { id: op.id, }) case acceptOp: - s, err := tc.conn.AcceptStream(ctx) - if err != nil { - t.Fatalf("AcceptStream() = %q; want stream %v", err, stringID(op.id)) - } + s := tc.acceptStream() if s.id != op.id { - t.Fatalf("accepted stram %v; want stream %v", err, stringID(op.id)) + t.Fatalf("accepted stream %v; want stream %v", stringID(s.id), stringID(op.id)) } t.Logf("accepted stream %v", stringID(op.id)) // Immediately close the stream, so the stream becomes done when the // peer closes its end. - s.CloseContext(ctx) + s.Close() } p := tc.readPacket() if p != nil { diff --git a/internal/quic/conn_test.go b/quic/conn_test.go similarity index 98% rename from internal/quic/conn_test.go rename to quic/conn_test.go index ddf0740e2..f4f1818a6 100644 --- a/internal/quic/conn_test.go +++ b/quic/conn_test.go @@ -21,7 +21,7 @@ import ( "testing" "time" - "golang.org/x/net/internal/quic/qlog" + "golang.org/x/net/quic/qlog" ) var ( @@ -168,6 +168,7 @@ type testConn struct { sentDatagrams [][]byte sentPackets []*testPacket sentFrames []debugFrame + lastDatagram *testDatagram lastPacket *testPacket recvDatagram chan *datagram @@ -241,8 +242,10 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { endpoint.configTestConn = configTestConn conn, err := endpoint.e.newConn( endpoint.now, + config, side, cids, + "", netip.MustParseAddrPort("127.0.0.1:443")) if err != nil { t.Fatal(err) @@ -382,6 +385,17 @@ func (tc *testConn) cleanup() { <-tc.conn.donec } +func (tc *testConn) acceptStream() *Stream { + tc.t.Helper() + s, err := tc.conn.AcceptStream(canceledContext()) + if err != nil { + tc.t.Fatalf("conn.AcceptStream() = %v, want stream", err) + } + s.SetReadContext(canceledContext()) + s.SetWriteContext(canceledContext()) + return s +} + func logDatagram(t *testing.T, text string, d *testDatagram) { t.Helper() if !*testVV { @@ -442,6 +456,7 @@ func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) { dstConnID: dstConnID, srcConnID: tc.peerConnID, }}, + addr: tc.conn.peerAddr, } if ptype == packetTypeInitial && tc.conn.side == serverSide { d.paddedSize = 1200 @@ -564,6 +579,7 @@ func (tc *testConn) readDatagram() *testDatagram { } p.frames = frames } + tc.lastDatagram = d return d } @@ -645,6 +661,12 @@ func (tc *testConn) wantPacket(expectation string, want *testPacket) { } func packetEqual(a, b *testPacket) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } ac := *a ac.frames = nil ac.header = 0 diff --git a/internal/quic/crypto_stream.go b/quic/crypto_stream.go similarity index 100% rename from internal/quic/crypto_stream.go rename to quic/crypto_stream.go diff --git a/internal/quic/crypto_stream_test.go b/quic/crypto_stream_test.go similarity index 100% rename from internal/quic/crypto_stream_test.go rename to quic/crypto_stream_test.go diff --git a/internal/quic/dgram.go b/quic/dgram.go similarity index 58% rename from internal/quic/dgram.go rename to quic/dgram.go index 79e6650fa..615589373 100644 --- a/internal/quic/dgram.go +++ b/quic/dgram.go @@ -12,10 +12,25 @@ import ( ) type datagram struct { - b []byte - addr netip.AddrPort + b []byte + localAddr netip.AddrPort + peerAddr netip.AddrPort + ecn ecnBits } +// Explicit Congestion Notification bits. +// +// https://www.rfc-editor.org/rfc/rfc3168.html#section-5 +type ecnBits byte + +const ( + ecnMask = 0b000000_11 + ecnNotECT = 0b000000_00 + ecnECT1 = 0b000000_01 + ecnECT0 = 0b000000_10 + ecnCE = 0b000000_11 +) + var datagramPool = sync.Pool{ New: func() any { return &datagram{ @@ -26,7 +41,9 @@ var datagramPool = sync.Pool{ func newDatagram() *datagram { m := datagramPool.Get().(*datagram) - m.b = m.b[:cap(m.b)] + *m = datagram{ + b: m.b[:cap(m.b)], + } return m } diff --git a/quic/doc.go b/quic/doc.go new file mode 100644 index 000000000..2fd10f087 --- /dev/null +++ b/quic/doc.go @@ -0,0 +1,45 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package quic implements the QUIC protocol. +// +// This package is a work in progress. +// It is not ready for production usage. +// Its API is subject to change without notice. +// +// This package is low-level. +// Most users will use it indirectly through an HTTP/3 implementation. +// +// # Usage +// +// An [Endpoint] sends and receives traffic on a network address. +// Create an Endpoint to either accept inbound QUIC connections +// or create outbound ones. +// +// A [Conn] is a QUIC connection. +// +// A [Stream] is a QUIC stream, an ordered, reliable byte stream. +// +// # Cancelation +// +// All blocking operations may be canceled using a context.Context. +// When performing an operation with a canceled context, the operation +// will succeed if doing so does not require blocking. For example, +// reading from a stream will return data when buffered data is available, +// even if the stream context is canceled. +// +// # Limitations +// +// This package is a work in progress. +// Known limitations include: +// +// - Performance is untuned. +// - 0-RTT is not supported. +// - Address migration is not supported. +// - Server preferred addresses are not supported. +// - The latency spin bit is not supported. +// - Stream send/receive windows are configurable, +// but are fixed and do not adapt to available throughput. +// - Path MTU discovery is not implemented. +package quic diff --git a/internal/quic/endpoint.go b/quic/endpoint.go similarity index 77% rename from internal/quic/endpoint.go rename to quic/endpoint.go index 82a08a18c..a55336b24 100644 --- a/internal/quic/endpoint.go +++ b/quic/endpoint.go @@ -22,11 +22,11 @@ import ( // // Multiple goroutines may invoke methods on an Endpoint simultaneously. type Endpoint struct { - config *Config - udpConn udpConn - testHooks endpointTestHooks - resetGen statelessResetTokenGenerator - retry retryState + listenConfig *Config + packetConn packetConn + testHooks endpointTestHooks + resetGen statelessResetTokenGenerator + retry retryState acceptQueue queue[*Conn] // new inbound connections connsMap connsMap // only accessed by the listen loop @@ -42,19 +42,20 @@ type endpointTestHooks interface { newConn(c *Conn) } -// A udpConn is a UDP connection. -// It is implemented by net.UDPConn. -type udpConn interface { +// A packetConn is the interface to sending and receiving UDP packets. +type packetConn interface { Close() error - LocalAddr() net.Addr - ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) - WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) + LocalAddr() netip.AddrPort + Read(f func(*datagram)) + Write(datagram) error } // Listen listens on a local network address. -// The configuration config must be non-nil. -func Listen(network, address string, config *Config) (*Endpoint, error) { - if config.TLSConfig == nil { +// +// The config is used to for connections accepted by the endpoint. +// If the config is nil, the endpoint will not accept connections. +func Listen(network, address string, listenConfig *Config) (*Endpoint, error) { + if listenConfig != nil && listenConfig.TLSConfig == nil { return nil, errors.New("TLSConfig is not set") } a, err := net.ResolveUDPAddr(network, address) @@ -65,21 +66,29 @@ func Listen(network, address string, config *Config) (*Endpoint, error) { if err != nil { return nil, err } - return newEndpoint(udpConn, config, nil) + pc, err := newNetUDPConn(udpConn) + if err != nil { + return nil, err + } + return newEndpoint(pc, listenConfig, nil) } -func newEndpoint(udpConn udpConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) { +func newEndpoint(pc packetConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) { e := &Endpoint{ - config: config, - udpConn: udpConn, - testHooks: hooks, - conns: make(map[*Conn]struct{}), - acceptQueue: newQueue[*Conn](), - closec: make(chan struct{}), - } - e.resetGen.init(config.StatelessResetKey) + listenConfig: config, + packetConn: pc, + testHooks: hooks, + conns: make(map[*Conn]struct{}), + acceptQueue: newQueue[*Conn](), + closec: make(chan struct{}), + } + var statelessResetKey [32]byte + if config != nil { + statelessResetKey = config.StatelessResetKey + } + e.resetGen.init(statelessResetKey) e.connsMap.init() - if config.RequireAddressValidation { + if config != nil && config.RequireAddressValidation { if err := e.retry.init(); err != nil { return nil, err } @@ -90,8 +99,7 @@ func newEndpoint(udpConn udpConn, config *Config, hooks endpointTestHooks) (*End // LocalAddr returns the local network address. func (e *Endpoint) LocalAddr() netip.AddrPort { - a, _ := e.udpConn.LocalAddr().(*net.UDPAddr) - return a.AddrPort() + return e.packetConn.LocalAddr() } // Close closes the Endpoint. @@ -103,25 +111,31 @@ func (e *Endpoint) LocalAddr() netip.AddrPort { // It waits for the peers of any open connection to acknowledge the connection has been closed. func (e *Endpoint) Close(ctx context.Context) error { e.acceptQueue.close(errors.New("endpoint closed")) + + // It isn't safe to call Conn.Abort or conn.exit with connsMu held, + // so copy the list of conns. + var conns []*Conn e.connsMu.Lock() if !e.closing { - e.closing = true + e.closing = true // setting e.closing prevents new conns from being created for c := range e.conns { - c.Abort(localTransportError{code: errNo}) + conns = append(conns, c) } if len(e.conns) == 0 { - e.udpConn.Close() + e.packetConn.Close() } } e.connsMu.Unlock() + + for _, c := range conns { + c.Abort(localTransportError{code: errNo}) + } select { case <-e.closec: case <-ctx.Done(): - e.connsMu.Lock() - for c := range e.conns { + for _, c := range conns { c.exit() } - e.connsMu.Unlock() return ctx.Err() } return nil @@ -133,14 +147,15 @@ func (e *Endpoint) Accept(ctx context.Context) (*Conn, error) { } // Dial creates and returns a connection to a network address. -func (e *Endpoint) Dial(ctx context.Context, network, address string) (*Conn, error) { +// The config cannot be nil. +func (e *Endpoint) Dial(ctx context.Context, network, address string, config *Config) (*Conn, error) { u, err := net.ResolveUDPAddr(network, address) if err != nil { return nil, err } addr := u.AddrPort() addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port()) - c, err := e.newConn(time.Now(), clientSide, newServerConnIDs{}, addr) + c, err := e.newConn(time.Now(), config, clientSide, newServerConnIDs{}, address, addr) if err != nil { return nil, err } @@ -151,13 +166,13 @@ func (e *Endpoint) Dial(ctx context.Context, network, address string) (*Conn, er return c, nil } -func (e *Endpoint) newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort) (*Conn, error) { +func (e *Endpoint) newConn(now time.Time, config *Config, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort) (*Conn, error) { e.connsMu.Lock() defer e.connsMu.Unlock() if e.closing { return nil, errors.New("endpoint closed") } - c, err := newConn(now, side, cids, peerAddr, e.config, e) + c, err := newConn(now, side, cids, peerHostname, peerAddr, config, e) if err != nil { return nil, err } @@ -194,34 +209,18 @@ func (e *Endpoint) connDrained(c *Conn) { defer e.connsMu.Unlock() delete(e.conns, c) if e.closing && len(e.conns) == 0 { - e.udpConn.Close() + e.packetConn.Close() } } func (e *Endpoint) listen() { defer close(e.closec) - for { - m := newDatagram() - // TODO: Read and process the ECN (explicit congestion notification) field. - // https://tools.ietf.org/html/draft-ietf-quic-transport-32#section-13.4 - n, _, _, addr, err := e.udpConn.ReadMsgUDPAddrPort(m.b, nil) - if err != nil { - // The user has probably closed the endpoint. - // We currently don't surface errors from other causes; - // we could check to see if the endpoint has been closed and - // record the unexpected error if it has not. - return - } - if n == 0 { - continue - } + e.packetConn.Read(func(m *datagram) { if e.connsMap.updateNeeded.Load() { e.connsMap.applyUpdates() } - m.addr = addr - m.b = m.b[:n] e.handleDatagram(m) - } + }) } func (e *Endpoint) handleDatagram(m *datagram) { @@ -271,7 +270,7 @@ func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { // If this is a 1-RTT packet, there's nothing productive we can do with it. // Send a stateless reset if possible. if !isLongHeader(m.b[0]) { - e.maybeSendStatelessReset(m.b, m.addr) + e.maybeSendStatelessReset(m.b, m.peerAddr) return } p, ok := parseGenericLongHeaderPacket(m.b) @@ -285,7 +284,7 @@ func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { return default: // Unknown version. - e.sendVersionNegotiation(p, m.addr) + e.sendVersionNegotiation(p, m.peerAddr) return } if getPacketType(m.b) != packetTypeInitial { @@ -296,14 +295,18 @@ func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { // https://www.rfc-editor.org/rfc/rfc9000#section-10.3-16 return } + if e.listenConfig == nil { + // We are not configured to accept connections. + return + } cids := newServerConnIDs{ srcConnID: p.srcConnID, dstConnID: p.dstConnID, } - if e.config.RequireAddressValidation { + if e.listenConfig.RequireAddressValidation { var ok bool cids.retrySrcConnID = p.dstConnID - cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.addr) + cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.peerAddr) if !ok { return } @@ -311,7 +314,7 @@ func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { cids.originalDstConnID = p.dstConnID } var err error - c, err := e.newConn(now, serverSide, cids, m.addr) + c, err := e.newConn(now, e.listenConfig, serverSide, cids, "", m.peerAddr) if err != nil { // The accept queue is probably full. // We could send a CONNECTION_CLOSE to the peer to reject the connection. @@ -323,7 +326,7 @@ func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { m = nil // don't recycle, sendMsg takes ownership } -func (e *Endpoint) maybeSendStatelessReset(b []byte, addr netip.AddrPort) { +func (e *Endpoint) maybeSendStatelessReset(b []byte, peerAddr netip.AddrPort) { if !e.resetGen.canReset { // Config.StatelessResetKey isn't set, so we don't send stateless resets. return @@ -364,17 +367,21 @@ func (e *Endpoint) maybeSendStatelessReset(b []byte, addr netip.AddrPort) { b[0] &^= headerFormLong // clear long header bit b[0] |= fixedBit // set fixed bit copy(b[len(b)-statelessResetTokenLen:], token[:]) - e.sendDatagram(b, addr) + e.sendDatagram(datagram{ + b: b, + peerAddr: peerAddr, + }) } -func (e *Endpoint) sendVersionNegotiation(p genericLongPacket, addr netip.AddrPort) { +func (e *Endpoint) sendVersionNegotiation(p genericLongPacket, peerAddr netip.AddrPort) { m := newDatagram() m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1) - e.sendDatagram(m.b, addr) + m.peerAddr = peerAddr + e.sendDatagram(*m) m.recycle() } -func (e *Endpoint) sendConnectionClose(in genericLongPacket, addr netip.AddrPort, code transportError) { +func (e *Endpoint) sendConnectionClose(in genericLongPacket, peerAddr netip.AddrPort, code transportError) { keys := initialKeys(in.dstConnID, serverSide) var w packetWriter p := longPacket{ @@ -393,12 +400,14 @@ func (e *Endpoint) sendConnectionClose(in genericLongPacket, addr netip.AddrPort if len(buf) == 0 { return } - e.sendDatagram(buf, addr) + e.sendDatagram(datagram{ + b: buf, + peerAddr: peerAddr, + }) } -func (e *Endpoint) sendDatagram(p []byte, addr netip.AddrPort) error { - _, err := e.udpConn.WriteToUDPAddrPort(p, addr) - return err +func (e *Endpoint) sendDatagram(dgram datagram) error { + return e.packetConn.Write(dgram) } // A connsMap is an endpoint's mapping of conn ids and reset tokens to conns. diff --git a/internal/quic/endpoint_test.go b/quic/endpoint_test.go similarity index 88% rename from internal/quic/endpoint_test.go rename to quic/endpoint_test.go index ab6cd1cf5..d5f436e6d 100644 --- a/internal/quic/endpoint_test.go +++ b/quic/endpoint_test.go @@ -12,12 +12,11 @@ import ( "crypto/tls" "io" "log/slog" - "net" "net/netip" "testing" "time" - "golang.org/x/net/internal/quic/qlog" + "golang.org/x/net/quic/qlog" ) func TestConnect(t *testing.T) { @@ -63,12 +62,13 @@ func TestStreamTransfer(t *testing.T) { } } -func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) { +func newLocalConnPair(t testing.TB, conf1, conf2 *Config) (clientConn, serverConn *Conn) { t.Helper() ctx := context.Background() e1 := newLocalEndpoint(t, serverSide, conf1) e2 := newLocalEndpoint(t, clientSide, conf2) - c2, err := e2.Dial(ctx, "udp", e1.LocalAddr().String()) + conf2 = makeTestConfig(conf2, clientSide) + c2, err := e2.Dial(ctx, "udp", e1.LocalAddr().String(), conf2) if err != nil { t.Fatal(err) } @@ -79,11 +79,26 @@ func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverCon return c2, c1 } -func newLocalEndpoint(t *testing.T, side connSide, conf *Config) *Endpoint { +func newLocalEndpoint(t testing.TB, side connSide, conf *Config) *Endpoint { t.Helper() + conf = makeTestConfig(conf, side) + e, err := Listen("udp", "127.0.0.1:0", conf) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + e.Close(canceledContext()) + }) + return e +} + +func makeTestConfig(conf *Config, side connSide) *Config { + if conf == nil { + return nil + } + newConf := *conf + conf = &newConf if conf.TLSConfig == nil { - newConf := *conf - conf = &newConf conf.TLSConfig = newTestTLSConfig(side) } if conf.QLogLogger == nil { @@ -92,14 +107,7 @@ func newLocalEndpoint(t *testing.T, side connSide, conf *Config) *Endpoint { Dir: *qlogdir, })) } - e, err := Listen("udp", "127.0.0.1:0", conf) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - e.Close(context.Background()) - }) - return e + return conf } type testEndpoint struct { @@ -190,13 +198,9 @@ func (te *testEndpoint) writeDatagram(d *testDatagram) { for len(buf) < d.paddedSize { buf = append(buf, 0) } - addr := d.addr - if !addr.IsValid() { - addr = testClientAddr - } te.write(&datagram{ - b: buf, - addr: addr, + b: buf, + peerAddr: d.addr, }) } @@ -303,25 +307,24 @@ func (te *testEndpointUDPConn) Close() error { return nil } -func (te *testEndpointUDPConn) LocalAddr() net.Addr { - return net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:443")) +func (te *testEndpointUDPConn) LocalAddr() netip.AddrPort { + return netip.MustParseAddrPort("127.0.0.1:443") } -func (te *testEndpointUDPConn) ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) { +func (te *testEndpointUDPConn) Read(f func(*datagram)) { for { select { case d, ok := <-te.recvc: if !ok { - return 0, 0, 0, netip.AddrPort{}, io.EOF + return } - n = copy(b, d.b) - return n, 0, 0, d.addr, nil + f(d) case <-te.idlec: } } } -func (te *testEndpointUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { - te.sentDatagrams = append(te.sentDatagrams, append([]byte(nil), b...)) - return len(b), nil +func (te *testEndpointUDPConn) Write(dgram datagram) error { + te.sentDatagrams = append(te.sentDatagrams, append([]byte(nil), dgram.b...)) + return nil } diff --git a/internal/quic/errors.go b/quic/errors.go similarity index 100% rename from internal/quic/errors.go rename to quic/errors.go diff --git a/internal/quic/files_test.go b/quic/files_test.go similarity index 100% rename from internal/quic/files_test.go rename to quic/files_test.go diff --git a/internal/quic/frame_debug.go b/quic/frame_debug.go similarity index 98% rename from internal/quic/frame_debug.go rename to quic/frame_debug.go index 0902c385f..17234dd7c 100644 --- a/internal/quic/frame_debug.go +++ b/quic/frame_debug.go @@ -77,6 +77,7 @@ func parseDebugFrame(b []byte) (f debugFrame, n int) { // debugFramePadding is a sequence of PADDING frames. type debugFramePadding struct { size int + to int // alternate for writing packets: pad to } func parseDebugFramePadding(b []byte) (f debugFramePadding, n int) { @@ -95,6 +96,10 @@ func (f debugFramePadding) write(w *packetWriter) bool { if w.avail() == 0 { return false } + if f.to > 0 { + w.appendPaddingTo(f.to) + return true + } for i := 0; i < f.size && w.avail() > 0; i++ { w.b = append(w.b, frameTypePadding) } @@ -584,7 +589,7 @@ func (f debugFrameRetireConnectionID) LogValue() slog.Value { // debugFramePathChallenge is a PATH_CHALLENGE frame. type debugFramePathChallenge struct { - data uint64 + data pathChallengeData } func parseDebugFramePathChallenge(b []byte) (f debugFramePathChallenge, n int) { @@ -593,7 +598,7 @@ func parseDebugFramePathChallenge(b []byte) (f debugFramePathChallenge, n int) { } func (f debugFramePathChallenge) String() string { - return fmt.Sprintf("PATH_CHALLENGE Data=%016x", f.data) + return fmt.Sprintf("PATH_CHALLENGE Data=%x", f.data) } func (f debugFramePathChallenge) write(w *packetWriter) bool { @@ -603,13 +608,13 @@ func (f debugFramePathChallenge) write(w *packetWriter) bool { func (f debugFramePathChallenge) LogValue() slog.Value { return slog.GroupValue( slog.String("frame_type", "path_challenge"), - slog.String("data", fmt.Sprintf("%016x", f.data)), + slog.String("data", fmt.Sprintf("%x", f.data)), ) } // debugFramePathResponse is a PATH_RESPONSE frame. type debugFramePathResponse struct { - data uint64 + data pathChallengeData } func parseDebugFramePathResponse(b []byte) (f debugFramePathResponse, n int) { @@ -618,7 +623,7 @@ func parseDebugFramePathResponse(b []byte) (f debugFramePathResponse, n int) { } func (f debugFramePathResponse) String() string { - return fmt.Sprintf("PATH_RESPONSE Data=%016x", f.data) + return fmt.Sprintf("PATH_RESPONSE Data=%x", f.data) } func (f debugFramePathResponse) write(w *packetWriter) bool { @@ -628,7 +633,7 @@ func (f debugFramePathResponse) write(w *packetWriter) bool { func (f debugFramePathResponse) LogValue() slog.Value { return slog.GroupValue( slog.String("frame_type", "path_response"), - slog.String("data", fmt.Sprintf("%016x", f.data)), + slog.String("data", fmt.Sprintf("%x", f.data)), ) } diff --git a/internal/quic/gate.go b/quic/gate.go similarity index 100% rename from internal/quic/gate.go rename to quic/gate.go diff --git a/internal/quic/gate_test.go b/quic/gate_test.go similarity index 100% rename from internal/quic/gate_test.go rename to quic/gate_test.go diff --git a/internal/quic/gotraceback_test.go b/quic/gotraceback_test.go similarity index 100% rename from internal/quic/gotraceback_test.go rename to quic/gotraceback_test.go diff --git a/internal/quic/idle.go b/quic/idle.go similarity index 100% rename from internal/quic/idle.go rename to quic/idle.go diff --git a/internal/quic/idle_test.go b/quic/idle_test.go similarity index 100% rename from internal/quic/idle_test.go rename to quic/idle_test.go diff --git a/internal/quic/key_update_test.go b/quic/key_update_test.go similarity index 100% rename from internal/quic/key_update_test.go rename to quic/key_update_test.go diff --git a/internal/quic/log.go b/quic/log.go similarity index 100% rename from internal/quic/log.go rename to quic/log.go diff --git a/internal/quic/loss.go b/quic/loss.go similarity index 90% rename from internal/quic/loss.go rename to quic/loss.go index a59081fd5..796b5f7a3 100644 --- a/internal/quic/loss.go +++ b/quic/loss.go @@ -7,6 +7,8 @@ package quic import ( + "context" + "log/slog" "math" "time" ) @@ -179,7 +181,7 @@ func (c *lossState) nextNumber(space numberSpace) packetNumber { } // packetSent records a sent packet. -func (c *lossState) packetSent(now time.Time, space numberSpace, sent *sentPacket) { +func (c *lossState) packetSent(now time.Time, log *slog.Logger, space numberSpace, sent *sentPacket) { sent.time = now c.spaces[space].add(sent) size := sent.size @@ -187,13 +189,16 @@ func (c *lossState) packetSent(now time.Time, space numberSpace, sent *sentPacke c.antiAmplificationLimit = max(0, c.antiAmplificationLimit-size) } if sent.inFlight { - c.cc.packetSent(now, space, sent) + c.cc.packetSent(now, log, space, sent) c.pacer.packetSent(now, size, c.cc.congestionWindow, c.rtt.smoothedRTT) if sent.ackEliciting { c.spaces[space].lastAckEliciting = sent.num c.ptoExpired = false // reset expired PTO timer after sending probe } c.scheduleTimer(now) + if logEnabled(log, QLogLevelPacket) { + logBytesInFlight(log, c.cc.bytesInFlight) + } } if sent.ackEliciting { c.consecutiveNonAckElicitingPackets = 0 @@ -267,7 +272,7 @@ func (c *lossState) receiveAckRange(now time.Time, space numberSpace, rangeIndex // receiveAckEnd finishes processing an ack frame. // The lossf function is called for each packet newly detected as lost. -func (c *lossState) receiveAckEnd(now time.Time, space numberSpace, ackDelay time.Duration, lossf func(numberSpace, *sentPacket, packetFate)) { +func (c *lossState) receiveAckEnd(now time.Time, log *slog.Logger, space numberSpace, ackDelay time.Duration, lossf func(numberSpace, *sentPacket, packetFate)) { c.spaces[space].sentPacketList.clean() // Update the RTT sample when the largest acknowledged packet in the ACK frame // is newly acknowledged, and at least one newly acknowledged packet is ack-eliciting. @@ -286,13 +291,30 @@ func (c *lossState) receiveAckEnd(now time.Time, space numberSpace, ackDelay tim // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.2.1-3 c.timer = time.Time{} c.detectLoss(now, lossf) - c.cc.packetBatchEnd(now, space, &c.rtt, c.maxAckDelay) + c.cc.packetBatchEnd(now, log, space, &c.rtt, c.maxAckDelay) + + if logEnabled(log, QLogLevelPacket) { + var ssthresh slog.Attr + if c.cc.slowStartThreshold != math.MaxInt { + ssthresh = slog.Int("ssthresh", c.cc.slowStartThreshold) + } + log.LogAttrs(context.Background(), QLogLevelPacket, + "recovery:metrics_updated", + slog.Duration("min_rtt", c.rtt.minRTT), + slog.Duration("smoothed_rtt", c.rtt.smoothedRTT), + slog.Duration("latest_rtt", c.rtt.latestRTT), + slog.Duration("rtt_variance", c.rtt.rttvar), + slog.Int("congestion_window", c.cc.congestionWindow), + slog.Int("bytes_in_flight", c.cc.bytesInFlight), + ssthresh, + ) + } } // discardPackets declares that packets within a number space will not be delivered // and that data contained in them should be resent. // For example, after receiving a Retry packet we discard already-sent Initial packets. -func (c *lossState) discardPackets(space numberSpace, lossf func(numberSpace, *sentPacket, packetFate)) { +func (c *lossState) discardPackets(space numberSpace, log *slog.Logger, lossf func(numberSpace, *sentPacket, packetFate)) { for i := 0; i < c.spaces[space].size; i++ { sent := c.spaces[space].nth(i) sent.lost = true @@ -300,10 +322,13 @@ func (c *lossState) discardPackets(space numberSpace, lossf func(numberSpace, *s lossf(numberSpace(space), sent, packetLost) } c.spaces[space].clean() + if logEnabled(log, QLogLevelPacket) { + logBytesInFlight(log, c.cc.bytesInFlight) + } } // discardKeys is called when dropping packet protection keys for a number space. -func (c *lossState) discardKeys(now time.Time, space numberSpace) { +func (c *lossState) discardKeys(now time.Time, log *slog.Logger, space numberSpace) { // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.4 for i := 0; i < c.spaces[space].size; i++ { sent := c.spaces[space].nth(i) @@ -313,6 +338,9 @@ func (c *lossState) discardKeys(now time.Time, space numberSpace) { c.spaces[space].maxAcked = -1 c.spaces[space].lastAckEliciting = -1 c.scheduleTimer(now) + if logEnabled(log, QLogLevelPacket) { + logBytesInFlight(log, c.cc.bytesInFlight) + } } func (c *lossState) lossDuration() time.Duration { @@ -459,3 +487,10 @@ func (c *lossState) ptoBasePeriod() time.Duration { } return pto } + +func logBytesInFlight(log *slog.Logger, bytesInFlight int) { + log.LogAttrs(context.Background(), QLogLevelPacket, + "recovery:metrics_updated", + slog.Int("bytes_in_flight", bytesInFlight), + ) +} diff --git a/internal/quic/loss_test.go b/quic/loss_test.go similarity index 99% rename from internal/quic/loss_test.go rename to quic/loss_test.go index efbf1649e..1fb9662e4 100644 --- a/internal/quic/loss_test.go +++ b/quic/loss_test.go @@ -1060,7 +1060,7 @@ func TestLossPersistentCongestion(t *testing.T) { maxDatagramSize: 1200, }) test.send(initialSpace, 0, testSentPacketSize(1200)) - test.c.cc.setUnderutilized(true) + test.c.cc.setUnderutilized(nil, true) test.advance(10 * time.Millisecond) test.ack(initialSpace, 0*time.Millisecond, i64range[packetNumber]{0, 1}) @@ -1377,7 +1377,7 @@ func (c *lossTest) setRTTVar(d time.Duration) { func (c *lossTest) setUnderutilized(v bool) { c.t.Logf("set congestion window underutilized: %v", v) - c.c.cc.setUnderutilized(v) + c.c.cc.setUnderutilized(nil, v) } func (c *lossTest) advance(d time.Duration) { @@ -1438,7 +1438,7 @@ func (c *lossTest) send(spaceID numberSpace, opts ...any) { sent := &sentPacket{} *sent = prototype sent.num = num - c.c.packetSent(c.now, spaceID, sent) + c.c.packetSent(c.now, nil, spaceID, sent) } } @@ -1462,7 +1462,7 @@ func (c *lossTest) ack(spaceID numberSpace, ackDelay time.Duration, rs ...i64ran c.t.Logf("ack %v delay=%v [%v,%v)", spaceID, ackDelay, r.start, r.end) c.c.receiveAckRange(c.now, spaceID, i, r.start, r.end, c.onAckOrLoss) } - c.c.receiveAckEnd(c.now, spaceID, ackDelay, c.onAckOrLoss) + c.c.receiveAckEnd(c.now, nil, spaceID, ackDelay, c.onAckOrLoss) } func (c *lossTest) onAckOrLoss(space numberSpace, sent *sentPacket, fate packetFate) { @@ -1491,7 +1491,7 @@ func (c *lossTest) discardKeys(spaceID numberSpace) { c.t.Helper() c.checkUnexpectedEvents() c.t.Logf("discard %s keys", spaceID) - c.c.discardKeys(c.now, spaceID) + c.c.discardKeys(c.now, nil, spaceID) } func (c *lossTest) setMaxAckDelay(d time.Duration) { diff --git a/internal/quic/main_test.go b/quic/main_test.go similarity index 100% rename from internal/quic/main_test.go rename to quic/main_test.go diff --git a/internal/quic/math.go b/quic/math.go similarity index 100% rename from internal/quic/math.go rename to quic/math.go diff --git a/internal/quic/pacer.go b/quic/pacer.go similarity index 100% rename from internal/quic/pacer.go rename to quic/pacer.go diff --git a/internal/quic/pacer_test.go b/quic/pacer_test.go similarity index 100% rename from internal/quic/pacer_test.go rename to quic/pacer_test.go diff --git a/internal/quic/packet.go b/quic/packet.go similarity index 100% rename from internal/quic/packet.go rename to quic/packet.go diff --git a/internal/quic/packet_codec_test.go b/quic/packet_codec_test.go similarity index 99% rename from internal/quic/packet_codec_test.go rename to quic/packet_codec_test.go index 475e18c1d..3b39795ef 100644 --- a/internal/quic/packet_codec_test.go +++ b/quic/packet_codec_test.go @@ -15,7 +15,7 @@ import ( "testing" "time" - "golang.org/x/net/internal/quic/qlog" + "golang.org/x/net/quic/qlog" ) func TestParseLongHeaderPacket(t *testing.T) { @@ -517,7 +517,7 @@ func TestFrameEncodeDecode(t *testing.T) { s: "PATH_CHALLENGE Data=0123456789abcdef", j: `{"frame_type":"path_challenge","data":"0123456789abcdef"}`, f: debugFramePathChallenge{ - data: 0x0123456789abcdef, + data: pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}, }, b: []byte{ 0x1a, // Type (i) = 0x1a, @@ -527,7 +527,7 @@ func TestFrameEncodeDecode(t *testing.T) { s: "PATH_RESPONSE Data=0123456789abcdef", j: `{"frame_type":"path_response","data":"0123456789abcdef"}`, f: debugFramePathResponse{ - data: 0x0123456789abcdef, + data: pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}, }, b: []byte{ 0x1b, // Type (i) = 0x1b, diff --git a/internal/quic/packet_number.go b/quic/packet_number.go similarity index 100% rename from internal/quic/packet_number.go rename to quic/packet_number.go diff --git a/internal/quic/packet_number_test.go b/quic/packet_number_test.go similarity index 100% rename from internal/quic/packet_number_test.go rename to quic/packet_number_test.go diff --git a/internal/quic/packet_parser.go b/quic/packet_parser.go similarity index 98% rename from internal/quic/packet_parser.go rename to quic/packet_parser.go index 02ef9fb14..feef9eac7 100644 --- a/internal/quic/packet_parser.go +++ b/quic/packet_parser.go @@ -463,18 +463,17 @@ func consumeRetireConnectionIDFrame(b []byte) (seq int64, n int) { return seq, n } -func consumePathChallengeFrame(b []byte) (data uint64, n int) { +func consumePathChallengeFrame(b []byte) (data pathChallengeData, n int) { n = 1 - var nn int - data, nn = consumeUint64(b[n:]) - if nn < 0 { - return 0, -1 + nn := copy(data[:], b[n:]) + if nn != len(data) { + return data, -1 } n += nn return data, n } -func consumePathResponseFrame(b []byte) (data uint64, n int) { +func consumePathResponseFrame(b []byte) (data pathChallengeData, n int) { return consumePathChallengeFrame(b) // identical frame format } diff --git a/internal/quic/packet_protection.go b/quic/packet_protection.go similarity index 100% rename from internal/quic/packet_protection.go rename to quic/packet_protection.go diff --git a/internal/quic/packet_protection_test.go b/quic/packet_protection_test.go similarity index 100% rename from internal/quic/packet_protection_test.go rename to quic/packet_protection_test.go diff --git a/internal/quic/packet_test.go b/quic/packet_test.go similarity index 100% rename from internal/quic/packet_test.go rename to quic/packet_test.go diff --git a/internal/quic/packet_writer.go b/quic/packet_writer.go similarity index 95% rename from internal/quic/packet_writer.go rename to quic/packet_writer.go index b4e54ce4b..e4d71e622 100644 --- a/internal/quic/packet_writer.go +++ b/quic/packet_writer.go @@ -141,7 +141,7 @@ func (w *packetWriter) finishProtectedLongHeaderPacket(pnumMaxAcked packetNumber hdr = appendPacketNumber(hdr, p.num, pnumMaxAcked) k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, p.num) - return w.finish(p.num) + return w.finish(p.ptype, p.num) } // start1RTTPacket starts writing a 1-RTT (short header) packet. @@ -183,7 +183,7 @@ func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConn hdr = appendPacketNumber(hdr, pnum, pnumMaxAcked) w.padPacketLength(pnumLen) k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, pnum) - return w.finish(pnum) + return w.finish(packetType1RTT, pnum) } // padPacketLength pads out the payload of the current packet to the minimum size, @@ -204,9 +204,10 @@ func (w *packetWriter) padPacketLength(pnumLen int) int { } // finish finishes the current packet after protection is applied. -func (w *packetWriter) finish(pnum packetNumber) *sentPacket { +func (w *packetWriter) finish(ptype packetType, pnum packetNumber) *sentPacket { w.b = w.b[:len(w.b)+aeadOverhead] w.sent.size = len(w.b) - w.pktOff + w.sent.ptype = ptype w.sent.num = pnum sent := w.sent w.sent = nil @@ -242,10 +243,7 @@ func (w *packetWriter) appendPingFrame() (added bool) { return false } w.b = append(w.b, frameTypePing) - // Mark this packet as ack-eliciting and in-flight, - // but there's no need to record the presence of a PING frame in it. - w.sent.ackEliciting = true - w.sent.inFlight = true + w.sent.markAckEliciting() // no need to record the frame itself return true } @@ -387,11 +385,7 @@ func (w *packetWriter) appendStreamFrame(id streamID, off int64, size int, fin b w.b = appendVarint(w.b, uint64(size)) start := len(w.b) w.b = w.b[:start+size] - if fin { - w.sent.appendAckElicitingFrame(frameTypeStreamBase | streamFinBit) - } else { - w.sent.appendAckElicitingFrame(frameTypeStreamBase) - } + w.sent.appendAckElicitingFrame(typ & (frameTypeStreamBase | streamFinBit)) w.sent.appendInt(uint64(id)) w.sent.appendOffAndSize(off, size) return w.b[start:][:size], true @@ -498,23 +492,23 @@ func (w *packetWriter) appendRetireConnectionIDFrame(seq int64) (added bool) { return true } -func (w *packetWriter) appendPathChallengeFrame(data uint64) (added bool) { +func (w *packetWriter) appendPathChallengeFrame(data pathChallengeData) (added bool) { if w.avail() < 1+8 { return false } w.b = append(w.b, frameTypePathChallenge) - w.b = binary.BigEndian.AppendUint64(w.b, data) - w.sent.appendAckElicitingFrame(frameTypePathChallenge) + w.b = append(w.b, data[:]...) + w.sent.markAckEliciting() // no need to record the frame itself return true } -func (w *packetWriter) appendPathResponseFrame(data uint64) (added bool) { +func (w *packetWriter) appendPathResponseFrame(data pathChallengeData) (added bool) { if w.avail() < 1+8 { return false } w.b = append(w.b, frameTypePathResponse) - w.b = binary.BigEndian.AppendUint64(w.b, data) - w.sent.appendAckElicitingFrame(frameTypePathResponse) + w.b = append(w.b, data[:]...) + w.sent.markAckEliciting() // no need to record the frame itself return true } diff --git a/quic/path.go b/quic/path.go new file mode 100644 index 000000000..8c237dd45 --- /dev/null +++ b/quic/path.go @@ -0,0 +1,89 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import "time" + +type pathState struct { + // Response to a peer's PATH_CHALLENGE. + // This is not a sentVal, because we don't resend lost PATH_RESPONSE frames. + // We only track the most recent PATH_CHALLENGE. + // If the peer sends a second PATH_CHALLENGE before we respond to the first, + // we'll drop the first response. + sendPathResponse pathResponseType + data pathChallengeData +} + +// pathChallengeData is data carried in a PATH_CHALLENGE or PATH_RESPONSE frame. +type pathChallengeData [64 / 8]byte + +type pathResponseType uint8 + +const ( + pathResponseNotNeeded = pathResponseType(iota) + pathResponseSmall // send PATH_RESPONSE, do not expand datagram + pathResponseExpanded // send PATH_RESPONSE, expand datagram to 1200 bytes +) + +func (c *Conn) handlePathChallenge(_ time.Time, dgram *datagram, data pathChallengeData) { + // A PATH_RESPONSE is sent in a datagram expanded to 1200 bytes, + // except when this would exceed the anti-amplification limit. + // + // Rather than maintaining anti-amplification state for each path + // we may be sending a PATH_RESPONSE on, follow the following heuristic: + // + // If we receive a PATH_CHALLENGE in an expanded datagram, + // respond with an expanded datagram. + // + // If we receive a PATH_CHALLENGE in a non-expanded datagram, + // then the peer is presumably blocked by its own anti-amplification limit. + // Respond with a non-expanded datagram. Receiving this PATH_RESPONSE + // will validate the path to the peer, remove its anti-amplification limit, + // and permit it to send a followup PATH_CHALLENGE in an expanded datagram. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-8.2.1 + if len(dgram.b) >= smallestMaxDatagramSize { + c.path.sendPathResponse = pathResponseExpanded + } else { + c.path.sendPathResponse = pathResponseSmall + } + c.path.data = data +} + +func (c *Conn) handlePathResponse(now time.Time, _ pathChallengeData) { + // "If the content of a PATH_RESPONSE frame does not match the content of + // a PATH_CHALLENGE frame previously sent by the endpoint, + // the endpoint MAY generate a connection error of type PROTOCOL_VIOLATION." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.18-4 + // + // We never send PATH_CHALLENGE frames. + c.abort(now, localTransportError{ + code: errProtocolViolation, + reason: "PATH_RESPONSE received when no PATH_CHALLENGE sent", + }) +} + +// appendPathFrames appends path validation related frames to the current packet. +// If the return value pad is true, then the packet should be padded to 1200 bytes. +func (c *Conn) appendPathFrames() (pad, ok bool) { + if c.path.sendPathResponse == pathResponseNotNeeded { + return pad, true + } + // We're required to send the PATH_RESPONSE on the path where the + // PATH_CHALLENGE was received (RFC 9000, Section 8.2.2). + // + // At the moment, we don't support path migration and reject packets if + // the peer changes its source address, so just sending the PATH_RESPONSE + // in a regular datagram is fine. + if !c.w.appendPathResponseFrame(c.path.data) { + return pad, false + } + if c.path.sendPathResponse == pathResponseExpanded { + pad = true + } + c.path.sendPathResponse = pathResponseNotNeeded + return pad, true +} diff --git a/quic/path_test.go b/quic/path_test.go new file mode 100644 index 000000000..a309ed14b --- /dev/null +++ b/quic/path_test.go @@ -0,0 +1,66 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "testing" +) + +func TestPathChallengeReceived(t *testing.T) { + for _, test := range []struct { + name string + padTo int + wantPadding int + }{{ + name: "unexpanded", + padTo: 0, + wantPadding: 0, + }, { + name: "expanded", + padTo: 1200, + wantPadding: 1200, + }} { + // "The recipient of [a PATH_CHALLENGE] frame MUST generate + // a PATH_RESPONSE frame [...] containing the same Data value." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.17-7 + tc := newTestConn(t, clientSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + data := pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef} + tc.writeFrames(packetType1RTT, debugFramePathChallenge{ + data: data, + }, debugFramePadding{ + to: test.padTo, + }) + tc.wantFrame("response to PATH_CHALLENGE", + packetType1RTT, debugFramePathResponse{ + data: data, + }) + if got, want := tc.lastDatagram.paddedSize, test.wantPadding; got != want { + t.Errorf("PATH_RESPONSE expanded to %v bytes, want %v", got, want) + } + tc.wantIdle("connection is idle") + } +} + +func TestPathResponseMismatchReceived(t *testing.T) { + // "If the content of a PATH_RESPONSE frame does not match the content of + // a PATH_CHALLENGE frame previously sent by the endpoint, + // the endpoint MAY generate a connection error of type PROTOCOL_VIOLATION." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.18-4 + tc := newTestConn(t, clientSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + tc.writeFrames(packetType1RTT, debugFramePathResponse{ + data: pathChallengeData{}, + }) + tc.wantFrame("invalid PATH_RESPONSE causes the connection to close", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errProtocolViolation, + }, + ) +} diff --git a/internal/quic/ping.go b/quic/ping.go similarity index 100% rename from internal/quic/ping.go rename to quic/ping.go diff --git a/internal/quic/ping_test.go b/quic/ping_test.go similarity index 100% rename from internal/quic/ping_test.go rename to quic/ping_test.go diff --git a/internal/quic/pipe.go b/quic/pipe.go similarity index 71% rename from internal/quic/pipe.go rename to quic/pipe.go index d3a448df3..75cf76db2 100644 --- a/internal/quic/pipe.go +++ b/quic/pipe.go @@ -17,14 +17,14 @@ import ( // Writing past the end of the window extends it. // Data may be discarded from the start of the pipe, advancing the window. type pipe struct { - start int64 - end int64 - head *pipebuf - tail *pipebuf + start int64 // stream position of first stored byte + end int64 // stream position just past the last stored byte + head *pipebuf // if non-nil, then head.off + len(head.b) > start + tail *pipebuf // if non-nil, then tail.off + len(tail.b) == end } type pipebuf struct { - off int64 + off int64 // stream position of b[0] b []byte next *pipebuf } @@ -111,6 +111,7 @@ func (p *pipe) copy(off int64, b []byte) { // read calls f with the data in [off, off+n) // The data may be provided sequentially across multiple calls to f. +// Note that read (unlike an io.Reader) does not consume the read data. func (p *pipe) read(off int64, n int, f func([]byte) error) error { if off < p.start { panic("invalid read range") @@ -135,6 +136,30 @@ func (p *pipe) read(off int64, n int, f func([]byte) error) error { return nil } +// peek returns a reference to up to n bytes of internal data buffer, starting at p.start. +// The returned slice is valid until the next call to discardBefore. +// The length of the returned slice will be in the range [0,n]. +func (p *pipe) peek(n int64) []byte { + pb := p.head + if pb == nil { + return nil + } + b := pb.b[p.start-pb.off:] + return b[:min(int64(len(b)), n)] +} + +// availableBuffer returns the available contiguous, allocated buffer space +// following the pipe window. +// +// This is used by the stream write fast path, which makes multiple writes into the pipe buffer +// without a lock, and then adjusts p.end at a later time with a lock held. +func (p *pipe) availableBuffer() []byte { + if p.tail == nil { + return nil + } + return p.tail.b[p.end-p.tail.off:] +} + // discardBefore discards all data prior to off. func (p *pipe) discardBefore(off int64) { for p.head != nil && p.head.end() < off { diff --git a/internal/quic/pipe_test.go b/quic/pipe_test.go similarity index 100% rename from internal/quic/pipe_test.go rename to quic/pipe_test.go diff --git a/internal/quic/qlog.go b/quic/qlog.go similarity index 92% rename from internal/quic/qlog.go rename to quic/qlog.go index 82ad92ac8..36831252c 100644 --- a/internal/quic/qlog.go +++ b/quic/qlog.go @@ -39,7 +39,11 @@ const ( ) func (c *Conn) logEnabled(level slog.Level) bool { - return c.log != nil && c.log.Enabled(context.Background(), level) + return logEnabled(c.log, level) +} + +func logEnabled(log *slog.Logger, level slog.Level) bool { + return log != nil && log.Enabled(context.Background(), level) } // slogHexstring returns a slog.Attr for a value of the hexstring type. @@ -147,6 +151,12 @@ func (c *Conn) logConnectionClosed() { ) } +func (c *Conn) logPacketDropped(dgram *datagram) { + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "connectivity:packet_dropped", + ) +} + func (c *Conn) logLongPacketReceived(p longPacket, pkt []byte) { var frames slog.Attr if c.logEnabled(QLogLevelFrame) { @@ -252,3 +262,13 @@ func (c *Conn) packetFramesAttr(payload []byte) slog.Attr { } return slog.Any("frames", frames) } + +func (c *Conn) logPacketLost(space numberSpace, sent *sentPacket) { + c.log.LogAttrs(context.Background(), QLogLevelPacket, + "recovery:packet_lost", + slog.Group("header", + slog.String("packet_type", sent.ptype.qlogString()), + slog.Uint64("packet_number", uint64(sent.num)), + ), + ) +} diff --git a/internal/quic/qlog/handler.go b/quic/qlog/handler.go similarity index 100% rename from internal/quic/qlog/handler.go rename to quic/qlog/handler.go diff --git a/internal/quic/qlog/json_writer.go b/quic/qlog/json_writer.go similarity index 99% rename from internal/quic/qlog/json_writer.go rename to quic/qlog/json_writer.go index b2fa3e03e..6fb8d33b2 100644 --- a/internal/quic/qlog/json_writer.go +++ b/quic/qlog/json_writer.go @@ -45,15 +45,15 @@ func (w *jsonWriter) writeRecordEnd() { func (w *jsonWriter) writeAttrs(attrs []slog.Attr) { w.buf.WriteByte('{') for _, a := range attrs { - if a.Key == "" { - continue - } w.writeAttr(a) } w.buf.WriteByte('}') } func (w *jsonWriter) writeAttr(a slog.Attr) { + if a.Key == "" { + return + } w.writeName(a.Key) w.writeValue(a.Value) } diff --git a/internal/quic/qlog/json_writer_test.go b/quic/qlog/json_writer_test.go similarity index 96% rename from internal/quic/qlog/json_writer_test.go rename to quic/qlog/json_writer_test.go index 6da556641..03cf6947c 100644 --- a/internal/quic/qlog/json_writer_test.go +++ b/quic/qlog/json_writer_test.go @@ -85,6 +85,15 @@ func TestJSONWriterAttrs(t *testing.T) { `}}`) } +func TestJSONWriterAttrEmpty(t *testing.T) { + w := newTestJSONWriter() + w.writeRecordStart() + var a slog.Attr + w.writeAttr(a) + w.writeRecordEnd() + wantJSONRecord(t, w, `{}`) +} + func TestJSONWriterObjectEmpty(t *testing.T) { w := newTestJSONWriter() w.writeRecordStart() diff --git a/internal/quic/qlog/qlog.go b/quic/qlog/qlog.go similarity index 99% rename from internal/quic/qlog/qlog.go rename to quic/qlog/qlog.go index e54c839f0..f33c6b0fd 100644 --- a/internal/quic/qlog/qlog.go +++ b/quic/qlog/qlog.go @@ -29,7 +29,7 @@ const ( // VantageClient traces follow a connection from the client's perspective. VantageClient = Vantage("client") - // VantageClient traces follow a connection from the server's perspective. + // VantageServer traces follow a connection from the server's perspective. VantageServer = Vantage("server") ) diff --git a/internal/quic/qlog/qlog_test.go b/quic/qlog/qlog_test.go similarity index 100% rename from internal/quic/qlog/qlog_test.go rename to quic/qlog/qlog_test.go diff --git a/internal/quic/qlog_test.go b/quic/qlog_test.go similarity index 74% rename from internal/quic/qlog_test.go rename to quic/qlog_test.go index e98b11838..c0b5cd170 100644 --- a/internal/quic/qlog_test.go +++ b/quic/qlog_test.go @@ -7,6 +7,7 @@ package quic import ( + "bytes" "encoding/hex" "encoding/json" "fmt" @@ -16,7 +17,7 @@ import ( "testing" "time" - "golang.org/x/net/internal/quic/qlog" + "golang.org/x/net/quic/qlog" ) func TestQLogHandshake(t *testing.T) { @@ -159,6 +160,98 @@ func TestQLogConnectionClosedTrigger(t *testing.T) { } } +func TestQLogRecovery(t *testing.T) { + qr := &qlogRecord{} + tc, s := newTestConnAndLocalStream(t, clientSide, uniStream, + permissiveTransportParameters, qr.config) + + // Ignore events from the handshake. + qr.ev = nil + + data := make([]byte, 16) + s.Write(data) + s.CloseWrite() + tc.wantFrame("created stream 0", + packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, uniStream, 0), + fin: true, + data: data, + }) + tc.writeAckForAll() + tc.wantIdle("connection should be idle now") + + // Don't check the contents of fields, but verify that recovery metrics are logged. + qr.wantEvents(t, jsonEvent{ + "name": "recovery:metrics_updated", + "data": map[string]any{ + "bytes_in_flight": nil, + }, + }, jsonEvent{ + "name": "recovery:metrics_updated", + "data": map[string]any{ + "bytes_in_flight": 0, + "congestion_window": nil, + "latest_rtt": nil, + "min_rtt": nil, + "rtt_variance": nil, + "smoothed_rtt": nil, + }, + }) +} + +func TestQLogLoss(t *testing.T) { + qr := &qlogRecord{} + tc, s := newTestConnAndLocalStream(t, clientSide, uniStream, + permissiveTransportParameters, qr.config) + + // Ignore events from the handshake. + qr.ev = nil + + data := make([]byte, 16) + s.Write(data) + s.CloseWrite() + tc.wantFrame("created stream 0", + packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, uniStream, 0), + fin: true, + data: data, + }) + + const pto = false + tc.triggerLossOrPTO(packetType1RTT, pto) + + qr.wantEvents(t, jsonEvent{ + "name": "recovery:packet_lost", + "data": map[string]any{ + "header": map[string]any{ + "packet_number": nil, + "packet_type": "1RTT", + }, + }, + }) +} + +func TestQLogPacketDropped(t *testing.T) { + qr := &qlogRecord{} + tc := newTestConn(t, clientSide, permissiveTransportParameters, qr.config) + tc.handshake() + + // A garbage-filled datagram with a DCID matching this connection. + dgram := bytes.Join([][]byte{ + {headerFormShort | fixedBit}, + testLocalConnID(0), + make([]byte, 100), + []byte{1, 2, 3, 4}, // random data, to avoid this looking like a stateless reset + }, nil) + tc.endpoint.write(&datagram{ + b: dgram, + }) + + qr.wantEvents(t, jsonEvent{ + "name": "connectivity:packet_dropped", + }) +} + type nopCloseWriter struct { io.Writer } @@ -193,14 +286,15 @@ func jsonPartialEqual(got, want any) (equal bool) { } return v } + if want == nil { + return true // match anything + } got = cmpval(got) want = cmpval(want) if reflect.TypeOf(got) != reflect.TypeOf(want) { return false } switch w := want.(type) { - case nil: - // Match anything. case map[string]any: // JSON object: Every field in want must match a field in got. g := got.(map[string]any) diff --git a/internal/quic/queue.go b/quic/queue.go similarity index 100% rename from internal/quic/queue.go rename to quic/queue.go diff --git a/internal/quic/queue_test.go b/quic/queue_test.go similarity index 100% rename from internal/quic/queue_test.go rename to quic/queue_test.go diff --git a/internal/quic/quic.go b/quic/quic.go similarity index 100% rename from internal/quic/quic.go rename to quic/quic.go diff --git a/internal/quic/quic_test.go b/quic/quic_test.go similarity index 100% rename from internal/quic/quic_test.go rename to quic/quic_test.go diff --git a/internal/quic/rangeset.go b/quic/rangeset.go similarity index 98% rename from internal/quic/rangeset.go rename to quic/rangeset.go index 4966a99d2..b8b2e9367 100644 --- a/internal/quic/rangeset.go +++ b/quic/rangeset.go @@ -50,7 +50,7 @@ func (s *rangeset[T]) add(start, end T) { if end <= r.end { return } - // Possibly coalesce subsquent ranges into range i. + // Possibly coalesce subsequent ranges into range i. r.end = end j := i + 1 for ; j < len(*s) && r.end >= (*s)[j].start; j++ { diff --git a/internal/quic/rangeset_test.go b/quic/rangeset_test.go similarity index 100% rename from internal/quic/rangeset_test.go rename to quic/rangeset_test.go diff --git a/internal/quic/retry.go b/quic/retry.go similarity index 96% rename from internal/quic/retry.go rename to quic/retry.go index 31cb57b88..5dc39d1d9 100644 --- a/internal/quic/retry.go +++ b/quic/retry.go @@ -139,7 +139,7 @@ func (rs *retryState) additionalData(srcConnID []byte, addr netip.AddrPort) []by return additional } -func (e *Endpoint) validateInitialAddress(now time.Time, p genericLongPacket, addr netip.AddrPort) (origDstConnID []byte, ok bool) { +func (e *Endpoint) validateInitialAddress(now time.Time, p genericLongPacket, peerAddr netip.AddrPort) (origDstConnID []byte, ok bool) { // The retry token is at the start of an Initial packet's data. token, n := consumeUint8Bytes(p.data) if n < 0 { @@ -151,22 +151,22 @@ func (e *Endpoint) validateInitialAddress(now time.Time, p genericLongPacket, ad if len(token) == 0 { // The sender has not provided a token. // Send a Retry packet to them with one. - e.sendRetry(now, p, addr) + e.sendRetry(now, p, peerAddr) return nil, false } - origDstConnID, ok = e.retry.validateToken(now, token, p.srcConnID, p.dstConnID, addr) + origDstConnID, ok = e.retry.validateToken(now, token, p.srcConnID, p.dstConnID, peerAddr) if !ok { // This does not seem to be a valid token. // Close the connection with an INVALID_TOKEN error. // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.2-5 - e.sendConnectionClose(p, addr, errInvalidToken) + e.sendConnectionClose(p, peerAddr, errInvalidToken) return nil, false } return origDstConnID, true } -func (e *Endpoint) sendRetry(now time.Time, p genericLongPacket, addr netip.AddrPort) { - token, srcConnID, err := e.retry.makeToken(now, p.srcConnID, p.dstConnID, addr) +func (e *Endpoint) sendRetry(now time.Time, p genericLongPacket, peerAddr netip.AddrPort) { + token, srcConnID, err := e.retry.makeToken(now, p.srcConnID, p.dstConnID, peerAddr) if err != nil { return } @@ -175,7 +175,10 @@ func (e *Endpoint) sendRetry(now time.Time, p genericLongPacket, addr netip.Addr srcConnID: srcConnID, token: token, }) - e.sendDatagram(b, addr) + e.sendDatagram(datagram{ + b: b, + peerAddr: peerAddr, + }) } type retryPacket struct { diff --git a/internal/quic/retry_test.go b/quic/retry_test.go similarity index 99% rename from internal/quic/retry_test.go rename to quic/retry_test.go index 8f36e1bd3..c898ad331 100644 --- a/internal/quic/retry_test.go +++ b/quic/retry_test.go @@ -436,8 +436,8 @@ func TestRetryClientIgnoresRetryWithInvalidIntegrityTag(t *testing.T) { }) pkt[len(pkt)-1] ^= 1 // invalidate the integrity tag tc.endpoint.write(&datagram{ - b: pkt, - addr: testClientAddr, + b: pkt, + peerAddr: testClientAddr, }) tc.wantIdle("client ignores Retry with invalid integrity tag") } @@ -521,7 +521,7 @@ func TestParseInvalidRetryPackets(t *testing.T) { }} { t.Run(test.name, func(t *testing.T) { if _, ok := parseRetryPacket(test.pkt, originalDstConnID); ok { - t.Errorf("parseRetryPacket succeded, want failure") + t.Errorf("parseRetryPacket succeeded, want failure") } }) } diff --git a/internal/quic/rtt.go b/quic/rtt.go similarity index 100% rename from internal/quic/rtt.go rename to quic/rtt.go diff --git a/internal/quic/rtt_test.go b/quic/rtt_test.go similarity index 100% rename from internal/quic/rtt_test.go rename to quic/rtt_test.go diff --git a/internal/quic/sent_packet.go b/quic/sent_packet.go similarity index 90% rename from internal/quic/sent_packet.go rename to quic/sent_packet.go index 4f11aa136..226152327 100644 --- a/internal/quic/sent_packet.go +++ b/quic/sent_packet.go @@ -14,9 +14,10 @@ import ( // A sentPacket tracks state related to an in-flight packet we sent, // to be committed when the peer acks it or resent if the packet is lost. type sentPacket struct { - num packetNumber - size int // size in bytes - time time.Time // time sent + num packetNumber + size int // size in bytes + time time.Time // time sent + ptype packetType ackEliciting bool // https://www.rfc-editor.org/rfc/rfc9002.html#section-2-3.4.1 inFlight bool // https://www.rfc-editor.org/rfc/rfc9002.html#section-2-3.6.1 @@ -58,6 +59,12 @@ func (sent *sentPacket) reset() { } } +// markAckEliciting marks the packet as containing an ack-eliciting frame. +func (sent *sentPacket) markAckEliciting() { + sent.ackEliciting = true + sent.inFlight = true +} + // The append* methods record information about frames in the packet. func (sent *sentPacket) appendNonAckElicitingFrame(frameType byte) { diff --git a/internal/quic/sent_packet_list.go b/quic/sent_packet_list.go similarity index 100% rename from internal/quic/sent_packet_list.go rename to quic/sent_packet_list.go diff --git a/internal/quic/sent_packet_list_test.go b/quic/sent_packet_list_test.go similarity index 100% rename from internal/quic/sent_packet_list_test.go rename to quic/sent_packet_list_test.go diff --git a/internal/quic/sent_packet_test.go b/quic/sent_packet_test.go similarity index 100% rename from internal/quic/sent_packet_test.go rename to quic/sent_packet_test.go diff --git a/internal/quic/sent_val.go b/quic/sent_val.go similarity index 100% rename from internal/quic/sent_val.go rename to quic/sent_val.go diff --git a/internal/quic/sent_val_test.go b/quic/sent_val_test.go similarity index 100% rename from internal/quic/sent_val_test.go rename to quic/sent_val_test.go diff --git a/internal/quic/stateless_reset.go b/quic/stateless_reset.go similarity index 100% rename from internal/quic/stateless_reset.go rename to quic/stateless_reset.go diff --git a/internal/quic/stateless_reset_test.go b/quic/stateless_reset_test.go similarity index 99% rename from internal/quic/stateless_reset_test.go rename to quic/stateless_reset_test.go index 45a49e81e..9458d2ea9 100644 --- a/internal/quic/stateless_reset_test.go +++ b/quic/stateless_reset_test.go @@ -57,8 +57,8 @@ func newDatagramForReset(cid []byte, size int, addr netip.AddrPort) *datagram { dgram = append(dgram, byte(len(dgram))) // semi-random junk } return &datagram{ - b: dgram, - addr: addr, + b: dgram, + peerAddr: addr, } } diff --git a/internal/quic/stream.go b/quic/stream.go similarity index 81% rename from internal/quic/stream.go rename to quic/stream.go index fb9c1cf3c..cb45534f8 100644 --- a/internal/quic/stream.go +++ b/quic/stream.go @@ -14,11 +14,31 @@ import ( "math" ) +// A Stream is an ordered byte stream. +// +// Streams may be bidirectional, read-only, or write-only. +// Methods inappropriate for a stream's direction +// (for example, [Write] to a read-only stream) +// return errors. +// +// It is not safe to perform concurrent reads from or writes to a stream. +// It is safe, however, to read and write at the same time. +// +// Reads and writes are buffered. +// It is generally not necessary to wrap a stream in a [bufio.ReadWriter] +// or otherwise apply additional buffering. +// +// To cancel reads or writes, use the [SetReadContext] and [SetWriteContext] methods. type Stream struct { id streamID conn *Conn - // ingate's lock guards all receive-related state. + // Contexts used for read/write operations. + // Intentionally not mutex-guarded, to allow the race detector to catch concurrent access. + inctx context.Context + outctx context.Context + + // ingate's lock guards receive-related state. // // The gate condition is set if a read from the stream will not block, // either because the stream has available data or because the read will fail. @@ -32,7 +52,7 @@ type Stream struct { inclosed sentVal // set by CloseRead inresetcode int64 // RESET_STREAM code received from the peer; -1 if not reset - // outgate's lock guards all send-related state. + // outgate's lock guards send-related state. // // The gate condition is set if a write to the stream will not block, // either because the stream has available flow control or because @@ -52,6 +72,12 @@ type Stream struct { outresetcode uint64 // reset code to send in RESET_STREAM outdone chan struct{} // closed when all data sent + // Unsynchronized buffers, used for lock-free fast path. + inbuf []byte // received data + inbufoff int // bytes of inbuf which have been consumed + outbuf []byte // written data + outbufoff int // bytes of outbuf which contain data to write + // Atomic stream state bits. // // These bits provide a fast way to coordinate between the @@ -152,6 +178,8 @@ func newStream(c *Conn, id streamID) *Stream { inresetcode: -1, // -1 indicates no RESET_STREAM received ingate: newLockedGate(), outgate: newLockedGate(), + inctx: context.Background(), + outctx: context.Background(), } if !s.IsReadOnly() { s.outdone = make(chan struct{}) @@ -159,6 +187,22 @@ func newStream(c *Conn, id streamID) *Stream { return s } +// SetReadContext sets the context used for reads from the stream. +// +// It is not safe to call SetReadContext concurrently. +func (s *Stream) SetReadContext(ctx context.Context) { + s.inctx = ctx +} + +// SetWriteContext sets the context used for writes to the stream. +// The write context is also used by Close when waiting for writes to be +// received by the peer. +// +// It is not safe to call SetWriteContext concurrently. +func (s *Stream) SetWriteContext(ctx context.Context) { + s.outctx = ctx +} + // IsReadOnly reports whether the stream is read-only // (a unidirectional stream created by the peer). func (s *Stream) IsReadOnly() bool { @@ -172,29 +216,42 @@ func (s *Stream) IsWriteOnly() bool { } // Read reads data from the stream. -// See ReadContext for more details. -func (s *Stream) Read(b []byte) (n int, err error) { - return s.ReadContext(context.Background(), b) -} - -// ReadContext reads data from the stream. // -// ReadContext returns as soon as at least one byte of data is available. +// Read returns as soon as at least one byte of data is available. // -// If the peer closes the stream cleanly, ReadContext returns io.EOF after +// If the peer closes the stream cleanly, Read returns io.EOF after // returning all data sent by the peer. -// If the peer aborts reads on the stream, ReadContext returns +// If the peer aborts reads on the stream, Read returns // an error wrapping StreamResetCode. -func (s *Stream) ReadContext(ctx context.Context, b []byte) (n int, err error) { +// +// It is not safe to call Read concurrently. +func (s *Stream) Read(b []byte) (n int, err error) { if s.IsWriteOnly() { return 0, errors.New("read from write-only stream") } - if err := s.ingate.waitAndLock(ctx, s.conn.testHooks); err != nil { + if len(s.inbuf) > s.inbufoff { + // Fast path: If s.inbuf contains unread bytes, return them immediately + // without taking a lock. + n = copy(b, s.inbuf[s.inbufoff:]) + s.inbufoff += n + return n, nil + } + if err := s.ingate.waitAndLock(s.inctx, s.conn.testHooks); err != nil { return 0, err } + if s.inbufoff > 0 { + // Discard bytes consumed by the fast path above. + s.in.discardBefore(s.in.start + int64(s.inbufoff)) + s.inbufoff = 0 + s.inbuf = nil + } + // bytesRead contains the number of bytes of connection-level flow control to return. + // We return flow control for bytes read by this Read call, as well as bytes moved + // to the fast-path read buffer (s.inbuf). + var bytesRead int64 defer func() { s.inUnlock() - s.conn.handleStreamBytesReadOffLoop(int64(n)) // must be done with ingate unlocked + s.conn.handleStreamBytesReadOffLoop(bytesRead) // must be done with ingate unlocked }() if s.inresetcode != -1 { return 0, fmt.Errorf("stream reset by peer: %w", StreamErrorCode(s.inresetcode)) @@ -212,22 +269,50 @@ func (s *Stream) ReadContext(ctx context.Context, b []byte) (n int, err error) { if size := int(s.inset[0].end - s.in.start); size < len(b) { b = b[:size] } + bytesRead = int64(len(b)) start := s.in.start end := start + int64(len(b)) s.in.copy(start, b) s.in.discardBefore(end) + if end == s.insize { + // We have read up to the end of the stream. + // No need to update stream flow control. + return len(b), io.EOF + } + if len(s.inset) > 0 && s.inset[0].start <= s.in.start && s.inset[0].end > s.in.start { + // If we have more readable bytes available, put the next chunk of data + // in s.inbuf for lock-free reads. + s.inbuf = s.in.peek(s.inset[0].end - s.in.start) + bytesRead += int64(len(s.inbuf)) + } if s.insize == -1 || s.insize > s.inwin { - if shouldUpdateFlowControl(s.inmaxbuf, s.in.start+s.inmaxbuf-s.inwin) { + newWindow := s.in.start + int64(len(s.inbuf)) + s.inmaxbuf + addedWindow := newWindow - s.inwin + if shouldUpdateFlowControl(s.inmaxbuf, addedWindow) { // Update stream flow control with a STREAM_MAX_DATA frame. s.insendmax.setUnsent() } } - if end == s.insize { - return len(b), io.EOF - } return len(b), nil } +// ReadByte reads and returns a single byte from the stream. +// +// It is not safe to call ReadByte concurrently. +func (s *Stream) ReadByte() (byte, error) { + if len(s.inbuf) > s.inbufoff { + b := s.inbuf[s.inbufoff] + s.inbufoff++ + return b, nil + } + var b [1]byte + n, err := s.Read(b[:]) + if n > 0 { + return b[0], err + } + return 0, err +} + // shouldUpdateFlowControl determines whether to send a flow control window update. // // We want to balance keeping the peer well-supplied with flow control with not sending @@ -237,21 +322,22 @@ func shouldUpdateFlowControl(maxWindow, addedWindow int64) bool { } // Write writes data to the stream. -// See WriteContext for more details. -func (s *Stream) Write(b []byte) (n int, err error) { - return s.WriteContext(context.Background(), b) -} - -// WriteContext writes data to the stream. // -// WriteContext writes data to the stream write buffer. +// Write writes data to the stream write buffer. // Buffered data is only sent when the buffer is sufficiently full. // Call the Flush method to ensure buffered data is sent. -func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) { +func (s *Stream) Write(b []byte) (n int, err error) { if s.IsReadOnly() { return 0, errors.New("write to read-only stream") } + if len(b) > 0 && len(s.outbuf)-s.outbufoff >= len(b) { + // Fast path: The data to write fits in s.outbuf. + copy(s.outbuf[s.outbufoff:], b) + s.outbufoff += len(b) + return len(b), nil + } canWrite := s.outgate.lock() + s.flushFastOutputBuffer() for { // The first time through this loop, we may or may not be write blocked. // We exit the loop after writing all data, so on subsequent passes through @@ -259,7 +345,7 @@ func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) if len(b) > 0 && !canWrite { // Our send buffer is full. Wait for the peer to ack some data. s.outUnlock() - if err := s.outgate.waitAndLock(ctx, s.conn.testHooks); err != nil { + if err := s.outgate.waitAndLock(s.outctx, s.conn.testHooks); err != nil { return n, err } // Successfully returning from waitAndLockGate means we are no longer @@ -311,13 +397,54 @@ func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) // If we have bytes left to send, we're blocked. canWrite = false } + if lim := s.out.start + s.outmaxbuf - s.out.end - 1; lim > 0 { + // If s.out has space allocated and available to be written into, + // then reference it in s.outbuf for fast-path writes. + // + // It's perhaps a bit pointless to limit s.outbuf to the send buffer limit. + // We've already allocated this buffer so we aren't saving any memory + // by not using it. + // For now, we limit it anyway to make it easier to reason about limits. + // + // We set the limit to one less than the send buffer limit (the -1 above) + // so that a write which completely fills the buffer will overflow + // s.outbuf and trigger a flush. + s.outbuf = s.out.availableBuffer() + if int64(len(s.outbuf)) > lim { + s.outbuf = s.outbuf[:lim] + } + } s.outUnlock() return n, nil } +// WriteBytes writes a single byte to the stream. +func (s *Stream) WriteByte(c byte) error { + if s.outbufoff < len(s.outbuf) { + s.outbuf[s.outbufoff] = c + s.outbufoff++ + return nil + } + b := [1]byte{c} + _, err := s.Write(b[:]) + return err +} + +func (s *Stream) flushFastOutputBuffer() { + if s.outbuf == nil { + return + } + // Commit data previously written to s.outbuf. + // s.outbuf is a reference to a buffer in s.out, so we just need to record + // that the output buffer has been extended. + s.out.end += int64(s.outbufoff) + s.outbuf = nil + s.outbufoff = 0 +} + // Flush flushes data written to the stream. // It does not wait for the peer to acknowledge receipt of the data. -// Use CloseContext to wait for the peer's acknowledgement. +// Use Close to wait for the peer's acknowledgement. func (s *Stream) Flush() { s.outgate.lock() defer s.outUnlock() @@ -325,6 +452,7 @@ func (s *Stream) Flush() { } func (s *Stream) flushLocked() { + s.flushFastOutputBuffer() s.outopened.set() if s.outflushed < s.outwin { s.outunsent.add(s.outflushed, min(s.outwin, s.out.end)) @@ -333,27 +461,21 @@ func (s *Stream) flushLocked() { } // Close closes the stream. -// See CloseContext for more details. -func (s *Stream) Close() error { - return s.CloseContext(context.Background()) -} - -// CloseContext closes the stream. // Any blocked stream operations will be unblocked and return errors. // -// CloseContext flushes any data in the stream write buffer and waits for the peer to +// Close flushes any data in the stream write buffer and waits for the peer to // acknowledge receipt of the data. // If the stream has been reset, it waits for the peer to acknowledge the reset. // If the context expires before the peer receives the stream's data, -// CloseContext discards the buffer and returns the context error. -func (s *Stream) CloseContext(ctx context.Context) error { +// Close discards the buffer and returns the context error. +func (s *Stream) Close() error { s.CloseRead() if s.IsReadOnly() { return nil } s.CloseWrite() // TODO: Return code from peer's RESET_STREAM frame? - if err := s.conn.waitOnDone(ctx, s.outdone); err != nil { + if err := s.conn.waitOnDone(s.outctx, s.outdone); err != nil { return err } s.outgate.lock() @@ -369,7 +491,7 @@ func (s *Stream) CloseContext(ctx context.Context) error { // // CloseRead notifies the peer that the stream has been closed for reading. // It does not wait for the peer to acknowledge the closure. -// Use CloseContext to wait for the peer's acknowledgement. +// Use Close to wait for the peer's acknowledgement. func (s *Stream) CloseRead() { if s.IsWriteOnly() { return @@ -394,7 +516,7 @@ func (s *Stream) CloseRead() { // // CloseWrite sends any data in the stream write buffer to the peer. // It does not wait for the peer to acknowledge receipt of the data. -// Use CloseContext to wait for the peer's acknowledgement. +// Use Close to wait for the peer's acknowledgement. func (s *Stream) CloseWrite() { if s.IsReadOnly() { return @@ -412,7 +534,7 @@ func (s *Stream) CloseWrite() { // Reset sends the application protocol error code, which must be // less than 2^62, to the peer. // It does not wait for the peer to acknowledge receipt of the error. -// Use CloseContext to wait for the peer's acknowledgement. +// Use Close to wait for the peer's acknowledgement. // // Reset does not affect reads. // Use CloseRead to abort reads on the stream. @@ -446,6 +568,8 @@ func (s *Stream) resetInternal(code uint64, userClosed bool) { // extra RESET_STREAM in this case is harmless. s.outreset.set() s.outresetcode = code + s.outbuf = nil + s.outbufoff = 0 s.out.discardBefore(s.out.end) s.outunsent = rangeset[int64]{} s.outblocked.clear() @@ -488,8 +612,9 @@ func (s *Stream) inUnlock() { // inUnlockNoQueue is inUnlock, // but reports whether s has frames to write rather than notifying the Conn. func (s *Stream) inUnlockNoQueue() streamState { - canRead := s.inset.contains(s.in.start) || // data available to read - s.insize == s.in.start || // at EOF + nextByte := s.in.start + int64(len(s.inbuf)) + canRead := s.inset.contains(nextByte) || // data available to read + s.insize == s.in.start+int64(len(s.inbuf)) || // at EOF s.inresetcode != -1 || // reset by peer s.inclosed.isSet() // closed locally defer s.ingate.unlock(canRead) diff --git a/internal/quic/stream_limits.go b/quic/stream_limits.go similarity index 100% rename from internal/quic/stream_limits.go rename to quic/stream_limits.go diff --git a/internal/quic/stream_limits_test.go b/quic/stream_limits_test.go similarity index 96% rename from internal/quic/stream_limits_test.go rename to quic/stream_limits_test.go index 3f291e9f4..8fed825d7 100644 --- a/internal/quic/stream_limits_test.go +++ b/quic/stream_limits_test.go @@ -200,7 +200,6 @@ func TestStreamLimitMaxStreamsFrameTooLarge(t *testing.T) { func TestStreamLimitSendUpdatesMaxStreams(t *testing.T) { testStreamTypes(t, "", func(t *testing.T, styp streamType) { - ctx := canceledContext() tc := newTestConn(t, serverSide, func(c *Config) { if styp == uniStream { c.MaxUniRemoteStreams = 4 @@ -218,13 +217,9 @@ func TestStreamLimitSendUpdatesMaxStreams(t *testing.T) { id: newStreamID(clientSide, styp, int64(i)), fin: true, }) - s, err := tc.conn.AcceptStream(ctx) - if err != nil { - t.Fatalf("AcceptStream = %v", err) - } - streams = append(streams, s) + streams = append(streams, tc.acceptStream()) } - streams[3].CloseContext(ctx) + streams[3].Close() if styp == bidiStream { tc.wantFrame("stream is closed", packetType1RTT, debugFrameStream{ @@ -254,7 +249,7 @@ func TestStreamLimitStopSendingDoesNotUpdateMaxStreams(t *testing.T) { tc.writeFrames(packetType1RTT, debugFrameStopSending{ id: s.id, }) - tc.wantFrame("recieved STOP_SENDING, send RESET_STREAM", + tc.wantFrame("received STOP_SENDING, send RESET_STREAM", packetType1RTT, debugFrameResetStream{ id: s.id, }) diff --git a/internal/quic/stream_test.go b/quic/stream_test.go similarity index 91% rename from internal/quic/stream_test.go rename to quic/stream_test.go index 00e392dba..9f857f29d 100644 --- a/internal/quic/stream_test.go +++ b/quic/stream_test.go @@ -19,7 +19,6 @@ import ( func TestStreamWriteBlockedByOutputBuffer(t *testing.T) { testStreamTypes(t, "", func(t *testing.T, styp streamType) { - ctx := canceledContext() want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} const writeBufferSize = 4 tc := newTestConn(t, clientSide, permissiveTransportParameters, func(c *Config) { @@ -28,15 +27,12 @@ func TestStreamWriteBlockedByOutputBuffer(t *testing.T) { tc.handshake() tc.ignoreFrame(frameTypeAck) - s, err := tc.conn.newLocalStream(ctx, styp) - if err != nil { - t.Fatal(err) - } + s := newLocalStream(t, tc, styp) // Non-blocking write. - n, err := s.WriteContext(ctx, want) + n, err := s.Write(want) if n != writeBufferSize || err != context.Canceled { - t.Fatalf("s.WriteContext() = %v, %v; want %v, context.Canceled", n, err, writeBufferSize) + t.Fatalf("s.Write() = %v, %v; want %v, context.Canceled", n, err, writeBufferSize) } s.Flush() tc.wantFrame("first write buffer of data sent", @@ -48,7 +44,8 @@ func TestStreamWriteBlockedByOutputBuffer(t *testing.T) { // Blocking write, which must wait for buffer space. w := runAsync(tc, func(ctx context.Context) (int, error) { - n, err := s.WriteContext(ctx, want[writeBufferSize:]) + s.SetWriteContext(ctx) + n, err := s.Write(want[writeBufferSize:]) s.Flush() return n, err }) @@ -75,7 +72,7 @@ func TestStreamWriteBlockedByOutputBuffer(t *testing.T) { }) if n, err := w.result(); n != len(want)-writeBufferSize || err != nil { - t.Fatalf("s.WriteContext() = %v, %v; want %v, nil", + t.Fatalf("s.Write() = %v, %v; want %v, nil", len(want)-writeBufferSize, err, writeBufferSize) } }) @@ -99,10 +96,11 @@ func TestStreamWriteBlockedByStreamFlowControl(t *testing.T) { } // Data is written to the stream output buffer, but we have no flow control. - _, err = s.WriteContext(ctx, want[:1]) + _, err = s.Write(want[:1]) if err != nil { t.Fatalf("write with available output buffer: unexpected error: %v", err) } + s.Flush() tc.wantFrame("write blocked by flow control triggers a STREAM_DATA_BLOCKED frame", packetType1RTT, debugFrameStreamDataBlocked{ id: s.id, @@ -110,10 +108,11 @@ func TestStreamWriteBlockedByStreamFlowControl(t *testing.T) { }) // Write more data. - _, err = s.WriteContext(ctx, want[1:]) + _, err = s.Write(want[1:]) if err != nil { t.Fatalf("write with available output buffer: unexpected error: %v", err) } + s.Flush() tc.wantIdle("adding more blocked data does not trigger another STREAM_DATA_BLOCKED") // Provide some flow control window. @@ -172,7 +171,7 @@ func TestStreamIgnoresMaxStreamDataReduction(t *testing.T) { if err != nil { t.Fatal(err) } - s.WriteContext(ctx, want[:1]) + s.Write(want[:1]) s.Flush() tc.wantFrame("sent data (1 byte) fits within flow control limit", packetType1RTT, debugFrameStream{ @@ -188,7 +187,7 @@ func TestStreamIgnoresMaxStreamDataReduction(t *testing.T) { }) // Write [1,4). - s.WriteContext(ctx, want[1:]) + s.Write(want[1:]) tc.wantFrame("stream limit is 4 bytes, ignoring decrease in MAX_STREAM_DATA", packetType1RTT, debugFrameStream{ id: s.id, @@ -208,7 +207,7 @@ func TestStreamIgnoresMaxStreamDataReduction(t *testing.T) { }) // Write [1,4). - s.WriteContext(ctx, want[4:]) + s.Write(want[4:]) tc.wantFrame("stream limit is 8 bytes, ignoring decrease in MAX_STREAM_DATA", packetType1RTT, debugFrameStream{ id: s.id, @@ -220,7 +219,6 @@ func TestStreamIgnoresMaxStreamDataReduction(t *testing.T) { func TestStreamWriteBlockedByWriteBufferLimit(t *testing.T) { testStreamTypes(t, "", func(t *testing.T, styp streamType) { - ctx := canceledContext() want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} const maxWriteBuffer = 4 tc := newTestConn(t, clientSide, func(p *transportParameters) { @@ -238,12 +236,10 @@ func TestStreamWriteBlockedByWriteBufferLimit(t *testing.T) { // Write more data than StreamWriteBufferSize. // The peer has given us plenty of flow control, // so we're just blocked by our local limit. - s, err := tc.conn.newLocalStream(ctx, styp) - if err != nil { - t.Fatal(err) - } + s := newLocalStream(t, tc, styp) w := runAsync(tc, func(ctx context.Context) (int, error) { - return s.WriteContext(ctx, want) + s.SetWriteContext(ctx) + return s.Write(want) }) tc.wantFrame("stream write should send as much data as write buffer allows", packetType1RTT, debugFrameStream{ @@ -266,7 +262,7 @@ func TestStreamWriteBlockedByWriteBufferLimit(t *testing.T) { w.cancel() n, err := w.result() if n != 2*maxWriteBuffer || err == nil { - t.Fatalf("WriteContext() = %v, %v; want %v bytes, error", n, err, 2*maxWriteBuffer) + t.Fatalf("Write() = %v, %v; want %v bytes, error", n, err, 2*maxWriteBuffer) } }) } @@ -397,7 +393,6 @@ func TestStreamReceive(t *testing.T) { }}, }} { testStreamTypes(t, test.name, func(t *testing.T, styp streamType) { - ctx := canceledContext() tc := newTestConn(t, serverSide) tc.handshake() sid := newStreamID(clientSide, styp, 0) @@ -413,21 +408,17 @@ func TestStreamReceive(t *testing.T) { fin: f.fin, }) if s == nil { - var err error - s, err = tc.conn.AcceptStream(ctx) - if err != nil { - tc.t.Fatalf("conn.AcceptStream() = %v", err) - } + s = tc.acceptStream() } for { - n, err := s.ReadContext(ctx, got[total:]) - t.Logf("s.ReadContext() = %v, %v", n, err) + n, err := s.Read(got[total:]) + t.Logf("s.Read() = %v, %v", n, err) total += n if f.wantEOF && err != io.EOF { - t.Fatalf("ReadContext() error = %v; want io.EOF", err) + t.Fatalf("Read() error = %v; want io.EOF", err) } if !f.wantEOF && err == io.EOF { - t.Fatalf("ReadContext() error = io.EOF, want something else") + t.Fatalf("Read() error = io.EOF, want something else") } if err != nil { break @@ -468,8 +459,8 @@ func TestStreamReceiveExtendsStreamWindow(t *testing.T) { } tc.wantIdle("stream window is not extended before data is read") buf := make([]byte, maxWindowSize+1) - if n, err := s.ReadContext(ctx, buf); n != maxWindowSize || err != nil { - t.Fatalf("s.ReadContext() = %v, %v; want %v, nil", n, err, maxWindowSize) + if n, err := s.Read(buf); n != maxWindowSize || err != nil { + t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, maxWindowSize) } tc.wantFrame("stream window is extended after reading data", packetType1RTT, debugFrameMaxStreamData{ @@ -482,8 +473,8 @@ func TestStreamReceiveExtendsStreamWindow(t *testing.T) { data: make([]byte, maxWindowSize), fin: true, }) - if n, err := s.ReadContext(ctx, buf); n != maxWindowSize || err != io.EOF { - t.Fatalf("s.ReadContext() = %v, %v; want %v, io.EOF", n, err, maxWindowSize) + if n, err := s.Read(buf); n != maxWindowSize || err != io.EOF { + t.Fatalf("s.Read() = %v, %v; want %v, io.EOF", n, err, maxWindowSize) } tc.wantIdle("stream window is not extended after FIN") }) @@ -549,6 +540,32 @@ func TestStreamReceiveDuplicateDataDoesNotViolateLimits(t *testing.T) { }) } +func TestStreamReceiveEmptyEOF(t *testing.T) { + // A stream receives some data, we read a byte of that data + // (causing the rest to be pulled into the s.inbuf buffer), + // and then we receive a FIN with no additional data. + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + tc, s := newTestConnAndRemoteStream(t, serverSide, styp, permissiveTransportParameters) + want := []byte{1, 2, 3} + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + data: want, + }) + if got, err := s.ReadByte(); got != want[0] || err != nil { + t.Fatalf("s.ReadByte() = %v, %v; want %v, nil", got, err, want[0]) + } + + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + off: 3, + fin: true, + }) + if got, err := io.ReadAll(s); !bytes.Equal(got, want[1:]) || err != nil { + t.Fatalf("io.ReadAll(s) = {%x}, %v; want {%x}, nil", got, err, want[1:]) + } + }) +} + func finalSizeTest(t *testing.T, wantErr transportError, f func(tc *testConn, sid streamID) (finalSize int64), opts ...any) { testStreamTypes(t, "", func(t *testing.T, styp streamType) { for _, test := range []struct { @@ -673,18 +690,19 @@ func TestStreamReceiveUnblocksReader(t *testing.T) { t.Fatalf("AcceptStream() = %v", err) } - // ReadContext succeeds immediately, since we already have data. + // Read succeeds immediately, since we already have data. got := make([]byte, len(want)) read := runAsync(tc, func(ctx context.Context) (int, error) { - return s.ReadContext(ctx, got) + return s.Read(got) }) if n, err := read.result(); n != write1size || err != nil { - t.Fatalf("ReadContext = %v, %v; want %v, nil", n, err, write1size) + t.Fatalf("Read = %v, %v; want %v, nil", n, err, write1size) } - // ReadContext blocks waiting for more data. + // Read blocks waiting for more data. read = runAsync(tc, func(ctx context.Context) (int, error) { - return s.ReadContext(ctx, got[write1size:]) + s.SetReadContext(ctx) + return s.Read(got[write1size:]) }) tc.writeFrames(packetType1RTT, debugFrameStream{ id: sid, @@ -693,7 +711,7 @@ func TestStreamReceiveUnblocksReader(t *testing.T) { fin: true, }) if n, err := read.result(); n != len(want)-write1size || err != io.EOF { - t.Fatalf("ReadContext = %v, %v; want %v, io.EOF", n, err, len(want)-write1size) + t.Fatalf("Read = %v, %v; want %v, io.EOF", n, err, len(want)-write1size) } if !bytes.Equal(got, want) { t.Fatalf("read bytes %x, want %x", got, want) @@ -935,7 +953,8 @@ func TestStreamResetBlockedStream(t *testing.T) { }) tc.ignoreFrame(frameTypeStreamDataBlocked) writing := runAsync(tc, func(ctx context.Context) (int, error) { - return s.WriteContext(ctx, []byte{0, 1, 2, 3, 4, 5, 6, 7}) + s.SetWriteContext(ctx) + return s.Write([]byte{0, 1, 2, 3, 4, 5, 6, 7}) }) tc.wantFrame("stream writes data until write buffer fills", packetType1RTT, debugFrameStream{ @@ -972,7 +991,7 @@ func TestStreamWriteMoreThanOnePacketOfData(t *testing.T) { want := make([]byte, 4096) rand.Read(want) // doesn't need to be crypto/rand, but non-deprecated and harmless w := runAsync(tc, func(ctx context.Context) (int, error) { - n, err := s.WriteContext(ctx, want) + n, err := s.Write(want) s.Flush() return n, err }) @@ -992,7 +1011,7 @@ func TestStreamWriteMoreThanOnePacketOfData(t *testing.T) { got = append(got, sf.data...) } if n, err := w.result(); n != len(want) || err != nil { - t.Fatalf("s.WriteContext() = %v, %v; want %v, nil", n, err, len(want)) + t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(want)) } if !bytes.Equal(got, want) { t.Fatalf("mismatch in received stream data") @@ -1000,17 +1019,16 @@ func TestStreamWriteMoreThanOnePacketOfData(t *testing.T) { } func TestStreamCloseWaitsForAcks(t *testing.T) { - ctx := canceledContext() tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) data := make([]byte, 100) - s.WriteContext(ctx, data) + s.Write(data) s.Flush() tc.wantFrame("conn sends data for the stream", packetType1RTT, debugFrameStream{ id: s.id, data: data, }) - if err := s.CloseContext(ctx); err != context.Canceled { + if err := s.Close(); err != context.Canceled { t.Fatalf("s.Close() = %v, want context.Canceled (data not acked yet)", err) } tc.wantFrame("conn sends FIN for closed stream", @@ -1021,21 +1039,22 @@ func TestStreamCloseWaitsForAcks(t *testing.T) { data: []byte{}, }) closing := runAsync(tc, func(ctx context.Context) (struct{}, error) { - return struct{}{}, s.CloseContext(ctx) + s.SetWriteContext(ctx) + return struct{}{}, s.Close() }) if _, err := closing.result(); err != errNotDone { - t.Fatalf("s.CloseContext() = %v, want it to block waiting for acks", err) + t.Fatalf("s.Close() = %v, want it to block waiting for acks", err) } tc.writeAckForAll() if _, err := closing.result(); err != nil { - t.Fatalf("s.CloseContext() = %v, want nil (all data acked)", err) + t.Fatalf("s.Close() = %v, want nil (all data acked)", err) } } func TestStreamCloseReadOnly(t *testing.T) { tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, permissiveTransportParameters) - if err := s.CloseContext(canceledContext()); err != nil { - t.Errorf("s.CloseContext() = %v, want nil", err) + if err := s.Close(); err != nil { + t.Errorf("s.Close() = %v, want nil", err) } tc.wantFrame("closed stream sends STOP_SENDING", packetType1RTT, debugFrameStopSending{ @@ -1069,17 +1088,16 @@ func TestStreamCloseUnblocked(t *testing.T) { }, }} { t.Run(test.name, func(t *testing.T) { - ctx := canceledContext() tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) data := make([]byte, 100) - s.WriteContext(ctx, data) + s.Write(data) s.Flush() tc.wantFrame("conn sends data for the stream", packetType1RTT, debugFrameStream{ id: s.id, data: data, }) - if err := s.CloseContext(ctx); err != context.Canceled { + if err := s.Close(); err != context.Canceled { t.Fatalf("s.Close() = %v, want context.Canceled (data not acked yet)", err) } tc.wantFrame("conn sends FIN for closed stream", @@ -1090,34 +1108,34 @@ func TestStreamCloseUnblocked(t *testing.T) { data: []byte{}, }) closing := runAsync(tc, func(ctx context.Context) (struct{}, error) { - return struct{}{}, s.CloseContext(ctx) + s.SetWriteContext(ctx) + return struct{}{}, s.Close() }) if _, err := closing.result(); err != errNotDone { - t.Fatalf("s.CloseContext() = %v, want it to block waiting for acks", err) + t.Fatalf("s.Close() = %v, want it to block waiting for acks", err) } test.unblock(tc, s) _, err := closing.result() switch { case err == errNotDone: - t.Fatalf("s.CloseContext() still blocking; want it to have returned") + t.Fatalf("s.Close() still blocking; want it to have returned") case err == nil && !test.success: - t.Fatalf("s.CloseContext() = nil, want error") + t.Fatalf("s.Close() = nil, want error") case err != nil && test.success: - t.Fatalf("s.CloseContext() = %v, want nil (all data acked)", err) + t.Fatalf("s.Close() = %v, want nil (all data acked)", err) } }) } } func TestStreamCloseWriteWhenBlockedByStreamFlowControl(t *testing.T) { - ctx := canceledContext() tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters, func(p *transportParameters) { //p.initialMaxData = 0 p.initialMaxStreamDataUni = 0 }) tc.ignoreFrame(frameTypeStreamDataBlocked) - if _, err := s.WriteContext(ctx, []byte{0, 1}); err != nil { + if _, err := s.Write([]byte{0, 1}); err != nil { t.Fatalf("s.Write = %v", err) } s.CloseWrite() @@ -1149,7 +1167,6 @@ func TestStreamCloseWriteWhenBlockedByStreamFlowControl(t *testing.T) { func TestStreamPeerResetsWithUnreadAndUnsentData(t *testing.T) { testStreamTypes(t, "", func(t *testing.T, styp streamType) { - ctx := canceledContext() tc, s := newTestConnAndRemoteStream(t, serverSide, styp) data := []byte{0, 1, 2, 3, 4, 5, 6, 7} tc.writeFrames(packetType1RTT, debugFrameStream{ @@ -1157,7 +1174,7 @@ func TestStreamPeerResetsWithUnreadAndUnsentData(t *testing.T) { data: data, }) got := make([]byte, 4) - if n, err := s.ReadContext(ctx, got); n != len(got) || err != nil { + if n, err := s.Read(got); n != len(got) || err != nil { t.Fatalf("Read start of stream: got %v, %v; want %v, nil", n, err, len(got)) } const sentCode = 42 @@ -1167,8 +1184,8 @@ func TestStreamPeerResetsWithUnreadAndUnsentData(t *testing.T) { code: sentCode, }) wantErr := StreamErrorCode(sentCode) - if n, err := s.ReadContext(ctx, got); n != 0 || !errors.Is(err, wantErr) { - t.Fatalf("Read reset stream: got %v, %v; want 0, %v", n, err, wantErr) + if _, err := io.ReadAll(s); !errors.Is(err, wantErr) { + t.Fatalf("Read reset stream: ReadAll got error %v; want %v", err, wantErr) } }) } @@ -1177,8 +1194,9 @@ func TestStreamPeerResetWakesBlockedRead(t *testing.T) { testStreamTypes(t, "", func(t *testing.T, styp streamType) { tc, s := newTestConnAndRemoteStream(t, serverSide, styp) reader := runAsync(tc, func(ctx context.Context) (int, error) { + s.SetReadContext(ctx) got := make([]byte, 4) - return s.ReadContext(ctx, got) + return s.Read(got) }) const sentCode = 42 tc.writeFrames(packetType1RTT, debugFrameResetStream{ @@ -1333,7 +1351,6 @@ func TestStreamFlushImplicitExact(t *testing.T) { id: s.id, data: want[0:4], }) - }) } @@ -1348,7 +1365,8 @@ func TestStreamFlushImplicitLargerThanBuffer(t *testing.T) { want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} w := runAsync(tc, func(ctx context.Context) (int, error) { - n, err := s.WriteContext(ctx, want) + s.SetWriteContext(ctx) + n, err := s.Write(want) return n, err }) @@ -1401,7 +1419,10 @@ func newTestConnAndLocalStream(t *testing.T, side connSide, styp streamType, opt tc := newTestConn(t, side, opts...) tc.handshake() tc.ignoreFrame(frameTypeAck) - return tc, newLocalStream(t, tc, styp) + s := newLocalStream(t, tc, styp) + s.SetReadContext(canceledContext()) + s.SetWriteContext(canceledContext()) + return tc, s } func newLocalStream(t *testing.T, tc *testConn, styp streamType) *Stream { @@ -1411,6 +1432,8 @@ func newLocalStream(t *testing.T, tc *testConn, styp streamType) *Stream { if err != nil { t.Fatalf("conn.newLocalStream(%v) = %v", styp, err) } + s.SetReadContext(canceledContext()) + s.SetWriteContext(canceledContext()) return s } @@ -1419,7 +1442,10 @@ func newTestConnAndRemoteStream(t *testing.T, side connSide, styp streamType, op tc := newTestConn(t, side, opts...) tc.handshake() tc.ignoreFrame(frameTypeAck) - return tc, newRemoteStream(t, tc, styp) + s := newRemoteStream(t, tc, styp) + s.SetReadContext(canceledContext()) + s.SetWriteContext(canceledContext()) + return tc, s } func newRemoteStream(t *testing.T, tc *testConn, styp streamType) *Stream { @@ -1432,6 +1458,8 @@ func newRemoteStream(t *testing.T, tc *testConn, styp streamType) *Stream { if err != nil { t.Fatalf("conn.AcceptStream() = %v", err) } + s.SetReadContext(canceledContext()) + s.SetWriteContext(canceledContext()) return s } diff --git a/internal/quic/tls.go b/quic/tls.go similarity index 88% rename from internal/quic/tls.go rename to quic/tls.go index a37e26fb8..e2f2e5bde 100644 --- a/internal/quic/tls.go +++ b/quic/tls.go @@ -11,14 +11,24 @@ import ( "crypto/tls" "errors" "fmt" + "net" "time" ) // startTLS starts the TLS handshake. -func (c *Conn) startTLS(now time.Time, initialConnID []byte, params transportParameters) error { +func (c *Conn) startTLS(now time.Time, initialConnID []byte, peerHostname string, params transportParameters) error { + tlsConfig := c.config.TLSConfig + if a, _, err := net.SplitHostPort(peerHostname); err == nil { + peerHostname = a + } + if tlsConfig.ServerName == "" && peerHostname != "" { + tlsConfig = tlsConfig.Clone() + tlsConfig.ServerName = peerHostname + } + c.keysInitial = initialKeys(initialConnID, c.side) - qconfig := &tls.QUICConfig{TLSConfig: c.config.TLSConfig} + qconfig := &tls.QUICConfig{TLSConfig: tlsConfig} if c.side == clientSide { c.tls = tls.QUICClient(qconfig) } else { diff --git a/internal/quic/tls_test.go b/quic/tls_test.go similarity index 100% rename from internal/quic/tls_test.go rename to quic/tls_test.go diff --git a/internal/quic/tlsconfig_test.go b/quic/tlsconfig_test.go similarity index 100% rename from internal/quic/tlsconfig_test.go rename to quic/tlsconfig_test.go diff --git a/internal/quic/transport_params.go b/quic/transport_params.go similarity index 100% rename from internal/quic/transport_params.go rename to quic/transport_params.go diff --git a/internal/quic/transport_params_test.go b/quic/transport_params_test.go similarity index 100% rename from internal/quic/transport_params_test.go rename to quic/transport_params_test.go diff --git a/quic/udp.go b/quic/udp.go new file mode 100644 index 000000000..0a578286b --- /dev/null +++ b/quic/udp.go @@ -0,0 +1,30 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import "net/netip" + +// Per-plaform consts describing support for various features. +// +// const udpECNSupport indicates whether the platform supports setting +// the ECN (Explicit Congestion Notification) IP header bits. +// +// const udpInvalidLocalAddrIsError indicates whether sending a packet +// from an local address not associated with the system is an error. +// For example, assuming 127.0.0.2 is not a local address, does sending +// from it (using IP_PKTINFO or some other such feature) result in an error? + +// unmapAddrPort returns a with any IPv4-mapped IPv6 address prefix removed. +func unmapAddrPort(a netip.AddrPort) netip.AddrPort { + if a.Addr().Is4In6() { + return netip.AddrPortFrom( + a.Addr().Unmap(), + a.Port(), + ) + } + return a +} diff --git a/quic/udp_darwin.go b/quic/udp_darwin.go new file mode 100644 index 000000000..2eb2e9f9f --- /dev/null +++ b/quic/udp_darwin.go @@ -0,0 +1,38 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 && darwin + +package quic + +import ( + "encoding/binary" + + "golang.org/x/sys/unix" +) + +// See udp.go. +const ( + udpECNSupport = true + udpInvalidLocalAddrIsError = true +) + +// Confusingly, on Darwin the contents of the IP_TOS option differ depending on whether +// it is used as an inbound or outbound cmsg. + +func parseIPTOS(b []byte) (ecnBits, bool) { + // Single byte. The low two bits are the ECN field. + if len(b) != 1 { + return 0, false + } + return ecnBits(b[0] & ecnMask), true +} + +func appendCmsgECNv4(b []byte, ecn ecnBits) []byte { + // 32-bit integer. + // https://github.com/apple/darwin-xnu/blob/2ff845c2e033bd0ff64b5b6aa6063a1f8f65aa32/bsd/netinet/in_tclass.c#L1062-L1073 + b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_TOS, 4) + binary.NativeEndian.PutUint32(data, uint32(ecn)) + return b +} diff --git a/quic/udp_linux.go b/quic/udp_linux.go new file mode 100644 index 000000000..6f191ed39 --- /dev/null +++ b/quic/udp_linux.go @@ -0,0 +1,33 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 && linux + +package quic + +import ( + "golang.org/x/sys/unix" +) + +// See udp.go. +const ( + udpECNSupport = true + udpInvalidLocalAddrIsError = false +) + +// The IP_TOS socket option is a single byte containing the IP TOS field. +// The low two bits are the ECN field. + +func parseIPTOS(b []byte) (ecnBits, bool) { + if len(b) != 1 { + return 0, false + } + return ecnBits(b[0] & ecnMask), true +} + +func appendCmsgECNv4(b []byte, ecn ecnBits) []byte { + b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_TOS, 1) + data[0] = byte(ecn) + return b +} diff --git a/quic/udp_msg.go b/quic/udp_msg.go new file mode 100644 index 000000000..0b600a2b4 --- /dev/null +++ b/quic/udp_msg.go @@ -0,0 +1,247 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 && !quicbasicnet && (darwin || linux) + +package quic + +import ( + "encoding/binary" + "net" + "net/netip" + "sync" + "unsafe" + + "golang.org/x/sys/unix" +) + +// Network interface for platforms using sendmsg/recvmsg with cmsgs. + +type netUDPConn struct { + c *net.UDPConn + localAddr netip.AddrPort +} + +func newNetUDPConn(uc *net.UDPConn) (*netUDPConn, error) { + a, _ := uc.LocalAddr().(*net.UDPAddr) + localAddr := a.AddrPort() + if localAddr.Addr().IsUnspecified() { + // If the conn is not bound to a specified (non-wildcard) address, + // then set localAddr.Addr to an invalid netip.Addr. + // This better conveys that this is not an address we should be using, + // and is a bit more efficient to test against. + localAddr = netip.AddrPortFrom(netip.Addr{}, localAddr.Port()) + } + + sc, err := uc.SyscallConn() + if err != nil { + return nil, err + } + sc.Control(func(fd uintptr) { + // Ask for ECN info and (when we aren't bound to a fixed local address) + // destination info. + // + // If any of these calls fail, we won't get the requested information. + // That's fine, we'll gracefully handle the lack. + unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1) + unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1) + if !localAddr.IsValid() { + unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1) + unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1) + } + }) + + return &netUDPConn{ + c: uc, + localAddr: localAddr, + }, nil +} + +func (c *netUDPConn) Close() error { return c.c.Close() } + +func (c *netUDPConn) LocalAddr() netip.AddrPort { + a, _ := c.c.LocalAddr().(*net.UDPAddr) + return a.AddrPort() +} + +func (c *netUDPConn) Read(f func(*datagram)) { + // We shouldn't ever see all of these messages at the same time, + // but the total is small so just allocate enough space for everything we use. + const ( + inPktinfoSize = 12 // int + in_addr + in_addr + in6PktinfoSize = 20 // in6_addr + int + ipTOSSize = 4 + ipv6TclassSize = 4 + ) + control := make([]byte, 0+ + unix.CmsgSpace(inPktinfoSize)+ + unix.CmsgSpace(in6PktinfoSize)+ + unix.CmsgSpace(ipTOSSize)+ + unix.CmsgSpace(ipv6TclassSize)) + + for { + d := newDatagram() + n, controlLen, _, peerAddr, err := c.c.ReadMsgUDPAddrPort(d.b, control) + if err != nil { + return + } + if n == 0 { + continue + } + d.localAddr = c.localAddr + d.peerAddr = unmapAddrPort(peerAddr) + d.b = d.b[:n] + parseControl(d, control[:controlLen]) + f(d) + } +} + +var cmsgPool = sync.Pool{ + New: func() any { + return new([]byte) + }, +} + +func (c *netUDPConn) Write(dgram datagram) error { + controlp := cmsgPool.Get().(*[]byte) + control := *controlp + defer func() { + *controlp = control[:0] + cmsgPool.Put(controlp) + }() + + localIP := dgram.localAddr.Addr() + if localIP.IsValid() { + if localIP.Is4() { + control = appendCmsgIPSourceAddrV4(control, localIP) + } else { + control = appendCmsgIPSourceAddrV6(control, localIP) + } + } + if dgram.ecn != ecnNotECT { + if dgram.peerAddr.Addr().Is4() { + control = appendCmsgECNv4(control, dgram.ecn) + } else { + control = appendCmsgECNv6(control, dgram.ecn) + } + } + + _, _, err := c.c.WriteMsgUDPAddrPort(dgram.b, control, dgram.peerAddr) + return err +} + +func parseControl(d *datagram, control []byte) { + for len(control) > 0 { + hdr, data, remainder, err := unix.ParseOneSocketControlMessage(control) + if err != nil { + return + } + control = remainder + switch hdr.Level { + case unix.IPPROTO_IP: + switch hdr.Type { + case unix.IP_TOS, unix.IP_RECVTOS: + // (Linux sets the type to IP_TOS, Darwin to IP_RECVTOS, + // just check for both.) + if ecn, ok := parseIPTOS(data); ok { + d.ecn = ecn + } + case unix.IP_PKTINFO: + if a, ok := parseInPktinfo(data); ok { + d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port()) + } + } + case unix.IPPROTO_IPV6: + switch hdr.Type { + case unix.IPV6_TCLASS: + // 32-bit integer containing the traffic class field. + // The low two bits are the ECN field. + if ecn, ok := parseIPv6TCLASS(data); ok { + d.ecn = ecn + } + case unix.IPV6_PKTINFO: + if a, ok := parseIn6Pktinfo(data); ok { + d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port()) + } + } + } + } +} + +// IPV6_TCLASS is specified by RFC 3542 as an int. + +func parseIPv6TCLASS(b []byte) (ecnBits, bool) { + if len(b) != 4 { + return 0, false + } + return ecnBits(binary.NativeEndian.Uint32(b) & ecnMask), true +} + +func appendCmsgECNv6(b []byte, ecn ecnBits) []byte { + b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 4) + binary.NativeEndian.PutUint32(data, uint32(ecn)) + return b +} + +// struct in_pktinfo { +// unsigned int ipi_ifindex; /* send/recv interface index */ +// struct in_addr ipi_spec_dst; /* Local address */ +// struct in_addr ipi_addr; /* IP Header dst address */ +// }; + +// parseInPktinfo returns the destination address from an IP_PKTINFO. +func parseInPktinfo(b []byte) (dst netip.Addr, ok bool) { + if len(b) != 12 { + return netip.Addr{}, false + } + return netip.AddrFrom4([4]byte(b[8:][:4])), true +} + +// appendCmsgIPSourceAddrV4 appends an IP_PKTINFO setting the source address +// for an outbound datagram. +func appendCmsgIPSourceAddrV4(b []byte, src netip.Addr) []byte { + // struct in_pktinfo { + // unsigned int ipi_ifindex; /* send/recv interface index */ + // struct in_addr ipi_spec_dst; /* Local address */ + // struct in_addr ipi_addr; /* IP Header dst address */ + // }; + b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_PKTINFO, 12) + ip := src.As4() + copy(data[4:], ip[:]) + return b +} + +// struct in6_pktinfo { +// struct in6_addr ipi6_addr; /* src/dst IPv6 address */ +// unsigned int ipi6_ifindex; /* send/recv interface index */ +// }; + +// parseIn6Pktinfo returns the destination address from an IPV6_PKTINFO. +func parseIn6Pktinfo(b []byte) (netip.Addr, bool) { + if len(b) != 20 { + return netip.Addr{}, false + } + return netip.AddrFrom16([16]byte(b[:16])).Unmap(), true +} + +// appendCmsgIPSourceAddrV6 appends an IPV6_PKTINFO setting the source address +// for an outbound datagram. +func appendCmsgIPSourceAddrV6(b []byte, src netip.Addr) []byte { + b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_PKTINFO, 20) + ip := src.As16() + copy(data[0:], ip[:]) + return b +} + +// appendCmsg appends a cmsg with the given level, type, and size to b. +// It returns the new buffer, and the data section of the cmsg. +func appendCmsg(b []byte, level, typ int32, size int) (_, data []byte) { + off := len(b) + b = append(b, make([]byte, unix.CmsgSpace(size))...) + h := (*unix.Cmsghdr)(unsafe.Pointer(&b[off])) + h.Level = level + h.Type = typ + h.SetLen(unix.CmsgLen(size)) + return b, b[off+unix.CmsgSpace(0):][:size] +} diff --git a/quic/udp_other.go b/quic/udp_other.go new file mode 100644 index 000000000..28be6d200 --- /dev/null +++ b/quic/udp_other.go @@ -0,0 +1,62 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 && (quicbasicnet || !(darwin || linux)) + +package quic + +import ( + "net" + "net/netip" +) + +// Lowest common denominator network interface: Basic net.UDPConn, no cmsgs. +// We will not be able to send or receive ECN bits, +// and we will not know what our local address is. +// +// The quicbasicnet build tag allows selecting this interface on any platform. + +// See udp.go. +const ( + udpECNSupport = false + udpInvalidLocalAddrIsError = false +) + +type netUDPConn struct { + c *net.UDPConn +} + +func newNetUDPConn(uc *net.UDPConn) (*netUDPConn, error) { + return &netUDPConn{ + c: uc, + }, nil +} + +func (c *netUDPConn) Close() error { return c.c.Close() } + +func (c *netUDPConn) LocalAddr() netip.AddrPort { + a, _ := c.c.LocalAddr().(*net.UDPAddr) + return a.AddrPort() +} + +func (c *netUDPConn) Read(f func(*datagram)) { + for { + dgram := newDatagram() + n, _, _, peerAddr, err := c.c.ReadMsgUDPAddrPort(dgram.b, nil) + if err != nil { + return + } + if n == 0 { + continue + } + dgram.peerAddr = unmapAddrPort(peerAddr) + dgram.b = dgram.b[:n] + f(dgram) + } +} + +func (c *netUDPConn) Write(dgram datagram) error { + _, err := c.c.WriteToUDPAddrPort(dgram.b, dgram.peerAddr) + return err +} diff --git a/quic/udp_test.go b/quic/udp_test.go new file mode 100644 index 000000000..d3732c140 --- /dev/null +++ b/quic/udp_test.go @@ -0,0 +1,189 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "bytes" + "fmt" + "net" + "net/netip" + "runtime" + "testing" +) + +func TestUDPSourceUnspecified(t *testing.T) { + // Send datagram with no source address set. + runUDPTest(t, func(t *testing.T, test udpTest) { + t.Logf("%v", test.dstAddr) + data := []byte("source unspecified") + if err := test.src.Write(datagram{ + b: data, + peerAddr: test.dstAddr, + }); err != nil { + t.Fatalf("Write: %v", err) + } + got := <-test.dgramc + if !bytes.Equal(got.b, data) { + t.Errorf("got datagram {%x}, want {%x}", got.b, data) + } + }) +} + +func TestUDPSourceSpecified(t *testing.T) { + // Send datagram with source address set. + runUDPTest(t, func(t *testing.T, test udpTest) { + data := []byte("source specified") + if err := test.src.Write(datagram{ + b: data, + peerAddr: test.dstAddr, + localAddr: test.src.LocalAddr(), + }); err != nil { + t.Fatalf("Write: %v", err) + } + got := <-test.dgramc + if !bytes.Equal(got.b, data) { + t.Errorf("got datagram {%x}, want {%x}", got.b, data) + } + }) +} + +func TestUDPSourceInvalid(t *testing.T) { + // Send datagram with source address set to an address not associated with the connection. + if !udpInvalidLocalAddrIsError { + t.Skipf("%v: sending from invalid source succeeds", runtime.GOOS) + } + runUDPTest(t, func(t *testing.T, test udpTest) { + var localAddr netip.AddrPort + if test.src.LocalAddr().Addr().Is4() { + localAddr = netip.MustParseAddrPort("127.0.0.2:1234") + } else { + localAddr = netip.MustParseAddrPort("[::2]:1234") + } + data := []byte("source invalid") + if err := test.src.Write(datagram{ + b: data, + peerAddr: test.dstAddr, + localAddr: localAddr, + }); err == nil { + t.Errorf("Write with invalid localAddr succeeded; want error") + } + }) +} + +func TestUDPECN(t *testing.T) { + if !udpECNSupport { + t.Skipf("%v: no ECN support", runtime.GOOS) + } + // Send datagrams with ECN bits set, verify the ECN bits are received. + runUDPTest(t, func(t *testing.T, test udpTest) { + for _, ecn := range []ecnBits{ecnNotECT, ecnECT1, ecnECT0, ecnCE} { + if err := test.src.Write(datagram{ + b: []byte{1, 2, 3, 4}, + peerAddr: test.dstAddr, + ecn: ecn, + }); err != nil { + t.Fatalf("Write: %v", err) + } + got := <-test.dgramc + if got.ecn != ecn { + t.Errorf("sending ECN bits %x, got %x", ecn, got.ecn) + } + } + }) +} + +type udpTest struct { + src *netUDPConn + dst *netUDPConn + dstAddr netip.AddrPort + dgramc chan *datagram +} + +// runUDPTest calls f with a pair of UDPConns in a matrix of network variations: +// udp, udp4, and udp6, and variations on binding to an unspecified address (0.0.0.0) +// or a specified one. +func runUDPTest(t *testing.T, f func(t *testing.T, u udpTest)) { + for _, test := range []struct { + srcNet, srcAddr, dstNet, dstAddr string + }{ + {"udp4", "127.0.0.1", "udp", ""}, + {"udp4", "127.0.0.1", "udp4", ""}, + {"udp4", "127.0.0.1", "udp4", "127.0.0.1"}, + {"udp6", "::1", "udp", ""}, + {"udp6", "::1", "udp6", ""}, + {"udp6", "::1", "udp6", "::1"}, + } { + spec := "spec" + if test.dstAddr == "" { + spec = "unspec" + } + t.Run(fmt.Sprintf("%v/%v/%v", test.srcNet, test.dstNet, spec), func(t *testing.T) { + // See: https://go.googlesource.com/go/+/refs/tags/go1.22.0/src/net/ipsock.go#47 + // On these platforms, conns with network="udp" cannot accept IPv6. + switch runtime.GOOS { + case "dragonfly", "openbsd": + if test.srcNet == "udp6" && test.dstNet == "udp" { + t.Skipf("%v: no support for mapping IPv4 address to IPv6", runtime.GOOS) + } + } + if runtime.GOARCH == "wasm" && test.srcNet == "udp6" { + t.Skipf("%v: IPv6 tests fail when using wasm fake net", runtime.GOARCH) + } + + srcAddr := netip.AddrPortFrom(netip.MustParseAddr(test.srcAddr), 0) + srcConn, err := net.ListenUDP(test.srcNet, net.UDPAddrFromAddrPort(srcAddr)) + if err != nil { + // If ListenUDP fails here, we presumably don't have + // IPv4/IPv6 configured. + t.Skipf("ListenUDP(%q, %v) = %v", test.srcNet, srcAddr, err) + } + t.Cleanup(func() { srcConn.Close() }) + src, err := newNetUDPConn(srcConn) + if err != nil { + t.Fatalf("newNetUDPConn: %v", err) + } + + var dstAddr netip.AddrPort + if test.dstAddr != "" { + dstAddr = netip.AddrPortFrom(netip.MustParseAddr(test.dstAddr), 0) + } + dstConn, err := net.ListenUDP(test.dstNet, net.UDPAddrFromAddrPort(dstAddr)) + if err != nil { + t.Skipf("ListenUDP(%q, nil) = %v", test.dstNet, err) + } + dst, err := newNetUDPConn(dstConn) + if err != nil { + dstConn.Close() + t.Fatalf("newNetUDPConn: %v", err) + } + + dgramc := make(chan *datagram) + go func() { + defer close(dgramc) + dst.Read(func(dgram *datagram) { + dgramc <- dgram + }) + }() + t.Cleanup(func() { + dstConn.Close() + for range dgramc { + t.Errorf("test read unexpected datagram") + } + }) + + f(t, udpTest{ + src: src, + dst: dst, + dstAddr: netip.AddrPortFrom( + srcAddr.Addr(), + dst.LocalAddr().Port(), + ), + dgramc: dgramc, + }) + }) + } +} diff --git a/internal/quic/version_test.go b/quic/version_test.go similarity index 96% rename from internal/quic/version_test.go rename to quic/version_test.go index 92fabd7b3..0bd8bac14 100644 --- a/internal/quic/version_test.go +++ b/quic/version_test.go @@ -39,10 +39,10 @@ func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) { }) gotPkt := te.read() if gotPkt == nil { - t.Fatalf("got no response; want Version Negotiaion") + t.Fatalf("got no response; want Version Negotiation") } if got := getPacketType(gotPkt); got != packetTypeVersionNegotiation { - t.Fatalf("got packet type %v; want Version Negotiaion", got) + t.Fatalf("got packet type %v; want Version Negotiation", got) } gotDst, gotSrc, versions := parseVersionNegotiation(gotPkt) if got, want := gotDst, srcConnID; !bytes.Equal(got, want) { diff --git a/internal/quic/wire.go b/quic/wire.go similarity index 100% rename from internal/quic/wire.go rename to quic/wire.go diff --git a/internal/quic/wire_test.go b/quic/wire_test.go similarity index 100% rename from internal/quic/wire_test.go rename to quic/wire_test.go diff --git a/websocket/client.go b/websocket/client.go index 69a4ac7ee..1e64157f3 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -6,10 +6,12 @@ package websocket import ( "bufio" + "context" "io" "net" "net/http" "net/url" + "time" ) // DialError is an error that occurs while dialling a websocket server. @@ -79,28 +81,59 @@ func parseAuthority(location *url.URL) string { // DialConfig opens a new client connection to a WebSocket with a config. func DialConfig(config *Config) (ws *Conn, err error) { - var client net.Conn + return config.DialContext(context.Background()) +} + +// DialContext opens a new client connection to a WebSocket, with context support for timeouts/cancellation. +func (config *Config) DialContext(ctx context.Context) (*Conn, error) { if config.Location == nil { return nil, &DialError{config, ErrBadWebSocketLocation} } if config.Origin == nil { return nil, &DialError{config, ErrBadWebSocketOrigin} } + dialer := config.Dialer if dialer == nil { dialer = &net.Dialer{} } - client, err = dialWithDialer(dialer, config) - if err != nil { - goto Error - } - ws, err = NewClient(config, client) + + client, err := dialWithDialer(ctx, dialer, config) if err != nil { - client.Close() - goto Error + return nil, &DialError{config, err} } - return -Error: - return nil, &DialError{config, err} + // Cleanup the connection if we fail to create the websocket successfully + success := false + defer func() { + if !success { + _ = client.Close() + } + }() + + var ws *Conn + var wsErr error + doneConnecting := make(chan struct{}) + go func() { + defer close(doneConnecting) + ws, err = NewClient(config, client) + if err != nil { + wsErr = &DialError{config, err} + } + }() + + // The websocket.NewClient() function can block indefinitely, make sure that we + // respect the deadlines specified by the context. + select { + case <-ctx.Done(): + // Force the pending operations to fail, terminating the pending connection attempt + _ = client.SetDeadline(time.Now()) + <-doneConnecting // Wait for the goroutine that tries to establish the connection to finish + return nil, &DialError{config, ctx.Err()} + case <-doneConnecting: + if wsErr == nil { + success = true // Disarm the deferred connection cleanup + } + return ws, wsErr + } } diff --git a/websocket/dial.go b/websocket/dial.go index 2dab943a4..8a2d83c47 100644 --- a/websocket/dial.go +++ b/websocket/dial.go @@ -5,18 +5,23 @@ package websocket import ( + "context" "crypto/tls" "net" ) -func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) { +func dialWithDialer(ctx context.Context, dialer *net.Dialer, config *Config) (conn net.Conn, err error) { switch config.Location.Scheme { case "ws": - conn, err = dialer.Dial("tcp", parseAuthority(config.Location)) + conn, err = dialer.DialContext(ctx, "tcp", parseAuthority(config.Location)) case "wss": - conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig) + tlsDialer := &tls.Dialer{ + NetDialer: dialer, + Config: config.TlsConfig, + } + conn, err = tlsDialer.DialContext(ctx, "tcp", parseAuthority(config.Location)) default: err = ErrBadScheme } diff --git a/websocket/dial_test.go b/websocket/dial_test.go index aa03e30dd..dd844872c 100644 --- a/websocket/dial_test.go +++ b/websocket/dial_test.go @@ -5,10 +5,13 @@ package websocket import ( + "context" "crypto/tls" + "errors" "fmt" "log" "net" + "net/http" "net/http/httptest" "testing" "time" @@ -41,3 +44,37 @@ func TestDialConfigTLSWithDialer(t *testing.T) { t.Fatalf("expected timeout error, got %#v", neterr) } } + +func TestDialConfigTLSWithTimeouts(t *testing.T) { + t.Parallel() + + finishedRequest := make(chan bool) + + // Context for cancellation + ctx, cancel := context.WithCancel(context.Background()) + + // This is a TLS server that blocks each request indefinitely (and cancels the context) + tlsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cancel() + <-finishedRequest + })) + + tlsServerAddr := tlsServer.Listener.Addr().String() + log.Print("Test TLS WebSocket server listening on ", tlsServerAddr) + defer tlsServer.Close() + defer close(finishedRequest) + + config, _ := NewConfig(fmt.Sprintf("wss://%s/echo", tlsServerAddr), "http://localhost") + config.TlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + + _, err := config.DialContext(ctx) + dialerr, ok := err.(*DialError) + if !ok { + t.Fatalf("DialError expected, got %#v", err) + } + if !errors.Is(dialerr.Err, context.Canceled) { + t.Fatalf("context.Canceled error expected, got %#v", dialerr.Err) + } +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy