Skip to content

Commit

Permalink
Local Authorization Check for RBAC (#157)
Browse files Browse the repository at this point in the history
Co-authored-by: Logan Gore <lgore@stytch.com>
  • Loading branch information
nikhil-stytch and logan-stytch authored Dec 8, 2023
1 parent 02f3a41 commit e07cfc3
Show file tree
Hide file tree
Showing 26 changed files with 413 additions and 56 deletions.
20 changes: 12 additions & 8 deletions stytch/b2b/b2bstytchapi/b2bstytchapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type API struct {
OTPs *b2b.OTPsClient
Organizations *b2b.OrganizationsClient
Passwords *b2b.PasswordsClient
RBAC *b2b.RBACClient
SSO *b2b.SSOClient
Sessions *b2b.SessionsClient
}
Expand Down Expand Up @@ -129,21 +130,24 @@ func NewClient(projectID string, secret string, opts ...Option) (*API, error) {
o(a)
}

policyCache := b2b.NewPolicyCache(b2b.NewRBACClient(a.client))

// Set up JWKS for local session authentication
jwks, err := a.instantiateJWKSClient(a.client.GetHTTPClient())
if err != nil {
return nil, fmt.Errorf("fetch JWKS from URL: %w", err)
}

a.Discovery = b2b.NewDiscoveryClient(a.client)
a.M2M = consumer.NewM2MClient(a.client)
a.M2M = consumer.NewM2MClient(a.client, jwks)
a.MagicLinks = b2b.NewMagicLinksClient(a.client)
a.OAuth = b2b.NewOAuthClient(a.client)
a.OTPs = b2b.NewOTPsClient(a.client)
a.Organizations = b2b.NewOrganizationsClient(a.client)
a.Passwords = b2b.NewPasswordsClient(a.client)
a.RBAC = b2b.NewRBACClient(a.client)
a.SSO = b2b.NewSSOClient(a.client)
a.Sessions = b2b.NewSessionsClient(a.client)
// Set up JWKS for local session authentication
jwks, err := a.instantiateJWKSClient(a.client.GetHTTPClient())
if err != nil {
return nil, fmt.Errorf("fetch JWKS from URL: %w", err)
}
a.M2M.JWKS = jwks
a.Sessions = b2b.NewSessionsClient(a.client, jwks, policyCache)

return a, nil
}
Expand Down
3 changes: 2 additions & 1 deletion stytch/b2b/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ type DiscoveryClient struct {

func NewDiscoveryClient(c stytch.Client) *DiscoveryClient {
return &DiscoveryClient{
C: c,
C: c,

IntermediateSessions: NewDiscoveryIntermediateSessionsClient(c),
Organizations: NewDiscoveryOrganizationsClient(c),
}
Expand Down
3 changes: 2 additions & 1 deletion stytch/b2b/magiclinks.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ type MagicLinksClient struct {

func NewMagicLinksClient(c stytch.Client) *MagicLinksClient {
return &MagicLinksClient{
C: c,
C: c,

Email: NewMagicLinksEmailClient(c),
Discovery: NewMagicLinksDiscoveryClient(c),
}
Expand Down
3 changes: 2 additions & 1 deletion stytch/b2b/magiclinks_email.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ type MagicLinksEmailClient struct {

func NewMagicLinksEmailClient(c stytch.Client) *MagicLinksEmailClient {
return &MagicLinksEmailClient{
C: c,
C: c,

Discovery: NewMagicLinksEmailDiscoveryClient(c),
}
}
Expand Down
3 changes: 2 additions & 1 deletion stytch/b2b/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ type OAuthClient struct {

func NewOAuthClient(c stytch.Client) *OAuthClient {
return &OAuthClient{
C: c,
C: c,

Discovery: NewOAuthDiscoveryClient(c),
}
}
Expand Down
3 changes: 2 additions & 1 deletion stytch/b2b/organizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ type OrganizationsClient struct {

func NewOrganizationsClient(c stytch.Client) *OrganizationsClient {
return &OrganizationsClient{
C: c,
C: c,

Members: NewOrganizationsMembersClient(c),
}
}
Expand Down
3 changes: 2 additions & 1 deletion stytch/b2b/otp.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ type OTPsClient struct {

func NewOTPsClient(c stytch.Client) *OTPsClient {
return &OTPsClient{
C: c,
C: c,

Sms: NewOTPsSmsClient(c),
}
}
3 changes: 2 additions & 1 deletion stytch/b2b/passwords.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ type PasswordsClient struct {

func NewPasswordsClient(c stytch.Client) *PasswordsClient {
return &PasswordsClient{
C: c,
C: c,

Email: NewPasswordsEmailClient(c),
Sessions: NewPasswordsSessionsClient(c),
ExistingPassword: NewPasswordsExistingPasswordClient(c),
Expand Down
70 changes: 70 additions & 0 deletions stytch/b2b/rbac.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package b2b

// !!!
// WARNING: This file is autogenerated
// Only modify code within MANUAL() sections
// or your changes may be overwritten later!
// !!!

import (
"context"
"time"

"github.com/stytchauth/stytch-go/v11/stytch"
"github.com/stytchauth/stytch-go/v11/stytch/b2b/rbac"
)

type RBACClient struct {
C stytch.Client
}

func NewRBACClient(c stytch.Client) *RBACClient {
return &RBACClient{
C: c,
}
}

func (c *RBACClient) Policy(
ctx context.Context,
body *rbac.PolicyParams,
) (*rbac.PolicyResponse, error) {
var retVal rbac.PolicyResponse
err := c.C.NewRequest(
ctx,
"GET",
"/v1/b2b/rbac/policy",
nil,
nil,
&retVal,

Check failure on line 38 in stytch/b2b/rbac.go

View workflow job for this annotation

GitHub Actions / test (1.18)

not enough arguments in call to c.C.NewRequest

Check failure on line 38 in stytch/b2b/rbac.go

View workflow job for this annotation

GitHub Actions / test (1.19)

not enough arguments in call to c.C.NewRequest
)
return &retVal, err
}

// MANUAL(PolicyCache)(TYPES)

type PolicyCache struct {
rbacClient *RBACClient
policy *rbac.Policy
lastUpdatedAt time.Time
}

const refreshCadence = 5 * time.Minute

func NewPolicyCache(rbacClient *RBACClient) *PolicyCache {
return &PolicyCache{rbacClient: rbacClient}
}

func (pc *PolicyCache) Get(ctx context.Context) (*rbac.Policy, error) {
if pc.policy == nil || time.Since(pc.lastUpdatedAt) > refreshCadence {
policyResp, err := pc.rbacClient.Policy(ctx, &rbac.PolicyParams{})
if err != nil {
return nil, err
}

pc.policy = policyResp.Policy
pc.lastUpdatedAt = time.Now()
}
return pc.policy, nil
}

// ENDMANUAL(PolicyCache)
35 changes: 35 additions & 0 deletions stytch/b2b/rbac/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package rbac

// !!!
// WARNING: This file is autogenerated
// Only modify code within MANUAL() sections
// or your changes may be overwritten later!
// !!!

type Policy struct {
Roles []PolicyRole `json:"roles,omitempty"`
Resources []PolicyResource `json:"resources,omitempty"`
}
type (
PolicyParams struct{}
PolicyResource struct {
ResourceID string `json:"resource_id,omitempty"`
Description string `json:"description,omitempty"`
Actions []string `json:"actions,omitempty"`
}
)

type PolicyRole struct {
RoleID string `json:"role_id,omitempty"`
Description string `json:"description,omitempty"`
Permissions []PolicyRolePermission `json:"permissions,omitempty"`
}
type PolicyRolePermission struct {
ResourceID string `json:"resource_id,omitempty"`
Actions []string `json:"actions,omitempty"`
}
type PolicyResponse struct {
RequestID string `json:"request_id,omitempty"`
StatusCode int32 `json:"status_code,omitempty"`
Policy *Policy `json:"policy,omitempty"`
}
40 changes: 33 additions & 7 deletions stytch/b2b/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,23 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/mitchellh/mapstructure"
"github.com/stytchauth/stytch-go/v11/stytch"
"github.com/stytchauth/stytch-go/v11/stytch/b2b/rbac"
"github.com/stytchauth/stytch-go/v11/stytch/b2b/sessions"
"github.com/stytchauth/stytch-go/v11/stytch/shared"
"github.com/stytchauth/stytch-go/v11/stytch/stytcherror"
)

type SessionsClient struct {
C stytch.Client
JWKS *keyfunc.JWKS
C stytch.Client
JWKS *keyfunc.JWKS
PolicyCache *PolicyCache
}

func NewSessionsClient(c stytch.Client) *SessionsClient {
func NewSessionsClient(c stytch.Client, jwks *keyfunc.JWKS, policyCache *PolicyCache) *SessionsClient {
return &SessionsClient{
C: c,
C: c,
JWKS: jwks,
PolicyCache: policyCache,
}
}

Expand Down Expand Up @@ -286,7 +291,7 @@ func (c *SessionsClient) AuthenticateJWT(
return c.Authenticate(ctx, params.Body)
}

session, err := c.AuthenticateJWTLocal(params.Body.SessionJWT, params.MaxTokenAge)
session, err := c.AuthenticateJWTLocal(ctx, params.Body.SessionJWT, params.MaxTokenAge, params.Body.AuthorizationCheck)
if err != nil {
// JWT couldn't be verified locally. Check with the Stytch API.
return c.Authenticate(ctx, params.Body)
Expand All @@ -307,7 +312,7 @@ func (c *SessionsClient) AuthenticateJWTWithClaims(
return c.AuthenticateWithClaims(ctx, body, claims)
}

session, err := c.AuthenticateJWTLocal(body.SessionJWT, maxTokenAge)
session, err := c.AuthenticateJWTLocal(ctx, body.SessionJWT, maxTokenAge, body.AuthorizationCheck)
if err != nil {
// JWT couldn't be verified locally. Check with the Stytch API.
return c.Authenticate(ctx, body)
Expand All @@ -318,9 +323,12 @@ func (c *SessionsClient) AuthenticateJWTWithClaims(
}, nil
}

// ADDIMPORT: "github.com/stytchauth/stytch-go/v11/stytch/shared"
func (c *SessionsClient) AuthenticateJWTLocal(
ctx context.Context,
token string,
maxTokenAge time.Duration,
authorizationCheck *sessions.AuthorizationCheck,
) (*sessions.MemberSession, error) {
if c.JWKS == nil {
return nil, stytcherror.ErrJWKSNotInitialized
Expand All @@ -341,7 +349,25 @@ func (c *SessionsClient) AuthenticateJWTLocal(
return nil, sessions.ErrJWTTooOld
}

return marshalJWTIntoSession(claims)
memberSession, err := marshalJWTIntoSession(claims)
if err != nil {
return nil, fmt.Errorf("failed to marshal JWT into session: %w", err)
}

var policy *rbac.Policy
if authorizationCheck != nil {
policy, err = c.PolicyCache.Get(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get cached policy: %w", err)
}

err = shared.PerformAuthorizationCheck(policy, claims.Roles, memberSession.OrganizationID, authorizationCheck)
if err != nil {
return nil, err
}
}

return memberSession, nil
}

func marshalJWTIntoSession(claims sessions.Claims) (*sessions.MemberSession, error) {
Expand Down
9 changes: 8 additions & 1 deletion stytch/b2b/sessions/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ type AuthenticateParams struct {
// delete a key, supply a null value. Custom claims made with reserved claims (`iss`, `sub`, `aud`,
// `exp`, `nbf`, `iat`, `jti`) will be ignored.
// Total custom claims size cannot exceed four kilobytes.
SessionCustomClaims map[string]any `json:"session_custom_claims,omitempty"`
SessionCustomClaims map[string]any `json:"session_custom_claims,omitempty"`
AuthorizationCheck *AuthorizationCheck `json:"authorization_check,omitempty"`
}
type AuthorizationCheck struct {
OrganizationID string `json:"organization_id,omitempty"`
ResourceID string `json:"resource_id,omitempty"`
Action string `json:"action,omitempty"`
}

// ExchangeParams: Request type for `Sessions.Exchange`.
Expand Down Expand Up @@ -283,6 +289,7 @@ type OrgClaim struct {
type Claims struct {
Session sessions.SessionClaim `json:"https://stytch.com/session"`
Organization OrgClaim `json:"https://stytch.com/organization"`
Roles []string `json:"https://stytch.com/roles"`
jwt.RegisteredClaims
}

Expand Down
19 changes: 12 additions & 7 deletions stytch/b2b/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ func TestAuthenticateJWTLocal(t *testing.T) {
keyID: keyfunc.NewGivenRSA(&key.PublicKey, keyfunc.GivenKeyOptions{Algorithm: "RS256"}),
})

sessionClient := b2b.NewSessionsClient(client)
sessionClient.JWKS = jwks
policyCache := b2b.NewPolicyCache(b2b.NewRBACClient(client))
sessionClient := b2b.NewSessionsClient(client, jwks, policyCache)

t.Run("expired JWT", func(t *testing.T) {
iat := time.Now().UTC().Add(-time.Hour).Truncate(time.Second)
Expand All @@ -53,7 +53,8 @@ func TestAuthenticateJWTLocal(t *testing.T) {
claims := sandboxClaims(t, iat, exp)
token := signJWT(t, keyID, key, claims)

s, err := sessionClient.AuthenticateJWTLocal(token, 10*time.Minute)
ctx := context.Background()
s, err := sessionClient.AuthenticateJWTLocal(ctx, token, 10*time.Minute, nil)
assert.ErrorIs(t, err, jwt.ErrTokenExpired)
assert.Nil(t, s)
})
Expand All @@ -65,7 +66,8 @@ func TestAuthenticateJWTLocal(t *testing.T) {
claims := sandboxClaims(t, iat, exp)
token := signJWT(t, keyID, key, claims)

s, err := sessionClient.AuthenticateJWTLocal(token, 1*time.Minute)
ctx := context.Background()
s, err := sessionClient.AuthenticateJWTLocal(ctx, token, 1*time.Minute, nil)
assert.ErrorIs(t, err, sessions.ErrJWTTooOld)
assert.Nil(t, s)
})
Expand All @@ -79,7 +81,8 @@ func TestAuthenticateJWTLocal(t *testing.T) {

token := signJWT(t, keyID, key, claims)

s, err := sessionClient.AuthenticateJWTLocal(token, 1*time.Minute)
ctx := context.Background()
s, err := sessionClient.AuthenticateJWTLocal(ctx, token, 1*time.Minute, nil)
assert.ErrorIs(t, err, jwt.ErrTokenInvalidAudience)
assert.Nil(t, s)
})
Expand All @@ -93,7 +96,8 @@ func TestAuthenticateJWTLocal(t *testing.T) {

token := signJWT(t, keyID, key, claims)

s, err := sessionClient.AuthenticateJWTLocal(token, 1*time.Minute)
ctx := context.Background()
s, err := sessionClient.AuthenticateJWTLocal(ctx, token, 1*time.Minute, nil)
assert.ErrorIs(t, err, jwt.ErrTokenInvalidIssuer)
assert.Nil(t, s)
})
Expand All @@ -105,7 +109,8 @@ func TestAuthenticateJWTLocal(t *testing.T) {
claims := sandboxClaims(t, iat, exp)
token := signJWT(t, keyID, key, claims)

session, err := sessionClient.AuthenticateJWTLocal(token, 3*time.Minute)
ctx := context.Background()
session, err := sessionClient.AuthenticateJWTLocal(ctx, token, 3*time.Minute, nil)
require.NoError(t, err)

expected := &sessions.MemberSession{
Expand Down
3 changes: 2 additions & 1 deletion stytch/b2b/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ type SSOClient struct {

func NewSSOClient(c stytch.Client) *SSOClient {
return &SSOClient{
C: c,
C: c,

OIDC: NewSSOOIDCClient(c),
SAML: NewSSOSAMLClient(c),
}
Expand Down
Loading

0 comments on commit e07cfc3

Please sign in to comment.