diff --git a/changelog.d/16155.bugfix b/changelog.d/16155.bugfix new file mode 100644 index 000000000000..8b2dc0400672 --- /dev/null +++ b/changelog.d/16155.bugfix @@ -0,0 +1 @@ +Fix IPv6-related bugs on SMTP settings, adding groundwork to fix similar issues. Contributed by @evilham and @telmich (ungleich.ch). diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index 804cc6e81e00..3c7b31a16066 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -23,9 +23,11 @@ import twisted from twisted.internet.defer import Deferred -from twisted.internet.interfaces import IOpenSSLContextFactory +from twisted.internet.endpoints import HostnameEndpoint +from twisted.internet.interfaces import IOpenSSLContextFactory, IProtocolFactory from twisted.internet.ssl import optionsForClientTLS from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory +from twisted.protocols.tls import TLSMemoryBIOFactory from synapse.logging.context import make_deferred_yieldable from synapse.types import ISynapseReactor @@ -37,6 +39,9 @@ _is_old_twisted = parse_version(twisted.__version__) < parse_version("21") +# We assign the name ESMTPTLSClientFactory, to be able to redefine it in tests +ESMTPTLSClientFactory = TLSMemoryBIOFactory + class _NoTLSESMTPSender(ESMTPSender): """Extend ESMTPSender to disable TLS @@ -97,6 +102,7 @@ def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory: **kwargs, ) + factory: IProtocolFactory if _is_old_twisted: # before twisted 21.2, we have to override the ESMTPSender protocol to disable # TLS @@ -109,23 +115,13 @@ def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory: # set to enable TLS. factory = build_sender_factory(hostname=smtphost if enable_tls else None) + endpoint = HostnameEndpoint( + reactor, smtphost, smtpport, timeout=30, bindAddress=None + ) if force_tls: - reactor.connectSSL( - smtphost, - smtpport, - factory, - optionsForClientTLS(smtphost), - timeout=30, - bindAddress=None, - ) - else: - reactor.connectTCP( - smtphost, - smtpport, - factory, - timeout=30, - bindAddress=None, - ) + factory = ESMTPTLSClientFactory(optionsForClientTLS(smtphost), True, factory) + + await make_deferred_yieldable(endpoint.connect(factory)) await make_deferred_yieldable(d) diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py index 8b6e4a40b620..3a5462cb4dc5 100644 --- a/tests/handlers/test_send_email.py +++ b/tests/handlers/test_send_email.py @@ -13,19 +13,42 @@ # limitations under the License. -from typing import Callable, List, Tuple +from typing import Callable, List, Tuple, Type, Union from zope.interface import implementer from twisted.internet import defer -from twisted.internet.address import IPv4Address +from twisted.internet._sslverify import ClientTLSOptions +from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.defer import ensureDeferred +from twisted.internet.interfaces import IProtocolFactory +from twisted.internet.ssl import ContextFactory from twisted.mail import interfaces, smtp +from twisted.protocols.tls import TLSMemoryBIOFactory + +import synapse.handlers.send_email from tests.server import FakeTransport from tests.unittest import HomeserverTestCase, override_config +def TestingESMTPTLSClientFactory( + contextFactory: ContextFactory, + _connectWrapped: bool, + wrappedProtocol: IProtocolFactory, +) -> IProtocolFactory: + """We use this to pass through in testing without using TLS, but + saving the context information to check that it would have happened. + + Note that this is what the MemoryReactor does on connectSSL. + It only saves the contextFactory, but starts the connection with the + underlying Factory. + See: L{twisted.internet.testing.MemoryReactor.connectSSL}""" + + wrappedProtocol._testingContextFactory = contextFactory # type: ignore[attr-defined] + return wrappedProtocol + + @implementer(interfaces.IMessageDelivery) class _DummyMessageDelivery: def __init__(self) -> None: @@ -75,7 +98,20 @@ def connectionLost(self) -> None: pass -class SendEmailHandlerTestCase(HomeserverTestCase): +class SendEmailHandlerTestCaseIPv4(HomeserverTestCase): + ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address + + def setUp(self) -> None: + HomeserverTestCase.lookups["localhost"] = HomeserverTestCase.lookups.get( + "localhost", "127.0.0.1" + ) + super().setUp() + + def tearDown(self) -> None: + # Restore ESMTPTLSClientFactory + synapse.handlers.send_email.ESMTPTLSClientFactory = TLSMemoryBIOFactory + super().setUp() + def test_send_email(self) -> None: """Happy-path test that we can send email to a non-TLS server.""" h = self.hs.get_send_email_handler() @@ -89,7 +125,7 @@ def test_send_email(self) -> None: (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[ 0 ] - self.assertEqual(host, "localhost") + self.assertEqual(host, self.lookups["localhost"]) self.assertEqual(port, 25) # wire it up to an SMTP server @@ -105,7 +141,7 @@ def test_send_email(self) -> None: FakeTransport( client_protocol, self.reactor, - peer_address=IPv4Address("TCP", "127.0.0.1", 1234), + peer_address=self.ip_class("TCP", self.lookups["localhost"], 1234), ) ) @@ -128,6 +164,8 @@ def test_send_email(self) -> None: ) def test_send_email_force_tls(self) -> None: """Happy-path test that we can send email to an Implicit TLS server.""" + # Patch ESMTPTLSClientFactory + synapse.handlers.send_email.ESMTPTLSClientFactory = TestingESMTPTLSClientFactory # type: ignore[assignment] h = self.hs.get_send_email_handler() d = ensureDeferred( h.send_email( @@ -135,17 +173,23 @@ def test_send_email_force_tls(self) -> None: ) ) # there should be an attempt to connect to localhost:465 - self.assertEqual(len(self.reactor.sslClients), 1) + self.assertEqual(len(self.reactor.tcpClients), 1) ( host, port, client_factory, - contextFactory, _timeout, _bindAddress, - ) = self.reactor.sslClients[0] - self.assertEqual(host, "localhost") + ) = self.reactor.tcpClients[0] + self.assertEqual(host, self.lookups["localhost"]) self.assertEqual(port, 465) + # We need to make sure that TLS is happenning + self.assertIsInstance( + client_factory._wrappedFactory._testingContextFactory, + ClientTLSOptions, + ) + # And since we use endpoints, they go through reactor.connectTCP + # which works differently to connectSSL on the testing reactor # wire it up to an SMTP server message_delivery = _DummyMessageDelivery() @@ -160,7 +204,7 @@ def test_send_email_force_tls(self) -> None: FakeTransport( client_protocol, self.reactor, - peer_address=IPv4Address("TCP", "127.0.0.1", 1234), + peer_address=self.ip_class("TCP", self.lookups["localhost"], 1234), ) ) @@ -172,3 +216,11 @@ def test_send_email_force_tls(self) -> None: user, msg = message_delivery.messages.pop() self.assertEqual(str(user), "foo@bar.com") self.assertIn(b"Subject: test subject", msg) + + +class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4): + ip_class = IPv6Address + + def setUp(self) -> None: + HomeserverTestCase.lookups["localhost"] = "::1" + super().setUp() diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py index 05d5e39cabd4..ca45dfd274d4 100644 --- a/tests/rest/media/test_url_preview.py +++ b/tests/rest/media/test_url_preview.py @@ -16,7 +16,7 @@ import json import os import re -from typing import Any, Dict, Optional, Sequence, Tuple, Type +from typing import Dict, List, Optional, Sequence, Tuple, Type, Union from urllib.parse import quote, urlencode from twisted.internet._resolver import HostResolution @@ -48,6 +48,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): skip = "url preview feature requires lxml" hijack_auth = True + lookups: Dict[ # type: ignore[misc, assignment] + str, List[Tuple[Union[Type[IPv4Address], Type[IPv6Address]], str]] + ] user_id = "@test:user" end_content = ( b"" @@ -120,7 +123,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo_resource = hs.get_media_repository_resource() self.preview_url = media_repo_resource.children[b"preview_url"] - self.lookups: Dict[str, Any] = {} + self.lookups = {} class Resolver: def resolveHostName( diff --git a/tests/server.py b/tests/server.py index 481fe34c5caa..ecc3987d6666 100644 --- a/tests/server.py +++ b/tests/server.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import hashlib +import ipaddress import json import logging import os @@ -45,7 +46,7 @@ from typing_extensions import ParamSpec from zope.interface import implementer -from twisted.internet import address, threads, udp +from twisted.internet import address, tcp, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed from twisted.internet.error import DNSLookupError @@ -567,6 +568,8 @@ def connectTCP( conn = super().connectTCP( host, port, factory, timeout=timeout, bindAddress=None ) + if self.lookups and host in self.lookups: + validate_connector(conn, self.lookups[host]) callback = self._tcp_callbacks.get((host, port)) if callback: @@ -599,6 +602,54 @@ def advance(self, amount: float) -> None: super().advance(0) +def validate_connector(connector: tcp.Connector, expected_ip: str) -> None: + """Try to validate the obtained connector as it would happen when + synapse is running and the conection will be established. + + This method will raise a useful exception when necessary, else it will + just do nothing. + + This is in order to help catch quirks related to reactor.connectTCP, + since when called directly, the connector's destination will be of type + IPv4Address, with the hostname as the literal host that was given (which + could be an IPv6-only host or an IPv6 literal). + + But when called from reactor.connectTCP *through* e.g. an Endpoint, the + connector's destination will contain the specific IP address with the + correct network stack class. + + Note that testing code paths that use connectTCP directly should not be + affected by this check, unless they specifically add a test with a + matching HomeserverTestCase.lookups[HOSTNAME] = "IPv6Literal". + For an example of implementing such tests, see test/handlers/send_email.py. + """ + destination = connector.getDestination() + + def check_ip( + cls: Union[Type[ipaddress.IPv4Address], Type[ipaddress.IPv6Address]] + ) -> None: + """With this class we produce a more informative error if needed""" + try: + cls(expected_ip) + except Exception: + raise ValueError( + "Invalid IP type and resolution, got %s, expected: %s %s" + % (destination, expected_ip, cls) + ) + + # We use address.IPv{4,6}Address to check what the reactor thinks it is + # is sending but check for validity with with ipaddress.IPv{4,6}Address + # because they fail with IPs on the wrong network stack. + if isinstance(destination, address.IPv4Address): + check_ip(ipaddress.IPv4Address) + elif isinstance(destination, address.IPv6Address): + check_ip(ipaddress.IPv6Address) + else: + raise ValueError( + "Unknown address type %s for %s" % (type(destination), destination) + ) + + class ThreadPool: """ Threadless thread pool. diff --git a/tests/unittest.py b/tests/unittest.py index b0721e060c40..0f3bf831cae6 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -319,6 +319,7 @@ class HomeserverTestCase(TestCase): hijack_auth: ClassVar[bool] = True needs_threadpool: ClassVar[bool] = False servlets: ClassVar[List[RegisterServletsFunc]] = [] + lookups: ClassVar[Dict[str, str]] = {} def __init__(self, methodName: str): super().__init__(methodName) @@ -334,6 +335,12 @@ def setUp(self) -> None: calling the prepare function. """ self.reactor, self.clock = get_clock() + + # The home server will start connecting places as soon as it starts, + # so we must update some fake DNS entries early. + if HomeserverTestCase.lookups: + self.reactor.lookups.update(HomeserverTestCase.lookups) + self._hs_args = {"clock": self.clock, "reactor": self.reactor} self.hs = self.make_homeserver(self.reactor, self.clock) @@ -410,6 +417,8 @@ async def get_requester(*args: Any, **kwargs: Any) -> Requester: def tearDown(self) -> None: # Reset to not use frozen dicts. events.USE_FROZEN_DICTS = False + # Clear any possible forced lookups + HomeserverTestCase.lookups.clear() def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None: """