This repository has been archived by the owner on Aug 30, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
replace the CLI arguments with a config file
References ========== https://xainag.atlassian.net/browse/XP-456 https://xainag.atlassian.net/browse/XP-512 Rationale ========= The CLI is getting complex so it is worth loading the configuration from a file instead. Implementation details ====================== TOML ---- We decided to use TOML for the following reasons: - it is human friendly, ie easy to read and write - our configuration has a pretty flat structure which makes TOML quite adapted - it is well specified and has lots of implementation - it is well known The other options we considered: - INI: it is quite frequent in the Python ecosystem to use INI for config files, and the standard library even provides support for this. However, INI is not as powerful as TOML and does not have a specification - JSON: it is very popular but is not human friendly. For instance, it does not support comments, is very verbose, and breaks easily (if a trailing comma is forgotten at the end of a list for instance) - YAML: another popular choice, but is in my opinion more complex than TOML. Validation ---------- We use the third-party `schema` library to validate the configuration. It provides a convenient way to: - declare a schema to validate our config - leverage third-party libraries to validate some inputs (we use the `idna` library to validate hostnames) - define our own validators - transform data after it has been validated: this can be useful to turn a relative path into an absolute one for example - provide user friendly error message when the configuration is invalid The `Config` class ------------------ By default, the `schema` library returns a dictionary containing a valid configuration, but that is not convenient to manipulate in Python. Therefore, we dynamically create a `Config` class from the configuration schema, and instantiate a `Config` object from the data returned by the `schema` validator. Package re-organization ----------------------- We moved the command line and config file logic into its own `config` sub-package, and moved the former `xain_fl.cli.main` entrypoint into the `xain_fl.__main__` module. Docker infrastructure --------------------- - Cache the xain_fl dependencies. This considerably reduces "edit->build-> debug" cycle, since installing the dependencies takes about 30 minutes. - Move all the docker related files into the `docker/` directory
- Loading branch information
1 parent
50f61a3
commit 6a90382
Showing
17 changed files
with
703 additions
and
277 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
FROM python:3.6-alpine | ||
|
||
RUN apk update && apk add python3-dev build-base git | ||
|
||
WORKDIR /app | ||
|
||
# Some dependencies require a very long compilation: protobuf, numpy, | ||
# grpcio. To avoid having to re-install these packages every time, we | ||
# pre-install them so that they are cached. | ||
# | ||
# However, we cannot just pre-install a few packages, because we don't | ||
# know exactly which version `pip` will pick for them. Instead, we | ||
# give pip all the package's dependencies, and let it resolve and | ||
# install them. | ||
COPY setup.py . | ||
RUN mkdir xain_fl && \ | ||
printf '__version__ = "0.2.0"\n__short_version__ = "0.2"' > xain_fl/__version__.py && \ | ||
touch README.md && \ | ||
python setup.py egg_info && \ | ||
cat *.egg-info/requires.txt | grep -v '^\[' | uniq | pip install -r /dev/stdin | ||
|
||
RUN rm -rf xain_fl README.md | ||
COPY README.md . | ||
COPY xain_fl xain_fl/ | ||
RUN pip install -e . | ||
|
||
COPY docker/dev/xain-fl.toml /xain-fl.toml | ||
COPY docker/dev/initial_weights.npy /initial_weights.npy | ||
|
||
CMD ["coordinator", "--config", "/xain-fl.toml"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
[server] | ||
# Address to listen on for incoming gRPC connections | ||
host = "0.0.0.0" | ||
# Port to listen on for incoming gRPC connections | ||
port = 50051 | ||
|
||
[ai] | ||
# Path to a file containing a numpy ndarray to use a initial model weights. | ||
initial_weights = "/initial_weights.npy" | ||
# Number of global rounds the model is going to be trained for. This | ||
# must be a positive integer. | ||
rounds = 1 | ||
# Number of local epochs per round | ||
epochs = 1 | ||
# Minimum number of participants to be selected for a round. | ||
min_participants = 1 | ||
# Fraction of total clients that participate in a training round. This | ||
# must be a float between 0 and 1. | ||
fraction_participants = 1.0 | ||
|
||
[storage] | ||
# URL to the storage service to use | ||
endpoint = "http://minio-dev:900" | ||
# Name of the bucket for storing the aggregated models | ||
bucket = "xain-fl-aggregated-weights" | ||
# AWS access key ID to use to authenticate to the storage service | ||
access_key_id = "minio" | ||
# AWS secret access to use to authenticate to the storage service | ||
secret_access_key = "minio123" |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
# pylint: disable=missing-docstring,redefined-outer-name | ||
import re | ||
|
||
import pytest | ||
|
||
from xain_fl.config import Config, InvalidConfig | ||
|
||
|
||
@pytest.fixture | ||
def server_sample(): | ||
""" | ||
Return a valid "server" section | ||
""" | ||
return {"host": "localhost", "port": 50051} | ||
|
||
|
||
@pytest.fixture | ||
def ai_sample(): | ||
""" | ||
Return a valid "ai" section | ||
""" | ||
return { | ||
"initial_weights": "./test_array.npy", | ||
"rounds": 1, | ||
"epochs": 1, | ||
"min_participants": 1, | ||
"fraction_participants": 1.0, | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def storage_sample(): | ||
""" | ||
Return a valid "storage" section | ||
""" | ||
return { | ||
"endpoint": "http://localhost:9000", | ||
"bucket": "aggregated_weights", | ||
"secret_access_key": "my-secret", | ||
"access_key_id": "my-key-id", | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def config_sample(server_sample, ai_sample, storage_sample): | ||
""" | ||
Return a valid config | ||
""" | ||
return { | ||
"ai": ai_sample, | ||
"server": server_sample, | ||
"storage": storage_sample, | ||
} | ||
|
||
|
||
def test_load_valid_config(config_sample): | ||
""" | ||
Check that a valid config is loaded correctly | ||
""" | ||
config = Config.from_unchecked_dict(config_sample) | ||
|
||
assert config.server.host == "localhost" | ||
assert config.server.port == 50051 | ||
|
||
assert config.ai.initial_weights == "./test_array.npy" | ||
assert config.ai.rounds == 1 | ||
assert config.ai.epochs == 1 | ||
assert config.ai.min_participants == 1 | ||
assert config.ai.fraction_participants == 1.0 | ||
|
||
assert config.storage.endpoint == "http://localhost:9000" | ||
assert config.storage.bucket == "aggregated_weights" | ||
assert config.storage.secret_access_key == "my-secret" | ||
assert config.storage.access_key_id == "my-key-id" | ||
|
||
|
||
def test_server_config_ip_address(config_sample, server_sample): | ||
"""Check that the config is loaded correctly when the `server.host` | ||
key is an IP address | ||
""" | ||
# Ipv4 host | ||
server_sample["host"] = "1.2.3.4" | ||
config_sample["server"] = server_sample | ||
config = Config.from_unchecked_dict(config_sample) | ||
assert config.server.host == server_sample["host"] | ||
|
||
# Ipv6 host | ||
server_sample["host"] = "::1" | ||
config_sample["server"] = server_sample | ||
config = Config.from_unchecked_dict(config_sample) | ||
assert config.server.host == server_sample["host"] | ||
|
||
|
||
def test_server_config_extra_key(config_sample, server_sample): | ||
"""Check that the config is rejected when the server section contains | ||
an extra key | ||
""" | ||
server_sample["extra-key"] = "foo" | ||
config_sample["server"] = server_sample | ||
|
||
with AssertInvalid() as err: | ||
Config.from_unchecked_dict(config_sample) | ||
|
||
err.check_section("server") | ||
err.check_extra_key("extra-key") | ||
|
||
|
||
def test_server_config_invalid_host(config_sample, server_sample): | ||
"""Check that the config is rejected when the `server.host` key is | ||
invalid. | ||
""" | ||
server_sample["host"] = 1.0 | ||
config_sample["server"] = server_sample | ||
|
||
with AssertInvalid() as err: | ||
Config.from_unchecked_dict(config_sample) | ||
|
||
err.check_other( | ||
re.compile("Invalid `server.host`: value must be a valid domain name or IP address") | ||
) | ||
|
||
|
||
def test_server_config_valid_ipv6(config_sample, server_sample): | ||
"""Check some edge cases with IPv6 `server.host` key""" | ||
server_sample["host"] = "::" | ||
config_sample["server"] = server_sample | ||
config = Config.from_unchecked_dict(config_sample) | ||
assert config.server.host == server_sample["host"] | ||
|
||
server_sample["host"] = "fe80::" | ||
config_sample["server"] = server_sample | ||
config = Config.from_unchecked_dict(config_sample) | ||
assert config.server.host == server_sample["host"] | ||
|
||
|
||
# Adapted from unittest's assertRaises | ||
class AssertInvalid: | ||
"""A context manager that check that an `xainfl.config.InvalidConfig` | ||
exception is raised, and provides helpers to perform checks on the | ||
exception. | ||
""" | ||
|
||
def __init__(self): | ||
self.message = None | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_value, _tb): | ||
if exc_type is None: | ||
raise Exception("Did not get an exception") | ||
if not isinstance(exc_value, InvalidConfig): | ||
# let this un-expected exception be re-raised | ||
return False | ||
|
||
self.message = str(exc_value) | ||
|
||
return True | ||
|
||
def check_section(self, section): | ||
needle = re.compile(f"Key '{section}' error:") | ||
assert re.search(needle, self.message) | ||
|
||
def check_extra_key(self, key): | ||
needle = re.compile(f"Wrong keys '{key}' in") | ||
assert re.search(needle, self.message) | ||
|
||
def check_other(self, needle): | ||
assert re.search(needle, self.message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
"""This module is the entrypoint to start a new coordinator instance. | ||
""" | ||
import sys | ||
|
||
import numpy as np | ||
|
||
from xain_fl.config import Config, InvalidConfig, get_cmd_parameters | ||
from xain_fl.coordinator.coordinator import Coordinator | ||
from xain_fl.coordinator.store import Store | ||
from xain_fl.serve import serve | ||
|
||
|
||
def main(): | ||
"""Start a coordinator instance | ||
""" | ||
|
||
args = get_cmd_parameters() | ||
try: | ||
config = Config.load(args.config) | ||
except InvalidConfig as err: | ||
print(err, file=sys.stderr) | ||
sys.exit(1) | ||
|
||
coordinator = Coordinator( | ||
weights=list(np.load(config.ai.initial_weights, allow_pickle=True)), | ||
num_rounds=config.ai.rounds, | ||
epochs=config.ai.epochs, | ||
minimum_participants_in_round=config.ai.min_participants, | ||
fraction_of_participants=config.ai.fraction_participants, | ||
) | ||
|
||
store = Store(config.storage) | ||
|
||
serve(coordinator=coordinator, store=store, host=config.server.host, port=config.server.port) | ||
|
||
|
||
main() |
Oops, something went wrong.