Skip to content

Commit

Permalink
format/style
Browse files Browse the repository at this point in the history
  • Loading branch information
chesterxgchen committed Aug 26, 2024
1 parent 72f928b commit ab8743f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 27 deletions.
19 changes: 9 additions & 10 deletions examples/hello-world/step-by-step/cifar10/stats/image_stats_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse

from image_statistics import ImageStatistics

from nvflare.job_config.stats_job import StatsJob


Expand All @@ -38,19 +40,16 @@ def main():
work_dir = args.work_dir
export_config = args.export_config

statistic_configs = {
"count": {},
"histogram": {
"*": {"bins": 20, "range": [0, 256]}
}
}
statistic_configs = {"count": {}, "histogram": {"*": {"bins": 20, "range": [0, 256]}}}
# define local stats generator
stats_generator = ImageStatistics(data_root_dir)

job = StatsJob(job_name="stats_image",
statistic_configs=statistic_configs,
stats_generator=stats_generator,
output_path=output_path)
job = StatsJob(
job_name="stats_image",
statistic_configs=statistic_configs,
stats_generator=stats_generator,
output_path=output_path,
)

sites = [f"site-{i + 1}" for i in range(n_clients)]
job.setup_client(sites)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import argparse

from df_stats import DFStatistics

from nvflare.job_config.stats_job import StatsJob


Expand Down Expand Up @@ -50,10 +51,12 @@ def main():
# define local stats generator
df_stats_generator = DFStatistics(data_root_dir=data_root_dir)

job = StatsJob(job_name="stats_df",
statistic_configs=statistic_configs,
stats_generator=df_stats_generator,
output_path=output_path)
job = StatsJob(
job_name="stats_df",
statistic_configs=statistic_configs,
stats_generator=df_stats_generator,
output_path=output_path,
)

sites = [f"site-{i + 1}" for i in range(n_clients)]
job.setup_client(sites)
Expand Down
27 changes: 14 additions & 13 deletions nvflare/job_config/stats_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@


class StatsJob(FedJob):

def __init__(self,
job_name: str,
statistic_configs: dict,
stats_generator: Statistics,
output_path: str,
min_count: int = 10,
min_noise_level=0.1,
max_noise_level=0.3,
max_bins_percent=10,
):
def __init__(
self,
job_name: str,
statistic_configs: dict,
stats_generator: Statistics,
output_path: str,
min_count: int = 10,
min_noise_level=0.1,
max_noise_level=0.3,
max_bins_percent=10,
):
super().__init__()
self.writer_id = "stats_writer"
self.stats_generator_id = "stats_generator"
Expand Down Expand Up @@ -68,8 +68,9 @@ def setup_client(self, sites: List[str]):
self.add_privacy_result_filters(site_id)

def get_stats_controller(self) -> StatisticsController:
return StatisticsController(statistic_configs=self.statistic_configs, writer_id=self.writer_id,
enable_pre_run_task=False)
return StatisticsController(
statistic_configs=self.statistic_configs, writer_id=self.writer_id, enable_pre_run_task=False
)

def get_stats_output_writer(self):
json_encoder_path = "nvflare.app_common.utils.json_utils.ObjectEncoder"
Expand Down

0 comments on commit ab8743f

Please sign in to comment.