This repository has been archived by the owner on Aug 30, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #38 from xainag/aggregations
Get applicable aggregation from config and run
- Loading branch information
Showing
10 changed files
with
110 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from absl import app, flags | ||
|
||
from .aggregation import aggregate | ||
|
||
|
||
def app_run_aggregate(): | ||
flags.mark_flag_as_required("group_name") | ||
app.run(main=lambda _: aggregate()) | ||
|
||
|
||
if __name__ == "__main__": | ||
app_run_aggregate() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,54 +1,35 @@ | ||
import os | ||
from abc import ABC | ||
from typing import List | ||
from typing import Callable, Dict | ||
|
||
from absl import logging | ||
from absl import flags, logging | ||
|
||
from xain.benchmark.aggregation import final_task_accuracies, task_accuracies | ||
from xain.helpers import storage | ||
|
||
|
||
def flul_aggregation(): | ||
logging.info("flul_aggregation started") | ||
raise NotImplementedError() | ||
|
||
|
||
def cpp_aggregation(): | ||
logging.info("cpp_aggregation started") | ||
raise NotImplementedError() | ||
FLAGS = flags.FLAGS | ||
|
||
|
||
class TaskResult(ABC): | ||
def __init__(self, fname: str): | ||
self.data = storage.read_json(fname) | ||
def aggregate(): | ||
"""Calls aggregation defined in group config.json""" | ||
fname = os.path.join(FLAGS.results_dir, FLAGS.group_name, "config.json") | ||
config = storage.read_json(fname) | ||
|
||
def get_class(self) -> str: | ||
return self.data["task_name"].split("_")[0] | ||
aggregation_name = config["aggregation_name"] | ||
|
||
def get_label(self) -> str: | ||
return self.data["dataset"].split("-")[-1] | ||
aggregations[aggregation_name]() | ||
|
||
def get_final_accuracy(self) -> float: | ||
return self.data["acc"] | ||
|
||
def get_accuracies(self) -> List[float]: | ||
return self.data["hist"]["val_acc"] | ||
|
||
|
||
class GroupResult(ABC): | ||
def __init__(self, group_dir: str): | ||
assert os.path.isdir(group_dir) | ||
def flul_aggregation(): | ||
logging.info("flul_aggregation started") | ||
task_accuracies.aggregate() | ||
|
||
# get list of all directories which contain given substring | ||
json_files = [ | ||
fname | ||
for fname in storage.listdir_recursive(group_dir, relpath=False) | ||
if fname.endswith("results.json") | ||
] | ||
|
||
if not json_files: | ||
raise Exception(f"No values results found in group_dir: {group_dir}") | ||
def cpp_aggregation(): | ||
logging.info("cpp_aggregation started") | ||
final_task_accuracies.aggregate() | ||
|
||
self.task_results = [TaskResult(fname) for fname in json_files] | ||
|
||
def get_results(self) -> List[TaskResult]: | ||
return self.task_results | ||
aggregations: Dict[str, Callable] = { | ||
"flul-aggregation": flul_aggregation, | ||
"cpp-aggregation": cpp_aggregation, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import os | ||
from abc import ABC | ||
from typing import List | ||
|
||
from xain.helpers import storage | ||
|
||
|
||
class TaskResult(ABC): | ||
def __init__(self, fname: str): | ||
self.data = storage.read_json(fname) | ||
|
||
def get_class(self) -> str: | ||
return self.data["task_name"].split("_")[0] | ||
|
||
def get_label(self) -> str: | ||
return self.data["dataset"].split("-")[-1] | ||
|
||
def get_final_accuracy(self) -> float: | ||
return self.data["acc"] | ||
|
||
def get_accuracies(self) -> List[float]: | ||
return self.data["hist"]["val_acc"] | ||
|
||
def get_E(self) -> int: | ||
return self.data["E"] | ||
|
||
|
||
class GroupResult(ABC): | ||
def __init__(self, group_dir: str): | ||
assert os.path.isdir(group_dir) | ||
|
||
# get list of all directories which contain given substring | ||
json_files = [ | ||
fname | ||
for fname in storage.listdir_recursive(group_dir, relpath=False) | ||
if fname.endswith("results.json") | ||
] | ||
|
||
if not json_files: | ||
raise Exception(f"No values results found in group_dir: {group_dir}") | ||
|
||
self.task_results = [TaskResult(fname) for fname in json_files] | ||
|
||
def get_results(self) -> List[TaskResult]: | ||
return self.task_results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters