-
Notifications
You must be signed in to change notification settings - Fork 147
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tests for the HTTP Proxy support (asyncio and common only) (#151)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
- Loading branch information
1 parent
b2a74db
commit 53015c7
Showing
5 changed files
with
440 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Add tests for HTTP Proxy support. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import logging | ||
import types | ||
from asyncio import AbstractEventLoop, transports | ||
from asyncio.protocols import BaseProtocol, Protocol | ||
from asyncio.transports import Transport | ||
from contextvars import Context | ||
from typing import Any, Callable, List, Optional, Tuple | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class TimelessEventLoopWrapper: | ||
@property # type: ignore | ||
def __class__(self): | ||
""" | ||
Fakes isinstance(this, AbstractEventLoop) so we can set_event_loop | ||
without fail. | ||
""" | ||
return self._wrapped_loop.__class__ | ||
|
||
def __init__(self, wrapped_loop: AbstractEventLoop): | ||
self._wrapped_loop = wrapped_loop | ||
self._time = 0.0 | ||
self._to_be_called: List[Tuple[float, Any, Any, Any]] = [] | ||
|
||
def advance(self, time_delta: float): | ||
target_time = self._time + time_delta | ||
logger.debug( | ||
"advancing from %f by %f (%d in queue)", | ||
self._time, | ||
time_delta, | ||
len(self._to_be_called), | ||
) | ||
while self._time < target_time and self._to_be_called: | ||
# pop off the next callback from the queue | ||
next_time, next_callback, args, _context = self._to_be_called[0] | ||
if next_time > target_time: | ||
# this isn't allowed to run yet | ||
break | ||
logger.debug("callback at %f on %r", next_time, next_callback) | ||
self._to_be_called = self._to_be_called[1:] | ||
self._time = next_time | ||
next_callback(*args) | ||
|
||
# no more tasks can run now but advance to the time anyway | ||
self._time = target_time | ||
|
||
def __getattr__(self, item: str): | ||
""" | ||
We use this to delegate other method calls to the real EventLoop. | ||
""" | ||
value = getattr(self._wrapped_loop, item) | ||
if isinstance(value, types.MethodType): | ||
# rebind this method to be called on us | ||
# this makes the wrapped class use our overridden methods when | ||
# available. | ||
# we have to do this because methods are bound to the underlying | ||
# event loop, which will call `self.call_later` or something | ||
# which won't normally hit us because we are not an actual subtype. | ||
return types.MethodType(value.__func__, self) | ||
else: | ||
return value | ||
|
||
def call_later( | ||
self, | ||
delay: float, | ||
callback: Callable, | ||
*args: Any, | ||
context: Optional[Context] = None, | ||
): | ||
self.call_at(self._time + delay, callback, *args, context=context) | ||
|
||
def call_at( | ||
self, | ||
when: float, | ||
callback: Callable, | ||
*args: Any, | ||
context: Optional[Context] = None, | ||
): | ||
logger.debug(f"Calling {callback} at %f...", when) | ||
self._to_be_called.append((when, callback, args, context)) | ||
|
||
# re-sort list in ascending time order | ||
self._to_be_called.sort(key=lambda x: x[0]) | ||
|
||
def call_soon( | ||
self, callback: Callable, *args: Any, context: Optional[Context] = None | ||
): | ||
return self.call_later(0, callback, *args, context=context) | ||
|
||
def time(self) -> float: | ||
return self._time | ||
|
||
|
||
class MockTransport(Transport): | ||
""" | ||
A transport intended to be driven by tests. | ||
Stores received data into a buffer. | ||
""" | ||
|
||
def __init__(self): | ||
# Holds bytes received | ||
self.buffer = b"" | ||
|
||
# Whether we reached the end of file/stream | ||
self.eofed = False | ||
|
||
# Whether the connection was aborted | ||
self.aborted = False | ||
|
||
# The protocol attached to this transport | ||
self.protocol = None | ||
|
||
# Whether this transport was closed | ||
self.closed = False | ||
|
||
def reset_mock(self) -> None: | ||
self.buffer = b"" | ||
self.eofed = False | ||
self.aborted = False | ||
self.closed = False | ||
|
||
def is_reading(self) -> bool: | ||
return True | ||
|
||
def pause_reading(self) -> None: | ||
pass # NOP | ||
|
||
def resume_reading(self) -> None: | ||
pass # NOP | ||
|
||
def set_write_buffer_limits(self, high: int = None, low: int = None) -> None: | ||
pass # NOP | ||
|
||
def get_write_buffer_size(self) -> int: | ||
"""Return the current size of the write buffer.""" | ||
raise NotImplementedError | ||
|
||
def write(self, data: bytes) -> None: | ||
self.buffer += data | ||
|
||
def write_eof(self) -> None: | ||
self.eofed = True | ||
|
||
def can_write_eof(self) -> bool: | ||
return True | ||
|
||
def abort(self) -> None: | ||
self.aborted = True | ||
|
||
def pretend_to_receive(self, data: bytes) -> None: | ||
proto = self.get_protocol() | ||
assert isinstance(proto, Protocol) | ||
proto.data_received(data) | ||
|
||
def set_protocol(self, protocol: BaseProtocol) -> None: | ||
self.protocol = protocol | ||
|
||
def get_protocol(self) -> BaseProtocol: | ||
assert isinstance(self.protocol, BaseProtocol) | ||
return self.protocol | ||
|
||
def close(self) -> None: | ||
self.closed = True | ||
|
||
|
||
class MockProtocol(Protocol): | ||
""" | ||
A protocol intended to be driven by tests. | ||
Stores received data into a buffer. | ||
""" | ||
|
||
def __init__(self): | ||
self._to_transmit = b"" | ||
self.received_bytes = b"" | ||
self.transport = None | ||
|
||
def data_received(self, data: bytes) -> None: | ||
self.received_bytes += data | ||
|
||
def connection_made(self, transport: transports.BaseTransport) -> None: | ||
assert isinstance(transport, Transport) | ||
self.transport = transport | ||
if self._to_transmit: | ||
transport.write(self._to_transmit) | ||
|
||
def write(self, data: bytes) -> None: | ||
if self.transport: | ||
self.transport.write(data) | ||
else: | ||
self._to_transmit += data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2020 The Matrix.org Foundation C.I.C. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import asyncio | ||
from asyncio import AbstractEventLoop, BaseTransport, Protocol, Task | ||
from typing import Optional, Tuple, cast | ||
|
||
from sygnal.exceptions import ProxyConnectError | ||
from sygnal.helper.proxy.proxy_asyncio import HttpConnectProtocol | ||
|
||
from tests import testutils | ||
from tests.asyncio_test_helpers import ( | ||
MockProtocol, | ||
MockTransport, | ||
TimelessEventLoopWrapper, | ||
) | ||
|
||
|
||
class AsyncioHttpProxyTest(testutils.TestCase): | ||
def config_setup(self, config): | ||
super().config_setup(config) | ||
config["apps"]["com.example.spqr"] = { | ||
"type": "tests.test_pushgateway_api_v1.TestPushkin" | ||
} | ||
base_loop = asyncio.new_event_loop() | ||
augmented_loop = TimelessEventLoopWrapper(base_loop) # type: ignore | ||
asyncio.set_event_loop(cast(AbstractEventLoop, augmented_loop)) | ||
|
||
self.loop = augmented_loop | ||
|
||
def make_fake_proxy( | ||
self, host: str, port: int, proxy_credentials: Optional[Tuple[str, str]] | ||
) -> Tuple[MockProtocol, MockTransport, "Task[Tuple[BaseTransport, Protocol]]"]: | ||
# Task[Tuple[MockTransport, MockProtocol]] | ||
# make a fake proxy | ||
fake_proxy = MockTransport() | ||
# make a fake protocol that we fancy using through the proxy | ||
fake_protocol = MockProtocol() | ||
# create a HTTP CONNECT proxy client protocol | ||
http_connect_protocol = HttpConnectProtocol( | ||
target_hostport=(host, port), | ||
proxy_credentials=proxy_credentials, | ||
protocol_factory=lambda: fake_protocol, | ||
sslcontext=None, | ||
loop=None, | ||
) | ||
switch_over_task = asyncio.get_event_loop().create_task( | ||
http_connect_protocol.switch_over_when_ready() | ||
) | ||
# check the task is not somehow already marked as done before we even | ||
# receive anything. | ||
self.assertFalse(switch_over_task.done()) | ||
# connect the proxy client to the proxy | ||
fake_proxy.set_protocol(http_connect_protocol) | ||
http_connect_protocol.connection_made(fake_proxy) | ||
return fake_protocol, fake_proxy, switch_over_task | ||
|
||
def test_connect_no_credentials(self): | ||
""" | ||
Tests the proxy connection procedure when there is no basic auth. | ||
""" | ||
host = "example.org" | ||
port = 443 | ||
proxy_credentials = None | ||
fake_protocol, fake_proxy, switch_over_task = self.make_fake_proxy( | ||
host, port, proxy_credentials | ||
) | ||
|
||
# Check that the proxy got the proper CONNECT request. | ||
self.assertEqual(fake_proxy.buffer, b"CONNECT example.org:443 HTTP/1.0\r\n\r\n") | ||
# Reset the proxy mock | ||
fake_proxy.reset_mock() | ||
|
||
# pretend we got a happy response with some dangling bytes from the | ||
# target protocol | ||
fake_proxy.pretend_to_receive( | ||
b"HTTP/1.0 200 Connection Established\r\n\r\n" | ||
b"begin beep boop\r\n\r\n~~ :) ~~" | ||
) | ||
|
||
# advance event loop because we have to let coroutines be executed | ||
self.loop.advance(1.0) | ||
|
||
# *now* we should have switched over from the HTTP CONNECT protocol | ||
# to the user protocol (in our case, a MockProtocol). | ||
self.assertTrue(switch_over_task.done()) | ||
|
||
transport, protocol = switch_over_task.result() | ||
|
||
# check it was our protocol that was returned | ||
self.assertIs(protocol, fake_protocol) | ||
|
||
# check our protocol received exactly the bytes meant for it | ||
self.assertEqual( | ||
fake_protocol.received_bytes, b"begin beep boop\r\n\r\n~~ :) ~~" | ||
) | ||
|
||
def test_connect_correct_credentials(self): | ||
""" | ||
Tests the proxy connection procedure when there is basic auth. | ||
""" | ||
host = "example.org" | ||
port = 443 | ||
proxy_credentials = ("user", "secret") | ||
fake_protocol, fake_proxy, switch_over_task = self.make_fake_proxy( | ||
host, port, proxy_credentials | ||
) | ||
|
||
# Check that the proxy got the proper CONNECT request with the | ||
# correctly-encoded credentials | ||
self.assertEqual( | ||
fake_proxy.buffer, | ||
b"CONNECT example.org:443 HTTP/1.0\r\n" | ||
b"Proxy-Authorization: basic dXNlcjpzZWNyZXQ=\r\n\r\n", | ||
) | ||
# Reset the proxy mock | ||
fake_proxy.reset_mock() | ||
|
||
# pretend we got a happy response with some dangling bytes from the | ||
# target protocol | ||
fake_proxy.pretend_to_receive( | ||
b"HTTP/1.0 200 Connection Established\r\n\r\n" | ||
b"begin beep boop\r\n\r\n~~ :) ~~" | ||
) | ||
|
||
# advance event loop because we have to let coroutines be executed | ||
self.loop.advance(1.0) | ||
|
||
# *now* we should have switched over from the HTTP CONNECT protocol | ||
# to the user protocol (in our case, a MockProtocol). | ||
self.assertTrue(switch_over_task.done()) | ||
|
||
transport, protocol = switch_over_task.result() | ||
|
||
# check it was our protocol that was returned | ||
self.assertIs(protocol, fake_protocol) | ||
|
||
# check our protocol received exactly the bytes meant for it | ||
self.assertEqual( | ||
fake_protocol.received_bytes, b"begin beep boop\r\n\r\n~~ :) ~~" | ||
) | ||
|
||
def test_connect_failure(self): | ||
""" | ||
Test that our task fails properly when we cannot make a connection through | ||
the proxy. | ||
""" | ||
host = "example.org" | ||
port = 443 | ||
proxy_credentials = ("user", "secret") | ||
fake_protocol, fake_proxy, switch_over_task = self.make_fake_proxy( | ||
host, port, proxy_credentials | ||
) | ||
|
||
# Check that the proxy got the proper CONNECT request with the | ||
# correctly-encoded credentials. | ||
self.assertEqual( | ||
fake_proxy.buffer, | ||
b"CONNECT example.org:443 HTTP/1.0\r\n" | ||
b"Proxy-Authorization: basic dXNlcjpzZWNyZXQ=\r\n\r\n", | ||
) | ||
# Reset the proxy mock | ||
fake_proxy.reset_mock() | ||
|
||
# For the sake of this test, pretend the credentials are incorrect so | ||
# send a sad response with a HTML error page | ||
fake_proxy.pretend_to_receive( | ||
b"HTTP/1.0 401 Unauthorised\r\n\r\n<HTML>... some error here ...</HTML>" | ||
) | ||
|
||
# advance event loop because we have to let coroutines be executed | ||
self.loop.advance(1.0) | ||
|
||
# *now* this future should have completed | ||
self.assertTrue(switch_over_task.done()) | ||
|
||
# but we should have failed | ||
self.assertIsInstance(switch_over_task.exception(), ProxyConnectError) | ||
|
||
# check our protocol did not receive anything, because it was an HTTP- | ||
# level error, not actually a connection to our target. | ||
self.assertEqual(fake_protocol.received_bytes, b"") |
Oops, something went wrong.