From cf7ca43b4887bc3ad2f2edece7e406d9a48de98f Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 15 Nov 2024 10:54:33 -0600 Subject: [PATCH 1/3] chore: scim auth header case insensitive for 'bearer' Fixes status codes to return more than 500 --- enterprise/coderd/scim.go | 42 ++++++++++++++++------------- enterprise/coderd/scim/scimtypes.go | 41 +++++++++++++++++++++++++++- enterprise/coderd/scim_test.go | 25 ++++++++++++++--- 3 files changed, 86 insertions(+), 22 deletions(-) diff --git a/enterprise/coderd/scim.go b/enterprise/coderd/scim.go index 01d04626a6948..a7bb502a300eb 100644 --- a/enterprise/coderd/scim.go +++ b/enterprise/coderd/scim.go @@ -1,6 +1,7 @@ package coderd import ( + "bytes" "crypto/subtle" "database/sql" "encoding/json" @@ -26,16 +27,21 @@ import ( ) func (api *API) scimVerifyAuthHeader(r *http.Request) bool { - bearer := []byte("Bearer ") + bearer := []byte("bearer ") hdr := []byte(r.Header.Get("Authorization")) - if len(hdr) >= len(bearer) && subtle.ConstantTimeCompare(hdr[:len(bearer)], bearer) == 1 { + // Use toLower to make the comparison case-insensitive. + if len(hdr) >= len(bearer) && subtle.ConstantTimeCompare(bytes.ToLower(hdr[:len(bearer)]), bearer) == 1 { hdr = hdr[len(bearer):] } return len(api.SCIMAPIKey) != 0 && subtle.ConstantTimeCompare(hdr, api.SCIMAPIKey) == 1 } +func scimUnauthorized(rw http.ResponseWriter) { + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusUnauthorized, "invalidAuthorization", xerrors.New("invalid authorization"))) +} + // scimServiceProviderConfig returns a static SCIM service provider configuration. // // @Summary SCIM 2.0: Service Provider Config @@ -114,7 +120,7 @@ func (api *API) scimServiceProviderConfig(rw http.ResponseWriter, _ *http.Reques //nolint:revive func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) { if !api.scimVerifyAuthHeader(r) { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"}) + scimUnauthorized(rw) return } @@ -142,11 +148,11 @@ func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) { //nolint:revive func (api *API) scimGetUser(rw http.ResponseWriter, r *http.Request) { if !api.scimVerifyAuthHeader(r) { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"}) + scimUnauthorized(rw) return } - _ = handlerutil.WriteError(rw, spec.ErrNotFound) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("endpoint will always return 404"))) } // We currently use our own struct instead of using the SCIM package. This was @@ -192,7 +198,7 @@ var SCIMAuditAdditionalFields = map[string]string{ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() if !api.scimVerifyAuthHeader(r) { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"}) + scimUnauthorized(rw) return } @@ -209,7 +215,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { var sUser SCIMUser err := json.NewDecoder(r.Body).Decode(&sUser) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err)) return } @@ -222,7 +228,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { } if email == "" { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusBadRequest, Type: "invalidEmail"}) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidEmail", xerrors.New("no primary email provided"))) return } @@ -232,7 +238,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { Username: sUser.UserName, }) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, err) // internal error return } if err == nil { @@ -248,7 +254,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { UpdatedAt: dbtime.Now(), }) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, err) // internal error return } aReq.New = newUser @@ -284,14 +290,14 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { //nolint:gocritic // SCIM operations are a system user orgSync, err := api.IDPSync.OrganizationSyncSettings(dbauthz.AsSystemRestricted(ctx), api.Database) if err != nil { - _ = handlerutil.WriteError(rw, xerrors.Errorf("failed to get organization sync settings: %w", err)) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get organization sync settings: %w", err))) return } if orgSync.AssignDefault { //nolint:gocritic // SCIM operations are a system user defaultOrganization, err := api.Database.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx)) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get default organization: %w", err))) return } organizations = append(organizations, defaultOrganization.ID) @@ -309,7 +315,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { SkipNotifications: true, }) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to create user: %w", err))) return } aReq.New = dbUser @@ -335,7 +341,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() if !api.scimVerifyAuthHeader(r) { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"}) + scimUnauthorized(rw) return } @@ -354,21 +360,21 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { var sUser SCIMUser err := json.NewDecoder(r.Body).Decode(&sUser) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err)) return } sUser.ID = id uid, err := uuid.Parse(id) if err != nil { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusBadRequest, Type: "invalidId"}) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidId", xerrors.Errorf("id must be a uuid: %w", err))) return } //nolint:gocritic // needed for SCIM dbUser, err := api.Database.GetUserByID(dbauthz.AsSystemRestricted(ctx), uid) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, err) // internal error return } aReq.Old = dbUser @@ -400,7 +406,7 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { UpdatedAt: dbtime.Now(), }) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, err) // internal error return } dbUser = userNew diff --git a/enterprise/coderd/scim/scimtypes.go b/enterprise/coderd/scim/scimtypes.go index e78b70b3e9f3f..39e022aa24e05 100644 --- a/enterprise/coderd/scim/scimtypes.go +++ b/enterprise/coderd/scim/scimtypes.go @@ -1,6 +1,11 @@ package scim -import "time" +import ( + "encoding/json" + "time" + + "github.com/imulab/go-scim/pkg/v2/spec" +) type ServiceProviderConfig struct { Schemas []string `json:"schemas"` @@ -44,3 +49,37 @@ type AuthenticationScheme struct { SpecURI string `json:"specUri"` DocURI string `json:"documentationUri"` } + +// HTTPError wraps a *spec.Error for correct usage with +// 'handlerutil.WriteError'. This error type is cursed to be +// absolutely strange and specific to the SCIM library we use. +// +// The library expects *spec.Error to be returned on unwrap, and the +// internal error description to be returned by a json.Marshal of the +// top level error. +type HTTPError struct { + scim *spec.Error + internal error +} + +func NewHTTPError(status int, eType string, err error) *HTTPError { + return &HTTPError{ + scim: &spec.Error{ + Status: status, + Type: eType, + }, + internal: err, + } +} + +func (e HTTPError) Error() string { + return e.internal.Error() +} + +func (e HTTPError) MarshalJSON() ([]byte, error) { + return json.Marshal(e.internal) +} + +func (e HTTPError) Unwrap() error { + return e.scim +} diff --git a/enterprise/coderd/scim_test.go b/enterprise/coderd/scim_test.go index 3e5c22f7e9461..3afaa19b11e60 100644 --- a/enterprise/coderd/scim_test.go +++ b/enterprise/coderd/scim_test.go @@ -6,9 +6,12 @@ import ( "fmt" "io" "net/http" + "net/http/httptest" "testing" "github.com/golang-jwt/jwt/v4" + "github.com/imulab/go-scim/pkg/v2/handlerutil" + "github.com/imulab/go-scim/pkg/v2/spec" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -22,6 +25,7 @@ import ( "github.com/coder/coder/v2/enterprise/coderd" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" "github.com/coder/coder/v2/enterprise/coderd/license" + "github.com/coder/coder/v2/enterprise/coderd/scim" "github.com/coder/coder/v2/testutil" ) @@ -59,7 +63,8 @@ func setScimAuth(key []byte) func(*http.Request) { func setScimAuthBearer(key []byte) func(*http.Request) { return func(r *http.Request) { - r.Header.Set("Authorization", "Bearer "+string(key)) + // Do strange casing to ensure it's case-insensitive + r.Header.Set("Authorization", "beAreR "+string(key)) } } @@ -111,7 +116,7 @@ func TestScim(t *testing.T) { res, err := client.Request(ctx, "POST", "/scim/v2/Users", struct{}{}) require.NoError(t, err) defer res.Body.Close() - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) }) t.Run("OK", func(t *testing.T) { @@ -454,7 +459,7 @@ func TestScim(t *testing.T) { require.NoError(t, err) _, _ = io.Copy(io.Discard, res.Body) _ = res.Body.Close() - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) }) t.Run("OK", func(t *testing.T) { @@ -585,3 +590,17 @@ func TestScim(t *testing.T) { }) }) } + +func TestScimError(t *testing.T) { + t.Parallel() + + // Demonstrates that we cannot use the standard errors + rw := httptest.NewRecorder() + _ = handlerutil.WriteError(rw, spec.ErrNotFound) + require.Equal(t, http.StatusInternalServerError, rw.Result().StatusCode) + + // Our error wrapper works + rw = httptest.NewRecorder() + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, fmt.Errorf("not found"))) + require.Equal(t, http.StatusNotFound, rw.Result().StatusCode) +} From 4f224e2fdc4fcc65e8a2540c483dd426e032b75d Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 15 Nov 2024 11:12:55 -0600 Subject: [PATCH 2/3] linting --- enterprise/coderd/scim_test.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/enterprise/coderd/scim_test.go b/enterprise/coderd/scim_test.go index 3afaa19b11e60..806c29a8e0fef 100644 --- a/enterprise/coderd/scim_test.go +++ b/enterprise/coderd/scim_test.go @@ -597,10 +597,14 @@ func TestScimError(t *testing.T) { // Demonstrates that we cannot use the standard errors rw := httptest.NewRecorder() _ = handlerutil.WriteError(rw, spec.ErrNotFound) - require.Equal(t, http.StatusInternalServerError, rw.Result().StatusCode) + resp := rw.Result() + defer resp.Body.Close() + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) // Our error wrapper works rw = httptest.NewRecorder() _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, fmt.Errorf("not found"))) - require.Equal(t, http.StatusNotFound, rw.Result().StatusCode) + resp = rw.Result() + defer resp.Body.Close() + require.Equal(t, http.StatusNotFound, resp.StatusCode) } From d12ab225ec4a45d545cc1139b1de70a16038e9bd Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 15 Nov 2024 11:25:02 -0600 Subject: [PATCH 3/3] linting --- enterprise/coderd/scim_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/enterprise/coderd/scim_test.go b/enterprise/coderd/scim_test.go index 806c29a8e0fef..1f9d230bf7f2d 100644 --- a/enterprise/coderd/scim_test.go +++ b/enterprise/coderd/scim_test.go @@ -14,6 +14,7 @@ import ( "github.com/imulab/go-scim/pkg/v2/spec" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/coderdtest" @@ -603,7 +604,7 @@ func TestScimError(t *testing.T) { // Our error wrapper works rw = httptest.NewRecorder() - _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, fmt.Errorf("not found"))) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("not found"))) resp = rw.Result() defer resp.Body.Close() require.Equal(t, http.StatusNotFound, resp.StatusCode) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy