Skip to content
This repository has been archived by the owner on Aug 30, 2022. It is now read-only.

Commit

Permalink
Make code more comprehensible by extracting code segment into function
Browse files Browse the repository at this point in the history
  • Loading branch information
tanertopal committed Sep 6, 2019
1 parent 4074ff5 commit e9f7ba3
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions xain/benchmark/aggregation/task_accuracies.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def read_task_values(task_result: TaskResult) -> Tuple[str, List[float], int]:
fname (str): path to results.json file containing required fields
Returns
class, label, final_accuracy (str, str, float): e.g. ("VisionTask", "cpp01", 0.92)
task_class, accuracies, epochs (str, List[float], int): e.g. ("VisionTask", [0.12, 0.33], 5)
"""
return (task_result.get_class(), task_result.get_accuracies(), task_result.get_E())

Expand All @@ -30,13 +30,25 @@ def read_all_task_values(group_dir: str) -> List[Tuple[str, List[float], int]]:
:param filter_substring: has to be part of the dir name in results directory
:returns: List of tuples (task_class, task_label, federated_accuracy)
:returns: List of tuples (task_class, accuracies, epochs)
"""
task_results = GroupResult(group_dir).get_results()
# Read accuracies from each file and return list of values in tuples
return [read_task_values(task_result) for task_result in task_results]


def build_plot_values(
task_class: str, task_accuracies: List[float], E: int
) -> PlotValues:
"""Returns PlotValues with appropriate indices based on task class (Unitary or Federated)"""
if "Unitary" in task_class:
indices = [i for i in range(1, len(task_accuracies) + 1, 1)]
else:
indices = [i for i in range(E, len(task_accuracies) * E + 1, E)]

return (task_class, task_accuracies, indices)


def prepare_aggregation_data(group_name: str) -> List[PlotValues]:
"""Constructs and returns curves and xticks_args
Expand All @@ -56,17 +68,10 @@ def prepare_aggregation_data(group_name: str) -> List[PlotValues]:
assert values, "No values for group found"
assert len(values) == 2, "Expecting only two tasks"

data: List[PlotValues] = []

for value in values:
task_class, task_accuracies, E = value

if "Unitary" in task_class:
indices = [i for i in range(1, len(task_accuracies) + 1, 1)]
else:
indices = [i for i in range(E, len(task_accuracies) * E + 1, E)]

data.append((task_class, task_accuracies, indices))
data: List[PlotValues] = [
build_plot_values(task_class, task_accuracies, E)
for task_class, task_accuracies, E in values
]

return data

Expand Down

0 comments on commit e9f7ba3

Please sign in to comment.