diff --git a/accept.go b/accept.go index 479138fc..47e20b52 100644 --- a/accept.go +++ b/accept.go @@ -9,10 +9,11 @@ import ( "errors" "fmt" "io" + "log" "net/http" "net/textproto" "net/url" - "strconv" + "path/filepath" "strings" "nhooyr.io/websocket/internal/errd" @@ -25,18 +26,27 @@ type AcceptOptions struct { // reject it, close the connection when c.Subprotocol() == "". Subprotocols []string - // InsecureSkipVerify disables Accept's origin verification behaviour. By default, - // the connection will only be accepted if the request origin is equal to the request - // host. + // InsecureSkipVerify is used to disable Accept's origin verification behaviour. // - // This is only required if you want javascript served from a different domain - // to access your WebSocket server. + // Deprecated: Use OriginPatterns with a match all pattern of * instead to control + // origin authorization yourself. + InsecureSkipVerify bool + + // OriginPatterns lists the host patterns for authorized origins. + // The request host is always authorized. + // Use this to enable cross origin WebSockets. + // + // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com. + // In such a case, example.com is the origin and chat.example.com is the request host. + // One would set this field to []string{"example.com"} to authorize example.com to connect. // - // See https://stackoverflow.com/a/37837709/4283659 + // Each pattern is matched case insensitively against the request origin host + // with filepath.Match. + // See https://golang.org/pkg/path/filepath/#Match // // Please ensure you understand the ramifications of enabling this. // If used incorrectly your WebSocket server will be open to CSRF attacks. - InsecureSkipVerify bool + OriginPatterns []string // CompressionMode controls the compression mode. // Defaults to CompressionNoContextTakeover. @@ -77,8 +87,12 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con } if !opts.InsecureSkipVerify { - err = authenticateOrigin(r) + err = authenticateOrigin(r, opts.OriginPatterns) if err != nil { + if errors.Is(err, filepath.ErrBadPattern) { + log.Printf("websocket: %v", err) + err = errors.New(http.StatusText(http.StatusForbidden)) + } http.Error(w, err.Error(), http.StatusForbidden) return nil, err } @@ -165,18 +179,35 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ return 0, nil } -func authenticateOrigin(r *http.Request) error { +func authenticateOrigin(r *http.Request, originHosts []string) error { origin := r.Header.Get("Origin") - if origin != "" { - u, err := url.Parse(origin) + if origin == "" { + return nil + } + + u, err := url.Parse(origin) + if err != nil { + return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) + } + + if strings.EqualFold(r.Host, u.Host) { + return nil + } + + for _, hostPattern := range originHosts { + matched, err := match(hostPattern, u.Host) if err != nil { - return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) + return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) } - if !strings.EqualFold(u.Host, r.Host) { - return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + if matched { + return nil } } - return nil + return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) +} + +func match(pattern, s string) (bool, error) { + return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) } func selectSubprotocol(r *http.Request, subprotocols []string) string { @@ -235,16 +266,6 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi return copts, nil } -// parseExtensionParameter parses the value in the extension parameter p. -func parseExtensionParameter(p string) (int, bool) { - ps := strings.Split(p, "=") - if len(ps) == 1 { - return 0, false - } - i, e := strconv.Atoi(strings.Trim(ps[1], `"`)) - return i, e == nil -} - func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { copts := mode.opts() // The peer must explicitly request it. diff --git a/accept_test.go b/accept_test.go index 49667799..40a7b40c 100644 --- a/accept_test.go +++ b/accept_test.go @@ -244,10 +244,11 @@ func Test_authenticateOrigin(t *testing.T) { t.Parallel() testCases := []struct { - name string - origin string - host string - success bool + name string + origin string + host string + originPatterns []string + success bool }{ { name: "none", @@ -278,6 +279,26 @@ func Test_authenticateOrigin(t *testing.T) { host: "example.com", success: true, }, + { + name: "originPatterns", + origin: "https://two.examplE.com", + host: "example.com", + originPatterns: []string{ + "*.example.com", + "bar.com", + }, + success: true, + }, + { + name: "originPatternsUnauthorized", + origin: "https://two.examplE.com", + host: "example.com", + originPatterns: []string{ + "exam3.com", + "bar.com", + }, + success: false, + }, } for _, tc := range testCases { @@ -288,7 +309,7 @@ func Test_authenticateOrigin(t *testing.T) { r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil) r.Header.Set("Origin", tc.origin) - err := authenticateOrigin(r) + err := authenticateOrigin(r, tc.originPatterns) if tc.success { assert.Success(t, err) } else { diff --git a/example_test.go b/example_test.go index 666914d2..c56e53f3 100644 --- a/example_test.go +++ b/example_test.go @@ -6,7 +6,6 @@ import ( "context" "log" "net/http" - "net/url" "time" "nhooyr.io/websocket" @@ -121,17 +120,8 @@ func Example_writeOnly() { // from the origin example.com. func Example_crossOrigin() { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - origin := r.Header.Get("Origin") - if origin != "" { - u, err := url.Parse(origin) - if err != nil || u.Host != "example.com" { - http.Error(w, "bad origin header", http.StatusForbidden) - return - } - } - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - InsecureSkipVerify: true, + OriginPatterns: []string{"example.com"}, }) if err != nil { log.Println(err)
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: