Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert to async/await #101

Merged
merged 11 commits into from
Sep 2, 2020
109 changes: 52 additions & 57 deletions ldap_auth_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.internet import defer, threads
from twisted.internet import threads


import ldap3
Expand Down Expand Up @@ -88,8 +88,7 @@ def __init__(self, config, account_handler):
def get_supported_login_types(self):
return {'m.login.password': ('password',)}

@defer.inlineCallbacks
def check_auth(self, username, login_type, login_dict):
async def check_auth(self, username, login_type, login_dict):
""" Attempt to authenticate a user against an LDAP Server
and register an account if none exists.

Expand All @@ -103,7 +102,7 @@ def check_auth(self, username, login_type, login_dict):
# an anonymous authorization state and not suitable for user
# authentication.
if not password:
defer.returnValue(False)
return False

if username.startswith("@") and ":" in username:
# username is of the form @foo:bar.com
Expand All @@ -122,7 +121,7 @@ def check_auth(self, username, login_type, login_dict):
uid_value = login + "@" + domain
default_display_name = login
except ActiveDirectoryUPNException:
defer.returnValue(False)
return False

try:
tls = ldap3.Tls(validate=ssl.CERT_REQUIRED)
Expand All @@ -140,7 +139,7 @@ def check_auth(self, username, login_type, login_dict):
value=uid_value,
base=self.ldap_base
)
result, conn = yield self._ldap_simple_bind(
result, conn = await self._ldap_simple_bind(
server=server, bind_dn=bind_dn, password=password
)
logger.debug(
Expand All @@ -150,10 +149,10 @@ def check_auth(self, username, login_type, login_dict):
conn
)
if not result:
defer.returnValue(False)
return False
elif self.ldap_mode == LDAPMode.SEARCH:
filters = [(self.ldap_attributes["uid"], uid_value)]
result, conn, _ = yield self._ldap_authenticated_search(
result, conn, _ = await self._ldap_authenticated_search(
server=server, password=password, filters=filters
)
logger.debug(
Expand All @@ -163,7 +162,7 @@ def check_auth(self, username, login_type, login_dict):
conn
)
if not result:
defer.returnValue(False)
return False
else: # pragma: no cover
raise RuntimeError(
'Invalid LDAP mode specified: {mode}'.format(
Expand All @@ -181,17 +180,17 @@ def check_auth(self, username, login_type, login_dict):
"Authentication method yielded no LDAP connection, "
"aborting!"
)
defer.returnValue(False)
return False

# Get full user id from localpart
user_id = self.account_handler.get_qualified_user_id(localpart)

# check if user with user_id exists
if (yield self.account_handler.check_user_exists(user_id)):
if await self.account_handler.check_user_exists(user_id):
# exists, authentication complete
if hasattr(conn, "unbind"):
yield threads.deferToThread(conn.unbind)
defer.returnValue(user_id)
await threads.deferToThread(conn.unbind)
return user_id

else:
# does not exist, register
Expand All @@ -200,7 +199,7 @@ def check_auth(self, username, login_type, login_dict):
# existing ldap connection
filters = [(self.ldap_attributes['uid'], uid_value)]

result, conn, response = yield self._ldap_authenticated_search(
result, conn, response = await self._ldap_authenticated_search(
server=server, password=password, filters=filters,
)

Expand All @@ -222,18 +221,17 @@ def check_auth(self, username, login_type, login_dict):
mail = None

# Register the user
user_id = yield self.register_user(localpart, display_name, mail)
user_id = await self.register_user(localpart, display_name, mail)

defer.returnValue(user_id)
return user_id

defer.returnValue(False)
return False

except ldap3.core.exceptions.LDAPException as e:
logger.warning("Error during ldap authentication: %s", e)
defer.returnValue(False)
return False

@defer.inlineCallbacks
def check_3pid_auth(self, medium, address, password):
async def check_3pid_auth(self, medium, address, password):
""" Handle authentication against thirdparty login types, such as email

Args:
Expand All @@ -248,11 +246,11 @@ def check_3pid_auth(self, medium, address, password):
if self.ldap_mode != LDAPMode.SEARCH:
logger.debug("3PID LDAP login/register attempted but LDAP search mode "
"not enabled. Bailing.")
defer.returnValue(None)
return None

# We currently only support email
if medium != "email":
defer.returnValue(None)
return None

# Talk to LDAP and check if this email/password combo is correct
try:
Expand All @@ -265,7 +263,7 @@ def check_3pid_auth(self, medium, address, password):
)

search_filter = [(self.ldap_attributes["mail"], address)]
result, conn, response = yield self._ldap_authenticated_search(
result, conn, response = await self._ldap_authenticated_search(
server=server, password=password, filters=search_filter,
)

Expand All @@ -279,10 +277,10 @@ def check_3pid_auth(self, medium, address, password):

# Close connection
if hasattr(conn, "unbind"):
yield threads.deferToThread(conn.unbind)
await threads.deferToThread(conn.unbind)

if not result:
defer.returnValue(None)
return None

# Extract the username from the search response from the LDAP server
localpart = response["attributes"].get(
Expand All @@ -306,16 +304,15 @@ def check_3pid_auth(self, medium, address, password):
givenName = givenName[0] if len(givenName) == 1 else localpart

# Register the user
user_id = yield self.register_user(localpart, givenName, address)
user_id = await self.register_user(localpart, givenName, address)

defer.returnValue(user_id)
return user_id

except ldap3.core.exceptions.LDAPException as e:
logger.warning("Error during ldap authentication: %s", e)
raise

@defer.inlineCallbacks
def register_user(self, localpart, name, email_address):
async def register_user(self, localpart, name, email_address):
"""Register a Synapse user, first checking if they exist.

Args:
Expand All @@ -329,9 +326,9 @@ def register_user(self, localpart, name, email_address):
# Get full user id from localpart
user_id = self.account_handler.get_qualified_user_id(localpart)

if (yield self.account_handler.check_user_exists(user_id)):
if await self.account_handler.check_user_exists(user_id):
# exists, authentication complete
defer.returnValue(user_id)
return user_id

# register an email address if one exists
emails = [email_address] if email_address is not None else []
Expand All @@ -341,14 +338,14 @@ def register_user(self, localpart, name, email_address):
# from password providers
if parse_version(synapse.__version__) <= parse_version("0.99.3"):
user_id, access_token = (
yield self.account_handler.register(
await self.account_handler.register(
localpart=localpart, displayname=name,
)
)
else:
# If Synapse has support, bind emails
user_id, access_token = (
yield self.account_handler.register(
await self.account_handler.register(
localpart=localpart, displayname=name, emails=emails,
)
)
Expand All @@ -358,7 +355,7 @@ def register_user(self, localpart, name, email_address):
user_id,
)

defer.returnValue(user_id)
return user_id

@staticmethod
def parse_config(config):
Expand Down Expand Up @@ -407,8 +404,7 @@ class _LdapConfig(object):

return ldap_config

@defer.inlineCallbacks
def _ldap_simple_bind(self, server, bind_dn, password):
async def _ldap_simple_bind(self, server, bind_dn, password):
""" Attempt a simple bind with the credentials
given by the user against the LDAP server.

Expand All @@ -420,7 +416,7 @@ def _ldap_simple_bind(self, server, bind_dn, password):

try:
# bind with the the local user's ldap credentials
conn = yield threads.deferToThread(
conn = await threads.deferToThread(
ldap3.Connection,
server, bind_dn, password,
authentication=LDAP_AUTH_SIMPLE,
Expand All @@ -432,33 +428,32 @@ def _ldap_simple_bind(self, server, bind_dn, password):
)

if self.ldap_start_tls:
yield threads.deferToThread(conn.open)
yield threads.deferToThread(conn.start_tls)
await threads.deferToThread(conn.open)
await threads.deferToThread(conn.start_tls)
logger.debug(
"Upgraded LDAP connection in simple bind mode through "
"StartTLS: %s",
conn
)

if (yield threads.deferToThread(conn.bind)):
if await threads.deferToThread(conn.bind):
# GOOD: bind okay
logger.debug("LDAP Bind successful in simple bind mode.")
defer.returnValue((True, conn))
return (True, conn)

# BAD: bind failed
logger.info(
"Binding against LDAP failed for '%s' failed: %s",
bind_dn, conn.result['description']
)
yield threads.deferToThread(conn.unbind)
defer.returnValue((False, None))
await threads.deferToThread(conn.unbind)
return (False, None)

except ldap3.core.exceptions.LDAPException as e:
logger.warning("Error during LDAP authentication: %s", e)
raise

@defer.inlineCallbacks
def _ldap_authenticated_search(self, server, password, filters):
async def _ldap_authenticated_search(self, server, password, filters):
"""Attempt to login with the preconfigured bind_dn and then continue
searching and filtering within the base_dn.

Expand All @@ -480,7 +475,7 @@ def _ldap_authenticated_search(self, server, password, filters):
"""

try:
conn = yield threads.deferToThread(
conn = await threads.deferToThread(
ldap3.Connection,
server,
self.ldap_bind_dn,
Expand All @@ -493,21 +488,21 @@ def _ldap_authenticated_search(self, server, password, filters):
)

if self.ldap_start_tls:
yield threads.deferToThread(conn.open)
yield threads.deferToThread(conn.start_tls)
await threads.deferToThread(conn.open)
await threads.deferToThread(conn.start_tls)
logger.debug(
"Upgraded LDAP connection in search mode through "
"StartTLS: %s",
conn
)

if not (yield threads.deferToThread(conn.bind)):
if not await threads.deferToThread(conn.bind):
logger.warning(
"Binding against LDAP with `bind_dn` failed: %s",
conn.result['description']
)
yield threads.deferToThread(conn.unbind)
defer.returnValue((False, None, None))
await threads.deferToThread(conn.unbind)
return (False, None, None)

# Construct search filter
query = ""
Expand All @@ -529,7 +524,7 @@ def _ldap_authenticated_search(self, server, password, filters):
"LDAP search filter: %s",
query
)
yield threads.deferToThread(
await threads.deferToThread(
conn.search,
search_base=self.ldap_base,
search_filter=query,
Expand All @@ -555,12 +550,12 @@ def _ldap_authenticated_search(self, server, password, filters):
# unbind and simple bind with user_dn to verify the password
# Note: do not use rebind(), for some reason it did not verify
# the password for me!
yield threads.deferToThread(conn.unbind)
result, conn = yield self._ldap_simple_bind(
await threads.deferToThread(conn.unbind)
result, conn = await self._ldap_simple_bind(
server=server, bind_dn=user_dn, password=password
)

defer.returnValue((result, conn, responses[0]))
return (result, conn, responses[0])
else:
# BAD: found 0 or > 1 results, abort!
if len(responses) == 0:
Expand All @@ -573,9 +568,9 @@ def _ldap_authenticated_search(self, server, password, filters):
"LDAP search returned too many (%s) results for '%s'",
len(responses), filters
)
yield threads.deferToThread(conn.unbind)
await threads.deferToThread(conn.unbind)

defer.returnValue((False, None, None))
return (False, None, None)

except ldap3.core.exceptions.LDAPException as e:
logger.warning("Error during LDAP authentication: %s", e)
Expand Down
Loading