Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infinite backfill with MSC2716 #817

Merged
merged 13 commits into from
Oct 11, 2022
28 changes: 15 additions & 13 deletions mautrix_telegram/commands/telegram/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import cast
import base64
import codecs
import math
import re

from aiohttp import ClientSession, InvalidURL
Expand Down Expand Up @@ -427,6 +428,9 @@ async def backfill(evt: CommandEvent) -> None:
if not evt.is_portal:
await evt.reply("You can only use backfill in portal rooms")
return
elif not evt.config["bridge.backfill.enable"]:
await evt.reply("Backfilling is disabled in the bridge config")
return
try:
limit = int(evt.args[0])
except (ValueError, IndexError):
Expand All @@ -435,16 +439,14 @@ async def backfill(evt: CommandEvent) -> None:
if not evt.config["bridge.backfill.normal_groups"] and portal.peer_type == "chat":
await evt.reply("Backfilling normal groups is disabled in the bridge config")
return
try:
await portal.backfill(evt.sender, limit=limit)
except TakeoutInitDelayError:
msg = (
"Please accept the data export request from a mobile device, "
"then re-run the backfill command."
)
if portal.peer_type == "user":
from mautrix.appservice import IntentAPI

await portal.main_intent.send_notice(evt.room_id, msg)
else:
await evt.reply(msg)
if portal.backfill_msc2716:
messages_per_batch = evt.config["bridge.backfill.incremental.messages_per_batch"]
batches = math.ceil(limit / messages_per_batch)
rounded = ""
if batches * messages_per_batch != limit:
rounded = f" (rounded message limit to {batches}*{messages_per_batch})"
await portal.enqueue_backfill(evt.sender, priority=0, max_batches=batches)
await evt.reply(f"Backfill queued{rounded}")
else:
output = await portal.forward_backfill(evt.sender, initial=False, override_limit=limit)
await evt.reply(output)
16 changes: 11 additions & 5 deletions mautrix_telegram/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,18 @@ def do_update(self, helper: ConfigUpdateHelper) -> None:
copy("bridge.bridge_matrix_leave")
copy("bridge.kick_on_logout")
copy("bridge.always_read_joined_telegram_notice")
copy("bridge.backfill.invite_own_puppet")
copy("bridge.backfill.takeout_limit")
copy("bridge.backfill.initial_limit")
copy("bridge.backfill.missed_limit")
copy("bridge.backfill.disable_notifications")
copy("bridge.backfill.enable")
copy("bridge.backfill.msc2716")
copy("bridge.backfill.double_puppet_backfill")
copy("bridge.backfill.normal_groups")
copy("bridge.backfill.forward.initial_limit")
copy("bridge.backfill.forward.sync_limit")
copy("bridge.backfill.incremental.messages_per_batch")
copy("bridge.backfill.incremental.post_batch_delay")
copy("bridge.backfill.incremental.max_batches.user")
copy("bridge.backfill.incremental.max_batches.normal_group")
copy("bridge.backfill.incremental.max_batches.supergroup")
copy("bridge.backfill.incremental.max_batches.channel")

copy("bridge.initial_power_level_overrides.group")
copy("bridge.initial_power_level_overrides.user")
Expand Down
3 changes: 3 additions & 0 deletions mautrix_telegram/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.async_db import Database

from .backfill_queue import Backfill
from .bot_chat import BotChat
from .disappearing_message import DisappearingMessage
from .message import Message
Expand All @@ -38,6 +39,7 @@ def init(db: Database) -> None:
BotChat,
PgSession,
DisappearingMessage,
Backfill,
):
table.db = db

Expand All @@ -54,4 +56,5 @@ def init(db: Database) -> None:
"BotChat",
"PgSession",
"DisappearingMessage",
"Backfill",
]
175 changes: 175 additions & 0 deletions mautrix_telegram/db/backfill_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2022 Tulir Asokan, Sumner Evans
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar
from datetime import datetime, timedelta

from asyncpg import Record
from attr import dataclass

from mautrix.types import RoomID, UserID
from mautrix.util.async_db import Database

fake_db = Database.create("") if TYPE_CHECKING else None


@dataclass
class Backfill:
db: ClassVar[Database] = fake_db

queue_id: int | None
user_mxid: UserID
priority: int
portal_tgid: int
portal_tg_receiver: int
messages_per_batch: int
post_batch_delay: int
max_batches: int
dispatch_time: datetime | None
completed_at: datetime | None
cooldown_timeout: datetime | None

@staticmethod
def new(
user_mxid: UserID,
priority: int,
portal_tgid: int,
portal_tg_receiver: int,
messages_per_batch: int,
post_batch_delay: int = 0,
max_batches: int = -1,
) -> "Backfill":
return Backfill(
queue_id=None,
user_mxid=user_mxid,
priority=priority,
portal_tgid=portal_tgid,
portal_tg_receiver=portal_tg_receiver,
messages_per_batch=messages_per_batch,
post_batch_delay=post_batch_delay,
max_batches=max_batches,
dispatch_time=None,
completed_at=None,
cooldown_timeout=None,
)

@classmethod
def _from_row(cls, row: Record | None) -> Backfill | None:
if row is None:
return None
return cls(**row)

columns = [
"user_mxid",
"priority",
"portal_tgid",
"portal_tg_receiver",
"messages_per_batch",
"post_batch_delay",
"max_batches",
"dispatch_time",
"completed_at",
"cooldown_timeout",
]
columns_str = ",".join(columns)

@classmethod
async def get_next(cls, user_mxid: UserID) -> Backfill | None:
q = f"""
SELECT queue_id, {cls.columns_str}
FROM backfill_queue
WHERE user_mxid=$1
AND (
dispatch_time IS NULL
OR (
dispatch_time < $2
AND completed_at IS NULL
)
)
AND (
cooldown_timeout IS NULL
OR cooldown_timeout < current_timestamp
)
ORDER BY priority, queue_id
LIMIT 1
"""
return cls._from_row(
await cls.db.fetchrow(q, user_mxid, datetime.now() - timedelta(minutes=15))
)

@classmethod
async def get(
cls,
user_mxid: UserID,
portal_tgid: int,
portal_tg_receiver: int,
) -> Backfill | None:
q = f"""
SELECT queue_id, {cls.columns_str}
FROM backfill_queue
WHERE user_mxid=$1
AND portal_tgid=$2
AND portal_tg_receiver=$3
ORDER BY priority, queue_id
LIMIT 1
"""
return cls._from_row(await cls.db.fetchrow(q, user_mxid, portal_tgid, portal_tg_receiver))

@classmethod
async def delete_all(cls, user_mxid: UserID) -> None:
await cls.db.execute("DELETE FROM backfill_queue WHERE user_mxid=$1", user_mxid)

@classmethod
async def delete_for_portal(cls, tgid: int, tg_receiver: int) -> None:
q = "DELETE FROM backfill_queue WHERE portal_tgid=$1 AND portal_tg_receiver=$2"
await cls.db.execute(q, tgid, tg_receiver)

async def insert(self) -> None:
q = f"""
INSERT INTO backfill_queue ({self.columns_str})
VALUES ({','.join(f'${i+1}' for i in range(len(self.columns)))})
RETURNING queue_id
"""
row = await self.db.fetchrow(
q,
self.user_mxid,
self.priority,
self.portal_tgid,
self.portal_tg_receiver,
self.messages_per_batch,
self.post_batch_delay,
self.max_batches,
self.dispatch_time,
self.completed_at,
self.cooldown_timeout,
)
self.queue_id = row["queue_id"]

async def mark_dispatched(self) -> None:
q = "UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2"
await self.db.execute(q, datetime.now(), self.queue_id)

async def mark_done(self) -> None:
q = "UPDATE backfill_queue SET completed_at=$1 WHERE queue_id=$2"
await self.db.execute(q, datetime.now(), self.queue_id)

async def set_cooldown_timeout(self, timeout) -> None:
"""
Set the backfill request to cooldown for ``timeout`` seconds.
"""
q = "UPDATE backfill_queue SET cooldown_timeout=$1 WHERE queue_id=$2"
await self.db.execute(q, datetime.now() + timedelta(seconds=timeout), self.queue_id)
34 changes: 27 additions & 7 deletions mautrix_telegram/db/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from asyncpg import Record
from attr import dataclass
import attr

from mautrix.types import EventID, RoomID, UserID
from mautrix.util.async_db import Database, Scheme
Expand Down Expand Up @@ -122,6 +123,14 @@ async def find_last(cls, mx_room: RoomID, tg_space: TelegramID) -> Message | Non
)
return cls._from_row(await cls.db.fetchrow(q, mx_room, tg_space))

@classmethod
async def find_first(cls, mx_room: RoomID, tg_space: TelegramID) -> Message | None:
q = (
f"SELECT {cls.columns} FROM message WHERE mx_room=$1 AND tg_space=$2 "
f"ORDER BY tgid ASC LIMIT 1"
)
return cls._from_row(await cls.db.fetchrow(q, mx_room, tg_space))

@classmethod
async def delete_all(cls, mx_room: RoomID) -> None:
await cls.db.execute("DELETE FROM message WHERE mx_room=$1", mx_room)
Expand Down Expand Up @@ -173,6 +182,23 @@ async def delete_temp_mxid(cls, temp_mxid: str, mx_room: RoomID) -> None:
q = "DELETE FROM message WHERE mxid=$1 AND mx_room=$2"
await cls.db.execute(q, temp_mxid, mx_room)

@classmethod
async def bulk_insert(cls, messages: list[Message]) -> None:
columns = cls.columns.split(", ")
records = [attr.astuple(message) for message in messages]
async with cls.db.acquire() as conn, conn.transaction():
if cls.db.scheme == Scheme.POSTGRES:
await conn.copy_records_to_table("message", records=records, columns=columns)
else:
await conn.executemany(cls._insert_query, records)

_insert_query: ClassVar[
str
] = """
INSERT INTO message (mxid, mx_room, tgid, tg_space, edit_index, redacted, content_hash, sender_mxid, sender)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
"""

@property
def _values(self):
return (
Expand All @@ -188,13 +214,7 @@ def _values(self):
)

async def insert(self) -> None:
q = """
INSERT INTO message (
mxid, mx_room, tgid, tg_space, edit_index, redacted, content_hash,
sender_mxid, sender
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
"""
await self.db.execute(q, *self._values)
await self.db.execute(self._insert_query, *self._values)

async def delete(self) -> None:
q = "DELETE FROM message WHERE mxid=$1 AND mx_room=$2 AND tg_space=$3"
Expand Down
Loading