Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix use_ssl: True on Python 3.10 #1496

Merged
merged 6 commits into from
Jun 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ tox-env-clean:

lint: check-venv
@. $(VENV_ACTIVATE_FILE); find esrally benchmarks scripts tests it setup.py -name "*.py" -exec pylint -j0 -rn --rcfile=$(CURDIR)/.pylintrc \{\} +
@. $(VENV_ACTIVATE_FILE); black --check esrally benchmarks scripts tests it setup.py
@. $(VENV_ACTIVATE_FILE); isort --check esrally benchmarks scripts tests it setup.py
@. $(VENV_ACTIVATE_FILE); black --check --diff esrally benchmarks scripts tests it setup.py
@. $(VENV_ACTIVATE_FILE); isort --check --diff esrally benchmarks scripts tests it setup.py

format: check-venv
@. $(VENV_ACTIVATE_FILE); black esrally benchmarks scripts tests it setup.py
Expand Down
26 changes: 21 additions & 5 deletions esrally/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import certifi
import urllib3
from urllib3.connection import is_ipaddress

from esrally import doc_link, exceptions
from esrally.utils import console, convert
Expand Down Expand Up @@ -135,17 +136,15 @@ def __init__(self, hosts, client_options):
self.logger.info("SSL support: on")
self.client_options["scheme"] = "https"

# ssl.Purpose.CLIENT_AUTH allows presenting client certs and can only be enabled during instantiation
# but can be disabled via the verify_mode property later on.
self.ssl_context = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH, cafile=self.client_options.pop("ca_certs", certifi.where())
ssl.Purpose.SERVER_AUTH, cafile=self.client_options.pop("ca_certs", certifi.where())
)

if not self.client_options.pop("verify_certs", True):
self.logger.info("SSL certificate verification: off")
# order matters to avoid ValueError: check_hostname needs a SSL context with either CERT_OPTIONAL or CERT_REQUIRED
self.ssl_context.verify_mode = ssl.CERT_NONE
self.ssl_context.check_hostname = False
self.ssl_context.verify_mode = ssl.CERT_NONE

self.logger.warning(
"User has enabled SSL but disabled certificate verification. This is dangerous but may be ok for a "
Expand All @@ -156,8 +155,9 @@ def __init__(self, hosts, client_options):
# advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings"
urllib3.disable_warnings()
else:
# check_hostname should not be set when host is an IP address
self.ssl_context.check_hostname = self._only_hostnames(hosts)
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
self.ssl_context.check_hostname = True
self.logger.info("SSL certificate verification: on")

# When using SSL_context, all SSL related kwargs in client options get ignored
Expand Down Expand Up @@ -209,6 +209,22 @@ def __init__(self, hosts, client_options):
if self._is_set(self.client_options, "enable_cleanup_closed"):
self.client_options["enable_cleanup_closed"] = convert.to_bool(self.client_options.pop("enable_cleanup_closed"))

@staticmethod
def _only_hostnames(hosts):
has_ip = False
has_hostname = False
for host in hosts:
is_ip = is_ipaddress(host["host"])
if is_ip:
has_ip = True
else:
has_hostname = True

if has_ip and has_hostname:
raise exceptions.SystemSetupError("Cannot verify certs with mixed IP addresses and hostnames")

return has_hostname

def _is_set(self, client_opts, k):
try:
return client_opts[k]
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def str_from_file(name):
# urllib3: MIT
# aiohttp: Apache 2.0
"elasticsearch[async]==7.14.0",
"urllib3==1.26.9",
# License: BSD
"psutil==5.8.0",
# License: MIT
Expand Down Expand Up @@ -104,6 +105,8 @@ def str_from_file(name):
"pylint==2.6.0",
"black==22.3.0",
"isort==5.8.0",
"trustme==0.9.0",
"pytest-httpserver==1.0.4",
]

python_version_classifiers = ["Programming Language :: Python :: {}.{}".format(major, minor) for major, minor in supported_python_versions]
Expand Down
118 changes: 111 additions & 7 deletions tests/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import asyncio
import contextlib
import logging
import os
import random
Expand All @@ -25,7 +26,9 @@

import elasticsearch
import pytest
import trustme
import urllib3.exceptions
from pytest_httpserver import HTTPServer

from esrally import client, doc_link, exceptions
from esrally.async_connection import AIOHttpConnection
Expand All @@ -36,7 +39,7 @@ class TestEsClientFactory:
cwd = os.path.dirname(__file__)

def test_create_http_connection(self):
hosts = [{"host": "127.0.0.1", "port": 9200}]
hosts = [{"host": "localhost", "port": 9200}]
client_options = {}
# make a copy so we can verify later that the factory did not modify it
original_client_options = dict(client_options)
Expand All @@ -52,7 +55,7 @@ def test_create_http_connection(self):

@mock.patch.object(ssl.SSLContext, "load_cert_chain")
def test_create_https_connection_verify_server(self, mocked_load_cert_chain):
hosts = [{"host": "127.0.0.1", "port": 9200}]
hosts = [{"host": "localhost", "port": 9200}]
client_options = {
"use_ssl": True,
"verify_certs": True,
Expand Down Expand Up @@ -90,7 +93,7 @@ def test_create_https_connection_verify_server(self, mocked_load_cert_chain):

@mock.patch.object(ssl.SSLContext, "load_cert_chain")
def test_create_https_connection_verify_self_signed_server_and_client_certificate(self, mocked_load_cert_chain):
hosts = [{"host": "127.0.0.1", "port": 9200}]
hosts = [{"host": "localhost", "port": 9200}]
client_options = {
"use_ssl": True,
"verify_certs": True,
Expand Down Expand Up @@ -134,7 +137,7 @@ def test_create_https_connection_verify_self_signed_server_and_client_certificat

@mock.patch.object(ssl.SSLContext, "load_cert_chain")
def test_create_https_connection_only_verify_self_signed_server_certificate(self, mocked_load_cert_chain):
hosts = [{"host": "127.0.0.1", "port": 9200}]
hosts = [{"host": "localhost", "port": 9200}]
client_options = {
"use_ssl": True,
"verify_certs": True,
Expand Down Expand Up @@ -171,7 +174,7 @@ def test_create_https_connection_only_verify_self_signed_server_certificate(self
assert client_options == original_client_options

def test_raises_error_when_only_one_of_client_cert_and_client_key_defined(self):
hosts = [{"host": "127.0.0.1", "port": 9200}]
hosts = [{"host": "localhost", "port": 9200}]
client_options = {
"use_ssl": True,
"verify_certs": True,
Expand Down Expand Up @@ -205,7 +208,7 @@ def test_raises_error_when_only_one_of_client_cert_and_client_key_defined(self):

@mock.patch.object(ssl.SSLContext, "load_cert_chain")
def test_create_https_connection_unverified_certificate(self, mocked_load_cert_chain):
hosts = [{"host": "127.0.0.1", "port": 9200}]
hosts = [{"host": "localhost", "port": 9200}]
client_options = {
"use_ssl": True,
"verify_certs": False,
Expand Down Expand Up @@ -245,7 +248,7 @@ def test_create_https_connection_unverified_certificate(self, mocked_load_cert_c

@mock.patch.object(ssl.SSLContext, "load_cert_chain")
def test_create_https_connection_unverified_certificate_present_client_certificates(self, mocked_load_cert_chain):
hosts = [{"host": "127.0.0.1", "port": 9200}]
hosts = [{"host": "localhost", "port": 9200}]
client_options = {
"use_ssl": True,
"verify_certs": False,
Expand Down Expand Up @@ -287,6 +290,107 @@ def test_create_https_connection_unverified_certificate_present_client_certifica

assert client_options == original_client_options

def test_raises_error_when_verify_ssl_with_mixed_hosts(self):
hosts = [{"host": "127.0.0.1", "port": 9200}, {"host": "localhost", "port": 9200}]
client_options = {
"use_ssl": True,
"verify_certs": True,
"http_auth": ("user", "password"),
}

with pytest.raises(
exceptions.SystemSetupError,
match="Cannot verify certs with mixed IP addresses and hostnames",
):
client.EsClientFactory(hosts, client_options)

def test_check_hostname_false_when_host_is_ip(self):
hosts = [{"host": "127.0.0.1", "port": 9200}]
client_options = {
"use_ssl": True,
"verify_certs": True,
"http_auth": ("user", "password"),
}

f = client.EsClientFactory(hosts, client_options)
assert f.hosts == hosts
assert f.ssl_context.check_hostname is False
assert f.ssl_context.verify_mode == ssl.CERT_REQUIRED


@contextlib.contextmanager
def _build_server(tmpdir, host):
ca = trustme.CA()
ca_cert_path = str(tmpdir / "ca.pem")
ca.cert_pem.write_to_path(ca_cert_path)

server_cert = ca.issue_cert(host)
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
server_crt = server_cert.cert_chain_pems[0]
server_key = server_cert.private_key_pem
with server_crt.tempfile() as crt_file, server_key.tempfile() as key_file:
context.load_cert_chain(crt_file, key_file)

server = HTTPServer(ssl_context=context)
# Fake what the client expects from Elasticsearch
server.expect_request("/").respond_with_json(
headers={
"x-elastic-product": "Elasticsearch",
},
response_json={
"version": {
"number": "8.0.0",
}
},
)
server.start()

yield server, ca, ca_cert_path

server.clear()
if server.is_running():
server.stop()


class TestEsClientAgainstHTTPSServer:
def test_ip_address(self, tmp_path_factory: pytest.TempPathFactory):
tmpdir = tmp_path_factory.mktemp("certs")
with _build_server(tmpdir, "127.0.0.1") as cfg:
server, _ca, ca_cert_path = cfg
hosts = [{"host": "127.0.0.1", "port": server.port}]
client_options = {
"use_ssl": True,
"verify_certs": True,
"ca_certs": ca_cert_path,
}
f = client.EsClientFactory(hosts, client_options)
es = f.create()
assert es.info() == {"version": {"number": "8.0.0"}}

def test_client_cert(self, tmp_path_factory: pytest.TempPathFactory):
tmpdir = tmp_path_factory.mktemp("certs")
with _build_server(tmpdir, "localhost") as cfg:
server, ca, ca_cert_path = cfg
client_cert = ca.issue_cert("localhost")
client_cert_path = str(tmpdir / "client.pem")
client_key_path = str(tmpdir / "client.key")
client_cert.cert_chain_pems[0].write_to_path(client_cert_path)
client_cert.private_key_pem.write_to_path(client_key_path)

hosts = [
{"host": "localhost", "port": server.port},
]
client_options = {
"use_ssl": True,
"verify_certs": True,
"ca_certs": ca_cert_path,
"client_cert": client_cert_path,
"client_key": client_key_path,
}
f = client.EsClientFactory(hosts, client_options)
es = f.create()
assert es.info() == {"version": {"number": "8.0.0"}}


class TestRequestContextManager:
@pytest.mark.asyncio
Expand Down