Skip to content

Commit

Permalink
Support local JWT auth in B2B sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-stytch committed Oct 4, 2023
1 parent 371c1bd commit b40f54f
Show file tree
Hide file tree
Showing 3 changed files with 594 additions and 1 deletion.
119 changes: 118 additions & 1 deletion stytch/b2b/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@ import (
"context"
"encoding/json"
"fmt"
"time"

"github.com/MicahParks/keyfunc/v2"
"github.com/golang-jwt/jwt/v5"
"github.com/mitchellh/mapstructure"
"github.com/stytchauth/stytch-go/v11/stytch"
"github.com/stytchauth/stytch-go/v11/stytch/b2b/sessions"
"github.com/stytchauth/stytch-go/v11/stytch/stytcherror"
)

type SessionsClient struct {
C stytch.Client
C stytch.Client
JWKS *keyfunc.JWKS
}

func NewSessionsClient(c stytch.Client) *SessionsClient {
Expand Down Expand Up @@ -234,3 +238,116 @@ func (c *SessionsClient) GetJWKS(
)
return &retVal, err
}

// MANUAL(AuthenticateJWT)(SERVICE_METHOD)
// ADDIMPORT: "encoding/json"
// ADDIMPORT: "time"
// ADDIMPORT: "github.com/golang-jwt/jwt/v5"
// ADDIMPORT: "github.com/MicahParks/keyfunc/v2"
// ADDIMPORT: "github.com/stytchauth/stytch-go/v11/stytch/stytcherror"

func (c *SessionsClient) AuthenticateJWT(
ctx context.Context,
maxTokenAge time.Duration,
body *sessions.AuthenticateParams,
) (*sessions.AuthenticateResponse, error) {
if body.SessionJWT == "" || maxTokenAge == time.Duration(0) {
return c.Authenticate(ctx, body)
}

session, err := c.AuthenticateJWTLocal(body.SessionJWT, maxTokenAge)
if err != nil {
// JWT couldn't be verified locally. Check with the Stytch API.
return c.Authenticate(ctx, body)
}

return &sessions.AuthenticateResponse{
MemberSession: *session,
}, nil
}

func (c *SessionsClient) AuthenticateJWTWithClaims(
ctx context.Context,
maxTokenAge time.Duration,
body *sessions.AuthenticateParams,
claims map[string]any,
) (*sessions.AuthenticateResponse, error) {
if body.SessionJWT == "" || maxTokenAge == time.Duration(0) {
return c.AuthenticateWithClaims(ctx, body, claims)
}

session, err := c.AuthenticateJWTLocal(body.SessionJWT, maxTokenAge)
if err != nil {
// JWT couldn't be verified locally. Check with the Stytch API.
return c.Authenticate(ctx, body)
}

return &sessions.AuthenticateResponse{
MemberSession: *session,
}, nil
}

func (c *SessionsClient) AuthenticateJWTLocal(
token string,
maxTokenAge time.Duration,
) (*sessions.MemberSession, error) {
if c.JWKS == nil {
return nil, stytcherror.ErrJWKSNotInitialized
}

var claims sessions.Claims

aud := c.C.GetConfig().ProjectID
iss := fmt.Sprintf("stytch.com/%s", c.C.GetConfig().ProjectID)

_, err := jwt.ParseWithClaims(token, &claims, c.JWKS.Keyfunc, jwt.WithAudience(aud), jwt.WithIssuer(iss))
if err != nil {
return nil, fmt.Errorf("failed to parse JWT: %w", err)
}

if claims.RegisteredClaims.IssuedAt.Add(maxTokenAge).Before(time.Now()) {
// The JWT is valid, but older than the tolerable maximum age.
return nil, sessions.ErrJWTTooOld
}

return marshalJWTIntoSession(claims)
}

func marshalJWTIntoSession(claims sessions.Claims) (*sessions.MemberSession, error) {
// For JWTs that include it, prefer the inner expires_at claim.
expiresAt := claims.Session.ExpiresAt
if expiresAt == "" {
expiresAt = claims.RegisteredClaims.ExpiresAt.Time.Format(time.RFC3339)
}

started, err := time.Parse(time.RFC3339, claims.Session.StartedAt)
if err != nil {
return nil, err
}
started = started.UTC()

accessed, err := time.Parse(time.RFC3339, claims.Session.LastAccessedAt)
if err != nil {
return nil, err
}
accessed = accessed.UTC()

expires, err := time.Parse(time.RFC3339, expiresAt)
if err != nil {
return nil, err
}
expires = expires.UTC()

// TODO
return &sessions.MemberSession{
MemberSessionID: claims.Session.ID,
MemberID: claims.RegisteredClaims.Subject,
StartedAt: &started,
LastAccessedAt: &accessed,
ExpiresAt: &expires,
AuthenticationFactors: claims.Session.AuthenticationFactors,
OrganizationID: claims.Organization.ID,
}, nil
}

// ENDMANUAL(AuthenticateJWT)
31 changes: 31 additions & 0 deletions stytch/b2b/sessions/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ package sessions
// !!!

import (
"errors"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/stytchauth/stytch-go/v11/stytch/b2b/mfa"
"github.com/stytchauth/stytch-go/v11/stytch/b2b/organizations"
"github.com/stytchauth/stytch-go/v11/stytch/consumer/sessions"
Expand Down Expand Up @@ -258,3 +260,32 @@ const (
ExchangeRequestLocaleEs ExchangeRequestLocale = "es"
ExchangeRequestLocalePtbr ExchangeRequestLocale = "pt-br"
)

// MANUAL(Types)(TYPES)
// ADDIMPORT: "errors"
// ADDIMPORT: "strings"
// ADDIMPORT: "github.com/golang-jwt/jwt/v5"
// ADDIMPORT: "github.com/stytchauth/stytch-go/v11/stytch/consumer/sessions"

var ErrJWTTooOld = errors.New("JWT too old")

type OrgClaim struct {
ID string `json:"organization_id"`
Slug string `json:"slug"`
}

type Claims struct {
Session sessions.SessionClaim `json:"https://stytch.com/session"`
Organization OrgClaim `json:"https://stytch.com/organization"`
jwt.RegisteredClaims
}

type ClaimsWrapper struct {
Claims map[string]any `json:"custom_claims"`
}

type SessionWrapper struct {
Session ClaimsWrapper `json:"session"`
}

// ENDMANUAL(Types)
Loading

0 comments on commit b40f54f

Please sign in to comment.