diff --git a/superset/config.py b/superset/config.py index 2637c0032bef5..09ef4e71a2d54 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1316,7 +1316,8 @@ def SQL_QUERY_MUTATOR( # pylint: disable=invalid-name,unused-argument GUEST_TOKEN_JWT_ALGO = "HS256" GUEST_TOKEN_HEADER_NAME = "X-GuestToken" GUEST_TOKEN_JWT_EXP_SECONDS = 300 # 5 minutes -GUEST_TOKEN_JWT_AUDIENCE = None +# Guest token audience for the embedded superset, either string or callable +GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None # A SQL dataset health check. Note if enabled it is strongly advised that the callable # be memoized to aid with performance, i.e., diff --git a/superset/security/manager.py b/superset/security/manager.py index ac494a1837827..91b203e83f774 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1300,6 +1300,13 @@ def _get_current_epoch_time() -> float: """ This is used so the tests can mock time """ return time.time() + @staticmethod + def _get_guest_token_jwt_audience() -> str: + audience = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() + if callable(audience): + audience = audience() + return audience + def create_guest_access_token( self, user: GuestTokenUser, @@ -1309,8 +1316,7 @@ def create_guest_access_token( secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] exp_seconds = current_app.config["GUEST_TOKEN_JWT_EXP_SECONDS"] - aud = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() - + audience = self._get_guest_token_jwt_audience() # calculate expiration time now = self._get_current_epoch_time() exp = now + (exp_seconds * 1000) @@ -1321,7 +1327,7 @@ def create_guest_access_token( # standard jwt claims: "iat": now, # issued at "exp": exp, # expiration time - "aud": aud, + "aud": audience, "type": "guest", } token = jwt.encode(claims, secret, algorithm=algo) @@ -1363,8 +1369,7 @@ def get_guest_user_from_token(self, token: GuestToken) -> GuestUser: token=token, roles=[self.find_role(current_app.config["GUEST_ROLE_NAME"])], ) - @staticmethod - def parse_jwt_guest_token(raw_token: str) -> Dict[str, Any]: + def parse_jwt_guest_token(self, raw_token: str) -> Dict[str, Any]: """ Parses a guest token. Raises an error if the jwt fails standard claims checks. :param raw_token: the token gotten from the request @@ -1372,8 +1377,8 @@ def parse_jwt_guest_token(raw_token: str) -> Dict[str, Any]: """ secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] - aud = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() - return jwt.decode(raw_token, secret, algorithms=[algo], audience=aud) + audience = self._get_guest_token_jwt_audience() + return jwt.decode(raw_token, secret, algorithms=[algo], audience=audience) @staticmethod def is_guest_user(user: Optional[Any] = None) -> bool: diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index efcd191ffafc1..9dca5ac51375c 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -1299,3 +1299,25 @@ def test_get_guest_user_bad_audience(self): self.assertRaisesRegex(jwt.exceptions.InvalidAudienceError, "Invalid audience") self.assertIsNone(guest_user) + + @patch("superset.security.SupersetSecurityManager._get_current_epoch_time") + def test_create_guest_access_token_callable_audience(self, get_time_mock): + now = time.time() + get_time_mock.return_value = now + app.config["GUEST_TOKEN_JWT_AUDIENCE"] = Mock(return_value="cool_code") + + user = {"username": "test_guest"} + resources = [{"some": "resource"}] + rls = [{"dataset": 1, "clause": "access = 1"}] + token = security_manager.create_guest_access_token(user, resources, rls) + + decoded_token = jwt.decode( + token, + self.app.config["GUEST_TOKEN_JWT_SECRET"], + algorithms=[self.app.config["GUEST_TOKEN_JWT_ALGO"]], + audience="cool_code", + ) + app.config["GUEST_TOKEN_JWT_AUDIENCE"].assert_called_once() + self.assertEqual("cool_code", decoded_token["aud"]) + self.assertEqual("guest", decoded_token["type"]) + app.config["GUEST_TOKEN_JWT_AUDIENCE"] = None