diff --git a/google/externalaccount/basecredentials.go b/google/externalaccount/basecredentials.go index fc106347d..aa0bba2eb 100644 --- a/google/externalaccount/basecredentials.go +++ b/google/externalaccount/basecredentials.go @@ -263,7 +263,7 @@ const ( fileTypeJSON = "json" ) -// Format contains information needed to retireve a subject token for URL or File sourced credentials. +// Format contains information needed to retrieve a subject token for URL or File sourced credentials. type Format struct { // Type should be either "text" or "json". This determines whether the file or URL sourced credentials // expect a simple text subject token or if the subject token will be contained in a JSON object. diff --git a/google/externalaccount/basecredentials_test.go b/google/externalaccount/basecredentials_test.go index 8f165cdb0..d52f6a789 100644 --- a/google/externalaccount/basecredentials_test.go +++ b/google/externalaccount/basecredentials_test.go @@ -347,12 +347,12 @@ func TestNonworkforceWithWorkforcePoolUserProject(t *testing.T) { t.Fatalf("Expected error but found none") } if got, want := err.Error(), "oauth2/google/externalaccount: Workforce pool user project should not be set for non-workforce pool credentials"; got != want { - t.Errorf("Incorrect error received.\nExpected: %s\nRecieved: %s", want, got) + t.Errorf("Incorrect error received.\nExpected: %s\nReceived: %s", want, got) } } func TestWorkforcePoolCreation(t *testing.T) { - var audienceValidatyTests = []struct { + var audienceValidityTests = []struct { audience string expectSuccess bool }{ @@ -371,7 +371,7 @@ func TestWorkforcePoolCreation(t *testing.T) { } ctx := context.Background() - for _, tt := range audienceValidatyTests { + for _, tt := range audienceValidityTests { t.Run(" "+tt.audience, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. config := testConfig config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL diff --git a/google/externalaccount/executablecredsource_test.go b/google/externalaccount/executablecredsource_test.go index 69ec21ae1..3ecc05f92 100644 --- a/google/externalaccount/executablecredsource_test.go +++ b/google/externalaccount/executablecredsource_test.go @@ -654,7 +654,7 @@ func TestRetrieveOutputFileSubjectTokenNotJSON(t *testing.T) { if _, err = base.subjectToken(); err == nil { t.Fatalf("Expected error but found none") } else if got, want := err.Error(), jsonParsingError(outputFileSource, "tokentokentoken").Error(); got != want { - t.Errorf("Incorrect error received.\nExpected: %s\nRecieved: %s", want, got) + t.Errorf("Incorrect error received.\nExpected: %s\nReceived: %s", want, got) } _, deadlineSet := te.getDeadline() @@ -801,7 +801,7 @@ func TestRetrieveOutputFileSubjectTokenFailureTests(t *testing.T) { if _, err = ecs.subjectToken(); err == nil { t.Errorf("Expected error but found none") } else if got, want := err.Error(), tt.expectedErr.Error(); got != want { - t.Errorf("Incorrect error received.\nExpected: %s\nRecieved: %s", want, got) + t.Errorf("Incorrect error received.\nExpected: %s\nReceived: %s", want, got) } if _, deadlineSet := te.getDeadline(); deadlineSet { @@ -923,7 +923,7 @@ func TestRetrieveOutputFileSubjectTokenInvalidCache(t *testing.T) { } if got, want := out, "tokentokentoken"; got != want { - t.Errorf("Incorrect token received.\nExpected: %s\nRecieved: %s", want, got) + t.Errorf("Incorrect token received.\nExpected: %s\nReceived: %s", want, got) } }) } @@ -1012,7 +1012,7 @@ func TestRetrieveOutputFileSubjectTokenJwt(t *testing.T) { if out, err := ecs.subjectToken(); err != nil { t.Errorf("retrieveSubjectToken() failed: %v", err) } else if got, want := out, "tokentokentoken"; got != want { - t.Errorf("Incorrect token received.\nExpected: %s\nRecieved: %s", want, got) + t.Errorf("Incorrect token received.\nExpected: %s\nReceived: %s", want, got) } if _, deadlineSet := te.getDeadline(); deadlineSet { diff --git a/google/google_test.go b/google/google_test.go index 7078d429f..5aa5e2845 100644 --- a/google/google_test.go +++ b/google/google_test.go @@ -72,7 +72,7 @@ func TestConfigFromJSON(t *testing.T) { t.Errorf("ClientSecret = %q; want %q", got, want) } if got, want := conf.RedirectURL, "https://www.example.com/oauth2callback"; got != want { - t.Errorf("RedictURL = %q; want %q", got, want) + t.Errorf("RedirectURL = %q; want %q", got, want) } if got, want := strings.Join(conf.Scopes, ","), "scope1,scope2"; got != want { t.Errorf("Scopes = %q; want %q", got, want) diff --git a/google/internal/externalaccountauthorizeduser/externalaccountauthorizeduser_test.go b/google/internal/externalaccountauthorizeduser/externalaccountauthorizeduser_test.go index 94bfee3d6..1bbbbac19 100644 --- a/google/internal/externalaccountauthorizeduser/externalaccountauthorizeduser_test.go +++ b/google/internal/externalaccountauthorizeduser/externalaccountauthorizeduser_test.go @@ -38,7 +38,7 @@ type testRefreshTokenServer struct { server *httptest.Server } -func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) { +func TestExternalAccountAuthorizedUser_JustToken(t *testing.T) { config := &Config{ Token: "AAAAAAA", Expiry: now().Add(time.Hour), @@ -57,7 +57,7 @@ func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) { } } -func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t *testing.T) { +func TestExternalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInResponse(t *testing.T) { server := &testRefreshTokenServer{ URL: "/", Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=", @@ -99,7 +99,7 @@ func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t } } -func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) { +func TestExternalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) { server := &testRefreshTokenServer{ URL: "/", Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=", @@ -187,7 +187,7 @@ func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) { }, }, { - name: "missing client secrect", + name: "missing client secret", config: Config{ RefreshToken: "BBBBBBBBB", TokenURL: url, diff --git a/google/internal/stsexchange/sts_exchange_test.go b/google/internal/stsexchange/sts_exchange_test.go index 895b9bcf9..ff9a9ad08 100644 --- a/google/internal/stsexchange/sts_exchange_test.go +++ b/google/internal/stsexchange/sts_exchange_test.go @@ -142,7 +142,7 @@ func TestExchangeToken_Opts(t *testing.T) { } strOpts, ok := data["options"] if !ok { - t.Errorf("Server didn't recieve an \"options\" field.") + t.Errorf("Server didn't receive an \"options\" field.") } else if len(strOpts) < 1 { t.Errorf("\"options\" field has length 0.") } diff --git a/jws/jws.go b/jws/jws.go index 6f03a49d3..27ab06139 100644 --- a/jws/jws.go +++ b/jws/jws.go @@ -116,12 +116,12 @@ func (h *Header) encode() (string, error) { // Decode decodes a claim set from a JWS payload. func Decode(payload string) (*ClaimSet, error) { // decode returned id token to get expiry - s := strings.Split(payload, ".") - if len(s) < 2 { + _, claims, _, ok := parseToken(payload) + if !ok { // TODO(jbd): Provide more context about the error. return nil, errors.New("jws: invalid token received") } - decoded, err := base64.RawURLEncoding.DecodeString(s[1]) + decoded, err := base64.RawURLEncoding.DecodeString(claims) if err != nil { return nil, err } @@ -165,18 +165,34 @@ func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) { // Verify tests whether the provided JWT token's signature was produced by the private key // associated with the supplied public key. func Verify(token string, key *rsa.PublicKey) error { - if strings.Count(token, ".") != 2 { + header, claims, sig, ok := parseToken(token) + if !ok { return errors.New("jws: invalid token received, token must have 3 parts") } - - parts := strings.SplitN(token, ".", 3) - signedContent := parts[0] + "." + parts[1] - signatureString, err := base64.RawURLEncoding.DecodeString(parts[2]) + signatureString, err := base64.RawURLEncoding.DecodeString(sig) if err != nil { return err } h := sha256.New() - h.Write([]byte(signedContent)) + h.Write([]byte(header + tokenDelim + claims)) return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString) } + +func parseToken(s string) (header, claims, sig string, ok bool) { + header, s, ok = strings.Cut(s, tokenDelim) + if !ok { // no period found + return "", "", "", false + } + claims, s, ok = strings.Cut(s, tokenDelim) + if !ok { // only one period found + return "", "", "", false + } + sig, _, ok = strings.Cut(s, tokenDelim) + if ok { // three periods found + return "", "", "", false + } + return header, claims, sig, true +} + +const tokenDelim = "." diff --git a/jws/jws_test.go b/jws/jws_test.go index 39a136a29..1776f56b8 100644 --- a/jws/jws_test.go +++ b/jws/jws_test.go @@ -7,6 +7,8 @@ package jws import ( "crypto/rand" "crypto/rsa" + "net/http" + "strings" "testing" ) @@ -39,8 +41,57 @@ func TestSignAndVerify(t *testing.T) { } func TestVerifyFailsOnMalformedClaim(t *testing.T) { - err := Verify("abc.def", nil) - if err == nil { - t.Error("got no errors; want improperly formed JWT not to be verified") + cases := []struct { + desc string + token string + }{ + { + desc: "no periods", + token: "aa", + }, { + desc: "only one period", + token: "a.a", + }, { + desc: "more than two periods", + token: "a.a.a.a", + }, + } + for _, tc := range cases { + f := func(t *testing.T) { + err := Verify(tc.token, nil) + if err == nil { + t.Error("got no errors; want improperly formed JWT not to be verified") + } + } + t.Run(tc.desc, f) + } +} + +func BenchmarkVerify(b *testing.B) { + cases := []struct { + desc string + token string + }{ + { + desc: "full of periods", + token: strings.Repeat(".", http.DefaultMaxHeaderBytes), + }, { + desc: "two trailing periods", + token: strings.Repeat("a", http.DefaultMaxHeaderBytes-2) + "..", + }, + } + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + b.Fatal(err) + } + for _, bc := range cases { + f := func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for range b.N { + Verify(bc.token, &privateKey.PublicKey) + } + } + b.Run(bc.desc, f) } } diff --git a/oauth2.go b/oauth2.go index 74f052aa9..eacdd7fd9 100644 --- a/oauth2.go +++ b/oauth2.go @@ -288,7 +288,7 @@ func (tf *tokenRefresher) Token() (*Token, error) { if tf.refreshToken != tk.RefreshToken { tf.refreshToken = tk.RefreshToken } - return tk, err + return tk, nil } // reuseTokenSource is a TokenSource that holds a single token in memory @@ -356,11 +356,15 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client { if src == nil { return internal.ContextClient(ctx) } + cc := internal.ContextClient(ctx) return &http.Client{ Transport: &Transport{ - Base: internal.ContextClient(ctx).Transport, + Base: cc.Transport, Source: ReuseTokenSource(nil, src), }, + CheckRedirect: cc.CheckRedirect, + Jar: cc.Jar, + Timeout: cc.Timeout, } } diff --git a/token.go b/token.go index 109997d77..8c31136c4 100644 --- a/token.go +++ b/token.go @@ -169,7 +169,7 @@ func tokenFromInternal(t *internal.Token) *Token { // retrieveToken takes a *Config and uses that to retrieve an *internal.Token. // This token is then mapped from *internal.Token into an *oauth2.Token which is returned along -// with an error.. +// with an error. func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get()) if err != nil { diff --git a/transport_test.go b/transport_test.go index faa87d514..a8e6ea236 100644 --- a/transport_test.go +++ b/transport_test.go @@ -9,12 +9,6 @@ import ( "time" ) -type tokenSource struct{ token *Token } - -func (t *tokenSource) Token() (*Token, error) { - return t.token, nil -} - func TestTransportNilTokenSource(t *testing.T) { tr := &Transport{} server := newMockServer(func(w http.ResponseWriter, r *http.Request) {}) @@ -88,13 +82,10 @@ func TestTransportCloseRequestBodySuccess(t *testing.T) { } func TestTransportTokenSource(t *testing.T) { - ts := &tokenSource{ - token: &Token{ - AccessToken: "abc", - }, - } tr := &Transport{ - Source: ts, + Source: StaticTokenSource(&Token{ + AccessToken: "abc", + }), } server := newMockServer(func(w http.ResponseWriter, r *http.Request) { if got, want := r.Header.Get("Authorization"), "Bearer abc"; got != want { @@ -123,14 +114,11 @@ func TestTransportTokenSourceTypes(t *testing.T) { {key: "basic", val: val, want: "Basic abc"}, } for _, tc := range tests { - ts := &tokenSource{ - token: &Token{ + tr := &Transport{ + Source: StaticTokenSource(&Token{ AccessToken: tc.val, TokenType: tc.key, - }, - } - tr := &Transport{ - Source: ts, + }), } server := newMockServer(func(w http.ResponseWriter, r *http.Request) { if got, want := r.Header.Get("Authorization"), tc.want; got != want {
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: