diff --git a/accept.go b/accept.go index 428abba4..6e1f494e 100644 --- a/accept.go +++ b/accept.go @@ -163,13 +163,13 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) } - if !headerContainsToken(r.Header, "Connection", "Upgrade") { + if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) } - if !headerContainsToken(r.Header, "Upgrade", "websocket") { + if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) @@ -313,11 +313,9 @@ func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode Com return copts, nil } -func headerContainsToken(h http.Header, key, token string) bool { - token = strings.ToLower(token) - +func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { for _, t := range headerTokens(h, key) { - if t == token { + if strings.EqualFold(t, token) { return true } } @@ -358,7 +356,6 @@ func headerTokens(h http.Header, key string) []string { for _, v := range h[key] { v = strings.TrimSpace(v) for _, t := range strings.Split(v, ",") { - t = strings.ToLower(t) t = strings.TrimSpace(t) tokens = append(tokens, t) } diff --git a/accept_test.go b/accept_test.go index f7bc6693..d19f54e1 100644 --- a/accept_test.go +++ b/accept_test.go @@ -226,6 +226,12 @@ func Test_selectSubprotocol(t *testing.T) { serverProtocols: []string{"echo2", "echo3"}, negotiated: "echo3", }, + { + name: "clientCasePresered", + clientProtocols: []string{"Echo1"}, + serverProtocols: []string{"echo1"}, + negotiated: "Echo1", + }, } for _, tc := range testCases { diff --git a/ci/test.sh b/ci/test.sh index 95ef7101..bd68b80e 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -5,9 +5,9 @@ main() { cd "$(dirname "$0")/.." go test -timeout=30m -covermode=atomic -coverprofile=ci/out/coverage.prof -coverpkg=./... "$@" ./... - sed -i '/stringer\.go/d' ci/out/coverage.prof - sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof - sed -i '/examples/d' ci/out/coverage.prof + sed -i.bak '/stringer\.go/d' ci/out/coverage.prof + sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof + sed -i.bak '/examples/d' ci/out/coverage.prof # Last line is the total coverage. go tool cover -func ci/out/coverage.prof | tail -n1 diff --git a/dial.go b/dial.go index d5d2266e..a79b55e6 100644 --- a/dial.go +++ b/dial.go @@ -8,7 +8,6 @@ import ( "context" "crypto/rand" "encoding/base64" - "errors" "fmt" "io" "io/ioutil" @@ -47,18 +46,27 @@ type DialOptions struct { CompressionThreshold int } -func (opts *DialOptions) cloneWithDefaults() *DialOptions { +func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) { + var cancel context.CancelFunc + var o DialOptions if opts != nil { o = *opts } if o.HTTPClient == nil { o.HTTPClient = http.DefaultClient + } else if opts.HTTPClient.Timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout) + + newClient := *opts.HTTPClient + newClient.Timeout = 0 + opts.HTTPClient = &newClient } if o.HTTPHeader == nil { o.HTTPHeader = http.Header{} } - return &o + + return ctx, cancel, &o } // Dial performs a WebSocket handshake on url. @@ -81,7 +89,11 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { defer errd.Wrap(&err, "failed to WebSocket dial") - opts = opts.cloneWithDefaults() + var cancel context.CancelFunc + ctx, cancel, opts = opts.cloneWithDefaults(ctx) + if cancel != nil { + defer cancel() + } secWebSocketKey, err := secWebSocketKey(rand) if err != nil { @@ -137,10 +149,6 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( } func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) { - if opts.HTTPClient.Timeout > 0 { - return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") - } - u, err := url.Parse(urls) if err != nil { return nil, fmt.Errorf("failed to parse url: %w", err) @@ -193,11 +201,11 @@ func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSo return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } - if !headerContainsToken(resp.Header, "Connection", "Upgrade") { + if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") { return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) } - if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") { + if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") { return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } @@ -242,7 +250,8 @@ func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compress return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) } - copts = &*copts + _copts := *copts + copts = &_copts for _, p := range ext.params { switch p { diff --git a/dial_test.go b/dial_test.go index 7f13a934..28c255c6 100644 --- a/dial_test.go +++ b/dial_test.go @@ -36,15 +36,6 @@ func TestBadDials(t *testing.T) { name: "badURLScheme", url: "ftp://nhooyr.io", }, - { - name: "badHTTPClient", - url: "ws://nhooyr.io", - opts: &DialOptions{ - HTTPClient: &http.Client{ - Timeout: time.Minute, - }, - }, - }, { name: "badTLS", url: "wss://totallyfake.nhooyr.io", diff --git a/examples/echo/server.go b/examples/echo/server.go index 308c4a5e..e9f70f03 100644 --- a/examples/echo/server.go +++ b/examples/echo/server.go @@ -16,7 +16,6 @@ import ( // It ensures the client speaks the echo subprotocol and // only allows one message every 100ms with a 10 message burst. type echoServer struct { - // logf controls where logs are sent. logf func(f string, v ...interface{}) } diff --git a/internal/test/wstest/pipe.go b/internal/test/wstest/pipe.go index 1534f316..f3d4c517 100644 --- a/internal/test/wstest/pipe.go +++ b/internal/test/wstest/pipe.go @@ -24,7 +24,8 @@ func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) if dialOpts == nil { dialOpts = &websocket.DialOptions{} } - dialOpts = &*dialOpts + _dialOpts := *dialOpts + dialOpts = &_dialOpts dialOpts.HTTPClient = &http.Client{ Transport: tt, }
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: