diff --git a/enterprise/coderd/scim.go b/enterprise/coderd/scim.go index 01d04626a6948..a7bb502a300eb 100644 --- a/enterprise/coderd/scim.go +++ b/enterprise/coderd/scim.go @@ -1,6 +1,7 @@ package coderd import ( + "bytes" "crypto/subtle" "database/sql" "encoding/json" @@ -26,16 +27,21 @@ import ( ) func (api *API) scimVerifyAuthHeader(r *http.Request) bool { - bearer := []byte("Bearer ") + bearer := []byte("bearer ") hdr := []byte(r.Header.Get("Authorization")) - if len(hdr) >= len(bearer) && subtle.ConstantTimeCompare(hdr[:len(bearer)], bearer) == 1 { + // Use toLower to make the comparison case-insensitive. + if len(hdr) >= len(bearer) && subtle.ConstantTimeCompare(bytes.ToLower(hdr[:len(bearer)]), bearer) == 1 { hdr = hdr[len(bearer):] } return len(api.SCIMAPIKey) != 0 && subtle.ConstantTimeCompare(hdr, api.SCIMAPIKey) == 1 } +func scimUnauthorized(rw http.ResponseWriter) { + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusUnauthorized, "invalidAuthorization", xerrors.New("invalid authorization"))) +} + // scimServiceProviderConfig returns a static SCIM service provider configuration. // // @Summary SCIM 2.0: Service Provider Config @@ -114,7 +120,7 @@ func (api *API) scimServiceProviderConfig(rw http.ResponseWriter, _ *http.Reques //nolint:revive func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) { if !api.scimVerifyAuthHeader(r) { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"}) + scimUnauthorized(rw) return } @@ -142,11 +148,11 @@ func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) { //nolint:revive func (api *API) scimGetUser(rw http.ResponseWriter, r *http.Request) { if !api.scimVerifyAuthHeader(r) { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"}) + scimUnauthorized(rw) return } - _ = handlerutil.WriteError(rw, spec.ErrNotFound) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("endpoint will always return 404"))) } // We currently use our own struct instead of using the SCIM package. This was @@ -192,7 +198,7 @@ var SCIMAuditAdditionalFields = map[string]string{ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() if !api.scimVerifyAuthHeader(r) { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"}) + scimUnauthorized(rw) return } @@ -209,7 +215,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { var sUser SCIMUser err := json.NewDecoder(r.Body).Decode(&sUser) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err)) return } @@ -222,7 +228,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { } if email == "" { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusBadRequest, Type: "invalidEmail"}) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidEmail", xerrors.New("no primary email provided"))) return } @@ -232,7 +238,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { Username: sUser.UserName, }) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, err) // internal error return } if err == nil { @@ -248,7 +254,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { UpdatedAt: dbtime.Now(), }) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, err) // internal error return } aReq.New = newUser @@ -284,14 +290,14 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { //nolint:gocritic // SCIM operations are a system user orgSync, err := api.IDPSync.OrganizationSyncSettings(dbauthz.AsSystemRestricted(ctx), api.Database) if err != nil { - _ = handlerutil.WriteError(rw, xerrors.Errorf("failed to get organization sync settings: %w", err)) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get organization sync settings: %w", err))) return } if orgSync.AssignDefault { //nolint:gocritic // SCIM operations are a system user defaultOrganization, err := api.Database.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx)) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get default organization: %w", err))) return } organizations = append(organizations, defaultOrganization.ID) @@ -309,7 +315,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { SkipNotifications: true, }) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to create user: %w", err))) return } aReq.New = dbUser @@ -335,7 +341,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() if !api.scimVerifyAuthHeader(r) { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"}) + scimUnauthorized(rw) return } @@ -354,21 +360,21 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { var sUser SCIMUser err := json.NewDecoder(r.Body).Decode(&sUser) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err)) return } sUser.ID = id uid, err := uuid.Parse(id) if err != nil { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusBadRequest, Type: "invalidId"}) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidId", xerrors.Errorf("id must be a uuid: %w", err))) return } //nolint:gocritic // needed for SCIM dbUser, err := api.Database.GetUserByID(dbauthz.AsSystemRestricted(ctx), uid) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, err) // internal error return } aReq.Old = dbUser @@ -400,7 +406,7 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { UpdatedAt: dbtime.Now(), }) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, err) // internal error return } dbUser = userNew diff --git a/enterprise/coderd/scim/scimtypes.go b/enterprise/coderd/scim/scimtypes.go index e78b70b3e9f3f..39e022aa24e05 100644 --- a/enterprise/coderd/scim/scimtypes.go +++ b/enterprise/coderd/scim/scimtypes.go @@ -1,6 +1,11 @@ package scim -import "time" +import ( + "encoding/json" + "time" + + "github.com/imulab/go-scim/pkg/v2/spec" +) type ServiceProviderConfig struct { Schemas []string `json:"schemas"` @@ -44,3 +49,37 @@ type AuthenticationScheme struct { SpecURI string `json:"specUri"` DocURI string `json:"documentationUri"` } + +// HTTPError wraps a *spec.Error for correct usage with +// 'handlerutil.WriteError'. This error type is cursed to be +// absolutely strange and specific to the SCIM library we use. +// +// The library expects *spec.Error to be returned on unwrap, and the +// internal error description to be returned by a json.Marshal of the +// top level error. +type HTTPError struct { + scim *spec.Error + internal error +} + +func NewHTTPError(status int, eType string, err error) *HTTPError { + return &HTTPError{ + scim: &spec.Error{ + Status: status, + Type: eType, + }, + internal: err, + } +} + +func (e HTTPError) Error() string { + return e.internal.Error() +} + +func (e HTTPError) MarshalJSON() ([]byte, error) { + return json.Marshal(e.internal) +} + +func (e HTTPError) Unwrap() error { + return e.scim +} diff --git a/enterprise/coderd/scim_test.go b/enterprise/coderd/scim_test.go index 3e5c22f7e9461..1f9d230bf7f2d 100644 --- a/enterprise/coderd/scim_test.go +++ b/enterprise/coderd/scim_test.go @@ -6,11 +6,15 @@ import ( "fmt" "io" "net/http" + "net/http/httptest" "testing" "github.com/golang-jwt/jwt/v4" + "github.com/imulab/go-scim/pkg/v2/handlerutil" + "github.com/imulab/go-scim/pkg/v2/spec" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/coderdtest" @@ -22,6 +26,7 @@ import ( "github.com/coder/coder/v2/enterprise/coderd" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" "github.com/coder/coder/v2/enterprise/coderd/license" + "github.com/coder/coder/v2/enterprise/coderd/scim" "github.com/coder/coder/v2/testutil" ) @@ -59,7 +64,8 @@ func setScimAuth(key []byte) func(*http.Request) { func setScimAuthBearer(key []byte) func(*http.Request) { return func(r *http.Request) { - r.Header.Set("Authorization", "Bearer "+string(key)) + // Do strange casing to ensure it's case-insensitive + r.Header.Set("Authorization", "beAreR "+string(key)) } } @@ -111,7 +117,7 @@ func TestScim(t *testing.T) { res, err := client.Request(ctx, "POST", "/scim/v2/Users", struct{}{}) require.NoError(t, err) defer res.Body.Close() - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) }) t.Run("OK", func(t *testing.T) { @@ -454,7 +460,7 @@ func TestScim(t *testing.T) { require.NoError(t, err) _, _ = io.Copy(io.Discard, res.Body) _ = res.Body.Close() - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) }) t.Run("OK", func(t *testing.T) { @@ -585,3 +591,21 @@ func TestScim(t *testing.T) { }) }) } + +func TestScimError(t *testing.T) { + t.Parallel() + + // Demonstrates that we cannot use the standard errors + rw := httptest.NewRecorder() + _ = handlerutil.WriteError(rw, spec.ErrNotFound) + resp := rw.Result() + defer resp.Body.Close() + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) + + // Our error wrapper works + rw = httptest.NewRecorder() + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("not found"))) + resp = rw.Result() + defer resp.Body.Close() + require.Equal(t, http.StatusNotFound, resp.StatusCode) +}
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: