Skip to content

Commit

Permalink
Convert to async/await (#101)
Browse files Browse the repository at this point in the history
Signed-off-by: Olivier Wilkinson (reivilibre) <olivier@librepush.net>

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
  • Loading branch information
reivilibre and clokep authored Sep 2, 2020
1 parent 5a9b52c commit 2e9d781
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 118 deletions.
109 changes: 52 additions & 57 deletions ldap_auth_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.internet import defer, threads
from twisted.internet import threads


import ldap3
Expand Down Expand Up @@ -88,8 +88,7 @@ def __init__(self, config, account_handler):
def get_supported_login_types(self):
return {'m.login.password': ('password',)}

@defer.inlineCallbacks
def check_auth(self, username, login_type, login_dict):
async def check_auth(self, username, login_type, login_dict):
""" Attempt to authenticate a user against an LDAP Server
and register an account if none exists.
Expand All @@ -103,7 +102,7 @@ def check_auth(self, username, login_type, login_dict):
# an anonymous authorization state and not suitable for user
# authentication.
if not password:
defer.returnValue(False)
return False

if username.startswith("@") and ":" in username:
# username is of the form @foo:bar.com
Expand All @@ -122,7 +121,7 @@ def check_auth(self, username, login_type, login_dict):
uid_value = login + "@" + domain
default_display_name = login
except ActiveDirectoryUPNException:
defer.returnValue(False)
return False

try:
tls = ldap3.Tls(validate=ssl.CERT_REQUIRED)
Expand All @@ -140,7 +139,7 @@ def check_auth(self, username, login_type, login_dict):
value=uid_value,
base=self.ldap_base
)
result, conn = yield self._ldap_simple_bind(
result, conn = await self._ldap_simple_bind(
server=server, bind_dn=bind_dn, password=password
)
logger.debug(
Expand All @@ -150,10 +149,10 @@ def check_auth(self, username, login_type, login_dict):
conn
)
if not result:
defer.returnValue(False)
return False
elif self.ldap_mode == LDAPMode.SEARCH:
filters = [(self.ldap_attributes["uid"], uid_value)]
result, conn, _ = yield self._ldap_authenticated_search(
result, conn, _ = await self._ldap_authenticated_search(
server=server, password=password, filters=filters
)
logger.debug(
Expand All @@ -163,7 +162,7 @@ def check_auth(self, username, login_type, login_dict):
conn
)
if not result:
defer.returnValue(False)
return False
else: # pragma: no cover
raise RuntimeError(
'Invalid LDAP mode specified: {mode}'.format(
Expand All @@ -181,17 +180,17 @@ def check_auth(self, username, login_type, login_dict):
"Authentication method yielded no LDAP connection, "
"aborting!"
)
defer.returnValue(False)
return False

# Get full user id from localpart
user_id = self.account_handler.get_qualified_user_id(localpart)

# check if user with user_id exists
if (yield self.account_handler.check_user_exists(user_id)):
if await self.account_handler.check_user_exists(user_id):
# exists, authentication complete
if hasattr(conn, "unbind"):
yield threads.deferToThread(conn.unbind)
defer.returnValue(user_id)
await threads.deferToThread(conn.unbind)
return user_id

else:
# does not exist, register
Expand All @@ -200,7 +199,7 @@ def check_auth(self, username, login_type, login_dict):
# existing ldap connection
filters = [(self.ldap_attributes['uid'], uid_value)]

result, conn, response = yield self._ldap_authenticated_search(
result, conn, response = await self._ldap_authenticated_search(
server=server, password=password, filters=filters,
)

Expand All @@ -222,18 +221,17 @@ def check_auth(self, username, login_type, login_dict):
mail = None

# Register the user
user_id = yield self.register_user(localpart, display_name, mail)
user_id = await self.register_user(localpart, display_name, mail)

defer.returnValue(user_id)
return user_id

defer.returnValue(False)
return False

except ldap3.core.exceptions.LDAPException as e:
logger.warning("Error during ldap authentication: %s", e)
defer.returnValue(False)
return False

@defer.inlineCallbacks
def check_3pid_auth(self, medium, address, password):
async def check_3pid_auth(self, medium, address, password):
""" Handle authentication against thirdparty login types, such as email
Args:
Expand All @@ -248,11 +246,11 @@ def check_3pid_auth(self, medium, address, password):
if self.ldap_mode != LDAPMode.SEARCH:
logger.debug("3PID LDAP login/register attempted but LDAP search mode "
"not enabled. Bailing.")
defer.returnValue(None)
return None

# We currently only support email
if medium != "email":
defer.returnValue(None)
return None

# Talk to LDAP and check if this email/password combo is correct
try:
Expand All @@ -265,7 +263,7 @@ def check_3pid_auth(self, medium, address, password):
)

search_filter = [(self.ldap_attributes["mail"], address)]
result, conn, response = yield self._ldap_authenticated_search(
result, conn, response = await self._ldap_authenticated_search(
server=server, password=password, filters=search_filter,
)

Expand All @@ -279,10 +277,10 @@ def check_3pid_auth(self, medium, address, password):

# Close connection
if hasattr(conn, "unbind"):
yield threads.deferToThread(conn.unbind)
await threads.deferToThread(conn.unbind)

if not result:
defer.returnValue(None)
return None

# Extract the username from the search response from the LDAP server
localpart = response["attributes"].get(
Expand All @@ -306,16 +304,15 @@ def check_3pid_auth(self, medium, address, password):
givenName = givenName[0] if len(givenName) == 1 else localpart

# Register the user
user_id = yield self.register_user(localpart, givenName, address)
user_id = await self.register_user(localpart, givenName, address)

defer.returnValue(user_id)
return user_id

except ldap3.core.exceptions.LDAPException as e:
logger.warning("Error during ldap authentication: %s", e)
raise

@defer.inlineCallbacks
def register_user(self, localpart, name, email_address):
async def register_user(self, localpart, name, email_address):
"""Register a Synapse user, first checking if they exist.
Args:
Expand All @@ -329,9 +326,9 @@ def register_user(self, localpart, name, email_address):
# Get full user id from localpart
user_id = self.account_handler.get_qualified_user_id(localpart)

if (yield self.account_handler.check_user_exists(user_id)):
if await self.account_handler.check_user_exists(user_id):
# exists, authentication complete
defer.returnValue(user_id)
return user_id

# register an email address if one exists
emails = [email_address] if email_address is not None else []
Expand All @@ -341,14 +338,14 @@ def register_user(self, localpart, name, email_address):
# from password providers
if parse_version(synapse.__version__) <= parse_version("0.99.3"):
user_id, access_token = (
yield self.account_handler.register(
await self.account_handler.register(
localpart=localpart, displayname=name,
)
)
else:
# If Synapse has support, bind emails
user_id, access_token = (
yield self.account_handler.register(
await self.account_handler.register(
localpart=localpart, displayname=name, emails=emails,
)
)
Expand All @@ -358,7 +355,7 @@ def register_user(self, localpart, name, email_address):
user_id,
)

defer.returnValue(user_id)
return user_id

@staticmethod
def parse_config(config):
Expand Down Expand Up @@ -407,8 +404,7 @@ class _LdapConfig(object):

return ldap_config

@defer.inlineCallbacks
def _ldap_simple_bind(self, server, bind_dn, password):
async def _ldap_simple_bind(self, server, bind_dn, password):
""" Attempt a simple bind with the credentials
given by the user against the LDAP server.
Expand All @@ -420,7 +416,7 @@ def _ldap_simple_bind(self, server, bind_dn, password):

try:
# bind with the the local user's ldap credentials
conn = yield threads.deferToThread(
conn = await threads.deferToThread(
ldap3.Connection,
server, bind_dn, password,
authentication=LDAP_AUTH_SIMPLE,
Expand All @@ -432,33 +428,32 @@ def _ldap_simple_bind(self, server, bind_dn, password):
)

if self.ldap_start_tls:
yield threads.deferToThread(conn.open)
yield threads.deferToThread(conn.start_tls)
await threads.deferToThread(conn.open)
await threads.deferToThread(conn.start_tls)
logger.debug(
"Upgraded LDAP connection in simple bind mode through "
"StartTLS: %s",
conn
)

if (yield threads.deferToThread(conn.bind)):
if await threads.deferToThread(conn.bind):
# GOOD: bind okay
logger.debug("LDAP Bind successful in simple bind mode.")
defer.returnValue((True, conn))
return (True, conn)

# BAD: bind failed
logger.info(
"Binding against LDAP failed for '%s' failed: %s",
bind_dn, conn.result['description']
)
yield threads.deferToThread(conn.unbind)
defer.returnValue((False, None))
await threads.deferToThread(conn.unbind)
return (False, None)

except ldap3.core.exceptions.LDAPException as e:
logger.warning("Error during LDAP authentication: %s", e)
raise

@defer.inlineCallbacks
def _ldap_authenticated_search(self, server, password, filters):
async def _ldap_authenticated_search(self, server, password, filters):
"""Attempt to login with the preconfigured bind_dn and then continue
searching and filtering within the base_dn.
Expand All @@ -480,7 +475,7 @@ def _ldap_authenticated_search(self, server, password, filters):
"""

try:
conn = yield threads.deferToThread(
conn = await threads.deferToThread(
ldap3.Connection,
server,
self.ldap_bind_dn,
Expand All @@ -493,21 +488,21 @@ def _ldap_authenticated_search(self, server, password, filters):
)

if self.ldap_start_tls:
yield threads.deferToThread(conn.open)
yield threads.deferToThread(conn.start_tls)
await threads.deferToThread(conn.open)
await threads.deferToThread(conn.start_tls)
logger.debug(
"Upgraded LDAP connection in search mode through "
"StartTLS: %s",
conn
)

if not (yield threads.deferToThread(conn.bind)):
if not await threads.deferToThread(conn.bind):
logger.warning(
"Binding against LDAP with `bind_dn` failed: %s",
conn.result['description']
)
yield threads.deferToThread(conn.unbind)
defer.returnValue((False, None, None))
await threads.deferToThread(conn.unbind)
return (False, None, None)

# Construct search filter
query = ""
Expand All @@ -529,7 +524,7 @@ def _ldap_authenticated_search(self, server, password, filters):
"LDAP search filter: %s",
query
)
yield threads.deferToThread(
await threads.deferToThread(
conn.search,
search_base=self.ldap_base,
search_filter=query,
Expand All @@ -555,12 +550,12 @@ def _ldap_authenticated_search(self, server, password, filters):
# unbind and simple bind with user_dn to verify the password
# Note: do not use rebind(), for some reason it did not verify
# the password for me!
yield threads.deferToThread(conn.unbind)
result, conn = yield self._ldap_simple_bind(
await threads.deferToThread(conn.unbind)
result, conn = await self._ldap_simple_bind(
server=server, bind_dn=user_dn, password=password
)

defer.returnValue((result, conn, responses[0]))
return (result, conn, responses[0])
else:
# BAD: found 0 or > 1 results, abort!
if len(responses) == 0:
Expand All @@ -573,9 +568,9 @@ def _ldap_authenticated_search(self, server, password, filters):
"LDAP search returned too many (%s) results for '%s'",
len(responses), filters
)
yield threads.deferToThread(conn.unbind)
await threads.deferToThread(conn.unbind)

defer.returnValue((False, None, None))
return (False, None, None)

except ldap3.core.exceptions.LDAPException as e:
logger.warning("Error during LDAP authentication: %s", e)
Expand Down
Loading

0 comments on commit 2e9d781

Please sign in to comment.