Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert some of the general database methods to async #8100

Merged
merged 5 commits into from
Aug 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8100.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
23 changes: 9 additions & 14 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,7 @@ def is_running(self):
"""
return self._db_pool.running

@defer.inlineCallbacks
def _check_safe_to_upsert(self):
async def _check_safe_to_upsert(self):
"""
Is it safe to use native UPSERT?

Expand All @@ -342,7 +341,7 @@ def _check_safe_to_upsert(self):

If the background updates have not completed, wait 15 sec and check again.
"""
updates = yield self.simple_select_list(
updates = await self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
Expand Down Expand Up @@ -614,8 +613,7 @@ def interaction(txn):
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.

@defer.inlineCallbacks
def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
"""Executes an INSERT query on the named table.

Args:
Expand All @@ -631,7 +629,7 @@ def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
`or_ignore` is True
"""
try:
yield self.runInteraction(desc, self.simple_insert_txn, table, values)
await self.runInteraction(desc, self.simple_insert_txn, table, values)
except self.engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
Expand Down Expand Up @@ -684,8 +682,7 @@ def simple_insert_many_txn(txn, table, values):

txn.executemany(sql, vals)

@defer.inlineCallbacks
def simple_upsert(
async def simple_upsert(
self,
table,
keyvalues,
Expand Down Expand Up @@ -714,14 +711,14 @@ def simple_upsert(
inserting
lock (bool): True to lock the table when doing the upsert.
Returns:
Deferred(None or bool): Native upserts always return None. Emulated
None or bool: Native upserts always return None. Emulated
upserts return True if a new entry was created, False if an existing
one was updated.
"""
attempts = 0
while True:
try:
result = yield self.runInteraction(
return await self.runInteraction(
desc,
self.simple_upsert_txn,
table,
Expand All @@ -730,7 +727,6 @@ def simple_upsert(
insertion_values,
lock=lock,
)
return result
except self.engine.module.IntegrityError as e:
attempts += 1
if attempts >= 5:
Expand Down Expand Up @@ -1121,8 +1117,7 @@ def simple_select_list_txn(cls, txn, table, keyvalues, retcols):

return cls.cursor_to_dict(txn)

@defer.inlineCallbacks
def simple_select_many_batch(
async def simple_select_many_batch(
self,
table,
column,
Expand Down Expand Up @@ -1156,7 +1151,7 @@ def simple_select_many_batch(
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
]
for chunk in chunks:
rows = yield self.runInteraction(
rows = await self.runInteraction(
desc,
self.simple_select_many_txn,
table,
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def set_appservice_state(self, service, state):
service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply.
Returns:
A Deferred which resolves when the state was set successfully.
An Awaitable which resolves when the state was set successfully.
"""
return self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
Expand Down
16 changes: 9 additions & 7 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,13 +847,15 @@ def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
rows = yield self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
rows = yield defer.ensureDeferred(
self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
)
)

return {r["event_id"] for r in rows}
Expand Down
8 changes: 3 additions & 5 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

import logging
import re
from typing import Dict, List, Optional

from twisted.internet.defer import Deferred
from typing import Awaitable, Dict, List, Optional

from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
Expand Down Expand Up @@ -563,7 +561,7 @@ def add_user_bound_threepid(self, user_id, medium, address, id_server):
id_server (str)

Returns:
Deferred
Awaitable
"""
# We need to use an upsert, in case they user had already bound the
# threepid
Expand Down Expand Up @@ -1084,7 +1082,7 @@ def _register_user(

def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
) -> Deferred:
) -> Awaitable:
"""Record a mapping from an external user id to a mxid

Args:
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,13 +767,13 @@ async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:

return set(room_ids)

def get_membership_from_event_ids(
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs.
"""

return self.db_pool.simple_select_many_batch(
return await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
Expand Down
4 changes: 2 additions & 2 deletions tests/handlers/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def register_query_handler(query_type, handler):
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")

yield self.store.create_profile(self.frank.localpart)
yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))

self.handler = hs.get_profile_handler()
self.hs = hs
Expand Down Expand Up @@ -157,7 +157,7 @@ def test_get_other_name(self):

@defer.inlineCallbacks
def test_incoming_fed_query(self):
yield self.store.create_profile("caroline")
yield defer.ensureDeferred(self.store.create_profile("caroline"))
yield self.store.set_profile_displayname("caroline", "Caroline")

response = yield defer.ensureDeferred(
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def get_users_in_room(room_id):
([], 0)
)
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
None
)

Expand Down
16 changes: 12 additions & 4 deletions tests/storage/test_appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def test_get_appservices_by_state_none(self):
@defer.inlineCallbacks
def test_set_appservices_state_down(self):
service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
)
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
Expand All @@ -219,9 +221,15 @@ def test_set_appservices_state_down(self):
@defer.inlineCallbacks
def test_set_appservices_state_multiple_up(self):
service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.UP)
)
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
)
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.UP)
)
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
Expand Down
16 changes: 10 additions & 6 deletions tests/storage/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ def runWithConnection(func, *args, **kwargs):
def test_insert_1col(self):
self.mock_txn.rowcount = 1

yield self.datastore.db_pool.simple_insert(
table="tablename", values={"columname": "Value"}
yield defer.ensureDeferred(
self.datastore.db_pool.simple_insert(
table="tablename", values={"columname": "Value"}
)
)

self.mock_txn.execute.assert_called_with(
Expand All @@ -78,10 +80,12 @@ def test_insert_1col(self):
def test_insert_3cols(self):
self.mock_txn.rowcount = 1

yield self.datastore.db_pool.simple_insert(
table="tablename",
# Use OrderedDict() so we can assert on the SQL generated
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
yield defer.ensureDeferred(
self.datastore.db_pool.simple_insert(
table="tablename",
# Use OrderedDict() so we can assert on the SQL generated
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
)
)

self.mock_txn.execute.assert_called_with(
Expand Down
30 changes: 16 additions & 14 deletions tests/storage/test_event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,22 @@ def _mark_read(stream, depth):
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
return self.store.db_pool.simple_insert(
"events",
{
"stream_ordering": so,
"received_ts": ts,
"event_id": "event%i" % so,
"type": "",
"room_id": "",
"content": "",
"processed": True,
"outlier": False,
"topological_ordering": 0,
"depth": 0,
},
return defer.ensureDeferred(
self.store.db_pool.simple_insert(
"events",
{
"stream_ordering": so,
"received_ts": ts,
"event_id": "event%i" % so,
"type": "",
"room_id": "",
"content": "",
"processed": True,
"outlier": False,
"topological_ordering": 0,
"depth": 0,
},
)
)

# start with the base case where there are no events in the table
Expand Down
2 changes: 1 addition & 1 deletion tests/storage/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def setUp(self):
@defer.inlineCallbacks
def test_get_users_paginate(self):
yield self.store.register_user(self.user.to_string(), "pass")
yield self.store.create_profile(self.user.localpart)
yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
yield self.store.set_profile_displayname(self.user.localpart, self.displayname)

users, total = yield self.store.get_users_paginate(
Expand Down
4 changes: 2 additions & 2 deletions tests/storage/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def setUp(self):

@defer.inlineCallbacks
def test_displayname(self):
yield self.store.create_profile(self.u_frank.localpart)
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))

yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")

Expand All @@ -43,7 +43,7 @@ def test_displayname(self):

@defer.inlineCallbacks
def test_avatar_url(self):
yield self.store.create_profile(self.u_frank.localpart)
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))

yield self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
Expand Down