diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS deleted file mode 100644 index d2eae33e..00000000 --- a/.github/CODEOWNERS +++ /dev/null @@ -1 +0,0 @@ -* @nhooyr diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml deleted file mode 100644 index 3d9829ef..00000000 --- a/.github/workflows/ci.yaml +++ /dev/null @@ -1,39 +0,0 @@ -name: ci - -on: [push, pull_request] - -jobs: - fmt: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v1 - - name: Run ./ci/fmt.sh - uses: ./ci/container - with: - args: ./ci/fmt.sh - - lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v1 - - name: Run ./ci/lint.sh - uses: ./ci/container - with: - args: ./ci/lint.sh - - test: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v1 - - name: Run ./ci/test.sh - uses: ./ci/container - with: - args: ./ci/test.sh - env: - NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }} - NETLIFY_SITE_ID: 9b3ee4dc-8297-4774-b4b9-a61561fbbce7 - - name: Upload coverage.html - uses: actions/upload-artifact@v2 - with: - name: coverage.html - path: ./ci/out/coverage.html diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..e9b4b5f6 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,47 @@ +name: ci +on: [push, pull_request] +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} + cancel-in-progress: true + +jobs: + fmt: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod + - run: ./ci/fmt.sh + + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - run: go version + - uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod + - run: ./ci/lint.sh + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod + - run: ./ci/test.sh + - uses: actions/upload-artifact@v3 + with: + name: coverage.html + path: ./ci/out/coverage.html + + bench: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod + - run: ./ci/bench.sh diff --git a/.github/workflows/daily.yml b/.github/workflows/daily.yml new file mode 100644 index 00000000..2ba9ce34 --- /dev/null +++ b/.github/workflows/daily.yml @@ -0,0 +1,54 @@ +name: daily +on: + workflow_dispatch: + schedule: + - cron: '42 0 * * *' # daily at 00:42 +concurrency: + group: ${{ github.workflow }} + cancel-in-progress: true + +jobs: + bench: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod + - run: AUTOBAHN=1 ./ci/bench.sh + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod + - run: AUTOBAHN=1 ./ci/test.sh + - uses: actions/upload-artifact@v3 + with: + name: coverage.html + path: ./ci/out/coverage.html + bench-dev: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: dev + - uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod + - run: AUTOBAHN=1 ./ci/bench.sh + test-dev: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: dev + - uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod + - run: AUTOBAHN=1 ./ci/test.sh + - uses: actions/upload-artifact@v3 + with: + name: coverage-dev.html + path: ./ci/out/coverage.html diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 6961e5c8..00000000 --- a/.gitignore +++ /dev/null @@ -1 +0,0 @@ -websocket.test diff --git a/LICENSE.txt b/LICENSE.txt index b5b5fef3..77b5bef6 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,21 +1,13 @@ -MIT License - -Copyright (c) 2018 Anmol Sethi - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. +Copyright (c) 2023 Anmol Sethi + +Permission to use, copy, modify, and distribute this software for any +purpose with or without fee is hereby granted, provided that the above +copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. diff --git a/README.md b/README.md index df20c581..d093746d 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,13 @@ # websocket -[![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://pkg.go.dev/nhooyr.io/websocket) -[![coverage](https://img.shields.io/badge/coverage-88%25-success)](https://nhooyrio-websocket-coverage.netlify.app) +[![Go Reference](https://pkg.go.dev/badge/nhooyr.io/websocket.svg)](https://pkg.go.dev/nhooyr.io/websocket) +[![Go Coverage](https://img.shields.io/badge/coverage-91%25-success)](https://nhooyr.io/websocket/coverage.html) websocket is a minimal and idiomatic WebSocket library for Go. ## Install -```bash +```sh go get nhooyr.io/websocket ``` @@ -16,26 +16,37 @@ go get nhooyr.io/websocket - Minimal and idiomatic API - First class [context.Context](https://blog.golang.org/context) support - Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) -- [Single dependency](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) -- JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages +- [Zero dependencies](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) +- JSON helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) subpackage - Zero alloc reads and writes - Concurrent writes - [Close handshake](https://pkg.go.dev/nhooyr.io/websocket#Conn.Close) - [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper - [Ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression +- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper for write only connections - Compile to [Wasm](https://pkg.go.dev/nhooyr.io/websocket#hdr-Wasm) ## Roadmap +See GitHub issues for minor issues but the major future enhancements are: + +- [ ] Perfect examples [#217](https://github.com/nhooyr/websocket/issues/217) +- [ ] wstest.Pipe for in memory testing [#340](https://github.com/nhooyr/websocket/issues/340) +- [ ] Ping pong heartbeat helper [#267](https://github.com/nhooyr/websocket/issues/267) +- [ ] Ping pong instrumentation callbacks [#246](https://github.com/nhooyr/websocket/issues/246) +- [ ] Graceful shutdown helpers [#209](https://github.com/nhooyr/websocket/issues/209) +- [ ] Assembly for WebSocket masking [#16](https://github.com/nhooyr/websocket/issues/16) + - WIP at [#326](https://github.com/nhooyr/websocket/pull/326), about 3x faster - [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4) +- [ ] The holy grail [#402](https://github.com/nhooyr/websocket/issues/402) ## Examples For a production quality example that demonstrates the complete API, see the -[echo example](./examples/echo). +[echo example](./internal/examples/echo). -For a full stack example, see the [chat example](./examples/chat). +For a full stack example, see the [chat example](./internal/examples/chat). ### Server @@ -45,7 +56,7 @@ http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { if err != nil { // ... } - defer c.Close(websocket.StatusInternalError, "the sky is falling") + defer c.CloseNow() ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) defer cancel() @@ -72,7 +83,7 @@ c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { // ... } -defer c.Close(websocket.StatusInternalError, "the sky is falling") +defer c.CloseNow() err = wsjson.Write(ctx, c, "hi") if err != nil { @@ -91,6 +102,8 @@ Advantages of [gorilla/websocket](https://github.com/gorilla/websocket): - Mature and widely used - [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage) - Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers) +- No extra goroutine per connection to support cancellation with context.Context. This costs nhooyr.io/websocket 2 KB of memory per connection. + - Will be removed soon with [context.AfterFunc](https://github.com/golang/go/issues/57928). See [#411](https://github.com/nhooyr/websocket/issues/411) Advantages of nhooyr.io/websocket: @@ -107,14 +120,13 @@ Advantages of nhooyr.io/websocket: - Idiomatic [ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API - Gorilla requires registering a pong callback before sending a Ping - Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) -- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages +- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) subpackage - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). + Soon we'll have assembly and be 3x faster [#326](https://github.com/nhooyr/websocket/pull/326) - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Gorilla only supports no context takeover mode - - We use [klauspost/compress](https://github.com/klauspost/compress) for much lower memory usage ([gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203)) -- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) -- Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) +- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) #### golang.org/x/net/websocket @@ -129,4 +141,15 @@ to nhooyr.io/websocket. [gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). -However when writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. +However it is quite bloated. See https://pkg.go.dev/github.com/gobwas/ws + +When writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. + +#### lesismal/nbio + +[lesismal/nbio](https://github.com/lesismal/nbio) is similar to gobwas/ws in that the API is +event driven for performance reasons. + +However it is quite bloated. See https://pkg.go.dev/github.com/lesismal/nbio + +When writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. diff --git a/accept.go b/accept.go index 18536bdb..285b3103 100644 --- a/accept.go +++ b/accept.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package websocket @@ -51,7 +52,7 @@ type AcceptOptions struct { OriginPatterns []string // CompressionMode controls the compression mode. - // Defaults to CompressionNoContextTakeover. + // Defaults to CompressionDisabled. // // See docs on CompressionMode for details. CompressionMode CompressionMode @@ -63,6 +64,14 @@ type AcceptOptions struct { CompressionThreshold int } +func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { + var o AcceptOptions + if opts != nil { + o = *opts + } + return &o +} + // Accept accepts a WebSocket handshake from a client and upgrades the // the connection to a WebSocket. // @@ -77,17 +86,13 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { defer errd.Wrap(&err, "failed to accept WebSocket connection") - if opts == nil { - opts = &AcceptOptions{} - } - opts = &*opts - errCode, err := verifyClientRequest(w, r) if err != nil { http.Error(w, err.Error(), errCode) return nil, err } + opts = opts.cloneWithDefaults() if !opts.InsecureSkipVerify { err = authenticateOrigin(r, opts.OriginPatterns) if err != nil { @@ -118,9 +123,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con w.Header().Set("Sec-WebSocket-Protocol", subproto) } - copts, err := acceptCompression(r, w, opts.CompressionMode) - if err != nil { - return nil, err + copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode) + if ok { + w.Header().Set("Sec-WebSocket-Extensions", copts.String()) } w.WriteHeader(http.StatusSwitchingProtocols) @@ -180,10 +185,21 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) } - if r.Header.Get("Sec-WebSocket-Key") == "" { + websocketSecKeys := r.Header.Values("Sec-WebSocket-Key") + if len(websocketSecKeys) == 0 { return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") } + if len(websocketSecKeys) > 1 { + return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers") + } + + // The RFC states to remove any leading or trailing whitespace. + websocketSecKey := strings.TrimSpace(websocketSecKeys[0]) + if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 { + return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey) + } + return 0, nil } @@ -211,7 +227,10 @@ func authenticateOrigin(r *http.Request, originHosts []string) error { return nil } } - return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + if u.Host == "" { + return fmt.Errorf("request Origin %q is not a valid URL with a host", origin) + } + return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host) } func match(pattern, s string) (bool, error) { @@ -230,26 +249,26 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string { return "" } -func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) { +func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) { if mode == CompressionDisabled { - return nil, nil + return nil, false } - - for _, ext := range websocketExtensions(r.Header) { + for _, ext := range extensions { switch ext.name { + // We used to implement x-webkit-deflate-frame too for Safari but Safari has bugs... + // See https://github.com/nhooyr/websocket/issues/218 case "permessage-deflate": - return acceptDeflate(w, ext, mode) - // Disabled for now, see https://github.com/nhooyr/websocket/issues/218 - // case "x-webkit-deflate-frame": - // return acceptWebkitDeflate(w, ext, mode) + copts, ok := acceptDeflate(ext, mode) + if ok { + return copts, true + } } } - return nil, nil + return nil, false } -func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { +func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) { copts := mode.opts() - for _, p := range ext.params { switch p { case "client_no_context_takeover": @@ -258,55 +277,18 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi case "server_no_context_takeover": copts.serverNoContextTakeover = true continue - } - - if strings.HasPrefix(p, "client_max_window_bits") { - // We cannot adjust the read sliding window so cannot make use of this. + case "client_max_window_bits", + "server_max_window_bits=15": continue } - err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p) - http.Error(w, err.Error(), http.StatusBadRequest) - return nil, err - } - - copts.setHeader(w.Header()) - - return copts, nil -} - -func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { - copts := mode.opts() - // The peer must explicitly request it. - copts.serverNoContextTakeover = false - - for _, p := range ext.params { - if p == "no_context_takeover" { - copts.serverNoContextTakeover = true + if strings.HasPrefix(p, "client_max_window_bits=") { + // We can't adjust the deflate window, but decoding with a larger window is acceptable. continue } - - // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead - // of ignoring it as the draft spec is unclear. It says the server can ignore it - // but the server has no way of signalling to the client it was ignored as the parameters - // are set one way. - // Thus us ignoring it would make the client think we understood it which would cause issues. - // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1 - // - // Either way, we're only implementing this for webkit which never sends the max_window_bits - // parameter so we don't need to worry about it. - err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) - http.Error(w, err.Error(), http.StatusBadRequest) - return nil, err - } - - s := "x-webkit-deflate-frame" - if copts.clientNoContextTakeover { - s += "; no_context_takeover" + return nil, false } - w.Header().Set("Sec-WebSocket-Extensions", s) - - return copts, nil + return copts, true } func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { diff --git a/accept_js.go b/accept_js.go deleted file mode 100644 index daad4b79..00000000 --- a/accept_js.go +++ /dev/null @@ -1,20 +0,0 @@ -package websocket - -import ( - "errors" - "net/http" -) - -// AcceptOptions represents Accept's options. -type AcceptOptions struct { - Subprotocols []string - InsecureSkipVerify bool - OriginPatterns []string - CompressionMode CompressionMode - CompressionThreshold int -} - -// Accept is stubbed out for Wasm. -func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { - return nil, errors.New("unimplemented") -} diff --git a/accept_test.go b/accept_test.go index e114d1ad..18233b1e 100644 --- a/accept_test.go +++ b/accept_test.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package websocket @@ -9,9 +10,11 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" "nhooyr.io/websocket/internal/test/assert" + "nhooyr.io/websocket/internal/test/xrand" ) func TestAccept(t *testing.T) { @@ -35,28 +38,76 @@ func TestAccept(t *testing.T) { r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") - r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) r.Header.Set("Origin", "harhar.com") _, err := Accept(w, r, nil) - assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host`) + assert.Contains(t, err, `request Origin "harhar.com" is not a valid URL with a host`) }) - t.Run("badCompression", func(t *testing.T) { + // #247 + t.Run("unauthorizedOriginErrorMessage", func(t *testing.T) { t.Parallel() - w := mockHijacker{ - ResponseWriter: httptest.NewRecorder(), - } + w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") - r.Header.Set("Sec-WebSocket-Key", "meow123") - r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) + r.Header.Set("Origin", "https://harhar.com") _, err := Accept(w, r, nil) - assert.Contains(t, err, `unsupported permessage-deflate parameter`) + assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host "example.com"`) + }) + + t.Run("badCompression", func(t *testing.T) { + t.Parallel() + + newRequest := func(extensions string) *http.Request { + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) + r.Header.Set("Sec-WebSocket-Extensions", extensions) + return r + } + errHijack := errors.New("hijack error") + newResponseWriter := func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + hijack: func() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, errHijack + }, + } + } + + t.Run("withoutFallback", func(t *testing.T) { + t.Parallel() + + w := newResponseWriter() + r := newRequest("permessage-deflate; harharhar") + _, err := Accept(w, r, &AcceptOptions{ + CompressionMode: CompressionNoContextTakeover, + }) + assert.ErrorIs(t, errHijack, err) + assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "") + }) + t.Run("withFallback", func(t *testing.T) { + t.Parallel() + + w := newResponseWriter() + r := newRequest("permessage-deflate; harharhar, permessage-deflate") + _, err := Accept(w, r, &AcceptOptions{ + CompressionMode: CompressionNoContextTakeover, + }) + assert.ErrorIs(t, errHijack, err) + assert.Equal(t, "extension header", + w.Header().Get("Sec-WebSocket-Extensions"), + CompressionNoContextTakeover.opts().String(), + ) + }) }) t.Run("requireHttpHijacker", func(t *testing.T) { @@ -67,7 +118,7 @@ func TestAccept(t *testing.T) { r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") - r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) _, err := Accept(w, r, nil) assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`) @@ -87,11 +138,47 @@ func TestAccept(t *testing.T) { r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") - r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) _, err := Accept(w, r, nil) assert.Contains(t, err, `failed to hijack connection`) }) + t.Run("closeRace", func(t *testing.T) { + t.Parallel() + + server, _ := net.Pipe() + + rw := bufio.NewReadWriter(bufio.NewReader(server), bufio.NewWriter(server)) + newResponseWriter := func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + hijack: func() (net.Conn, *bufio.ReadWriter, error) { + return server, rw, nil + }, + } + } + w := newResponseWriter() + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) + + c, err := Accept(w, r, nil) + wg := &sync.WaitGroup{} + wg.Add(2) + go func() { + c.Close(StatusInternalError, "the sky is falling") + wg.Done() + }() + go func() { + c.CloseNow() + wg.Done() + }() + wg.Wait() + assert.Success(t, err) + }) } func Test_verifyClientHandshake(t *testing.T) { @@ -134,7 +221,15 @@ func Test_verifyClientHandshake(t *testing.T) { }, }, { - name: "badWebSocketKey", + name: "missingWebSocketKey", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + }, + }, + { + name: "emptyWebSocketKey", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", @@ -142,13 +237,43 @@ func Test_verifyClientHandshake(t *testing.T) { "Sec-WebSocket-Key": "", }, }, + { + name: "shortWebSocketKey", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": xrand.Base64(15), + }, + }, + { + name: "invalidWebSocketKey", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": "notbase64", + }, + }, + { + name: "extraWebSocketKey", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + // Kinda cheeky, but http headers are case-insensitive. + // If 2 sec keys are present, this is a failure condition. + "Sec-WebSocket-Key": xrand.Base64(16), + "sec-webSocket-key": xrand.Base64(16), + }, + }, { name: "badHTTPVersion", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", - "Sec-WebSocket-Key": "meow123", + "Sec-WebSocket-Key": xrand.Base64(16), }, http1: true, }, @@ -158,7 +283,17 @@ func Test_verifyClientHandshake(t *testing.T) { "Connection": "keep-alive, Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", - "Sec-WebSocket-Key": "meow123", + "Sec-WebSocket-Key": xrand.Base64(16), + }, + success: true, + }, + { + name: "successSecKeyExtraSpace", + h: map[string]string{ + "Connection": "keep-alive, Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": " " + xrand.Base64(16) + " ", }, success: true, }, @@ -178,7 +313,7 @@ func Test_verifyClientHandshake(t *testing.T) { } for k, v := range tc.h { - r.Header.Set(k, v) + r.Header.Add(k, v) } _, err := verifyClientRequest(httptest.NewRecorder(), r) @@ -325,59 +460,54 @@ func Test_authenticateOrigin(t *testing.T) { } } -func Test_acceptCompression(t *testing.T) { +func Test_selectDeflate(t *testing.T) { t.Parallel() testCases := []struct { - name string - mode CompressionMode - reqSecWebSocketExtensions string - respSecWebSocketExtensions string - expCopts *compressionOptions - error bool + name string + mode CompressionMode + header string + expCopts *compressionOptions + expOK bool }{ { name: "disabled", mode: CompressionDisabled, expCopts: nil, + expOK: false, }, { name: "noClientSupport", mode: CompressionNoContextTakeover, expCopts: nil, + expOK: false, }, { - name: "permessage-deflate", - mode: CompressionNoContextTakeover, - reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits", - respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover", + name: "permessage-deflate", + mode: CompressionNoContextTakeover, + header: "permessage-deflate; client_max_window_bits", expCopts: &compressionOptions{ clientNoContextTakeover: true, serverNoContextTakeover: true, }, + expOK: true, }, { - name: "permessage-deflate/error", - mode: CompressionNoContextTakeover, - reqSecWebSocketExtensions: "permessage-deflate; meow", - error: true, + name: "permessage-deflate/unknown-parameter", + mode: CompressionNoContextTakeover, + header: "permessage-deflate; meow", + expOK: false, + }, + { + name: "permessage-deflate/unknown-parameter", + mode: CompressionNoContextTakeover, + header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits", + expCopts: &compressionOptions{ + clientNoContextTakeover: true, + serverNoContextTakeover: true, + }, + expOK: true, }, - // { - // name: "x-webkit-deflate-frame", - // mode: CompressionNoContextTakeover, - // reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", - // respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", - // expCopts: &compressionOptions{ - // clientNoContextTakeover: true, - // serverNoContextTakeover: true, - // }, - // }, - // { - // name: "x-webkit-deflate/error", - // mode: CompressionNoContextTakeover, - // reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits", - // error: true, - // }, } for _, tc := range testCases { @@ -385,19 +515,11 @@ func Test_acceptCompression(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions) - - w := httptest.NewRecorder() - copts, err := acceptCompression(r, w, tc.mode) - if tc.error { - assert.Error(t, err) - return - } - - assert.Success(t, err) + h := http.Header{} + h.Set("Sec-WebSocket-Extensions", tc.header) + copts, ok := selectDeflate(websocketExtensions(h), tc.mode) + assert.Equal(t, "selected options", tc.expOK, ok) assert.Equal(t, "compression options", tc.expCopts, copts) - assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) }) } } diff --git a/autobahn_test.go b/autobahn_test.go index e56a4912..57ceebd5 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package websocket_test @@ -5,8 +6,9 @@ package websocket_test import ( "context" "encoding/json" + "errors" "fmt" - "io/ioutil" + "io" "net" "os" "os/exec" @@ -19,6 +21,7 @@ import ( "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/wstest" + "nhooyr.io/websocket/internal/util" ) var excludedAutobahnCases = []string{ @@ -28,25 +31,43 @@ var excludedAutobahnCases = []string{ // We skip the tests related to requestMaxWindowBits as that is unimplemented due // to limitations in compress/flate. See https://github.com/golang/go/issues/3155 - // Same with klauspost/compress which doesn't allow adjusting the sliding window size. "13.3.*", "13.4.*", "13.5.*", "13.6.*", } var autobahnCases = []string{"*"} +// Used to run individual test cases. autobahnCases runs only those cases matched +// and not excluded by excludedAutobahnCases. Adding cases here means excludedAutobahnCases +// is niled. +var onlyAutobahnCases = []string{} + func TestAutobahn(t *testing.T) { t.Parallel() - if os.Getenv("AUTOBAHN_TEST") == "" { + if os.Getenv("AUTOBAHN") == "" { t.SkipNow() } - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) + if os.Getenv("AUTOBAHN") == "fast" { + // These are the slow tests. + excludedAutobahnCases = append(excludedAutobahnCases, + "9.*", "12.*", "13.*", + ) + } + + if len(onlyAutobahnCases) > 0 { + excludedAutobahnCases = []string{} + autobahnCases = onlyAutobahnCases + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Hour) defer cancel() - wstestURL, closeFn, err := wstestClientServer(ctx) + wstestURL, closeFn, err := wstestServer(t, ctx) assert.Success(t, err) - defer closeFn() + defer func() { + assert.Success(t, closeFn()) + }() err = waitWS(ctx, wstestURL) assert.Success(t, err) @@ -61,7 +82,9 @@ func TestAutobahn(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil) + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), &websocket.DialOptions{ + CompressionMode: websocket.CompressionContextTakeover, + }) assert.Success(t, err) err = wstest.EchoLoop(ctx, c) t.Logf("echoLoop: %v", err) @@ -73,7 +96,7 @@ func TestAutobahn(t *testing.T) { assert.Success(t, err) c.Close(websocket.StatusNormalClosure, "") - checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") + checkWSTestIndex(t, "./ci/out/autobahn-report/index.json") } func waitWS(ctx context.Context, url string) error { @@ -92,17 +115,24 @@ func waitWS(ctx context.Context, url string) error { return ctx.Err() } -func wstestClientServer(ctx context.Context) (url string, closeFn func(), err error) { +func wstestServer(tb testing.TB, ctx context.Context) (url string, closeFn func() error, err error) { + defer errd.Wrap(&err, "failed to start autobahn wstest server") + serverAddr, err := unusedListenAddr() if err != nil { return "", nil, err } + _, serverPort, err := net.SplitHostPort(serverAddr) + if err != nil { + return "", nil, err + } url = "ws://" + serverAddr + const outDir = "ci/out/autobahn-report" specFile, err := tempJSONFile(map[string]interface{}{ "url": url, - "outdir": "ci/out/wstestClientReports", + "outdir": outDir, "cases": autobahnCases, "exclude-cases": excludedAutobahnCases, }) @@ -110,26 +140,71 @@ func wstestClientServer(ctx context.Context) (url string, closeFn func(), err er return "", nil, fmt.Errorf("failed to write spec: %w", err) } - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) + ctx, cancel := context.WithTimeout(ctx, time.Hour) defer func() { if err != nil { cancel() } }() - args := []string{"--mode", "fuzzingserver", "--spec", specFile, + dockerPull := exec.CommandContext(ctx, "docker", "pull", "crossbario/autobahn-testsuite") + dockerPull.Stdout = util.WriterFunc(func(p []byte) (int, error) { + tb.Log(string(p)) + return len(p), nil + }) + dockerPull.Stderr = util.WriterFunc(func(p []byte) (int, error) { + tb.Log(string(p)) + return len(p), nil + }) + tb.Log(dockerPull) + err = dockerPull.Run() + if err != nil { + return "", nil, fmt.Errorf("failed to pull docker image: %w", err) + } + + wd, err := os.Getwd() + if err != nil { + return "", nil, err + } + + var args []string + args = append(args, "run", "-i", "--rm", + "-v", fmt.Sprintf("%s:%[1]s", specFile), + "-v", fmt.Sprintf("%s/ci:/ci", wd), + fmt.Sprintf("-p=%s:%s", serverAddr, serverPort), + "crossbario/autobahn-testsuite", + ) + args = append(args, "wstest", "--mode", "fuzzingserver", "--spec", specFile, // Disables some server that runs as part of fuzzingserver mode. // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 "--webport=0", - } - wstest := exec.CommandContext(ctx, "wstest", args...) + ) + wstest := exec.CommandContext(ctx, "docker", args...) + wstest.Stdout = util.WriterFunc(func(p []byte) (int, error) { + tb.Log(string(p)) + return len(p), nil + }) + wstest.Stderr = util.WriterFunc(func(p []byte) (int, error) { + tb.Log(string(p)) + return len(p), nil + }) + tb.Log(wstest) err = wstest.Start() if err != nil { return "", nil, fmt.Errorf("failed to start wstest: %w", err) } - return url, func() { - wstest.Process.Kill() + return url, func() error { + err = wstest.Process.Kill() + if err != nil { + return fmt.Errorf("failed to kill wstest: %w", err) + } + err = wstest.Wait() + var ee *exec.ExitError + if errors.As(err, &ee) && ee.ExitCode() == -1 { + return nil + } + return err }, nil } @@ -146,7 +221,7 @@ func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { if err != nil { return 0, err } - b, err := ioutil.ReadAll(r) + b, err := io.ReadAll(r) if err != nil { return 0, err } @@ -161,7 +236,7 @@ func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { } func checkWSTestIndex(t *testing.T, path string) { - wstestOut, err := ioutil.ReadFile(path) + wstestOut, err := os.ReadFile(path) assert.Success(t, err) var indexJSON map[string]map[string]struct { @@ -206,7 +281,7 @@ func unusedListenAddr() (_ string, err error) { } func tempJSONFile(v interface{}) (string, error) { - f, err := ioutil.TempFile("", "temp.json") + f, err := os.CreateTemp("", "temp.json") if err != nil { return "", fmt.Errorf("temp file: %w", err) } diff --git a/ci/all.sh b/ci/all.sh deleted file mode 100755 index 1ee7640f..00000000 --- a/ci/all.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -main() { - cd "$(dirname "$0")/.." - - ./ci/fmt.sh - ./ci/lint.sh - ./ci/test.sh "$@" -} - -main "$@" diff --git a/ci/bench.sh b/ci/bench.sh new file mode 100755 index 00000000..30c06986 --- /dev/null +++ b/ci/bench.sh @@ -0,0 +1,20 @@ +#!/bin/sh +set -eu +cd -- "$(dirname "$0")/.." + +go test --run=^$ --bench=. --benchmem "$@" ./... +# For profiling add: --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test +( + cd ./internal/thirdparty + go test --run=^$ --bench=. --benchmem "$@" . + + GOARCH=arm64 go test -c -o ../../ci/out/thirdparty-arm64.test "$@" . + if [ "$#" -eq 0 ]; then + if [ "${CI-}" ]; then + sudo apt-get update + sudo apt-get install -y qemu-user-static + ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64 + fi + qemu-aarch64 ../../ci/out/thirdparty-arm64.test --test.run=^$ --test.bench=Benchmark_mask --test.benchmem + fi +) diff --git a/ci/container/Dockerfile b/ci/container/Dockerfile deleted file mode 100644 index 0c6c2a54..00000000 --- a/ci/container/Dockerfile +++ /dev/null @@ -1,14 +0,0 @@ -FROM golang - -RUN apt-get update -RUN apt-get install -y npm shellcheck chromium - -ENV GO111MODULE=on -RUN go get golang.org/x/tools/cmd/goimports -RUN go get mvdan.cc/sh/v3/cmd/shfmt -RUN go get golang.org/x/tools/cmd/stringer -RUN go get golang.org/x/lint/golint -RUN go get github.com/agnivade/wasmbrowsertest - -RUN npm --unsafe-perm=true install -g prettier -RUN npm --unsafe-perm=true install -g netlify-cli diff --git a/ci/fmt.sh b/ci/fmt.sh index e6a2d689..31d0c15d 100755 --- a/ci/fmt.sh +++ b/ci/fmt.sh @@ -1,38 +1,24 @@ -#!/usr/bin/env bash -set -euo pipefail +#!/bin/sh +set -eu +cd -- "$(dirname "$0")/.." -main() { - cd "$(dirname "$0")/.." +go mod tidy +(cd ./internal/thirdparty && go mod tidy) +(cd ./internal/examples && go mod tidy) +gofmt -w -s . +go run golang.org/x/tools/cmd/goimports@latest -w "-local=$(go list -m)" . - go mod tidy - gofmt -w -s . - goimports -w "-local=$(go list -m)" . +npx prettier@3.0.3 \ + --write \ + --log-level=warn \ + --print-width=90 \ + --no-semi \ + --single-quote \ + --arrow-parens=avoid \ + $(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html") - prettier \ - --write \ - --print-width=120 \ - --no-semi \ - --trailing-comma=all \ - --loglevel=warn \ - --arrow-parens=avoid \ - $(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html") - shfmt -i 2 -w -s -sr $(git ls-files "*.sh") +go run golang.org/x/tools/cmd/stringer@latest -type=opcode,MessageType,StatusCode -output=stringer.go - stringer -type=opcode,MessageType,StatusCode -output=stringer.go - - if [[ ${CI-} ]]; then - ensure_fmt - fi -} - -ensure_fmt() { - if [[ $(git ls-files --other --modified --exclude-standard) ]]; then - git -c color.ui=always --no-pager diff - echo - echo "Please run the following locally:" - echo " ./ci/fmt.sh" - exit 1 - fi -} - -main "$@" +if [ "${CI-}" ]; then + git diff --exit-code +fi diff --git a/ci/lint.sh b/ci/lint.sh index e1053d13..3cf8eee4 100755 --- a/ci/lint.sh +++ b/ci/lint.sh @@ -1,16 +1,33 @@ -#!/usr/bin/env bash -set -euo pipefail +#!/bin/sh +set -eu +cd -- "$(dirname "$0")/.." -main() { - cd "$(dirname "$0")/.." +go vet ./... +GOOS=js GOARCH=wasm go vet ./... - go vet ./... - GOOS=js GOARCH=wasm go vet ./... - - golint -set_exit_status ./... - GOOS=js GOARCH=wasm golint -set_exit_status ./... +go install honnef.co/go/tools/cmd/staticcheck@latest +staticcheck ./... +GOOS=js GOARCH=wasm staticcheck ./... - shellcheck --exclude=SC2046 $(git ls-files "*.sh") +govulncheck() { + tmpf=$(mktemp) + if ! command govulncheck "$@" >"$tmpf" 2>&1; then + cat "$tmpf" + fi } +go install golang.org/x/vuln/cmd/govulncheck@latest +govulncheck ./... +GOOS=js GOARCH=wasm govulncheck ./... -main "$@" +( + cd ./internal/examples + go vet ./... + staticcheck ./... + govulncheck ./... +) +( + cd ./internal/thirdparty + go vet ./... + staticcheck ./... + govulncheck ./... +) diff --git a/ci/test.sh b/ci/test.sh index 95ef7101..a3007614 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -1,25 +1,36 @@ -#!/usr/bin/env bash -set -euo pipefail +#!/bin/sh +set -eu +cd -- "$(dirname "$0")/.." -main() { - cd "$(dirname "$0")/.." +( + cd ./internal/examples + go test "$@" ./... +) +( + cd ./internal/thirdparty + go test "$@" ./... +) - go test -timeout=30m -covermode=atomic -coverprofile=ci/out/coverage.prof -coverpkg=./... "$@" ./... - sed -i '/stringer\.go/d' ci/out/coverage.prof - sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof - sed -i '/examples/d' ci/out/coverage.prof +( + GOARCH=arm64 go test -c -o ./ci/out/websocket-arm64.test "$@" . + if [ "$#" -eq 0 ]; then + if [ "${CI-}" ]; then + sudo apt-get update + sudo apt-get install -y qemu-user-static + ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64 + fi + qemu-aarch64 ./ci/out/websocket-arm64.test -test.run=TestMask + fi +) - # Last line is the total coverage. - go tool cover -func ci/out/coverage.prof | tail -n1 - go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html +go install github.com/agnivade/wasmbrowsertest@latest +go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./... +sed -i.bak '/stringer\.go/d' ci/out/coverage.prof +sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof +sed -i.bak '/examples/d' ci/out/coverage.prof - if [[ ${CI-} && ${GITHUB_REF-} == *master ]]; then - local deployDir - deployDir="$(mktemp -d)" - cp ci/out/coverage.html "$deployDir/index.html" - netlify deploy --prod "--dir=$deployDir" - fi -} +# Last line is the total coverage. +go tool cover -func ci/out/coverage.prof | tail -n1 -main "$@" +go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html diff --git a/close.go b/close.go index 7cbc19e9..31504b0e 100644 --- a/close.go +++ b/close.go @@ -1,8 +1,17 @@ +//go:build !js +// +build !js + package websocket import ( + "context" + "encoding/binary" "errors" "fmt" + "net" + "time" + + "nhooyr.io/websocket/internal/errd" ) // StatusCode represents a WebSocket status code. @@ -74,3 +83,266 @@ func CloseStatus(err error) StatusCode { } return -1 } + +// Close performs the WebSocket close handshake with the given status code and reason. +// +// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for +// the peer to send a close frame. +// All data messages received from the peer during the close handshake will be discarded. +// +// The connection can only be closed once. Additional calls to Close +// are no-ops. +// +// The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason. +// +// Close will unblock all goroutines interacting with the connection once +// complete. +func (c *Conn) Close(code StatusCode, reason string) (err error) { + defer errd.Wrap(&err, "failed to close WebSocket") + + if !c.casClosing() { + err = c.waitGoroutines() + if err != nil { + return err + } + return net.ErrClosed + } + defer func() { + if errors.Is(err, net.ErrClosed) { + err = nil + } + }() + + err = c.closeHandshake(code, reason) + + err2 := c.close() + if err == nil && err2 != nil { + err = err2 + } + + err2 = c.waitGoroutines() + if err == nil && err2 != nil { + err = err2 + } + + return err +} + +// CloseNow closes the WebSocket connection without attempting a close handshake. +// Use when you do not want the overhead of the close handshake. +func (c *Conn) CloseNow() (err error) { + defer errd.Wrap(&err, "failed to immediately close WebSocket") + + if !c.casClosing() { + err = c.waitGoroutines() + if err != nil { + return err + } + return net.ErrClosed + } + defer func() { + if errors.Is(err, net.ErrClosed) { + err = nil + } + }() + + err = c.close() + + err2 := c.waitGoroutines() + if err == nil && err2 != nil { + err = err2 + } + return err +} + +func (c *Conn) closeHandshake(code StatusCode, reason string) error { + err := c.writeClose(code, reason) + if err != nil { + return err + } + + err = c.waitCloseHandshake() + if CloseStatus(err) != code { + return err + } + return nil +} + +func (c *Conn) writeClose(code StatusCode, reason string) error { + ce := CloseError{ + Code: code, + Reason: reason, + } + + var p []byte + var err error + if ce.Code != StatusNoStatusRcvd { + p, err = ce.bytes() + if err != nil { + return err + } + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + err = c.writeControl(ctx, opClose, p) + // If the connection closed as we're writing we ignore the error as we might + // have written the close frame, the peer responded and then someone else read it + // and closed the connection. + if err != nil && !errors.Is(err, net.ErrClosed) { + return err + } + return nil +} + +func (c *Conn) waitCloseHandshake() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + err := c.readMu.lock(ctx) + if err != nil { + return err + } + defer c.readMu.unlock() + + for i := int64(0); i < c.msgReader.payloadLength; i++ { + _, err := c.br.ReadByte() + if err != nil { + return err + } + } + + for { + h, err := c.readLoop(ctx) + if err != nil { + return err + } + + for i := int64(0); i < h.payloadLength; i++ { + _, err := c.br.ReadByte() + if err != nil { + return err + } + } + } +} + +func (c *Conn) waitGoroutines() error { + t := time.NewTimer(time.Second * 15) + defer t.Stop() + + select { + case <-c.timeoutLoopDone: + case <-t.C: + return errors.New("failed to wait for timeoutLoop goroutine to exit") + } + + c.closeReadMu.Lock() + closeRead := c.closeReadCtx != nil + c.closeReadMu.Unlock() + if closeRead { + select { + case <-c.closeReadDone: + case <-t.C: + return errors.New("failed to wait for close read goroutine to exit") + } + } + + select { + case <-c.closed: + case <-t.C: + return errors.New("failed to wait for connection to be closed") + } + + return nil +} + +func parseClosePayload(p []byte) (CloseError, error) { + if len(p) == 0 { + return CloseError{ + Code: StatusNoStatusRcvd, + }, nil + } + + if len(p) < 2 { + return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) + } + + ce := CloseError{ + Code: StatusCode(binary.BigEndian.Uint16(p)), + Reason: string(p[2:]), + } + + if !validWireCloseCode(ce.Code) { + return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) + } + + return ce, nil +} + +// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// and https://tools.ietf.org/html/rfc6455#section-7.4.1 +func validWireCloseCode(code StatusCode) bool { + switch code { + case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: + return false + } + + if code >= StatusNormalClosure && code <= StatusBadGateway { + return true + } + if code >= 3000 && code <= 4999 { + return true + } + + return false +} + +func (ce CloseError) bytes() ([]byte, error) { + p, err := ce.bytesErr() + if err != nil { + err = fmt.Errorf("failed to marshal close frame: %w", err) + ce = CloseError{ + Code: StatusInternalError, + } + p, _ = ce.bytesErr() + } + return p, err +} + +const maxCloseReason = maxControlPayload - 2 + +func (ce CloseError) bytesErr() ([]byte, error) { + if len(ce.Reason) > maxCloseReason { + return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) + } + + if !validWireCloseCode(ce.Code) { + return nil, fmt.Errorf("status code %v cannot be set", ce.Code) + } + + buf := make([]byte, 2+len(ce.Reason)) + binary.BigEndian.PutUint16(buf, uint16(ce.Code)) + copy(buf[2:], ce.Reason) + return buf, nil +} + +func (c *Conn) casClosing() bool { + c.closeMu.Lock() + defer c.closeMu.Unlock() + if !c.closing { + c.closing = true + return true + } + return false +} + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} diff --git a/close_notjs.go b/close_notjs.go deleted file mode 100644 index 4251311d..00000000 --- a/close_notjs.go +++ /dev/null @@ -1,211 +0,0 @@ -// +build !js - -package websocket - -import ( - "context" - "encoding/binary" - "errors" - "fmt" - "log" - "time" - - "nhooyr.io/websocket/internal/errd" -) - -// Close performs the WebSocket close handshake with the given status code and reason. -// -// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for -// the peer to send a close frame. -// All data messages received from the peer during the close handshake will be discarded. -// -// The connection can only be closed once. Additional calls to Close -// are no-ops. -// -// The maximum length of reason must be 125 bytes. Avoid -// sending a dynamic reason. -// -// Close will unblock all goroutines interacting with the connection once -// complete. -func (c *Conn) Close(code StatusCode, reason string) error { - return c.closeHandshake(code, reason) -} - -func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { - defer errd.Wrap(&err, "failed to close WebSocket") - - writeErr := c.writeClose(code, reason) - closeHandshakeErr := c.waitCloseHandshake() - - if writeErr != nil { - return writeErr - } - - if CloseStatus(closeHandshakeErr) == -1 { - return closeHandshakeErr - } - - return nil -} - -var errAlreadyWroteClose = errors.New("already wrote close") - -func (c *Conn) writeClose(code StatusCode, reason string) error { - c.closeMu.Lock() - wroteClose := c.wroteClose - c.wroteClose = true - c.closeMu.Unlock() - if wroteClose { - return errAlreadyWroteClose - } - - ce := CloseError{ - Code: code, - Reason: reason, - } - - var p []byte - var marshalErr error - if ce.Code != StatusNoStatusRcvd { - p, marshalErr = ce.bytes() - if marshalErr != nil { - log.Printf("websocket: %v", marshalErr) - } - } - - writeErr := c.writeControl(context.Background(), opClose, p) - if CloseStatus(writeErr) != -1 { - // Not a real error if it's due to a close frame being received. - writeErr = nil - } - - // We do this after in case there was an error writing the close frame. - c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) - - if marshalErr != nil { - return marshalErr - } - return writeErr -} - -func (c *Conn) waitCloseHandshake() error { - defer c.close(nil) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - err := c.readMu.lock(ctx) - if err != nil { - return err - } - defer c.readMu.unlock() - - if c.readCloseFrameErr != nil { - return c.readCloseFrameErr - } - - for { - h, err := c.readLoop(ctx) - if err != nil { - return err - } - - for i := int64(0); i < h.payloadLength; i++ { - _, err := c.br.ReadByte() - if err != nil { - return err - } - } - } -} - -func parseClosePayload(p []byte) (CloseError, error) { - if len(p) == 0 { - return CloseError{ - Code: StatusNoStatusRcvd, - }, nil - } - - if len(p) < 2 { - return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) - } - - ce := CloseError{ - Code: StatusCode(binary.BigEndian.Uint16(p)), - Reason: string(p[2:]), - } - - if !validWireCloseCode(ce.Code) { - return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) - } - - return ce, nil -} - -// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number -// and https://tools.ietf.org/html/rfc6455#section-7.4.1 -func validWireCloseCode(code StatusCode) bool { - switch code { - case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: - return false - } - - if code >= StatusNormalClosure && code <= StatusBadGateway { - return true - } - if code >= 3000 && code <= 4999 { - return true - } - - return false -} - -func (ce CloseError) bytes() ([]byte, error) { - p, err := ce.bytesErr() - if err != nil { - err = fmt.Errorf("failed to marshal close frame: %w", err) - ce = CloseError{ - Code: StatusInternalError, - } - p, _ = ce.bytesErr() - } - return p, err -} - -const maxCloseReason = maxControlPayload - 2 - -func (ce CloseError) bytesErr() ([]byte, error) { - if len(ce.Reason) > maxCloseReason { - return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) - } - - if !validWireCloseCode(ce.Code) { - return nil, fmt.Errorf("status code %v cannot be set", ce.Code) - } - - buf := make([]byte, 2+len(ce.Reason)) - binary.BigEndian.PutUint16(buf, uint16(ce.Code)) - copy(buf[2:], ce.Reason) - return buf, nil -} - -func (c *Conn) setCloseErr(err error) { - c.closeMu.Lock() - c.setCloseErrLocked(err) - c.closeMu.Unlock() -} - -func (c *Conn) setCloseErrLocked(err error) { - if c.closeErr == nil { - c.closeErr = fmt.Errorf("WebSocket closed: %w", err) - } -} - -func (c *Conn) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} diff --git a/close_test.go b/close_test.go index 00a48d9e..6bf3c256 100644 --- a/close_test.go +++ b/close_test.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package websocket diff --git a/compress.go b/compress.go index 80b46d1c..1f3adcfb 100644 --- a/compress.go +++ b/compress.go @@ -1,39 +1,233 @@ +//go:build !js +// +build !js + package websocket -// CompressionMode represents the modes available to the deflate extension. +import ( + "compress/flate" + "io" + "sync" +) + +// CompressionMode represents the modes available to the permessage-deflate extension. // See https://tools.ietf.org/html/rfc7692 // -// A compatibility layer is implemented for the older deflate-frame extension used -// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06 -// It will work the same in every way except that we cannot signal to the peer we -// want to use no context takeover on our side, we can only signal that they should. -// It is however currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218 +// Works in all modern browsers except Safari which does not implement the permessage-deflate extension. +// +// Compression is only used if the peer supports the mode selected. type CompressionMode int const ( - // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed - // for every message. This applies to both server and client side. - // - // This means less efficient compression as the sliding window from previous messages - // will not be used but the memory overhead will be lower if the connections - // are long lived and seldom used. + // CompressionDisabled disables the negotiation of the permessage-deflate extension. // - // The message will only be compressed if greater than 512 bytes. - CompressionNoContextTakeover CompressionMode = iota + // This is the default. Do not enable compression without benchmarking for your particular use case first. + CompressionDisabled CompressionMode = iota - // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. - // This enables reusing the sliding window from previous messages. - // As most WebSocket protocols are repetitive, this can be very efficient. - // It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover. + // CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from + // previous messages. i.e compression context across messages is preserved. + // + // As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient. // - // If the peer negotiates NoContextTakeover on the client or server side, it will be - // used instead as this is required by the RFC. + // The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's + // that are used when reading and then returned. + // + // Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently. + // + // If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover. CompressionContextTakeover - // CompressionDisabled disables the deflate extension. + // CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with + // a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from + // a sync.Pool. + // + // This means less efficient compression as the sliding window from previous messages will not be used but the + // memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window. + // Especially if the connections are long lived and seldom written to. // - // Use this if you are using a predominantly binary protocol with very - // little duplication in between messages or CPU and memory are more - // important than bandwidth. - CompressionDisabled + // Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently. + // + // If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled. + CompressionNoContextTakeover ) + +func (m CompressionMode) opts() *compressionOptions { + return &compressionOptions{ + clientNoContextTakeover: m == CompressionNoContextTakeover, + serverNoContextTakeover: m == CompressionNoContextTakeover, + } +} + +type compressionOptions struct { + clientNoContextTakeover bool + serverNoContextTakeover bool +} + +func (copts *compressionOptions) String() string { + s := "permessage-deflate" + if copts.clientNoContextTakeover { + s += "; client_no_context_takeover" + } + if copts.serverNoContextTakeover { + s += "; server_no_context_takeover" + } + return s +} + +// These bytes are required to get flate.Reader to return. +// They are removed when sending to avoid the overhead as +// WebSocket framing tell's when the message has ended but then +// we need to add them back otherwise flate.Reader keeps +// trying to read more bytes. +const deflateMessageTail = "\x00\x00\xff\xff" + +type trimLastFourBytesWriter struct { + w io.Writer + tail []byte +} + +func (tw *trimLastFourBytesWriter) reset() { + if tw != nil && tw.tail != nil { + tw.tail = tw.tail[:0] + } +} + +func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { + if tw.tail == nil { + tw.tail = make([]byte, 0, 4) + } + + extra := len(tw.tail) + len(p) - 4 + + if extra <= 0 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Now we need to write as many extra bytes as we can from the previous tail. + if extra > len(tw.tail) { + extra = len(tw.tail) + } + if extra > 0 { + _, err := tw.w.Write(tw.tail[:extra]) + if err != nil { + return 0, err + } + + // Shift remaining bytes in tail over. + n := copy(tw.tail, tw.tail[extra:]) + tw.tail = tw.tail[:n] + } + + // If p is less than or equal to 4 bytes, + // all of it is is part of the tail. + if len(p) <= 4 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Otherwise, only the last 4 bytes are. + tw.tail = append(tw.tail, p[len(p)-4:]...) + + p = p[:len(p)-4] + n, err := tw.w.Write(p) + return n + 4, err +} + +var flateReaderPool sync.Pool + +func getFlateReader(r io.Reader, dict []byte) io.Reader { + fr, ok := flateReaderPool.Get().(io.Reader) + if !ok { + return flate.NewReaderDict(r, dict) + } + fr.(flate.Resetter).Reset(r, dict) + return fr +} + +func putFlateReader(fr io.Reader) { + flateReaderPool.Put(fr) +} + +var flateWriterPool sync.Pool + +func getFlateWriter(w io.Writer) *flate.Writer { + fw, ok := flateWriterPool.Get().(*flate.Writer) + if !ok { + fw, _ = flate.NewWriter(w, flate.BestSpeed) + return fw + } + fw.Reset(w) + return fw +} + +func putFlateWriter(w *flate.Writer) { + flateWriterPool.Put(w) +} + +type slidingWindow struct { + buf []byte +} + +var swPoolMu sync.RWMutex +var swPool = map[int]*sync.Pool{} + +func slidingWindowPool(n int) *sync.Pool { + swPoolMu.RLock() + p, ok := swPool[n] + swPoolMu.RUnlock() + if ok { + return p + } + + p = &sync.Pool{} + + swPoolMu.Lock() + swPool[n] = p + swPoolMu.Unlock() + + return p +} + +func (sw *slidingWindow) init(n int) { + if sw.buf != nil { + return + } + + if n == 0 { + n = 32768 + } + + p := slidingWindowPool(n) + sw2, ok := p.Get().(*slidingWindow) + if ok { + *sw = *sw2 + } else { + sw.buf = make([]byte, 0, n) + } +} + +func (sw *slidingWindow) close() { + sw.buf = sw.buf[:0] + swPoolMu.Lock() + swPool[cap(sw.buf)].Put(sw) + swPoolMu.Unlock() +} + +func (sw *slidingWindow) write(p []byte) { + if len(p) >= cap(sw.buf) { + sw.buf = sw.buf[:cap(sw.buf)] + p = p[len(p)-cap(sw.buf):] + copy(sw.buf, p) + return + } + + left := cap(sw.buf) - len(sw.buf) + if left < len(p) { + // We need to shift spaceNeeded bytes from the end to make room for p at the end. + spaceNeeded := len(p) - left + copy(sw.buf, sw.buf[spaceNeeded:]) + sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] + } + + sw.buf = append(sw.buf, p...) +} diff --git a/compress_notjs.go b/compress_notjs.go deleted file mode 100644 index 809a272c..00000000 --- a/compress_notjs.go +++ /dev/null @@ -1,181 +0,0 @@ -// +build !js - -package websocket - -import ( - "io" - "net/http" - "sync" - - "github.com/klauspost/compress/flate" -) - -func (m CompressionMode) opts() *compressionOptions { - return &compressionOptions{ - clientNoContextTakeover: m == CompressionNoContextTakeover, - serverNoContextTakeover: m == CompressionNoContextTakeover, - } -} - -type compressionOptions struct { - clientNoContextTakeover bool - serverNoContextTakeover bool -} - -func (copts *compressionOptions) setHeader(h http.Header) { - s := "permessage-deflate" - if copts.clientNoContextTakeover { - s += "; client_no_context_takeover" - } - if copts.serverNoContextTakeover { - s += "; server_no_context_takeover" - } - h.Set("Sec-WebSocket-Extensions", s) -} - -// These bytes are required to get flate.Reader to return. -// They are removed when sending to avoid the overhead as -// WebSocket framing tell's when the message has ended but then -// we need to add them back otherwise flate.Reader keeps -// trying to return more bytes. -const deflateMessageTail = "\x00\x00\xff\xff" - -type trimLastFourBytesWriter struct { - w io.Writer - tail []byte -} - -func (tw *trimLastFourBytesWriter) reset() { - if tw != nil && tw.tail != nil { - tw.tail = tw.tail[:0] - } -} - -func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { - if tw.tail == nil { - tw.tail = make([]byte, 0, 4) - } - - extra := len(tw.tail) + len(p) - 4 - - if extra <= 0 { - tw.tail = append(tw.tail, p...) - return len(p), nil - } - - // Now we need to write as many extra bytes as we can from the previous tail. - if extra > len(tw.tail) { - extra = len(tw.tail) - } - if extra > 0 { - _, err := tw.w.Write(tw.tail[:extra]) - if err != nil { - return 0, err - } - - // Shift remaining bytes in tail over. - n := copy(tw.tail, tw.tail[extra:]) - tw.tail = tw.tail[:n] - } - - // If p is less than or equal to 4 bytes, - // all of it is is part of the tail. - if len(p) <= 4 { - tw.tail = append(tw.tail, p...) - return len(p), nil - } - - // Otherwise, only the last 4 bytes are. - tw.tail = append(tw.tail, p[len(p)-4:]...) - - p = p[:len(p)-4] - n, err := tw.w.Write(p) - return n + 4, err -} - -var flateReaderPool sync.Pool - -func getFlateReader(r io.Reader, dict []byte) io.Reader { - fr, ok := flateReaderPool.Get().(io.Reader) - if !ok { - return flate.NewReaderDict(r, dict) - } - fr.(flate.Resetter).Reset(r, dict) - return fr -} - -func putFlateReader(fr io.Reader) { - flateReaderPool.Put(fr) -} - -type slidingWindow struct { - buf []byte -} - -var swPoolMu sync.RWMutex -var swPool = map[int]*sync.Pool{} - -func slidingWindowPool(n int) *sync.Pool { - swPoolMu.RLock() - p, ok := swPool[n] - swPoolMu.RUnlock() - if ok { - return p - } - - p = &sync.Pool{} - - swPoolMu.Lock() - swPool[n] = p - swPoolMu.Unlock() - - return p -} - -func (sw *slidingWindow) init(n int) { - if sw.buf != nil { - return - } - - if n == 0 { - n = 32768 - } - - p := slidingWindowPool(n) - buf, ok := p.Get().([]byte) - if ok { - sw.buf = buf[:0] - } else { - sw.buf = make([]byte, 0, n) - } -} - -func (sw *slidingWindow) close() { - if sw.buf == nil { - return - } - - swPoolMu.Lock() - swPool[cap(sw.buf)].Put(sw.buf) - swPoolMu.Unlock() - sw.buf = nil -} - -func (sw *slidingWindow) write(p []byte) { - if len(p) >= cap(sw.buf) { - sw.buf = sw.buf[:cap(sw.buf)] - p = p[len(p)-cap(sw.buf):] - copy(sw.buf, p) - return - } - - left := cap(sw.buf) - len(sw.buf) - if left < len(p) { - // We need to shift spaceNeeded bytes from the end to make room for p at the end. - spaceNeeded := len(p) - left - copy(sw.buf, sw.buf[spaceNeeded:]) - sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] - } - - sw.buf = append(sw.buf, p...) -} diff --git a/compress_test.go b/compress_test.go index 2c4c896c..667e1408 100644 --- a/compress_test.go +++ b/compress_test.go @@ -1,8 +1,12 @@ +//go:build !js // +build !js package websocket import ( + "bytes" + "compress/flate" + "io" "strings" "testing" @@ -32,3 +36,27 @@ func Test_slidingWindow(t *testing.T) { }) } } + +func BenchmarkFlateWriter(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + w, _ := flate.NewWriter(io.Discard, flate.BestSpeed) + // We have to write a byte to get the writer to allocate to its full extent. + w.Write([]byte{'a'}) + w.Flush() + } +} + +func BenchmarkFlateReader(b *testing.B) { + b.ReportAllocs() + + var buf bytes.Buffer + w, _ := flate.NewWriter(&buf, flate.BestSpeed) + w.Write([]byte{'a'}) + w.Flush() + + for i := 0; i < b.N; i++ { + r := flate.NewReader(bytes.NewReader(buf.Bytes())) + io.ReadAll(r) + } +} diff --git a/conn.go b/conn.go index a41808be..8690fb3b 100644 --- a/conn.go +++ b/conn.go @@ -1,5 +1,20 @@ +//go:build !js +// +build !js + package websocket +import ( + "bufio" + "context" + "fmt" + "io" + "net" + "runtime" + "strconv" + "sync" + "sync/atomic" +) + // MessageType represents the type of a WebSocket message. // See https://tools.ietf.org/html/rfc6455#section-5.6 type MessageType int @@ -11,3 +26,270 @@ const ( // MessageBinary is for binary messages like protobufs. MessageBinary ) + +// Conn represents a WebSocket connection. +// All methods may be called concurrently except for Reader and Read. +// +// You must always read from the connection. Otherwise control +// frames will not be handled. See Reader and CloseRead. +// +// Be sure to call Close on the connection when you +// are finished with it to release associated resources. +// +// On any error from any method, the connection is closed +// with an appropriate reason. +// +// This applies to context expirations as well unfortunately. +// See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220 +type Conn struct { + noCopy noCopy + + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + br *bufio.Reader + bw *bufio.Writer + + readTimeout chan context.Context + writeTimeout chan context.Context + timeoutLoopDone chan struct{} + + // Read state. + readMu *mu + readHeaderBuf [8]byte + readControlBuf [maxControlPayload]byte + msgReader *msgReader + + // Write state. + msgWriter *msgWriter + writeFrameMu *mu + writeBuf []byte + writeHeaderBuf [8]byte + writeHeader header + + closeReadMu sync.Mutex + closeReadCtx context.Context + closeReadDone chan struct{} + + closed chan struct{} + closeMu sync.Mutex + closing bool + + pingCounter int32 + activePingsMu sync.Mutex + activePings map[string]chan<- struct{} +} + +type connConfig struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + + br *bufio.Reader + bw *bufio.Writer +} + +func newConn(cfg connConfig) *Conn { + c := &Conn{ + subprotocol: cfg.subprotocol, + rwc: cfg.rwc, + client: cfg.client, + copts: cfg.copts, + flateThreshold: cfg.flateThreshold, + + br: cfg.br, + bw: cfg.bw, + + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), + timeoutLoopDone: make(chan struct{}), + + closed: make(chan struct{}), + activePings: make(map[string]chan<- struct{}), + } + + c.readMu = newMu(c) + c.writeFrameMu = newMu(c) + + c.msgReader = newMsgReader(c) + + c.msgWriter = newMsgWriter(c) + if c.client { + c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) + } + + if c.flate() && c.flateThreshold == 0 { + c.flateThreshold = 128 + if !c.msgWriter.flateContextTakeover() { + c.flateThreshold = 512 + } + } + + runtime.SetFinalizer(c, func(c *Conn) { + c.close() + }) + + go c.timeoutLoop() + + return c +} + +// Subprotocol returns the negotiated subprotocol. +// An empty string means the default protocol. +func (c *Conn) Subprotocol() string { + return c.subprotocol +} + +func (c *Conn) close() error { + c.closeMu.Lock() + defer c.closeMu.Unlock() + + if c.isClosed() { + return net.ErrClosed + } + runtime.SetFinalizer(c, nil) + close(c.closed) + + // Have to close after c.closed is closed to ensure any goroutine that wakes up + // from the connection being closed also sees that c.closed is closed and returns + // closeErr. + err := c.rwc.Close() + // With the close of rwc, these become safe to close. + c.msgWriter.close() + c.msgReader.close() + return err +} + +func (c *Conn) timeoutLoop() { + defer close(c.timeoutLoopDone) + + readCtx := context.Background() + writeCtx := context.Background() + + for { + select { + case <-c.closed: + return + + case writeCtx = <-c.writeTimeout: + case readCtx = <-c.readTimeout: + + case <-readCtx.Done(): + c.close() + return + case <-writeCtx.Done(): + c.close() + return + } + } +} + +func (c *Conn) flate() bool { + return c.copts != nil +} + +// Ping sends a ping to the peer and waits for a pong. +// Use this to measure latency or ensure the peer is responsive. +// Ping must be called concurrently with Reader as it does +// not read from the connection but instead waits for a Reader call +// to read the pong. +// +// TCP Keepalives should suffice for most use cases. +func (c *Conn) Ping(ctx context.Context) error { + p := atomic.AddInt32(&c.pingCounter, 1) + + err := c.ping(ctx, strconv.Itoa(int(p))) + if err != nil { + return fmt.Errorf("failed to ping: %w", err) + } + return nil +} + +func (c *Conn) ping(ctx context.Context, p string) error { + pong := make(chan struct{}, 1) + + c.activePingsMu.Lock() + c.activePings[p] = pong + c.activePingsMu.Unlock() + + defer func() { + c.activePingsMu.Lock() + delete(c.activePings, p) + c.activePingsMu.Unlock() + }() + + err := c.writeControl(ctx, opPing, []byte(p)) + if err != nil { + return err + } + + select { + case <-c.closed: + return net.ErrClosed + case <-ctx.Done(): + return fmt.Errorf("failed to wait for pong: %w", ctx.Err()) + case <-pong: + return nil + } +} + +type mu struct { + c *Conn + ch chan struct{} +} + +func newMu(c *Conn) *mu { + return &mu{ + c: c, + ch: make(chan struct{}, 1), + } +} + +func (m *mu) forceLock() { + m.ch <- struct{}{} +} + +func (m *mu) tryLock() bool { + select { + case m.ch <- struct{}{}: + return true + default: + return false + } +} + +func (m *mu) lock(ctx context.Context) error { + select { + case <-m.c.closed: + return net.ErrClosed + case <-ctx.Done(): + return fmt.Errorf("failed to acquire lock: %w", ctx.Err()) + case m.ch <- struct{}{}: + // To make sure the connection is certainly alive. + // As it's possible the send on m.ch was selected + // over the receive on closed. + select { + case <-m.c.closed: + // Make sure to release. + m.unlock() + return net.ErrClosed + default: + } + return nil + } +} + +func (m *mu) unlock() { + select { + case <-m.ch: + default: + } +} + +type noCopy struct{} + +func (*noCopy) Lock() {} diff --git a/conn_notjs.go b/conn_notjs.go deleted file mode 100644 index 0c85ab77..00000000 --- a/conn_notjs.go +++ /dev/null @@ -1,265 +0,0 @@ -// +build !js - -package websocket - -import ( - "bufio" - "context" - "errors" - "fmt" - "io" - "runtime" - "strconv" - "sync" - "sync/atomic" -) - -// Conn represents a WebSocket connection. -// All methods may be called concurrently except for Reader and Read. -// -// You must always read from the connection. Otherwise control -// frames will not be handled. See Reader and CloseRead. -// -// Be sure to call Close on the connection when you -// are finished with it to release associated resources. -// -// On any error from any method, the connection is closed -// with an appropriate reason. -type Conn struct { - subprotocol string - rwc io.ReadWriteCloser - client bool - copts *compressionOptions - flateThreshold int - br *bufio.Reader - bw *bufio.Writer - - readTimeout chan context.Context - writeTimeout chan context.Context - - // Read state. - readMu *mu - readHeaderBuf [8]byte - readControlBuf [maxControlPayload]byte - msgReader *msgReader - readCloseFrameErr error - - // Write state. - msgWriterState *msgWriterState - writeFrameMu *mu - writeBuf []byte - writeHeaderBuf [8]byte - writeHeader header - - closed chan struct{} - closeMu sync.Mutex - closeErr error - wroteClose bool - - pingCounter int32 - activePingsMu sync.Mutex - activePings map[string]chan<- struct{} -} - -type connConfig struct { - subprotocol string - rwc io.ReadWriteCloser - client bool - copts *compressionOptions - flateThreshold int - - br *bufio.Reader - bw *bufio.Writer -} - -func newConn(cfg connConfig) *Conn { - c := &Conn{ - subprotocol: cfg.subprotocol, - rwc: cfg.rwc, - client: cfg.client, - copts: cfg.copts, - flateThreshold: cfg.flateThreshold, - - br: cfg.br, - bw: cfg.bw, - - readTimeout: make(chan context.Context), - writeTimeout: make(chan context.Context), - - closed: make(chan struct{}), - activePings: make(map[string]chan<- struct{}), - } - - c.readMu = newMu(c) - c.writeFrameMu = newMu(c) - - c.msgReader = newMsgReader(c) - - c.msgWriterState = newMsgWriterState(c) - if c.client { - c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) - } - - if c.flate() && c.flateThreshold == 0 { - c.flateThreshold = 128 - if !c.msgWriterState.flateContextTakeover() { - c.flateThreshold = 512 - } - } - - runtime.SetFinalizer(c, func(c *Conn) { - c.close(errors.New("connection garbage collected")) - }) - - go c.timeoutLoop() - - return c -} - -// Subprotocol returns the negotiated subprotocol. -// An empty string means the default protocol. -func (c *Conn) Subprotocol() string { - return c.subprotocol -} - -func (c *Conn) close(err error) { - c.closeMu.Lock() - defer c.closeMu.Unlock() - - if c.isClosed() { - return - } - c.setCloseErrLocked(err) - close(c.closed) - runtime.SetFinalizer(c, nil) - - // Have to close after c.closed is closed to ensure any goroutine that wakes up - // from the connection being closed also sees that c.closed is closed and returns - // closeErr. - c.rwc.Close() - - go func() { - c.msgWriterState.close() - - c.msgReader.close() - }() -} - -func (c *Conn) timeoutLoop() { - readCtx := context.Background() - writeCtx := context.Background() - - for { - select { - case <-c.closed: - return - - case writeCtx = <-c.writeTimeout: - case readCtx = <-c.readTimeout: - - case <-readCtx.Done(): - c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) - go c.writeError(StatusPolicyViolation, errors.New("timed out")) - case <-writeCtx.Done(): - c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) - return - } - } -} - -func (c *Conn) flate() bool { - return c.copts != nil -} - -// Ping sends a ping to the peer and waits for a pong. -// Use this to measure latency or ensure the peer is responsive. -// Ping must be called concurrently with Reader as it does -// not read from the connection but instead waits for a Reader call -// to read the pong. -// -// TCP Keepalives should suffice for most use cases. -func (c *Conn) Ping(ctx context.Context) error { - p := atomic.AddInt32(&c.pingCounter, 1) - - err := c.ping(ctx, strconv.Itoa(int(p))) - if err != nil { - return fmt.Errorf("failed to ping: %w", err) - } - return nil -} - -func (c *Conn) ping(ctx context.Context, p string) error { - pong := make(chan struct{}, 1) - - c.activePingsMu.Lock() - c.activePings[p] = pong - c.activePingsMu.Unlock() - - defer func() { - c.activePingsMu.Lock() - delete(c.activePings, p) - c.activePingsMu.Unlock() - }() - - err := c.writeControl(ctx, opPing, []byte(p)) - if err != nil { - return err - } - - select { - case <-c.closed: - return c.closeErr - case <-ctx.Done(): - err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) - c.close(err) - return err - case <-pong: - return nil - } -} - -type mu struct { - c *Conn - ch chan struct{} -} - -func newMu(c *Conn) *mu { - return &mu{ - c: c, - ch: make(chan struct{}, 1), - } -} - -func (m *mu) forceLock() { - m.ch <- struct{}{} -} - -func (m *mu) lock(ctx context.Context) error { - select { - case <-m.c.closed: - return m.c.closeErr - case <-ctx.Done(): - err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) - m.c.close(err) - return err - case m.ch <- struct{}{}: - // To make sure the connection is certainly alive. - // As it's possible the send on m.ch was selected - // over the receive on closed. - select { - case <-m.c.closed: - // Make sure to release. - m.unlock() - return m.c.closeErr - default: - } - return nil - } -} - -func (m *mu) unlock() { - select { - case <-m.ch: - default: - } -} diff --git a/conn_test.go b/conn_test.go index c2c41292..2b44ad22 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,13 +1,13 @@ -// +build !js +//go:build !js package websocket_test import ( "bytes" "context" + "errors" "fmt" "io" - "io/ioutil" "net/http" "net/http/httptest" "os" @@ -16,10 +16,6 @@ import ( "testing" "time" - "github.com/gin-gonic/gin" - "github.com/golang/protobuf/ptypes" - "github.com/golang/protobuf/ptypes/duration" - "nhooyr.io/websocket" "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/test/assert" @@ -27,7 +23,6 @@ import ( "nhooyr.io/websocket/internal/test/xrand" "nhooyr.io/websocket/internal/xsync" "nhooyr.io/websocket/wsjson" - "nhooyr.io/websocket/wspb" ) func TestConn(t *testing.T) { @@ -37,7 +32,7 @@ func TestConn(t *testing.T) { t.Parallel() compressionMode := func() websocket.CompressionMode { - return websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)) + return websocket.CompressionMode(xrand.Int(int(websocket.CompressionContextTakeover) + 1)) } for i := 0; i < 5; i++ { @@ -49,7 +44,6 @@ func TestConn(t *testing.T) { CompressionMode: compressionMode(), CompressionThreshold: xrand.Int(9999), }) - defer tt.cleanup() tt.goEchoLoop(c2) @@ -67,8 +61,9 @@ func TestConn(t *testing.T) { }) t.Run("badClose", func(t *testing.T) { - tt, c1, _ := newConnTest(t, nil, nil) - defer tt.cleanup() + tt, c1, c2 := newConnTest(t, nil, nil) + + c2.CloseRead(tt.ctx) err := c1.Close(-1, "") assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") @@ -76,7 +71,6 @@ func TestConn(t *testing.T) { t.Run("ping", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() c1.CloseRead(tt.ctx) c2.CloseRead(tt.ctx) @@ -92,7 +86,6 @@ func TestConn(t *testing.T) { t.Run("badPing", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() c2.CloseRead(tt.ctx) @@ -105,7 +98,6 @@ func TestConn(t *testing.T) { t.Run("concurrentWrite", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() tt.goDiscardLoop(c2) @@ -138,7 +130,6 @@ func TestConn(t *testing.T) { t.Run("concurrentWriteError", func(t *testing.T) { tt, c1, _ := newConnTest(t, nil, nil) - defer tt.cleanup() _, err := c1.Writer(tt.ctx, websocket.MessageText) assert.Success(t, err) @@ -147,12 +138,13 @@ func TestConn(t *testing.T) { defer cancel() err = c1.Write(ctx, websocket.MessageText, []byte("x")) - assert.Equal(t, "write error", context.DeadlineExceeded, err) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("unexpected error: %#v", err) + } }) t.Run("netConn", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) @@ -163,8 +155,8 @@ func TestConn(t *testing.T) { n1.SetDeadline(time.Time{}) assert.Equal(t, "remote addr", n1.RemoteAddr(), n1.LocalAddr()) - assert.Equal(t, "remote addr string", "websocket/unknown-addr", n1.RemoteAddr().String()) - assert.Equal(t, "remote addr network", "websocket", n1.RemoteAddr().Network()) + assert.Equal(t, "remote addr string", "pipe", n1.RemoteAddr().String()) + assert.Equal(t, "remote addr network", "pipe", n1.RemoteAddr().Network()) errs := xsync.Go(func() error { _, err := n2.Write([]byte("hello")) @@ -174,7 +166,7 @@ func TestConn(t *testing.T) { return n2.Close() }) - b, err := ioutil.ReadAll(n1) + b, err := io.ReadAll(n1) assert.Success(t, err) _, err = n1.Read(nil) @@ -192,21 +184,47 @@ func TestConn(t *testing.T) { t.Run("netConn/BadMsg", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText) + c2.CloseRead(tt.ctx) errs := xsync.Go(func() error { _, err := n2.Write([]byte("hello")) + return err + }) + + _, err := io.ReadAll(n1) + assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`) + + select { + case err := <-errs: + assert.Success(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + }) + + t.Run("netConn/readLimit", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + + n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) + n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) + + s := strings.Repeat("papa", 1<<20) + errs := xsync.Go(func() error { + _, err := n2.Write([]byte(s)) if err != nil { return err } - return nil + return n2.Close() }) - _, err := ioutil.ReadAll(n1) - assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`) + b, err := io.ReadAll(n1) + assert.Success(t, err) + + _, err = n1.Read(nil) + assert.Equal(t, "read error", err, io.EOF) select { case err := <-errs: @@ -214,11 +232,24 @@ func TestConn(t *testing.T) { case <-tt.ctx.Done(): t.Fatal(tt.ctx.Err()) } + + assert.Equal(t, "read msg", s, string(b)) + }) + + t.Run("netConn/pastDeadline", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + + n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) + n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) + + n1.SetDeadline(time.Now().Add(-time.Minute)) + n2.SetDeadline(time.Now().Add(-time.Minute)) + + // No panic we're good. }) t.Run("wsjson", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() tt.goEchoLoop(c2) @@ -246,21 +277,67 @@ func TestConn(t *testing.T) { assert.Success(t, err) }) - t.Run("wspb", func(t *testing.T) { - tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() + t.Run("HTTPClient.Timeout", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, &websocket.DialOptions{ + HTTPClient: &http.Client{Timeout: time.Second * 5}, + }, nil) tt.goEchoLoop(c2) - exp := ptypes.DurationProto(100) - err := wspb.Write(tt.ctx, c1, exp) - assert.Success(t, err) + c1.SetReadLimit(1 << 30) + + exp := xrand.String(xrand.Int(131072)) - act := &duration.Duration{} - err = wspb.Read(tt.ctx, c1, act) + werr := xsync.Go(func() error { + return wsjson.Write(tt.ctx, c1, exp) + }) + + var act interface{} + err := wsjson.Read(tt.ctx, c1, &act) assert.Success(t, err) assert.Equal(t, "read msg", exp, act) + select { + case err := <-werr: + assert.Success(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + err = c1.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + }) + + t.Run("CloseNow", func(t *testing.T) { + _, c1, c2 := newConnTest(t, nil, nil) + + err1 := c1.CloseNow() + err2 := c2.CloseNow() + assert.Success(t, err1) + assert.Success(t, err2) + err1 = c1.CloseNow() + err2 = c2.CloseNow() + assert.ErrorIs(t, websocket.ErrClosed, err1) + assert.ErrorIs(t, websocket.ErrClosed, err2) + }) + + t.Run("MidReadClose", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + + tt.goEchoLoop(c2) + + c1.SetReadLimit(131072) + + for i := 0; i < 5; i++ { + err := wstest.Echo(tt.ctx, c1, 131072) + assert.Success(t, err) + } + + err := wsjson.Write(tt.ctx, c1, "four") + assert.Success(t, err) + _, _, err = c1.Reader(tt.ctx) + assert.Success(t, err) + err = c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) @@ -268,6 +345,9 @@ func TestConn(t *testing.T) { func TestWasm(t *testing.T) { t.Parallel() + if os.Getenv("CI") == "" { + t.SkipNow() + } s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := echoServer(w, r, &websocket.AcceptOptions{ @@ -283,7 +363,7 @@ func TestWasm(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".") + cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".", "-v") cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL)) b, err := cmd.CombinedOutput() @@ -305,8 +385,6 @@ func assertCloseStatus(exp websocket.StatusCode, err error) error { type connTest struct { t testing.TB ctx context.Context - - doneFuncs []func() } func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) { @@ -317,30 +395,20 @@ func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *webs ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) tt = &connTest{t: t, ctx: ctx} - tt.appendDone(cancel) + t.Cleanup(cancel) c1, c2 = wstest.Pipe(dialOpts, acceptOpts) if xrand.Bool() { c1, c2 = c2, c1 } - tt.appendDone(func() { - c2.Close(websocket.StatusInternalError, "") - c1.Close(websocket.StatusInternalError, "") + t.Cleanup(func() { + c2.CloseNow() + c1.CloseNow() }) return tt, c1, c2 } -func (tt *connTest) appendDone(f func()) { - tt.doneFuncs = append(tt.doneFuncs, f) -} - -func (tt *connTest) cleanup() { - for i := len(tt.doneFuncs) - 1; i >= 0; i-- { - tt.doneFuncs[i]() - } -} - func (tt *connTest) goEchoLoop(c *websocket.Conn) { ctx, cancel := context.WithCancel(tt.ctx) @@ -348,7 +416,7 @@ func (tt *connTest) goEchoLoop(c *websocket.Conn) { err := wstest.EchoLoop(ctx, c) return assertCloseStatus(websocket.StatusNormalClosure, err) }) - tt.appendDone(func() { + tt.t.Cleanup(func() { cancel() err := <-echoLoopErr if err != nil { @@ -370,7 +438,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) { } } }) - tt.appendDone(func() { + tt.t.Cleanup(func() { cancel() err := <-discardLoopErr if err != nil { @@ -389,7 +457,7 @@ func BenchmarkConn(b *testing.B) { mode: websocket.CompressionDisabled, }, { - name: "compress", + name: "compressContextTakeover", mode: websocket.CompressionContextTakeover, }, { @@ -404,7 +472,6 @@ func BenchmarkConn(b *testing.B) { }, &websocket.AcceptOptions{ CompressionMode: bc.mode, }) - defer bb.cleanup() bb.goEchoLoop(c2) @@ -438,7 +505,7 @@ func BenchmarkConn(b *testing.B) { typ, r, err := c1.Reader(bb.ctx) if err != nil { - b.Fatal(err) + b.Fatal(i, err) } if websocket.MessageText != typ { assert.Equal(b, "data type", websocket.MessageText, typ) @@ -494,36 +561,55 @@ func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOp return assertCloseStatus(websocket.StatusNormalClosure, err) } -func TestGin(t *testing.T) { - t.Parallel() +func assertEcho(tb testing.TB, ctx context.Context, c *websocket.Conn) { + exp := xrand.String(xrand.Int(131072)) - gin.SetMode(gin.ReleaseMode) - r := gin.New() - r.GET("/", func(ginCtx *gin.Context) { - err := echoServer(ginCtx.Writer, ginCtx.Request, nil) - if err != nil { - t.Error(err) - } + werr := xsync.Go(func() error { + return wsjson.Write(ctx, c, exp) }) - s := httptest.NewServer(r) - defer s.Close() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) - defer cancel() - - c, _, err := websocket.Dial(ctx, s.URL, nil) - assert.Success(t, err) - defer c.Close(websocket.StatusInternalError, "") + var act interface{} + c.SetReadLimit(1 << 30) + err := wsjson.Read(ctx, c, &act) + assert.Success(tb, err) + assert.Equal(tb, "read msg", exp, act) + + select { + case err := <-werr: + assert.Success(tb, err) + case <-ctx.Done(): + tb.Fatal(ctx.Err()) + } +} - err = wsjson.Write(ctx, c, "hello") - assert.Success(t, err) +func assertClose(tb testing.TB, c *websocket.Conn) { + tb.Helper() + err := c.Close(websocket.StatusNormalClosure, "") + assert.Success(tb, err) +} - var v interface{} - err = wsjson.Read(ctx, c, &v) - assert.Success(t, err) - assert.Equal(t, "read msg", "hello", v) +func TestConcurrentClosePing(t *testing.T) { + t.Parallel() + for i := 0; i < 64; i++ { + func() { + c1, c2 := wstest.Pipe(nil, nil) + defer c1.CloseNow() + defer c2.CloseNow() + c1.CloseRead(context.Background()) + c2.CloseRead(context.Background()) + errc := xsync.Go(func() error { + for range time.Tick(time.Millisecond) { + err := c1.Ping(context.Background()) + if err != nil { + return err + } + } + panic("unreachable") + }) - err = c.Close(websocket.StatusNormalClosure, "") - assert.Success(t, err) + time.Sleep(10 * time.Millisecond) + assert.Success(t, c1.Close(websocket.StatusNormalClosure, "")) + <-errc + }() + } } diff --git a/dial.go b/dial.go index 7a7787ff..e4c4daa1 100644 --- a/dial.go +++ b/dial.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package websocket @@ -10,7 +11,6 @@ import ( "encoding/base64" "fmt" "io" - "io/ioutil" "net/http" "net/url" "strings" @@ -30,11 +30,15 @@ type DialOptions struct { // HTTPHeader specifies the HTTP headers included in the handshake request. HTTPHeader http.Header + // Host optionally overrides the Host HTTP header to send. If empty, the value + // of URL.Host will be used. + Host string + // Subprotocols lists the WebSocket subprotocols to negotiate with the server. Subprotocols []string // CompressionMode controls the compression mode. - // Defaults to CompressionNoContextTakeover. + // Defaults to CompressionDisabled. // // See docs on CompressionMode for details. CompressionMode CompressionMode @@ -46,6 +50,45 @@ type DialOptions struct { CompressionThreshold int } +func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) { + var cancel context.CancelFunc + + var o DialOptions + if opts != nil { + o = *opts + } + if o.HTTPClient == nil { + o.HTTPClient = http.DefaultClient + } + if o.HTTPClient.Timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, o.HTTPClient.Timeout) + + newClient := *o.HTTPClient + newClient.Timeout = 0 + o.HTTPClient = &newClient + } + if o.HTTPHeader == nil { + o.HTTPHeader = http.Header{} + } + newClient := *o.HTTPClient + oldCheckRedirect := o.HTTPClient.CheckRedirect + newClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + switch req.URL.Scheme { + case "ws": + req.URL.Scheme = "http" + case "wss": + req.URL.Scheme = "https" + } + if oldCheckRedirect != nil { + return oldCheckRedirect(req, via) + } + return nil + } + o.HTTPClient = &newClient + + return ctx, cancel, &o +} + // Dial performs a WebSocket handshake on url. // // The response is the WebSocket handshake response from the server. @@ -66,26 +109,10 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { defer errd.Wrap(&err, "failed to WebSocket dial") - if opts == nil { - opts = &DialOptions{} - } - - opts = &*opts - if opts.HTTPClient == nil { - opts.HTTPClient = http.DefaultClient - } else if opts.HTTPClient.Timeout > 0 { - var cancel context.CancelFunc - - ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout) + var cancel context.CancelFunc + ctx, cancel, opts = opts.cloneWithDefaults(ctx) + if cancel != nil { defer cancel() - - newClient := *opts.HTTPClient - newClient.Timeout = 0 - opts.HTTPClient = &newClient - } - - if opts.HTTPHeader == nil { - opts.HTTPHeader = http.Header{} } secWebSocketKey, err := secWebSocketKey(rand) @@ -114,9 +141,9 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( }) defer timer.Stop() - b, _ := ioutil.ReadAll(r) + b, _ := io.ReadAll(r) respBody.Close() - resp.Body = ioutil.NopCloser(bytes.NewReader(b)) + resp.Body = io.NopCloser(bytes.NewReader(b)) } }() @@ -157,7 +184,13 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) } - req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to create new http request: %w", err) + } + if len(opts.Host) > 0 { + req.Host = opts.Host + } req.Header = opts.HTTPHeader.Clone() req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "websocket") @@ -167,7 +200,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } if copts != nil { - copts.setHeader(req.Header) + req.Header.Set("Sec-WebSocket-Extensions", copts.String()) } resp, err := opts.HTTPClient.Do(req) @@ -243,7 +276,8 @@ func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compress return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) } - copts = &*copts + _copts := *copts + copts = &_copts for _, p := range ext.params { switch p { @@ -254,6 +288,10 @@ func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compress copts.serverNoContextTakeover = true continue } + if strings.HasPrefix(p, "server_max_window_bits=") { + // We can't adjust the deflate window, but decoding with a larger window is acceptable. + continue + } return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) } diff --git a/dial_test.go b/dial_test.go index 28c255c6..237a2874 100644 --- a/dial_test.go +++ b/dial_test.go @@ -1,19 +1,24 @@ +//go:build !js // +build !js -package websocket +package websocket_test import ( + "bytes" "context" "crypto/rand" "io" - "io/ioutil" "net/http" "net/http/httptest" + "net/url" "strings" "testing" "time" + "nhooyr.io/websocket" "nhooyr.io/websocket/internal/test/assert" + "nhooyr.io/websocket/internal/util" + "nhooyr.io/websocket/internal/xsync" ) func TestBadDials(t *testing.T) { @@ -23,10 +28,11 @@ func TestBadDials(t *testing.T) { t.Parallel() testCases := []struct { - name string - url string - opts *DialOptions - rand readerFunc + name string + url string + opts *websocket.DialOptions + rand util.ReaderFunc + nilCtx bool }{ { name: "badURL", @@ -46,6 +52,11 @@ func TestBadDials(t *testing.T) { return 0, io.EOF }, }, + { + name: "nilContext", + url: "http://localhost", + nilCtx: true, + }, } for _, tc := range testCases { @@ -53,14 +64,18 @@ func TestBadDials(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() + var ctx context.Context + var cancel func() + if !tc.nilCtx { + ctx, cancel = context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + } if tc.rand == nil { tc.rand = rand.Reader.Read } - _, _, err := dial(ctx, tc.url, tc.opts, tc.rand) + _, _, err := websocket.ExportedDial(ctx, tc.url, tc.opts, tc.rand) assert.Error(t, err) }) } @@ -72,10 +87,10 @@ func TestBadDials(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ + _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) { return &http.Response{ - Body: ioutil.NopCloser(strings.NewReader("hi")), + Body: io.NopCloser(strings.NewReader("hi")), }, nil }), }) @@ -92,22 +107,82 @@ func TestBadDials(t *testing.T) { h := http.Header{} h.Set("Connection", "Upgrade") h.Set("Upgrade", "websocket") - h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) + h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) return &http.Response{ StatusCode: http.StatusSwitchingProtocols, Header: h, - Body: ioutil.NopCloser(strings.NewReader("hi")), + Body: io.NopCloser(strings.NewReader("hi")), }, nil } - _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ + _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ HTTPClient: mockHTTPClient(rt), }) assert.Contains(t, err, "response body is not a io.ReadWriteCloser") }) } +func Test_verifyHostOverride(t *testing.T) { + testCases := []struct { + name string + host string + exp string + }{ + { + name: "noOverride", + host: "", + exp: "example.com", + }, + { + name: "hostOverride", + host: "example.net", + exp: "example.net", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + rt := func(r *http.Request) (*http.Response, error) { + assert.Equal(t, "Host", tc.exp, r.Host) + + h := http.Header{} + h.Set("Connection", "Upgrade") + h.Set("Upgrade", "websocket") + h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) + + return &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + Header: h, + Body: mockBody{bytes.NewBufferString("hi")}, + }, nil + } + + c, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ + HTTPClient: mockHTTPClient(rt), + Host: tc.host, + }) + assert.Success(t, err) + c.CloseNow() + }) + } + +} + +type mockBody struct { + *bytes.Buffer +} + +func (mb mockBody) Close() error { + return nil +} + func Test_verifyServerHandshake(t *testing.T) { t.Parallel() @@ -201,18 +276,18 @@ func Test_verifyServerHandshake(t *testing.T) { resp := w.Result() r := httptest.NewRequest("GET", "/", nil) - key, err := secWebSocketKey(rand.Reader) + key, err := websocket.SecWebSocketKey(rand.Reader) assert.Success(t, err) r.Header.Set("Sec-WebSocket-Key", key) if resp.Header.Get("Sec-WebSocket-Accept") == "" { - resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) + resp.Header.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(key)) } - opts := &DialOptions{ + opts := &websocket.DialOptions{ Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","), } - _, err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp) + _, err = websocket.VerifyServerResponse(opts, websocket.CompressionModeOpts(opts.CompressionMode), key, resp) if tc.success { assert.Success(t, err) } else { @@ -233,3 +308,113 @@ type roundTripperFunc func(*http.Request) (*http.Response, error) func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } + +func TestDialRedirect(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ + HTTPClient: mockHTTPClient(func(r *http.Request) (*http.Response, error) { + resp := &http.Response{ + Header: http.Header{}, + } + if r.URL.Scheme != "https" { + resp.Header.Set("Location", "wss://example.com") + resp.StatusCode = http.StatusFound + return resp, nil + } + resp.Header.Set("Connection", "Upgrade") + resp.Header.Set("Upgrade", "meow") + resp.StatusCode = http.StatusSwitchingProtocols + return resp, nil + }), + }) + assert.Contains(t, err, "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \"meow\" does not contain websocket") +} + +type forwardProxy struct { + hc *http.Client +} + +func newForwardProxy() *forwardProxy { + return &forwardProxy{ + hc: &http.Client{}, + } +} + +func (fc *forwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) + defer cancel() + + r = r.WithContext(ctx) + r.RequestURI = "" + resp, err := fc.hc.Do(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + defer resp.Body.Close() + + for k, v := range resp.Header { + w.Header()[k] = v + } + w.Header().Set("PROXIED", "true") + w.WriteHeader(resp.StatusCode) + if resprw, ok := resp.Body.(io.ReadWriter); ok { + c, brw, err := w.(http.Hijacker).Hijack() + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + brw.Flush() + + errc1 := xsync.Go(func() error { + _, err := io.Copy(c, resprw) + return err + }) + errc2 := xsync.Go(func() error { + _, err := io.Copy(resprw, c) + return err + }) + select { + case <-errc1: + case <-errc2: + case <-r.Context().Done(): + } + } else { + io.Copy(w, resp.Body) + } +} + +func TestDialViaProxy(t *testing.T) { + t.Parallel() + + ps := httptest.NewServer(newForwardProxy()) + defer ps.Close() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := echoServer(w, r, nil) + assert.Success(t, err) + })) + defer s.Close() + + psu, err := url.Parse(ps.URL) + assert.Success(t, err) + proxyTransport := http.DefaultTransport.(*http.Transport).Clone() + proxyTransport.Proxy = http.ProxyURL(psu) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + c, resp, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{ + HTTPClient: &http.Client{ + Transport: proxyTransport, + }, + }) + assert.Success(t, err) + assert.Equal(t, "", "true", resp.Header.Get("PROXIED")) + + assertEcho(t, ctx, c) + assertClose(t, c) +} diff --git a/doc.go b/doc.go index efa920e3..2ab648a6 100644 --- a/doc.go +++ b/doc.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js // Package websocket implements the RFC 6455 WebSocket protocol. @@ -12,11 +13,11 @@ // // The examples are the best way to understand how to correctly use the library. // -// The wsjson and wspb subpackages contain helpers for JSON and protobuf messages. +// The wsjson subpackage contain helpers for JSON and protobuf messages. // // More documentation at https://nhooyr.io/websocket. // -// Wasm +// # Wasm // // The client side supports compiling to Wasm. // It wraps the WebSocket browser API. @@ -25,8 +26,9 @@ // // Some important caveats to be aware of: // -// - Accept always errors out -// - Conn.Ping is no-op -// - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op -// - *http.Response from Dial is &http.Response{} with a 101 status code on success +// - Accept always errors out +// - Conn.Ping is no-op +// - Conn.CloseNow is Close(StatusGoingAway, "") +// - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op +// - *http.Response from Dial is &http.Response{} with a 101 status code on success package websocket // import "nhooyr.io/websocket" diff --git a/example_test.go b/example_test.go index 632c4d6e..590c0411 100644 --- a/example_test.go +++ b/example_test.go @@ -20,7 +20,7 @@ func ExampleAccept() { log.Println(err) return } - defer c.Close(websocket.StatusInternalError, "the sky is falling") + defer c.CloseNow() ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) defer cancel() @@ -50,7 +50,7 @@ func ExampleDial() { if err != nil { log.Fatal(err) } - defer c.Close(websocket.StatusInternalError, "the sky is falling") + defer c.CloseNow() err = wsjson.Write(ctx, c, "hi") if err != nil { @@ -71,7 +71,7 @@ func ExampleCloseStatus() { if err != nil { log.Fatal(err) } - defer c.Close(websocket.StatusInternalError, "the sky is falling") + defer c.CloseNow() _, _, err = c.Reader(ctx) if websocket.CloseStatus(err) != websocket.StatusNormalClosure { @@ -88,7 +88,7 @@ func Example_writeOnly() { log.Println(err) return } - defer c.Close(websocket.StatusInternalError, "the sky is falling") + defer c.CloseNow() ctx, cancel := context.WithTimeout(r.Context(), time.Minute*10) defer cancel() @@ -135,64 +135,37 @@ func Example_crossOrigin() { log.Fatal(err) } -// This example demonstrates how to create a WebSocket server -// that gracefully exits when sent a signal. -// -// It starts a WebSocket server that keeps every connection open -// for 10 seconds. -// If you CTRL+C while a connection is open, it will wait at most 30s -// for all connections to terminate before shutting down. -// func ExampleGrace() { -// fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// c, err := websocket.Accept(w, r, nil) -// if err != nil { -// log.Println(err) -// return -// } -// defer c.Close(websocket.StatusInternalError, "the sky is falling") -// -// ctx := c.CloseRead(r.Context()) -// select { -// case <-ctx.Done(): -// case <-time.After(time.Second * 10): -// } -// -// c.Close(websocket.StatusNormalClosure, "") -// }) -// -// var g websocket.Grace -// s := &http.Server{ -// Handler: g.Handler(fn), -// ReadTimeout: time.Second * 15, -// WriteTimeout: time.Second * 15, -// } -// -// errc := make(chan error, 1) -// go func() { -// errc <- s.ListenAndServe() -// }() -// -// sigs := make(chan os.Signal, 1) -// signal.Notify(sigs, os.Interrupt) -// select { -// case err := <-errc: -// log.Printf("failed to listen and serve: %v", err) -// case sig := <-sigs: -// log.Printf("terminating: %v", sig) -// } -// -// ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) -// defer cancel() -// s.Shutdown(ctx) -// g.Shutdown(ctx) -// } +func ExampleConn_Ping() { + // Dials a server and pings it 5 times. + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) + if err != nil { + log.Fatal(err) + } + defer c.CloseNow() + + // Required to read the Pongs from the server. + ctx = c.CloseRead(ctx) + + for i := 0; i < 5; i++ { + err = c.Ping(ctx) + if err != nil { + log.Fatal(err) + } + } + + c.Close(websocket.StatusNormalClosure, "") +} // This example demonstrates full stack chat with an automated test. func Example_fullStackChat() { - // https://github.com/nhooyr/websocket/tree/master/examples/chat + // https://github.com/nhooyr/websocket/tree/master/internal/examples/chat } // This example demonstrates a echo server. func Example_echo() { - // https://github.com/nhooyr/websocket/tree/master/examples/echo + // https://github.com/nhooyr/websocket/tree/master/internal/examples/echo } diff --git a/export_test.go b/export_test.go index 88b82c9f..a644d8f0 100644 --- a/export_test.go +++ b/export_test.go @@ -1,10 +1,17 @@ +//go:build !js // +build !js package websocket +import ( + "net" + + "nhooyr.io/websocket/internal/util" +) + func (c *Conn) RecordBytesWritten() *int { var bytesWritten int - c.bw.Reset(writerFunc(func(p []byte) (int, error) { + c.bw.Reset(util.WriterFunc(func(p []byte) (int, error) { bytesWritten += len(p) return c.rwc.Write(p) })) @@ -13,10 +20,19 @@ func (c *Conn) RecordBytesWritten() *int { func (c *Conn) RecordBytesRead() *int { var bytesRead int - c.br.Reset(readerFunc(func(p []byte) (int, error) { + c.br.Reset(util.ReaderFunc(func(p []byte) (int, error) { n, err := c.rwc.Read(p) bytesRead += n return n, err })) return &bytesRead } + +var ErrClosed = net.ErrClosed + +var ExportedDial = dial +var SecWebSocketAccept = secWebSocketAccept +var SecWebSocketKey = secWebSocketKey +var VerifyServerResponse = verifyServerResponse + +var CompressionModeOpts = CompressionMode.opts diff --git a/frame.go b/frame.go index 2a036f94..d5631863 100644 --- a/frame.go +++ b/frame.go @@ -1,3 +1,5 @@ +//go:build !js + package websocket import ( @@ -6,7 +8,6 @@ import ( "fmt" "io" "math" - "math/bits" "nhooyr.io/websocket/internal/errd" ) @@ -170,125 +171,3 @@ func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) { return nil } - -// mask applies the WebSocket masking algorithm to p -// with the given key. -// See https://tools.ietf.org/html/rfc6455#section-5.3 -// -// The returned value is the correctly rotated key to -// to continue to mask/unmask the message. -// -// It is optimized for LittleEndian and expects the key -// to be in little endian. -// -// See https://github.com/golang/go/issues/31586 -func mask(key uint32, b []byte) uint32 { - if len(b) >= 8 { - key64 := uint64(key)<<32 | uint64(key) - - // At some point in the future we can clean these unrolled loops up. - // See https://github.com/golang/go/issues/31586#issuecomment-487436401 - - // Then we xor until b is less than 128 bytes. - for len(b) >= 128 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - v = binary.LittleEndian.Uint64(b[32:40]) - binary.LittleEndian.PutUint64(b[32:40], v^key64) - v = binary.LittleEndian.Uint64(b[40:48]) - binary.LittleEndian.PutUint64(b[40:48], v^key64) - v = binary.LittleEndian.Uint64(b[48:56]) - binary.LittleEndian.PutUint64(b[48:56], v^key64) - v = binary.LittleEndian.Uint64(b[56:64]) - binary.LittleEndian.PutUint64(b[56:64], v^key64) - v = binary.LittleEndian.Uint64(b[64:72]) - binary.LittleEndian.PutUint64(b[64:72], v^key64) - v = binary.LittleEndian.Uint64(b[72:80]) - binary.LittleEndian.PutUint64(b[72:80], v^key64) - v = binary.LittleEndian.Uint64(b[80:88]) - binary.LittleEndian.PutUint64(b[80:88], v^key64) - v = binary.LittleEndian.Uint64(b[88:96]) - binary.LittleEndian.PutUint64(b[88:96], v^key64) - v = binary.LittleEndian.Uint64(b[96:104]) - binary.LittleEndian.PutUint64(b[96:104], v^key64) - v = binary.LittleEndian.Uint64(b[104:112]) - binary.LittleEndian.PutUint64(b[104:112], v^key64) - v = binary.LittleEndian.Uint64(b[112:120]) - binary.LittleEndian.PutUint64(b[112:120], v^key64) - v = binary.LittleEndian.Uint64(b[120:128]) - binary.LittleEndian.PutUint64(b[120:128], v^key64) - b = b[128:] - } - - // Then we xor until b is less than 64 bytes. - for len(b) >= 64 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - v = binary.LittleEndian.Uint64(b[32:40]) - binary.LittleEndian.PutUint64(b[32:40], v^key64) - v = binary.LittleEndian.Uint64(b[40:48]) - binary.LittleEndian.PutUint64(b[40:48], v^key64) - v = binary.LittleEndian.Uint64(b[48:56]) - binary.LittleEndian.PutUint64(b[48:56], v^key64) - v = binary.LittleEndian.Uint64(b[56:64]) - binary.LittleEndian.PutUint64(b[56:64], v^key64) - b = b[64:] - } - - // Then we xor until b is less than 32 bytes. - for len(b) >= 32 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - b = b[32:] - } - - // Then we xor until b is less than 16 bytes. - for len(b) >= 16 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - b = b[16:] - } - - // Then we xor until b is less than 8 bytes. - for len(b) >= 8 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - b = b[8:] - } - } - - // Then we xor until b is less than 4 bytes. - for len(b) >= 4 { - v := binary.LittleEndian.Uint32(b) - binary.LittleEndian.PutUint32(b, v^key) - b = b[4:] - } - - // xor remaining bytes. - for i := range b { - b[i] ^= byte(key) - key = bits.RotateLeft32(key, -8) - } - - return key -} diff --git a/frame_test.go b/frame_test.go index 76826248..bd626358 100644 --- a/frame_test.go +++ b/frame_test.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package websocket @@ -11,10 +12,6 @@ import ( "strconv" "testing" "time" - _ "unsafe" - - "github.com/gobwas/ws" - _ "github.com/gorilla/websocket" "nhooyr.io/websocket/internal/test/assert" ) @@ -54,7 +51,7 @@ func TestHeader(t *testing.T) { r := rand.New(rand.NewSource(time.Now().UnixNano())) randBool := func() bool { - return r.Intn(1) == 0 + return r.Intn(2) == 0 } for i := 0; i < 10000; i++ { @@ -66,9 +63,11 @@ func TestHeader(t *testing.T) { opcode: opcode(r.Intn(16)), masked: randBool(), - maskKey: r.Uint32(), payloadLength: r.Int63(), } + if h.masked { + h.maskKey = r.Uint32() + } testHeader(t, h) } @@ -98,7 +97,7 @@ func Test_mask(t *testing.T) { key := []byte{0xa, 0xb, 0xc, 0xff} key32 := binary.LittleEndian.Uint32(key) p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} - gotKey32 := mask(key32, p) + gotKey32 := mask(p, key32) expP := []byte{0, 0, 0, 0x0d, 0x6} assert.Equal(t, "p", expP, p) @@ -106,87 +105,3 @@ func Test_mask(t *testing.T) { expKey32 := bits.RotateLeft32(key32, -8) assert.Equal(t, "key32", expKey32, gotKey32) } - -func basicMask(maskKey [4]byte, pos int, b []byte) int { - for i := range b { - b[i] ^= maskKey[pos&3] - pos++ - } - return pos & 3 -} - -//go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes -func gorillaMaskBytes(key [4]byte, pos int, b []byte) int - -func Benchmark_mask(b *testing.B) { - sizes := []int{ - 2, - 3, - 4, - 8, - 16, - 32, - 128, - 512, - 4096, - 16384, - } - - fns := []struct { - name string - fn func(b *testing.B, key [4]byte, p []byte) - }{ - { - name: "basic", - fn: func(b *testing.B, key [4]byte, p []byte) { - for i := 0; i < b.N; i++ { - basicMask(key, 0, p) - } - }, - }, - - { - name: "nhooyr", - fn: func(b *testing.B, key [4]byte, p []byte) { - key32 := binary.LittleEndian.Uint32(key[:]) - b.ResetTimer() - - for i := 0; i < b.N; i++ { - mask(key32, p) - } - }, - }, - { - name: "gorilla", - fn: func(b *testing.B, key [4]byte, p []byte) { - for i := 0; i < b.N; i++ { - gorillaMaskBytes(key, 0, p) - } - }, - }, - { - name: "gobwas", - fn: func(b *testing.B, key [4]byte, p []byte) { - for i := 0; i < b.N; i++ { - ws.Cipher(p, key, 0) - } - }, - }, - } - - key := [4]byte{1, 2, 3, 4} - - for _, size := range sizes { - p := make([]byte, size) - - b.Run(strconv.Itoa(size), func(b *testing.B) { - for _, fn := range fns { - b.Run(fn.name, func(b *testing.B) { - b.SetBytes(int64(size)) - - fn.fn(b, key, p) - }) - } - }) - } -} diff --git a/go.mod b/go.mod index c5f1a20f..715a9f7a 100644 --- a/go.mod +++ b/go.mod @@ -1,15 +1,3 @@ module nhooyr.io/websocket -go 1.13 - -require ( - github.com/gin-gonic/gin v1.6.3 - github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect - github.com/gobwas/pool v0.2.0 // indirect - github.com/gobwas/ws v1.0.2 - github.com/golang/protobuf v1.3.5 - github.com/google/go-cmp v0.4.0 - github.com/gorilla/websocket v1.4.1 - github.com/klauspost/compress v1.10.3 - golang.org/x/time v0.0.0-20191024005414-555d28b269f0 -) +go 1.19 diff --git a/go.sum b/go.sum index 155c3013..e69de29b 100644 --- a/go.sum +++ b/go.sum @@ -1,64 +0,0 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= -github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14= -github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= -github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= -github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= -github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= -github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= -github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= -github.com/go-playground/validator/v10 v10.2.0 h1:KgJ0snyC2R9VXYN2rneOtQcw5aHQB1Vv0sFl1UcHBOY= -github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= -github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= -github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= -github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= -github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= -github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= -github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.3.5 h1:F768QJ1E9tib+q5Sc8MkdJi1RxLTbRcTf8LJV56aRls= -github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= -github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= -github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= -github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8= -github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= -github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= -github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= -github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= -github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= -github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= -github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42 h1:vEOn+mP2zCOVzKckCZy6YsCtDblrpj/w7B9nxGNELpg= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/examples/README.md b/internal/examples/README.md similarity index 100% rename from examples/README.md rename to internal/examples/README.md diff --git a/examples/chat/README.md b/internal/examples/chat/README.md similarity index 97% rename from examples/chat/README.md rename to internal/examples/chat/README.md index ca1024a0..574c6994 100644 --- a/examples/chat/README.md +++ b/internal/examples/chat/README.md @@ -5,7 +5,7 @@ This directory contains a full stack example of a simple chat webapp using nhooy ```bash $ cd examples/chat $ go run . localhost:0 -listening on http://127.0.0.1:51055 +listening on ws://127.0.0.1:51055 ``` Visit the printed URL to submit and view broadcasted messages in a browser. diff --git a/examples/chat/chat.go b/internal/examples/chat/chat.go similarity index 88% rename from examples/chat/chat.go rename to internal/examples/chat/chat.go index 532e50f5..8b1e30c1 100644 --- a/examples/chat/chat.go +++ b/internal/examples/chat/chat.go @@ -3,8 +3,9 @@ package main import ( "context" "errors" - "io/ioutil" + "io" "log" + "net" "net/http" "sync" "time" @@ -69,14 +70,7 @@ func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { // subscribeHandler accepts the WebSocket connection and then subscribes // it to all future messages. func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, nil) - if err != nil { - cs.logf("%v", err) - return - } - defer c.Close(websocket.StatusInternalError, "") - - err = cs.subscribe(r.Context(), c) + err := cs.subscribe(r.Context(), w, r) if errors.Is(err, context.Canceled) { return } @@ -98,7 +92,7 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { return } body := http.MaxBytesReader(w, r.Body, 8192) - msg, err := ioutil.ReadAll(body) + msg, err := io.ReadAll(body) if err != nil { http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) return @@ -117,18 +111,39 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { // // It uses CloseRead to keep reading from the connection to process control // messages and cancel the context if the connection drops. -func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - +func (cs *chatServer) subscribe(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + var mu sync.Mutex + var c *websocket.Conn + var closed bool s := &subscriber{ msgs: make(chan []byte, cs.subscriberMessageBuffer), closeSlow: func() { - c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") + mu.Lock() + defer mu.Unlock() + closed = true + if c != nil { + c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") + } }, } cs.addSubscriber(s) defer cs.deleteSubscriber(s) + c2, err := websocket.Accept(w, r, nil) + if err != nil { + return err + } + mu.Lock() + if closed { + mu.Unlock() + return net.ErrClosed + } + c = c2 + mu.Unlock() + defer c.CloseNow() + + ctx = c.CloseRead(ctx) + for { select { case msg := <-s.msgs: diff --git a/examples/chat/chat_test.go b/internal/examples/chat/chat_test.go similarity index 100% rename from examples/chat/chat_test.go rename to internal/examples/chat/chat_test.go diff --git a/examples/chat/index.css b/internal/examples/chat/index.css similarity index 86% rename from examples/chat/index.css rename to internal/examples/chat/index.css index 73a8e0f3..ce27c378 100644 --- a/examples/chat/index.css +++ b/internal/examples/chat/index.css @@ -54,7 +54,7 @@ body { margin: 0 0 0 10px; } -#publish-form input[type="text"] { +#publish-form input[type='text'] { flex-grow: 1; -moz-appearance: none; @@ -64,7 +64,7 @@ body { border: 1px solid #ccc; } -#publish-form input[type="submit"] { +#publish-form input[type='submit'] { color: white; background-color: black; border-radius: 5px; @@ -72,10 +72,10 @@ body { border: none; } -#publish-form input[type="submit"]:hover { +#publish-form input[type='submit']:hover { background-color: red; } -#publish-form input[type="submit"]:active { +#publish-form input[type='submit']:active { background-color: red; } diff --git a/examples/chat/index.html b/internal/examples/chat/index.html similarity index 98% rename from examples/chat/index.html rename to internal/examples/chat/index.html index 76ae8370..64edd286 100644 --- a/examples/chat/index.html +++ b/internal/examples/chat/index.html @@ -1,4 +1,4 @@ - + diff --git a/examples/chat/index.js b/internal/examples/chat/index.js similarity index 65% rename from examples/chat/index.js rename to internal/examples/chat/index.js index 5868e7ca..2efca013 100644 --- a/examples/chat/index.js +++ b/internal/examples/chat/index.js @@ -6,21 +6,21 @@ function dial() { const conn = new WebSocket(`ws://${location.host}/subscribe`) - conn.addEventListener("close", ev => { + conn.addEventListener('close', ev => { appendLog(`WebSocket Disconnected code: ${ev.code}, reason: ${ev.reason}`, true) if (ev.code !== 1001) { - appendLog("Reconnecting in 1s", true) + appendLog('Reconnecting in 1s', true) setTimeout(dial, 1000) } }) - conn.addEventListener("open", ev => { - console.info("websocket connected") + conn.addEventListener('open', ev => { + console.info('websocket connected') }) // This is where we handle messages received. - conn.addEventListener("message", ev => { - if (typeof ev.data !== "string") { - console.error("unexpected message type", typeof ev.data) + conn.addEventListener('message', ev => { + if (typeof ev.data !== 'string') { + console.error('unexpected message type', typeof ev.data) return } const p = appendLog(ev.data) @@ -32,38 +32,38 @@ } dial() - const messageLog = document.getElementById("message-log") - const publishForm = document.getElementById("publish-form") - const messageInput = document.getElementById("message-input") + const messageLog = document.getElementById('message-log') + const publishForm = document.getElementById('publish-form') + const messageInput = document.getElementById('message-input') // appendLog appends the passed text to messageLog. function appendLog(text, error) { - const p = document.createElement("p") + const p = document.createElement('p') // Adding a timestamp to each message makes the log easier to read. p.innerText = `${new Date().toLocaleTimeString()}: ${text}` if (error) { - p.style.color = "red" - p.style.fontStyle = "bold" + p.style.color = 'red' + p.style.fontStyle = 'bold' } messageLog.append(p) return p } - appendLog("Submit a message to get started!") + appendLog('Submit a message to get started!') // onsubmit publishes the message from the user when the form is submitted. publishForm.onsubmit = async ev => { ev.preventDefault() const msg = messageInput.value - if (msg === "") { + if (msg === '') { return } - messageInput.value = "" + messageInput.value = '' expectingMessage = true try { - const resp = await fetch("/publish", { - method: "POST", + const resp = await fetch('/publish', { + method: 'POST', body: msg, }) if (resp.status !== 202) { diff --git a/examples/chat/main.go b/internal/examples/chat/main.go similarity index 95% rename from examples/chat/main.go rename to internal/examples/chat/main.go index 3fcec6be..e3432984 100644 --- a/examples/chat/main.go +++ b/internal/examples/chat/main.go @@ -31,7 +31,7 @@ func run() error { if err != nil { return err } - log.Printf("listening on http://%v", l.Addr()) + log.Printf("listening on ws://%v", l.Addr()) cs := newChatServer() s := &http.Server{ diff --git a/examples/echo/README.md b/internal/examples/echo/README.md similarity index 94% rename from examples/echo/README.md rename to internal/examples/echo/README.md index 7f42c3c5..ac03f640 100644 --- a/examples/echo/README.md +++ b/internal/examples/echo/README.md @@ -5,7 +5,7 @@ This directory contains a echo server example using nhooyr.io/websocket. ```bash $ cd examples/echo $ go run . localhost:0 -listening on http://127.0.0.1:51055 +listening on ws://127.0.0.1:51055 ``` You can use a WebSocket client like https://github.com/hashrocket/ws to connect. All messages diff --git a/examples/echo/main.go b/internal/examples/echo/main.go similarity index 95% rename from examples/echo/main.go rename to internal/examples/echo/main.go index 16d78a79..47e30d05 100644 --- a/examples/echo/main.go +++ b/internal/examples/echo/main.go @@ -31,7 +31,7 @@ func run() error { if err != nil { return err } - log.Printf("listening on http://%v", l.Addr()) + log.Printf("listening on ws://%v", l.Addr()) s := &http.Server{ Handler: echoServer{ diff --git a/examples/echo/server.go b/internal/examples/echo/server.go similarity index 95% rename from examples/echo/server.go rename to internal/examples/echo/server.go index e9f70f03..246ad582 100644 --- a/examples/echo/server.go +++ b/internal/examples/echo/server.go @@ -28,7 +28,7 @@ func (s echoServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.logf("%v", err) return } - defer c.Close(websocket.StatusInternalError, "the sky is falling") + defer c.CloseNow() if c.Subprotocol() != "echo" { c.Close(websocket.StatusPolicyViolation, "client must speak the echo subprotocol") diff --git a/examples/echo/server_test.go b/internal/examples/echo/server_test.go similarity index 100% rename from examples/echo/server_test.go rename to internal/examples/echo/server_test.go diff --git a/internal/examples/go.mod b/internal/examples/go.mod new file mode 100644 index 00000000..c98b81ce --- /dev/null +++ b/internal/examples/go.mod @@ -0,0 +1,10 @@ +module nhooyr.io/websocket/examples + +go 1.19 + +replace nhooyr.io/websocket => ../.. + +require ( + golang.org/x/time v0.3.0 + nhooyr.io/websocket v0.0.0-00010101000000-000000000000 +) diff --git a/internal/examples/go.sum b/internal/examples/go.sum new file mode 100644 index 00000000..f8a07e82 --- /dev/null +++ b/internal/examples/go.sum @@ -0,0 +1,2 @@ +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/test/assert/assert.go b/internal/test/assert/assert.go index 6eaf7fc3..1b90cc9f 100644 --- a/internal/test/assert/assert.go +++ b/internal/test/assert/assert.go @@ -1,29 +1,19 @@ package assert import ( + "errors" "fmt" "reflect" "strings" "testing" - - "github.com/golang/protobuf/proto" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" ) -// Diff returns a human readable diff between v1 and v2 -func Diff(v1, v2 interface{}) string { - return cmp.Diff(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool { - return true - }), cmp.Comparer(proto.Equal)) -} - // Equal asserts exp == act. -func Equal(t testing.TB, name string, exp, act interface{}) { +func Equal(t testing.TB, name string, exp, got interface{}) { t.Helper() - if diff := Diff(exp, act); diff != "" { - t.Fatalf("unexpected %v: %v", name, diff) + if !reflect.DeepEqual(exp, got) { + t.Fatalf("unexpected %v: expected %#v but got %#v", name, exp, got) } } @@ -54,3 +44,12 @@ func Contains(t testing.TB, v interface{}, sub string) { t.Fatalf("expected %q to contain %q", s, sub) } } + +// ErrorIs asserts errors.Is(got, exp) +func ErrorIs(t testing.TB, exp, got error) { + t.Helper() + + if !errors.Is(got, exp) { + t.Fatalf("expected %v but got %v", exp, got) + } +} diff --git a/internal/test/wstest/echo.go b/internal/test/wstest/echo.go index 8f4e47c8..dc21a8f0 100644 --- a/internal/test/wstest/echo.go +++ b/internal/test/wstest/echo.go @@ -8,7 +8,6 @@ import ( "time" "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/xrand" "nhooyr.io/websocket/internal/xsync" ) @@ -21,7 +20,7 @@ func EchoLoop(ctx context.Context, c *websocket.Conn) error { c.SetReadLimit(1 << 30) - ctx, cancel := context.WithTimeout(ctx, time.Minute) + ctx, cancel := context.WithTimeout(ctx, time.Minute*5) defer cancel() b := make([]byte, 32<<10) @@ -76,7 +75,7 @@ func Echo(ctx context.Context, c *websocket.Conn, max int) error { } if !bytes.Equal(msg, act) { - return fmt.Errorf("unexpected msg read: %v", assert.Diff(msg, act)) + return fmt.Errorf("unexpected msg read: %#v", act) } return nil diff --git a/internal/test/wstest/pipe.go b/internal/test/wstest/pipe.go index 1534f316..8e1deb47 100644 --- a/internal/test/wstest/pipe.go +++ b/internal/test/wstest/pipe.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package wstest @@ -24,7 +25,8 @@ func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) if dialOpts == nil { dialOpts = &websocket.DialOptions{} } - dialOpts = &*dialOpts + _dialOpts := *dialOpts + dialOpts = &_dialOpts dialOpts.HTTPClient = &http.Client{ Transport: tt, } diff --git a/internal/test/xrand/xrand.go b/internal/test/xrand/xrand.go index 8de1ede8..9bfb39ce 100644 --- a/internal/test/xrand/xrand.go +++ b/internal/test/xrand/xrand.go @@ -2,6 +2,7 @@ package xrand import ( "crypto/rand" + "encoding/base64" "fmt" "math/big" "strings" @@ -45,3 +46,8 @@ func Int(max int) int { } return int(x.Int64()) } + +// Base64 returns a randomly generated base64 string of length n. +func Base64(n int) string { + return base64.StdEncoding.EncodeToString(Bytes(n)) +} diff --git a/internal/thirdparty/doc.go b/internal/thirdparty/doc.go new file mode 100644 index 00000000..e756d09f --- /dev/null +++ b/internal/thirdparty/doc.go @@ -0,0 +1,2 @@ +// Package thirdparty contains third party benchmarks and tests. +package thirdparty diff --git a/internal/thirdparty/frame_test.go b/internal/thirdparty/frame_test.go new file mode 100644 index 00000000..89042e53 --- /dev/null +++ b/internal/thirdparty/frame_test.go @@ -0,0 +1,134 @@ +package thirdparty + +import ( + "encoding/binary" + "runtime" + "strconv" + "testing" + _ "unsafe" + + "github.com/gobwas/ws" + _ "github.com/gorilla/websocket" + _ "github.com/lesismal/nbio/nbhttp/websocket" + + _ "nhooyr.io/websocket" +) + +func basicMask(b []byte, maskKey [4]byte, pos int) int { + for i := range b { + b[i] ^= maskKey[pos&3] + pos++ + } + return pos & 3 +} + +//go:linkname maskGo nhooyr.io/websocket.maskGo +func maskGo(b []byte, key32 uint32) int + +//go:linkname maskAsm nhooyr.io/websocket.maskAsm +func maskAsm(b *byte, len int, key32 uint32) uint32 + +//go:linkname nbioMaskBytes github.com/lesismal/nbio/nbhttp/websocket.maskXOR +func nbioMaskBytes(b, key []byte) int + +//go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes +func gorillaMaskBytes(key [4]byte, pos int, b []byte) int + +func Benchmark_mask(b *testing.B) { + b.Run(runtime.GOARCH, benchmark_mask) +} + +func benchmark_mask(b *testing.B) { + sizes := []int{ + 8, + 16, + 32, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + } + + fns := []struct { + name string + fn func(b *testing.B, key [4]byte, p []byte) + }{ + { + name: "basic", + fn: func(b *testing.B, key [4]byte, p []byte) { + for i := 0; i < b.N; i++ { + basicMask(p, key, 0) + } + }, + }, + + { + name: "nhooyr-go", + fn: func(b *testing.B, key [4]byte, p []byte) { + key32 := binary.LittleEndian.Uint32(key[:]) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + maskGo(p, key32) + } + }, + }, + { + name: "wdvxdr1123-asm", + fn: func(b *testing.B, key [4]byte, p []byte) { + key32 := binary.LittleEndian.Uint32(key[:]) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + maskAsm(&p[0], len(p), key32) + } + }, + }, + + { + name: "gorilla", + fn: func(b *testing.B, key [4]byte, p []byte) { + for i := 0; i < b.N; i++ { + gorillaMaskBytes(key, 0, p) + } + }, + }, + { + name: "gobwas", + fn: func(b *testing.B, key [4]byte, p []byte) { + for i := 0; i < b.N; i++ { + ws.Cipher(p, key, 0) + } + }, + }, + { + name: "nbio", + fn: func(b *testing.B, key [4]byte, p []byte) { + keyb := key[:] + for i := 0; i < b.N; i++ { + nbioMaskBytes(p, keyb) + } + }, + }, + } + + key := [4]byte{1, 2, 3, 4} + + for _, fn := range fns { + b.Run(fn.name, func(b *testing.B) { + for _, size := range sizes { + p := make([]byte, size) + + b.Run(strconv.Itoa(size), func(b *testing.B) { + b.SetBytes(int64(size)) + + fn.fn(b, key, p) + }) + } + }) + } +} diff --git a/internal/thirdparty/gin_test.go b/internal/thirdparty/gin_test.go new file mode 100644 index 00000000..6d59578d --- /dev/null +++ b/internal/thirdparty/gin_test.go @@ -0,0 +1,75 @@ +package thirdparty + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + + "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/test/assert" + "nhooyr.io/websocket/internal/test/wstest" + "nhooyr.io/websocket/wsjson" +) + +func TestGin(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.ReleaseMode) + r := gin.New() + r.GET("/", func(ginCtx *gin.Context) { + err := echoServer(ginCtx.Writer, ginCtx.Request, nil) + if err != nil { + t.Error(err) + } + }) + + s := httptest.NewServer(r) + defer s.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + c, _, err := websocket.Dial(ctx, s.URL, nil) + assert.Success(t, err) + defer c.Close(websocket.StatusInternalError, "") + + err = wsjson.Write(ctx, c, "hello") + assert.Success(t, err) + + var v interface{} + err = wsjson.Read(ctx, c, &v) + assert.Success(t, err) + assert.Equal(t, "read msg", "hello", v) + + err = c.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) +} + +func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOptions) (err error) { + defer errd.Wrap(&err, "echo server failed") + + c, err := websocket.Accept(w, r, opts) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + + err = wstest.EchoLoop(r.Context(), c) + return assertCloseStatus(websocket.StatusNormalClosure, err) +} + +func assertCloseStatus(exp websocket.StatusCode, err error) error { + if websocket.CloseStatus(err) == -1 { + return fmt.Errorf("expected websocket.CloseError: %T %v", err, err) + } + if websocket.CloseStatus(err) != exp { + return fmt.Errorf("expected close status %v but got %v", exp, err) + } + return nil +} diff --git a/internal/thirdparty/go.mod b/internal/thirdparty/go.mod new file mode 100644 index 00000000..d991dd64 --- /dev/null +++ b/internal/thirdparty/go.mod @@ -0,0 +1,43 @@ +module nhooyr.io/websocket/internal/thirdparty + +go 1.19 + +replace nhooyr.io/websocket => ../.. + +require ( + github.com/gin-gonic/gin v1.9.1 + github.com/gobwas/ws v1.3.0 + github.com/gorilla/websocket v1.5.0 + github.com/lesismal/nbio v1.3.18 + nhooyr.io/websocket v0.0.0-00010101000000-000000000000 +) + +require ( + github.com/bytedance/sonic v1.9.1 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/gobwas/httphead v0.1.0 // indirect + github.com/gobwas/pool v0.2.1 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.2.4 // indirect + github.com/leodido/go-urn v1.2.4 // indirect + github.com/lesismal/llib v1.1.12 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.11 // indirect + golang.org/x/arch v0.3.0 // indirect + golang.org/x/crypto v0.9.0 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/sys v0.17.0 // indirect + golang.org/x/text v0.9.0 // indirect + google.golang.org/protobuf v1.30.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/internal/thirdparty/go.sum b/internal/thirdparty/go.sum new file mode 100644 index 00000000..1f542103 --- /dev/null +++ b/internal/thirdparty/go.sum @@ -0,0 +1,129 @@ +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= +github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= +github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= +github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= +github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= +github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= +github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= +github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.3.0 h1:sbeU3Y4Qzlb+MOzIe6mQGf7QR4Hkv6ZD0qhGkBFL2O0= +github.com/gobwas/ws v1.3.0/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= +github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/lesismal/llib v1.1.12 h1:KJFB8bL02V+QGIvILEw/w7s6bKj9Ps9Px97MZP2EOk0= +github.com/lesismal/llib v1.1.12/go.mod h1:70tFXXe7P1FZ02AU9l8LgSOK7d7sRrpnkUr3rd3gKSg= +github.com/lesismal/nbio v1.3.18 h1:kmJZlxjQpVfuCPYcXdv0Biv9LHVViJZet5K99Xs3RAs= +github.com/lesismal/nbio v1.3.18/go.mod h1:KWlouFT5cgDdW5sMX8RsHASUMGniea9X0XIellZ0B38= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= +github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= +github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= +github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= +golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210513122933-cd7d49e622d5/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= +golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= +golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/internal/util/util.go b/internal/util/util.go new file mode 100644 index 00000000..aa210703 --- /dev/null +++ b/internal/util/util.go @@ -0,0 +1,15 @@ +package util + +// WriterFunc is used to implement one off io.Writers. +type WriterFunc func(p []byte) (int, error) + +func (f WriterFunc) Write(p []byte) (int, error) { + return f(p) +} + +// ReaderFunc is used to implement one off io.Readers. +type ReaderFunc func(p []byte) (int, error) + +func (f ReaderFunc) Read(p []byte) (int, error) { + return f(p) +} diff --git a/internal/wsjs/wsjs_js.go b/internal/wsjs/wsjs_js.go index 26ffb456..11eb59cb 100644 --- a/internal/wsjs/wsjs_js.go +++ b/internal/wsjs/wsjs_js.go @@ -1,3 +1,4 @@ +//go:build js // +build js // Package wsjs implements typed access to the browser javascript WebSocket API. @@ -118,8 +119,6 @@ func (c WebSocket) OnMessage(fn func(m MessageEvent)) (remove func()) { Data: data, } fn(me) - - return }) } diff --git a/internal/xsync/go.go b/internal/xsync/go.go index 7a61f27f..5229b12a 100644 --- a/internal/xsync/go.go +++ b/internal/xsync/go.go @@ -2,6 +2,7 @@ package xsync import ( "fmt" + "runtime/debug" ) // Go allows running a function in another goroutine @@ -13,7 +14,7 @@ func Go(fn func() error) <-chan error { r := recover() if r != nil { select { - case errs <- fmt.Errorf("panic in go fn: %v", r): + case errs <- fmt.Errorf("panic in go fn: %v, %s", r, debug.Stack()): default: } } diff --git a/main_test.go b/main_test.go new file mode 100644 index 00000000..2b93bb18 --- /dev/null +++ b/main_test.go @@ -0,0 +1,30 @@ +package websocket_test + +import ( + "fmt" + "os" + "runtime" + "testing" +) + +func goroutineStacks() []byte { + buf := make([]byte, 512) + for { + m := runtime.Stack(buf, true) + if m < len(buf) { + return buf[:m] + } + buf = make([]byte, len(buf)*2) + } +} + +func TestMain(m *testing.M) { + code := m.Run() + if runtime.GOOS != "js" && runtime.NumGoroutine() != 1 || + runtime.GOOS == "js" && runtime.NumGoroutine() != 2 { + fmt.Fprintf(os.Stderr, "goroutine leak detected, expected 1 but got %d goroutines\n", runtime.NumGoroutine()) + fmt.Fprintf(os.Stderr, "%s\n", goroutineStacks()) + os.Exit(1) + } + os.Exit(code) +} diff --git a/make.sh b/make.sh new file mode 100755 index 00000000..170d00a8 --- /dev/null +++ b/make.sh @@ -0,0 +1,12 @@ +#!/bin/sh +set -eu +cd -- "$(dirname "$0")" + +echo "=== fmt.sh" +./ci/fmt.sh +echo "=== lint.sh" +./ci/lint.sh +echo "=== test.sh" +./ci/test.sh "$@" +echo "=== bench.sh" +./ci/bench.sh diff --git a/mask.go b/mask.go new file mode 100644 index 00000000..7bc0c8d5 --- /dev/null +++ b/mask.go @@ -0,0 +1,128 @@ +package websocket + +import ( + "encoding/binary" + "math/bits" +) + +// maskGo applies the WebSocket masking algorithm to p +// with the given key. +// See https://tools.ietf.org/html/rfc6455#section-5.3 +// +// The returned value is the correctly rotated key to +// to continue to mask/unmask the message. +// +// It is optimized for LittleEndian and expects the key +// to be in little endian. +// +// See https://github.com/golang/go/issues/31586 +func maskGo(b []byte, key uint32) uint32 { + if len(b) >= 8 { + key64 := uint64(key)<<32 | uint64(key) + + // At some point in the future we can clean these unrolled loops up. + // See https://github.com/golang/go/issues/31586#issuecomment-487436401 + + // Then we xor until b is less than 128 bytes. + for len(b) >= 128 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + v = binary.LittleEndian.Uint64(b[64:72]) + binary.LittleEndian.PutUint64(b[64:72], v^key64) + v = binary.LittleEndian.Uint64(b[72:80]) + binary.LittleEndian.PutUint64(b[72:80], v^key64) + v = binary.LittleEndian.Uint64(b[80:88]) + binary.LittleEndian.PutUint64(b[80:88], v^key64) + v = binary.LittleEndian.Uint64(b[88:96]) + binary.LittleEndian.PutUint64(b[88:96], v^key64) + v = binary.LittleEndian.Uint64(b[96:104]) + binary.LittleEndian.PutUint64(b[96:104], v^key64) + v = binary.LittleEndian.Uint64(b[104:112]) + binary.LittleEndian.PutUint64(b[104:112], v^key64) + v = binary.LittleEndian.Uint64(b[112:120]) + binary.LittleEndian.PutUint64(b[112:120], v^key64) + v = binary.LittleEndian.Uint64(b[120:128]) + binary.LittleEndian.PutUint64(b[120:128], v^key64) + b = b[128:] + } + + // Then we xor until b is less than 64 bytes. + for len(b) >= 64 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + b = b[64:] + } + + // Then we xor until b is less than 32 bytes. + for len(b) >= 32 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + b = b[32:] + } + + // Then we xor until b is less than 16 bytes. + for len(b) >= 16 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + b = b[16:] + } + + // Then we xor until b is less than 8 bytes. + for len(b) >= 8 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + b = b[8:] + } + } + + // Then we xor until b is less than 4 bytes. + for len(b) >= 4 { + v := binary.LittleEndian.Uint32(b) + binary.LittleEndian.PutUint32(b, v^key) + b = b[4:] + } + + // xor remaining bytes. + for i := range b { + b[i] ^= byte(key) + key = bits.RotateLeft32(key, -8) + } + + return key +} diff --git a/mask_amd64.s b/mask_amd64.s new file mode 100644 index 00000000..bd42be31 --- /dev/null +++ b/mask_amd64.s @@ -0,0 +1,127 @@ +#include "textflag.h" + +// func maskAsm(b *byte, len int, key uint32) +TEXT ·maskAsm(SB), NOSPLIT, $0-28 + // AX = b + // CX = len (left length) + // SI = key (uint32) + // DI = uint64(SI) | uint64(SI)<<32 + MOVQ b+0(FP), AX + MOVQ len+8(FP), CX + MOVL key+16(FP), SI + + // calculate the DI + // DI = SI<<32 | SI + MOVL SI, DI + MOVQ DI, DX + SHLQ $32, DI + ORQ DX, DI + + CMPQ CX, $15 + JLE less_than_16 + CMPQ CX, $63 + JLE less_than_64 + CMPQ CX, $128 + JLE sse + TESTQ $31, AX + JNZ unaligned + +unaligned_loop_1byte: + XORB SI, (AX) + INCQ AX + DECQ CX + ROLL $24, SI + TESTQ $7, AX + JNZ unaligned_loop_1byte + + // calculate DI again since SI was modified + // DI = SI<<32 | SI + MOVL SI, DI + MOVQ DI, DX + SHLQ $32, DI + ORQ DX, DI + + TESTQ $31, AX + JZ sse + +unaligned: + TESTQ $7, AX // AND $7 & len, if not zero jump to loop_1b. + JNZ unaligned_loop_1byte + +unaligned_loop: + // we don't need to check the CX since we know it's above 128 + XORQ DI, (AX) + ADDQ $8, AX + SUBQ $8, CX + TESTQ $31, AX + JNZ unaligned_loop + JMP sse + +sse: + CMPQ CX, $0x40 + JL less_than_64 + MOVQ DI, X0 + PUNPCKLQDQ X0, X0 + +sse_loop: + MOVOU 0*16(AX), X1 + MOVOU 1*16(AX), X2 + MOVOU 2*16(AX), X3 + MOVOU 3*16(AX), X4 + PXOR X0, X1 + PXOR X0, X2 + PXOR X0, X3 + PXOR X0, X4 + MOVOU X1, 0*16(AX) + MOVOU X2, 1*16(AX) + MOVOU X3, 2*16(AX) + MOVOU X4, 3*16(AX) + ADDQ $0x40, AX + SUBQ $0x40, CX + CMPQ CX, $0x40 + JAE sse_loop + +less_than_64: + TESTQ $32, CX + JZ less_than_32 + XORQ DI, (AX) + XORQ DI, 8(AX) + XORQ DI, 16(AX) + XORQ DI, 24(AX) + ADDQ $32, AX + +less_than_32: + TESTQ $16, CX + JZ less_than_16 + XORQ DI, (AX) + XORQ DI, 8(AX) + ADDQ $16, AX + +less_than_16: + TESTQ $8, CX + JZ less_than_8 + XORQ DI, (AX) + ADDQ $8, AX + +less_than_8: + TESTQ $4, CX + JZ less_than_4 + XORL SI, (AX) + ADDQ $4, AX + +less_than_4: + TESTQ $2, CX + JZ less_than_2 + XORW SI, (AX) + ROLL $16, SI + ADDQ $2, AX + +less_than_2: + TESTQ $1, CX + JZ done + XORB SI, (AX) + ROLL $24, SI + +done: + MOVL SI, ret+24(FP) + RET diff --git a/mask_arm64.s b/mask_arm64.s new file mode 100644 index 00000000..e494b43a --- /dev/null +++ b/mask_arm64.s @@ -0,0 +1,72 @@ +#include "textflag.h" + +// func maskAsm(b *byte, len int, key uint32) +TEXT ·maskAsm(SB), NOSPLIT, $0-28 + // R0 = b + // R1 = len + // R3 = key (uint32) + // R2 = uint64(key)<<32 | uint64(key) + MOVD b_ptr+0(FP), R0 + MOVD b_len+8(FP), R1 + MOVWU key+16(FP), R3 + MOVD R3, R2 + ORR R2<<32, R2, R2 + VDUP R2, V0.D2 + CMP $64, R1 + BLT less_than_64 + +loop_64: + VLD1 (R0), [V1.B16, V2.B16, V3.B16, V4.B16] + VEOR V1.B16, V0.B16, V1.B16 + VEOR V2.B16, V0.B16, V2.B16 + VEOR V3.B16, V0.B16, V3.B16 + VEOR V4.B16, V0.B16, V4.B16 + VST1.P [V1.B16, V2.B16, V3.B16, V4.B16], 64(R0) + SUBS $64, R1 + CMP $64, R1 + BGE loop_64 + +less_than_64: + CBZ R1, end + TBZ $5, R1, less_than_32 + VLD1 (R0), [V1.B16, V2.B16] + VEOR V1.B16, V0.B16, V1.B16 + VEOR V2.B16, V0.B16, V2.B16 + VST1.P [V1.B16, V2.B16], 32(R0) + +less_than_32: + TBZ $4, R1, less_than_16 + LDP (R0), (R11, R12) + EOR R11, R2, R11 + EOR R12, R2, R12 + STP.P (R11, R12), 16(R0) + +less_than_16: + TBZ $3, R1, less_than_8 + MOVD (R0), R11 + EOR R2, R11, R11 + MOVD.P R11, 8(R0) + +less_than_8: + TBZ $2, R1, less_than_4 + MOVWU (R0), R11 + EORW R2, R11, R11 + MOVWU.P R11, 4(R0) + +less_than_4: + TBZ $1, R1, less_than_2 + MOVHU (R0), R11 + EORW R3, R11, R11 + MOVHU.P R11, 2(R0) + RORW $16, R3 + +less_than_2: + TBZ $0, R1, end + MOVBU (R0), R11 + EORW R3, R11, R11 + MOVBU.P R11, 1(R0) + RORW $8, R3 + +end: + MOVWU R3, ret+24(FP) + RET diff --git a/mask_asm.go b/mask_asm.go new file mode 100644 index 00000000..f9484b5b --- /dev/null +++ b/mask_asm.go @@ -0,0 +1,26 @@ +//go:build amd64 || arm64 + +package websocket + +func mask(b []byte, key uint32) uint32 { + // TODO: Will enable in v1.9.0. + return maskGo(b, key) + /* + if len(b) > 0 { + return maskAsm(&b[0], len(b), key) + } + return key + */ +} + +// @nhooyr: I am not confident that the amd64 or the arm64 implementations of this +// function are perfect. There are almost certainly missing optimizations or +// opportunities for simplification. I'm confident there are no bugs though. +// For example, the arm64 implementation doesn't align memory like the amd64. +// Or the amd64 implementation could use AVX512 instead of just AVX2. +// The AVX2 code I had to disable anyway as it wasn't performing as expected. +// See https://github.com/nhooyr/websocket/pull/326#issuecomment-1771138049 +// +//go:noescape +//lint:ignore U1000 disabled till v1.9.0 +func maskAsm(b *byte, len int, key uint32) uint32 diff --git a/mask_asm_test.go b/mask_asm_test.go new file mode 100644 index 00000000..416cbc43 --- /dev/null +++ b/mask_asm_test.go @@ -0,0 +1,11 @@ +//go:build amd64 || arm64 + +package websocket + +import "testing" + +func TestMaskASM(t *testing.T) { + t.Parallel() + + testMask(t, "maskASM", mask) +} diff --git a/mask_go.go b/mask_go.go new file mode 100644 index 00000000..b29435e9 --- /dev/null +++ b/mask_go.go @@ -0,0 +1,7 @@ +//go:build !amd64 && !arm64 && !js + +package websocket + +func mask(b []byte, key uint32) uint32 { + return maskGo(b, key) +} diff --git a/mask_test.go b/mask_test.go new file mode 100644 index 00000000..54f55e43 --- /dev/null +++ b/mask_test.go @@ -0,0 +1,73 @@ +package websocket + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "math/big" + "math/bits" + "testing" + + "nhooyr.io/websocket/internal/test/assert" +) + +func basicMask(b []byte, key uint32) uint32 { + for i := range b { + b[i] ^= byte(key) + key = bits.RotateLeft32(key, -8) + } + return key +} + +func basicMask2(b []byte, key uint32) uint32 { + keyb := binary.LittleEndian.AppendUint32(nil, key) + pos := 0 + for i := range b { + b[i] ^= keyb[pos&3] + pos++ + } + return bits.RotateLeft32(key, (pos&3)*-8) +} + +func TestMask(t *testing.T) { + t.Parallel() + + testMask(t, "basicMask", basicMask) + testMask(t, "maskGo", maskGo) + testMask(t, "basicMask2", basicMask2) +} + +func testMask(t *testing.T, name string, fn func(b []byte, key uint32) uint32) { + t.Run(name, func(t *testing.T) { + t.Parallel() + for i := 0; i < 9999; i++ { + keyb := make([]byte, 4) + _, err := rand.Read(keyb) + assert.Success(t, err) + key := binary.LittleEndian.Uint32(keyb) + + n, err := rand.Int(rand.Reader, big.NewInt(1<<16)) + assert.Success(t, err) + + b := make([]byte, 1+n.Int64()) + _, err = rand.Read(b) + assert.Success(t, err) + + b2 := make([]byte, len(b)) + copy(b2, b) + b3 := make([]byte, len(b)) + copy(b3, b) + + key2 := basicMask(b2, key) + key3 := fn(b3, key) + + if key2 != key3 { + t.Errorf("expected key %X but got %X", key2, key3) + } + if !bytes.Equal(b2, b3) { + t.Error("bad bytes") + return + } + } + }) +} diff --git a/netconn.go b/netconn.go index 64aadf0b..86f7dadb 100644 --- a/netconn.go +++ b/netconn.go @@ -6,7 +6,7 @@ import ( "io" "math" "net" - "sync" + "sync/atomic" "time" ) @@ -28,30 +28,64 @@ import ( // // Close will close the *websocket.Conn with StatusNormalClosure. // -// When a deadline is hit, the connection will be closed. This is -// different from most net.Conn implementations where only the -// reading/writing goroutines are interrupted but the connection is kept alive. +// When a deadline is hit and there is an active read or write goroutine, the +// connection will be closed. This is different from most net.Conn implementations +// where only the reading/writing goroutines are interrupted but the connection +// is kept alive. // -// The Addr methods will return a mock net.Addr that returns "websocket" for Network -// and "websocket/unknown-addr" for String. +// The Addr methods will return the real addresses for connections obtained +// from websocket.Accept. But for connections obtained from websocket.Dial, a mock net.Addr +// will be returned that gives "websocket" for Network() and "websocket/unknown-addr" for +// String(). This is because websocket.Dial only exposes a io.ReadWriteCloser instead of the +// full net.Conn to us. +// +// When running as WASM, the Addr methods will always return the mock address described above. // // A received StatusNormalClosure or StatusGoingAway close frame will be translated to // io.EOF when reading. +// +// Furthermore, the ReadLimit is set to -1 to disable it. func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { + c.SetReadLimit(-1) + nc := &netConn{ c: c, msgType: msgType, + readMu: newMu(c), + writeMu: newMu(c), } - var cancel context.CancelFunc - nc.writeContext, cancel = context.WithCancel(ctx) - nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel) + nc.writeCtx, nc.writeCancel = context.WithCancel(ctx) + nc.readCtx, nc.readCancel = context.WithCancel(ctx) + + nc.writeTimer = time.AfterFunc(math.MaxInt64, func() { + if !nc.writeMu.tryLock() { + // If the lock cannot be acquired, then there is an + // active write goroutine and so we should cancel the context. + nc.writeCancel() + return + } + defer nc.writeMu.unlock() + + // Prevents future writes from writing until the deadline is reset. + atomic.StoreInt64(&nc.writeExpired, 1) + }) if !nc.writeTimer.Stop() { <-nc.writeTimer.C } - nc.readContext, cancel = context.WithCancel(ctx) - nc.readTimer = time.AfterFunc(math.MaxInt64, cancel) + nc.readTimer = time.AfterFunc(math.MaxInt64, func() { + if !nc.readMu.tryLock() { + // If the lock cannot be acquired, then there is an + // active read goroutine and so we should cancel the context. + nc.readCancel() + return + } + defer nc.readMu.unlock() + + // Prevents future reads from reading until the deadline is reset. + atomic.StoreInt64(&nc.readExpired, 1) + }) if !nc.readTimer.Stop() { <-nc.readTimer.C } @@ -60,63 +94,98 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { } type netConn struct { + // These must be first to be aligned on 32 bit platforms. + // https://github.com/nhooyr/websocket/pull/438 + readExpired int64 + writeExpired int64 + c *Conn msgType MessageType - writeTimer *time.Timer - writeContext context.Context - - readTimer *time.Timer - readContext context.Context - - readMu sync.Mutex - eofed bool - reader io.Reader + writeTimer *time.Timer + writeMu *mu + writeCtx context.Context + writeCancel context.CancelFunc + + readTimer *time.Timer + readMu *mu + readCtx context.Context + readCancel context.CancelFunc + readEOFed bool + reader io.Reader } var _ net.Conn = &netConn{} -func (c *netConn) Close() error { - return c.c.Close(StatusNormalClosure, "") +func (nc *netConn) Close() error { + nc.writeTimer.Stop() + nc.writeCancel() + nc.readTimer.Stop() + nc.readCancel() + return nc.c.Close(StatusNormalClosure, "") } -func (c *netConn) Write(p []byte) (int, error) { - err := c.c.Write(c.writeContext, c.msgType, p) +func (nc *netConn) Write(p []byte) (int, error) { + nc.writeMu.forceLock() + defer nc.writeMu.unlock() + + if atomic.LoadInt64(&nc.writeExpired) == 1 { + return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded) + } + + err := nc.c.Write(nc.writeCtx, nc.msgType, p) if err != nil { return 0, err } return len(p), nil } -func (c *netConn) Read(p []byte) (int, error) { - c.readMu.Lock() - defer c.readMu.Unlock() +func (nc *netConn) Read(p []byte) (int, error) { + nc.readMu.forceLock() + defer nc.readMu.unlock() - if c.eofed { + for { + n, err := nc.read(p) + if err != nil { + return n, err + } + if n == 0 { + continue + } + return n, nil + } +} + +func (nc *netConn) read(p []byte) (int, error) { + if atomic.LoadInt64(&nc.readExpired) == 1 { + return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded) + } + + if nc.readEOFed { return 0, io.EOF } - if c.reader == nil { - typ, r, err := c.c.Reader(c.readContext) + if nc.reader == nil { + typ, r, err := nc.c.Reader(nc.readCtx) if err != nil { switch CloseStatus(err) { case StatusNormalClosure, StatusGoingAway: - c.eofed = true + nc.readEOFed = true return 0, io.EOF } return 0, err } - if typ != c.msgType { - err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ) - c.c.Close(StatusUnsupportedData, err.Error()) + if typ != nc.msgType { + err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ) + nc.c.Close(StatusUnsupportedData, err.Error()) return 0, err } - c.reader = r + nc.reader = r } - n, err := c.reader.Read(p) + n, err := nc.reader.Read(p) if err == io.EOF { - c.reader = nil + nc.reader = nil err = nil } return n, err @@ -133,34 +202,36 @@ func (a websocketAddr) String() string { return "websocket/unknown-addr" } -func (c *netConn) RemoteAddr() net.Addr { - return websocketAddr{} -} - -func (c *netConn) LocalAddr() net.Addr { - return websocketAddr{} -} - -func (c *netConn) SetDeadline(t time.Time) error { - c.SetWriteDeadline(t) - c.SetReadDeadline(t) +func (nc *netConn) SetDeadline(t time.Time) error { + nc.SetWriteDeadline(t) + nc.SetReadDeadline(t) return nil } -func (c *netConn) SetWriteDeadline(t time.Time) error { +func (nc *netConn) SetWriteDeadline(t time.Time) error { + atomic.StoreInt64(&nc.writeExpired, 0) if t.IsZero() { - c.writeTimer.Stop() + nc.writeTimer.Stop() } else { - c.writeTimer.Reset(t.Sub(time.Now())) + dur := time.Until(t) + if dur <= 0 { + dur = 1 + } + nc.writeTimer.Reset(dur) } return nil } -func (c *netConn) SetReadDeadline(t time.Time) error { +func (nc *netConn) SetReadDeadline(t time.Time) error { + atomic.StoreInt64(&nc.readExpired, 0) if t.IsZero() { - c.readTimer.Stop() + nc.readTimer.Stop() } else { - c.readTimer.Reset(t.Sub(time.Now())) + dur := time.Until(t) + if dur <= 0 { + dur = 1 + } + nc.readTimer.Reset(dur) } return nil } diff --git a/netconn_js.go b/netconn_js.go new file mode 100644 index 00000000..ccc8c89f --- /dev/null +++ b/netconn_js.go @@ -0,0 +1,11 @@ +package websocket + +import "net" + +func (nc *netConn) RemoteAddr() net.Addr { + return websocketAddr{} +} + +func (nc *netConn) LocalAddr() net.Addr { + return websocketAddr{} +} diff --git a/netconn_notjs.go b/netconn_notjs.go new file mode 100644 index 00000000..f3eb0d66 --- /dev/null +++ b/netconn_notjs.go @@ -0,0 +1,20 @@ +//go:build !js +// +build !js + +package websocket + +import "net" + +func (nc *netConn) RemoteAddr() net.Addr { + if unc, ok := nc.c.rwc.(net.Conn); ok { + return unc.RemoteAddr() + } + return websocketAddr{} +} + +func (nc *netConn) LocalAddr() net.Addr { + if unc, ok := nc.c.rwc.(net.Conn); ok { + return unc.LocalAddr() + } + return websocketAddr{} +} diff --git a/read.go b/read.go index ae05cf93..a59e71d9 100644 --- a/read.go +++ b/read.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package websocket @@ -8,15 +9,16 @@ import ( "errors" "fmt" "io" - "io/ioutil" + "net" "strings" "time" "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/util" "nhooyr.io/websocket/internal/xsync" ) -// Reader reads from the connection until until there is a WebSocket +// Reader reads from the connection until there is a WebSocket // data message to be read. It will handle ping, pong and close frames as appropriate. // // It returns the type of the message and an io.Reader to read it. @@ -26,6 +28,11 @@ import ( // Call CloseRead if you do not expect any data messages from the peer. // // Only one Reader may be open at a time. +// +// If you need a separate timeout on the Reader call and the Read itself, +// use time.AfterFunc to cancel the context passed in. +// See https://github.com/nhooyr/websocket/issues/87#issue-451703332 +// Most users should not need this. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { return c.reader(ctx) } @@ -38,7 +45,7 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { return 0, nil, err } - b, err := ioutil.ReadAll(r) + b, err := io.ReadAll(r) return typ, b, err } @@ -53,12 +60,28 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { // Call CloseRead when you do not expect to read any more messages. // Since it actively reads from the connection, it will ensure that ping, pong and close // frames are responded to. This means c.Ping and c.Close will still work as expected. +// +// This function is idempotent. func (c *Conn) CloseRead(ctx context.Context) context.Context { + c.closeReadMu.Lock() + ctx2 := c.closeReadCtx + if ctx2 != nil { + c.closeReadMu.Unlock() + return ctx2 + } ctx, cancel := context.WithCancel(ctx) + c.closeReadCtx = ctx + c.closeReadDone = make(chan struct{}) + c.closeReadMu.Unlock() + go func() { + defer close(c.closeReadDone) defer cancel() - c.Reader(ctx) - c.Close(StatusPolicyViolation, "unexpected data message") + defer c.close() + _, _, err := c.Reader(ctx) + if err == nil { + c.Close(StatusPolicyViolation, "unexpected data message") + } }() return ctx } @@ -69,10 +92,16 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { // By default, the connection has a message read limit of 32768 bytes. // // When the limit is hit, the connection will be closed with StatusMessageTooBig. +// +// Set to -1 to disable. func (c *Conn) SetReadLimit(n int64) { - // We add read one more byte than the limit in case - // there is a fin frame that needs to be read. - c.msgReader.limitReader.limit.Store(n + 1) + if n >= 0 { + // We read one more byte than the limit in case + // there is a fin frame that needs to be read. + n++ + } + + c.msgReader.limitReader.limit.Store(n) } const defaultReadLimit = 32768 @@ -90,13 +119,20 @@ func newMsgReader(c *Conn) *msgReader { func (mr *msgReader) resetFlate() { if mr.flateContextTakeover() { + if mr.dict == nil { + mr.dict = &slidingWindow{} + } mr.dict.init(32768) } if mr.flateBufio == nil { mr.flateBufio = getBufioReader(mr.readFunc) } - mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) + if mr.flateContextTakeover() { + mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) + } else { + mr.flateReader = getFlateReader(mr.flateBufio, nil) + } mr.limitReader.r = mr.flateReader mr.flateTail.Reset(deflateMessageTail) } @@ -111,7 +147,10 @@ func (mr *msgReader) putFlateReader() { func (mr *msgReader) close() { mr.c.readMu.forceLock() mr.putFlateReader() - mr.dict.close() + if mr.dict != nil { + mr.dict.close() + mr.dict = nil + } if mr.flateBufio != nil { putBufioReader(mr.flateBufio) } @@ -181,7 +220,7 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { select { case <-c.closed: - return header{}, c.closeErr + return header{}, net.ErrClosed case c.readTimeout <- ctx: } @@ -189,18 +228,17 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { if err != nil { select { case <-c.closed: - return header{}, c.closeErr + return header{}, net.ErrClosed case <-ctx.Done(): return header{}, ctx.Err() default: - c.close(err) return header{}, err } } select { case <-c.closed: - return header{}, c.closeErr + return header{}, net.ErrClosed case c.readTimeout <- context.Background(): } @@ -210,7 +248,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { select { case <-c.closed: - return 0, c.closeErr + return 0, net.ErrClosed case c.readTimeout <- ctx: } @@ -218,19 +256,17 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { if err != nil { select { case <-c.closed: - return n, c.closeErr + return n, net.ErrClosed case <-ctx.Done(): return n, ctx.Err() default: - err = fmt.Errorf("failed to read frame payload: %w", err) - c.close(err) - return n, err + return n, fmt.Errorf("failed to read frame payload: %w", err) } } select { case <-c.closed: - return n, c.closeErr + return n, net.ErrClosed case c.readTimeout <- context.Background(): } @@ -260,7 +296,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { } if h.masked { - mask(h.maskKey, b) + mask(b, h.maskKey) } switch h.opcode { @@ -279,9 +315,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { return nil } - defer func() { - c.readCloseFrameErr = err - }() + // opClose ce, err := parseClosePayload(b) if err != nil { @@ -291,9 +325,9 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { } err = fmt.Errorf("received close frame: %w", ce) - c.setCloseErr(err) c.writeClose(ce.Code, ce.Reason) - c.close(err) + c.readMu.unlock() + c.close() return err } @@ -307,9 +341,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro defer c.readMu.unlock() if !c.msgReader.fin { - err = errors.New("previous message not read to completion") - c.close(fmt.Errorf("failed to get reader: %w", err)) - return 0, nil, err + return 0, nil, errors.New("previous message not read to completion") } h, err := c.readLoop(ctx) @@ -337,14 +369,14 @@ type msgReader struct { flateBufio *bufio.Reader flateTail strings.Reader limitReader *limitReader - dict slidingWindow + dict *slidingWindow fin bool payloadLength int64 maskKey uint32 - // readerFunc(mr.Read) to avoid continuous allocations. - readFunc readerFunc + // util.ReaderFunc(mr.Read) to avoid continuous allocations. + readFunc util.ReaderFunc } func (mr *msgReader) reset(ctx context.Context, h header) { @@ -382,10 +414,9 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { return n, io.EOF } if err != nil { - err = fmt.Errorf("failed to read: %w", err) - mr.c.close(err) + return n, fmt.Errorf("failed to read: %w", err) } - return n, err + return n, nil } func (mr *msgReader) read(p []byte) (int, error) { @@ -424,7 +455,7 @@ func (mr *msgReader) read(p []byte) (int, error) { mr.payloadLength -= int64(n) if !mr.c.client { - mr.maskKey = mask(mr.maskKey, p) + mr.maskKey = mask(p, mr.maskKey) } return n, nil @@ -453,7 +484,11 @@ func (lr *limitReader) reset(r io.Reader) { } func (lr *limitReader) Read(p []byte) (int, error) { - if lr.n <= 0 { + if lr.n < 0 { + return lr.r.Read(p) + } + + if lr.n == 0 { err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) lr.c.writeError(StatusMessageTooBig, err) return 0, err @@ -464,11 +499,8 @@ func (lr *limitReader) Read(p []byte) (int, error) { } n, err := lr.r.Read(p) lr.n -= int64(n) + if lr.n < 0 { + lr.n = 0 + } return n, err } - -type readerFunc func(p []byte) (int, error) - -func (f readerFunc) Read(p []byte) (int, error) { - return f(p) -} diff --git a/write.go b/write.go index 2210cf81..d7222f2d 100644 --- a/write.go +++ b/write.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package websocket @@ -10,11 +11,13 @@ import ( "errors" "fmt" "io" + "net" "time" - "github.com/klauspost/compress/flate" + "compress/flate" "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/util" ) // Writer returns a writer bounded by the context that will write @@ -36,7 +39,7 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err // // See the Writer method if you want to stream a message. // -// If compression is disabled or the threshold is not met, then it +// If compression is disabled or the compression threshold is not met, then it // will write the message in a single frame. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { _, err := c.write(ctx, typ, p) @@ -47,41 +50,22 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { } type msgWriter struct { - mw *msgWriterState - closed bool -} - -func (mw *msgWriter) Write(p []byte) (int, error) { - if mw.closed { - return 0, errors.New("cannot use closed writer") - } - return mw.mw.Write(p) -} - -func (mw *msgWriter) Close() error { - if mw.closed { - return errors.New("cannot use closed writer") - } - mw.closed = true - return mw.mw.Close() -} - -type msgWriterState struct { c *Conn mu *mu writeMu *mu + closed bool ctx context.Context opcode opcode flate bool - trimWriter *trimLastFourBytesWriter - dict slidingWindow + trimWriter *trimLastFourBytesWriter + flateWriter *flate.Writer } -func newMsgWriterState(c *Conn) *msgWriterState { - mw := &msgWriterState{ +func newMsgWriter(c *Conn) *msgWriter { + mw := &msgWriter{ c: c, mu: newMu(c), writeMu: newMu(c), @@ -89,18 +73,20 @@ func newMsgWriterState(c *Conn) *msgWriterState { return mw } -func (mw *msgWriterState) ensureFlate() { +func (mw *msgWriter) ensureFlate() { if mw.trimWriter == nil { mw.trimWriter = &trimLastFourBytesWriter{ - w: writerFunc(mw.write), + w: util.WriterFunc(mw.write), } } - mw.dict.init(8192) + if mw.flateWriter == nil { + mw.flateWriter = getFlateWriter(mw.trimWriter) + } mw.flate = true } -func (mw *msgWriterState) flateContextTakeover() bool { +func (mw *msgWriter) flateContextTakeover() bool { if mw.c.client { return !mw.c.copts.clientNoContextTakeover } @@ -108,14 +94,11 @@ func (mw *msgWriterState) flateContextTakeover() bool { } func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - err := c.msgWriterState.reset(ctx, typ) + err := c.msgWriter.reset(ctx, typ) if err != nil { return nil, err } - return &msgWriter{ - mw: c.msgWriterState, - closed: false, - }, nil + return c.msgWriter, nil } func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { @@ -125,8 +108,8 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error } if !c.flate() { - defer c.msgWriterState.mu.unlock() - return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p) + defer c.msgWriter.mu.unlock() + return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p) } n, err := mw.Write(p) @@ -138,7 +121,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error return n, err } -func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { +func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { err := mw.mu.lock(ctx) if err != nil { return err @@ -147,24 +130,35 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { mw.ctx = ctx mw.opcode = opcode(typ) mw.flate = false + mw.closed = false mw.trimWriter.reset() return nil } +func (mw *msgWriter) putFlateWriter() { + if mw.flateWriter != nil { + putFlateWriter(mw.flateWriter) + mw.flateWriter = nil + } +} + // Write writes the given bytes to the WebSocket connection. -func (mw *msgWriterState) Write(p []byte) (_ int, err error) { +func (mw *msgWriter) Write(p []byte) (_ int, err error) { err = mw.writeMu.lock(mw.ctx) if err != nil { return 0, fmt.Errorf("failed to write: %w", err) } defer mw.writeMu.unlock() + if mw.closed { + return 0, errors.New("cannot use closed writer") + } + defer func() { if err != nil { err = fmt.Errorf("failed to write: %w", err) - mw.c.close(err) } }() @@ -177,18 +171,13 @@ func (mw *msgWriterState) Write(p []byte) (_ int, err error) { } if mw.flate { - err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf) - if err != nil { - return 0, err - } - mw.dict.write(p) - return len(p), nil + return mw.flateWriter.Write(p) } return mw.write(p) } -func (mw *msgWriterState) write(p []byte) (int, error) { +func (mw *msgWriter) write(p []byte) (int, error) { n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) if err != nil { return n, fmt.Errorf("failed to write data frame: %w", err) @@ -198,7 +187,7 @@ func (mw *msgWriterState) write(p []byte) (int, error) { } // Close flushes the frame to the connection. -func (mw *msgWriterState) Close() (err error) { +func (mw *msgWriter) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") err = mw.writeMu.lock(mw.ctx) @@ -207,26 +196,38 @@ func (mw *msgWriterState) Close() (err error) { } defer mw.writeMu.unlock() + if mw.closed { + return errors.New("writer already closed") + } + mw.closed = true + + if mw.flate { + err = mw.flateWriter.Flush() + if err != nil { + return fmt.Errorf("failed to flush flate: %w", err) + } + } + _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { return fmt.Errorf("failed to write fin frame: %w", err) } if mw.flate && !mw.flateContextTakeover() { - mw.dict.close() + mw.putFlateWriter() } mw.mu.unlock() return nil } -func (mw *msgWriterState) close() { +func (mw *msgWriter) close() { if mw.c.client { mw.c.writeFrameMu.forceLock() putBufioWriter(mw.c.bw) } mw.writeMu.forceLock() - mw.dict.close() + mw.putFlateWriter() } func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { @@ -240,7 +241,7 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error return nil } -// frame handles all writes to the connection. +// writeFrame handles all writes to the connection. func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { err = c.writeFrameMu.lock(ctx) if err != nil { @@ -248,26 +249,9 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco } defer c.writeFrameMu.unlock() - // If the state says a close has already been written, we wait until - // the connection is closed and return that error. - // - // However, if the frame being written is a close, that means its the close from - // the state being set so we let it go through. - c.closeMu.Lock() - wroteClose := c.wroteClose - c.closeMu.Unlock() - if wroteClose && opcode != opClose { - select { - case <-ctx.Done(): - return 0, ctx.Err() - case <-c.closed: - return 0, c.closeErr - } - } - select { case <-c.closed: - return 0, c.closeErr + return 0, net.ErrClosed case c.writeTimeout <- ctx: } @@ -275,11 +259,11 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco if err != nil { select { case <-c.closed: - err = c.closeErr + err = net.ErrClosed case <-ctx.Done(): err = ctx.Err() + default: } - c.close(err) err = fmt.Errorf("failed to write frame: %w", err) } }() @@ -321,7 +305,10 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco select { case <-c.closed: - return n, c.closeErr + if opcode == opClose { + return n, nil + } + return n, net.ErrClosed case c.writeTimeout <- context.Background(): } @@ -358,7 +345,7 @@ func (c *Conn) writeFramePayload(p []byte) (n int, err error) { return n, err } - maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()]) + maskKey = mask(c.writeBuf[i:c.bw.Buffered()], maskKey) p = p[j:] n += j @@ -367,17 +354,11 @@ func (c *Conn) writeFramePayload(p []byte) (n int, err error) { return n, nil } -type writerFunc func(p []byte) (int, error) - -func (f writerFunc) Write(p []byte) (int, error) { - return f(p) -} - // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer // and returns it. func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { var writeBuf []byte - bw.Reset(writerFunc(func(p2 []byte) (int, error) { + bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) { writeBuf = p2[:cap(p2)] return len(p2), nil })) @@ -391,7 +372,5 @@ func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { } func (c *Conn) writeError(code StatusCode, err error) { - c.setCloseErr(err) c.writeClose(code, err.Error()) - c.close(nil) } diff --git a/ws_js.go b/ws_js.go index b87e32cd..02d61f28 100644 --- a/ws_js.go +++ b/ws_js.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "reflect" "runtime" @@ -18,15 +19,38 @@ import ( "nhooyr.io/websocket/internal/xsync" ) +// opcode represents a WebSocket opcode. +type opcode int + +// https://tools.ietf.org/html/rfc6455#section-11.8. +const ( + opContinuation opcode = iota + opText + opBinary + // 3 - 7 are reserved for further non-control frames. + _ + _ + _ + _ + _ + opClose + opPing + opPong + // 11-16 are reserved for further control frames. +) + // Conn provides a wrapper around the browser WebSocket API. type Conn struct { - ws wsjs.WebSocket + noCopy noCopy + ws wsjs.WebSocket // read limit for a message in bytes. msgReadLimit xsync.Int64 + closeReadMu sync.Mutex + closeReadCtx context.Context + closingMu sync.Mutex - isReadClosed xsync.Int64 closeOnce sync.Once closed chan struct{} closeErrOnce sync.Once @@ -34,6 +58,7 @@ type Conn struct { closeWasClean bool releaseOnClose func() + releaseOnError func() releaseOnMessage func() readSignal chan struct{} @@ -71,9 +96,15 @@ func (c *Conn) init() { c.close(err, e.WasClean) c.releaseOnClose() + c.releaseOnError() c.releaseOnMessage() }) + c.releaseOnError = c.ws.OnError(func(v js.Value) { + c.setCloseErr(errors.New(v.Get("message").String())) + c.closeWithInternal() + }) + c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) { c.readBufMu.Lock() defer c.readBufMu.Unlock() @@ -100,7 +131,10 @@ func (c *Conn) closeWithInternal() { // Read attempts to read a message from the connection. // The maximum time spent waiting is bounded by the context. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { - if c.isReadClosed.Load() == 1 { + c.closeReadMu.Lock() + closedRead := c.closeReadCtx != nil + c.closeReadMu.Unlock() + if closedRead { return 0, nil, errors.New("WebSocket connection read closed") } @@ -108,7 +142,8 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { if err != nil { return 0, nil, fmt.Errorf("failed to read: %w", err) } - if int64(len(p)) > c.msgReadLimit.Load() { + readLimit := c.msgReadLimit.Load() + if readLimit >= 0 && int64(len(p)) > readLimit { err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load()) c.Close(StatusMessageTooBig, err.Error()) return 0, nil, err @@ -123,7 +158,7 @@ func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) { return 0, nil, ctx.Err() case <-c.readSignal: case <-c.closed: - return 0, nil, c.closeErr + return 0, nil, net.ErrClosed } c.readBufMu.Lock() @@ -177,7 +212,7 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { if c.isClosed() { - return c.closeErr + return net.ErrClosed } switch typ { case MessageBinary: @@ -201,19 +236,28 @@ func (c *Conn) Close(code StatusCode, reason string) error { return nil } +// CloseNow closes the WebSocket connection without attempting a close handshake. +// Use when you do not want the overhead of the close handshake. +// +// note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close +// a WebSocket without the close handshake. +func (c *Conn) CloseNow() error { + return c.Close(StatusGoingAway, "") +} + func (c *Conn) exportedClose(code StatusCode, reason string) error { c.closingMu.Lock() defer c.closingMu.Unlock() + if c.isClosed() { + return net.ErrClosed + } + ce := fmt.Errorf("sent close: %w", CloseError{ Code: code, Reason: reason, }) - if c.isClosed() { - return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) - } - c.setCloseErr(ce) err := c.ws.Close(int(code), reason) if err != nil { @@ -284,7 +328,7 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Resp StatusCode: http.StatusSwitchingProtocols, }, nil case <-c.closed: - return nil, nil, c.closeErr + return nil, nil, net.ErrClosed } } @@ -302,7 +346,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { // It buffers the entire message in memory and then sends it when the writer // is closed. func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - return writer{ + return &writer{ c: c, ctx: ctx, typ: typ, @@ -320,7 +364,7 @@ type writer struct { b *bytes.Buffer } -func (w writer) Write(p []byte) (int, error) { +func (w *writer) Write(p []byte) (int, error) { if w.closed { return 0, errors.New("cannot write to closed writer") } @@ -331,7 +375,7 @@ func (w writer) Write(p []byte) (int, error) { return n, nil } -func (w writer) Close() error { +func (w *writer) Close() error { if w.closed { return errors.New("cannot close closed writer") } @@ -347,13 +391,23 @@ func (w writer) Close() error { // CloseRead implements *Conn.CloseRead for wasm. func (c *Conn) CloseRead(ctx context.Context) context.Context { - c.isReadClosed.Store(1) - + c.closeReadMu.Lock() + ctx2 := c.closeReadCtx + if ctx2 != nil { + c.closeReadMu.Unlock() + return ctx2 + } ctx, cancel := context.WithCancel(ctx) + c.closeReadCtx = ctx + c.closeReadMu.Unlock() + go func() { defer cancel() - c.read(ctx) - c.Close(StatusPolicyViolation, "unexpected data message") + defer c.CloseNow() + _, _, err := c.read(ctx) + if err != nil { + c.Close(StatusPolicyViolation, "unexpected data message") + } }() return ctx } @@ -377,3 +431,168 @@ func (c *Conn) isClosed() bool { return false } } + +// AcceptOptions represents Accept's options. +type AcceptOptions struct { + Subprotocols []string + InsecureSkipVerify bool + OriginPatterns []string + CompressionMode CompressionMode + CompressionThreshold int +} + +// Accept is stubbed out for Wasm. +func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + return nil, errors.New("unimplemented") +} + +// StatusCode represents a WebSocket status code. +// https://tools.ietf.org/html/rfc6455#section-7.4 +type StatusCode int + +// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// +// These are only the status codes defined by the protocol. +// +// You can define custom codes in the 3000-4999 range. +// The 3000-3999 range is reserved for use by libraries, frameworks and applications. +// The 4000-4999 range is reserved for private use. +const ( + StatusNormalClosure StatusCode = 1000 + StatusGoingAway StatusCode = 1001 + StatusProtocolError StatusCode = 1002 + StatusUnsupportedData StatusCode = 1003 + + // 1004 is reserved and so unexported. + statusReserved StatusCode = 1004 + + // StatusNoStatusRcvd cannot be sent in a close message. + // It is reserved for when a close message is received without + // a status code. + StatusNoStatusRcvd StatusCode = 1005 + + // StatusAbnormalClosure is exported for use only with Wasm. + // In non Wasm Go, the returned error will indicate whether the + // connection was closed abnormally. + StatusAbnormalClosure StatusCode = 1006 + + StatusInvalidFramePayloadData StatusCode = 1007 + StatusPolicyViolation StatusCode = 1008 + StatusMessageTooBig StatusCode = 1009 + StatusMandatoryExtension StatusCode = 1010 + StatusInternalError StatusCode = 1011 + StatusServiceRestart StatusCode = 1012 + StatusTryAgainLater StatusCode = 1013 + StatusBadGateway StatusCode = 1014 + + // StatusTLSHandshake is only exported for use with Wasm. + // In non Wasm Go, the returned error will indicate whether there was + // a TLS handshake failure. + StatusTLSHandshake StatusCode = 1015 +) + +// CloseError is returned when the connection is closed with a status and reason. +// +// Use Go 1.13's errors.As to check for this error. +// Also see the CloseStatus helper. +type CloseError struct { + Code StatusCode + Reason string +} + +func (ce CloseError) Error() string { + return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) +} + +// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab +// the status code from a CloseError. +// +// -1 will be returned if the passed error is nil or not a CloseError. +func CloseStatus(err error) StatusCode { + var ce CloseError + if errors.As(err, &ce) { + return ce.Code + } + return -1 +} + +// CompressionMode represents the modes available to the deflate extension. +// See https://tools.ietf.org/html/rfc7692 +// Works in all browsers except Safari which does not implement the deflate extension. +type CompressionMode int + +const ( + // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed + // for every message. This applies to both server and client side. + // + // This means less efficient compression as the sliding window from previous messages + // will not be used but the memory overhead will be lower if the connections + // are long lived and seldom used. + // + // The message will only be compressed if greater than 512 bytes. + CompressionNoContextTakeover CompressionMode = iota + + // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. + // This enables reusing the sliding window from previous messages. + // As most WebSocket protocols are repetitive, this can be very efficient. + // It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover. + // + // If the peer negotiates NoContextTakeover on the client or server side, it will be + // used instead as this is required by the RFC. + CompressionContextTakeover + + // CompressionDisabled disables the deflate extension. + // + // Use this if you are using a predominantly binary protocol with very + // little duplication in between messages or CPU and memory are more + // important than bandwidth. + CompressionDisabled +) + +// MessageType represents the type of a WebSocket message. +// See https://tools.ietf.org/html/rfc6455#section-5.6 +type MessageType int + +// MessageType constants. +const ( + // MessageText is for UTF-8 encoded text messages like JSON. + MessageText MessageType = iota + 1 + // MessageBinary is for binary messages like protobufs. + MessageBinary +) + +type mu struct { + c *Conn + ch chan struct{} +} + +func newMu(c *Conn) *mu { + return &mu{ + c: c, + ch: make(chan struct{}, 1), + } +} + +func (m *mu) forceLock() { + m.ch <- struct{}{} +} + +func (m *mu) tryLock() bool { + select { + case m.ch <- struct{}{}: + return true + default: + return false + } +} + +func (m *mu) unlock() { + select { + case <-m.ch: + default: + } +} + +type noCopy struct{} + +func (*noCopy) Lock() {} diff --git a/ws_js_test.go b/ws_js_test.go index e6be6181..ba98b9a0 100644 --- a/ws_js_test.go +++ b/ws_js_test.go @@ -36,3 +36,19 @@ func TestWasm(t *testing.T) { err = c.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) } + +func TestWasmDialTimeout(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + + beforeDial := time.Now() + _, _, err := websocket.Dial(ctx, "ws://example.com:9893", &websocket.DialOptions{ + Subprotocols: []string{"echo"}, + }) + assert.Error(t, err) + if time.Since(beforeDial) >= time.Second { + t.Fatal("wasm context dial timeout is not working", time.Since(beforeDial)) + } +} diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 2000a77a..7c986a0d 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -9,6 +9,7 @@ import ( "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/util" ) // Read reads a JSON message from c into v. @@ -51,17 +52,17 @@ func Write(ctx context.Context, c *websocket.Conn, v interface{}) error { func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { defer errd.Wrap(&err, "failed to write JSON message") - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - return err - } - // json.Marshal cannot reuse buffers between calls as it has to return // a copy of the byte slice but Encoder does as it directly writes to w. - err = json.NewEncoder(w).Encode(v) + err = json.NewEncoder(util.WriterFunc(func(p []byte) (int, error) { + err := c.Write(ctx, websocket.MessageText, p) + if err != nil { + return 0, err + } + return len(p), nil + })).Encode(v) if err != nil { return fmt.Errorf("failed to marshal JSON: %w", err) } - - return w.Close() + return nil } diff --git a/wsjson/wsjson_test.go b/wsjson/wsjson_test.go new file mode 100644 index 00000000..080ab38d --- /dev/null +++ b/wsjson/wsjson_test.go @@ -0,0 +1,53 @@ +package wsjson_test + +import ( + "encoding/json" + "io" + "strconv" + "testing" + + "nhooyr.io/websocket/internal/test/xrand" +) + +func BenchmarkJSON(b *testing.B) { + sizes := []int{ + 8, + 16, + 32, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + } + + b.Run("json.Encoder", func(b *testing.B) { + for _, size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + msg := xrand.String(size) + b.SetBytes(int64(size)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + json.NewEncoder(io.Discard).Encode(msg) + } + }) + } + }) + b.Run("json.Marshal", func(b *testing.B) { + for _, size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + msg := xrand.String(size) + b.SetBytes(int64(size)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + json.Marshal(msg) + } + }) + } + }) +} diff --git a/wspb/wspb.go b/wspb/wspb.go deleted file mode 100644 index e43042d5..00000000 --- a/wspb/wspb.go +++ /dev/null @@ -1,73 +0,0 @@ -// Package wspb provides helpers for reading and writing protobuf messages. -package wspb // import "nhooyr.io/websocket/wspb" - -import ( - "bytes" - "context" - "fmt" - - "github.com/golang/protobuf/proto" - - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/bpool" - "nhooyr.io/websocket/internal/errd" -) - -// Read reads a protobuf message from c into v. -// It will reuse buffers in between calls to avoid allocations. -func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { - return read(ctx, c, v) -} - -func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { - defer errd.Wrap(&err, "failed to read protobuf message") - - typ, r, err := c.Reader(ctx) - if err != nil { - return err - } - - if typ != websocket.MessageBinary { - c.Close(websocket.StatusUnsupportedData, "expected binary message") - return fmt.Errorf("expected binary message for protobuf but got: %v", typ) - } - - b := bpool.Get() - defer bpool.Put(b) - - _, err = b.ReadFrom(r) - if err != nil { - return err - } - - err = proto.Unmarshal(b.Bytes(), v) - if err != nil { - c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf") - return fmt.Errorf("failed to unmarshal protobuf: %w", err) - } - - return nil -} - -// Write writes the protobuf message v to c. -// It will reuse buffers in between calls to avoid allocations. -func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error { - return write(ctx, c, v) -} - -func write(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { - defer errd.Wrap(&err, "failed to write protobuf message") - - b := bpool.Get() - pb := proto.NewBuffer(b.Bytes()) - defer func() { - bpool.Put(bytes.NewBuffer(pb.Bytes())) - }() - - err = pb.Marshal(v) - if err != nil { - return fmt.Errorf("failed to marshal protobuf: %w", err) - } - - return c.Write(ctx, websocket.MessageBinary, pb.Bytes()) -} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy