diff --git a/msrest/__init__.py b/msrest/__init__.py index 87271ec8c3..7cb2a517cd 100644 --- a/msrest/__init__.py +++ b/msrest/__init__.py @@ -25,12 +25,13 @@ # -------------------------------------------------------------------------- from .configuration import Configuration -from .service_client import ServiceClient +from .service_client import ServiceClient, SDKClient from .serialization import Serializer, Deserializer from .version import msrest_version __all__ = [ "ServiceClient", + "SDKClient", "Serializer", "Deserializer", "Configuration" diff --git a/msrest/authentication.py b/msrest/authentication.py index 7cb5cedfd0..b446e7f743 100644 --- a/msrest/authentication.py +++ b/msrest/authentication.py @@ -36,13 +36,17 @@ class Authentication(object): header = "Authorization" - def signed_session(self): - """Create requests session with any required auth headers - applied. + def signed_session(self, session=None): + """Create requests session with any required auth headers applied. + + If a session object is provided, configure it directly. Otherwise, + create a new session and return it. + :param session: The session to configure for authentication + :type session: requests.Session :rtype: requests.Session """ - return requests.Session() + return session or requests.Session() class BasicAuthentication(Authentication): @@ -57,13 +61,18 @@ def __init__(self, username, password): self.username = username self.password = password - def signed_session(self): + def signed_session(self, session=None): """Create requests session with any required auth headers applied. + If a session object is provided, configure it directly. Otherwise, + create a new session and return it. + + :param session: The session to configure for authentication + :type session: requests.Session :rtype: requests.Session """ - session = super(BasicAuthentication, self).signed_session() + session = super(BasicAuthentication, self).signed_session(session) session.auth = HTTPBasicAuth(self.username, self.password) return session @@ -87,13 +96,18 @@ def set_token(self): """ pass - def signed_session(self): + def signed_session(self, session=None): """Create requests session with any required auth headers applied. + If a session object is provided, configure it directly. Otherwise, + create a new session and return it. + + :param session: The session to configure for authentication + :type session: requests.Session :rtype: requests.Session """ - session = super(BasicTokenAuthentication, self).signed_session() + session = super(BasicTokenAuthentication, self).signed_session(session) header = "{} {}".format(self.scheme, self.token['access_token']) session.headers['Authorization'] = header return session @@ -101,6 +115,7 @@ def signed_session(self): class OAuthTokenAuthentication(BasicTokenAuthentication): """OAuth Token Authentication. + Requires that supplied token contains an expires_in field. :param str client_id: Account Client ID. @@ -108,9 +123,8 @@ class OAuthTokenAuthentication(BasicTokenAuthentication): """ def __init__(self, client_id, token): - self.scheme = 'Bearer' + super(OAuthTokenAuthentication, self).__init__(token) self.id = client_id - self.token = token self.store_key = self.id def construct_auth(self): @@ -120,20 +134,32 @@ def construct_auth(self): """ return "{} {}".format(self.scheme, self.token) - def refresh_session(self): + def refresh_session(self, session=None): """Return updated session if token has expired, attempts to refresh using refresh token. + If a session object is provided, configure it directly. Otherwise, + create a new session and return it. + + :param session: The session to configure for authentication + :type session: requests.Session :rtype: requests.Session """ - return self.signed_session() + return self.signed_session(session) - def signed_session(self): + def signed_session(self, session=None): """Create requests session with any required auth headers applied. + If a session object is provided, configure it directly. Otherwise, + create a new session and return it. + + :param session: The session to configure for authentication + :type session: requests.Session :rtype: requests.Session """ - return oauth.OAuth2Session(self.id, token=self.token) + session = session or requests.Session() # Don't call super on purpose, let's "auth" manage the headers. + session.auth = oauth.OAuth2(self.id, token=self.token) + return session class ApiKeyCredentials(Authentication): """Represent the ApiKey feature of Swagger. @@ -144,6 +170,7 @@ class ApiKeyCredentials(Authentication): :param dict[str,str] in_query: ApiKey in the query as parameters """ def __init__(self, in_headers=None, in_query=None): + super(ApiKeyCredentials, self).__init__() if in_headers is None: in_headers = {} if in_query is None: @@ -155,12 +182,17 @@ def __init__(self, in_headers=None, in_query=None): self.in_headers = in_headers self.in_query = in_query - def signed_session(self): + def signed_session(self, session=None): """Create requests session with ApiKey. + If a session object is provided, configure it directly. Otherwise, + create a new session and return it. + + :param session: The session to configure for authentication + :type session: requests.Session :rtype: requests.Session """ - session = super(ApiKeyCredentials, self).signed_session() + session = super(ApiKeyCredentials, self).signed_session(session) session.headers.update(self.in_headers) session.params.update(self.in_query) return session diff --git a/msrest/configuration.py b/msrest/configuration.py index 81eef67f82..4b5e68bd78 100644 --- a/msrest/configuration.py +++ b/msrest/configuration.py @@ -97,6 +97,9 @@ def __init__(self, base_url, filepath=None): self.session_configuration_callback = default_session_configuration_callback + # If set to True, ServiceClient will own the sessionn + self.keep_alive = False + self._config = configparser.ConfigParser() self._config.optionxform = str diff --git a/msrest/pipeline.py b/msrest/pipeline.py index a42d7db6a5..c069131c4b 100644 --- a/msrest/pipeline.py +++ b/msrest/pipeline.py @@ -46,22 +46,6 @@ class ClientRequest(requests.Request): """Wrapper for requests.Request object.""" - def add_header(self, header, value): - """Add a header to the single request. - - :param str header: The header name. - :param str value: The header value. - """ - self.headers[header] = value - - def add_headers(self, headers): - """Add multiple headers to the single request. - - :param dict headers: A dictionary of headers. - """ - for key, value in headers.items(): - self.add_header(key, value) - def format_parameters(self, params): """Format parameters into a valid query string. It's assumed all parameters have already been quoted as diff --git a/msrest/service_client.py b/msrest/service_client.py index 1e4fd876b9..7611f8d8ac 100644 --- a/msrest/service_client.py +++ b/msrest/service_client.py @@ -46,6 +46,23 @@ _LOGGER = logging.getLogger(__name__) +class SDKClient(object): + """The base class of all generated SDK client. + """ + def __init__(self, creds, config): + self._client = ServiceClient(creds, config) + + def close(self): + """Close the client if keep_alive is True. + """ + self._client.close() + + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._client.__exit__(exc_type, exc_val, exc_tb) class ServiceClient(object): """REST Service Client. @@ -61,6 +78,22 @@ def __init__(self, creds, config): self.config = config self.creds = creds if creds else Authentication() self._headers = {} + self._session = None + + def __enter__(self): + self.config.keep_alive = True + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + self.config.keep_alive = False + + def close(self): + """Close the session if keep_alive is True. + """ + if self._session: + self._session.close() + self._session = None def _format_data(self, data): """Format field data according to whether it is a stream or @@ -100,39 +133,48 @@ def _configure_session(self, session, **config): :param requests.Session session: Current request session. :param config: Specific configuration overrides. + :rtype: dict + :return: A dict that will be kwarg-send to session.request """ kwargs = self.config.connection() for opt in ['timeout', 'verify', 'cert']: kwargs[opt] = config.get(opt, kwargs[opt]) - for opt in ['cookies', 'files']: - kwargs[opt] = config.get(opt) + kwargs.update({k:config[k] for k in ['cookies', 'files'] if k in config}) kwargs['allow_redirects'] = config.get( 'allow_redirects', bool(self.config.redirect_policy)) - session.headers.update(self._headers) - session.headers['User-Agent'] = self.config.user_agent - session.headers['Accept'] = 'application/json' - session.max_redirects = config.get( - 'max_redirects', self.config.redirect_policy()) - session.proxies = config.get( - 'proxies', self.config.proxies()) - session.trust_env = config.get( - 'use_env_proxies', self.config.proxies.use_env_settings) - redirect_logic = session.resolve_redirects - - def wrapped_redirect(resp, req, **kwargs): - attempt = self.config.redirect_policy.check_redirect(resp, req) - return redirect_logic(resp, req, **kwargs) if attempt else [] - - session.resolve_redirects = wrapped_redirect + kwargs['headers'] = dict(self._headers) + kwargs['headers']['User-Agent'] = self.config.user_agent + kwargs['headers']['Accept'] = 'application/json' + proxies = config.get('proxies', self.config.proxies()) + if proxies: + kwargs['proxies'] = proxies + + kwargs['stream'] = config.get('stream', True) + + session.max_redirects = config.get('max_redirects', self.config.redirect_policy()) + session.trust_env = config.get('use_env_proxies', self.config.proxies.use_env_settings) + + # Patch the redirect method directly *if not done already* + if not getattr(session.resolve_redirects, 'is_mrest_patched', False): + redirect_logic = session.resolve_redirects + + def wrapped_redirect(resp, req, **kwargs): + attempt = self.config.redirect_policy.check_redirect(resp, req) + return redirect_logic(resp, req, **kwargs) if attempt else [] + wrapped_redirect.is_mrest_patched = True + + session.resolve_redirects = wrapped_redirect + # if "enable_http_logger" is defined at the operation level, take the value. # if not, take the one in the client config # if not, disable http_logger + hooks = [] if config.get("enable_http_logger", self.config.enable_http_logger): def log_hook(r, *args, **kwargs): log_request(None, r.request) log_response(None, r.request, r, result=r) - session.hooks['response'].append(log_hook) + hooks.append(log_hook) def make_user_hook_cb(user_hook, session): def user_hook_cb(r, *args, **kwargs): @@ -141,13 +183,15 @@ def user_hook_cb(r, *args, **kwargs): return user_hook_cb for user_hook in self.config.hooks: - session.hooks['response'].append(make_user_hook_cb(user_hook, session)) + hooks.append(make_user_hook_cb(user_hook, session)) + + if hooks: + kwargs['hooks'] = {'response': hooks} - max_retries = config.get( - 'retries', self.config.retry_policy()) + # Change max_retries in current all installed adapters + max_retries = config.get('retries', self.config.retry_policy()) for protocol in self._protocols: - session.mount(protocol, - requests.adapters.HTTPAdapter(max_retries=max_retries)) + session.adapters[protocol].max_retries=max_retries output_kwargs = self.config.session_configuration_callback(session, self.config, config, **kwargs) if output_kwargs is not None: @@ -155,7 +199,7 @@ def user_hook_cb(r, *args, **kwargs): return kwargs - def send_formdata(self, request, headers=None, content=None, stream=True, **config): + def send_formdata(self, request, headers=None, content=None, **config): """Send data as a multipart form-data request. We only deal with file-like objects or strings at this point. The requests is not yet streamed. @@ -163,7 +207,6 @@ def send_formdata(self, request, headers=None, content=None, stream=True, **conf :param ClientRequest request: The request object to be sent. :param dict headers: Any headers to add to the request. :param dict content: Dictionary of the fields of the formdata. - :param bool stream: Is the session in stream mode. True by default for compat. :param config: Any specific config overrides. """ if content is None: @@ -173,35 +216,44 @@ def send_formdata(self, request, headers=None, content=None, stream=True, **conf if content_type and content_type.lower() == 'application/x-www-form-urlencoded': # Do NOT use "add_content" that assumes input is JSON request.data = {f: d for f, d in content.items() if d is not None} - return self.send(request, headers, None, stream=stream, **config) + return self.send(request, headers, None, **config) else: # Assume "multipart/form-data" file_data = {f: self._format_data(d) for f, d in content.items() if d is not None} - return self.send(request, headers, None, files=file_data, stream=stream, **config) + return self.send(request, headers, None, files=file_data, **config) - def send(self, request, headers=None, content=None, stream=True, **config): + def send(self, request, headers=None, content=None, **config): """Prepare and send request object according to configuration. :param ClientRequest request: The request object to be sent. :param dict headers: Any headers to add to the request. :param content: Any body data to add to the request. - :param bool stream: Is the session in stream mode. True by default for compat. :param config: Any specific config overrides """ - response = None - session = self.creds.signed_session() + if self.config.keep_alive and self._session is None: + self._session = requests.Session() + try: + session = self.creds.signed_session(self._session) + except TypeError: # Credentials does not support session injection + session = self.creds.signed_session() + if self._session is not None: + _LOGGER.warning("Your credentials class does not support session injection. Performance will not be at the maximum.") + kwargs = self._configure_session(session, **config) - kwargs['stream'] = stream + if headers: + request.headers.update(headers) - request.add_headers(headers if headers else {}) if not kwargs.get('files'): request.add_content(content) - try: + if request.data: + kwargs['data']=request.data + kwargs['headers'].update(request.headers) + response = None + try: try: response = session.request( - request.method, request.url, - data=request.data, - headers=request.headers, + request.method, + request.url, **kwargs) return response @@ -212,12 +264,14 @@ def send(self, request, headers=None, content=None, stream=True, **config): try: session = self.creds.refresh_session() - kwargs = self._configure_session(session) + kwargs = self._configure_session(session, **config) + if request.data: + kwargs['data']=request.data + kwargs['headers'].update(request.headers) response = session.request( - request.method, request.url, - request.data, - request.headers, + request.method, + request.url, **kwargs) return response except (oauth2.rfc6749.errors.InvalidGrantError, @@ -230,8 +284,15 @@ def send(self, request, headers=None, content=None, stream=True, **config): msg = "Error occurred in request." raise_with_traceback(ClientRequestError, msg, err) finally: - if not response or not stream: - session.close() + self._close_local_session_if_necessary(response, session, kwargs['stream']) + + def _close_local_session_if_necessary(self, response, session, stream): + # Do NOT close session if session is self._session. No exception. + if self._session is session: + return + # Here, it's a local session, I might close it. + if not response or not stream: + session.close() def stream_download(self, data, callback): """Generator for streaming request body data. diff --git a/tests/test_auth.py b/tests/test_auth.py index a976c03f65..c7dcf61e0b 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -46,7 +46,7 @@ TopicCredentials ) -from requests import Request +from requests import Request, PreparedRequest class TestAuthentication(unittest.TestCase): @@ -91,11 +91,16 @@ def test_basic_token_auth(self): def test_token_auth(self): - token = {"my_token":123} + token = { + 'access_token': '123456789' + } auth = OAuthTokenAuthentication("client_id", token) session = auth.signed_session() - self.assertEqual(session.token, token) + request = PreparedRequest() + request.prepare("GET", "https://example.org") + session.auth(request) + assert request.headers == {'Authorization': 'Bearer 123456789'} def test_apikey_auth(self): auth = ApiKeyCredentials( diff --git a/tests/test_client.py b/tests/test_client.py index c9143259a7..241e2cabc4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -33,10 +33,11 @@ import mock import requests +from requests.adapters import HTTPAdapter from oauthlib import oauth2 -from msrest import ServiceClient -from msrest.authentication import OAuthTokenAuthentication +from msrest import ServiceClient, SDKClient +from msrest.authentication import OAuthTokenAuthentication, Authentication from msrest import Configuration from msrest.exceptions import ClientRequestError, TokenExpiredError @@ -66,36 +67,159 @@ def callback(session, global_config, local_config, **kwargs): output_kwargs = client._configure_session(local_session, **{"test": True}) self.assertTrue(output_kwargs['used_callback']) + def test_sdk_context_manager(self): + cfg = Configuration("http://127.0.0.1/") + + class Creds(Authentication): + def __init__(self): + self.first_session = None + self.called = 0 + + def signed_session(self, session=None): + self.called += 1 + assert session is not None + if self.first_session: + assert self.first_session is session + else: + self.first_session = session + creds = Creds() + + with SDKClient(creds, cfg) as client: + assert cfg.keep_alive + + req = client._client.get() + try: + client._client.send(req) # Will fail, I don't care, that's not the point of the test + except Exception: + pass + + try: + client._client.send(req) # Will fail, I don't care, that's not the point of the test + except Exception: + pass + + assert not cfg.keep_alive + assert creds.called == 2 + + def test_context_manager(self): + + cfg = Configuration("http://127.0.0.1/") + + class Creds(Authentication): + def __init__(self): + self.first_session = None + self.called = 0 + + def signed_session(self, session=None): + self.called += 1 + assert session is not None + if self.first_session: + assert self.first_session is session + else: + self.first_session = session + creds = Creds() + + with ServiceClient(creds, cfg) as client: + assert cfg.keep_alive + + req = client.get() + try: + client.send(req) # Will fail, I don't care, that's not the point of the test + except Exception: + pass + + try: + client.send(req) # Will fail, I don't care, that's not the point of the test + except Exception: + pass + assert client._session # Still alive + + assert not cfg.keep_alive + assert creds.called == 2 + assert client._session is None # Dead + + def test_keep_alive(self): + + cfg = Configuration("http://127.0.0.1/") + cfg.keep_alive = True + + class Creds(Authentication): + def __init__(self): + self.first_session = None + self.called = 0 + + def signed_session(self, session=None): + self.called += 1 + assert session is not None + if self.first_session: + assert self.first_session is session + else: + self.first_session = session + creds = Creds() + + client = ServiceClient(creds, cfg) + req = client.get() + try: + client.send(req) # Will fail, I don't care, that's not the point of the test + except Exception: + pass + + try: + client.send(req) # Will fail, I don't care, that's not the point of the test + except Exception: + pass + + assert creds.called == 2 + assert client._session # Still alive + # Manually close the client in "keep_alive" mode + client.close() + assert client._session is None # Dead + + def test_max_retries_on_default_adapter(self): + # max_retries must be applied only on the default adapters of requests + # If the user adds its own adapter, don't touch it + client = ServiceClient(self.creds, self.cfg) + + max_retries = self.cfg.retry_policy() + + local_session = requests.Session() + local_session.mount('http://example.org', HTTPAdapter()) + client._configure_session(local_session) + assert local_session.adapters["http://"].max_retries is max_retries + assert local_session.adapters["https://"].max_retries is max_retries + assert local_session.adapters['http://example.org'].max_retries is not max_retries + + def test_no_log(self): client = ServiceClient(self.creds, self.cfg) # By default, no log handler for HTTP local_session = requests.Session() - client._configure_session(local_session) - self.assertEqual(len(local_session.hooks["response"]), 0) + kwargs = client._configure_session(local_session) + assert 'hooks' not in kwargs # I can enable it per request local_session = requests.Session() - client._configure_session(local_session, **{"enable_http_logger": True}) - self.assertEqual(len(local_session.hooks["response"]), 1) + kwargs = client._configure_session(local_session, **{"enable_http_logger": True}) + assert 'hooks' in kwargs # I can enable it per request (bool value should be honored) local_session = requests.Session() - client._configure_session(local_session, **{"enable_http_logger": False}) - self.assertEqual(len(local_session.hooks["response"]), 0) + kwargs = client._configure_session(local_session, **{"enable_http_logger": False}) + assert 'hooks' not in kwargs # I can enable it globally client.config.enable_http_logger = True local_session = requests.Session() - client._configure_session(local_session) - self.assertEqual(len(local_session.hooks["response"]), 1) + kwargs = client._configure_session(local_session) + assert 'hooks' in kwargs # I can enable it globally and override it locally client.config.enable_http_logger = True local_session = requests.Session() - client._configure_session(local_session, **{"enable_http_logger": False}) - self.assertEqual(len(local_session.hooks["response"]), 0) + kwargs = client._configure_session(local_session, **{"enable_http_logger": False}) + assert 'hooks' not in kwargs def test_client_request(self): @@ -147,103 +271,148 @@ def test_format_url(self): url = "/bool/test true" - mock_client = mock.create_autospec(ServiceClient) - mock_client.config = mock.Mock(base_url="http://localhost:3000") + client = mock.create_autospec(ServiceClient) + client.config = mock.Mock(base_url="http://localhost:3000") - formatted = ServiceClient.format_url(mock_client, url) + formatted = ServiceClient.format_url(client, url) self.assertEqual(formatted, "http://localhost:3000/bool/test true") - mock_client.config = mock.Mock(base_url="http://localhost:3000/") - formatted = ServiceClient.format_url(mock_client, url, foo=123, bar="value") + client.config = mock.Mock(base_url="http://localhost:3000/") + formatted = ServiceClient.format_url(client, url, foo=123, bar="value") self.assertEqual(formatted, "http://localhost:3000/bool/test true") url = "https://absolute_url.com/my/test/path" - formatted = ServiceClient.format_url(mock_client, url) + formatted = ServiceClient.format_url(client, url) self.assertEqual(formatted, "https://absolute_url.com/my/test/path") - formatted = ServiceClient.format_url(mock_client, url, foo=123, bar="value") + formatted = ServiceClient.format_url(client, url, foo=123, bar="value") self.assertEqual(formatted, "https://absolute_url.com/my/test/path") url = "test" - formatted = ServiceClient.format_url(mock_client, url) + formatted = ServiceClient.format_url(client, url) self.assertEqual(formatted, "http://localhost:3000/test") - mock_client.config = mock.Mock(base_url="http://{hostname}:{port}/{foo}/{bar}") - formatted = ServiceClient.format_url(mock_client, url, hostname="localhost", port="3000", foo=123, bar="value") + client.config = mock.Mock(base_url="http://{hostname}:{port}/{foo}/{bar}") + formatted = ServiceClient.format_url(client, url, hostname="localhost", port="3000", foo=123, bar="value") self.assertEqual(formatted, "http://localhost:3000/123/value/test") - mock_client.config = mock.Mock(base_url="https://my_endpoint.com/") - formatted = ServiceClient.format_url(mock_client, url, foo=123, bar="value") + client.config = mock.Mock(base_url="https://my_endpoint.com/") + formatted = ServiceClient.format_url(client, url, foo=123, bar="value") self.assertEqual(formatted, "https://my_endpoint.com/test") def test_client_send(self): - mock_client = mock.create_autospec(ServiceClient) - mock_client.config = self.cfg - mock_client.creds = self.creds - mock_client._configure_session.return_value = {} + client = ServiceClient(self.creds, self.cfg) session = mock.create_autospec(requests.Session) - mock_client.creds.signed_session.return_value = session - mock_client.creds.refresh_session.return_value = session + session.adapters = { + "http://": HTTPAdapter(), + "https://": HTTPAdapter(), + } + client.creds.signed_session.return_value = session + client.creds.refresh_session.return_value = session + # Be sure the mock does not trick me + assert not hasattr(session.resolve_redirects, 'is_mrest_patched') request = ClientRequest('GET') - ServiceClient.send(mock_client, request, stream=False) + client.send(request, stream=False) session.request.call_count = 0 - mock_client._configure_session.assert_called_with(session) - session.request.assert_called_with('GET', None, data=[], headers={}, stream=False) + session.request.assert_called_with( + 'GET', + None, + allow_redirects=True, + cert=None, + headers={ + 'User-Agent': self.cfg.user_agent, + 'Accept': 'application/json' + }, + stream=False, + timeout=100, + verify=True + ) + assert session.resolve_redirects.is_mrest_patched session.close.assert_called_with() - ServiceClient.send(mock_client, request, headers={'id':'1234'}, content={'Test':'Data'}, stream=False) - mock_client._configure_session.assert_called_with(session) - session.request.assert_called_with('GET', None, data='{"Test": "Data"}', headers={'Content-Length': '16', 'id':'1234'}, stream=False) + client.send(request, headers={'id':'1234'}, content={'Test':'Data'}, stream=False) + session.request.assert_called_with( + 'GET', + None, + data='{"Test": "Data"}', + allow_redirects=True, + cert=None, + headers={ + 'User-Agent': self.cfg.user_agent, + 'Accept': 'application/json', + 'Content-Length': '16', + 'id':'1234' + }, + stream=False, + timeout=100, + verify=True + ) self.assertEqual(session.request.call_count, 1) session.request.call_count = 0 + assert session.resolve_redirects.is_mrest_patched session.close.assert_called_with() session.request.side_effect = requests.RequestException("test") with self.assertRaises(ClientRequestError): - ServiceClient.send(mock_client, request, headers={'id':'1234'}, content={'Test':'Data'}, test='value', stream=False) - mock_client._configure_session.assert_called_with(session, test='value') - session.request.assert_called_with('GET', None, data='{"Test": "Data"}', headers={'Content-Length': '16', 'id':'1234'}, stream=False) + client.send(request, headers={'id':'1234'}, content={'Test':'Data'}, test='value', stream=False) + session.request.assert_called_with( + 'GET', + None, + data='{"Test": "Data"}', + allow_redirects=True, + cert=None, + headers={ + 'User-Agent': self.cfg.user_agent, + 'Accept': 'application/json', + 'Content-Length': '16', + 'id':'1234' + }, + stream=False, + timeout=100, + verify=True + ) self.assertEqual(session.request.call_count, 1) session.request.call_count = 0 + assert session.resolve_redirects.is_mrest_patched session.close.assert_called_with() session.request.side_effect = oauth2.rfc6749.errors.InvalidGrantError("test") with self.assertRaises(TokenExpiredError): - ServiceClient.send(mock_client, request, headers={'id':'1234'}, content={'Test':'Data'}, test='value') + client.send(request, headers={'id':'1234'}, content={'Test':'Data'}, test='value') self.assertEqual(session.request.call_count, 2) session.request.call_count = 0 session.close.assert_called_with() session.request.side_effect = ValueError("test") with self.assertRaises(ValueError): - ServiceClient.send(mock_client, request, headers={'id':'1234'}, content={'Test':'Data'}, test='value') + client.send(request, headers={'id':'1234'}, content={'Test':'Data'}, test='value') session.close.assert_called_with() def test_client_formdata_send(self): - mock_client = mock.create_autospec(ServiceClient) - mock_client._format_data.return_value = "formatted" + client = mock.create_autospec(ServiceClient) + client._format_data.return_value = "formatted" request = ClientRequest('GET') - ServiceClient.send_formdata(mock_client, request) - mock_client.send.assert_called_with(request, None, None, files={}, stream=True) + ServiceClient.send_formdata(client, request) + client.send.assert_called_with(request, None, None, files={}) - ServiceClient.send_formdata(mock_client, request, {'id':'1234'}, {'Test':'Data'}) - mock_client.send.assert_called_with(request, {'id':'1234'}, None, files={'Test':'formatted'}, stream=True) + ServiceClient.send_formdata(client, request, {'id':'1234'}, {'Test':'Data'}) + client.send.assert_called_with(request, {'id':'1234'}, None, files={'Test':'formatted'}) - ServiceClient.send_formdata(mock_client, request, {'Content-Type':'1234'}, {'1':'1', '2':'2'}) - mock_client.send.assert_called_with(request, {}, None, files={'1':'formatted', '2':'formatted'}, stream=True) + ServiceClient.send_formdata(client, request, {'Content-Type':'1234'}, {'1':'1', '2':'2'}) + client.send.assert_called_with(request, {}, None, files={'1':'formatted', '2':'formatted'}) - ServiceClient.send_formdata(mock_client, request, {'Content-Type':'1234'}, {'1':'1', '2':None}) - mock_client.send.assert_called_with(request, {}, None, files={'1':'formatted'}, stream=True) + ServiceClient.send_formdata(client, request, {'Content-Type':'1234'}, {'1':'1', '2':None}) + client.send.assert_called_with(request, {}, None, files={'1':'formatted'}) - ServiceClient.send_formdata(mock_client, request, {'Content-Type':'application/x-www-form-urlencoded'}, {'1':'1', '2':'2'}) - mock_client.send.assert_called_with(request, {}, None, stream=True) + ServiceClient.send_formdata(client, request, {'Content-Type':'application/x-www-form-urlencoded'}, {'1':'1', '2':'2'}) + client.send.assert_called_with(request, {}, None) self.assertEqual(request.data, {'1':'1', '2':'2'}) - ServiceClient.send_formdata(mock_client, request, {'Content-Type':'application/x-www-form-urlencoded'}, {'1':'1', '2':None}) - mock_client.send.assert_called_with(request, {}, None, stream=True) + ServiceClient.send_formdata(client, request, {'Content-Type':'application/x-www-form-urlencoded'}, {'1':'1', '2':None}) + client.send.assert_called_with(request, {}, None) self.assertEqual(request.data, {'1':'1'}) def test_format_data(self): diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index f971ae484b..c4c45c0fc2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -44,14 +44,6 @@ class TestClientRequest(unittest.TestCase): - def test_request_headers(self): - - request = ClientRequest() - request.add_header("a", 1) - request.add_headers({'b':2, 'c':3}) - - self.assertEqual(request.headers, {'a':1, 'b':2, 'c':3}) - def test_request_data(self): request = ClientRequest()