From 6a903828c7933946bce9aad4c5f0e4576405ffa0 Mon Sep 17 00:00:00 2001 From: Corentin Henry Date: Thu, 16 Jan 2020 15:12:22 +0100 Subject: [PATCH] replace the CLI arguments with a config file References ========== https://xainag.atlassian.net/browse/XP-456 https://xainag.atlassian.net/browse/XP-512 Rationale ========= The CLI is getting complex so it is worth loading the configuration from a file instead. Implementation details ====================== TOML ---- We decided to use TOML for the following reasons: - it is human friendly, ie easy to read and write - our configuration has a pretty flat structure which makes TOML quite adapted - it is well specified and has lots of implementation - it is well known The other options we considered: - INI: it is quite frequent in the Python ecosystem to use INI for config files, and the standard library even provides support for this. However, INI is not as powerful as TOML and does not have a specification - JSON: it is very popular but is not human friendly. For instance, it does not support comments, is very verbose, and breaks easily (if a trailing comma is forgotten at the end of a list for instance) - YAML: another popular choice, but is in my opinion more complex than TOML. Validation ---------- We use the third-party `schema` library to validate the configuration. It provides a convenient way to: - declare a schema to validate our config - leverage third-party libraries to validate some inputs (we use the `idna` library to validate hostnames) - define our own validators - transform data after it has been validated: this can be useful to turn a relative path into an absolute one for example - provide user friendly error message when the configuration is invalid The `Config` class ------------------ By default, the `schema` library returns a dictionary containing a valid configuration, but that is not convenient to manipulate in Python. Therefore, we dynamically create a `Config` class from the configuration schema, and instantiate a `Config` object from the data returned by the `schema` validator. Package re-organization ----------------------- We moved the command line and config file logic into its own `config` sub-package, and moved the former `xain_fl.cli.main` entrypoint into the `xain_fl.__main__` module. Docker infrastructure --------------------- - Cache the xain_fl dependencies. This considerably reduces "edit->build-> debug" cycle, since installing the dependencies takes about 30 minutes. - Move all the docker related files into the `docker/` directory --- Dockerfile.dev | 12 - README.md | 23 +- docker/dev/Dockerfile | 30 ++ .../dev/docker-compose.yml | 16 +- .../dev/initial_weights.npy | Bin docker/dev/xain-fl.toml | 29 ++ Dockerfile => docker/release/Dockerfile | 0 docker/{ => release}/entrypoint.sh | 0 setup.py | 5 +- tests/store.py | 10 +- tests/test_config.py | 173 +++++++++ xain_fl/__main__.py | 38 ++ xain_fl/cli.py | 237 ------------ xain_fl/config/__init__.py | 15 + xain_fl/config/cli.py | 19 + xain_fl/config/schema.py | 358 ++++++++++++++++++ xain_fl/coordinator/store.py | 15 +- 17 files changed, 703 insertions(+), 277 deletions(-) delete mode 100644 Dockerfile.dev create mode 100644 docker/dev/Dockerfile rename docker-compose-dev.yml => docker/dev/docker-compose.yml (76%) rename test_array.npy => docker/dev/initial_weights.npy (100%) create mode 100644 docker/dev/xain-fl.toml rename Dockerfile => docker/release/Dockerfile (100%) rename docker/{ => release}/entrypoint.sh (100%) create mode 100644 tests/test_config.py create mode 100644 xain_fl/__main__.py delete mode 100644 xain_fl/cli.py create mode 100644 xain_fl/config/__init__.py create mode 100644 xain_fl/config/cli.py create mode 100644 xain_fl/config/schema.py diff --git a/Dockerfile.dev b/Dockerfile.dev deleted file mode 100644 index 4ea14a35c..000000000 --- a/Dockerfile.dev +++ /dev/null @@ -1,12 +0,0 @@ -FROM python:3.6-alpine - -RUN apk update && apk add python3-dev build-base git - -WORKDIR /app -COPY setup.py . -COPY xain_fl xain_fl/ -COPY README.md . - -RUN pip install -v -e . - -CMD ["python3", "setup.py", "--fullname"] diff --git a/README.md b/README.md index 70a9e5009..e3a4040af 100644 --- a/README.md +++ b/README.md @@ -89,14 +89,27 @@ There are two docker-compose files, one for development and one for release. To run the coordinator's development image, first build the Docker image: ```shell -$ docker build -t xain-fl-dev -f Dockerfile.dev . +docker build -t xain-fl-dev -f docker/dev/Dockerfile . ``` Then run the image, mounting the directory as a Docker volume, and call the entrypoint: ```shell -$ docker run -v $(pwd):/app -v '/app/xain_fl.egg-info' xain-fl-dev coordinator +docker run -v $(pwd):/app -v '/app/xain_fl.egg-info' xain-fl-dev +``` + +You can pass a custom config file and custom initial weights by +mounting volumes: + +```shell +docker run \ + -v $(pwd)/custom_config.toml:/custom_config.toml \ + -v $(pwd)/custom_initial_weights.npy:/custom_initial_weights.npy \ + -v $(pwd):/app \ + -v '/app/xain_fl.egg-info' \ + xain-fl-dev \ + coordinator --config /custom_config.toml ``` #### Release image @@ -118,7 +131,11 @@ $ docker run -p 50051:50051 xain-fl #### Development ```shell -$ docker-compose -f docker-compose-dev.yml up +# First, build the docker image +docker build -t xain-fl-dev -f docker/dev/Dockerfile . + +# Start the services from the root of the repo +docker-compose -f docker/dev/docker-compose.yml up ``` #### Release diff --git a/docker/dev/Dockerfile b/docker/dev/Dockerfile new file mode 100644 index 000000000..58de28f78 --- /dev/null +++ b/docker/dev/Dockerfile @@ -0,0 +1,30 @@ +FROM python:3.6-alpine + +RUN apk update && apk add python3-dev build-base git + +WORKDIR /app + +# Some dependencies require a very long compilation: protobuf, numpy, +# grpcio. To avoid having to re-install these packages every time, we +# pre-install them so that they are cached. +# +# However, we cannot just pre-install a few packages, because we don't +# know exactly which version `pip` will pick for them. Instead, we +# give pip all the package's dependencies, and let it resolve and +# install them. +COPY setup.py . +RUN mkdir xain_fl && \ + printf '__version__ = "0.2.0"\n__short_version__ = "0.2"' > xain_fl/__version__.py && \ + touch README.md && \ + python setup.py egg_info && \ + cat *.egg-info/requires.txt | grep -v '^\[' | uniq | pip install -r /dev/stdin + +RUN rm -rf xain_fl README.md +COPY README.md . +COPY xain_fl xain_fl/ +RUN pip install -e . + +COPY docker/dev/xain-fl.toml /xain-fl.toml +COPY docker/dev/initial_weights.npy /initial_weights.npy + +CMD ["coordinator", "--config", "/xain-fl.toml"] diff --git a/docker-compose-dev.yml b/docker/dev/docker-compose.yml similarity index 76% rename from docker-compose-dev.yml rename to docker/dev/docker-compose.yml index c17628314..39c3d6f44 100644 --- a/docker-compose-dev.yml +++ b/docker/dev/docker-compose.yml @@ -48,20 +48,16 @@ services: " xain-fl-dev: - environment: - MINIO_ACCESS_KEY: minio - MINIO_SECRET_KEY: minio123 build: - context: . - dockerfile: Dockerfile.dev - command: sh -c "coordinator -f test_array.npy --storage-endpoint http://minio-dev:9000 --storage-key-id $${MINIO_ACCESS_KEY} --storage-secret-access-key $${MINIO_SECRET_KEY} --storage-bucket xain-fl-aggregated-weights" + context: ../.. + dockerfile: docker/dev/Dockerfile volumes: - # don't use the local egg-info, if one exists - - /app/xain_fl.egg-info - - ./xain_fl:/app/xain_fl + - /app/xain_fl.egg-info # don't use the local egg-info, if one exists + - ${PWD}/xain_fl:/app/xain_fl - ${PWD}/setup.py:/app/setup.py - ${PWD}/README.md:/app/README.md - - ${PWD}/test_array.npy:/app/test_array.npy + - ${PWD}/docker/dev/xain-fl.toml:/xain-fl.toml + - ${PWD}/docker/dev/initial_weights.npy:/initial_weights.npy networks: - xain-fl-dev ports: diff --git a/test_array.npy b/docker/dev/initial_weights.npy similarity index 100% rename from test_array.npy rename to docker/dev/initial_weights.npy diff --git a/docker/dev/xain-fl.toml b/docker/dev/xain-fl.toml new file mode 100644 index 000000000..fe85db66b --- /dev/null +++ b/docker/dev/xain-fl.toml @@ -0,0 +1,29 @@ +[server] +# Address to listen on for incoming gRPC connections +host = "0.0.0.0" +# Port to listen on for incoming gRPC connections +port = 50051 + +[ai] +# Path to a file containing a numpy ndarray to use a initial model weights. +initial_weights = "/initial_weights.npy" +# Number of global rounds the model is going to be trained for. This +# must be a positive integer. +rounds = 1 +# Number of local epochs per round +epochs = 1 +# Minimum number of participants to be selected for a round. +min_participants = 1 +# Fraction of total clients that participate in a training round. This +# must be a float between 0 and 1. +fraction_participants = 1.0 + +[storage] +# URL to the storage service to use +endpoint = "http://minio-dev:900" +# Name of the bucket for storing the aggregated models +bucket = "xain-fl-aggregated-weights" +# AWS access key ID to use to authenticate to the storage service +access_key_id = "minio" +# AWS secret access to use to authenticate to the storage service +secret_access_key = "minio123" \ No newline at end of file diff --git a/Dockerfile b/docker/release/Dockerfile similarity index 100% rename from Dockerfile rename to docker/release/Dockerfile diff --git a/docker/entrypoint.sh b/docker/release/entrypoint.sh similarity index 100% rename from docker/entrypoint.sh rename to docker/release/entrypoint.sh diff --git a/setup.py b/setup.py index ee2285dc0..6c1df9d10 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,9 @@ # TODO: change xain-proto requirement to "xain-proto==0.2.0" once it is released "xain-proto @ git+https://github.com/xainag/xain-proto.git@37fc05566da91d263c37d203c0ba70804960be9b#egg=xain_proto-0.1.0&subdirectory=python", # Apache License 2.0 "boto3==1.10.48", # Apache License 2.0 + "toml==0.10.0", # MIT + "schema==0.6.8", # MIT + "idna==2.8", # BSD ] dev_require = [ @@ -97,5 +100,5 @@ "docs": docs_require, "dev": dev_require + tests_require + docs_require, }, - entry_points={"console_scripts": ["coordinator=xain_fl.cli:main"]}, + entry_points={"console_scripts": ["coordinator=xain_fl.__main__"]}, ) diff --git a/tests/store.py b/tests/store.py index 9534ccb24..20a319aee 100644 --- a/tests/store.py +++ b/tests/store.py @@ -7,7 +7,8 @@ import numpy as np -from xain_fl.coordinator.store import Store, StoreConfig +from xain_fl.config import StorageConfig +from xain_fl.coordinator.store import Store class FakeS3Resource: @@ -71,7 +72,12 @@ class TestStore(Store): # # pylint: disable=super-init-not-called def __init__(self): - self.config = StoreConfig("endpoint_url", "access_key_id", "secret_access_key", "bucket") + self.config = StorageConfig( + endpoint="endpoint", + access_key_id="access_key_id", + secret_access_key="secret_access_key", + bucket="bucket", + ) self.s3 = FakeS3Resource() def assert_wrote(self, round: int, weights: np.ndarray): diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 000000000..cc2a487b6 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,173 @@ +# pylint: disable=missing-docstring,redefined-outer-name +import re + +import pytest + +from xain_fl.config import Config, InvalidConfig + + +@pytest.fixture +def server_sample(): + """ + Return a valid "server" section + """ + return {"host": "localhost", "port": 50051} + + +@pytest.fixture +def ai_sample(): + """ + Return a valid "ai" section + """ + return { + "initial_weights": "./test_array.npy", + "rounds": 1, + "epochs": 1, + "min_participants": 1, + "fraction_participants": 1.0, + } + + +@pytest.fixture +def storage_sample(): + """ + Return a valid "storage" section + """ + return { + "endpoint": "http://localhost:9000", + "bucket": "aggregated_weights", + "secret_access_key": "my-secret", + "access_key_id": "my-key-id", + } + + +@pytest.fixture +def config_sample(server_sample, ai_sample, storage_sample): + """ + Return a valid config + """ + return { + "ai": ai_sample, + "server": server_sample, + "storage": storage_sample, + } + + +def test_load_valid_config(config_sample): + """ + Check that a valid config is loaded correctly + """ + config = Config.from_unchecked_dict(config_sample) + + assert config.server.host == "localhost" + assert config.server.port == 50051 + + assert config.ai.initial_weights == "./test_array.npy" + assert config.ai.rounds == 1 + assert config.ai.epochs == 1 + assert config.ai.min_participants == 1 + assert config.ai.fraction_participants == 1.0 + + assert config.storage.endpoint == "http://localhost:9000" + assert config.storage.bucket == "aggregated_weights" + assert config.storage.secret_access_key == "my-secret" + assert config.storage.access_key_id == "my-key-id" + + +def test_server_config_ip_address(config_sample, server_sample): + """Check that the config is loaded correctly when the `server.host` + key is an IP address + + """ + # Ipv4 host + server_sample["host"] = "1.2.3.4" + config_sample["server"] = server_sample + config = Config.from_unchecked_dict(config_sample) + assert config.server.host == server_sample["host"] + + # Ipv6 host + server_sample["host"] = "::1" + config_sample["server"] = server_sample + config = Config.from_unchecked_dict(config_sample) + assert config.server.host == server_sample["host"] + + +def test_server_config_extra_key(config_sample, server_sample): + """Check that the config is rejected when the server section contains + an extra key + + """ + server_sample["extra-key"] = "foo" + config_sample["server"] = server_sample + + with AssertInvalid() as err: + Config.from_unchecked_dict(config_sample) + + err.check_section("server") + err.check_extra_key("extra-key") + + +def test_server_config_invalid_host(config_sample, server_sample): + """Check that the config is rejected when the `server.host` key is + invalid. + + """ + server_sample["host"] = 1.0 + config_sample["server"] = server_sample + + with AssertInvalid() as err: + Config.from_unchecked_dict(config_sample) + + err.check_other( + re.compile("Invalid `server.host`: value must be a valid domain name or IP address") + ) + + +def test_server_config_valid_ipv6(config_sample, server_sample): + """Check some edge cases with IPv6 `server.host` key""" + server_sample["host"] = "::" + config_sample["server"] = server_sample + config = Config.from_unchecked_dict(config_sample) + assert config.server.host == server_sample["host"] + + server_sample["host"] = "fe80::" + config_sample["server"] = server_sample + config = Config.from_unchecked_dict(config_sample) + assert config.server.host == server_sample["host"] + + +# Adapted from unittest's assertRaises +class AssertInvalid: + """A context manager that check that an `xainfl.config.InvalidConfig` + exception is raised, and provides helpers to perform checks on the + exception. + + """ + + def __init__(self): + self.message = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, _tb): + if exc_type is None: + raise Exception("Did not get an exception") + if not isinstance(exc_value, InvalidConfig): + # let this un-expected exception be re-raised + return False + + self.message = str(exc_value) + + return True + + def check_section(self, section): + needle = re.compile(f"Key '{section}' error:") + assert re.search(needle, self.message) + + def check_extra_key(self, key): + needle = re.compile(f"Wrong keys '{key}' in") + assert re.search(needle, self.message) + + def check_other(self, needle): + assert re.search(needle, self.message) diff --git a/xain_fl/__main__.py b/xain_fl/__main__.py new file mode 100644 index 000000000..de68627ef --- /dev/null +++ b/xain_fl/__main__.py @@ -0,0 +1,38 @@ +"""This module is the entrypoint to start a new coordinator instance. + +""" +import sys + +import numpy as np + +from xain_fl.config import Config, InvalidConfig, get_cmd_parameters +from xain_fl.coordinator.coordinator import Coordinator +from xain_fl.coordinator.store import Store +from xain_fl.serve import serve + + +def main(): + """Start a coordinator instance + """ + + args = get_cmd_parameters() + try: + config = Config.load(args.config) + except InvalidConfig as err: + print(err, file=sys.stderr) + sys.exit(1) + + coordinator = Coordinator( + weights=list(np.load(config.ai.initial_weights, allow_pickle=True)), + num_rounds=config.ai.rounds, + epochs=config.ai.epochs, + minimum_participants_in_round=config.ai.min_participants, + fraction_of_participants=config.ai.fraction_participants, + ) + + store = Store(config.storage) + + serve(coordinator=coordinator, store=store, host=config.server.host, port=config.server.port) + + +main() diff --git a/xain_fl/cli.py b/xain_fl/cli.py deleted file mode 100644 index c04f2698e..000000000 --- a/xain_fl/cli.py +++ /dev/null @@ -1,237 +0,0 @@ -"""Module implementing the networked coordinator using gRPC. - -This module implements the Coordinator state machine, the Coordinator gRPC -service and helper class to keep state about the Participants. -""" -import argparse - -import numpy as np - -from xain_fl.coordinator.coordinator import Coordinator -from xain_fl.coordinator.store import Store, StoreConfig -from xain_fl.serve import serve - - -def type_num_rounds(value): - """[summary] - - .. todo:: Advance docstrings (https://xainag.atlassian.net/browse/XP-425) - - Args: - value ([type]): [description] - - Returns: - [type]: [description] - - Raises: - ~argparse.ArgumentTypeError: [description] - ~argparse.ArgumentTypeError: [description] - """ - - ivalue = int(value) - - if ivalue <= 0: - raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value) - - if ivalue > 1_000: - raise argparse.ArgumentTypeError("%s More than 1_000 rounds is not supported" % value) - - return ivalue - - -def type_num_epochs(value): - """[summary] - - .. todo:: Advance docstrings (https://xainag.atlassian.net/browse/XP-425) - - Args: - value ([type]): [description] - - Returns: - [type]: [description] - - Raises: - ~argparse.ArgumentTypeError: [description] - ~argparse.ArgumentTypeError: [description] - """ - - ivalue = int(value) - - if ivalue < 0: - raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value) - - if ivalue > 10_000: - raise argparse.ArgumentTypeError("%s More than 10_000 epochs is not supported" % value) - - return ivalue - - -def type_min_num_participants_in_round(value): - """[summary] - - .. todo:: Advance docstrings (https://xainag.atlassian.net/browse/XP-425) - - Args: - value ([type]): [description] - - Returns: - [type]: [description] - - Raises: - ~argparse.ArgumentTypeError: [description] - ~argparse.ArgumentTypeError: [description] - """ - - ivalue = int(value) - - if ivalue <= 0: - raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value) - - if ivalue > 1_000_000: - raise argparse.ArgumentTypeError( - "%s More than 1_000_000 participants is currently not supported" % value - ) - - return ivalue - - -def type_fraction(value): - """[summary] - - .. todo:: Advance docstrings (https://xainag.atlassian.net/browse/XP-425) - - Args: - value ([type]): [description] - - Returns: - [type]: [description] - - Raises: - ~argparse.ArgumentTypeError: [description] - ~argparse.ArgumentTypeError: [description] - """ - - ivalue = float(value) - - if ivalue <= 0: - raise argparse.ArgumentTypeError("%s is an invalid positive float value" % value) - - if ivalue > 1: - raise argparse.ArgumentTypeError( - "%s is not a valid fraction of the total participant count." % value - ) - - return ivalue - - -def get_cmd_parameters(): - """[summary] - - .. todo:: Advance docstrings (https://xainag.atlassian.net/browse/XP-425) - - Returns: - [type]: [description] - """ - - # Allow various parameters to be passed via the commandline - parser = argparse.ArgumentParser(description="Coordinator CLI") - - parser.add_argument("--host", dest="host", default="[::]", type=str, help="Host") - parser.add_argument("--port", dest="port", default=50051, type=int, help="Port") - - parser.add_argument( - "-f", - dest="file", - required=True, - help="Path to numpy ndarray file containing model weights", - ) - - parser.add_argument( - "-r", - dest="num_rounds", - default=1, - type=type_num_rounds, - help="Number of global rounds the model is going to be trained for.", - ) - - parser.add_argument( - "-e", - dest="num_epochs", - default=1, - type=type_num_epochs, - help="Number of local epochs per round.", - ) - - parser.add_argument( - "-p", - dest="min_num_participants_in_round", - default=1, - type=type_min_num_participants_in_round, - help="Minimum number of participants to be selected for a round.", - ) - - parser.add_argument( - "-c", - dest="fraction", - default=1.0, - type=type_fraction, - help="Fraction of total clients that participate in a training round. \ - A float between 0 and 1", - ) - - parser.add_argument( - "--storage-endpoint", required=True, type=str, help="URL to the storage service to use", - ) - - parser.add_argument( - "--storage-bucket", - required=True, - type=str, - help="Name of the bucket for storing the aggregated models", - ) - - parser.add_argument( - "--storage-key-id", - required=True, - type=str, - help="AWS access key ID to use to authenticate to the storage service", - ) - - parser.add_argument( - "--storage-secret-access-key", - required=True, - type=str, - help="AWS secret access to use to authenticate to the storage service", - ) - return parser.parse_args() - - -def main(): - """[summary] - - .. todo:: Advance docstrings (https://xainag.atlassian.net/browse/XP-425) - """ - - parameters = get_cmd_parameters() - - coordinator = Coordinator( - weights=list(np.load(parameters.file, allow_pickle=True)), - num_rounds=parameters.num_rounds, - epochs=parameters.num_epochs, - minimum_participants_in_round=parameters.min_num_participants_in_round, - fraction_of_participants=parameters.fraction, - ) - - store_config = StoreConfig( - parameters.storage_endpoint, - parameters.storage_key_id, - parameters.storage_secret_access_key, - parameters.storage_bucket, - ) - store = Store(store_config) - - serve(coordinator=coordinator, store=store, host=parameters.host, port=parameters.port) - - -if __name__ == "__main__": - main() diff --git a/xain_fl/config/__init__.py b/xain_fl/config/__init__.py new file mode 100644 index 000000000..7723a155e --- /dev/null +++ b/xain_fl/config/__init__.py @@ -0,0 +1,15 @@ +"""This package provides the logic for reading and validating the +various configuration options exposed by the CLI and the toml config +file.""" + +from xain_fl.config.cli import get_cmd_parameters +from xain_fl.config.schema import AiConfig, Config, InvalidConfig, ServerConfig, StorageConfig + +__all__ = [ + "get_cmd_parameters", + "Config", + "AiConfig", + "StorageConfig", + "ServerConfig", + "InvalidConfig", +] diff --git a/xain_fl/config/cli.py b/xain_fl/config/cli.py new file mode 100644 index 000000000..8c9c391b4 --- /dev/null +++ b/xain_fl/config/cli.py @@ -0,0 +1,19 @@ +"""This module provides helpers for parsing the CLI arguments. +""" +import argparse + + +def get_cmd_parameters() -> argparse.Namespace: + """Parse the command arguments + + Returns: + ~argparse.Namespace: the parsed command arguments + """ + parser = argparse.ArgumentParser(description="Coordinator CLI") + parser.add_argument( + "--config", + dest="config", + default="xain-fl.toml", + help="Path to the coordinator configuration file", + ) + return parser.parse_args() diff --git a/xain_fl/config/schema.py b/xain_fl/config/schema.py new file mode 100644 index 000000000..b5d2a15c0 --- /dev/null +++ b/xain_fl/config/schema.py @@ -0,0 +1,358 @@ +"""This module provides helpers for reading and validating the TOML +configuration. + +""" +from collections import namedtuple +import ipaddress +from typing import Any, Mapping, NamedTuple, Type, TypeVar +import urllib + +import idna +from schema import And, Optional, Or, Schema, SchemaError, Use +import toml + + +def error(key: str, description: str) -> str: + """Return an error message for the given configuration item and + description of the expected value type. + + Args: + + key (str): key of the configuration item + description (str): description of the expected type of value + for this configuration item + """ + return f"Invalid `{key}`: value must be {description}" + + +def strictly_positive_integer( + key: str, expected_value: str = "a strictly positive integer" +) -> Schema: + """Return a validator for strictly positive integers for the given + configuration item. + + Args: + + key (str): key of the configuration item + expected_value (str): description of the expected type of + value for this configuration item + """ + return And(int, lambda value: value > 0, error=error(key, expected_value)) + + +def positive_integer(key: str, expected_value: str = "a positive integer") -> Schema: + """Return a positive integer validator for the given configuration + item. + + Args: + + key: key of the configuration item + expected_value: description of the expected type of + value for this configuration item + + """ + return And(int, lambda value: value >= 0, error=error(key, expected_value)) + + +def url(key: str, expected_value: str = "a valid URL") -> Schema: + """Return a URL validator for the given configuration item. + + Args: + + key: key of the configuration item + expected_value: description of the expected type of + value for this configuration item + + """ + + def is_valid_url(value): + try: + parsed = urllib.parse.urlparse(value) + except (ValueError, urllib.error.URLError): + return False + # A URL is considered valid if it has at least a scheme and a + # network location. + return all([parsed.scheme, parsed.netloc]) + + return And(str, is_valid_url, error=error(key, expected_value)) + + +def is_valid_hostname(value: str) -> bool: + """Return whether the given string is a valid hostname + + Args: + + value: string to check + + Returns: + + `True` if the given value is a valid hostname, `False` + otherwise + """ + try: + idna.encode(value) + except idna.IDNAError: + return False + return True + + +def is_valid_ip_address(value: str) -> bool: + """Return whether the given string is a valid IP address + + Args: + + value: string to check + + Returns: + + `True` if the given value is a valid IP address, `False` + otherwise + """ + try: + ipaddress.ip_address(value) + except ipaddress.AddressValueError: + return False + return True + + +def hostname_or_ip_address( + key: str, expected_value: str = "a valid domain name or IP address" +) -> Schema: + """Return a hostname or IP address validator for the given + configuration item. + + Args: + + key: key of the configuration item + expected_value: description of the expected type of + value for this configuration item + + """ + return And(str, Or(is_valid_hostname, is_valid_ip_address), error=error(key, expected_value),) + + +SERVER_SCHEMA = Schema( + { + Optional("host", default="localhost"): hostname_or_ip_address("server.host"), + Optional("port", default=50051): Use( + int, error=error("server.port", "a valid port number") + ), + } +) + +AI_SCHEMA = Schema( + { + "initial_weights": Use(str, error=error("ai.initial_weights", "a valid path"),), + "rounds": strictly_positive_integer("ai.rounds"), + "epochs": positive_integer("ai.epochs"), + Optional("min_participants", default=1): strictly_positive_integer("ai.min_participants"), + Optional("fraction_participants", default=1.0): And( + Or(int, float), + lambda value: 0 < value <= 1.0, + error=error("ai.fraction_participants", "a float between 0 and 1.0"), + ), + } +) + +STORAGE_SCHEMA = Schema( + { + "endpoint": And(str, url, error=error("storage.endpoint", "a valid URL")), + "bucket": Use(str, error=error("storage.endpoint", "an S3 bucket name")), + "secret_access_key": Use( + str, error=error("storage.secret_access_key", "a valid utf-8 string") + ), + "access_key_id": Use(str, error=error("storage.access_key_id", "a valid utf-8 string")), + } +) + +CONFIG_SCHEMA = Schema( + { + Optional("server", default=SERVER_SCHEMA.validate({})): SERVER_SCHEMA, + "ai": AI_SCHEMA, + "storage": STORAGE_SCHEMA, + } +) + + +# pylint: disable=protected-access +def create_class_from_schema(class_name: str, schema: Schema) -> Any: + """Create a class named `class_name` from the given `schema`, where + the attributes of the new class are the schema's keys. + + Args: + + class_name: name of the class to create + schema: schema from which to create the class + + Returns: + + A new class which attributes are the given schema's keys + """ + keys = schema._schema.keys() + attributes = list(map(lambda key: key._schema if isinstance(key, Schema) else key, keys)) + return namedtuple(class_name, attributes) + + +# pylint: disable=invalid-name +AiConfig = create_class_from_schema("AiConfig", AI_SCHEMA) +AiConfig.__doc__ = "FL configuration: number of participant to each training round, etc." + +ServerConfig = create_class_from_schema("ServerConfig", SERVER_SCHEMA) +ServerConfig.__doc__ = "The server configuration: TLS, addresses for incoming connections, etc." + +StorageConfig = create_class_from_schema("StorageConfig", STORAGE_SCHEMA) +StorageConfig.__doc__ = "Storage related configuration: storage endpoints and credentials, etc." + +T = TypeVar("T", bound="Config") + + +class Config: + """The coordinator configuration. + + Configuration is split in three sections: `Config.ai` for items + directly related to the FL protocol, `Config.server` for the + server infrastructure, and `Config.storage` for storage related + items. + + The configuration is usually loaded from a dictionary the `Config` + attributes map to the dictionary keys. + + Args: + + ai: foo + storage: bar + server: baz + + :Example: + + Here is a valid configuration: + + .. code-block:: toml + + # This section correspond to the `Config.server` attribute + [server] + + # Address to listen on for incoming gRPC connections + host = "[::]" + # Port to listen on for incoming gRPC connections + port = 50051 + + + # This section corresponds to the `Config.ai` attribute + [ai] + + # Path to a file containing a numpy ndarray to use a initial model + # weights. The path can be absolute or relative. Relative path are + # relative to the config file directory. + initial_weights = "./test_array.npy" + + # Number of global rounds the model is going to be trained for. This + # must be a positive integer. + rounds = 1 + + # Number of local epochs per round + epochs = 1 + + # Minimum number of participants to be selected for a round. + min_participants = 1 + + # Fraction of total clients that participate in a training round. This + # must be a float between 0 and 1. + fraction_participants = 1.0 + + # This section corresponds to the `Config.storage` attribute + [storage] + + # URL to the storage service to use + endpoint = "http://localhost:9000" + + # Name of the bucket for storing the aggregated models + bucket = "aggregated_weights" + + # AWS secret access to use to authenticate to the storage service + secret_access_key = "my-secret" + + # AWS access key ID to use to authenticate to the storage service + access_key_id = "my-key-id" + + This config file can be loaded and used as follow: + + .. code-block:: python + + from xain_fl.config import Config + + config = Config.load("example_config.toml") + + assert config.server.host == "[::]" + assert config.server.port == 50051 + + assert config.ai.initial_weights == "./test_array.npy" + assert config.ai.rounds == 1 + assert config.ai.epochs == 1 + assert config.ai.min_participants == 1 + assert config.ai.fraction_participants == 1.0 + + assert config.storage.endpoint == "http://localhost:9000" + assert config.storage.bucket == "aggregated_weights" + assert config.storage.secret_access_key == "my-access-key" + assert config.storage.access_key_id == "my-key" + """ + + def __init__(self, ai: NamedTuple, storage: NamedTuple, server: NamedTuple): + self.ai = ai + self.storage = storage + self.server = server + + @classmethod + def from_unchecked_dict(cls: Type[T], dictionary: Mapping[str, Any]) -> T: + """Check if the given dictionary contains a valid configuration, and + if so, create a `Config` instance from it. + + Args: + + dictionary: a dictionary containing the configuration + """ + try: + valid_config = CONFIG_SCHEMA.validate(dictionary) + except SchemaError as err: + raise InvalidConfig(err) + return cls.from_valid_dict(valid_config) + + @classmethod + def from_valid_dict(cls: Type[T], dictionary: Mapping[str, Any]) -> T: + """Create a `Config` instance for the given dictionary, assuming it + contains a valid configuration + + Args: + + dictionary: a dictionary containing the configuration + + """ + return cls( + AiConfig(**dictionary["ai"]), + StorageConfig(**dictionary["storage"]), + ServerConfig(**dictionary["server"]), + ) + + @classmethod + def load(cls: Type[T], path: str) -> T: + """Read the config file from the given path, check that it contains a + valid configuration, and return the corresponding `Config` + instance. + + Args: + + path: path to a configuration file + """ + with open(path, "r") as f: + raw_config = toml.load(f) + return cls.from_unchecked_dict(raw_config) + + +class InvalidConfig(ValueError): + """ + Exception raised upon trying to load an invalid configuration + """ + + def __init__(self, err, *args, **kwargs): + super().__init__(f"Invalid configuration:\n{err.code}", *args, **kwargs) diff --git a/xain_fl/coordinator/store.py b/xain_fl/coordinator/store.py index 564dae44e..05da9fd44 100644 --- a/xain_fl/coordinator/store.py +++ b/xain_fl/coordinator/store.py @@ -5,25 +5,16 @@ import boto3 from numpy import ndarray - -# pylint: disable=too-few-public-methods -class StoreConfig: - def __init__( - self, endpoint_url: str, access_key_id: str, secret_access_key: str, bucket: str, - ): - self.endpoint_url = endpoint_url - self.access_key_id = access_key_id - self.secret_access_key = secret_access_key - self.bucket = bucket +from xain_fl.config import StorageConfig class Store: - def __init__(self, config: StoreConfig): + def __init__(self, config: StorageConfig): self.config = config # pylint: disable=invalid-name self.s3 = boto3.resource( "s3", - endpoint_url=self.config.endpoint_url, + endpoint_url=self.config.endpoint, aws_access_key_id=self.config.access_key_id, aws_secret_access_key=self.config.secret_access_key, # FIXME: not sure what this should be for now