From af0c4e5f114fe10bb0e1ec6fac4de966571e1983 Mon Sep 17 00:00:00 2001 From: chesterxgchen Date: Mon, 26 Aug 2024 15:52:15 -0700 Subject: [PATCH] switch the id prefix to real id --- nvflare/job_config/stats_job.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/nvflare/job_config/stats_job.py b/nvflare/job_config/stats_job.py index 610d4550fb..929132a06a 100644 --- a/nvflare/job_config/stats_job.py +++ b/nvflare/job_config/stats_job.py @@ -38,7 +38,7 @@ def __init__( ): super().__init__() self.writer_id = "stats_writer" - self.stats_generator_id = "stats_generator" + self.stats_generator_id_prefix = "stats_generator" self.job_name = job_name self.stats_generator = stats_generator self.statistic_configs = statistic_configs @@ -53,17 +53,17 @@ def __init__( def setup_server(self): # define stats controller ctr = self.get_stats_controller() + self.to(ctr, "server") # define stat writer to output Json file stats_writer = self.get_stats_output_writer() - self.to(ctr, "server") self.to(stats_writer, "server", id=self.writer_id) def setup_client(self, sites: List[str]): # Client side job config # Add client site for site_id in sites: - self.to(self.stats_generator, site_id, id=self.stats_generator_id) - executor = StatisticsExecutor(generator_id=self.stats_generator_id) + stats_generator_id = self.to(self.stats_generator, site_id, id=self.stats_generator_id_prefix) + executor = StatisticsExecutor(generator_id=stats_generator_id) self.to(executor, site_id, tasks=["fed_stats_pre_run", "fed_stats"]) self.add_privacy_result_filters(site_id) @@ -78,13 +78,12 @@ def get_stats_output_writer(self): def add_privacy_result_filters(self, site_id: str): # add privacy filters - result_cleanser_ids = ["min_count_cleanser", "min_max_noise_cleanser", "hist_bins_cleanser"] - result_filter = StatisticsPrivacyFilter(result_cleanser_ids=result_cleanser_ids) - min_count_cleanser = MinCountCleanser(min_count=self.min_count) min_max_cleanser = AddNoiseToMinMax(min_noise_level=self.min_noise_level, max_noise_level=self.max_noise_level) hist_bins_cleanser = HistogramBinsCleanser(max_bins_percent=self.max_bins_percent) - self.to(min_count_cleanser, site_id, id="min_count_cleanser") - self.to(min_max_cleanser, site_id, id="min_max_noise_cleanser") - self.to(hist_bins_cleanser, site_id, id="hist_bins_cleanser") + result_cleanser_ids = [self.to(min_count_cleanser, site_id, id="min_count_cleanser"), + self.to(min_max_cleanser, site_id, id="min_max_noise_cleanser"), + self.to(hist_bins_cleanser, site_id, id="hist_bins_cleanser")] + + result_filter = StatisticsPrivacyFilter(result_cleanser_ids=result_cleanser_ids) self.to(result_filter, site_id, filter_type=FilterType.TASK_RESULT, tasks=["fed_stats"])