diff --git a/nvflare/app_common/app_constant.py b/nvflare/app_common/app_constant.py index ec92e2eccd..98b8bc21d6 100644 --- a/nvflare/app_common/app_constant.py +++ b/nvflare/app_common/app_constant.py @@ -180,6 +180,8 @@ class StatisticsConstants(AppConstants): STATS_2nd_STATISTICS = "fed_stats_2nd_statistics" GLOBAL = "Global" + LOCAL = "Local" + NAME = "Name" ordered_statistics = { STATS_1st_STATISTICS: [STATS_COUNT, STATS_FAILURE_COUNT, STATS_SUM, STATS_MEAN, STATS_MIN, STATS_MAX], diff --git a/nvflare/app_common/statistics/hierarchical_numeric_stats.py b/nvflare/app_common/statistics/hierarchical_numeric_stats.py new file mode 100644 index 0000000000..0f905e7502 --- /dev/null +++ b/nvflare/app_common/statistics/hierarchical_numeric_stats.py @@ -0,0 +1,611 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from math import sqrt +from typing import Dict, List, TypeVar + +from nvflare.app_common.abstract.statistics_spec import Bin, BinRange, DataType, Feature, Histogram, HistogramType +from nvflare.app_common.app_constant import StatisticsConstants as StC + +T = TypeVar("T") + + +def get_initial_structure(client_metrics: dict, ordered_metrics: dict) -> dict: + """Calculate initial output structure that is common at all the hierarchical levels. + + Args: + client_metrics: Local stats for each client. + ordered_metrics: Ordered target metrics. + + Returns: + A dict containing initial output structure. + """ + stats = {} + for metric in ordered_metrics: + stats[metric] = {} + for stat in client_metrics: + for site in client_metrics[stat]: + for ds in client_metrics[stat][site]: + stats[metric][ds] = {} + for feature in client_metrics[stat][site][ds]: + stats[metric][ds][feature] = 0 + return stats + + +def create_output_structure( + client_metrics: dict, metric_task: str, ordered_metrics: dict, hierarchy_config: dict +) -> dict: + """Recursively calculate the hierarchical global stats structure from the given hierarchy config. + + Args: + client_metrics: Local stats for each client. + metric_task: Statistics task. + ordered_metrics: Ordered target metrics. + hierarchy_config: Hierarchy configuration for the global stats. + + Returns: + A dict containing hierarchical global stats structure. + """ + + def recursively_add_values(structure: dict, value_json: dict, metric_task: str, ordered_metrics: dict): + if isinstance(structure, dict): + new_items = {} + for key, value in list(structure.items()): + if key == StC.NAME: + continue + if isinstance(value, list): + if key not in new_items: + new_items[StC.GLOBAL] = get_initial_structure(value_json, ordered_metrics) + for i, item in enumerate(value): + if isinstance(item, str): + value[i] = { + StC.NAME: item, + StC.LOCAL: get_initial_structure(value_json, ordered_metrics), + } + else: + recursively_add_values(item, value_json, metric_task, ordered_metrics) + else: + recursively_add_values(value, value_json, metric_task, ordered_metrics) + structure.update(new_items) + elif isinstance(structure, list): + for item in structure: + recursively_add_values(item, value_json, metric_task, ordered_metrics) + return structure + + filled_structure = copy.deepcopy(hierarchy_config) + final_strcture = recursively_add_values(filled_structure, client_metrics, metric_task, ordered_metrics) + return final_strcture + + +def get_output_structure(client_metrics: dict, metric_task: str, ordered_metrics: dict, hierarchy_config: dict) -> dict: + """Create required global statistics hierarchical output structure. + + Args: + client_metrics: Local stats for each client. + metric_task: Statistics task. + ordered_metrics: Ordered target metrics. + hierarchy_config: Hierarchy configuration for the global stats. + + Returns: + A dict containing hierarchical global stats structure that also includes + top level global stats structure. + """ + top_strcture = get_initial_structure(client_metrics, ordered_metrics) + output_structure = { + StC.GLOBAL: top_strcture, + **create_output_structure(client_metrics, metric_task, ordered_metrics, hierarchy_config), + } + return output_structure + + +def update_output_strcture( + client_metrics: dict, + metric_task: str, + ordered_metrics: dict, + global_metrics: dict, +) -> None: + """Update global statistics hierarchical output structure with the new ordered metrics. + + Args: + client_metrics: Local stats for each client. + metric_task: Statistics task. + ordered_metrics: Ordered target metrics. + global_metrics: The current global metrics. + + Returns: + A dict containing updated hierarchical global stats. + """ + if isinstance(global_metrics, dict): + for key, value in list(global_metrics.items()): + if key == StC.NAME: + continue + elif key == StC.GLOBAL: + global_metrics[key].update(get_initial_structure(client_metrics, ordered_metrics)) + elif key == StC.LOCAL: + global_metrics[key].update(get_initial_structure(client_metrics, ordered_metrics)) + return + elif isinstance(value, list): + update_output_strcture(client_metrics, metric_task, ordered_metrics, value) + elif isinstance(global_metrics, list): + for item in global_metrics: + update_output_strcture(client_metrics, metric_task, ordered_metrics, item) + + +def get_global_stats(global_metrics: dict, client_metrics: dict, metric_task: str, hierarchy_config: dict) -> dict: + """Get global hierarchical statistics for the given hierarchy config. + + Args: + global_metrics: The current global metrics. + client_metrics: Local stats for each client. + metric_task: Statistics task. + hierarchy_config: Hierarchy configuration for the global stats. + + + Returns: + A dict containing global hierarchical statistics. + """ + # create stats structure + ordered_target_metrics = StC.ordered_statistics[metric_task] + ordered_metrics = [metric for metric in ordered_target_metrics if metric in client_metrics] + + # Create hierarchical output structure + if StC.GLOBAL not in global_metrics: + global_metrics = get_output_structure(client_metrics, metric_task, ordered_metrics, hierarchy_config) + else: + update_output_strcture(client_metrics, metric_task, ordered_metrics, global_metrics) + + for metric in ordered_metrics: + stats = client_metrics[metric] + if metric == StC.STATS_COUNT or metric == StC.STATS_FAILURE_COUNT or metric == StC.STATS_SUM: + for client_name in stats: + global_metrics = accumulate_hierarchical_metrics( + metric, client_name, stats[client_name], global_metrics, hierarchy_config + ) + elif metric == StC.STATS_MAX or metric == StC.STATS_MIN: + for client_name in stats: + global_metrics = get_hierarchical_mins_or_maxs( + metric, client_name, stats[client_name], global_metrics, hierarchy_config + ) + elif metric == StC.STATS_MEAN: + global_metrics = get_hierarchical_means(metric, global_metrics) + elif metric == StC.STATS_HISTOGRAM: + for client_name in stats: + global_metrics = get_hierarchical_histograms( + metric, client_name, stats[client_name], global_metrics, hierarchy_config + ) + elif metric == StC.STATS_VAR: + for client_name in stats: + global_metrics = accumulate_hierarchical_metrics( + metric, client_name, stats[client_name], global_metrics, hierarchy_config + ) + elif metric == StC.STATS_STDDEV: + global_metrics = get_hierarchical_stddevs(global_metrics) + + return global_metrics + + +def accumulate_hierarchical_metrics( + metric: str, client_name: str, metrics: dict, global_metrics: dict, hierarchy_config: dict +) -> dict: + """Accumulate matrics at each hierarchical level. + + Args: + metric: Metric to accumulate. + client_name: Client name. + metrics: Client metrics. + global_metrics: The current global metrics. + hierarchy_config: Hierarchy configuration for the global stats. + + Returns: + A dict containing accumulated hierarchical global statistics. + """ + + def recursively_accumulate_hierarchical_metrics( + metric: str, client_name: str, metrics: dict, global_metrics: dict, dataset: str, feature: str, org: list + ) -> dict: + if isinstance(global_metrics, dict): + for key, value in global_metrics.items(): + if key == StC.GLOBAL and StC.NAME not in global_metrics: + global_metrics[StC.GLOBAL][metric][dataset][feature] += metrics[dataset][feature] + continue + if key == StC.NAME: + if org and value in org: + # The client belongs to this org so update current global matrics before sending it further + global_metrics[StC.GLOBAL][metric][dataset][feature] += metrics[dataset][feature] + elif value == client_name: + # This is a client local metrics update + global_metrics[StC.LOCAL][metric][dataset][feature] += metrics[dataset][feature] + else: + break + if isinstance(value, list): + for item in value: + recursively_accumulate_hierarchical_metrics( + metric, client_name, metrics, item, dataset, feature, org + ) + + client_org = get_client_hierarchy(copy.deepcopy(hierarchy_config), client_name) + for dataset in metrics: + for feature in metrics[dataset]: + recursively_accumulate_hierarchical_metrics( + metric, client_name, metrics, global_metrics, dataset, feature, client_org + ) + + return global_metrics + + +def get_hierarchical_mins_or_maxs( + metric: str, client_name: str, metrics: dict, global_metrics: dict, hierarchy_config: dict +) -> dict: + """Calculate min or max at each hierarchical level. + + Args: + metric: Metric to accumulate. + client_name: Client name. + metrics: Client metrics. + global_metrics: The current global metrics. + hierarchy_config: Hierarchy configuration for the global stats. + + Returns: + A dict containing updated hierarchical global statistics with + accumulated mins or maxs. + """ + + def recursively_update_org_mins_or_maxs( + metric: str, + client_name: str, + metrics: dict, + global_metrics: dict, + dataset: str, + feature: str, + org: list, + op: str, + ) -> dict: + if isinstance(global_metrics, dict): + for key, value in global_metrics.items(): + if key == StC.GLOBAL and StC.NAME not in global_metrics: + if global_metrics[StC.GLOBAL][metric][dataset][feature]: + global_metrics[StC.GLOBAL][metric][dataset][feature] = op( + global_metrics[StC.GLOBAL][metric][dataset][feature], metrics[dataset][feature] + ) + else: + global_metrics[StC.GLOBAL][metric][dataset][feature] = metrics[dataset][feature] + continue + if key == StC.NAME: + if org and value in org: + # The client belongs to this org so update current global matrics before sending it further + if global_metrics[StC.GLOBAL][metric][dataset][feature]: + global_metrics[StC.GLOBAL][metric][dataset][feature] = op( + global_metrics[StC.GLOBAL][metric][dataset][feature], metrics[dataset][feature] + ) + else: + global_metrics[StC.GLOBAL][metric][dataset][feature] = metrics[dataset][feature] + elif value == client_name: + # This is a client local metrics update + global_metrics[StC.LOCAL][metric][dataset][feature] = metrics[dataset][feature] + else: + break + if isinstance(value, list): + for item in value: + recursively_update_org_mins_or_maxs( + metric, client_name, metrics, item, dataset, feature, org, op + ) + + if metric == "min": + op = min + else: + op = max + client_org = get_client_hierarchy(copy.deepcopy(hierarchy_config), client_name) + for dataset in metrics: + for feature in metrics[dataset]: + recursively_update_org_mins_or_maxs( + metric, client_name, metrics, global_metrics, dataset, feature, client_org, op + ) + + return global_metrics + + +def get_hierarchical_means(metric: str, global_metrics: dict) -> dict: + """Calculate means at each hierarchical level. + + Args: + metric: Metric to accumulate. + global_metrics: The current global metrics. + + Returns: + A dict containing updated hierarchical global statistics with + accumulated means. + """ + + def recursively_update_org_means(metrics: dict, global_metrics: dict, dataset: str, feature: str) -> dict: + if isinstance(global_metrics, dict): + for key, value in global_metrics.items(): + if key == StC.GLOBAL: + global_metrics[StC.GLOBAL][metric][dataset][feature] = ( + global_metrics[StC.GLOBAL][StC.STATS_SUM][dataset][feature] + / global_metrics[StC.GLOBAL][StC.STATS_COUNT][dataset][feature] + ) + if key == StC.LOCAL: + global_metrics[StC.LOCAL][metric][dataset][feature] = ( + global_metrics[StC.LOCAL][StC.STATS_SUM][dataset][feature] + / global_metrics[StC.LOCAL][StC.STATS_COUNT][dataset][feature] + ) + if isinstance(value, list): + for item in value: + recursively_update_org_means(metrics, item, dataset, feature) + + # Iterate each hierarchical level and calculate 'mean' from 'sum' and 'count'. + for dataset in global_metrics[StC.GLOBAL][StC.STATS_COUNT]: + for feature in global_metrics[StC.GLOBAL][StC.STATS_COUNT][dataset]: + recursively_update_org_means(metric, global_metrics, dataset, feature) + + return global_metrics + + +def get_hierarchical_histograms( + metric: str, client_name: str, metrics: dict, global_metrics: dict, hierarchy_config: dict +) -> dict: + """Calculate histograms at each hierarchical level. + + Args: + metric: Metric to accumulate. + client_name: Client name. + metrics: Client metrics. + global_metrics: The current global metrics. + hierarchy_config: Hierarchy configuration for the global stats. + + Returns: + A dict containing updated hierarchical global statistics with + accumulated histograms. + """ + + def recursively_accumulate_org_histograms( + metric: str, + client_name: str, + metrics: dict, + global_metrics: dict, + dataset: str, + feature: str, + org: list, + histogram: dict, + ) -> dict: + if isinstance(global_metrics, dict): + for key, value in global_metrics.items(): + if key == StC.GLOBAL and StC.NAME not in global_metrics: + if ( + feature not in global_metrics[StC.GLOBAL][metric][dataset] + or not global_metrics[StC.GLOBAL][metric][dataset][feature] + ): + g_bins = [] + for bucket in histogram.bins: + g_bins.append(Bin(bucket.low_value, bucket.high_value, bucket.sample_count)) + g_hist = Histogram(HistogramType.STANDARD, g_bins) + global_metrics[StC.GLOBAL][metric][dataset][feature] = g_hist + else: + g_hist = global_metrics[StC.GLOBAL][metric][dataset][feature] + g_buckets = bins_to_dict(g_hist.bins) + for bucket in histogram.bins: + bin_range = BinRange(bucket.low_value, bucket.high_value) + if bin_range in g_buckets: + g_buckets[bin_range] += bucket.sample_count + else: + g_buckets[bin_range] = bucket.sample_count + # update ordered bins + updated_bins = [] + for gb in g_hist.bins: + bin_range = BinRange(gb.low_value, gb.high_value) + updated_bins.append(Bin(gb.low_value, gb.high_value, g_buckets[bin_range])) + global_metrics[StC.GLOBAL][metric][dataset][feature] = Histogram(g_hist.hist_type, updated_bins) + continue + if key == StC.NAME: + if org and value in org: + # The client belongs to this org so update current global matrics before sending it further + if ( + feature not in global_metrics[StC.GLOBAL][metric][dataset] + or not global_metrics[StC.GLOBAL][metric][dataset][feature] + ): + g_bins = [] + for bucket in histogram.bins: + g_bins.append(Bin(bucket.low_value, bucket.high_value, bucket.sample_count)) + g_hist = Histogram(HistogramType.STANDARD, g_bins) + global_metrics[StC.GLOBAL][metric][dataset][feature] = g_hist + else: + g_hist = global_metrics[StC.GLOBAL][metric][dataset][feature] + g_buckets = bins_to_dict(g_hist.bins) + for bucket in histogram.bins: + bin_range = BinRange(bucket.low_value, bucket.high_value) + if bin_range in g_buckets: + g_buckets[bin_range] += bucket.sample_count + else: + g_buckets[bin_range] = bucket.sample_count + # update ordered bins + updated_bins = [] + for gb in g_hist.bins: + bin_range = BinRange(gb.low_value, gb.high_value) + updated_bins.append(Bin(gb.low_value, gb.high_value, g_buckets[bin_range])) + global_metrics[StC.GLOBAL][metric][dataset][feature] = Histogram( + g_hist.hist_type, updated_bins + ) + elif value == client_name: + # This is a client local metrics update + if ( + feature not in global_metrics[StC.LOCAL][metric][dataset] + or not global_metrics[StC.LOCAL][metric][dataset][feature] + ): + g_bins = [] + for bucket in histogram.bins: + g_bins.append(Bin(bucket.low_value, bucket.high_value, bucket.sample_count)) + g_hist = Histogram(HistogramType.STANDARD, g_bins) + global_metrics[StC.LOCAL][metric][dataset][feature] = g_hist + else: + g_hist = global_metrics[StC.LOCAL][metric][dataset][feature] + g_buckets = bins_to_dict(g_hist.bins) + for bucket in histogram.bins: + bin_range = BinRange(bucket.low_value, bucket.high_value) + if bin_range in g_buckets: + g_buckets[bin_range] += bucket.sample_count + else: + g_buckets[bin_range] = bucket.sample_count + # update ordered bins + updated_bins = [] + for gb in g_hist.bins: + bin_range = BinRange(gb.low_value, gb.high_value) + updated_bins.append(Bin(gb.low_value, gb.high_value, g_buckets[bin_range])) + global_metrics[StC.LOCAL][metric][dataset][feature] = Histogram( + g_hist.hist_type, updated_bins + ) + else: + break + if isinstance(value, list): + for item in value: + recursively_accumulate_org_histograms( + metric, client_name, metrics, item, dataset, feature, org, histogram + ) + + client_org = get_client_hierarchy(copy.deepcopy(hierarchy_config), client_name) + for dataset in metrics: + for feature in metrics[dataset]: + histogram = metrics[dataset][feature] + recursively_accumulate_org_histograms( + metric, client_name, metrics, global_metrics, dataset, feature, client_org, histogram + ) + + return global_metrics + + +def get_hierarchical_stddevs(global_metrics: dict) -> dict: + """Calculate stddevs at each hierarchical level. + + Args: + global_metrics: The current global metrics. + + Returns: + A dict containing updated hierarchical global statistics with + accumulated stddevs. + """ + + def recursively_update_org_stddevs(global_metrics: dict, dataset: str, feature: str) -> dict: + if isinstance(global_metrics, dict): + for key, value in global_metrics.items(): + if key == StC.GLOBAL: + global_metrics[StC.GLOBAL][StC.STATS_STDDEV][dataset][feature] = sqrt( + global_metrics[StC.GLOBAL][StC.STATS_VAR][dataset][feature] + ) + if key == StC.LOCAL: + global_metrics[StC.LOCAL][StC.STATS_STDDEV][dataset][feature] = sqrt( + global_metrics[StC.LOCAL][StC.STATS_VAR][dataset][feature] + ) + if isinstance(value, list): + for item in value: + recursively_update_org_stddevs(item, dataset, feature) + + for dataset in global_metrics[StC.GLOBAL][StC.STATS_VAR]: + for feature in global_metrics[StC.GLOBAL][StC.STATS_VAR][dataset]: + recursively_update_org_stddevs(global_metrics, dataset, feature) + + return global_metrics + + +def get_hierarchical_levels(data: dict, level: int = 0, levels_dict: dict = None) -> dict: + """Calculate number of hierarchical levels from the given hierarchy config. + + Args: + data: Hierarchy configuration for the global stats. + level: The current hierarchical level (used for recursive calls). + levels_dict: The accumulated levels dict (used for recursive calls). + + Returns: + A dict containing containing hierarchical levels. + """ + if levels_dict is None: + levels_dict = {} + + if isinstance(data, list): + for item in data: + get_hierarchical_levels(item, level, levels_dict) + elif isinstance(data, dict): + for key, value in data.items(): + if key == StC.NAME: + continue + if key not in levels_dict: + levels_dict[key] = level + get_hierarchical_levels(value, level + 1, levels_dict) + + return levels_dict + + +def get_client_hierarchy(hierarchy_config: dict, client_name: str, path=None) -> list: + """Calculate hierarchy for the given client name. + + Args: + hierarchy_config: Hierarchy configuration for the global stats. + client_name: Client name. + path: The accumulated hierarchy path (used for recursive calls). + + Returns: + A list containing hierarchy levels for the client. + """ + if path is None: + path = [] + + if isinstance(hierarchy_config, dict): + for key, value in hierarchy_config.items(): + if isinstance(value, list): + result = get_client_hierarchy(value, client_name, path) + if result: + return result + elif isinstance(hierarchy_config, list): + for item in hierarchy_config: + if item == client_name: + return path + if isinstance(item, dict): + result = get_client_hierarchy(item, client_name, path + [item.get(StC.NAME)]) + if result: + return result + + return None + + +def bins_to_dict(bins: List[Bin]) -> Dict[BinRange, float]: + """Convert histogram bins to a 'dict'. + + Args: + bins: Histogram bins. + + Returns: + A dict containing histogram bins. + """ + buckets = {} + for bucket in bins: + bucket_range = BinRange(bucket.low_value, bucket.high_value) + buckets[bucket_range] = bucket.sample_count + return buckets + + +def filter_numeric_features(ds_features: Dict[str, List[Feature]]) -> Dict[str, List[Feature]]: + """Filter numeric features. + + Args: + ds_features: A features dict. + + Returns: + A dict containing numeric features. + """ + numeric_ds_features = {} + for ds_name in ds_features: + features: List[Feature] = ds_features[ds_name] + n_features = [f for f in features if (f.data_type == DataType.INT or f.data_type == DataType.FLOAT)] + numeric_ds_features[ds_name] = n_features + + return numeric_ds_features diff --git a/nvflare/app_common/workflows/hierarchical_statistics_controller.py b/nvflare/app_common/workflows/hierarchical_statistics_controller.py new file mode 100644 index 0000000000..d53490860d --- /dev/null +++ b/nvflare/app_common/workflows/hierarchical_statistics_controller.py @@ -0,0 +1,337 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from typing import Dict, List, Optional + +from nvflare.apis.controller_spec import Task +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal +from nvflare.app_common.abstract.statistics_spec import Histogram, StatisticConfig +from nvflare.app_common.app_constant import StatisticsConstants as StC +from nvflare.app_common.statistics.hierarchical_numeric_stats import get_global_stats +from nvflare.app_common.workflows.statistics_controller import StatisticsController +from nvflare.fuel.utils import fobs + + +class HierarchicalStatisticsController(StatisticsController): + def __init__( + self, + statistic_configs: Dict[str, dict], + writer_id: str, + wait_time_after_min_received: int = 1, + result_wait_timeout: int = 10, + precision=4, + min_clients: Optional[int] = None, + enable_pre_run_task: bool = True, + hierarchy_config: str = None, + ): + """Controller for hierarchical statistics. + + Args: + statistic_configs: defines the input statistic to be computed and each statistic's configuration, see below for details. + writer_id: ID for StatisticsWriter. The StatisticWriter will save the result to output specified by the + StatisticsWriter + wait_time_after_min_received: numbers of seconds to wait after minimum numer of clients specified has received. + result_wait_timeout: numbers of seconds to wait until we received all results. + Notice this is after the min_clients have arrived, and we wait for result process + callback, this becomes important if the data size to be processed is large + precision: number of precision digits + min_clients: if specified, min number of clients we have to wait before process. + hierarchy_config: Hierarchy specification file providing details about all the clients and their hierarchy. + + This class is derived from 'StatisticsController' and overrides only methods required to output calculated global + statistics in given hierarchical order. + + For statistic_configs, the key is one of statistics' names sum, count, mean, stddev, histogram, and + the value is the arguments needed. All other statistics except histogram require no argument. + + .. code-block:: text + + "statistic_configs": { + "count": {}, + "mean": {}, + "sum": {}, + "stddev": {}, + "histogram": { + "*": {"bins": 20}, + "Age": {"bins": 10, "range": [0, 120]} + } + }, + + Histogram requires the following arguments: + 1) numbers of bins or buckets of the histogram + 2) the histogram range values [min, max] + + These arguments are different for each feature. Here are few examples: + + .. code-block:: text + + "histogram": { + "*": {"bins": 20 }, + "Age": {"bins": 10, "range":[0,120]} + } + + The configuration specifies that the + feature 'Age' will have 10 bins for and the range is within [0, 120). + For all other features, the default ("*") configuration is used, with bins = 20. + The range of histogram is not specified, thus requires the Statistics controller + to dynamically estimate histogram range for each feature. Then this estimated global + range (est global min, est. global max) will be used as the histogram range. + + To dynamically estimate such a histogram range, we need the client to provide the local + min and max values in order to calculate the global bin and max value. In order to protect + data privacy and avoid data leakage, a noise level is added to the local min/max + value before sending to the controller. Therefore the controller only gets the 'estimated' + values, and the global min/max are estimated, or more accurately, they are noised global min/max + values. + + Here is another example: + + .. code-block:: text + + "histogram": { + "density": {"bins": 10, "range":[0,120]} + } + + In this example, there is no default histogram configuration for other features. + + This will work correctly if there is only one feature called "density" + but will fail if there are other features in the dataset. + + In the following configuration: + + .. code-block:: text + + "statistic_configs": { + "count": {}, + "mean": {}, + "stddev": {} + } + + Only count, mean and stddev statistics are specified, so the statistics_controller + will only set tasks to calculate these three statistics. + + For 'hierarchy_config', below is an example hierarchy specification with 4 level hierarchy for 9 NVFLARE clients with the names ranging + from 'Device-1' to 'Device-9' and with hierarchical levels named 'Manufacturers', 'Orgs', 'Locations', 'Devices' with + 'Manufacturers' being the top most hierarchical level and "Devices" being the lowest hierarchical level: + + .. code-block:: text + { + "Manufacturers": [ + { + "Name": "Manufacturer-1", + "Orgs": [ + { + "Name": "Org-1", + "Locations": [ + { + "Name": "Location-1", + "Devices": ["Device-1", "Device-2"] + }, + { + "Name": "Location-2", + "Devices": ["Device-3"] + } + ] + }, + { + "Name": "Org-2", + "Locations": [ + { + "Name": "Location-1", + "Devices": ["Device-4", "Device-5"] + }, + { + "Name": "Location-2", + "Devices": ["Device-6"] + } + ] + } + ] + }, + { + "Name": "Manufacturer-2", + "Orgs": [ + { + "Name": "Org-3", + "Locations": [ + { + "Name": "Location-1", + "Devices": ["Device-7", "Device-8"] + }, + { + "Name": "Location-6", + "Devices": ["Device-9"] + } + ] + } + ] + } + ] + } + + """ + super().__init__( + statistic_configs, + writer_id, + wait_time_after_min_received, + result_wait_timeout, + precision, + min_clients, + enable_pre_run_task, + ) + self.hierarchy_config = hierarchy_config + + def statistics_task_flow(self, abort_signal: Signal, fl_ctx: FLContext, statistic_task: str): + """Statistics task flow for the given task. + + Args: + abort_signal: Abort signal. + fl_ctx: The FLContext. + statistic_task: Statistics task. + """ + if self.hierarchy_config: + engine = fl_ctx.get_engine() + ws = engine.get_workspace() + app_conf_dir = ws.get_app_config_dir(fl_ctx.get_job_id()) + hierarchy_config_file_path = os.path.join(app_conf_dir, self.hierarchy_config) + try: + with open(hierarchy_config_file_path) as hierarchy_config_file: + hierarchy_config_json = json.load(hierarchy_config_file) + except FileNotFoundError: + self.system_panic(f"The hierarchy config file {hierarchy_config_file_path} does not exist.", fl_ctx) + return False + except IOError as e: + self.system_panic( + f"An I/O error occurred while loading hierarchy config file {hierarchy_config_file_path}: {e}", + fl_ctx, + ) + return False + except json.decoder.JSONDecodeError as e: + self.system_panic( + f"Failed to decode hierarchy config JSON from the file {hierarchy_config_file_path}: {e}", fl_ctx + ) + return False + except Exception as e: + self.system_panic( + f"An unexpected error occurred while loading hierarchy config file {hierarchy_config_file_path}: {e}", + fl_ctx, + ) + return False + else: + self.system_panic(f"Error: No hierarchy config file provided.", fl_ctx) + return False + + self.log_info(fl_ctx, f"start prepare inputs for task {statistic_task}") + inputs = self._prepare_inputs(statistic_task) + results_cb_fn = self._get_result_cb(statistic_task) + + self.log_info(fl_ctx, f"task: {self.task_name} statistics_flow for {statistic_task} started.") + + if abort_signal.triggered: + return False + + task_props = {StC.STATISTICS_TASK_KEY: statistic_task} + task = Task(name=self.task_name, data=inputs, result_received_cb=results_cb_fn, props=task_props) + + self.broadcast_and_wait( + task=task, + targets=None, + min_responses=self.min_clients, + fl_ctx=fl_ctx, + wait_time_after_min_received=self.wait_time_after_min_received, + abort_signal=abort_signal, + ) + + self.global_statistics = get_global_stats( + self.global_statistics, self.client_statistics, statistic_task, hierarchy_config_json + ) + + self.log_info(fl_ctx, f"task {self.task_name} statistics_flow for {statistic_task} flow end.") + + def _recursively_round_global_stats(self, global_stats): + """Apply given precision to the calculated global statistics. + + Args: + global_stats: Global stats. + + Returns: + A dict containing global stats with applied precision. + """ + if isinstance(global_stats, dict): + for key, value in global_stats.items(): + if key == StC.GLOBAL or key == StC.LOCAL: + for key, metric in value.items(): + if key == StC.STATS_HISTOGRAM: + for ds in metric: + for name, val in metric[ds].items(): + hist: Histogram = metric[ds][name] + buckets = StatisticsController._apply_histogram_precision(hist.bins, self.precision) + metric[ds][name] = buckets + else: + for ds in metric: + for name, val in metric[ds].items(): + metric[ds][name] = round(metric[ds][name], self.precision) + continue + if isinstance(value, list): + for item in value: + self._recursively_round_global_stats(item) + elif isinstance(global_stats, list): + for item in global_stats: + self._recursively_round_global_stats(item) + + return global_stats + + def _combine_all_statistics(self): + """Get combined global statistics with precision applied. + + Returns: + A dict containing global statistics with precision applied. + """ + result = self.global_statistics + return self._recursively_round_global_stats(result) + + def _prepare_inputs(self, statistic_task: str) -> Shareable: + """Prepare inputs for the given task. + + Args: + statistic_task: Statistics task. + + Returns: + A dict containing inputs. + """ + inputs = Shareable() + target_statistics: List[StatisticConfig] = StatisticsController._get_target_statistics( + self.statistic_configs, StC.ordered_statistics[statistic_task] + ) + for tm in target_statistics: + if tm.name == StC.STATS_HISTOGRAM: + if StC.STATS_MIN in self.global_statistics[StC.GLOBAL]: + inputs[StC.STATS_MIN] = self.global_statistics[StC.GLOBAL][StC.STATS_MIN] + if StC.STATS_MAX in self.global_statistics[StC.GLOBAL]: + inputs[StC.STATS_MAX] = self.global_statistics[StC.GLOBAL][StC.STATS_MAX] + elif tm.name == StC.STATS_VAR: + if StC.STATS_COUNT in self.global_statistics[StC.GLOBAL]: + inputs[StC.STATS_GLOBAL_COUNT] = self.global_statistics[StC.GLOBAL][StC.STATS_COUNT] + if StC.STATS_MEAN in self.global_statistics[StC.GLOBAL]: + inputs[StC.STATS_GLOBAL_MEAN] = self.global_statistics[StC.GLOBAL][StC.STATS_MEAN] + + inputs[StC.STATISTICS_TASK_KEY] = statistic_task + + inputs[StC.STATS_TARGET_STATISTICS] = fobs.dumps(target_statistics) + + return inputs diff --git a/tests/unit_test/app_common/statistics/hierarchical_numeric_stats_test.py b/tests/unit_test/app_common/statistics/hierarchical_numeric_stats_test.py new file mode 100644 index 0000000000..e70d03bee2 --- /dev/null +++ b/tests/unit_test/app_common/statistics/hierarchical_numeric_stats_test.py @@ -0,0 +1,918 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nvflare.app_common.abstract.statistics_spec import Bin, Histogram, HistogramType +from nvflare.app_common.app_constant import StatisticsConstants as StC +from nvflare.app_common.workflows.statistics_controller import StatisticsController + +# Unit test upto four hierarchical levels +HIERARCHY_CONFIGS = [ + {"Sites": ["Site-1", "Site-2", "Site-3", "Site-4"]}, + { + "Manufacturers": [ + {"Name": "Manufacturer-1", "Sites": ["Site-1", "Site-2"]}, + {"Name": "Manufacturer-2", "Sites": ["Site-3", "Site-4"]}, + ] + }, + { + "Manufacturers": [ + { + "Name": "Manufacturer-1", + "Orgs": [{"Name": "Org-1", "Sites": ["Site-1"]}, {"Name": "Org-2", "Sites": ["Site-2"]}], + }, + {"Name": "Manufacturer-2", "Orgs": [{"Name": "Org-3", "Sites": ["Site-3", "Site-4"]}]}, + ] + }, + { + "Manufacturers": [ + { + "Name": "Manufacturer-1", + "Orgs": [ + {"Name": "Org-1", "Locations": [{"Name": "Location-1", "Sites": ["Site-1"]}]}, + {"Name": "Org-2", "Locations": [{"Name": "Location-1", "Sites": ["Site-2"]}]}, + ], + }, + { + "Name": "Manufacturer-2", + "Orgs": [ + { + "Name": "Org-3", + "Locations": [ + {"Name": "Location-1", "Sites": ["Site-3"]}, + {"Name": "Location-2", "Sites": ["Site-4"]}, + ], + } + ], + }, + ] + }, +] + +hist_bins = [] +hist_bins.append(Bin(0.0, 0.5, 50)) +hist_bins.append(Bin(0.5, 1.0, 50)) +g_hist = Histogram(HistogramType.STANDARD, hist_bins) + +CLIENT_STATS = { + "count": { + "Site-1": {"data_set1": {"Feature1": 100}}, + "Site-2": {"data_set1": {"Feature1": 200}}, + "Site-3": {"data_set1": {"Feature1": 300}}, + "Site-4": {"data_set1": {"Feature1": 400}}, + }, + "sum": { + "Site-1": {"data_set1": {"Feature1": 1000}}, + "Site-2": {"data_set1": {"Feature1": 2000}}, + "Site-3": {"data_set1": {"Feature1": 3000}}, + "Site-4": {"data_set1": {"Feature1": 4000}}, + }, + "max": { + "Site-1": {"data_set1": {"Feature1": 20}}, + "Site-2": {"data_set1": {"Feature1": 30}}, + "Site-3": {"data_set1": {"Feature1": 40}}, + "Site-4": {"data_set1": {"Feature1": 50}}, + }, + "min": { + "Site-1": {"data_set1": {"Feature1": 0}}, + "Site-2": {"data_set1": {"Feature1": 1}}, + "Site-3": {"data_set1": {"Feature1": 2}}, + "Site-4": {"data_set1": {"Feature1": 3}}, + }, + "mean": { + "Site-1": {"data_set1": {"Feature1": 10}}, + "Site-2": {"data_set1": {"Feature1": 10}}, + "Site-3": {"data_set1": {"Feature1": 10}}, + "Site-4": {"data_set1": {"Feature1": 10}}, + }, + "var": { + "Site-1": {"data_set1": {"Feature1": 0.1}}, + "Site-2": {"data_set1": {"Feature1": 0.1}}, + "Site-3": {"data_set1": {"Feature1": 0.1}}, + "Site-4": {"data_set1": {"Feature1": 0.1}}, + }, + "stddev": { + "Site-1": {"data_set1": {"Feature1": 0.1}}, + "Site-2": {"data_set1": {"Feature1": 0.1}}, + "Site-3": {"data_set1": {"Feature1": 0.1}}, + "Site-4": {"data_set1": {"Feature1": 0.1}}, + }, + "histogram": { + "Site-1": {"data_set1": {"Feature1": g_hist}}, + "Site-2": {"data_set1": {"Feature1": g_hist}}, + "Site-3": {"data_set1": {"Feature1": g_hist}}, + "Site-4": {"data_set1": {"Feature1": g_hist}}, + }, +} + +global_stats_0 = { + "Global": { + "count": {"data_set1": {"Feature1": 1000}}, + "sum": {"data_set1": {"Feature1": 10000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=200), + Bin(low_value=0.5, high_value=1.0, sample_count=200), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.4}}, + "stddev": {"data_set1": {"Feature1": 0.63}}, + }, + "Sites": [ + { + "Name": "Site-1", + "Local": { + "count": {"data_set1": {"Feature1": 100}}, + "sum": {"data_set1": {"Feature1": 1000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 0}}, + "max": {"data_set1": {"Feature1": 20}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Site-2", + "Local": { + "count": {"data_set1": {"Feature1": 200}}, + "sum": {"data_set1": {"Feature1": 2000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Site-3", + "Local": { + "count": {"data_set1": {"Feature1": 300}}, + "sum": {"data_set1": {"Feature1": 3000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 40}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Site-4", + "Local": { + "count": {"data_set1": {"Feature1": 400}}, + "sum": {"data_set1": {"Feature1": 4000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 3}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + ], +} +global_stats_1 = { + "Global": { + "count": {"data_set1": {"Feature1": 1000}}, + "sum": {"data_set1": {"Feature1": 10000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=200), + Bin(low_value=0.5, high_value=1.0, sample_count=200), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.4}}, + "stddev": {"data_set1": {"Feature1": 0.63}}, + }, + "Manufacturers": [ + { + "Name": "Manufacturer-1", + "Sites": [ + { + "Name": "Site-1", + "Local": { + "count": {"data_set1": {"Feature1": 100}}, + "sum": {"data_set1": {"Feature1": 1000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 0}}, + "max": {"data_set1": {"Feature1": 20}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Site-2", + "Local": { + "count": {"data_set1": {"Feature1": 200}}, + "sum": {"data_set1": {"Feature1": 2000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + ], + "Global": { + "count": {"data_set1": {"Feature1": 300}}, + "sum": {"data_set1": {"Feature1": 3000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=100), + Bin(low_value=0.5, high_value=1.0, sample_count=100), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.2}}, + "stddev": {"data_set1": {"Feature1": 0.45}}, + }, + }, + { + "Name": "Manufacturer-2", + "Sites": [ + { + "Name": "Site-3", + "Local": { + "count": {"data_set1": {"Feature1": 300}}, + "sum": {"data_set1": {"Feature1": 3000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 40}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Site-4", + "Local": { + "count": {"data_set1": {"Feature1": 400}}, + "sum": {"data_set1": {"Feature1": 4000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 3}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + ], + "Global": { + "count": {"data_set1": {"Feature1": 700}}, + "sum": {"data_set1": {"Feature1": 7000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=100), + Bin(low_value=0.5, high_value=1.0, sample_count=100), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.2}}, + "stddev": {"data_set1": {"Feature1": 0.45}}, + }, + }, + ], +} +global_stats_2 = { + "Global": { + "count": {"data_set1": {"Feature1": 1000}}, + "sum": {"data_set1": {"Feature1": 10000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=200), + Bin(low_value=0.5, high_value=1.0, sample_count=200), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.4}}, + "stddev": {"data_set1": {"Feature1": 0.63}}, + }, + "Manufacturers": [ + { + "Name": "Manufacturer-1", + "Orgs": [ + { + "Name": "Org-1", + "Sites": [ + { + "Name": "Site-1", + "Local": { + "count": {"data_set1": {"Feature1": 100}}, + "sum": {"data_set1": {"Feature1": 1000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 0}}, + "max": {"data_set1": {"Feature1": 20}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 100}}, + "sum": {"data_set1": {"Feature1": 1000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 0}}, + "max": {"data_set1": {"Feature1": 20}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Org-2", + "Sites": [ + { + "Name": "Site-2", + "Local": { + "count": {"data_set1": {"Feature1": 200}}, + "sum": {"data_set1": {"Feature1": 2000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 200}}, + "sum": {"data_set1": {"Feature1": 2000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + ], + "Global": { + "count": {"data_set1": {"Feature1": 300}}, + "sum": {"data_set1": {"Feature1": 3000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=100), + Bin(low_value=0.5, high_value=1.0, sample_count=100), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.2}}, + "stddev": {"data_set1": {"Feature1": 0.45}}, + }, + }, + { + "Name": "Manufacturer-2", + "Orgs": [ + { + "Name": "Org-3", + "Sites": [ + { + "Name": "Site-3", + "Local": { + "count": {"data_set1": {"Feature1": 300}}, + "sum": {"data_set1": {"Feature1": 3000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 40}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Site-4", + "Local": { + "count": {"data_set1": {"Feature1": 400}}, + "sum": {"data_set1": {"Feature1": 4000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 3}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + ], + "Global": { + "count": {"data_set1": {"Feature1": 700}}, + "sum": {"data_set1": {"Feature1": 7000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=100), + Bin(low_value=0.5, high_value=1.0, sample_count=100), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.2}}, + "stddev": {"data_set1": {"Feature1": 0.45}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 700}}, + "sum": {"data_set1": {"Feature1": 7000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=100), + Bin(low_value=0.5, high_value=1.0, sample_count=100), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.2}}, + "stddev": {"data_set1": {"Feature1": 0.45}}, + }, + }, + ], +} +global_stats_3 = { + "Global": { + "count": {"data_set1": {"Feature1": 1000}}, + "sum": {"data_set1": {"Feature1": 10000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=200), + Bin(low_value=0.5, high_value=1.0, sample_count=200), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.4}}, + "stddev": {"data_set1": {"Feature1": 0.63}}, + }, + "Manufacturers": [ + { + "Name": "Manufacturer-1", + "Orgs": [ + { + "Name": "Org-1", + "Locations": [ + { + "Name": "Location-1", + "Sites": [ + { + "Name": "Site-1", + "Local": { + "count": {"data_set1": {"Feature1": 100}}, + "sum": {"data_set1": {"Feature1": 1000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 0}}, + "max": {"data_set1": {"Feature1": 20}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 100}}, + "sum": {"data_set1": {"Feature1": 1000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 0}}, + "max": {"data_set1": {"Feature1": 20}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 100}}, + "sum": {"data_set1": {"Feature1": 1000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 0}}, + "max": {"data_set1": {"Feature1": 20}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Org-2", + "Locations": [ + { + "Name": "Location-1", + "Sites": [ + { + "Name": "Site-2", + "Local": { + "count": {"data_set1": {"Feature1": 200}}, + "sum": {"data_set1": {"Feature1": 2000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 200}}, + "sum": {"data_set1": {"Feature1": 2000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 200}}, + "sum": {"data_set1": {"Feature1": 2000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + ], + "Global": { + "count": {"data_set1": {"Feature1": 300}}, + "sum": {"data_set1": {"Feature1": 3000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 1}}, + "max": {"data_set1": {"Feature1": 30}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=100), + Bin(low_value=0.5, high_value=1.0, sample_count=100), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.2}}, + "stddev": {"data_set1": {"Feature1": 0.45}}, + }, + }, + { + "Name": "Manufacturer-2", + "Orgs": [ + { + "Name": "Org-3", + "Locations": [ + { + "Name": "Location-1", + "Sites": [ + { + "Name": "Site-3", + "Local": { + "count": {"data_set1": {"Feature1": 300}}, + "sum": {"data_set1": {"Feature1": 3000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 40}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 300}}, + "sum": {"data_set1": {"Feature1": 3000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 40}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + { + "Name": "Location-2", + "Sites": [ + { + "Name": "Site-4", + "Local": { + "count": {"data_set1": {"Feature1": 400}}, + "sum": {"data_set1": {"Feature1": 4000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 3}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 400}}, + "sum": {"data_set1": {"Feature1": 4000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 3}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=50), + Bin(low_value=0.5, high_value=1.0, sample_count=50), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.1}}, + "stddev": {"data_set1": {"Feature1": 0.32}}, + }, + }, + ], + "Global": { + "count": {"data_set1": {"Feature1": 700}}, + "sum": {"data_set1": {"Feature1": 7000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=100), + Bin(low_value=0.5, high_value=1.0, sample_count=100), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.2}}, + "stddev": {"data_set1": {"Feature1": 0.45}}, + }, + } + ], + "Global": { + "count": {"data_set1": {"Feature1": 700}}, + "sum": {"data_set1": {"Feature1": 7000}}, + "mean": {"data_set1": {"Feature1": 10.0}}, + "min": {"data_set1": {"Feature1": 2}}, + "max": {"data_set1": {"Feature1": 50}}, + "histogram": { + "data_set1": { + "Feature1": [ + Bin(low_value=0.0, high_value=0.5, sample_count=100), + Bin(low_value=0.5, high_value=1.0, sample_count=100), + ] + } + }, + "var": {"data_set1": {"Feature1": 0.2}}, + "stddev": {"data_set1": {"Feature1": 0.45}}, + }, + }, + ], +} + +PARAMS = [ + (CLIENT_STATS, global_stats_0, HIERARCHY_CONFIGS[0]), + (CLIENT_STATS, global_stats_1, HIERARCHY_CONFIGS[1]), + (CLIENT_STATS, global_stats_2, HIERARCHY_CONFIGS[2]), + (CLIENT_STATS, global_stats_3, HIERARCHY_CONFIGS[3]), +] + + +class TestHierarchicalNumericStats: + def _round_global_stats(self, global_stats): + if isinstance(global_stats, dict): + for key, value in global_stats.items(): + if key == StC.GLOBAL or key == StC.LOCAL: + for key, metric in value.items(): + if key == StC.STATS_HISTOGRAM: + for ds in metric: + for name, val in metric[ds].items(): + hist: Histogram = metric[ds][name] + buckets = StatisticsController._apply_histogram_precision(hist.bins, 2) + metric[ds][name] = buckets + else: + for ds in metric: + for name, val in metric[ds].items(): + metric[ds][name] = round(metric[ds][name], 2) + continue + if isinstance(value, list): + for item in value: + self._round_global_stats(item) + if isinstance(global_stats, list): + for item in global_stats: + self._round_global_stats(item) + return global_stats + + @pytest.mark.parametrize("client_stats, expected_global_stats, hierarchy_configs", PARAMS) + def test_global_stats(self, client_stats, expected_global_stats, hierarchy_configs): + from nvflare.app_common.statistics.hierarchical_numeric_stats import get_global_stats + + global_stats = get_global_stats({}, client_stats, StC.STATS_1st_STATISTICS, hierarchy_configs) + global_stats = get_global_stats(global_stats, client_stats, StC.STATS_2nd_STATISTICS, hierarchy_configs) + result = self._round_global_stats(global_stats) + + assert expected_global_stats == result