Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Include aggregations in results.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Jan 27, 2022
1 parent bbd10e3 commit 4b3aebe
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 4 deletions.
2 changes: 2 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def read_config(self, config: JsonDict, **kwargs):
self.msc1849_enabled = config.get("experimental_msc1849_support_enabled", True)
# MSC3440 (thread relation)
self.msc3440_enabled: bool = experimental.get("msc3440_enabled", False)
# MSC3666: including bundled relations in /search.
self.msc3666_enabled: bool = experimental.get("msc3666_enabled", False)

# MSC3026 (busy presence state)
self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
Expand Down
27 changes: 24 additions & 3 deletions synapse/handlers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(self, hs: "HomeServer"):
self.state_store = self.storage.state
self.auth = hs.get_auth()

self._msc3666_enabled = hs.config.experimental.msc3666_enabled

async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
"""Retrieves room IDs of old rooms in the history of an upgraded room.
Expand Down Expand Up @@ -418,12 +420,29 @@ async def search(

time_now = self.clock.time_msec()

aggregations = None
if self._msc3666_enabled:
aggregations = await self.store.get_bundled_aggregations(
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools.chain(
# The events_before and events_after for each context.
itertools.chain.from_iterable(
itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
for context in contexts.values()
),
# The returned events.
allowed_events,
),
user.to_string(),
)

for context in contexts.values():
context["events_before"] = self._event_serializer.serialize_events(
context["events_before"], time_now # type: ignore[arg-type]
context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
)
context["events_after"] = self._event_serializer.serialize_events(
context["events_after"], time_now # type: ignore[arg-type]
context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
)

state_results = {}
Expand All @@ -440,7 +459,9 @@ async def search(
results.append(
{
"rank": rank_map[e.event_id],
"result": self._event_serializer.serialize_event(e, time_now),
"result": self._event_serializer.serialize_event(
e, time_now, bundle_aggregations=aggregations
),
"context": contexts.get(e.event_id, {}),
}
)
Expand Down
39 changes: 38 additions & 1 deletion tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,9 @@ def test_aggregation_must_be_annotation(self):
)
self.assertEquals(400, channel.code, channel.json_body)

@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
@unittest.override_config(
{"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}}
)
def test_bundled_aggregations(self):
"""
Test that annotations, references, and threads get correctly bundled.
Expand Down Expand Up @@ -579,6 +581,23 @@ def assert_bundle(event_json: JsonDict) -> None:
self.assertTrue(room_timeline["limited"])
assert_bundle(self._find_event_in_chunk(room_timeline["events"]))

# Request search.
channel = self.make_request(
"POST",
"/search",
# Search term matches the parent message.
content={"search_categories": {"room_events": {"search_term": "Hi"}}},
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
chunk = [
result["result"]
for result in channel.json_body["search_categories"]["room_events"][
"results"
]
]
assert_bundle(self._find_event_in_chunk(chunk))

def test_aggregation_get_event_for_annotation(self):
"""Test that annotations do not get bundled aggregations included
when directly requested.
Expand Down Expand Up @@ -759,6 +778,7 @@ def test_ignore_invalid_room(self):
self.assertEquals(200, channel.code, channel.json_body)
self.assertNotIn("m.relations", channel.json_body["unsigned"])

@unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
def test_edit(self):
"""Test that a simple edit works."""

Expand Down Expand Up @@ -825,6 +845,23 @@ def assert_bundle(event_json: JsonDict) -> None:
self.assertTrue(room_timeline["limited"])
assert_bundle(self._find_event_in_chunk(room_timeline["events"]))

# Request search.
channel = self.make_request(
"POST",
"/search",
# Search term matches the parent message.
content={"search_categories": {"room_events": {"search_term": "Hi"}}},
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
chunk = [
result["result"]
for result in channel.json_body["search_categories"]["room_events"][
"results"
]
]
assert_bundle(self._find_event_in_chunk(chunk))

def test_multi_edit(self):
"""Test that multiple edits, including attempts by people who
shouldn't be allowed, are correctly handled.
Expand Down

0 comments on commit 4b3aebe

Please sign in to comment.