Skip to content

Commit 8e0a153

Browse files
authored
chore: implement device auth flow for fake idp (#11707)
* chore: implement device auth flow for fake idp
1 parent 16c6cef commit 8e0a153

File tree

4 files changed

+333
-23
lines changed

4 files changed

+333
-23
lines changed

coderd/coderdtest/oidctest/idp.go

Lines changed: 253 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@ import (
1010
"errors"
1111
"fmt"
1212
"io"
13+
"math/rand"
14+
"mime"
1315
"net"
1416
"net/http"
1517
"net/http/cookiejar"
1618
"net/http/httptest"
1719
"net/url"
20+
"strconv"
1821
"strings"
1922
"testing"
2023
"time"
@@ -34,9 +37,11 @@ import (
3437
"cdr.dev/slog/sloggers/slogtest"
3538
"github.com/coder/coder/v2/coderd"
3639
"github.com/coder/coder/v2/coderd/externalauth"
40+
"github.com/coder/coder/v2/coderd/httpapi"
3741
"github.com/coder/coder/v2/coderd/promoauth"
3842
"github.com/coder/coder/v2/coderd/util/syncmap"
3943
"github.com/coder/coder/v2/codersdk"
44+
"github.com/coder/coder/v2/testutil"
4045
)
4146

4247
type token struct {
@@ -45,6 +50,13 @@ type token struct {
4550
exp time.Time
4651
}
4752

53+
type deviceFlow struct {
54+
// userInput is the expected input to authenticate the device flow.
55+
userInput string
56+
exp time.Time
57+
granted bool
58+
}
59+
4860
// FakeIDP is a functional OIDC provider.
4961
// It only supports 1 OIDC client.
5062
type FakeIDP struct {
@@ -77,6 +89,8 @@ type FakeIDP struct {
7789
refreshTokens *syncmap.Map[string, string]
7890
stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
7991
refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
92+
// Device flow
93+
deviceCode *syncmap.Map[string, deviceFlow]
8094

8195
// hooks
8296
// hookValidRedirectURL can be used to reject a redirect url from the
@@ -226,6 +240,8 @@ const (
226240
authorizePath = "/oauth2/authorize"
227241
keysPath = "/oauth2/keys"
228242
userInfoPath = "/oauth2/userinfo"
243+
deviceAuth = "/login/device/code"
244+
deviceVerify = "/login/device"
229245
)
230246

231247
func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
@@ -246,6 +262,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
246262
refreshTokensUsed: syncmap.New[string, bool](),
247263
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
248264
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
265+
deviceCode: syncmap.New[string, deviceFlow](),
249266
hookOnRefresh: func(_ string) error { return nil },
250267
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
251268
hookValidRedirectURL: func(redirectURL string) error { return nil },
@@ -288,11 +305,12 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
288305
// ProviderJSON is the JSON representation of the OpenID Connect provider
289306
// These are all the urls that the IDP will respond to.
290307
f.provider = ProviderJSON{
291-
Issuer: issuer,
292-
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
293-
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
294-
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
295-
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
308+
Issuer: issuer,
309+
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
310+
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
311+
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
312+
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
313+
DeviceCodeURL: u.ResolveReference(&url.URL{Path: deviceAuth}).String(),
296314
Algorithms: []string{
297315
"RS256",
298316
},
@@ -467,6 +485,31 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f
467485
_ = res.Body.Close()
468486
}
469487

488+
// DeviceLogin does the oauth2 device flow for external auth providers.
489+
func (*FakeIDP) DeviceLogin(t testing.TB, client *codersdk.Client, externalAuthID string) {
490+
// First we need to initiate the device flow. This will have Coder hit the
491+
// fake IDP and get a device code.
492+
device, err := client.ExternalAuthDeviceByID(context.Background(), externalAuthID)
493+
require.NoError(t, err)
494+
495+
// Now the user needs to go to the fake IDP page and click "allow" and enter
496+
// the device code input. For our purposes, we just send an http request to
497+
// the verification url. No additional user input is needed.
498+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
499+
defer cancel()
500+
resp, err := client.Request(ctx, http.MethodPost, device.VerificationURI, nil)
501+
require.NoError(t, err)
502+
defer resp.Body.Close()
503+
504+
// Now we need to exchange the device code for an access token. We do this
505+
// in this method because it is the user that does the polling for the device
506+
// auth flow, not the backend.
507+
err = client.ExternalAuthDeviceExchange(context.Background(), externalAuthID, codersdk.ExternalAuthDeviceExchange{
508+
DeviceCode: device.DeviceCode,
509+
})
510+
require.NoError(t, err)
511+
}
512+
470513
// CreateAuthCode emulates a user clicking "allow" on the IDP page. When doing
471514
// unit tests, it's easier to skip this step sometimes. It does make an actual
472515
// request to the IDP, so it should be equivalent to doing this "manually" with
@@ -536,12 +579,13 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
536579

537580
// ProviderJSON is the .well-known/configuration JSON
538581
type ProviderJSON struct {
539-
Issuer string `json:"issuer"`
540-
AuthURL string `json:"authorization_endpoint"`
541-
TokenURL string `json:"token_endpoint"`
542-
JWKSURL string `json:"jwks_uri"`
543-
UserInfoURL string `json:"userinfo_endpoint"`
544-
Algorithms []string `json:"id_token_signing_alg_values_supported"`
582+
Issuer string `json:"issuer"`
583+
AuthURL string `json:"authorization_endpoint"`
584+
TokenURL string `json:"token_endpoint"`
585+
JWKSURL string `json:"jwks_uri"`
586+
UserInfoURL string `json:"userinfo_endpoint"`
587+
DeviceCodeURL string `json:"device_authorization_endpoint"`
588+
Algorithms []string `json:"id_token_signing_alg_values_supported"`
545589
// This is custom
546590
ExternalAuthURL string `json:"external_auth_url"`
547591
}
@@ -709,8 +753,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
709753
}))
710754

711755
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
712-
values, err := f.authenticateOIDCClientRequest(t, r)
756+
var values url.Values
757+
var err error
758+
if r.URL.Query().Get("grant_type") == "urn:ietf:params:oauth:grant-type:device_code" {
759+
values = r.URL.Query()
760+
} else {
761+
values, err = f.authenticateOIDCClientRequest(t, r)
762+
}
713763
f.logger.Info(r.Context(), "http idp call token",
764+
slog.F("url", r.URL.String()),
714765
slog.F("valid", err == nil),
715766
slog.F("grant_type", values.Get("grant_type")),
716767
slog.F("values", values.Encode()),
@@ -784,6 +835,37 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
784835
f.refreshTokensUsed.Store(refreshToken, true)
785836
// Always invalidate the refresh token after it is used.
786837
f.refreshTokens.Delete(refreshToken)
838+
case "urn:ietf:params:oauth:grant-type:device_code":
839+
// Device flow
840+
var resp externalauth.ExchangeDeviceCodeResponse
841+
deviceCode := values.Get("device_code")
842+
if deviceCode == "" {
843+
resp.Error = "invalid_request"
844+
resp.ErrorDescription = "missing device_code"
845+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp)
846+
return
847+
}
848+
849+
deviceFlow, ok := f.deviceCode.Load(deviceCode)
850+
if !ok {
851+
resp.Error = "invalid_request"
852+
resp.ErrorDescription = "device_code provided not found"
853+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp)
854+
return
855+
}
856+
857+
if !deviceFlow.granted {
858+
// Status code ok with the error as pending.
859+
resp.Error = "authorization_pending"
860+
resp.ErrorDescription = ""
861+
httpapi.Write(r.Context(), rw, http.StatusOK, resp)
862+
return
863+
}
864+
865+
// Would be nice to get an actual email here.
866+
claims = jwt.MapClaims{
867+
"email": "unknown-dev-auth",
868+
}
787869
default:
788870
t.Errorf("unexpected grant_type %q", values.Get("grant_type"))
789871
http.Error(rw, "invalid grant_type", http.StatusBadRequest)
@@ -807,8 +889,30 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
807889
// Store the claims for the next refresh
808890
f.refreshIDTokenClaims.Store(refreshToken, claims)
809891

810-
rw.Header().Set("Content-Type", "application/json")
811-
_ = json.NewEncoder(rw).Encode(token)
892+
mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept"))
893+
if mediaType == "application/x-www-form-urlencoded" {
894+
// This val encode might not work for some data structures.
895+
// It's good enough for now...
896+
rw.Header().Set("Content-Type", "application/x-www-form-urlencoded")
897+
vals := url.Values{}
898+
for k, v := range token {
899+
vals.Set(k, fmt.Sprintf("%v", v))
900+
}
901+
_, _ = rw.Write([]byte(vals.Encode()))
902+
return
903+
}
904+
// Default to json since the oauth2 package doesn't use Accept headers.
905+
if mediaType == "application/json" || mediaType == "" {
906+
rw.Header().Set("Content-Type", "application/json")
907+
_ = json.NewEncoder(rw).Encode(token)
908+
return
909+
}
910+
911+
// If we get something we don't support, throw an error.
912+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
913+
Message: "'Accept' header contains unsupported media type",
914+
Detail: fmt.Sprintf("Found %q", mediaType),
915+
})
812916
}))
813917

814918
validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) {
@@ -886,6 +990,125 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
886990
_ = json.NewEncoder(rw).Encode(set)
887991
}))
888992

993+
mux.Handle(deviceVerify, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
994+
f.logger.Info(r.Context(), "http call device verify")
995+
996+
inputParam := "user_input"
997+
userInput := r.URL.Query().Get(inputParam)
998+
if userInput == "" {
999+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
1000+
Message: "Invalid user input",
1001+
Detail: fmt.Sprintf("Hit this url again with ?%s=<user_code>", inputParam),
1002+
})
1003+
return
1004+
}
1005+
1006+
deviceCode := r.URL.Query().Get("device_code")
1007+
if deviceCode == "" {
1008+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
1009+
Message: "Invalid device code",
1010+
Detail: "Hit this url again with ?device_code=<device_code>",
1011+
})
1012+
return
1013+
}
1014+
1015+
flow, ok := f.deviceCode.Load(deviceCode)
1016+
if !ok {
1017+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
1018+
Message: "Invalid device code",
1019+
Detail: "Device code not found.",
1020+
})
1021+
return
1022+
}
1023+
1024+
if time.Now().After(flow.exp) {
1025+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
1026+
Message: "Invalid device code",
1027+
Detail: "Device code expired.",
1028+
})
1029+
return
1030+
}
1031+
1032+
if strings.TrimSpace(flow.userInput) != strings.TrimSpace(userInput) {
1033+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
1034+
Message: "Invalid device code",
1035+
Detail: "user code does not match",
1036+
})
1037+
return
1038+
}
1039+
1040+
f.deviceCode.Store(deviceCode, deviceFlow{
1041+
userInput: flow.userInput,
1042+
exp: flow.exp,
1043+
granted: true,
1044+
})
1045+
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
1046+
Message: "Device authenticated!",
1047+
})
1048+
}))
1049+
1050+
mux.Handle(deviceAuth, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
1051+
f.logger.Info(r.Context(), "http call device auth")
1052+
1053+
p := httpapi.NewQueryParamParser()
1054+
p.Required("client_id")
1055+
clientID := p.String(r.URL.Query(), "", "client_id")
1056+
_ = p.String(r.URL.Query(), "", "scopes")
1057+
if len(p.Errors) > 0 {
1058+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
1059+
Message: "Invalid query params",
1060+
Validations: p.Errors,
1061+
})
1062+
return
1063+
}
1064+
1065+
if clientID != f.clientID {
1066+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
1067+
Message: "Invalid client id",
1068+
})
1069+
return
1070+
}
1071+
1072+
deviceCode := uuid.NewString()
1073+
lifetime := time.Second * 900
1074+
flow := deviceFlow{
1075+
//nolint:gosec
1076+
userInput: fmt.Sprintf("%d", rand.Intn(9999999)+1e8),
1077+
}
1078+
f.deviceCode.Store(deviceCode, deviceFlow{
1079+
userInput: flow.userInput,
1080+
exp: time.Now().Add(lifetime),
1081+
})
1082+
1083+
verifyURL := f.issuerURL.ResolveReference(&url.URL{
1084+
Path: deviceVerify,
1085+
RawQuery: url.Values{
1086+
"device_code": {deviceCode},
1087+
"user_input": {flow.userInput},
1088+
}.Encode(),
1089+
}).String()
1090+
1091+
if mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept")); mediaType == "application/json" {
1092+
httpapi.Write(r.Context(), rw, http.StatusOK, map[string]any{
1093+
"device_code": deviceCode,
1094+
"user_code": flow.userInput,
1095+
"verification_uri": verifyURL,
1096+
"expires_in": int(lifetime.Seconds()),
1097+
"interval": 3,
1098+
})
1099+
return
1100+
}
1101+
1102+
// By default, GitHub form encodes these.
1103+
_, _ = fmt.Fprint(rw, url.Values{
1104+
"device_code": {deviceCode},
1105+
"user_code": {flow.userInput},
1106+
"verification_uri": {verifyURL},
1107+
"expires_in": {strconv.Itoa(int(lifetime.Seconds()))},
1108+
"interval": {"3"},
1109+
}.Encode())
1110+
}))
1111+
8891112
mux.NotFound(func(rw http.ResponseWriter, r *http.Request) {
8901113
f.logger.Error(r.Context(), "http call not found", slog.F("path", r.URL.Path))
8911114
t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path)
@@ -987,6 +1210,8 @@ type ExternalAuthConfigOptions struct {
9871210
// completely customize the response. It captures all routes under the /external-auth-validate/*
9881211
// so the caller can do whatever they want and even add routes.
9891212
routes map[string]func(email string, rw http.ResponseWriter, r *http.Request)
1213+
1214+
UseDeviceAuth bool
9901215
}
9911216

9921217
func (o *ExternalAuthConfigOptions) AddRoute(route string, handle func(email string, rw http.ResponseWriter, r *http.Request)) *ExternalAuthConfigOptions {
@@ -1033,17 +1258,30 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
10331258
}
10341259
}
10351260
instrumentF := promoauth.NewFactory(prometheus.NewRegistry())
1261+
oauthCfg := instrumentF.New(f.clientID, f.OIDCConfig(t, nil))
10361262
cfg := &externalauth.Config{
10371263
DisplayName: id,
1038-
InstrumentedOAuth2Config: instrumentF.New(f.clientID, f.OIDCConfig(t, nil)),
1264+
InstrumentedOAuth2Config: oauthCfg,
10391265
ID: id,
10401266
// No defaults for these fields by omitting the type
10411267
Type: "",
10421268
DisplayIcon: f.WellknownConfig().UserInfoURL,
10431269
// Omit the /user for the validate so we can easily append to it when modifying
10441270
// the cfg for advanced tests.
10451271
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
1272+
DeviceAuth: &externalauth.DeviceAuth{
1273+
Config: oauthCfg,
1274+
ClientID: f.clientID,
1275+
TokenURL: f.provider.TokenURL,
1276+
Scopes: []string{},
1277+
CodeURL: f.provider.DeviceCodeURL,
1278+
},
1279+
}
1280+
1281+
if !custom.UseDeviceAuth {
1282+
cfg.DeviceAuth = nil
10461283
}
1284+
10471285
for _, opt := range opts {
10481286
opt(cfg)
10491287
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy