Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Restructure the CAS code to be more like SAML/OIDC.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Dec 8, 2020
1 parent a9e5a2a commit cde552e
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 34 deletions.
61 changes: 34 additions & 27 deletions synapse/handlers/cas_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def _build_service_param(self, args: Dict[str, str]) -> str:

async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str]
) -> Tuple[str, Optional[str]]:
) -> Tuple[str, Dict[str, Optional[str]]]:
"""
Validate a CAS ticket with the server, parse the response, and return the user and display name.
Validate a CAS ticket with the server, parse the response, and return the user and other attributes.
Args:
ticket: The CAS ticket from the client.
Expand All @@ -97,22 +97,7 @@ async def _validate_ticket(
# even if that's being used old-http style to signal end-of-data
body = pde.response

user, attributes = self._parse_cas_response(body)
displayname = attributes.pop(self._cas_displayname_attribute, None)

for required_attribute, required_value in self._cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)

# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)

return user, displayname
return self._parse_cas_response(body)

def _parse_cas_response(
self, cas_response_body: bytes
Expand Down Expand Up @@ -208,7 +193,7 @@ async def handle_ticket(
args["redirectUrl"] = client_redirect_url
if session:
args["session"] = session
username, user_display_name = await self._validate_ticket(ticket, args)
username, attributes = await self._validate_ticket(ticket, args)

# first check if we're doing a UIA
if session:
Expand All @@ -218,14 +203,36 @@ async def handle_ticket(

# otherwise, we're handling a login request.

# Ensure that the attributes of the logged in user meet the required
# attributes.
for required_attribute, required_value in self._cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
self._sso_handler.render_error(
request, "unauthorised", "You are not authorised to log in here."
)
return

# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
)
return

# Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("")
ip_address = self.hs.get_ip_from_request(request)

# Get the matrix ID from the CAS username.
try:
user_id = await self._map_cas_user_to_matrix_user(
username, user_display_name, user_agent, ip_address
username, attributes, user_agent, ip_address
)
except MappingException as e:
logger.exception("Could not map user")
Expand All @@ -242,7 +249,7 @@ async def handle_ticket(
async def _map_cas_user_to_matrix_user(
self,
remote_user_id: str,
display_name: Optional[str],
attributes: Dict[str, Optional[str]],
user_agent: str,
ip_address: str,
) -> str:
Expand All @@ -251,7 +258,7 @@ async def _map_cas_user_to_matrix_user(
Args:
remote_user_id: The username from the CAS response.
display_name: The display name from the CAS response.
attributes: Additional attributes from the CAS response.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Expand All @@ -262,12 +269,14 @@ async def _map_cas_user_to_matrix_user(
The user ID associated with this response.
"""

# Note that CAS does not support a mapping provider, so the logic is hard-coded.
localpart = map_username_to_mxid_localpart(remote_user_id)
display_name = attributes.pop(self._cas_displayname_attribute, None)

async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
"""
Map from CAS attributes to user attributes.
"""
localpart = map_username_to_mxid_localpart(remote_user_id)

# Due to the grandfathering logic matching any previously registered
# mxids it isn't expected for there to be any failures.
if failures:
Expand All @@ -278,9 +287,7 @@ async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
async def grandfather_existing_users() -> Optional[str]:
# Since CAS did not used to support storing data into the user_external_ids
# tables, we need to attempt to map to existing users.
user_id = UserID(
map_username_to_mxid_localpart(remote_user_id), self._hostname
).to_string()
user_id = UserID(localpart, self._hostname).to_string()

logger.debug(
"Looking for existing account based on mapped %s", user_id,
Expand Down
14 changes: 7 additions & 7 deletions tests/handlers/test_cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def make_homeserver(self, reactor, clock):
def test_map_cas_user_to_user(self):
"""Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
cas_user_id = "test_user"
display_name = ""
attributes = {}
mxid = self.get_success(
self.handler._map_cas_user_to_matrix_user(
cas_user_id, display_name, "user-agent", "10.10.10.10"
cas_user_id, attributes, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@test_user:test")
Expand All @@ -63,29 +63,29 @@ def test_map_cas_user_to_existing_user(self):

# Map a user via SSO.
cas_user_id = "test_user"
display_name = ""
attributes = {}
mxid = self.get_success(
self.handler._map_cas_user_to_matrix_user(
cas_user_id, display_name, "user-agent", "10.10.10.10"
cas_user_id, attributes, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@test_user:test")

# Subsequent calls should map to the same mxid.
mxid = self.get_success(
self.handler._map_cas_user_to_matrix_user(
cas_user_id, display_name, "user-agent", "10.10.10.10"
cas_user_id, attributes, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@test_user:test")

def test_map_cas_user_to_invalid_localpart(self):
"""CAS automaps invalid characters to base-64 encoding."""
cas_user_id = "föö"
display_name = ""
attributes = {}
mxid = self.get_success(
self.handler._map_cas_user_to_matrix_user(
cas_user_id, display_name, "user-agent", "10.10.10.10"
cas_user_id, attributes, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@f=c3=b6=c3=b6:test")

0 comments on commit cde552e

Please sign in to comment.