Skip to content

Commit

Permalink
feat(embedded): make guest token JWT audience callable or str (#18748)
Browse files Browse the repository at this point in the history
* feat(embedded): make guest token JWT audience callable

* reset GUEST_TOKEN_JWT_AUDIENCE after test

* helper method for get audience
  • Loading branch information
Lily Kuang authored Feb 16, 2022
1 parent c8df849 commit b2613f6
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
3 changes: 2 additions & 1 deletion superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,8 @@ def SQL_QUERY_MUTATOR( # pylint: disable=invalid-name,unused-argument
GUEST_TOKEN_JWT_ALGO = "HS256"
GUEST_TOKEN_HEADER_NAME = "X-GuestToken"
GUEST_TOKEN_JWT_EXP_SECONDS = 300 # 5 minutes
GUEST_TOKEN_JWT_AUDIENCE = None
# Guest token audience for the embedded superset, either string or callable
GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None

# A SQL dataset health check. Note if enabled it is strongly advised that the callable
# be memoized to aid with performance, i.e.,
Expand Down
19 changes: 12 additions & 7 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,13 @@ def _get_current_epoch_time() -> float:
""" This is used so the tests can mock time """
return time.time()

@staticmethod
def _get_guest_token_jwt_audience() -> str:
audience = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host()
if callable(audience):
audience = audience()
return audience

def create_guest_access_token(
self,
user: GuestTokenUser,
Expand All @@ -1309,8 +1316,7 @@ def create_guest_access_token(
secret = current_app.config["GUEST_TOKEN_JWT_SECRET"]
algo = current_app.config["GUEST_TOKEN_JWT_ALGO"]
exp_seconds = current_app.config["GUEST_TOKEN_JWT_EXP_SECONDS"]
aud = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host()

audience = self._get_guest_token_jwt_audience()
# calculate expiration time
now = self._get_current_epoch_time()
exp = now + (exp_seconds * 1000)
Expand All @@ -1321,7 +1327,7 @@ def create_guest_access_token(
# standard jwt claims:
"iat": now, # issued at
"exp": exp, # expiration time
"aud": aud,
"aud": audience,
"type": "guest",
}
token = jwt.encode(claims, secret, algorithm=algo)
Expand Down Expand Up @@ -1363,17 +1369,16 @@ def get_guest_user_from_token(self, token: GuestToken) -> GuestUser:
token=token, roles=[self.find_role(current_app.config["GUEST_ROLE_NAME"])],
)

@staticmethod
def parse_jwt_guest_token(raw_token: str) -> Dict[str, Any]:
def parse_jwt_guest_token(self, raw_token: str) -> Dict[str, Any]:
"""
Parses a guest token. Raises an error if the jwt fails standard claims checks.
:param raw_token: the token gotten from the request
:return: the same token that was passed in, tested but unchanged
"""
secret = current_app.config["GUEST_TOKEN_JWT_SECRET"]
algo = current_app.config["GUEST_TOKEN_JWT_ALGO"]
aud = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host()
return jwt.decode(raw_token, secret, algorithms=[algo], audience=aud)
audience = self._get_guest_token_jwt_audience()
return jwt.decode(raw_token, secret, algorithms=[algo], audience=audience)

@staticmethod
def is_guest_user(user: Optional[Any] = None) -> bool:
Expand Down
22 changes: 22 additions & 0 deletions tests/integration_tests/security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,3 +1299,25 @@ def test_get_guest_user_bad_audience(self):

self.assertRaisesRegex(jwt.exceptions.InvalidAudienceError, "Invalid audience")
self.assertIsNone(guest_user)

@patch("superset.security.SupersetSecurityManager._get_current_epoch_time")
def test_create_guest_access_token_callable_audience(self, get_time_mock):
now = time.time()
get_time_mock.return_value = now
app.config["GUEST_TOKEN_JWT_AUDIENCE"] = Mock(return_value="cool_code")

user = {"username": "test_guest"}
resources = [{"some": "resource"}]
rls = [{"dataset": 1, "clause": "access = 1"}]
token = security_manager.create_guest_access_token(user, resources, rls)

decoded_token = jwt.decode(
token,
self.app.config["GUEST_TOKEN_JWT_SECRET"],
algorithms=[self.app.config["GUEST_TOKEN_JWT_ALGO"]],
audience="cool_code",
)
app.config["GUEST_TOKEN_JWT_AUDIENCE"].assert_called_once()
self.assertEqual("cool_code", decoded_token["aud"])
self.assertEqual("guest", decoded_token["type"])
app.config["GUEST_TOKEN_JWT_AUDIENCE"] = None

0 comments on commit b2613f6

Please sign in to comment.