Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Request & follow redirects for /media/v3/download #16701

Merged
merged 5 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16701.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Follow redirects when downloading media over federation (per [MSC3860](https://github.com/matrix-org/matrix-spec-proposals/pull/3860)).
38 changes: 38 additions & 0 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TYPE_CHECKING,
AbstractSet,
Awaitable,
BinaryIO,
Callable,
Collection,
Container,
Expand Down Expand Up @@ -1862,6 +1863,43 @@ def filter_user_id(user_id: str) -> bool:

return filtered_statuses, filtered_failures

async def download_media(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
try:
return await self.transport_layer.download_media_v3(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
# fallback to the r0 endpoint. Otherwise, consider it a legitimate error
# and raise.
if not is_unknown_endpoint(e):
raise

logger.debug(
"Couldn't download media %s/%s with the v3 API, falling back to the r0 API",
destination,
media_id,
)

return await self.transport_layer.download_media_r0(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
)


@attr.s(frozen=True, slots=True, auto_attribs=True)
class TimestampToEventResponse:
Expand Down
53 changes: 53 additions & 0 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Callable,
Collection,
Dict,
Expand Down Expand Up @@ -804,6 +805,58 @@ async def get_account_status(
destination=destination, path=path, data={"user_ids": user_ids}
)

async def download_media_r0(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/r0/download/{destination}/{media_id}"

return await self.client.get_file(
destination,
path,
output_stream=output_stream,
max_size=max_size,
args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
"timeout_ms": str(max_timeout_ms),
},
)

async def download_media_v3(
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/v3/download/{destination}/{media_id}"

return await self.client.get_file(
destination,
path,
output_stream=output_stream,
max_size=max_size,
args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
"timeout_ms": str(max_timeout_ms),
# Matrix 1.7 allows for this to redirect to another URL, this should
# just be ignored for an old homeserver, so always provide it.
"allow_redirect": "true",
},
follow_redirects=True,
)


def _create_path(federation_prefix: str, path: str, *args: str) -> str:
"""
Expand Down
77 changes: 57 additions & 20 deletions synapse/http/matrixfederationclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,18 @@ class MatrixFederationRequest:
"""Query arguments.
"""

txn_id: Optional[str] = None
"""Unique ID for this request (for logging)
txn_id: str = attr.ib(init=False)
"""Unique ID for this request (for logging), this is autogenerated.
"""

uri: bytes = attr.ib(init=False)
"""The URI of this request
uri: bytes = b""
"""The URI of this request, usually generated from the above information.
"""

_generate_uri: bool = True
"""True to automatically generate the uri field based on the above information.

Set to False if manually configuring the URI.
"""
Comment on lines +164 to 168
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't figure out a better way to force attrs to do what I want. Using a default/factory works if you're either just creating the instance or evolving it and updating the URI, but not when evolving it and wanting to generate the URI again. I figured being explicit was best.


def __attrs_post_init__(self) -> None:
Expand All @@ -168,22 +174,23 @@ def __attrs_post_init__(self) -> None:

object.__setattr__(self, "txn_id", txn_id)

destination_bytes = self.destination.encode("ascii")
path_bytes = self.path.encode("ascii")
query_bytes = encode_query_args(self.query)

# The object is frozen so we can pre-compute this.
uri = urllib.parse.urlunparse(
(
b"matrix-federation",
destination_bytes,
path_bytes,
None,
query_bytes,
b"",
if self._generate_uri:
destination_bytes = self.destination.encode("ascii")
path_bytes = self.path.encode("ascii")
query_bytes = encode_query_args(self.query)

# The object is frozen so we can pre-compute this.
uri = urllib.parse.urlunparse(
(
b"matrix-federation",
destination_bytes,
path_bytes,
None,
query_bytes,
b"",
)
)
)
object.__setattr__(self, "uri", uri)
object.__setattr__(self, "uri", uri)

def get_json(self) -> Optional[JsonDict]:
if self.json_callback:
Expand Down Expand Up @@ -513,6 +520,7 @@ async def _send_request(
ignore_backoff: bool = False,
backoff_on_404: bool = False,
backoff_on_all_error_codes: bool = False,
follow_redirects: bool = False,
) -> IResponse:
"""
Sends a request to the given server.
Expand Down Expand Up @@ -555,6 +563,9 @@ async def _send_request(
backoff_on_404: Back off if we get a 404
backoff_on_all_error_codes: Back off if we get any error response

follow_redirects: True to follow the Location header of 307/308 redirect
responses. This does not recurse.

Returns:
Resolves with the HTTP response object on success.

Expand Down Expand Up @@ -714,6 +725,26 @@ async def _send_request(
response.code,
response_phrase,
)
elif (
response.code in (307, 308)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about wrapping the agent in a RedirectAgent, but that handles additional codes sadly.

and follow_redirects
and response.headers.hasHeader("Location")
):
# The Location header *might* be relative so resolve it.
location = response.headers.getRawHeaders(b"Location")[0]
new_uri = urllib.parse.urljoin(request.uri, location)

return await self._send_request(
attr.evolve(request, uri=new_uri, generate_uri=False),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's spelled _generate_uri in the struct definition. Is this some attrs way of declaring a private attribute that you can initialise via the constructor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this some attrs way of declaring a private attribute that you can initialise via the constructor?

Yes, pretty much. attrs strips the preceding underscores to let you have 'private' attributes that can be initialized.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

retry_on_dns_fail,
timeout,
long_retries,
ignore_backoff,
backoff_on_404,
backoff_on_all_error_codes,
# Do not continue following redirects.
follow_redirects=False,
)
else:
logger.info(
"{%s} [%s] Got response headers: %d %s",
Expand Down Expand Up @@ -1383,6 +1414,7 @@ async def get_file(
retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
ignore_backoff: bool = False,
follow_redirects: bool = False,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
"""GETs a file from a given homeserver
Args:
Expand All @@ -1392,6 +1424,8 @@ async def get_file(
args: Optional dictionary used to create the query string.
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
follow_redirects: True to follow the Location header of 307/308 redirect
responses. This does not recurse.

Returns:
Resolves with an (int,dict) tuple of
Expand All @@ -1412,7 +1446,10 @@ async def get_file(
)

response = await self._send_request(
request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff
request,
retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
follow_redirects=follow_redirects,
)

headers = dict(response.headers.getAllRawHeaders())
Expand Down
17 changes: 4 additions & 13 deletions synapse/media/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class MediaRepository:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.client = hs.get_federation_http_client()
self.client = hs.get_federation_client()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastores().main
Expand Down Expand Up @@ -644,22 +644,13 @@ async def _download_remote_file(
file_info = FileInfo(server_name=server_name, file_id=file_id)

with self.media_storage.store_into_file(file_info) as (f, fname, finish):
request_path = "/".join(
("/_matrix/media/r0/download", server_name, media_id)
)
try:
length, headers = await self.client.get_file(
length, headers = await self.client.download_media(
server_name,
request_path,
media_id,
output_stream=f,
max_size=self.max_upload_size,
args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
"timeout_ms": str(max_timeout_ms),
},
max_timeout_ms=max_timeout_ms,
)
except RequestSendFailed as e:
logger.warning(
Expand Down
62 changes: 58 additions & 4 deletions tests/media/test_media_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@

from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource

from synapse.api.errors import Codes
from synapse.api.errors import Codes, HttpResponseException
from synapse.events import EventBase
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
Expand Down Expand Up @@ -247,6 +248,7 @@ def get_file(
retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
ignore_backoff: bool = False,
follow_redirects: bool = False,
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
"""A mock for MatrixFederationHttpClient.get_file."""

Expand All @@ -257,10 +259,15 @@ def write_to(
output_stream.write(data)
return response

def write_err(f: Failure) -> Failure:
f.trap(HttpResponseException)
output_stream.write(f.value.response)
return f
Comment on lines +262 to +265
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one was new to me: https://docs.twisted.org/en/stable/api/twisted.python.failure.Failure.html#trap

TL;DR the trap call is a no-op if f contains an HTTPResponseException; otherwise the trap raises immediately so that the next errback can handle this Failure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd missed that this was in test code though. Why do we need to add this as an errback all of a sudden?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add it because we know call errback sometimes on the list of Deferreds so that we can resolve a request with an error instead of a response.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh, I think I see: we never called errback on the mock until now!


d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
self.fetches.append((d, destination, path, args))
# Note that this callback changes the value held by d.
d_after_callback = d.addCallback(write_to)
d_after_callback = d.addCallbacks(write_to, write_err)
return make_deferred_yieldable(d_after_callback)

# Mock out the homeserver's MatrixFederationHttpClient
Expand Down Expand Up @@ -316,10 +323,11 @@ def _req(
self.assertEqual(len(self.fetches), 1)
self.assertEqual(self.fetches[0][1], "example.com")
self.assertEqual(
self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id
self.fetches[0][2], "/_matrix/media/v3/download/" + self.media_id
)
self.assertEqual(
self.fetches[0][3], {"allow_remote": "false", "timeout_ms": "20000"}
self.fetches[0][3],
{"allow_remote": "false", "timeout_ms": "20000", "allow_redirect": "true"},
)

headers = {
Expand Down Expand Up @@ -671,6 +679,52 @@ def test_cross_origin_resource_policy_header(self) -> None:
[b"cross-origin"],
)

def test_unknown_v3_endpoint(self) -> None:
"""
If the v3 endpoint fails, try the r0 one.
"""
channel = self.make_request(
"GET",
f"/_matrix/media/v3/download/{self.media_id}",
shorthand=False,
await_result=False,
)
self.pump()

# We've made one fetch, to example.com, using the media URL, and asking
# the other server not to do a remote fetch
self.assertEqual(len(self.fetches), 1)
self.assertEqual(self.fetches[0][1], "example.com")
self.assertEqual(
self.fetches[0][2], "/_matrix/media/v3/download/" + self.media_id
)

# The result which says the endpoint is unknown.
unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}'
self.fetches[0][0].errback(
HttpResponseException(404, "NOT FOUND", unknown_endpoint)
)

self.pump()

# There should now be another request to the r0 URL.
self.assertEqual(len(self.fetches), 2)
self.assertEqual(self.fetches[1][1], "example.com")
self.assertEqual(
self.fetches[1][2], f"/_matrix/media/r0/download/{self.media_id}"
)

headers = {
b"Content-Length": [b"%d" % (len(self.test_image.data))],
}

self.fetches[1][0].callback(
(self.test_image.data, (len(self.test_image.data), headers))
)

self.pump()
self.assertEqual(channel.code, 200)


class TestSpamCheckerLegacy:
"""A spam checker module that rejects all media that includes the bytes
Expand Down
Loading
Loading