Skip to content

Commit d3c2c68

Browse files
committed
fix: stop extending API key access if OIDC refresh is available
1 parent ca5a78a commit d3c2c68

File tree

4 files changed

+204
-42
lines changed

4 files changed

+204
-42
lines changed

coderd/coderdtest/oidctest/idp.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values
307307
// WithLogging is optional, but will log some HTTP calls made to the IDP.
308308
func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
309309
return func(f *FakeIDP) {
310-
f.logger = slogtest.Make(t, options)
310+
f.logger = slogtest.Make(t, options).Named("fakeidp")
311311
}
312312
}
313313

@@ -794,6 +794,7 @@ func (f *FakeIDP) newToken(t testing.TB, email string, expires time.Time) string
794794
func (f *FakeIDP) newRefreshTokens(email string) string {
795795
refreshToken := uuid.NewString()
796796
f.refreshTokens.Store(refreshToken, email)
797+
f.logger.Info(context.Background(), "new refresh token", slog.F("email", email), slog.F("token", refreshToken))
797798
return refreshToken
798799
}
799800

@@ -1003,6 +1004,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
10031004
return
10041005
}
10051006

1007+
f.logger.Info(r.Context(), "http idp call refresh_token", slog.F("token", refreshToken))
10061008
_, ok := f.refreshTokens.Load(refreshToken)
10071009
if !assert.True(t, ok, "invalid refresh_token") {
10081010
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
@@ -1026,6 +1028,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
10261028
f.refreshTokensUsed.Store(refreshToken, true)
10271029
// Always invalidate the refresh token after it is used.
10281030
f.refreshTokens.Delete(refreshToken)
1031+
f.logger.Info(r.Context(), "refresh token invalidated", slog.F("token", refreshToken))
10291032
case "urn:ietf:params:oauth:grant-type:device_code":
10301033
// Device flow
10311034
var resp externalauth.ExchangeDeviceCodeResponse

coderd/httpmw/apikey.go

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,17 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
238238
// Tracks if the API key has properties updated
239239
changed = false
240240
)
241+
242+
if key.ExpiresAt.Before(now) {
243+
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
244+
Message: SignedOutErrorMessage,
245+
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
246+
})
247+
}
248+
249+
// We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor
250+
// really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly
251+
// refreshing the OIDC token.
241252
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
242253
var err error
243254
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
@@ -258,7 +269,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
258269
})
259270
}
260271
// Check if the OAuth token is expired
261-
if link.OAuthExpiry.Before(now) && !link.OAuthExpiry.IsZero() && link.OAuthRefreshToken != "" {
272+
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
262273
if cfg.OAuth2Configs.IsZero() {
263274
return write(http.StatusInternalServerError, codersdk.Response{
264275
Message: internalErrorMessage,
@@ -267,12 +278,15 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
267278
})
268279
}
269280

281+
var friendlyName string
270282
var oauthConfig promoauth.OAuth2Config
271283
switch key.LoginType {
272284
case database.LoginTypeGithub:
273285
oauthConfig = cfg.OAuth2Configs.Github
286+
friendlyName = "GitHub"
274287
case database.LoginTypeOIDC:
275288
oauthConfig = cfg.OAuth2Configs.OIDC
289+
friendlyName = "OpenID Connect"
276290
default:
277291
return write(http.StatusInternalServerError, codersdk.Response{
278292
Message: internalErrorMessage,
@@ -292,37 +306,51 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
292306
})
293307
}
294308

295-
// If it is, let's refresh it from the provided config
309+
if link.OAuthRefreshToken == "" {
310+
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
311+
Message: SignedOutErrorMessage,
312+
Detail: fmt.Sprintf("%s session expired at %q.", friendlyName, link.OAuthExpiry.String()),
313+
})
314+
}
315+
// We have a refresh token, so let's try it
296316
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
297317
AccessToken: link.OAuthAccessToken,
298318
RefreshToken: link.OAuthRefreshToken,
299319
Expiry: link.OAuthExpiry,
300320
}).Token()
301321
if err != nil {
302322
return write(http.StatusUnauthorized, codersdk.Response{
303-
Message: "Could not refresh expired Oauth token. Try re-authenticating to resolve this issue.",
304-
Detail: err.Error(),
323+
Message: fmt.Sprintf(
324+
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
325+
friendlyName),
326+
Detail: err.Error(),
305327
})
306328
}
307329
link.OAuthAccessToken = token.AccessToken
308330
link.OAuthRefreshToken = token.RefreshToken
309331
link.OAuthExpiry = token.Expiry
310-
key.ExpiresAt = token.Expiry
311-
changed = true
332+
//nolint:gocritic // system needs to update user link
333+
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
334+
UserID: link.UserID,
335+
LoginType: link.LoginType,
336+
OAuthAccessToken: link.OAuthAccessToken,
337+
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
338+
OAuthRefreshToken: link.OAuthRefreshToken,
339+
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
340+
OAuthExpiry: link.OAuthExpiry,
341+
// Refresh should keep the same debug context because we use
342+
// the original claims for the group/role sync.
343+
Claims: link.Claims,
344+
})
345+
if err != nil {
346+
return write(http.StatusInternalServerError, codersdk.Response{
347+
Message: internalErrorMessage,
348+
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
349+
})
350+
}
312351
}
313352
}
314353

315-
// Checking if the key is expired.
316-
// NOTE: The `RequireAuth` React component depends on this `Detail` to detect when
317-
// the users token has expired. If you change the text here, make sure to update it
318-
// in site/src/components/RequireAuth/RequireAuth.tsx as well.
319-
if key.ExpiresAt.Before(now) {
320-
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
321-
Message: SignedOutErrorMessage,
322-
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
323-
})
324-
}
325-
326354
// Only update LastUsed once an hour to prevent database spam.
327355
if now.Sub(key.LastUsed) > time.Hour {
328356
key.LastUsed = now
@@ -363,29 +391,6 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
363391
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
364392
})
365393
}
366-
// If the API Key is associated with a user_link (e.g. Github/OIDC)
367-
// then we want to update the relevant oauth fields.
368-
if link.UserID != uuid.Nil {
369-
//nolint:gocritic // system needs to update user link
370-
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
371-
UserID: link.UserID,
372-
LoginType: link.LoginType,
373-
OAuthAccessToken: link.OAuthAccessToken,
374-
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
375-
OAuthRefreshToken: link.OAuthRefreshToken,
376-
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
377-
OAuthExpiry: link.OAuthExpiry,
378-
// Refresh should keep the same debug context because we use
379-
// the original claims for the group/role sync.
380-
Claims: link.Claims,
381-
})
382-
if err != nil {
383-
return write(http.StatusInternalServerError, codersdk.Response{
384-
Message: internalErrorMessage,
385-
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
386-
})
387-
}
388-
}
389394

390395
// We only want to update this occasionally to reduce DB write
391396
// load. We update alongside the UserLink and APIKey since it's

coderd/httpmw/apikey_test.go

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,99 @@ func TestAPIKey(t *testing.T) {
508508
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
509509
})
510510

511+
t.Run("APIKeyExpiredOAuthExpired", func(t *testing.T) {
512+
t.Parallel()
513+
var (
514+
db = dbmem.New()
515+
user = dbgen.User(t, db, database.User{})
516+
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
517+
UserID: user.ID,
518+
LastUsed: dbtime.Now().AddDate(0, 0, -1),
519+
ExpiresAt: dbtime.Now().AddDate(0, 0, -1),
520+
LoginType: database.LoginTypeOIDC,
521+
})
522+
_ = dbgen.UserLink(t, db, database.UserLink{
523+
UserID: user.ID,
524+
LoginType: database.LoginTypeOIDC,
525+
OAuthExpiry: dbtime.Now().AddDate(0, 0, -1),
526+
})
527+
528+
r = httptest.NewRequest("GET", "/", nil)
529+
rw = httptest.NewRecorder()
530+
)
531+
r.Header.Set(codersdk.SessionTokenHeader, token)
532+
533+
oauthToken := &oauth2.Token{
534+
AccessToken: "wow",
535+
RefreshToken: "moo",
536+
Expiry: dbtime.Now().AddDate(0, 0, 1),
537+
}
538+
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
539+
DB: db,
540+
OAuth2Configs: &httpmw.OAuth2Configs{
541+
OIDC: &testutil.OAuth2Config{
542+
Token: oauthToken,
543+
},
544+
},
545+
RedirectToLogin: false,
546+
})(successHandler).ServeHTTP(rw, r)
547+
res := rw.Result()
548+
defer res.Body.Close()
549+
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
550+
551+
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
552+
require.NoError(t, err)
553+
554+
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
555+
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
556+
})
557+
558+
t.Run("APIKeyExpiredOAuthNotExpired", func(t *testing.T) {
559+
t.Parallel()
560+
var (
561+
db = dbmem.New()
562+
user = dbgen.User(t, db, database.User{})
563+
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
564+
UserID: user.ID,
565+
LastUsed: dbtime.Now().AddDate(0, 0, -1),
566+
ExpiresAt: dbtime.Now().AddDate(0, 0, -1),
567+
LoginType: database.LoginTypeOIDC,
568+
})
569+
_ = dbgen.UserLink(t, db, database.UserLink{
570+
UserID: user.ID,
571+
LoginType: database.LoginTypeOIDC,
572+
})
573+
574+
r = httptest.NewRequest("GET", "/", nil)
575+
rw = httptest.NewRecorder()
576+
)
577+
r.Header.Set(codersdk.SessionTokenHeader, token)
578+
579+
oauthToken := &oauth2.Token{
580+
AccessToken: "wow",
581+
RefreshToken: "moo",
582+
Expiry: dbtime.Now().AddDate(0, 0, 1),
583+
}
584+
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
585+
DB: db,
586+
OAuth2Configs: &httpmw.OAuth2Configs{
587+
OIDC: &testutil.OAuth2Config{
588+
Token: oauthToken,
589+
},
590+
},
591+
RedirectToLogin: false,
592+
})(successHandler).ServeHTTP(rw, r)
593+
res := rw.Result()
594+
defer res.Body.Close()
595+
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
596+
597+
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
598+
require.NoError(t, err)
599+
600+
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
601+
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
602+
})
603+
511604
t.Run("OAuthRefresh", func(t *testing.T) {
512605
t.Parallel()
513606
var (
@@ -553,7 +646,67 @@ func TestAPIKey(t *testing.T) {
553646
require.NoError(t, err)
554647

555648
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
556-
require.Equal(t, oauthToken.Expiry, gotAPIKey.ExpiresAt)
649+
// Note that OAuth expiry is independent of APIKey expiry, so an OIDC refresh DOES NOT affect the expiry of the
650+
// APIKey
651+
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
652+
653+
gotLink, err := db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
654+
UserID: user.ID,
655+
LoginType: database.LoginTypeGithub,
656+
})
657+
require.NoError(t, err)
658+
require.Equal(t, gotLink.OAuthRefreshToken, "moo")
659+
})
660+
661+
t.Run("OAuthExpiredNoRefresh", func(t *testing.T) {
662+
t.Parallel()
663+
var (
664+
ctx = testutil.Context(t, testutil.WaitShort)
665+
db = dbmem.New()
666+
user = dbgen.User(t, db, database.User{})
667+
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
668+
UserID: user.ID,
669+
LastUsed: dbtime.Now(),
670+
ExpiresAt: dbtime.Now().AddDate(0, 0, 1),
671+
LoginType: database.LoginTypeGithub,
672+
})
673+
674+
r = httptest.NewRequest("GET", "/", nil)
675+
rw = httptest.NewRecorder()
676+
)
677+
_, err := db.InsertUserLink(ctx, database.InsertUserLinkParams{
678+
UserID: user.ID,
679+
LoginType: database.LoginTypeGithub,
680+
OAuthExpiry: dbtime.Now().AddDate(0, 0, -1),
681+
OAuthAccessToken: "letmein",
682+
})
683+
require.NoError(t, err)
684+
685+
r.Header.Set(codersdk.SessionTokenHeader, token)
686+
687+
oauthToken := &oauth2.Token{
688+
AccessToken: "wow",
689+
RefreshToken: "moo",
690+
Expiry: dbtime.Now().AddDate(0, 0, 1),
691+
}
692+
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
693+
DB: db,
694+
OAuth2Configs: &httpmw.OAuth2Configs{
695+
Github: &testutil.OAuth2Config{
696+
Token: oauthToken,
697+
},
698+
},
699+
RedirectToLogin: false,
700+
})(successHandler).ServeHTTP(rw, r)
701+
res := rw.Result()
702+
defer res.Body.Close()
703+
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
704+
705+
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
706+
require.NoError(t, err)
707+
708+
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
709+
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
557710
})
558711

559712
t.Run("RemoteIPUpdates", func(t *testing.T) {

coderd/oauthpki/okidcpki_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ func TestAzureAKPKIWithCoderd(t *testing.T) {
144144
return values, nil
145145
}),
146146
oidctest.WithServing(),
147+
oidctest.WithLogging(t, nil),
147148
)
148149
cfg := fake.OIDCConfig(t, scopes, func(cfg *coderd.OIDCConfig) {
149150
cfg.AllowSignups = true

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy