Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix a bug with thread summaries when the latest event is edited #11992

Merged
merged 7 commits into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11992.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary.
69 changes: 45 additions & 24 deletions synapse/events/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,33 @@ def serialize_event(

return serialized_event

def _apply_edit(
self, orig_event: EventBase, serialized_event: JsonDict, edit: EventBase
) -> None:
"""Replace the content, preserving existing relations of the serialized event.

Args:
orig_event: The original event.
serialized_event: The original event, serialized. This is modified.
edit: The event which edits the above.
"""

# Ensure we take copies of the edit content, otherwise we risk modifying
# the original event.
edit_content = edit.content.copy()

# Unfreeze the event content if necessary, so that we may modify it below
edit_content = unfreeze(edit_content)
serialized_event["content"] = edit_content.get("m.new_content", {})

# Check for existing relations
relates_to = orig_event.content.get("m.relates_to")
if relates_to:
# Keep the relations, ensuring we use a dict copy of the original
serialized_event["content"]["m.relates_to"] = relates_to.copy()
else:
serialized_event["content"].pop("m.relates_to", None)

def _inject_bundled_aggregations(
self,
event: EventBase,
Expand All @@ -450,26 +477,11 @@ def _inject_bundled_aggregations(
serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references

if aggregations.replace:
# If there is an edit replace the content, preserving existing
# relations.
# If there is an edit, apply it to the event.
edit = aggregations.replace
self._apply_edit(event, serialized_event, edit)

# Ensure we take copies of the edit content, otherwise we risk modifying
# the original event.
edit_content = edit.content.copy()

# Unfreeze the event content if necessary, so that we may modify it below
edit_content = unfreeze(edit_content)
serialized_event["content"] = edit_content.get("m.new_content", {})

# Check for existing relations
relates_to = event.content.get("m.relates_to")
if relates_to:
# Keep the relations, ensuring we use a dict copy of the original
serialized_event["content"]["m.relates_to"] = relates_to.copy()
else:
serialized_event["content"].pop("m.relates_to", None)

# Include information about it in the relations dict.
serialized_aggregations[RelationTypes.REPLACE] = {
"event_id": edit.event_id,
"origin_server_ts": edit.origin_server_ts,
Expand All @@ -478,13 +490,22 @@ def _inject_bundled_aggregations(

# If this event is the start of a thread, include a summary of the replies.
if aggregations.thread:
thread = aggregations.thread

# Don't bundle aggregations as this could recurse forever.
serialized_latest_event = self.serialize_event(
thread.latest_event, time_now, bundle_aggregations=None
)
# Manually apply an edit, if one exists.
if thread.latest_edit:
self._apply_edit(
thread.latest_event, serialized_latest_event, thread.latest_edit
)

serialized_aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever.
"latest_event": self.serialize_event(
aggregations.thread.latest_event, time_now, bundle_aggregations=None
),
"count": aggregations.thread.count,
"current_user_participated": aggregations.thread.current_user_participated,
"latest_event": serialized_latest_event,
"count": thread.count,
"current_user_participated": thread.current_user_participated,
}

# Include the bundled aggregations in the event.
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ async def get_events(
include the previous states content in the unsigned field.

allow_rejected: If True, return rejected events. Otherwise,
omits rejeted events from the response.
omits rejected events from the response.

Returns:
A mapping from event_id to event.
Expand Down
12 changes: 9 additions & 3 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
latest_event: EventBase
latest_edit: Optional[EventBase]
count: int
current_user_participated: bool

Expand Down Expand Up @@ -461,7 +462,7 @@ def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]:
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
async def _get_thread_summaries(
self, event_ids: Collection[str]
) -> Dict[str, Optional[Tuple[int, EventBase]]]:
) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
clokep marked this conversation as resolved.
Show resolved Hide resolved
"""Get the number of threaded replies and the latest reply (if any) for the given event.

Args:
Expand Down Expand Up @@ -558,6 +559,9 @@ def _get_thread_summaries_txn(

latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]

# Check to see if any of those events are edited.
latest_edits = await self._get_applicable_edits(latest_event_ids.values())

# Map to the event IDs to the thread summary.
#
# There might not be a summary due to there not being a thread or
Expand All @@ -568,7 +572,8 @@ def _get_thread_summaries_txn(

summary = None
if latest_event:
summary = (counts[parent_event_id], latest_event)
latest_edit = latest_edits.get(latest_event_id)
summary = (counts[parent_event_id], latest_event, latest_edit)
summaries[parent_event_id] = summary

return summaries
Expand Down Expand Up @@ -828,11 +833,12 @@ async def get_bundled_aggregations(
)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event = summary
thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
Expand Down
42 changes: 42 additions & 0 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,48 @@ def test_edit_reply(self):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)

@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_edit_thread(self):
"""Test that editing a thread works."""

# Create a thread and edit the last event.
channel = self._send_relation(
RelationTypes.THREAD,
"m.room.message",
content={"msgtype": "m.text", "body": "A threaded reply!"},
)
self.assertEquals(200, channel.code, channel.json_body)
threaded_event_id = channel.json_body["event_id"]

new_body = {"msgtype": "m.text", "body": "I've been edited!"}
channel = self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
parent_id=threaded_event_id,
)
self.assertEquals(200, channel.code, channel.json_body)

# Fetch the thread root, to get the bundled aggregation for the thread.
channel = self.make_request(
"GET",
f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)

# We expect that the edit message appears in the thread summary in the
# unsigned relations section.
relations_dict = channel.json_body["unsigned"].get("m.relations")
self.assertIn(RelationTypes.THREAD, relations_dict)

thread_summary = relations_dict[RelationTypes.THREAD]
self.assertIn("latest_event", thread_summary)
latest_event_in_thread = thread_summary["latest_event"]
self.assertEquals(
latest_event_in_thread["content"]["body"], "I've been edited!"
)

def test_edit_edit(self):
"""Test that an edit cannot be edited."""
new_body = {"msgtype": "m.text", "body": "Initial edit"}
Expand Down