From 0be8b16154f19d7a51a3e5d80b91d0a5075a1a40 Mon Sep 17 00:00:00 2001 From: Arun Patole Date: Mon, 3 Jun 2024 14:22:30 +0530 Subject: [PATCH] Add hierarchical statistics support This change adds hierarchical federated statistics support. The existing `StatisticsController` outputs global statistics in a flat hierarchy. There are usecases where global statistics are required to be generated as per the given hierarchy configuration where NVFlare clients can be specified to belong to a particular hierarchy. The new class `HierarchicalStatisticsController` is added to support hierarchical statistics. It is derived from `StatisticsController` and takes additional argument `hierarchy_config` for hierarchy specification file providing details about all the clients and their hierarchy. Example hierarchy config file contents: { "Manufacturers": [ { "Name": "Manufacturer-1", "Orgs": [ { "Name": "Org-1", "Sites": ["Site-1"] }, { "Name": "Org-2", "Sites": ["Site-2"] } ], }, { "Name": "Manufacturer-2", "Orgs": [ { "Name": "Org-3", "Sites": ["Site-3", "Site-4"] } ] }, ] } The above hierarchy config specifies three level hierarchy for four NVFlare clients named 'Site-1', 'Site-2', 'Site-3' and 'Site-4'. And the generate global statistics output will in hierarchical format like below. At each hierarchical level, global statistics are caulculated whereas local statistics are reported at the last hierarchical level. { "Global": { }, "Manufacturers": [ { "Name": "Manufacturer-1", "Global": { }, "Orgs": [ { "Name": "Org-1", "Global": { }, "Sites": [ { "Name": "Site-1", "Local": { }, }, { "Name": "Site-2", "Local": { }, } ], }, ] }, ], } The number of hierarchical levels and the hierarchical level names are automatically calculated from the given hierarchy config. Any number of hierarchical levels are supported. This change also adds unit test to test hierarchical global statistics upto four hierarchical levels. --- nvflare/app_common/app_constant.py | 2 + .../statistics/hierarchical_numeric_stats.py | 611 ++++++++++++ .../hierarchical_statistics_controller.py | 335 +++++++ .../hierarchical_numeric_stats_test.py | 918 ++++++++++++++++++ 4 files changed, 1866 insertions(+) create mode 100644 nvflare/app_common/statistics/hierarchical_numeric_stats.py create mode 100644 nvflare/app_common/workflows/hierarchical_statistics_controller.py create mode 100644 tests/unit_test/app_common/statistics/hierarchical_numeric_stats_test.py diff --git a/nvflare/app_common/app_constant.py b/nvflare/app_common/app_constant.py index ec92e2eccd..98b8bc21d6 100644 --- a/nvflare/app_common/app_constant.py +++ b/nvflare/app_common/app_constant.py @@ -180,6 +180,8 @@ class StatisticsConstants(AppConstants): STATS_2nd_STATISTICS = "fed_stats_2nd_statistics" GLOBAL = "Global" + LOCAL = "Local" + NAME = "Name" ordered_statistics = { STATS_1st_STATISTICS: [STATS_COUNT, STATS_FAILURE_COUNT, STATS_SUM, STATS_MEAN, STATS_MIN, STATS_MAX], diff --git a/nvflare/app_common/statistics/hierarchical_numeric_stats.py b/nvflare/app_common/statistics/hierarchical_numeric_stats.py new file mode 100644 index 0000000000..0f905e7502 --- /dev/null +++ b/nvflare/app_common/statistics/hierarchical_numeric_stats.py @@ -0,0 +1,611 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from math import sqrt +from typing import Dict, List, TypeVar + +from nvflare.app_common.abstract.statistics_spec import Bin, BinRange, DataType, Feature, Histogram, HistogramType +from nvflare.app_common.app_constant import StatisticsConstants as StC + +T = TypeVar("T") + + +def get_initial_structure(client_metrics: dict, ordered_metrics: dict) -> dict: + """Calculate initial output structure that is common at all the hierarchical levels. + + Args: + client_metrics: Local stats for each client. + ordered_metrics: Ordered target metrics. + + Returns: + A dict containing initial output structure. + """ + stats = {} + for metric in ordered_metrics: + stats[metric] = {} + for stat in client_metrics: + for site in client_metrics[stat]: + for ds in client_metrics[stat][site]: + stats[metric][ds] = {} + for feature in client_metrics[stat][site][ds]: + stats[metric][ds][feature] = 0 + return stats + + +def create_output_structure( + client_metrics: dict, metric_task: str, ordered_metrics: dict, hierarchy_config: dict +) -> dict: + """Recursively calculate the hierarchical global stats structure from the given hierarchy config. + + Args: + client_metrics: Local stats for each client. + metric_task: Statistics task. + ordered_metrics: Ordered target metrics. + hierarchy_config: Hierarchy configuration for the global stats. + + Returns: + A dict containing hierarchical global stats structure. + """ + + def recursively_add_values(structure: dict, value_json: dict, metric_task: str, ordered_metrics: dict): + if isinstance(structure, dict): + new_items = {} + for key, value in list(structure.items()): + if key == StC.NAME: + continue + if isinstance(value, list): + if key not in new_items: + new_items[StC.GLOBAL] = get_initial_structure(value_json, ordered_metrics) + for i, item in enumerate(value): + if isinstance(item, str): + value[i] = { + StC.NAME: item, + StC.LOCAL: get_initial_structure(value_json, ordered_metrics), + } + else: + recursively_add_values(item, value_json, metric_task, ordered_metrics) + else: + recursively_add_values(value, value_json, metric_task, ordered_metrics) + structure.update(new_items) + elif isinstance(structure, list): + for item in structure: + recursively_add_values(item, value_json, metric_task, ordered_metrics) + return structure + + filled_structure = copy.deepcopy(hierarchy_config) + final_strcture = recursively_add_values(filled_structure, client_metrics, metric_task, ordered_metrics) + return final_strcture + + +def get_output_structure(client_metrics: dict, metric_task: str, ordered_metrics: dict, hierarchy_config: dict) -> dict: + """Create required global statistics hierarchical output structure. + + Args: + client_metrics: Local stats for each client. + metric_task: Statistics task. + ordered_metrics: Ordered target metrics. + hierarchy_config: Hierarchy configuration for the global stats. + + Returns: + A dict containing hierarchical global stats structure that also includes + top level global stats structure. + """ + top_strcture = get_initial_structure(client_metrics, ordered_metrics) + output_structure = { + StC.GLOBAL: top_strcture, + **create_output_structure(client_metrics, metric_task, ordered_metrics, hierarchy_config), + } + return output_structure + + +def update_output_strcture( + client_metrics: dict, + metric_task: str, + ordered_metrics: dict, + global_metrics: dict, +) -> None: + """Update global statistics hierarchical output structure with the new ordered metrics. + + Args: + client_metrics: Local stats for each client. + metric_task: Statistics task. + ordered_metrics: Ordered target metrics. + global_metrics: The current global metrics. + + Returns: + A dict containing updated hierarchical global stats. + """ + if isinstance(global_metrics, dict): + for key, value in list(global_metrics.items()): + if key == StC.NAME: + continue + elif key == StC.GLOBAL: + global_metrics[key].update(get_initial_structure(client_metrics, ordered_metrics)) + elif key == StC.LOCAL: + global_metrics[key].update(get_initial_structure(client_metrics, ordered_metrics)) + return + elif isinstance(value, list): + update_output_strcture(client_metrics, metric_task, ordered_metrics, value) + elif isinstance(global_metrics, list): + for item in global_metrics: + update_output_strcture(client_metrics, metric_task, ordered_metrics, item) + + +def get_global_stats(global_metrics: dict, client_metrics: dict, metric_task: str, hierarchy_config: dict) -> dict: + """Get global hierarchical statistics for the given hierarchy config. + + Args: + global_metrics: The current global metrics. + client_metrics: Local stats for each client. + metric_task: Statistics task. + hierarchy_config: Hierarchy configuration for the global stats. + + + Returns: + A dict containing global hierarchical statistics. + """ + # create stats structure + ordered_target_metrics = StC.ordered_statistics[metric_task] + ordered_metrics = [metric for metric in ordered_target_metrics if metric in client_metrics] + + # Create hierarchical output structure + if StC.GLOBAL not in global_metrics: + global_metrics = get_output_structure(client_metrics, metric_task, ordered_metrics, hierarchy_config) + else: + update_output_strcture(client_metrics, metric_task, ordered_metrics, global_metrics) + + for metric in ordered_metrics: + stats = client_metrics[metric] + if metric == StC.STATS_COUNT or metric == StC.STATS_FAILURE_COUNT or metric == StC.STATS_SUM: + for client_name in stats: + global_metrics = accumulate_hierarchical_metrics( + metric, client_name, stats[client_name], global_metrics, hierarchy_config + ) + elif metric == StC.STATS_MAX or metric == StC.STATS_MIN: + for client_name in stats: + global_metrics = get_hierarchical_mins_or_maxs( + metric, client_name, stats[client_name], global_metrics, hierarchy_config + ) + elif metric == StC.STATS_MEAN: + global_metrics = get_hierarchical_means(metric, global_metrics) + elif metric == StC.STATS_HISTOGRAM: + for client_name in stats: + global_metrics = get_hierarchical_histograms( + metric, client_name, stats[client_name], global_metrics, hierarchy_config + ) + elif metric == StC.STATS_VAR: + for client_name in stats: + global_metrics = accumulate_hierarchical_metrics( + metric, client_name, stats[client_name], global_metrics, hierarchy_config + ) + elif metric == StC.STATS_STDDEV: + global_metrics = get_hierarchical_stddevs(global_metrics) + + return global_metrics + + +def accumulate_hierarchical_metrics( + metric: str, client_name: str, metrics: dict, global_metrics: dict, hierarchy_config: dict +) -> dict: + """Accumulate matrics at each hierarchical level. + + Args: + metric: Metric to accumulate. + client_name: Client name. + metrics: Client metrics. + global_metrics: The current global metrics. + hierarchy_config: Hierarchy configuration for the global stats. + + Returns: + A dict containing accumulated hierarchical global statistics. + """ + + def recursively_accumulate_hierarchical_metrics( + metric: str, client_name: str, metrics: dict, global_metrics: dict, dataset: str, feature: str, org: list + ) -> dict: + if isinstance(global_metrics, dict): + for key, value in global_metrics.items(): + if key == StC.GLOBAL and StC.NAME not in global_metrics: + global_metrics[StC.GLOBAL][metric][dataset][feature] += metrics[dataset][feature] + continue + if key == StC.NAME: + if org and value in org: + # The client belongs to this org so update current global matrics before sending it further + global_metrics[StC.GLOBAL][metric][dataset][feature] += metrics[dataset][feature] + elif value == client_name: + # This is a client local metrics update + global_metrics[StC.LOCAL][metric][dataset][feature] += metrics[dataset][feature] + else: + break + if isinstance(value, list): + for item in value: + recursively_accumulate_hierarchical_metrics( + metric, client_name, metrics, item, dataset, feature, org + ) + + client_org = get_client_hierarchy(copy.deepcopy(hierarchy_config), client_name) + for dataset in metrics: + for feature in metrics[dataset]: + recursively_accumulate_hierarchical_metrics( + metric, client_name, metrics, global_metrics, dataset, feature, client_org + ) + + return global_metrics + + +def get_hierarchical_mins_or_maxs( + metric: str, client_name: str, metrics: dict, global_metrics: dict, hierarchy_config: dict +) -> dict: + """Calculate min or max at each hierarchical level. + + Args: + metric: Metric to accumulate. + client_name: Client name. + metrics: Client metrics. + global_metrics: The current global metrics. + hierarchy_config: Hierarchy configuration for the global stats. + + Returns: + A dict containing updated hierarchical global statistics with + accumulated mins or maxs. + """ + + def recursively_update_org_mins_or_maxs( + metric: str, + client_name: str, + metrics: dict, + global_metrics: dict, + dataset: str, + feature: str, + org: list, + op: str, + ) -> dict: + if isinstance(global_metrics, dict): + for key, value in global_metrics.items(): + if key == StC.GLOBAL and StC.NAME not in global_metrics: + if global_metrics[StC.GLOBAL][metric][dataset][feature]: + global_metrics[StC.GLOBAL][metric][dataset][feature] = op( + global_metrics[StC.GLOBAL][metric][dataset][feature], metrics[dataset][feature] + ) + else: + global_metrics[StC.GLOBAL][metric][dataset][feature] = metrics[dataset][feature] + continue + if key == StC.NAME: + if org and value in org: + # The client belongs to this org so update current global matrics before sending it further + if global_metrics[StC.GLOBAL][metric][dataset][feature]: + global_metrics[StC.GLOBAL][metric][dataset][feature] = op( + global_metrics[StC.GLOBAL][metric][dataset][feature], metrics[dataset][feature] + ) + else: + global_metrics[StC.GLOBAL][metric][dataset][feature] = metrics[dataset][feature] + elif value == client_name: + # This is a client local metrics update + global_metrics[StC.LOCAL][metric][dataset][feature] = metrics[dataset][feature] + else: + break + if isinstance(value, list): + for item in value: + recursively_update_org_mins_or_maxs( + metric, client_name, metrics, item, dataset, feature, org, op + ) + + if metric == "min": + op = min + else: + op = max + client_org = get_client_hierarchy(copy.deepcopy(hierarchy_config), client_name) + for dataset in metrics: + for feature in metrics[dataset]: + recursively_update_org_mins_or_maxs( + metric, client_name, metrics, global_metrics, dataset, feature, client_org, op + ) + + return global_metrics + + +def get_hierarchical_means(metric: str, global_metrics: dict) -> dict: + """Calculate means at each hierarchical level. + + Args: + metric: Metric to accumulate. + global_metrics: The current global metrics. + + Returns: + A dict containing updated hierarchical global statistics with + accumulated means. + """ + + def recursively_update_org_means(metrics: dict, global_metrics: dict, dataset: str, feature: str) -> dict: + if isinstance(global_metrics, dict): + for key, value in global_metrics.items(): + if key == StC.GLOBAL: + global_metrics[StC.GLOBAL][metric][dataset][feature] = ( + global_metrics[StC.GLOBAL][StC.STATS_SUM][dataset][feature] + / global_metrics[StC.GLOBAL][StC.STATS_COUNT][dataset][feature] + ) + if key == StC.LOCAL: + global_metrics[StC.LOCAL][metric][dataset][feature] = ( + global_metrics[StC.LOCAL][StC.STATS_SUM][dataset][feature] + / global_metrics[StC.LOCAL][StC.STATS_COUNT][dataset][feature] + ) + if isinstance(value, list): + for item in value: + recursively_update_org_means(metrics, item, dataset, feature) + + # Iterate each hierarchical level and calculate 'mean' from 'sum' and 'count'. + for dataset in global_metrics[StC.GLOBAL][StC.STATS_COUNT]: + for feature in global_metrics[StC.GLOBAL][StC.STATS_COUNT][dataset]: + recursively_update_org_means(metric, global_metrics, dataset, feature) + + return global_metrics + + +def get_hierarchical_histograms( + metric: str, client_name: str, metrics: dict, global_metrics: dict, hierarchy_config: dict +) -> dict: + """Calculate histograms at each hierarchical level. + + Args: + metric: Metric to accumulate. + client_name: Client name. + metrics: Client metrics. + global_metrics: The current global metrics. + hierarchy_config: Hierarchy configuration for the global stats. + + Returns: + A dict containing updated hierarchical global statistics with + accumulated histograms. + """ + + def recursively_accumulate_org_histograms( + metric: str, + client_name: str, + metrics: dict, + global_metrics: dict, + dataset: str, + feature: str, + org: list, + histogram: dict, + ) -> dict: + if isinstance(global_metrics, dict): + for key, value in global_metrics.items(): + if key == StC.GLOBAL and StC.NAME not in global_metrics: + if ( + feature not in global_metrics[StC.GLOBAL][metric][dataset] + or not global_metrics[StC.GLOBAL][metric][dataset][feature] + ): + g_bins = [] + for bucket in histogram.bins: + g_bins.append(Bin(bucket.low_value, bucket.high_value, bucket.sample_count)) + g_hist = Histogram(HistogramType.STANDARD, g_bins) + global_metrics[StC.GLOBAL][metric][dataset][feature] = g_hist + else: + g_hist = global_metrics[StC.GLOBAL][metric][dataset][feature] + g_buckets = bins_to_dict(g_hist.bins) + for bucket in histogram.bins: + bin_range = BinRange(bucket.low_value, bucket.high_value) + if bin_range in g_buckets: + g_buckets[bin_range] += bucket.sample_count + else: + g_buckets[bin_range] = bucket.sample_count + # update ordered bins + updated_bins = [] + for gb in g_hist.bins: + bin_range = BinRange(gb.low_value, gb.high_value) + updated_bins.append(Bin(gb.low_value, gb.high_value, g_buckets[bin_range])) + global_metrics[StC.GLOBAL][metric][dataset][feature] = Histogram(g_hist.hist_type, updated_bins) + continue + if key == StC.NAME: + if org and value in org: + # The client belongs to this org so update current global matrics before sending it further + if ( + feature not in global_metrics[StC.GLOBAL][metric][dataset] + or not global_metrics[StC.GLOBAL][metric][dataset][feature] + ): + g_bins = [] + for bucket in histogram.bins: + g_bins.append(Bin(bucket.low_value, bucket.high_value, bucket.sample_count)) + g_hist = Histogram(HistogramType.STANDARD, g_bins) + global_metrics[StC.GLOBAL][metric][dataset][feature] = g_hist + else: + g_hist = global_metrics[StC.GLOBAL][metric][dataset][feature] + g_buckets = bins_to_dict(g_hist.bins) + for bucket in histogram.bins: + bin_range = BinRange(bucket.low_value, bucket.high_value) + if bin_range in g_buckets: + g_buckets[bin_range] += bucket.sample_count + else: + g_buckets[bin_range] = bucket.sample_count + # update ordered bins + updated_bins = [] + for gb in g_hist.bins: + bin_range = BinRange(gb.low_value, gb.high_value) + updated_bins.append(Bin(gb.low_value, gb.high_value, g_buckets[bin_range])) + global_metrics[StC.GLOBAL][metric][dataset][feature] = Histogram( + g_hist.hist_type, updated_bins + ) + elif value == client_name: + # This is a client local metrics update + if ( + feature not in global_metrics[StC.LOCAL][metric][dataset] + or not global_metrics[StC.LOCAL][metric][dataset][feature] + ): + g_bins = [] + for bucket in histogram.bins: + g_bins.append(Bin(bucket.low_value, bucket.high_value, bucket.sample_count)) + g_hist = Histogram(HistogramType.STANDARD, g_bins) + global_metrics[StC.LOCAL][metric][dataset][feature] = g_hist + else: + g_hist = global_metrics[StC.LOCAL][metric][dataset][feature] + g_buckets = bins_to_dict(g_hist.bins) + for bucket in histogram.bins: + bin_range = BinRange(bucket.low_value, bucket.high_value) + if bin_range in g_buckets: + g_buckets[bin_range] += bucket.sample_count + else: + g_buckets[bin_range] = bucket.sample_count + # update ordered bins + updated_bins = [] + for gb in g_hist.bins: + bin_range = BinRange(gb.low_value, gb.high_value) + updated_bins.append(Bin(gb.low_value, gb.high_value, g_buckets[bin_range])) + global_metrics[StC.LOCAL][metric][dataset][feature] = Histogram( + g_hist.hist_type, updated_bins + ) + else: + break + if isinstance(value, list): + for item in value: + recursively_accumulate_org_histograms( + metric, client_name, metrics, item, dataset, feature, org, histogram + ) + + client_org = get_client_hierarchy(copy.deepcopy(hierarchy_config), client_name) + for dataset in metrics: + for feature in metrics[dataset]: + histogram = metrics[dataset][feature] + recursively_accumulate_org_histograms( + metric, client_name, metrics, global_metrics, dataset, feature, client_org, histogram + ) + + return global_metrics + + +def get_hierarchical_stddevs(global_metrics: dict) -> dict: + """Calculate stddevs at each hierarchical level. + + Args: + global_metrics: The current global metrics. + + Returns: + A dict containing updated hierarchical global statistics with + accumulated stddevs. + """ + + def recursively_update_org_stddevs(global_metrics: dict, dataset: str, feature: str) -> dict: + if isinstance(global_metrics, dict): + for key, value in global_metrics.items(): + if key == StC.GLOBAL: + global_metrics[StC.GLOBAL][StC.STATS_STDDEV][dataset][feature] = sqrt( + global_metrics[StC.GLOBAL][StC.STATS_VAR][dataset][feature] + ) + if key == StC.LOCAL: + global_metrics[StC.LOCAL][StC.STATS_STDDEV][dataset][feature] = sqrt( + global_metrics[StC.LOCAL][StC.STATS_VAR][dataset][feature] + ) + if isinstance(value, list): + for item in value: + recursively_update_org_stddevs(item, dataset, feature) + + for dataset in global_metrics[StC.GLOBAL][StC.STATS_VAR]: + for feature in global_metrics[StC.GLOBAL][StC.STATS_VAR][dataset]: + recursively_update_org_stddevs(global_metrics, dataset, feature) + + return global_metrics + + +def get_hierarchical_levels(data: dict, level: int = 0, levels_dict: dict = None) -> dict: + """Calculate number of hierarchical levels from the given hierarchy config. + + Args: + data: Hierarchy configuration for the global stats. + level: The current hierarchical level (used for recursive calls). + levels_dict: The accumulated levels dict (used for recursive calls). + + Returns: + A dict containing containing hierarchical levels. + """ + if levels_dict is None: + levels_dict = {} + + if isinstance(data, list): + for item in data: + get_hierarchical_levels(item, level, levels_dict) + elif isinstance(data, dict): + for key, value in data.items(): + if key == StC.NAME: + continue + if key not in levels_dict: + levels_dict[key] = level + get_hierarchical_levels(value, level + 1, levels_dict) + + return levels_dict + + +def get_client_hierarchy(hierarchy_config: dict, client_name: str, path=None) -> list: + """Calculate hierarchy for the given client name. + + Args: + hierarchy_config: Hierarchy configuration for the global stats. + client_name: Client name. + path: The accumulated hierarchy path (used for recursive calls). + + Returns: + A list containing hierarchy levels for the client. + """ + if path is None: + path = [] + + if isinstance(hierarchy_config, dict): + for key, value in hierarchy_config.items(): + if isinstance(value, list): + result = get_client_hierarchy(value, client_name, path) + if result: + return result + elif isinstance(hierarchy_config, list): + for item in hierarchy_config: + if item == client_name: + return path + if isinstance(item, dict): + result = get_client_hierarchy(item, client_name, path + [item.get(StC.NAME)]) + if result: + return result + + return None + + +def bins_to_dict(bins: List[Bin]) -> Dict[BinRange, float]: + """Convert histogram bins to a 'dict'. + + Args: + bins: Histogram bins. + + Returns: + A dict containing histogram bins. + """ + buckets = {} + for bucket in bins: + bucket_range = BinRange(bucket.low_value, bucket.high_value) + buckets[bucket_range] = bucket.sample_count + return buckets + + +def filter_numeric_features(ds_features: Dict[str, List[Feature]]) -> Dict[str, List[Feature]]: + """Filter numeric features. + + Args: + ds_features: A features dict. + + Returns: + A dict containing numeric features. + """ + numeric_ds_features = {} + for ds_name in ds_features: + features: List[Feature] = ds_features[ds_name] + n_features = [f for f in features if (f.data_type == DataType.INT or f.data_type == DataType.FLOAT)] + numeric_ds_features[ds_name] = n_features + + return numeric_ds_features diff --git a/nvflare/app_common/workflows/hierarchical_statistics_controller.py b/nvflare/app_common/workflows/hierarchical_statistics_controller.py new file mode 100644 index 0000000000..46d2975675 --- /dev/null +++ b/nvflare/app_common/workflows/hierarchical_statistics_controller.py @@ -0,0 +1,335 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from typing import Dict, List, Optional + +from nvflare.apis.controller_spec import Task +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal +from nvflare.app_common.abstract.statistics_spec import Histogram, StatisticConfig +from nvflare.app_common.app_constant import StatisticsConstants as StC +from nvflare.app_common.statistics.hierarchical_numeric_stats import get_global_stats +from nvflare.app_common.workflows.statistics_controller import StatisticsController +from nvflare.fuel.utils import fobs + + +class HierarchicalStatisticsController(StatisticsController): + def __init__( + self, + statistic_configs: Dict[str, dict], + writer_id: str, + wait_time_after_min_received: int = 1, + result_wait_timeout: int = 10, + precision=4, + min_clients: Optional[int] = None, + enable_pre_run_task: bool = True, + hierarchy_config: str = None, + ): + """Controller for hierarchical statistics. + + Args: + statistic_configs: defines the input statistic to be computed and each statistic's configuration, see below for details. + writer_id: ID for StatisticsWriter. The StatisticWriter will save the result to output specified by the + StatisticsWriter + wait_time_after_min_received: numbers of seconds to wait after minimum numer of clients specified has received. + result_wait_timeout: numbers of seconds to wait until we received all results. + Notice this is after the min_clients have arrived, and we wait for result process + callback, this becomes important if the data size to be processed is large + precision: number of precision digits + min_clients: if specified, min number of clients we have to wait before process. + hierarchy_config: Hierarchy specification file providing details about all the clients and their hierarchy. + + This class is derived from 'StatisticsController' and overrides only methods required to output calculated global + statistics in given hierarchical order. + + For statistic_configs, the key is one of statistics' names sum, count, mean, stddev, histogram, and + the value is the arguments needed. All other statistics except histogram require no argument. + + .. code-block:: text + + "statistic_configs": { + "count": {}, + "mean": {}, + "sum": {}, + "stddev": {}, + "histogram": { + "*": {"bins": 20}, + "Age": {"bins": 10, "range": [0, 120]} + } + }, + + Histogram requires the following arguments: + 1) numbers of bins or buckets of the histogram + 2) the histogram range values [min, max] + + These arguments are different for each feature. Here are few examples: + + .. code-block:: text + + "histogram": { + "*": {"bins": 20 }, + "Age": {"bins": 10, "range":[0,120]} + } + + The configuration specifies that the + feature 'Age' will have 10 bins for and the range is within [0, 120). + For all other features, the default ("*") configuration is used, with bins = 20. + The range of histogram is not specified, thus requires the Statistics controller + to dynamically estimate histogram range for each feature. Then this estimated global + range (est global min, est. global max) will be used as the histogram range. + + To dynamically estimate such a histogram range, we need the client to provide the local + min and max values in order to calculate the global bin and max value. In order to protect + data privacy and avoid data leakage, a noise level is added to the local min/max + value before sending to the controller. Therefore the controller only gets the 'estimated' + values, and the global min/max are estimated, or more accurately, they are noised global min/max + values. + + Here is another example: + + .. code-block:: text + + "histogram": { + "density": {"bins": 10, "range":[0,120]} + } + + In this example, there is no default histogram configuration for other features. + + This will work correctly if there is only one feature called "density" + but will fail if there are other features in the dataset. + + In the following configuration: + + .. code-block:: text + + "statistic_configs": { + "count": {}, + "mean": {}, + "stddev": {} + } + + Only count, mean and stddev statistics are specified, so the statistics_controller + will only set tasks to calculate these three statistics. + + For 'hierarchy_config', below is an example hierarchy specification with 4 level hierarchy for 9 NVFLARE clients with the names ranging + from 'Device-1' to 'Device-9' and with hierarchical levels named 'Manufacturers', 'Orgs', 'Locations', 'Devices' with + 'Manufacturers' being the top most hierarchical level and "Devices" being the lowest hierarchical level: + + .. code-block:: text + { + "Manufacturers": [ + { + "Name": "Manufacturer-1", + "Orgs": [ + { + "Name": "Org-1", + "Locations": [ + { + "Name": "Location-1", + "Devices": ["Device-1", "Device-2"] + }, + { + "Name": "Location-2", + "Devices": ["Device-3"] + } + ] + }, + { + "Name": "Org-2", + "Locations": [ + { + "Name": "Location-1", + "Devices": ["Device-4", "Device-5"] + }, + { + "Name": "Location-2", + "Devices": ["Device-6"] + } + ] + } + ] + }, + { + "Name": "Manufacturer-2", + "Orgs": [ + { + "Name": "Org-3", + "Locations": [ + { + "Name": "Location-1", + "Devices": ["Device-7", "Device-8"] + }, + { + "Name": "Location-6", + "Devices": ["Device-9"] + } + ] + } + ] + } + ] + } + + """ + super().__init__( + statistic_configs, + writer_id, + wait_time_after_min_received, + result_wait_timeout, + precision, + min_clients, + enable_pre_run_task, + ) + self.hierarchy_config = hierarchy_config + + def statistics_task_flow(self, abort_signal: Signal, fl_ctx: FLContext, statistic_task: str): + """Statistics task flow for the given task. + + Args: + abort_signal: Abort signal. + fl_ctx: The FLContext. + statistic_task: Statistics task. + """ + if self.hierarchy_config: + engine = fl_ctx.get_engine() + ws = engine.get_workspace() + app_conf_dir = ws.get_app_config_dir(fl_ctx.get_job_id()) + hierarchy_config_file_path = os.path.join(app_conf_dir, self.hierarchy_config) + try: + with open(hierarchy_config_file_path) as hierarchy_config_file: + hierarchy_config_json = json.load(hierarchy_config_file) + except FileNotFoundError: + self.system_panic(f"The hierarchy config file {hierarchy_config_file_path} does not exist.", fl_ctx) + return False + except IOError as e: + self.system_panic( + f"An I/O error occurred while loading hierarchy config file {hierarchy_config_file_path}: {e}", + fl_ctx, + ) + return False + except json.decoder.JSONDecodeError as e: + self.system_panic( + f"Failed to decode hierarchy config JSON from the file {hierarchy_config_file_path}: {e}", fl_ctx + ) + return False + except Exception as e: + self.system_panic( + f"An unexpected error occurred while loading hierarchy config file {hierarchy_config_file_path}: {e}", + fl_ctx, + ) + return False + + self.log_info(fl_ctx, f"start prepare inputs for task {statistic_task}") + inputs = self._prepare_inputs(statistic_task) + results_cb_fn = self._get_result_cb(statistic_task) + + self.log_info(fl_ctx, f"task: {self.task_name} statistics_flow for {statistic_task} started.") + + if abort_signal.triggered: + return False + + task_props = {StC.STATISTICS_TASK_KEY: statistic_task} + task = Task(name=self.task_name, data=inputs, result_received_cb=results_cb_fn, props=task_props) + + self.broadcast_and_wait( + task=task, + targets=None, + min_responses=self.min_clients, + fl_ctx=fl_ctx, + wait_time_after_min_received=self.wait_time_after_min_received, + abort_signal=abort_signal, + ) + + self.global_statistics = get_global_stats( + self.global_statistics, self.client_statistics, statistic_task, hierarchy_config_json + ) + + self.log_info(fl_ctx, f"task {self.task_name} statistics_flow for {statistic_task} flow end.") + + def _recursively_round_global_stats(self, global_stats): + """Apply given precision to the calculated global statistics. + + Args: + global_stats: Global stats. + + Returns: + A dict containing global stats with applied precision. + """ + if isinstance(global_stats, dict): + for key, value in global_stats.items(): + if key == StC.GLOBAL or key == StC.LOCAL: + for key, metric in value.items(): + if key == StC.STATS_HISTOGRAM: + for ds in metric: + for name, val in metric[ds].items(): + hist: Histogram = metric[ds][name] + buckets = StatisticsController._apply_histogram_precision(hist.bins, self.precision) + metric[ds][name] = buckets + else: + for ds in metric: + for name, val in metric[ds].items(): + metric[ds][name] = round(metric[ds][name], self.precision) + continue + if isinstance(value, list): + for item in value: + self._recursively_round_global_stats(item) + if isinstance(global_stats, list): + for item in global_stats: + self._recursively_round_global_stats(item) + + return global_stats + + def _combine_all_statistics(self): + """Get combined global statistics with precision applied. + + Returns: + A dict containing global statistics with precision applied. + """ + result = self.global_statistics + return self._recursively_round_global_stats(result) + + def _prepare_inputs(self, statistic_task: str) -> Shareable: + """Prepare inputs for the given task. + + Args: + statistic_task: Statistics task. + + Returns: + A dict containing inputs. + """ + inputs = Shareable() + target_statistics: List[StatisticConfig] = StatisticsController._get_target_statistics( + self.statistic_configs, StC.ordered_statistics[statistic_task] + ) + for tm in target_statistics: + if tm.name == StC.STATS_HISTOGRAM: + if StC.STATS_MIN in self.global_statistics[StC.GLOBAL]: + inputs[StC.STATS_MIN] = self.global_statistics[StC.GLOBAL][StC.STATS_MIN] + if StC.STATS_MAX in self.global_statistics[StC.GLOBAL]: + inputs[StC.STATS_MAX] = self.global_statistics[StC.GLOBAL][StC.STATS_MAX] + + if tm.name == StC.STATS_VAR: + if StC.STATS_COUNT in self.global_statistics[StC.GLOBAL]: + inputs[StC.STATS_GLOBAL_COUNT] = self.global_statistics[StC.GLOBAL][StC.STATS_COUNT] + if StC.STATS_MEAN in self.global_statistics[StC.GLOBAL]: + inputs[StC.STATS_GLOBAL_MEAN] = self.global_statistics[StC.GLOBAL][StC.STATS_MEAN] + + inputs[StC.STATISTICS_TASK_KEY] = statistic_task + + inputs[StC.STATS_TARGET_STATISTICS] = fobs.dumps(target_statistics) + + return inputs diff --git a/tests/unit_test/app_common/statistics/hierarchical_numeric_stats_test.py b/tests/unit_test/app_common/statistics/hierarchical_numeric_stats_test.py new file mode 100644 index 0000000000..e70d03bee2 --- /dev/null +++ b/tests/unit_test/app_common/statistics/hierarchical_numeric_stats_test.py @@ -0,0 +1,918 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nvflare.app_common.abstract.statistics_spec import Bin, Histogram, HistogramType +from nvflare.app_common.app_constant import StatisticsConstants as StC +from nvflare.app_common.workflows.statistics_controller import StatisticsController + +# Unit test upto four hierarchical levels +HIERARCHY_CONFIGS = [ + {"Sites": ["Site-1", "Site-2", "Site-3", "Site-4"]}, + { + "Manufacturers": [ + {"Name": "Manufacturer-1", "Sites": ["Site-1", "Site-2"]}, + {"Name": "Manufacturer-2", "Sites": ["Site-3", "Site-4"]}, + ] + }, + { + "Manufacturers": [ + { + "Name": "Manufacturer-1", + "Orgs": [{"Name": "Org-1", "Sites": ["Site-1"]}, {"Name": "Org-2", "Sites": ["Site-2"]}], + }, + {"Name": "Manufacturer-2", "Orgs": [{"Name": "Org-3", "Sites": ["Site-3", "Site-4"]}]}, + ] + }, + { + "Manufacturers": [ + { + "Name": "Manufacturer-1", + "Orgs": [ + {"Name": "Org-1", "Locations": [{"Name": "Location-1", "Sites": ["Site-1"]}]}, + {"Name": "Org-2", "Locations": [{"Name": "Location-1", "Sites": ["Site-2"]}]}, + ], + }, + { + "Name": "Manufacturer-2", + "Orgs": [ + { + "Name": "Org-3", + "Locations": [ + {"Name": "Location-1", "Sites": ["Site-3"]}, + {"Name": "Location-2", "Sites": ["Site-4"]}, + ], + } + ], + }, + ] + }, +] + +hist_bins = [] +hist_bins.append(Bin(0.0, 0.5, 50)) +hist_bins.append(Bin(0.5, 1.0, 50)) +g_hist = Histogram(HistogramType.STANDARD, hist_bins) + +CLIENT_STATS = { + "count": { + "Site-1": {"data_set1": {"Feature1": 100}}, + "Site-2": {"data_set1": {"Feature1": 200}}, + "Site-3": {"data_set1": {"Feature1": 300}}, + "Site-4": {"data_set1": {"Feature1": 400}}, + }, + "sum": { + "Site-1": {"data_set1": {"Feature1": 1000}}, + "Site-2": {"data_set1": {"Feature1": 2000}}, + "Site-3": {"data_set1": {"Feature1": 3000}}, + "Site-4": {"data_set1": {"Feature1": 4000}}, + }, + "max": { + "Site-1": {"data_set1": {"Feature1": 20}}, + "Site-2": {"data_set1": {"Feature1": 30}}, + "Site-3": {"data_set1": {"Feature1": 40}}, + "Site-4": {"data_set1": {"Feature1": 50}}, + }, + "min": { + "Site-1": {"data_set1": {"Feature1": 0}}, + "Site-2": {"data_set1": {"Feature1": 1}}, + "Site-3": {"data_set1": {"Feature1": 2}}, + "Site-4": {"data_set1": {"Feature1": 3}}, + }, + "mean": { + "Site-1": {"data_set1": {"Feature1": 10}}, + "Site-2": {"data_set1": {"Feature1": 10}}, + "Site-3": {"data_set1": {"Feature1": 10}}, + "Site-4": {"data_set1": {"Feature1": 10}}, + }, + "var": { + "Site-1": {"data_set1": {"Feature1": 0.1}}, + "Site-2": {"data_set1": {"Feature1": 0.1}}, + "Site-3": {"data_set1": {"Feature1": 0.1}}, + "Site-4": {"data_set1": {"Feature1": 0.1}}, + }, + "stddev": { + "Site-1": {"data_set1": {"Feature1": 0.1}}, + "Site-2": {"data_set1": {"Feature1": 0.1}}, + "Site-3": {"data_set1": {"Feature1": 0.1}}, + "Site-4": {"data_set1": {"Feature1": 0.1}}, + }, + "histogram": { + "Site-1": {"data_set1": {"Feature1": g_hist}}, + "Site-2": {"data_set1": {"Feature1": g_hist}}, + "Site-3": {"data_set1": {"Feature1": g_hist}}, + "Site-4": {"data_set1": {"Feature1": g_hist}}, + }, +} + +global_stats_0 = { + "Global": { + "count": {"data_set1": {"Feature1": 1000}}, + "sum": {"data_set1": {"Feature1": 10000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=200), + Bin(low_value=0.5, high_value=1.0, sample_count=200), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.4}}, + "stddev": {"data_set1": {"Feature1": 0.63}}, + }, + "Sites": [ + { + "Name": "Site-1", + "Local": { + "count": {"data_set1": {"Feature1": 100}}, + "sum": {"data_set1": {"Feature1": 1000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 0}}, + "max": {"data_set1": {"Feature1": 20}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Site-2", + "Local": { + "count": {"data_set1": {"Feature1": 200}}, + "sum": {"data_set1": {"Feature1": 2000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Site-3", + "Local": { + "count": {"data_set1": {"Feature1": 300}}, + "sum": {"data_set1": {"Feature1": 3000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 40}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Site-4", + "Local": { + "count": {"data_set1": {"Feature1": 400}}, + "sum": {"data_set1": {"Feature1": 4000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 3}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + ], +} +global_stats_1 = { + "Global": { + "count": {"data_set1": {"Feature1": 1000}}, + "sum": {"data_set1": {"Feature1": 10000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=200), + Bin(low_value=0.5, high_value=1.0, sample_count=200), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.4}}, + "stddev": {"data_set1": {"Feature1": 0.63}}, + }, + "Manufacturers": [ + { + "Name": "Manufacturer-1", + "Sites": [ + { + "Name": "Site-1", + "Local": { + "count": {"data_set1": {"Feature1": 100}}, + "sum": {"data_set1": {"Feature1": 1000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 0}}, + "max": {"data_set1": {"Feature1": 20}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Site-2", + "Local": { + "count": {"data_set1": {"Feature1": 200}}, + "sum": {"data_set1": {"Feature1": 2000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + ], + "Global": { + "count": {"data_set1": {"Feature1": 300}}, + "sum": {"data_set1": {"Feature1": 3000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=100), + Bin(low_value=0.5, high_value=1.0, sample_count=100), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.2}}, + "stddev": {"data_set1": {"Feature1": 0.45}}, + }, + }, + { + "Name": "Manufacturer-2", + "Sites": [ + { + "Name": "Site-3", + "Local": { + "count": {"data_set1": {"Feature1": 300}}, + "sum": {"data_set1": {"Feature1": 3000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 40}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Site-4", + "Local": { + "count": {"data_set1": {"Feature1": 400}}, + "sum": {"data_set1": {"Feature1": 4000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 3}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + ], + "Global": { + "count": {"data_set1": {"Feature1": 700}}, + "sum": {"data_set1": {"Feature1": 7000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=100), + Bin(low_value=0.5, high_value=1.0, sample_count=100), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.2}}, + "stddev": {"data_set1": {"Feature1": 0.45}}, + }, + }, + ], +} +global_stats_2 = { + "Global": { + "count": {"data_set1": {"Feature1": 1000}}, + "sum": {"data_set1": {"Feature1": 10000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=200), + Bin(low_value=0.5, high_value=1.0, sample_count=200), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.4}}, + "stddev": {"data_set1": {"Feature1": 0.63}}, + }, + "Manufacturers": [ + { + "Name": "Manufacturer-1", + "Orgs": [ + { + "Name": "Org-1", + "Sites": [ + { + "Name": "Site-1", + "Local": { + "count": {"data_set1": {"Feature1": 100}}, + "sum": {"data_set1": {"Feature1": 1000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 0}}, + "max": {"data_set1": {"Feature1": 20}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 100}}, + "sum": {"data_set1": {"Feature1": 1000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 0}}, + "max": {"data_set1": {"Feature1": 20}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Org-2", + "Sites": [ + { + "Name": "Site-2", + "Local": { + "count": {"data_set1": {"Feature1": 200}}, + "sum": {"data_set1": {"Feature1": 2000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 200}}, + "sum": {"data_set1": {"Feature1": 2000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + ], + "Global": { + "count": {"data_set1": {"Feature1": 300}}, + "sum": {"data_set1": {"Feature1": 3000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=100), + Bin(low_value=0.5, high_value=1.0, sample_count=100), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.2}}, + "stddev": {"data_set1": {"Feature1": 0.45}}, + }, + }, + { + "Name": "Manufacturer-2", + "Orgs": [ + { + "Name": "Org-3", + "Sites": [ + { + "Name": "Site-3", + "Local": { + "count": {"data_set1": {"Feature1": 300}}, + "sum": {"data_set1": {"Feature1": 3000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 40}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Site-4", + "Local": { + "count": {"data_set1": {"Feature1": 400}}, + "sum": {"data_set1": {"Feature1": 4000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 3}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + ], + "Global": { + "count": {"data_set1": {"Feature1": 700}}, + "sum": {"data_set1": {"Feature1": 7000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=100), + Bin(low_value=0.5, high_value=1.0, sample_count=100), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.2}}, + "stddev": {"data_set1": {"Feature1": 0.45}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 700}}, + "sum": {"data_set1": {"Feature1": 7000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=100), + Bin(low_value=0.5, high_value=1.0, sample_count=100), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.2}}, + "stddev": {"data_set1": {"Feature1": 0.45}}, + }, + }, + ], +} +global_stats_3 = { + "Global": { + "count": {"data_set1": {"Feature1": 1000}}, + "sum": {"data_set1": {"Feature1": 10000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=200), + Bin(low_value=0.5, high_value=1.0, sample_count=200), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.4}}, + "stddev": {"data_set1": {"Feature1": 0.63}}, + }, + "Manufacturers": [ + { + "Name": "Manufacturer-1", + "Orgs": [ + { + "Name": "Org-1", + "Locations": [ + { + "Name": "Location-1", + "Sites": [ + { + "Name": "Site-1", + "Local": { + "count": {"data_set1": {"Feature1": 100}}, + "sum": {"data_set1": {"Feature1": 1000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 0}}, + "max": {"data_set1": {"Feature1": 20}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 100}}, + "sum": {"data_set1": {"Feature1": 1000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 0}}, + "max": {"data_set1": {"Feature1": 20}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 100}}, + "sum": {"data_set1": {"Feature1": 1000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 0}}, + "max": {"data_set1": {"Feature1": 20}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Org-2", + "Locations": [ + { + "Name": "Location-1", + "Sites": [ + { + "Name": "Site-2", + "Local": { + "count": {"data_set1": {"Feature1": 200}}, + "sum": {"data_set1": {"Feature1": 2000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 200}}, + "sum": {"data_set1": {"Feature1": 2000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 200}}, + "sum": {"data_set1": {"Feature1": 2000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + ], + "Global": { + "count": {"data_set1": {"Feature1": 300}}, + "sum": {"data_set1": {"Feature1": 3000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=100), + Bin(low_value=0.5, high_value=1.0, sample_count=100), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.2}}, + "stddev": {"data_set1": {"Feature1": 0.45}}, + }, + }, + { + "Name": "Manufacturer-2", + "Orgs": [ + { + "Name": "Org-3", + "Locations": [ + { + "Name": "Location-1", + "Sites": [ + { + "Name": "Site-3", + "Local": { + "count": {"data_set1": {"Feature1": 300}}, + "sum": {"data_set1": {"Feature1": 3000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 40}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 300}}, + "sum": {"data_set1": {"Feature1": 3000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 40}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Location-2", + "Sites": [ + { + "Name": "Site-4", + "Local": { + "count": {"data_set1": {"Feature1": 400}}, + "sum": {"data_set1": {"Feature1": 4000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 3}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 400}}, + "sum": {"data_set1": {"Feature1": 4000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 3}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + ], + "Global": { + "count": {"data_set1": {"Feature1": 700}}, + "sum": {"data_set1": {"Feature1": 7000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=100), + Bin(low_value=0.5, high_value=1.0, sample_count=100), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.2}}, + "stddev": {"data_set1": {"Feature1": 0.45}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 700}}, + "sum": {"data_set1": {"Feature1": 7000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=100), + Bin(low_value=0.5, high_value=1.0, sample_count=100), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.2}}, + "stddev": {"data_set1": {"Feature1": 0.45}}, + }, + }, + ], +} + +PARAMS = [ + (CLIENT_STATS, global_stats_0, HIERARCHY_CONFIGS[0]), + (CLIENT_STATS, global_stats_1, HIERARCHY_CONFIGS[1]), + (CLIENT_STATS, global_stats_2, HIERARCHY_CONFIGS[2]), + (CLIENT_STATS, global_stats_3, HIERARCHY_CONFIGS[3]), +] + + +class TestHierarchicalNumericStats: + def _round_global_stats(self, global_stats): + if isinstance(global_stats, dict): + for key, value in global_stats.items(): + if key == StC.GLOBAL or key == StC.LOCAL: + for key, metric in value.items(): + if key == StC.STATS_HISTOGRAM: + for ds in metric: + for name, val in metric[ds].items(): + hist: Histogram = metric[ds][name] + buckets = StatisticsController._apply_histogram_precision(hist.bins, 2) + metric[ds][name] = buckets + else: + for ds in metric: + for name, val in metric[ds].items(): + metric[ds][name] = round(metric[ds][name], 2) + continue + if isinstance(value, list): + for item in value: + self._round_global_stats(item) + if isinstance(global_stats, list): + for item in global_stats: + self._round_global_stats(item) + return global_stats + + @pytest.mark.parametrize("client_stats, expected_global_stats, hierarchy_configs", PARAMS) + def test_global_stats(self, client_stats, expected_global_stats, hierarchy_configs): + from nvflare.app_common.statistics.hierarchical_numeric_stats import get_global_stats + + global_stats = get_global_stats({}, client_stats, StC.STATS_1st_STATISTICS, hierarchy_configs) + global_stats = get_global_stats(global_stats, client_stats, StC.STATS_2nd_STATISTICS, hierarchy_configs) + result = self._round_global_stats(global_stats) + + assert expected_global_stats == result