From e07cfc3d80abf586b608798d62857921e0ca1205 Mon Sep 17 00:00:00 2001 From: nikhil-stytch <80711262+nikhil-stytch@users.noreply.github.com> Date: Thu, 7 Dec 2023 21:28:33 -0500 Subject: [PATCH] Local Authorization Check for RBAC (#157) Co-authored-by: Logan Gore --- stytch/b2b/b2bstytchapi/b2bstytchapi.go | 20 ++-- stytch/b2b/discovery.go | 3 +- stytch/b2b/magiclinks.go | 3 +- stytch/b2b/magiclinks_email.go | 3 +- stytch/b2b/oauth.go | 3 +- stytch/b2b/organizations.go | 3 +- stytch/b2b/otp.go | 3 +- stytch/b2b/passwords.go | 3 +- stytch/b2b/rbac.go | 70 +++++++++++ stytch/b2b/rbac/types.go | 35 ++++++ stytch/b2b/sessions.go | 40 +++++-- stytch/b2b/sessions/types.go | 9 +- stytch/b2b/sessions_test.go | 19 +-- stytch/b2b/sso.go | 3 +- stytch/config/version.go | 2 +- stytch/consumer/m2m.go | 6 +- stytch/consumer/m2m_clients.go | 3 +- stytch/consumer/m2m_test.go | 7 +- stytch/consumer/magiclinks.go | 3 +- stytch/consumer/otp.go | 3 +- stytch/consumer/passwords.go | 3 +- stytch/consumer/sessions.go | 5 +- stytch/consumer/sessions_test.go | 3 +- stytch/consumer/stytchapi/stytchapi.go | 18 ++- stytch/shared/rbac_local.go | 46 +++++++ stytch/shared/rbac_local_test.go | 153 ++++++++++++++++++++++++ 26 files changed, 413 insertions(+), 56 deletions(-) create mode 100644 stytch/b2b/rbac.go create mode 100644 stytch/b2b/rbac/types.go create mode 100644 stytch/shared/rbac_local.go create mode 100644 stytch/shared/rbac_local_test.go diff --git a/stytch/b2b/b2bstytchapi/b2bstytchapi.go b/stytch/b2b/b2bstytchapi/b2bstytchapi.go index d5890f1..60a7ce9 100644 --- a/stytch/b2b/b2bstytchapi/b2bstytchapi.go +++ b/stytch/b2b/b2bstytchapi/b2bstytchapi.go @@ -41,6 +41,7 @@ type API struct { OTPs *b2b.OTPsClient Organizations *b2b.OrganizationsClient Passwords *b2b.PasswordsClient + RBAC *b2b.RBACClient SSO *b2b.SSOClient Sessions *b2b.SessionsClient } @@ -129,21 +130,24 @@ func NewClient(projectID string, secret string, opts ...Option) (*API, error) { o(a) } + policyCache := b2b.NewPolicyCache(b2b.NewRBACClient(a.client)) + + // Set up JWKS for local session authentication + jwks, err := a.instantiateJWKSClient(a.client.GetHTTPClient()) + if err != nil { + return nil, fmt.Errorf("fetch JWKS from URL: %w", err) + } + a.Discovery = b2b.NewDiscoveryClient(a.client) - a.M2M = consumer.NewM2MClient(a.client) + a.M2M = consumer.NewM2MClient(a.client, jwks) a.MagicLinks = b2b.NewMagicLinksClient(a.client) a.OAuth = b2b.NewOAuthClient(a.client) a.OTPs = b2b.NewOTPsClient(a.client) a.Organizations = b2b.NewOrganizationsClient(a.client) a.Passwords = b2b.NewPasswordsClient(a.client) + a.RBAC = b2b.NewRBACClient(a.client) a.SSO = b2b.NewSSOClient(a.client) - a.Sessions = b2b.NewSessionsClient(a.client) - // Set up JWKS for local session authentication - jwks, err := a.instantiateJWKSClient(a.client.GetHTTPClient()) - if err != nil { - return nil, fmt.Errorf("fetch JWKS from URL: %w", err) - } - a.M2M.JWKS = jwks + a.Sessions = b2b.NewSessionsClient(a.client, jwks, policyCache) return a, nil } diff --git a/stytch/b2b/discovery.go b/stytch/b2b/discovery.go index 628e873..96e8601 100644 --- a/stytch/b2b/discovery.go +++ b/stytch/b2b/discovery.go @@ -18,7 +18,8 @@ type DiscoveryClient struct { func NewDiscoveryClient(c stytch.Client) *DiscoveryClient { return &DiscoveryClient{ - C: c, + C: c, + IntermediateSessions: NewDiscoveryIntermediateSessionsClient(c), Organizations: NewDiscoveryOrganizationsClient(c), } diff --git a/stytch/b2b/magiclinks.go b/stytch/b2b/magiclinks.go index 2cfc210..72401b2 100644 --- a/stytch/b2b/magiclinks.go +++ b/stytch/b2b/magiclinks.go @@ -25,7 +25,8 @@ type MagicLinksClient struct { func NewMagicLinksClient(c stytch.Client) *MagicLinksClient { return &MagicLinksClient{ - C: c, + C: c, + Email: NewMagicLinksEmailClient(c), Discovery: NewMagicLinksDiscoveryClient(c), } diff --git a/stytch/b2b/magiclinks_email.go b/stytch/b2b/magiclinks_email.go index 65532bb..814e65d 100644 --- a/stytch/b2b/magiclinks_email.go +++ b/stytch/b2b/magiclinks_email.go @@ -22,7 +22,8 @@ type MagicLinksEmailClient struct { func NewMagicLinksEmailClient(c stytch.Client) *MagicLinksEmailClient { return &MagicLinksEmailClient{ - C: c, + C: c, + Discovery: NewMagicLinksEmailDiscoveryClient(c), } } diff --git a/stytch/b2b/oauth.go b/stytch/b2b/oauth.go index 6ceb3a0..7629034 100644 --- a/stytch/b2b/oauth.go +++ b/stytch/b2b/oauth.go @@ -24,7 +24,8 @@ type OAuthClient struct { func NewOAuthClient(c stytch.Client) *OAuthClient { return &OAuthClient{ - C: c, + C: c, + Discovery: NewOAuthDiscoveryClient(c), } } diff --git a/stytch/b2b/organizations.go b/stytch/b2b/organizations.go index 8d76673..4a003e7 100644 --- a/stytch/b2b/organizations.go +++ b/stytch/b2b/organizations.go @@ -23,7 +23,8 @@ type OrganizationsClient struct { func NewOrganizationsClient(c stytch.Client) *OrganizationsClient { return &OrganizationsClient{ - C: c, + C: c, + Members: NewOrganizationsMembersClient(c), } } diff --git a/stytch/b2b/otp.go b/stytch/b2b/otp.go index 3e7fd50..e5a65e7 100644 --- a/stytch/b2b/otp.go +++ b/stytch/b2b/otp.go @@ -17,7 +17,8 @@ type OTPsClient struct { func NewOTPsClient(c stytch.Client) *OTPsClient { return &OTPsClient{ - C: c, + C: c, + Sms: NewOTPsSmsClient(c), } } diff --git a/stytch/b2b/passwords.go b/stytch/b2b/passwords.go index 84610a9..04c7655 100644 --- a/stytch/b2b/passwords.go +++ b/stytch/b2b/passwords.go @@ -26,7 +26,8 @@ type PasswordsClient struct { func NewPasswordsClient(c stytch.Client) *PasswordsClient { return &PasswordsClient{ - C: c, + C: c, + Email: NewPasswordsEmailClient(c), Sessions: NewPasswordsSessionsClient(c), ExistingPassword: NewPasswordsExistingPasswordClient(c), diff --git a/stytch/b2b/rbac.go b/stytch/b2b/rbac.go new file mode 100644 index 0000000..1506e4c --- /dev/null +++ b/stytch/b2b/rbac.go @@ -0,0 +1,70 @@ +package b2b + +// !!! +// WARNING: This file is autogenerated +// Only modify code within MANUAL() sections +// or your changes may be overwritten later! +// !!! + +import ( + "context" + "time" + + "github.com/stytchauth/stytch-go/v11/stytch" + "github.com/stytchauth/stytch-go/v11/stytch/b2b/rbac" +) + +type RBACClient struct { + C stytch.Client +} + +func NewRBACClient(c stytch.Client) *RBACClient { + return &RBACClient{ + C: c, + } +} + +func (c *RBACClient) Policy( + ctx context.Context, + body *rbac.PolicyParams, +) (*rbac.PolicyResponse, error) { + var retVal rbac.PolicyResponse + err := c.C.NewRequest( + ctx, + "GET", + "/v1/b2b/rbac/policy", + nil, + nil, + &retVal, + ) + return &retVal, err +} + +// MANUAL(PolicyCache)(TYPES) + +type PolicyCache struct { + rbacClient *RBACClient + policy *rbac.Policy + lastUpdatedAt time.Time +} + +const refreshCadence = 5 * time.Minute + +func NewPolicyCache(rbacClient *RBACClient) *PolicyCache { + return &PolicyCache{rbacClient: rbacClient} +} + +func (pc *PolicyCache) Get(ctx context.Context) (*rbac.Policy, error) { + if pc.policy == nil || time.Since(pc.lastUpdatedAt) > refreshCadence { + policyResp, err := pc.rbacClient.Policy(ctx, &rbac.PolicyParams{}) + if err != nil { + return nil, err + } + + pc.policy = policyResp.Policy + pc.lastUpdatedAt = time.Now() + } + return pc.policy, nil +} + +// ENDMANUAL(PolicyCache) diff --git a/stytch/b2b/rbac/types.go b/stytch/b2b/rbac/types.go new file mode 100644 index 0000000..4183b0c --- /dev/null +++ b/stytch/b2b/rbac/types.go @@ -0,0 +1,35 @@ +package rbac + +// !!! +// WARNING: This file is autogenerated +// Only modify code within MANUAL() sections +// or your changes may be overwritten later! +// !!! + +type Policy struct { + Roles []PolicyRole `json:"roles,omitempty"` + Resources []PolicyResource `json:"resources,omitempty"` +} +type ( + PolicyParams struct{} + PolicyResource struct { + ResourceID string `json:"resource_id,omitempty"` + Description string `json:"description,omitempty"` + Actions []string `json:"actions,omitempty"` + } +) + +type PolicyRole struct { + RoleID string `json:"role_id,omitempty"` + Description string `json:"description,omitempty"` + Permissions []PolicyRolePermission `json:"permissions,omitempty"` +} +type PolicyRolePermission struct { + ResourceID string `json:"resource_id,omitempty"` + Actions []string `json:"actions,omitempty"` +} +type PolicyResponse struct { + RequestID string `json:"request_id,omitempty"` + StatusCode int32 `json:"status_code,omitempty"` + Policy *Policy `json:"policy,omitempty"` +} diff --git a/stytch/b2b/sessions.go b/stytch/b2b/sessions.go index da12542..4e4bfc9 100644 --- a/stytch/b2b/sessions.go +++ b/stytch/b2b/sessions.go @@ -16,18 +16,23 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/mitchellh/mapstructure" "github.com/stytchauth/stytch-go/v11/stytch" + "github.com/stytchauth/stytch-go/v11/stytch/b2b/rbac" "github.com/stytchauth/stytch-go/v11/stytch/b2b/sessions" + "github.com/stytchauth/stytch-go/v11/stytch/shared" "github.com/stytchauth/stytch-go/v11/stytch/stytcherror" ) type SessionsClient struct { - C stytch.Client - JWKS *keyfunc.JWKS + C stytch.Client + JWKS *keyfunc.JWKS + PolicyCache *PolicyCache } -func NewSessionsClient(c stytch.Client) *SessionsClient { +func NewSessionsClient(c stytch.Client, jwks *keyfunc.JWKS, policyCache *PolicyCache) *SessionsClient { return &SessionsClient{ - C: c, + C: c, + JWKS: jwks, + PolicyCache: policyCache, } } @@ -286,7 +291,7 @@ func (c *SessionsClient) AuthenticateJWT( return c.Authenticate(ctx, params.Body) } - session, err := c.AuthenticateJWTLocal(params.Body.SessionJWT, params.MaxTokenAge) + session, err := c.AuthenticateJWTLocal(ctx, params.Body.SessionJWT, params.MaxTokenAge, params.Body.AuthorizationCheck) if err != nil { // JWT couldn't be verified locally. Check with the Stytch API. return c.Authenticate(ctx, params.Body) @@ -307,7 +312,7 @@ func (c *SessionsClient) AuthenticateJWTWithClaims( return c.AuthenticateWithClaims(ctx, body, claims) } - session, err := c.AuthenticateJWTLocal(body.SessionJWT, maxTokenAge) + session, err := c.AuthenticateJWTLocal(ctx, body.SessionJWT, maxTokenAge, body.AuthorizationCheck) if err != nil { // JWT couldn't be verified locally. Check with the Stytch API. return c.Authenticate(ctx, body) @@ -318,9 +323,12 @@ func (c *SessionsClient) AuthenticateJWTWithClaims( }, nil } +// ADDIMPORT: "github.com/stytchauth/stytch-go/v11/stytch/shared" func (c *SessionsClient) AuthenticateJWTLocal( + ctx context.Context, token string, maxTokenAge time.Duration, + authorizationCheck *sessions.AuthorizationCheck, ) (*sessions.MemberSession, error) { if c.JWKS == nil { return nil, stytcherror.ErrJWKSNotInitialized @@ -341,7 +349,25 @@ func (c *SessionsClient) AuthenticateJWTLocal( return nil, sessions.ErrJWTTooOld } - return marshalJWTIntoSession(claims) + memberSession, err := marshalJWTIntoSession(claims) + if err != nil { + return nil, fmt.Errorf("failed to marshal JWT into session: %w", err) + } + + var policy *rbac.Policy + if authorizationCheck != nil { + policy, err = c.PolicyCache.Get(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get cached policy: %w", err) + } + + err = shared.PerformAuthorizationCheck(policy, claims.Roles, memberSession.OrganizationID, authorizationCheck) + if err != nil { + return nil, err + } + } + + return memberSession, nil } func marshalJWTIntoSession(claims sessions.Claims) (*sessions.MemberSession, error) { diff --git a/stytch/b2b/sessions/types.go b/stytch/b2b/sessions/types.go index d4768d3..4ccce9b 100644 --- a/stytch/b2b/sessions/types.go +++ b/stytch/b2b/sessions/types.go @@ -44,7 +44,13 @@ type AuthenticateParams struct { // delete a key, supply a null value. Custom claims made with reserved claims (`iss`, `sub`, `aud`, // `exp`, `nbf`, `iat`, `jti`) will be ignored. // Total custom claims size cannot exceed four kilobytes. - SessionCustomClaims map[string]any `json:"session_custom_claims,omitempty"` + SessionCustomClaims map[string]any `json:"session_custom_claims,omitempty"` + AuthorizationCheck *AuthorizationCheck `json:"authorization_check,omitempty"` +} +type AuthorizationCheck struct { + OrganizationID string `json:"organization_id,omitempty"` + ResourceID string `json:"resource_id,omitempty"` + Action string `json:"action,omitempty"` } // ExchangeParams: Request type for `Sessions.Exchange`. @@ -283,6 +289,7 @@ type OrgClaim struct { type Claims struct { Session sessions.SessionClaim `json:"https://stytch.com/session"` Organization OrgClaim `json:"https://stytch.com/organization"` + Roles []string `json:"https://stytch.com/roles"` jwt.RegisteredClaims } diff --git a/stytch/b2b/sessions_test.go b/stytch/b2b/sessions_test.go index 8925be4..b085842 100644 --- a/stytch/b2b/sessions_test.go +++ b/stytch/b2b/sessions_test.go @@ -43,8 +43,8 @@ func TestAuthenticateJWTLocal(t *testing.T) { keyID: keyfunc.NewGivenRSA(&key.PublicKey, keyfunc.GivenKeyOptions{Algorithm: "RS256"}), }) - sessionClient := b2b.NewSessionsClient(client) - sessionClient.JWKS = jwks + policyCache := b2b.NewPolicyCache(b2b.NewRBACClient(client)) + sessionClient := b2b.NewSessionsClient(client, jwks, policyCache) t.Run("expired JWT", func(t *testing.T) { iat := time.Now().UTC().Add(-time.Hour).Truncate(time.Second) @@ -53,7 +53,8 @@ func TestAuthenticateJWTLocal(t *testing.T) { claims := sandboxClaims(t, iat, exp) token := signJWT(t, keyID, key, claims) - s, err := sessionClient.AuthenticateJWTLocal(token, 10*time.Minute) + ctx := context.Background() + s, err := sessionClient.AuthenticateJWTLocal(ctx, token, 10*time.Minute, nil) assert.ErrorIs(t, err, jwt.ErrTokenExpired) assert.Nil(t, s) }) @@ -65,7 +66,8 @@ func TestAuthenticateJWTLocal(t *testing.T) { claims := sandboxClaims(t, iat, exp) token := signJWT(t, keyID, key, claims) - s, err := sessionClient.AuthenticateJWTLocal(token, 1*time.Minute) + ctx := context.Background() + s, err := sessionClient.AuthenticateJWTLocal(ctx, token, 1*time.Minute, nil) assert.ErrorIs(t, err, sessions.ErrJWTTooOld) assert.Nil(t, s) }) @@ -79,7 +81,8 @@ func TestAuthenticateJWTLocal(t *testing.T) { token := signJWT(t, keyID, key, claims) - s, err := sessionClient.AuthenticateJWTLocal(token, 1*time.Minute) + ctx := context.Background() + s, err := sessionClient.AuthenticateJWTLocal(ctx, token, 1*time.Minute, nil) assert.ErrorIs(t, err, jwt.ErrTokenInvalidAudience) assert.Nil(t, s) }) @@ -93,7 +96,8 @@ func TestAuthenticateJWTLocal(t *testing.T) { token := signJWT(t, keyID, key, claims) - s, err := sessionClient.AuthenticateJWTLocal(token, 1*time.Minute) + ctx := context.Background() + s, err := sessionClient.AuthenticateJWTLocal(ctx, token, 1*time.Minute, nil) assert.ErrorIs(t, err, jwt.ErrTokenInvalidIssuer) assert.Nil(t, s) }) @@ -105,7 +109,8 @@ func TestAuthenticateJWTLocal(t *testing.T) { claims := sandboxClaims(t, iat, exp) token := signJWT(t, keyID, key, claims) - session, err := sessionClient.AuthenticateJWTLocal(token, 3*time.Minute) + ctx := context.Background() + session, err := sessionClient.AuthenticateJWTLocal(ctx, token, 3*time.Minute, nil) require.NoError(t, err) expected := &sessions.MemberSession{ diff --git a/stytch/b2b/sso.go b/stytch/b2b/sso.go index 37273d7..c6db807 100644 --- a/stytch/b2b/sso.go +++ b/stytch/b2b/sso.go @@ -25,7 +25,8 @@ type SSOClient struct { func NewSSOClient(c stytch.Client) *SSOClient { return &SSOClient{ - C: c, + C: c, + OIDC: NewSSOOIDCClient(c), SAML: NewSSOSAMLClient(c), } diff --git a/stytch/config/version.go b/stytch/config/version.go index 2c3db83..6e00298 100644 --- a/stytch/config/version.go +++ b/stytch/config/version.go @@ -1,3 +1,3 @@ package config -const APIVersion = "11.5.2" +const APIVersion = "11.6.0" diff --git a/stytch/consumer/m2m.go b/stytch/consumer/m2m.go index ec122aa..dc97182 100644 --- a/stytch/consumer/m2m.go +++ b/stytch/consumer/m2m.go @@ -30,9 +30,11 @@ type M2MClient struct { JWKS *keyfunc.JWKS } -func NewM2MClient(c stytch.Client) *M2MClient { +func NewM2MClient(c stytch.Client, jwks *keyfunc.JWKS) *M2MClient { return &M2MClient{ - C: c, + C: c, + JWKS: jwks, + Clients: NewM2MClientsClient(c), } } diff --git a/stytch/consumer/m2m_clients.go b/stytch/consumer/m2m_clients.go index aa5e951..d33b55c 100644 --- a/stytch/consumer/m2m_clients.go +++ b/stytch/consumer/m2m_clients.go @@ -23,7 +23,8 @@ type M2MClientsClient struct { func NewM2MClientsClient(c stytch.Client) *M2MClientsClient { return &M2MClientsClient{ - C: c, + C: c, + Secrets: NewM2MClientsSecretsClient(c), } } diff --git a/stytch/consumer/m2m_test.go b/stytch/consumer/m2m_test.go index 6460235..9cf94ac 100644 --- a/stytch/consumer/m2m_test.go +++ b/stytch/consumer/m2m_test.go @@ -57,7 +57,7 @@ func TestM2MClient_Token(t *testing.T) { client.Config.BaseURI = config.BaseURI(svr.URL) - res, err := consumer.NewM2MClient(client).Token(context.Background(), &m2m.TokenParams{ + res, err := consumer.NewM2MClient(client, nil).Token(context.Background(), &m2m.TokenParams{ ClientID: expectedClientID, ClientSecret: expectedClientSecret, Scopes: scopes, @@ -81,7 +81,7 @@ func TestM2MClient_Token(t *testing.T) { client.Config.BaseURI = config.BaseURI(svr.URL) - res, err := consumer.NewM2MClient(client).Token(context.Background(), &m2m.TokenParams{ + res, err := consumer.NewM2MClient(client, nil).Token(context.Background(), &m2m.TokenParams{ ClientID: expectedClientID, ClientSecret: expectedClientSecret, Scopes: scopes, @@ -111,8 +111,7 @@ func TestM2MClient_AuthenticateToken(t *testing.T) { keyID: keyfunc.NewGivenRSA(&key.PublicKey, keyfunc.GivenKeyOptions{Algorithm: "RS256"}), }) - m2mClient := consumer.NewM2MClient(client) - m2mClient.JWKS = jwks + m2mClient := consumer.NewM2MClient(client, jwks) t.Run("expired JWT", func(t *testing.T) { iat := time.Now().UTC().Add(-time.Hour).Truncate(time.Second) diff --git a/stytch/consumer/magiclinks.go b/stytch/consumer/magiclinks.go index 58a958c..48cb887 100644 --- a/stytch/consumer/magiclinks.go +++ b/stytch/consumer/magiclinks.go @@ -24,7 +24,8 @@ type MagicLinksClient struct { func NewMagicLinksClient(c stytch.Client) *MagicLinksClient { return &MagicLinksClient{ - C: c, + C: c, + Email: NewMagicLinksEmailClient(c), } } diff --git a/stytch/consumer/otp.go b/stytch/consumer/otp.go index 6521ff4..3fdfa4e 100644 --- a/stytch/consumer/otp.go +++ b/stytch/consumer/otp.go @@ -26,7 +26,8 @@ type OTPsClient struct { func NewOTPsClient(c stytch.Client) *OTPsClient { return &OTPsClient{ - C: c, + C: c, + Sms: NewOTPsSmsClient(c), Whatsapp: NewOTPsWhatsappClient(c), Email: NewOTPsEmailClient(c), diff --git a/stytch/consumer/passwords.go b/stytch/consumer/passwords.go index 2157ffc..c8101e4 100644 --- a/stytch/consumer/passwords.go +++ b/stytch/consumer/passwords.go @@ -26,7 +26,8 @@ type PasswordsClient struct { func NewPasswordsClient(c stytch.Client) *PasswordsClient { return &PasswordsClient{ - C: c, + C: c, + Email: NewPasswordsEmailClient(c), ExistingPassword: NewPasswordsExistingPasswordClient(c), Sessions: NewPasswordsSessionsClient(c), diff --git a/stytch/consumer/sessions.go b/stytch/consumer/sessions.go index 38fb431..5512649 100644 --- a/stytch/consumer/sessions.go +++ b/stytch/consumer/sessions.go @@ -25,9 +25,10 @@ type SessionsClient struct { JWKS *keyfunc.JWKS } -func NewSessionsClient(c stytch.Client) *SessionsClient { +func NewSessionsClient(c stytch.Client, jwks *keyfunc.JWKS) *SessionsClient { return &SessionsClient{ - C: c, + C: c, + JWKS: jwks, } } diff --git a/stytch/consumer/sessions_test.go b/stytch/consumer/sessions_test.go index 848f7c7..1a147e1 100644 --- a/stytch/consumer/sessions_test.go +++ b/stytch/consumer/sessions_test.go @@ -43,8 +43,7 @@ func TestAuthenticateJWTLocal(t *testing.T) { keyID: keyfunc.NewGivenRSA(&key.PublicKey, keyfunc.GivenKeyOptions{Algorithm: "RS256"}), }) - sessionClient := consumer.NewSessionsClient(client) - sessionClient.JWKS = jwks + sessionClient := consumer.NewSessionsClient(client, jwks) t.Run("expired JWT", func(t *testing.T) { iat := time.Now().UTC().Add(-time.Hour).Truncate(time.Second) diff --git a/stytch/consumer/stytchapi/stytchapi.go b/stytch/consumer/stytchapi/stytchapi.go index aa6d52f..d76f264 100644 --- a/stytch/consumer/stytchapi/stytchapi.go +++ b/stytch/consumer/stytchapi/stytchapi.go @@ -129,24 +129,22 @@ func NewClient(projectID string, secret string, opts ...Option) (*API, error) { o(a) } + // Set up JWKS for local session authentication + jwks, err := a.instantiateJWKSClient(a.client.GetHTTPClient()) + if err != nil { + return nil, fmt.Errorf("fetch JWKS from URL: %w", err) + } + a.CryptoWallets = consumer.NewCryptoWalletsClient(a.client) - a.M2M = consumer.NewM2MClient(a.client) + a.M2M = consumer.NewM2MClient(a.client, jwks) a.MagicLinks = consumer.NewMagicLinksClient(a.client) a.OAuth = consumer.NewOAuthClient(a.client) a.OTPs = consumer.NewOTPsClient(a.client) a.Passwords = consumer.NewPasswordsClient(a.client) - a.Sessions = consumer.NewSessionsClient(a.client) + a.Sessions = consumer.NewSessionsClient(a.client, jwks) a.TOTPs = consumer.NewTOTPsClient(a.client) a.Users = consumer.NewUsersClient(a.client) a.WebAuthn = consumer.NewWebAuthnClient(a.client) - // Set up JWKS for local session authentication - jwks, err := a.instantiateJWKSClient(a.client.GetHTTPClient()) - if err != nil { - return nil, fmt.Errorf("fetch JWKS from URL: %w", err) - } - a.M2M.JWKS = jwks - - a.Sessions.JWKS = jwks return a, nil } diff --git a/stytch/shared/rbac_local.go b/stytch/shared/rbac_local.go new file mode 100644 index 0000000..6c167b3 --- /dev/null +++ b/stytch/shared/rbac_local.go @@ -0,0 +1,46 @@ +package shared + +import ( + "github.com/stytchauth/stytch-go/v11/stytch/b2b/rbac" + "github.com/stytchauth/stytch-go/v11/stytch/b2b/sessions" + "github.com/stytchauth/stytch-go/v11/stytch/stytcherror" +) + +func PerformAuthorizationCheck( + policy *rbac.Policy, + subjectRoles []string, + subjectOrgID string, + authorizationCheck *sessions.AuthorizationCheck, +) error { + if authorizationCheck == nil { + return nil + } + + if subjectOrgID != authorizationCheck.OrganizationID { + return stytcherror.NewClientLibraryError("Subject organization ID does not match ID from request") + } + + for _, role := range policy.Roles { + if contains(subjectRoles, role.RoleID) { + for _, permission := range role.Permissions { + hasMatchingAction := contains(permission.Actions, "*") || + contains(permission.Actions, authorizationCheck.Action) + hasMatchingResource := permission.ResourceID == authorizationCheck.ResourceID + if hasMatchingAction && hasMatchingResource { + return nil + } + } + } + } + + return stytcherror.NewClientLibraryError("Member is not authorized") +} + +func contains(stringList []string, target string) bool { + for _, s := range stringList { + if target == s { + return true + } + } + return false +} diff --git a/stytch/shared/rbac_local_test.go b/stytch/shared/rbac_local_test.go new file mode 100644 index 0000000..4825b26 --- /dev/null +++ b/stytch/shared/rbac_local_test.go @@ -0,0 +1,153 @@ +package shared_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stytchauth/stytch-go/v11/stytch/b2b/rbac" + "github.com/stytchauth/stytch-go/v11/stytch/b2b/sessions" + "github.com/stytchauth/stytch-go/v11/stytch/shared" + "github.com/stytchauth/stytch-go/v11/stytch/stytcherror" +) + +const orgID = "organization-1234" + +func TestPerformAuthorizationCheck(t *testing.T) { + policy := &rbac.Policy{ + Roles: []rbac.PolicyRole{ + { + RoleID: "stytch_member", + Description: "member", + Permissions: []rbac.PolicyRolePermission{ + { + ResourceID: "document", + Actions: []string{"read", "write"}, + }, + { + ResourceID: "program", + Actions: []string{"read"}, + }, + }, + }, + { + RoleID: "stytch_editor", + Description: "member", + Permissions: []rbac.PolicyRolePermission{ + { + ResourceID: "document", + Actions: []string{"read", "write"}, + }, + { + ResourceID: "program", + Actions: []string{"read", "execute"}, + }, + }, + }, + { + RoleID: "stytch_admin", + Description: "admin", + Permissions: []rbac.PolicyRolePermission{ + { + ResourceID: "document", + Actions: []string{"read", "write", "delete"}, + }, + { + ResourceID: "program", + Actions: []string{"read", "edit", "execute"}, + }, + }, + }, + }, + Resources: []rbac.PolicyResource{ + { + ResourceID: "document", + Description: "All documents", + Actions: []string{"read", "write", "delete"}, + }, + { + ResourceID: "program", + Description: "An executable program", + Actions: []string{"read", "write", "execute"}, + }, + }, + } + + t.Run("tenancy mismatch", func(t *testing.T) { + err := shared.PerformAuthorizationCheck( + policy, + []string{"stytch_member"}, + orgID, + &sessions.AuthorizationCheck{ + OrganizationID: "different-organization-id", + ResourceID: "document", + Action: "read", + }, + ) + assert.ErrorIs(t, err, stytcherror.NewClientLibraryError("Subject organization ID does not match ID from request")) + }) + t.Run("action exists but resource does not", func(t *testing.T) { + err := shared.PerformAuthorizationCheck( + policy, + []string{"stytch_member"}, + orgID, + &sessions.AuthorizationCheck{ + OrganizationID: orgID, + ResourceID: "resource_that_doesnt_exist", + Action: "read", + }, + ) + assert.ErrorIs(t, err, stytcherror.NewClientLibraryError("Member is not authorized")) + }) + t.Run("resource exists but action does not", func(t *testing.T) { + err := shared.PerformAuthorizationCheck( + policy, + []string{"stytch_member"}, + orgID, + &sessions.AuthorizationCheck{ + OrganizationID: orgID, + ResourceID: "document", + Action: "action_that_doesnt_exist", + }, + ) + assert.ErrorIs(t, err, stytcherror.NewClientLibraryError("Member is not authorized")) + }) + t.Run("member has this action but on a different resource", func(t *testing.T) { + err := shared.PerformAuthorizationCheck( + policy, + []string{"stytch_member"}, + orgID, + &sessions.AuthorizationCheck{ + OrganizationID: orgID, + ResourceID: "program", + Action: "write", + }, + ) + assert.ErrorIs(t, err, stytcherror.NewClientLibraryError("Member is not authorized")) + }) + t.Run("another authorization check for a member with more elevated privileges", func(t *testing.T) { + err := shared.PerformAuthorizationCheck( + policy, + []string{"stytch_editor"}, + orgID, + &sessions.AuthorizationCheck{ + OrganizationID: orgID, + ResourceID: "program", + Action: "edit", + }, + ) + assert.ErrorIs(t, err, stytcherror.NewClientLibraryError("Member is not authorized")) + }) + t.Run("no error when the member is authorized", func(t *testing.T) { + err := shared.PerformAuthorizationCheck( + policy, + []string{"stytch_admin"}, + orgID, + &sessions.AuthorizationCheck{ + OrganizationID: orgID, + ResourceID: "document", + Action: "delete", + }, + ) + assert.NoError(t, err) + }) +}