Skip to content

Commit

Permalink
Populate claims map in AuthenticateJWTWithClaims
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremy-stytch committed Jan 25, 2024
1 parent 6a20e19 commit c36ae3f
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 8 deletions.
39 changes: 31 additions & 8 deletions stytch/consumer/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,19 +258,42 @@ func (c *SessionsClient) AuthenticateJWTWithClaims(
body *sessions.AuthenticateParams,
claims map[string]any,
) (*sessions.AuthenticateResponse, error) {
// This method has a different signature than AuthenticateWithClaims, which we can't change in
// this version of the library. For backward compatibility, populate the claims map by
// mutating it instead of replacing it like the non-JWT version does.
//
// TODO(v12.x): Change claims to `any`, also allow pointer-to-map and pointer-to-struct.
// TODO(v13): Remove support for populating a pre-existing map this way.

var resp *sessions.AuthenticateResponse

// Some special cases can force remote authentication. Otherwise, prefer local validation.
if body.SessionJWT == "" || maxTokenAge == time.Duration(0) {
return c.AuthenticateWithClaims(ctx, body, claims)
var err error
resp, err = c.AuthenticateWithClaims(ctx, body, nil)
if err != nil {
return nil, err
}
} else if session, err := c.AuthenticateJWTLocal(body.SessionJWT, maxTokenAge); err == nil {
resp = &sessions.AuthenticateResponse{
Session: *session,
}
} else {
// JWT couldn't be verified locally. Check with the Stytch API.
resp, err = c.Authenticate(ctx, body)
if err != nil {
return nil, err
}
}

session, err := c.AuthenticateJWTLocal(body.SessionJWT, maxTokenAge)
if err != nil {
// JWT couldn't be verified locally. Check with the Stytch API.
return c.Authenticate(ctx, body)
// Populate claims if possible.
if claims != nil {
for key, val := range resp.Session.CustomClaims {
claims[key] = val
}
}

return &sessions.AuthenticateResponse{
Session: *session,
}, nil
return resp, nil
}

func (c *SessionsClient) AuthenticateJWTLocal(
Expand Down
158 changes: 158 additions & 0 deletions stytch/consumer/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,164 @@ func TestAuthenticateJWTLocalWithClaims(t *testing.T) {
})
}

func TestAuthenticateJWTWithClaims(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/sessions/authenticate" {
// There are many other fields in this response, but these are the only ones we need
// for this test.
_, _ = w.Write([]byte(`{
"session": {
"expires_at": "2022-06-29T19:53:48Z",
"last_accessed_at": "2022-06-29T17:54:13Z",
"session_id": "session-test-aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
"started_at": "2022-06-29T17:53:48Z",
"user_id": "user-test-00000000-0000-0000-0000-000000000000",
"custom_claims": {
"https://my-app.example.net/custom-claim": {
"number": 1,
"array": [1, "foo", null],
"nested": {
"data": "here"
}
}
}
}
}`))
return
}

http.Error(w, "Bad Request", http.StatusBadRequest)
}))

client := &stytch.DefaultClient{
Config: &config.Config{
Env: config.EnvTest,
BaseURI: config.BaseURI(srv.URL),
ProjectID: "project-test-00000000-0000-0000-0000-000000000000",
Secret: "secret-test-11111111-1111-1111-1111-111111111111",
},
// In these tests, the keyset has already been downloaded, so no other network requests
// should be made.
HTTPClient: srv.Client(),
}

key := rsaKey(t)
keyID := "jwk-test-22222222-2222-2222-2222-222222222222"
jwks := keyfunc.NewGiven(map[string]keyfunc.GivenKey{
keyID: keyfunc.NewGivenRSA(&key.PublicKey, keyfunc.GivenKeyOptions{Algorithm: "RS256"}),
})

sessionClient := consumer.NewSessionsClient(client, jwks)

expectedClaims := map[string]any{
"https://my-app.example.net/custom-claim": map[string]any{
// Remember that numbers without specified types unmarshal as float64.
"number": float64(1),
"array": []interface{}{float64(1), "foo", nil},
"nested": map[string]any{
"data": "here",
},
},
}

t.Run("populate claims map", func(t *testing.T) {
iat := time.Now().UTC().Add(-time.Minute).Truncate(time.Second)
exp := iat.Add(5 * time.Minute)

token := signJWT(t, keyID, key, sandboxClaimsCustom(t, iat, exp, expectedClaims))

claims := make(map[string]any)
resp, err := sessionClient.AuthenticateJWTWithClaims(
context.Background(),
10*time.Minute,
&sessions.AuthenticateParams{SessionJWT: token},
claims,
)
require.NoError(t, err)

assert.Equal(t, expectedClaims, claims)
assert.Equal(t, expectedClaims, resp.Session.CustomClaims)
})

t.Run("skip populating a nil map", func(t *testing.T) {
iat := time.Now().UTC().Add(-time.Minute).Truncate(time.Second)
exp := iat.Add(5 * time.Minute)

expected := map[string]any{"special": "val"}
token := signJWT(t, keyID, key, sandboxClaimsCustom(t, iat, exp, expected))

var claims map[string]any
assert.NotPanics(t, func() {
resp, err := sessionClient.AuthenticateJWTWithClaims(
context.Background(),
10*time.Minute,
&sessions.AuthenticateParams{SessionJWT: token},
claims,
)
require.NoError(t, err)
assert.Equal(t, expected, resp.Session.CustomClaims)
})
assert.Empty(t, claims)
})

t.Run("send remote request if needed", func(t *testing.T) {
iat := time.Now().UTC().Add(-time.Minute).Truncate(time.Second)
exp := iat.Add(5 * time.Minute)

token := signJWT(t, keyID, key, sandboxClaimsCustom(t, iat, exp, expectedClaims))

claims := make(map[string]any)
resp, err := sessionClient.AuthenticateJWTWithClaims(
context.Background(),
time.Nanosecond,
&sessions.AuthenticateParams{SessionJWT: token},
claims,
)
require.NoError(t, err)
assert.Equal(t, expectedClaims, resp.Session.CustomClaims)
assert.Equal(t, expectedClaims, claims)
})

t.Run("send remote request if forced, skip claims", func(t *testing.T) {
iat := time.Now().UTC().Add(-time.Minute).Truncate(time.Second)
exp := iat.Add(5 * time.Minute)

token := signJWT(t, keyID, key, sandboxClaimsCustom(t, iat, exp, expectedClaims))

var claims map[string]any
assert.NotPanics(t, func() {
resp, err := sessionClient.AuthenticateJWTWithClaims(
context.Background(),
0,
&sessions.AuthenticateParams{SessionJWT: token},
claims,
)
require.NoError(t, err)
assert.Equal(t, expectedClaims, resp.Session.CustomClaims)
})
assert.Empty(t, claims)
})

t.Run("send remote request if forced, populate claims", func(t *testing.T) {
iat := time.Now().UTC().Add(-time.Minute).Truncate(time.Second)
exp := iat.Add(5 * time.Minute)

token := signJWT(t, keyID, key, sandboxClaimsCustom(t, iat, exp, expectedClaims))

claims := make(map[string]any)
resp, err := sessionClient.AuthenticateJWTWithClaims(
context.Background(),
0,
&sessions.AuthenticateParams{SessionJWT: token},
claims,
)
require.NoError(t, err)
assert.Equal(t, expectedClaims, resp.Session.CustomClaims)
assert.Equal(t, expectedClaims, claims)
})
}

func TestAuthenticateWithClaims(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Handle the async JWKS fetch.
Expand Down

0 comments on commit c36ae3f

Please sign in to comment.