diff --git a/stytch/b2b/sessions.go b/stytch/b2b/sessions.go index d2911fe..ca4754c 100644 --- a/stytch/b2b/sessions.go +++ b/stytch/b2b/sessions.go @@ -10,7 +10,10 @@ import ( "context" "encoding/json" "fmt" + "time" + "github.com/MicahParks/keyfunc/v2" + "github.com/golang-jwt/jwt/v5" "github.com/mitchellh/mapstructure" "github.com/stytchauth/stytch-go/v11/stytch" "github.com/stytchauth/stytch-go/v11/stytch/b2b/sessions" @@ -18,7 +21,8 @@ import ( ) type SessionsClient struct { - C stytch.Client + C stytch.Client + JWKS *keyfunc.JWKS } func NewSessionsClient(c stytch.Client) *SessionsClient { @@ -234,3 +238,116 @@ func (c *SessionsClient) GetJWKS( ) return &retVal, err } + +// MANUAL(AuthenticateJWT)(SERVICE_METHOD) +// ADDIMPORT: "encoding/json" +// ADDIMPORT: "time" +// ADDIMPORT: "github.com/golang-jwt/jwt/v5" +// ADDIMPORT: "github.com/MicahParks/keyfunc/v2" +// ADDIMPORT: "github.com/stytchauth/stytch-go/v11/stytch/stytcherror" + +func (c *SessionsClient) AuthenticateJWT( + ctx context.Context, + maxTokenAge time.Duration, + body *sessions.AuthenticateParams, +) (*sessions.AuthenticateResponse, error) { + if body.SessionJWT == "" || maxTokenAge == time.Duration(0) { + return c.Authenticate(ctx, body) + } + + session, err := c.AuthenticateJWTLocal(body.SessionJWT, maxTokenAge) + if err != nil { + // JWT couldn't be verified locally. Check with the Stytch API. + return c.Authenticate(ctx, body) + } + + return &sessions.AuthenticateResponse{ + MemberSession: *session, + }, nil +} + +func (c *SessionsClient) AuthenticateJWTWithClaims( + ctx context.Context, + maxTokenAge time.Duration, + body *sessions.AuthenticateParams, + claims map[string]any, +) (*sessions.AuthenticateResponse, error) { + if body.SessionJWT == "" || maxTokenAge == time.Duration(0) { + return c.AuthenticateWithClaims(ctx, body, claims) + } + + session, err := c.AuthenticateJWTLocal(body.SessionJWT, maxTokenAge) + if err != nil { + // JWT couldn't be verified locally. Check with the Stytch API. + return c.Authenticate(ctx, body) + } + + return &sessions.AuthenticateResponse{ + MemberSession: *session, + }, nil +} + +func (c *SessionsClient) AuthenticateJWTLocal( + token string, + maxTokenAge time.Duration, +) (*sessions.MemberSession, error) { + if c.JWKS == nil { + return nil, stytcherror.ErrJWKSNotInitialized + } + + var claims sessions.Claims + + aud := c.C.GetConfig().ProjectID + iss := fmt.Sprintf("stytch.com/%s", c.C.GetConfig().ProjectID) + + _, err := jwt.ParseWithClaims(token, &claims, c.JWKS.Keyfunc, jwt.WithAudience(aud), jwt.WithIssuer(iss)) + if err != nil { + return nil, fmt.Errorf("failed to parse JWT: %w", err) + } + + if claims.RegisteredClaims.IssuedAt.Add(maxTokenAge).Before(time.Now()) { + // The JWT is valid, but older than the tolerable maximum age. + return nil, sessions.ErrJWTTooOld + } + + return marshalJWTIntoSession(claims) +} + +func marshalJWTIntoSession(claims sessions.Claims) (*sessions.MemberSession, error) { + // For JWTs that include it, prefer the inner expires_at claim. + expiresAt := claims.Session.ExpiresAt + if expiresAt == "" { + expiresAt = claims.RegisteredClaims.ExpiresAt.Time.Format(time.RFC3339) + } + + started, err := time.Parse(time.RFC3339, claims.Session.StartedAt) + if err != nil { + return nil, err + } + started = started.UTC() + + accessed, err := time.Parse(time.RFC3339, claims.Session.LastAccessedAt) + if err != nil { + return nil, err + } + accessed = accessed.UTC() + + expires, err := time.Parse(time.RFC3339, expiresAt) + if err != nil { + return nil, err + } + expires = expires.UTC() + + // TODO + return &sessions.MemberSession{ + MemberSessionID: claims.Session.ID, + MemberID: claims.RegisteredClaims.Subject, + StartedAt: &started, + LastAccessedAt: &accessed, + ExpiresAt: &expires, + AuthenticationFactors: claims.Session.AuthenticationFactors, + OrganizationID: claims.Organization.ID, + }, nil +} + +// ENDMANUAL(AuthenticateJWT) diff --git a/stytch/b2b/sessions/types.go b/stytch/b2b/sessions/types.go index c187561..6212c37 100644 --- a/stytch/b2b/sessions/types.go +++ b/stytch/b2b/sessions/types.go @@ -7,8 +7,10 @@ package sessions // !!! import ( + "errors" "time" + "github.com/golang-jwt/jwt/v5" "github.com/stytchauth/stytch-go/v11/stytch/b2b/mfa" "github.com/stytchauth/stytch-go/v11/stytch/b2b/organizations" "github.com/stytchauth/stytch-go/v11/stytch/consumer/sessions" @@ -258,3 +260,32 @@ const ( ExchangeRequestLocaleEs ExchangeRequestLocale = "es" ExchangeRequestLocalePtbr ExchangeRequestLocale = "pt-br" ) + +// MANUAL(Types)(TYPES) +// ADDIMPORT: "errors" +// ADDIMPORT: "strings" +// ADDIMPORT: "github.com/golang-jwt/jwt/v5" +// ADDIMPORT: "github.com/stytchauth/stytch-go/v11/stytch/consumer/sessions" + +var ErrJWTTooOld = errors.New("JWT too old") + +type OrgClaim struct { + ID string `json:"organization_id"` + Slug string `json:"slug"` +} + +type Claims struct { + Session sessions.SessionClaim `json:"https://stytch.com/session"` + Organization OrgClaim `json:"https://stytch.com/organization"` + jwt.RegisteredClaims +} + +type ClaimsWrapper struct { + Claims map[string]any `json:"custom_claims"` +} + +type SessionWrapper struct { + Session ClaimsWrapper `json:"session"` +} + +// ENDMANUAL(Types) diff --git a/stytch/b2b/sessions_test.go b/stytch/b2b/sessions_test.go new file mode 100644 index 0000000..48ac96e --- /dev/null +++ b/stytch/b2b/sessions_test.go @@ -0,0 +1,445 @@ +package b2b_test + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stytchauth/stytch-go/v11/stytch/b2b" + "github.com/stytchauth/stytch-go/v11/stytch/b2b/b2bstytchapi" + "github.com/stytchauth/stytch-go/v11/stytch/b2b/sessions" + consumersessions "github.com/stytchauth/stytch-go/v11/stytch/consumer/sessions" + + "github.com/MicahParks/keyfunc/v2" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stytchauth/stytch-go/v11/stytch" + "github.com/stytchauth/stytch-go/v11/stytch/config" +) + +func TestAuthenticateJWTLocal(t *testing.T) { + client := &stytch.DefaultClient{ + Config: &config.Config{ + Env: config.EnvTest, + BaseURI: "https://example.test/v1/b2b/", + ProjectID: "project-test-00000000-0000-0000-0000-000000000000", + Secret: "secret-test-11111111-1111-1111-1111-111111111111", + }, + // In these tests, the keyset has already been downloaded, so no other network requests + // should be made. + HTTPClient: nil, + } + + key := rsaKey(t) + keyID := "jwk-test-22222222-2222-2222-2222-222222222222" + jwks := keyfunc.NewGiven(map[string]keyfunc.GivenKey{ + keyID: keyfunc.NewGivenRSA(&key.PublicKey, keyfunc.GivenKeyOptions{Algorithm: "RS256"}), + }) + + sessionClient := b2b.NewSessionsClient(client) + sessionClient.JWKS = jwks + + t.Run("expired JWT", func(t *testing.T) { + iat := time.Now().UTC().Add(-time.Hour).Truncate(time.Second) + exp := iat.Add(time.Minute) + + claims := sandboxClaims(t, iat, exp) + token := signJWT(t, keyID, key, claims) + + s, err := sessionClient.AuthenticateJWTLocal(token, 10*time.Minute) + assert.ErrorIs(t, err, jwt.ErrTokenExpired) + assert.Nil(t, s) + }) + + t.Run("stale JWT", func(t *testing.T) { + iat := time.Now().UTC().Add(-3 * time.Minute).Truncate(time.Second) + exp := iat.Add(time.Hour) + + claims := sandboxClaims(t, iat, exp) + token := signJWT(t, keyID, key, claims) + + s, err := sessionClient.AuthenticateJWTLocal(token, 1*time.Minute) + assert.ErrorIs(t, err, sessions.ErrJWTTooOld) + assert.Nil(t, s) + }) + + t.Run("incorrect audience", func(t *testing.T) { + iat := time.Now().UTC().Truncate(time.Second) + exp := iat.Add(time.Hour) + + claims := sandboxClaims(t, iat, exp) + claims.Audience = jwt.ClaimStrings{"not this project"} + + token := signJWT(t, keyID, key, claims) + + s, err := sessionClient.AuthenticateJWTLocal(token, 1*time.Minute) + assert.ErrorIs(t, err, jwt.ErrTokenInvalidAudience) + assert.Nil(t, s) + }) + + t.Run("incorrect issuer", func(t *testing.T) { + iat := time.Now().UTC().Truncate(time.Second) + exp := iat.Add(time.Hour) + + claims := sandboxClaims(t, iat, exp) + claims.Issuer = "not this project" + + token := signJWT(t, keyID, key, claims) + + s, err := sessionClient.AuthenticateJWTLocal(token, 1*time.Minute) + assert.ErrorIs(t, err, jwt.ErrTokenInvalidIssuer) + assert.Nil(t, s) + }) + + t.Run("valid JWT", func(t *testing.T) { + iat := time.Now().UTC().Truncate(time.Second) + exp := iat.Add(time.Hour) + + claims := sandboxClaims(t, iat, exp) + token := signJWT(t, keyID, key, claims) + + session, err := sessionClient.AuthenticateJWTLocal(token, 3*time.Minute) + require.NoError(t, err) + + expected := &sessions.MemberSession{ + MemberSessionID: "session-live-e26a0ccb-0dc0-4edb-a4bb-e70210f43555", + MemberID: "member-live-fde03dd1-fff7-4b3c-9b31-ead3fbc224de", + StartedAt: &iat, + LastAccessedAt: &iat, + ExpiresAt: &exp, + AuthenticationFactors: []consumersessions.AuthenticationFactor{ + { + Type: "magic_link", + DeliveryMethod: "email", + LastAuthenticatedAt: &iat, + EmailFactor: &consumersessions.EmailFactor{ + EmailAddress: "sandbox@stytch.com", + EmailID: "email-live-cca9d7d0-11b6-4167-9385-d7e0c9a77418", + }, + }, + }, + } + assert.Equal(t, expected, session) + }) + + t.Run("valid JWT (old format)", func(t *testing.T) { + iat := time.Now().UTC().Truncate(time.Second) + exp := iat.Add(time.Hour) + sessionExp := iat.Add(5 * time.Minute) + + claims := sandboxClaims(t, iat, exp) + claims.Session.ExpiresAt = "" + token := signJWT(t, keyID, key, claims) + + session, err := sessionClient.AuthenticateJWTLocal(token, 3*time.Minute) + require.NoError(t, err) + + expected := &sessions.MemberSession{ + MemberSessionID: "session-live-e26a0ccb-0dc0-4edb-a4bb-e70210f43555", + MemberID: "member-live-fde03dd1-fff7-4b3c-9b31-ead3fbc224de", + StartedAt: &iat, + LastAccessedAt: &iat, + ExpiresAt: &sessionExp, + AuthenticationFactors: []consumersessions.AuthenticationFactor{ + { + Type: "magic_link", + DeliveryMethod: "email", + LastAuthenticatedAt: &iat, + EmailFactor: &consumersessions.EmailFactor{ + EmailAddress: "sandbox@stytch.com", + EmailID: "email-live-cca9d7d0-11b6-4167-9385-d7e0c9a77418", + }, + }, + }, + } + assert.Equal(t, expected, session) + }) +} + +func TestAuthenticateWithClaims(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle the async JWKS fetch. + if strings.HasPrefix(r.URL.Path, "/v1/b2b/sessions/jwks/") { + _, _ = w.Write([]byte(`{"keys": []}`)) + return + } + + // This is the test request + if r.URL.Path == "/v1/b2b/sessions/authenticate" { + // There are many other fields in this response, but these are the only ones we need + // for this test. + _, _ = w.Write([]byte(`{ + "member_session": { + "expires_at": "2022-06-29T19:53:48Z", + "last_accessed_at": "2022-06-29T17:54:13Z", + "member_session_id": "session-test-aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", + "started_at": "2022-06-29T17:53:48Z", + "member_id": "user-test-00000000-0000-0000-0000-000000000000", + + "custom_claims": { + "https://my-app.example.net/custom-claim": { + "number": 1, + "array": [1, "foo", null], + "nested": { + "data": "here" + } + } + } + }, + "member": {}, + "organization": {} + }`)) + return + } + + http.Error(w, "Bad Request", http.StatusBadRequest) + })) + + client, err := b2bstytchapi.NewClient( + "project-test-00000000-0000-0000-0000-000000000000", + "secret-test-11111111-1111-1111-1111-111111111111", + b2bstytchapi.WithBaseURI(srv.URL), + ) + require.NoError(t, err) + + req := &sessions.AuthenticateParams{ + SessionToken: "fake session token", + } + + t.Run("marshaling claims into a map", func(t *testing.T) { + var claims map[string]any + _, err := client.Sessions.AuthenticateWithClaims(context.Background(), req, &claims) + require.NoError(t, err) + + type object = map[string]any + expected := object{ + "https://my-app.example.net/custom-claim": object{ + // Remember that numbers without specified types unmarshal as float64. + "number": float64(1), + "array": []interface{}{float64(1), "foo", nil}, + "nested": object{ + "data": "here", + }, + }, + } + assert.Equal(t, expected, claims) + }) + t.Run("marshaling claims into a struct", func(t *testing.T) { + type MyAppClaims struct { + Number int + Array []interface{} + Nested struct { + Data string + } + } + + type Claims struct { + MyApp MyAppClaims `json:"https://my-app.example.net/custom-claim"` + } + + { + var claims Claims + _, err = client.Sessions.AuthenticateWithClaims(context.Background(), req, &claims) + require.NoError(t, err) + expected := Claims{ + MyApp: MyAppClaims{ + Number: 1, + // Remember that numbers without specified types unmarshal as float64. + Array: []interface{}{float64(1), "foo", nil}, + Nested: struct{ Data string }{Data: "here"}, + }, + } + assert.Equal(t, expected, claims) + } + }) +} + +func ExampleSessionsClient_AuthenticateWithClaims_map() { + // If we know that our claims will follow this exact map structure, we can marshal the + // custom claims from the response into it + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle the async JWKS fetch. + if strings.HasPrefix(r.URL.Path, "/v1/b2b/sessions/jwks/") { + _, _ = w.Write([]byte(`{"keys": []}`)) + return + } + + // This is the test request + if r.URL.Path == "/v1/b2b/sessions/authenticate" { + // There are many other fields in this response, but these are the only ones we need + // for this test. + _, _ = w.Write([]byte(`{ + "member_session": { + "expires_at": "2022-06-29T19:53:48Z", + "last_accessed_at": "2022-06-29T17:54:13Z", + "member_session_id": "session-test-aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", + "started_at": "2022-06-29T17:53:48Z", + "member_id": "user-test-00000000-0000-0000-0000-000000000000", + + "custom_claims": { + "https://my-app.example.net/custom-claim": { + "claim1": 1, + "claim2": 2, + "claim3": 3 + } + } + }, + "member": {}, + "organization": {} + }`)) + return + } + + http.Error(w, "Bad Request", http.StatusBadRequest) + })) + + client, _ := b2bstytchapi.NewClient( + "project-test-00000000-0000-0000-0000-000000000000", + "secret-test-11111111-1111-1111-1111-111111111111", + b2bstytchapi.WithBaseURI(srv.URL), + ) + + // Expecting a map where all the values are maps from strings to integers + var mapClaims map[string]map[string]int32 + _, _ = client.Sessions.AuthenticateWithClaims( + context.Background(), + &sessions.AuthenticateParams{ + SessionToken: "fake session token", + }, + &mapClaims, + ) + + fmt.Println(mapClaims) + // Output: map[https://my-app.example.net/custom-claim:map[claim1:1 claim2:2 claim3:3]] +} + +func ExampleSessionsClient_AuthenticateWithClaims_struct() { + // When we define a struct that follows the shape of our claims, we can marshal the + // custom claims from the response into it + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle the async JWKS fetch. + if strings.HasPrefix(r.URL.Path, "/v1/b2b/sessions/jwks/") { + _, _ = w.Write([]byte(`{"keys": []}`)) + return + } + + // This is the test request + if r.URL.Path == "/v1/b2b/sessions/authenticate" { + // There are many other fields in this response, but these are the only ones we need + // for this test. + _, _ = w.Write([]byte(`{ + "member_session": { + "expires_at": "2022-06-29T19:53:48Z", + "last_accessed_at": "2022-06-29T17:54:13Z", + "member_session_id": "session-test-aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", + "started_at": "2022-06-29T17:53:48Z", + "member_id": "user-test-00000000-0000-0000-0000-000000000000", + + "custom_claims": { + "https://my-app.example.net/custom-claim": { + "number": 1, + "array": [1, "foo", null], + "nested": { + "data": "here" + } + } + } + }, + "member": {}, + "organization": {} + }`)) + return + } + + http.Error(w, "Bad Request", http.StatusBadRequest) + })) + + client, _ := b2bstytchapi.NewClient( + "project-test-00000000-0000-0000-0000-000000000000", + "secret-test-11111111-1111-1111-1111-111111111111", + b2bstytchapi.WithBaseURI(srv.URL), + ) + + // Expecting claims to follow this exact data structure + type MyAppClaims struct { + Number int + Array []interface{} + Nested struct { + Data string + } + } + type StructClaims struct { + MyApp MyAppClaims `json:"https://my-app.example.net/custom-claim"` + } + + var structClaims StructClaims + _, _ = client.Sessions.AuthenticateWithClaims( + context.Background(), + &sessions.AuthenticateParams{ + SessionToken: "fake session token", + }, + &structClaims, + ) + + fmt.Println(structClaims) + // Output: {{1 [1 foo ] {here}}} +} + +func rsaKey(t *testing.T) *rsa.PrivateKey { + // This short key length is fine for test data. We won't actually use the keys for anything. + // + // #nosec G403 + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatalf("generate test RSA key: %s", err) + } + return key +} + +func signJWT(t *testing.T, keyID string, key *rsa.PrivateKey, claims jwt.Claims) string { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = keyID + + signed, err := token.SignedString(key) + if err != nil { + t.Fatalf("sign JWT: %s", err) + } + return signed +} + +func sandboxClaims(t *testing.T, iat, exp time.Time) sessions.Claims { + return sessions.Claims{ + Session: consumersessions.SessionClaim{ + ID: "session-live-e26a0ccb-0dc0-4edb-a4bb-e70210f43555", + StartedAt: iat.Format(time.RFC3339), + LastAccessedAt: iat.Format(time.RFC3339), + ExpiresAt: exp.Format(time.RFC3339), + AuthenticationFactors: []consumersessions.AuthenticationFactor{ + { + Type: "magic_link", + DeliveryMethod: "email", + LastAuthenticatedAt: &iat, + EmailFactor: &consumersessions.EmailFactor{ + EmailAddress: "sandbox@stytch.com", + EmailID: "email-live-cca9d7d0-11b6-4167-9385-d7e0c9a77418", + }, + }, + }, + }, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "stytch.com/project-test-00000000-0000-0000-0000-000000000000", + Audience: []string{"project-test-00000000-0000-0000-0000-000000000000"}, + Subject: "member-live-fde03dd1-fff7-4b3c-9b31-ead3fbc224de", + IssuedAt: jwt.NewNumericDate(iat), + NotBefore: jwt.NewNumericDate(iat), + ExpiresAt: jwt.NewNumericDate(iat.Add(5 * time.Minute)), + }, + } +}