diff --git a/google/externalaccount/basecredentials.go b/google/externalaccount/basecredentials.go index 400aa0a07..6c81a6872 100644 --- a/google/externalaccount/basecredentials.go +++ b/google/externalaccount/basecredentials.go @@ -471,11 +471,12 @@ func (ts tokenSource) Token() (*oauth2.Token, error) { AccessToken: stsResp.AccessToken, TokenType: stsResp.TokenType, } - if stsResp.ExpiresIn < 0 { + + // The RFC8693 doesn't define the explicit 0 of "expires_in" field behavior. + if stsResp.ExpiresIn <= 0 { return nil, fmt.Errorf("oauth2/google/externalaccount: got invalid expiry from security token service") - } else if stsResp.ExpiresIn >= 0 { - accessToken.Expiry = now().Add(time.Duration(stsResp.ExpiresIn) * time.Second) } + accessToken.Expiry = now().Add(time.Duration(stsResp.ExpiresIn) * time.Second) if stsResp.RefreshToken != "" { accessToken.RefreshToken = stsResp.RefreshToken diff --git a/google/externalaccount/basecredentials_test.go b/google/externalaccount/basecredentials_test.go index 33314c3f0..8f165cdb0 100644 --- a/google/externalaccount/basecredentials_test.go +++ b/google/externalaccount/basecredentials_test.go @@ -6,6 +6,7 @@ package externalaccount import ( "context" + "encoding/json" "fmt" "io/ioutil" "net/http" @@ -101,15 +102,18 @@ func run(t *testing.T, config *Config, tets *testExchangeTokenServer) (*oauth2.T return ts.Token() } -func validateToken(t *testing.T, tok *oauth2.Token) { - if got, want := tok.AccessToken, correctAT; got != want { +func validateToken(t *testing.T, tok *oauth2.Token, expectToken *oauth2.Token) { + if expectToken == nil { + return + } + if got, want := tok.AccessToken, expectToken.AccessToken; got != want { t.Errorf("Unexpected access token: got %v, but wanted %v", got, want) } - if got, want := tok.TokenType, "Bearer"; got != want { + if got, want := tok.TokenType, expectToken.TokenType; got != want { t.Errorf("Unexpected TokenType: got %v, but wanted %v", got, want) } - if got, want := tok.Expiry, testNow().Add(time.Duration(3600)*time.Second); got != want { + if got, want := tok.Expiry, expectToken.Expiry; got != want { t.Errorf("Unexpected Expiry: got %v, but wanted %v", got, want) } } @@ -173,30 +177,91 @@ func getExpectedMetricsHeader(source string, saImpersonation bool, configLifetim } func TestToken(t *testing.T) { - config := Config{ - Audience: "32555940559.apps.googleusercontent.com", - SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token", - ClientSecret: "notsosecret", - ClientID: "rbrgnognrhongo3bi4gb9ghg9g", - CredentialSource: &testBaseCredSource, - Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"}, + type MockSTSResponse struct { + AccessToken string `json:"access_token"` + IssuedTokenType string `json:"issued_token_type"` + TokenType string `json:"token_type"` + ExpiresIn int32 `json:"expires_in,omitempty"` + Scope string `json:"scopre,omitenpty"` } - server := testExchangeTokenServer{ - url: "/", - authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=", - contentType: "application/x-www-form-urlencoded", - metricsHeader: getExpectedMetricsHeader("file", false, false), - body: baseCredsRequestBody, - response: baseCredsResponseBody, + testCases := []struct { + name string + responseBody MockSTSResponse + expectToken *oauth2.Token + expectErrorMsg string + }{ + { + name: "happy case", + responseBody: MockSTSResponse{ + AccessToken: correctAT, + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + TokenType: "Bearer", + ExpiresIn: 3600, + Scope: "https://www.googleapis.com/auth/cloud-platform", + }, + expectToken: &oauth2.Token{ + AccessToken: correctAT, + TokenType: "Bearer", + Expiry: testNow().Add(time.Duration(3600) * time.Second), + }, + }, + { + name: "no expiry time on token", + responseBody: MockSTSResponse{ + AccessToken: correctAT, + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + TokenType: "Bearer", + Scope: "https://www.googleapis.com/auth/cloud-platform", + }, + expectToken: nil, + expectErrorMsg: "oauth2/google/externalaccount: got invalid expiry from security token service", + }, + { + name: "negative expiry time", + responseBody: MockSTSResponse{ + AccessToken: correctAT, + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + TokenType: "Bearer", + ExpiresIn: -1, + Scope: "https://www.googleapis.com/auth/cloud-platform", + }, + expectToken: nil, + expectErrorMsg: "oauth2/google/externalaccount: got invalid expiry from security token service", + }, } - tok, err := run(t, &config, &server) + for _, testCase := range testCases { + config := Config{ + Audience: "32555940559.apps.googleusercontent.com", + SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token", + ClientSecret: "notsosecret", + ClientID: "rbrgnognrhongo3bi4gb9ghg9g", + CredentialSource: &testBaseCredSource, + Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"}, + } - if err != nil { - t.Fatalf("Unexpected error: %e", err) + responseBody, err := json.Marshal(testCase.responseBody) + if err != nil { + t.Errorf("Invalid response received.") + } + + server := testExchangeTokenServer{ + url: "/", + authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=", + contentType: "application/x-www-form-urlencoded", + metricsHeader: getExpectedMetricsHeader("file", false, false), + body: baseCredsRequestBody, + response: string(responseBody), + } + + tok, err := run(t, &config, &server) + + if err != nil && err.Error() != testCase.expectErrorMsg { + t.Errorf("Error not as expected: got = %v, and want = %v", err, testCase.expectErrorMsg) + } + validateToken(t, tok, testCase.expectToken) } - validateToken(t, tok) } func TestWorkforcePoolTokenWithClientID(t *testing.T) { @@ -224,7 +289,12 @@ func TestWorkforcePoolTokenWithClientID(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %e", err) } - validateToken(t, tok) + expectToken := oauth2.Token{ + AccessToken: correctAT, + TokenType: "Bearer", + Expiry: testNow().Add(time.Duration(3600) * time.Second), + } + validateToken(t, tok, &expectToken) } func TestWorkforcePoolTokenWithoutClientID(t *testing.T) { @@ -251,7 +321,12 @@ func TestWorkforcePoolTokenWithoutClientID(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %e", err) } - validateToken(t, tok) + expectToken := oauth2.Token{ + AccessToken: correctAT, + TokenType: "Bearer", + Expiry: testNow().Add(time.Duration(3600) * time.Second), + } + validateToken(t, tok, &expectToken) } func TestNonworkforceWithWorkforcePoolUserProject(t *testing.T) {