diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 5cc235fbdacb9..90c9c386628f1 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -775,7 +775,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { if f.hookWellKnown != nil { err := f.hookWellKnown(r, &cpy) if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) + httpError(rw, http.StatusInternalServerError, err) return } } @@ -792,7 +792,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { clientID := r.URL.Query().Get("client_id") if !assert.Equal(t, f.clientID, clientID, "unexpected client_id") { - http.Error(rw, "invalid client_id", http.StatusBadRequest) + httpError(rw, http.StatusBadRequest, xerrors.New("invalid client_id")) return } @@ -818,7 +818,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { err := f.hookValidRedirectURL(redirectURI) if err != nil { t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error()) - http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) + httpError(rw, http.StatusBadRequest, xerrors.Errorf("invalid redirect_uri: %w", err)) return } @@ -853,7 +853,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { )...) if err != nil { - http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) + httpError(rw, http.StatusBadRequest, err) return } getEmail := func(claims jwt.MapClaims) string { @@ -914,7 +914,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { claims = idTokenClaims err := f.hookOnRefresh(getEmail(claims)) if err != nil { - http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) + httpError(rw, http.StatusBadRequest, xerrors.Errorf("refresh hook blocked refresh: %w", err)) return } @@ -1036,7 +1036,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { claims, err := f.hookUserInfo(email) if err != nil { - http.Error(rw, fmt.Sprintf("user info hook returned error: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) + httpError(rw, http.StatusBadRequest, xerrors.Errorf("user info hook returned error: %w", err)) return } _ = json.NewEncoder(rw).Encode(claims) @@ -1499,13 +1499,33 @@ func slogRequestFields(r *http.Request) []any { } } -func httpErrorCode(defaultCode int, err error) int { - var statusErr statusHookError +// httpError handles better formatted custom errors. +func httpError(rw http.ResponseWriter, defaultCode int, err error) { status := defaultCode + + var statusErr statusHookError if errors.As(err, &statusErr) { status = statusErr.HTTPStatusCode } - return status + + var oauthErr *oauth2.RetrieveError + if errors.As(err, &oauthErr) { + if oauthErr.Response.StatusCode != 0 { + status = oauthErr.Response.StatusCode + } + + rw.Header().Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8") + form := url.Values{ + "error": {oauthErr.ErrorCode}, + "error_description": {oauthErr.ErrorDescription}, + "error_uri": {oauthErr.ErrorURI}, + } + rw.WriteHeader(status) + _, _ = rw.Write([]byte(form.Encode())) + return + } + + http.Error(rw, err.Error(), status) } type fakeRoundTripper struct { diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 4845ff22288fe..669ab42546777 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3319,6 +3319,13 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg) } +func (q *querier) RemoveRefreshToken(ctx context.Context, arg database.RemoveRefreshTokenParams) error { + fetch := func(ctx context.Context, arg database.RemoveRefreshTokenParams) (database.ExternalAuthLink, error) { + return q.db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{UserID: arg.UserID, ProviderID: arg.ProviderID}) + } + return fetchAndExec(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.RemoveRefreshToken)(ctx, arg) +} + func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error { // This is a system function to clear user groups in group sync. if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 2eb75f8b738c4..6570b51f263b8 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1269,6 +1269,14 @@ func (s *MethodTestSuite) TestUser() { UserID: u.ID, }).Asserts(u, policy.ActionUpdatePersonal) })) + s.Run("RemoveRefreshToken", s.Subtest(func(db database.Store, check *expects) { + link := dbgen.ExternalAuthLink(s.T(), db, database.ExternalAuthLink{}) + check.Args(database.RemoveRefreshTokenParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + UpdatedAt: link.UpdatedAt, + }).Asserts(rbac.ResourceUserObject(link.UserID), policy.ActionUpdatePersonal) + })) s.Run("UpdateExternalAuthLink", s.Subtest(func(db database.Store, check *expects) { link := dbgen.ExternalAuthLink(s.T(), db, database.ExternalAuthLink{}) check.Args(database.UpdateExternalAuthLinkParams{ diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index aed57e9284b3a..1891d62510351 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -8512,6 +8512,29 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg return database.WorkspaceProxy{}, sql.ErrNoRows } +func (q *FakeQuerier) RemoveRefreshToken(_ context.Context, arg database.RemoveRefreshTokenParams) error { + if err := validateDatabaseType(arg); err != nil { + return err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + for index, gitAuthLink := range q.externalAuthLinks { + if gitAuthLink.ProviderID != arg.ProviderID { + continue + } + if gitAuthLink.UserID != arg.UserID { + continue + } + gitAuthLink.UpdatedAt = arg.UpdatedAt + gitAuthLink.OAuthRefreshToken = "" + q.externalAuthLinks[index] = gitAuthLink + + return nil + } + return sql.ErrNoRows +} + func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUID) error { q.mutex.Lock() defer q.mutex.Unlock() diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 32d3cce658525..4b3431199ec48 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -2086,6 +2086,13 @@ func (m queryMetricsStore) RegisterWorkspaceProxy(ctx context.Context, arg datab return proxy, err } +func (m queryMetricsStore) RemoveRefreshToken(ctx context.Context, arg database.RemoveRefreshTokenParams) error { + start := time.Now() + r0 := m.s.RemoveRefreshToken(ctx, arg) + m.queryLatencies.WithLabelValues("RemoveRefreshToken").Observe(time.Since(start).Seconds()) + return r0 +} + func (m queryMetricsStore) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error { start := time.Now() r0 := m.s.RemoveUserFromAllGroups(ctx, userID) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index d6c34411f8208..28571a824e8c3 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -4448,6 +4448,20 @@ func (mr *MockStoreMockRecorder) RegisterWorkspaceProxy(arg0, arg1 any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterWorkspaceProxy", reflect.TypeOf((*MockStore)(nil).RegisterWorkspaceProxy), arg0, arg1) } +// RemoveRefreshToken mocks base method. +func (m *MockStore) RemoveRefreshToken(arg0 context.Context, arg1 database.RemoveRefreshTokenParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveRefreshToken", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveRefreshToken indicates an expected call of RemoveRefreshToken. +func (mr *MockStoreMockRecorder) RemoveRefreshToken(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveRefreshToken", reflect.TypeOf((*MockStore)(nil).RemoveRefreshToken), arg0, arg1) +} + // RemoveUserFromAllGroups mocks base method. func (m *MockStore) RemoveUserFromAllGroups(arg0 context.Context, arg1 uuid.UUID) error { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 49ba6fbf8496a..c19851f5d067f 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -423,6 +423,10 @@ type sqlcQuerier interface { OrganizationMembers(ctx context.Context, arg OrganizationMembersParams) ([]OrganizationMembersRow, error) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error) + // Removing the refresh token disables the refresh behavior for a given + // auth token. If a refresh token is marked invalid, it is better to remove it + // then continually attempt to refresh the token. + RemoveRefreshToken(ctx context.Context, arg RemoveRefreshTokenParams) error RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 09dd4c1fbc488..2bec1b5188b36 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1194,6 +1194,29 @@ func (q *sqlQuerier) InsertExternalAuthLink(ctx context.Context, arg InsertExter return i, err } +const removeRefreshToken = `-- name: RemoveRefreshToken :exec +UPDATE + external_auth_links +SET + oauth_refresh_token = '', + updated_at = $1 +WHERE provider_id = $2 AND user_id = $3 +` + +type RemoveRefreshTokenParams struct { + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + ProviderID string `db:"provider_id" json:"provider_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` +} + +// Removing the refresh token disables the refresh behavior for a given +// auth token. If a refresh token is marked invalid, it is better to remove it +// then continually attempt to refresh the token. +func (q *sqlQuerier) RemoveRefreshToken(ctx context.Context, arg RemoveRefreshTokenParams) error { + _, err := q.db.ExecContext(ctx, removeRefreshToken, arg.UpdatedAt, arg.ProviderID, arg.UserID) + return err +} + const updateExternalAuthLink = `-- name: UpdateExternalAuthLink :one UPDATE external_auth_links SET updated_at = $3, diff --git a/coderd/database/queries/externalauth.sql b/coderd/database/queries/externalauth.sql index 8470c44ea9125..cd223bd792a2a 100644 --- a/coderd/database/queries/externalauth.sql +++ b/coderd/database/queries/externalauth.sql @@ -42,3 +42,14 @@ UPDATE external_auth_links SET oauth_expiry = $8, oauth_extra = $9 WHERE provider_id = $1 AND user_id = $2 RETURNING *; + +-- name: RemoveRefreshToken :exec +-- Removing the refresh token disables the refresh behavior for a given +-- auth token. If a refresh token is marked invalid, it is better to remove it +-- then continually attempt to refresh the token. +UPDATE + external_auth_links +SET + oauth_refresh_token = '', + updated_at = @updated_at +WHERE provider_id = @provider_id AND user_id = @user_id; diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index 2ad2761e80b46..1ce850c9cec03 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -118,7 +118,7 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu // This is true for github, which has no expiry. !externalAuthLink.OAuthExpiry.IsZero() && externalAuthLink.OAuthExpiry.Before(dbtime.Now()) { - return externalAuthLink, InvalidTokenError("token expired, refreshing is disabled") + return externalAuthLink, InvalidTokenError("token expired, refreshing is either disabled or refreshing failed and will not be retried") } // This is additional defensive programming. Because TokenSource is an interface, @@ -130,16 +130,41 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu refreshToken = "" } - token, err := c.TokenSource(ctx, &oauth2.Token{ + existingToken := &oauth2.Token{ AccessToken: externalAuthLink.OAuthAccessToken, RefreshToken: refreshToken, Expiry: externalAuthLink.OAuthExpiry, - }).Token() + } + + token, err := c.TokenSource(ctx, existingToken).Token() if err != nil { - // Even if the token fails to be obtained, do not return the error as an error. + // TokenSource can fail for numerous reasons. If it fails because of + // a bad refresh token, then the refresh token is invalid, and we should + // get rid of it. Keeping it around will cause additional refresh + // attempts that will fail and cost us api rate limits. + if isFailedRefresh(existingToken, err) { + dbExecErr := db.RemoveRefreshToken(ctx, database.RemoveRefreshTokenParams{ + UpdatedAt: dbtime.Now(), + ProviderID: externalAuthLink.ProviderID, + UserID: externalAuthLink.UserID, + }) + if dbExecErr != nil { + // This error should be rare. + return externalAuthLink, InvalidTokenError(fmt.Sprintf("refresh token failed: %q, then removing refresh token failed: %q", err.Error(), dbExecErr.Error())) + } + // The refresh token was cleared + externalAuthLink.OAuthRefreshToken = "" + } + + // Unfortunately have to match exactly on the error message string. + // Improve the error message to account refresh tokens are deleted if + // invalid on our end. + if err.Error() == "oauth2: token expired and refresh token is not set" { + return externalAuthLink, InvalidTokenError("token expired, refreshing is either disabled or refreshing failed and will not be retried") + } + // TokenSource(...).Token() will always return the current token if the token is not expired. - // If it is expired, it will attempt to refresh the token, and if it cannot, it will fail with - // an error. This error is a reason the token is invalid. + // So this error is only returned if a refresh of the token failed. return externalAuthLink, InvalidTokenError(fmt.Sprintf("refresh token: %s", err.Error())) } @@ -973,3 +998,50 @@ func IsGithubDotComURL(str string) bool { } return ghURL.Host == "github.com" } + +// isFailedRefresh returns true if the error returned by the TokenSource.Token() +// is due to a failed refresh. The failure being the refresh token itself. +// If this returns true, no amount of retries will fix the issue. +// +// Notes: Provider responses are not uniform. Here are some examples: +// Github +// - Returns a 200 with Code "bad_refresh_token" and Description "The refresh token passed is incorrect or expired." +// +// Gitea [TODO: get an expired refresh token] +// - [Bad JWT] Returns 400 with Code "unauthorized_client" and Description "unable to parse refresh token" +// +// Gitlab +// - Returns 400 with Code "invalid_grant" and Description "The provided authorization grant is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client." +func isFailedRefresh(existingToken *oauth2.Token, err error) bool { + if existingToken.RefreshToken == "" { + return false // No refresh token, so this cannot be refreshed + } + + if existingToken.Valid() { + return false // Valid tokens are not refreshed + } + + var oauthErr *oauth2.RetrieveError + if xerrors.As(err, &oauthErr) { + switch oauthErr.ErrorCode { + // Known error codes that indicate a failed refresh. + // 'Spec' means the code is defined in the spec. + case "bad_refresh_token", // Github + "invalid_grant", // Gitlab & Spec + "unauthorized_client", // Gitea & Spec + "unsupported_grant_type": // Spec, refresh not supported + return true + } + + switch oauthErr.Response.StatusCode { + case http.StatusBadRequest, http.StatusUnauthorized, http.StatusForbidden, http.StatusOK: + // Status codes that indicate the request was processed, and rejected. + return true + case http.StatusInternalServerError, http.StatusTooManyRequests: + // These do not indicate a failed refresh, but could be a temporary issue. + return false + } + } + + return false +} diff --git a/coderd/externalauth/externalauth_test.go b/coderd/externalauth/externalauth_test.go index fbc1cab4b7091..84bded9856572 100644 --- a/coderd/externalauth/externalauth_test.go +++ b/coderd/externalauth/externalauth_test.go @@ -17,6 +17,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "golang.org/x/oauth2" "golang.org/x/xerrors" @@ -25,6 +26,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/codersdk" @@ -62,7 +64,7 @@ func TestRefreshToken(t *testing.T) { _, err := config.RefreshToken(ctx, nil, link) require.Error(t, err) require.True(t, externalauth.IsInvalidTokenError(err)) - require.Contains(t, err.Error(), "refreshing is disabled") + require.Contains(t, err.Error(), "refreshing is either disabled or refreshing failed") }) // NoRefreshNoExpiry tests that an oauth token without an expiry is always valid. @@ -141,6 +143,73 @@ func TestRefreshToken(t *testing.T) { require.True(t, validated, "token should have been attempted to be validated") }) + // RefreshRetries tests that refresh token retry behavior works as expected. + // If a refresh token fails because the token itself is invalid, no more + // refresh attempts should ever happen. An invalid refresh token does + // not magically become valid at some point in the future. + t.Run("RefreshRetries", func(t *testing.T) { + t.Parallel() + + var refreshErr *oauth2.RetrieveError + + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + + refreshCount := 0 + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithRefresh(func(_ string) error { + refreshCount++ + return refreshErr + }), + // The IDP should not be contacted since the token is expired and + // refresh attempts will fail. + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + t.Error("token was validated, but it was expired and this should never have happened.") + return nil, xerrors.New("should not be called") + }), + }, + ExternalAuthOpt: func(cfg *externalauth.Config) {}, + }) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + // Expire the link + link.OAuthExpiry = expired + + // Make the failure a server internal error. Not related to the token + refreshErr = &oauth2.RetrieveError{ + Response: &http.Response{ + StatusCode: http.StatusInternalServerError, + }, + ErrorCode: "internal_error", + } + _, err := config.RefreshToken(ctx, mDB, link) + require.Error(t, err) + require.True(t, externalauth.IsInvalidTokenError(err)) + require.Equal(t, refreshCount, 1) + + // Try again with a bad refresh token error + // Expect DB call to remove the refresh token + mDB.EXPECT().RemoveRefreshToken(gomock.Any(), gomock.Any()).Return(nil).Times(1) + refreshErr = &oauth2.RetrieveError{ // github error + Response: &http.Response{ + StatusCode: http.StatusOK, + }, + ErrorCode: "bad_refresh_token", + } + _, err = config.RefreshToken(ctx, mDB, link) + require.Error(t, err) + require.True(t, externalauth.IsInvalidTokenError(err)) + require.Equal(t, refreshCount, 2) + + // When the refresh token is empty, no api calls should be made + link.OAuthRefreshToken = "" // mock'd db, so manually set the token to '' + _, err = config.RefreshToken(ctx, mDB, link) + require.Error(t, err) + require.True(t, externalauth.IsInvalidTokenError(err)) + require.Equal(t, refreshCount, 2) + }) + // ValidateFailure tests if the token is no longer valid with a 401 response. t.Run("ValidateFailure", func(t *testing.T) { t.Parallel()
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: