Skip to content
This repository has been archived by the owner on Aug 30, 2022. It is now read-only.

Commit

Permalink
Merge pull request #38 from xainag/aggregations
Browse files Browse the repository at this point in the history
Get applicable aggregation from config and run
  • Loading branch information
Taner Topal committed Sep 6, 2019
2 parents 3e1a249 + ed304db commit c847c73
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 72 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
"console_scripts": [
"pull_results=xain.ops.__main__:download",
"train_remote=xain.benchmark.__main__:main",
"aggregate=xain.benchmark.aggregation.__main__:app_run_aggregate",
"aggregate_final_task_accuracies=xain.benchmark.aggregation.final_task_accuracies:app_run_aggregate",
"aggregate_task_accuracies=xain.benchmark.aggregation.task_accuracies:app_run_aggregate",
]
Expand Down
12 changes: 12 additions & 0 deletions xain/benchmark/aggregation/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from absl import app, flags

from .aggregation import aggregate


def app_run_aggregate():
flags.mark_flag_as_required("group_name")
app.run(main=lambda _: aggregate())


if __name__ == "__main__":
app_run_aggregate()
59 changes: 20 additions & 39 deletions xain/benchmark/aggregation/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,35 @@
import os
from abc import ABC
from typing import List
from typing import Callable, Dict

from absl import logging
from absl import flags, logging

from xain.benchmark.aggregation import final_task_accuracies, task_accuracies
from xain.helpers import storage


def flul_aggregation():
logging.info("flul_aggregation started")
raise NotImplementedError()


def cpp_aggregation():
logging.info("cpp_aggregation started")
raise NotImplementedError()
FLAGS = flags.FLAGS


class TaskResult(ABC):
def __init__(self, fname: str):
self.data = storage.read_json(fname)
def aggregate():
"""Calls aggregation defined in group config.json"""
fname = os.path.join(FLAGS.results_dir, FLAGS.group_name, "config.json")
config = storage.read_json(fname)

def get_class(self) -> str:
return self.data["task_name"].split("_")[0]
aggregation_name = config["aggregation_name"]

def get_label(self) -> str:
return self.data["dataset"].split("-")[-1]
aggregations[aggregation_name]()

def get_final_accuracy(self) -> float:
return self.data["acc"]

def get_accuracies(self) -> List[float]:
return self.data["hist"]["val_acc"]


class GroupResult(ABC):
def __init__(self, group_dir: str):
assert os.path.isdir(group_dir)
def flul_aggregation():
logging.info("flul_aggregation started")
task_accuracies.aggregate()

# get list of all directories which contain given substring
json_files = [
fname
for fname in storage.listdir_recursive(group_dir, relpath=False)
if fname.endswith("results.json")
]

if not json_files:
raise Exception(f"No values results found in group_dir: {group_dir}")
def cpp_aggregation():
logging.info("cpp_aggregation started")
final_task_accuracies.aggregate()

self.task_results = [TaskResult(fname) for fname in json_files]

def get_results(self) -> List[TaskResult]:
return self.task_results
aggregations: Dict[str, Callable] = {
"flul-aggregation": flul_aggregation,
"cpp-aggregation": cpp_aggregation,
}
8 changes: 4 additions & 4 deletions xain/benchmark/aggregation/final_task_accuracies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from xain.types import PlotValues, XticksLabels, XticksLocations

from .aggregation import GroupResult, TaskResult
from .plot import plot
from .results import GroupResult, TaskResult

FLAGS = flags.FLAGS

Expand All @@ -33,9 +33,9 @@ def read_all_task_values(group_dir: str) -> List[Tuple[str, str, float]]:
Reads results directory for given group id and
extracts values from results.json files
:param filter_substring: has to be part of the dir name in results directory
Args:
group_dir (str): path to directory to be read
:returns: List of tuples (task_class, task_label, final_accuracy)
"""
task_results = GroupResult(group_dir).get_results()
# Read accuracies from each file and return list of values in tuples
Expand Down Expand Up @@ -128,7 +128,7 @@ def aggregate() -> str:
legend_loc="upper right",
)

logging.info(f"Data plotted and saved in {fname}")
logging.info(f"Data plotted and saved in {fpath}")

return fpath

Expand Down
2 changes: 1 addition & 1 deletion xain/benchmark/aggregation/final_task_accuracies_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from xain.helpers import sha1

from . import final_task_accuracies
from .aggregation import TaskResult
from .results import TaskResult

FLAGS = flags.FLAGS

Expand Down
45 changes: 45 additions & 0 deletions xain/benchmark/aggregation/results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
from abc import ABC
from typing import List

from xain.helpers import storage


class TaskResult(ABC):
def __init__(self, fname: str):
self.data = storage.read_json(fname)

def get_class(self) -> str:
return self.data["task_name"].split("_")[0]

def get_label(self) -> str:
return self.data["dataset"].split("-")[-1]

def get_final_accuracy(self) -> float:
return self.data["acc"]

def get_accuracies(self) -> List[float]:
return self.data["hist"]["val_acc"]

def get_E(self) -> int:
return self.data["E"]


class GroupResult(ABC):
def __init__(self, group_dir: str):
assert os.path.isdir(group_dir)

# get list of all directories which contain given substring
json_files = [
fname
for fname in storage.listdir_recursive(group_dir, relpath=False)
if fname.endswith("results.json")
]

if not json_files:
raise Exception(f"No values results found in group_dir: {group_dir}")

self.task_results = [TaskResult(fname) for fname in json_files]

def get_results(self) -> List[TaskResult]:
return self.task_results
41 changes: 23 additions & 18 deletions xain/benchmark/aggregation/task_accuracies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,49 @@

from xain.types import PlotValues

from .aggregation import GroupResult, TaskResult
from .plot import plot
from .results import GroupResult, TaskResult

FLAGS = flags.FLAGS


def read_task_values(task_result: TaskResult) -> Tuple[str, List[float]]:
def read_task_values(task_result: TaskResult) -> Tuple[str, List[float], int]:
"""Reads unitary and federated accuracy from results.json
Args:
fname (str): path to results.json file containing required fields
Returns
class, label, final_accuracy (str, str, float): e.g. ("VisionTask", "cpp01", 0.92)
Returns:
task_class, accuracies, epochs (str, List[float], int): e.g. ("VisionTask", [0.12, 0.33], 5)
"""
return (task_result.get_class(), task_result.get_accuracies())
return (task_result.get_class(), task_result.get_accuracies(), task_result.get_E())


def read_all_task_values(group_dir: str) -> List[Tuple[str, List[float]]]:
def read_all_task_values(group_dir: str) -> List[Tuple[str, List[float], int]]:
"""
Reads results directory for given group id and
extracts values from results.json files
:param filter_substring: has to be part of the dir name in results directory
:returns: List of tuples (task_class, task_label, federated_accuracy)
Args:
group_dir: path to directory to be read
"""
task_results = GroupResult(group_dir).get_results()
# Read accuracies from each file and return list of values in tuples
return [read_task_values(task_result) for task_result in task_results]


def build_plot_values(values: Tuple[str, List[float], int]) -> PlotValues:
"""Returns PlotValues with appropriate indices based on task class (Unitary or Federated)"""
task_class, task_accuracies, E = values

if "Unitary" in task_class:
indices = [i for i in range(1, len(task_accuracies) + 1, 1)]
else:
indices = [i for i in range(E, len(task_accuracies) * E + 1, E)]

return (task_class, task_accuracies, indices)


def prepare_aggregation_data(group_name: str) -> List[PlotValues]:
"""Constructs and returns curves and xticks_args
Expand All @@ -56,13 +67,7 @@ def prepare_aggregation_data(group_name: str) -> List[PlotValues]:
assert values, "No values for group found"
assert len(values) == 2, "Expecting only two tasks"

data: List[PlotValues] = []

for value in values:
print(value)
task_class, task_accuracies = value
indices = list(range(1, len(task_accuracies) + 1))
data.append((task_class, task_accuracies, indices))
data: List[PlotValues] = list(map(build_plot_values, values))

return data

Expand All @@ -81,7 +86,7 @@ def aggregate() -> str:
data = prepare_aggregation_data(group_name)

# Take highest length of values list as xlim_max
xlim_max = max([len(values) for _, values, _ in data])
xlim_max = max([len(values) for _, values, _ in data]) + 1

fpath = plot(
data,
Expand All @@ -94,7 +99,7 @@ def aggregate() -> str:
xlim_max=xlim_max,
)

logging.info(f"Data plotted and saved in {fname}")
logging.info(f"Data plotted and saved in {fpath}")

return fpath

Expand Down
2 changes: 1 addition & 1 deletion xain/benchmark/aggregation/task_accuracies_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_plot_task_accuracies(output_dir, group_name, monkeypatch):
]
fname = f"plot_{group_name}.png"
expected_filepath = os.path.join(output_dir, fname)
expected_sha1 = "457baa8179f08f06c4e60213eb0bbbe79a4f9d3e"
expected_sha1 = "211f0d94bd56cc526b61ded4371af7aef6762f92"

def mock_prepare_aggregation_data(_: str):
return data
Expand Down
9 changes: 1 addition & 8 deletions xain/benchmark/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os
from tempfile import TemporaryDirectory
from time import strftime
from typing import Callable, Dict, List, Optional
from typing import Dict, List, Optional

from absl import flags, logging

from xain.benchmark.aggregation import aggregation
from xain.helpers import storage
from xain.ops import docker, results, run

Expand All @@ -14,12 +13,6 @@
FLAGS = flags.FLAGS


aggregations: Dict[str, Callable] = {
"flul-aggregation": aggregation.flul_aggregation,
"cpp-aggregation": aggregation.cpp_aggregation,
}


class Benchmark:
def __init__(self, tasks: List[Task], aggregation_name: str):
self.tasks = tasks
Expand Down
3 changes: 2 additions & 1 deletion xain/benchmark/benchmark_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from xain.datasets.dataset import config
from xain.ops.run import cores

from .benchmark import aggregations, benchmarks
from .aggregation.aggregation import aggregations
from .benchmark import benchmarks


def test_valid_aggregation_names():
Expand Down

0 comments on commit c847c73

Please sign in to comment.