Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Consistently use room_id from federation request body #8776

Merged
merged 3 commits into from
Nov 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8776.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug in some federation APIs which could lead to unexpected behaviour if different parameters were set in the URI and the request body.
23 changes: 10 additions & 13 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
make_deferred_yieldable,
nested_logging_context,
Expand Down Expand Up @@ -391,7 +392,7 @@ async def _process_edu(edu_dict):
TRANSACTION_CONCURRENCY_LIMIT,
)

async def on_context_state_request(
async def on_room_state_request(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]:
origin_host, _ = parse_server_name(origin)
Expand Down Expand Up @@ -514,11 +515,12 @@ async def on_invite_request(
return {"event": ret_pdu.get_pdu_json(time_now)}

async def on_send_join_request(
self, origin: str, content: JsonDict, room_id: str
self, origin: str, content: JsonDict
) -> Dict[str, Any]:
logger.debug("on_send_join_request: content: %s", content)

room_version = await self.store.get_room_version(room_id)
assert_params_in_dict(content, ["room_id"])
room_version = await self.store.get_room_version(content["room_id"])
pdu = event_from_pdu_json(content, room_version)

origin_host, _ = parse_server_name(origin)
Expand Down Expand Up @@ -547,12 +549,11 @@ async def on_make_leave_request(
time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}

async def on_send_leave_request(
self, origin: str, content: JsonDict, room_id: str
) -> dict:
async def on_send_leave_request(self, origin: str, content: JsonDict) -> dict:
logger.debug("on_send_leave_request: content: %s", content)

room_version = await self.store.get_room_version(room_id)
assert_params_in_dict(content, ["room_id"])
room_version = await self.store.get_room_version(content["room_id"])
pdu = event_from_pdu_json(content, room_version)

origin_host, _ = parse_server_name(origin)
Expand Down Expand Up @@ -748,12 +749,8 @@ async def exchange_third_party_invite(
)
return ret

async def on_exchange_third_party_invite_request(
self, room_id: str, event_dict: Dict
):
ret = await self.handler.on_exchange_third_party_invite_request(
room_id, event_dict
)
async def on_exchange_third_party_invite_request(self, event_dict: Dict):
ret = await self.handler.on_exchange_third_party_invite_request(event_dict)
return ret

async def check_server_matches_acl(self, server_name: str, room_id: str):
Expand Down
68 changes: 33 additions & 35 deletions synapse/federation/transport/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,13 +440,13 @@ async def on_GET(self, origin, content, query, event_id):


class FederationStateV1Servlet(BaseFederationServlet):
PATH = "/state/(?P<context>[^/]*)/?"
PATH = "/state/(?P<room_id>[^/]*)/?"

# This is when someone asks for all data for a given context.
async def on_GET(self, origin, content, query, context):
return await self.handler.on_context_state_request(
# This is when someone asks for all data for a given room.
async def on_GET(self, origin, content, query, room_id):
return await self.handler.on_room_state_request(
origin,
context,
room_id,
parse_string_from_args(query, "event_id", None, required=False),
)

Expand All @@ -463,16 +463,16 @@ async def on_GET(self, origin, content, query, room_id):


class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/?"
PATH = "/backfill/(?P<room_id>[^/]*)/?"

async def on_GET(self, origin, content, query, context):
async def on_GET(self, origin, content, query, room_id):
versions = [x.decode("ascii") for x in query[b"v"]]
limit = parse_integer_from_args(query, "limit", None)

if not limit:
return 400, {"error": "Did not include limit param"}

return await self.handler.on_backfill_request(origin, context, versions, limit)
return await self.handler.on_backfill_request(origin, room_id, versions, limit)


class FederationQueryServlet(BaseFederationServlet):
Expand All @@ -487,9 +487,9 @@ async def on_GET(self, origin, content, query, query_type):


class FederationMakeJoinServlet(BaseFederationServlet):
PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"

async def on_GET(self, origin, _content, query, context, user_id):
async def on_GET(self, origin, _content, query, room_id, user_id):
"""
Args:
origin (unicode): The authenticated server_name of the calling server
Expand All @@ -511,24 +511,24 @@ async def on_GET(self, origin, _content, query, context, user_id):
supported_versions = ["1"]

content = await self.handler.on_make_join_request(
origin, context, user_id, supported_versions=supported_versions
origin, room_id, user_id, supported_versions=supported_versions
)
return 200, content


class FederationMakeLeaveServlet(BaseFederationServlet):
PATH = "/make_leave/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"

async def on_GET(self, origin, content, query, context, user_id):
content = await self.handler.on_make_leave_request(origin, context, user_id)
async def on_GET(self, origin, content, query, room_id, user_id):
content = await self.handler.on_make_leave_request(origin, room_id, user_id)
return 200, content


class FederationV1SendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"

async def on_PUT(self, origin, content, query, room_id, event_id):
content = await self.handler.on_send_leave_request(origin, content, room_id)
content = await self.handler.on_send_leave_request(origin, content)
return 200, (200, content)


Expand All @@ -538,43 +538,43 @@ class FederationV2SendLeaveServlet(BaseFederationServlet):
PREFIX = FEDERATION_V2_PREFIX

async def on_PUT(self, origin, content, query, room_id, event_id):
content = await self.handler.on_send_leave_request(origin, content, room_id)
content = await self.handler.on_send_leave_request(origin, content)
return 200, content


class FederationEventAuthServlet(BaseFederationServlet):
PATH = "/event_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"

async def on_GET(self, origin, content, query, context, event_id):
return await self.handler.on_event_auth(origin, context, event_id)
async def on_GET(self, origin, content, query, room_id, event_id):
return await self.handler.on_event_auth(origin, room_id, event_id)


class FederationV1SendJoinServlet(BaseFederationServlet):
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
richvdh marked this conversation as resolved.
Show resolved Hide resolved

async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
async def on_PUT(self, origin, content, query, room_id, event_id):
# TODO(paul): assert that room_id/event_id parsed from path actually
# match those given in content
Comment on lines +556 to 557
Copy link
Member

@clokep clokep Nov 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this comment is now obsolete since the room_id is unused. (Also true for FederationV2SendJoinServlet and maybe a few other spots.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

content = await self.handler.on_send_join_request(origin, content, context)
content = await self.handler.on_send_join_request(origin, content)
return 200, (200, content)


class FederationV2SendJoinServlet(BaseFederationServlet):
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"

PREFIX = FEDERATION_V2_PREFIX

async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
async def on_PUT(self, origin, content, query, room_id, event_id):
# TODO(paul): assert that room_id/event_id parsed from path actually
# match those given in content
content = await self.handler.on_send_join_request(origin, content, context)
content = await self.handler.on_send_join_request(origin, content)
return 200, content


class FederationV1InviteServlet(BaseFederationServlet):
PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"

async def on_PUT(self, origin, content, query, context, event_id):
async def on_PUT(self, origin, content, query, room_id, event_id):
# We don't get a room version, so we have to assume its EITHER v1 or
# v2. This is "fine" as the only difference between V1 and V2 is the
# state resolution algorithm, and we don't use that for processing
Expand All @@ -589,12 +589,12 @@ async def on_PUT(self, origin, content, query, context, event_id):


class FederationV2InviteServlet(BaseFederationServlet):
PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"

PREFIX = FEDERATION_V2_PREFIX

async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
async def on_PUT(self, origin, content, query, room_id, event_id):
# TODO(paul): assert that room_id/event_id parsed from path actually
# match those given in content

room_version = content["room_version"]
Expand All @@ -616,9 +616,7 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"

async def on_PUT(self, origin, content, query, room_id):
content = await self.handler.on_exchange_third_party_invite_request(
room_id, content
)
content = await self.handler.on_exchange_third_party_invite_request(content)
return 200, content


Expand Down
10 changes: 5 additions & 5 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.handlers._base import BaseHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
make_deferred_yieldable,
nested_logging_context,
Expand Down Expand Up @@ -2688,20 +2689,19 @@ async def exchange_third_party_invite(
)

async def on_exchange_third_party_invite_request(
self, room_id: str, event_dict: JsonDict
self, event_dict: JsonDict
) -> None:
"""Handle an exchange_third_party_invite request from a remote server

The remote server will call this when it wants to turn a 3pid invite
into a normal m.room.member invite.

Args:
room_id: The ID of the room.

event_dict (dict[str, Any]): Dictionary containing the event body.
event_dict: Dictionary containing the event body.

"""
room_version = await self.store.get_room_version_id(room_id)
assert_params_in_dict(event_dict, ["room_id"])
room_version = await self.store.get_room_version_id(event_dict["room_id"])

# NB: event_dict has a particular specced format we might need to fudge
# if we change event formats too much.
Expand Down
1 change: 0 additions & 1 deletion tests/handlers/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def test_exchange_revoked_invite(self):
)

d = self.handler.on_exchange_third_party_invite_request(
room_id=room_id,
event_dict={
"type": EventTypes.Member,
"room_id": room_id,
Expand Down