diff --git a/examples/hello-world/step-by-step/higgs/sklearn-kmeans/code/sklearn_kmeans_job.py b/examples/hello-world/step-by-step/higgs/sklearn-kmeans/code/sklearn_kmeans_job.py index 6a7d18ae1b..3a0f6a2184 100644 --- a/examples/hello-world/step-by-step/higgs/sklearn-kmeans/code/sklearn_kmeans_job.py +++ b/examples/hello-world/step-by-step/higgs/sklearn-kmeans/code/sklearn_kmeans_job.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from kmeans_assembler import KMeansAssembler + from nvflare import FedJob from nvflare.app_common.aggregators.collect_and_assemble_aggregator import CollectAndAssembleAggregator from nvflare.app_common.shareablegenerators import FullModelShareableGenerator from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather from nvflare.app_opt.sklearn.joblib_model_param_persistor import JoblibModelParamPersistor -from nvflare.job_config.script_runner import ScriptRunner, FrameworkType +from nvflare.job_config.script_runner import FrameworkType, ScriptRunner if __name__ == "__main__": n_clients = 3 @@ -38,16 +39,18 @@ job.to(CollectAndAssembleAggregator(assembler_id=assembler_id), "server", id=aggregator_id) job.to(KMeansAssembler(), "server", id=assembler_id) - ctrl = ScatterAndGather(min_clients=n_clients, - num_rounds=num_rounds, - start_round=0, - wait_time_after_min_received=0, - aggregator_id=aggregator_id, - persistor_id=persistor_id, - shareable_generator_id=shareable_generator_id, - train_task_name="train", - train_timeout=0, - allow_empty_global_weights=True) + ctrl = ScatterAndGather( + min_clients=n_clients, + num_rounds=num_rounds, + start_round=0, + wait_time_after_min_received=0, + aggregator_id=aggregator_id, + persistor_id=persistor_id, + shareable_generator_id=shareable_generator_id, + train_task_name="train", + train_timeout=0, + allow_empty_global_weights=True, + ) job.to(ctrl, "server") diff --git a/examples/hello-world/step-by-step/higgs/sklearn-kmeans/sklearn_kmeans_job.py b/examples/hello-world/step-by-step/higgs/sklearn-kmeans/sklearn_kmeans_job.py index 093b4f5bd0..98a5fae371 100644 --- a/examples/hello-world/step-by-step/higgs/sklearn-kmeans/sklearn_kmeans_job.py +++ b/examples/hello-world/step-by-step/higgs/sklearn-kmeans/sklearn_kmeans_job.py @@ -17,7 +17,7 @@ from nvflare.app_common.shareablegenerators import FullModelShareableGenerator from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather from nvflare.app_opt.sklearn.joblib_model_param_persistor import JoblibModelParamPersistor -from nvflare.job_config.script_runner import ScriptRunner, FrameworkType +from nvflare.job_config.script_runner import FrameworkType, ScriptRunner if __name__ == "__main__": n_clients = 3 @@ -32,28 +32,24 @@ job = FedJob("sklearn_svm") initial_params = dict( - n_classes=2, - learning_rate="constant", - eta0=1e-05, - loss="log_loss", - penalty="l2", - fit_intercept=True, - max_iter=1 + n_classes=2, learning_rate="constant", eta0=1e-05, loss="log_loss", penalty="l2", fit_intercept=True, max_iter=1 ) job.to(JoblibModelParamPersistor(initial_params=initial_params), "server", id=persistor_id) job.to(FullModelShareableGenerator(), "server", id=shareable_generator_id) job.to(InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS), "server", id=aggregator_id) - ctrl = ScatterAndGather(min_clients=n_clients, - num_rounds=num_rounds, - start_round=0, - wait_time_after_min_received=0, - aggregator_id=aggregator_id, - persistor_id=persistor_id, - shareable_generator_id=shareable_generator_id, - train_task_name="train", - train_timeout=0, - allow_empty_global_weights=True) + ctrl = ScatterAndGather( + min_clients=n_clients, + num_rounds=num_rounds, + start_round=0, + wait_time_after_min_received=0, + aggregator_id=aggregator_id, + persistor_id=persistor_id, + shareable_generator_id=shareable_generator_id, + train_task_name="train", + train_timeout=0, + allow_empty_global_weights=True, + ) job.to(ctrl, "server") diff --git a/examples/hello-world/step-by-step/higgs/sklearn-linear/sklearn_linear_job.py b/examples/hello-world/step-by-step/higgs/sklearn-linear/sklearn_linear_job.py index 2389c519d4..51035f3f33 100644 --- a/examples/hello-world/step-by-step/higgs/sklearn-linear/sklearn_linear_job.py +++ b/examples/hello-world/step-by-step/higgs/sklearn-linear/sklearn_linear_job.py @@ -17,7 +17,7 @@ from nvflare.app_common.shareablegenerators import FullModelShareableGenerator from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather from nvflare.app_opt.sklearn.joblib_model_param_persistor import JoblibModelParamPersistor -from nvflare.job_config.script_runner import ScriptRunner, FrameworkType +from nvflare.job_config.script_runner import FrameworkType, ScriptRunner if __name__ == "__main__": n_clients = 3 @@ -32,28 +32,24 @@ job = FedJob("sklearn_sgd") initial_params = dict( - n_classes=2, - learning_rate="constant", - eta0=1e-05, - loss="log_loss", - penalty="l2", - fit_intercept=True, - max_iter=1 + n_classes=2, learning_rate="constant", eta0=1e-05, loss="log_loss", penalty="l2", fit_intercept=True, max_iter=1 ) job.to(JoblibModelParamPersistor(initial_params=initial_params), "server", id=persistor_id) job.to(FullModelShareableGenerator(), "server", id=shareable_generator_id) job.to(InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS), "server", id=aggregator_id) - ctrl = ScatterAndGather(min_clients=n_clients, - num_rounds=num_rounds, - start_round=0, - wait_time_after_min_received=0, - aggregator_id=aggregator_id, - persistor_id=persistor_id, - shareable_generator_id=shareable_generator_id, - train_task_name="train", - train_timeout=0, - allow_empty_global_weights=True) + ctrl = ScatterAndGather( + min_clients=n_clients, + num_rounds=num_rounds, + start_round=0, + wait_time_after_min_received=0, + aggregator_id=aggregator_id, + persistor_id=persistor_id, + shareable_generator_id=shareable_generator_id, + train_task_name="train", + train_timeout=0, + allow_empty_global_weights=True, + ) job.to(ctrl, "server") diff --git a/examples/hello-world/step-by-step/higgs/sklearn-svm/code/sklearn_svm_job.py b/examples/hello-world/step-by-step/higgs/sklearn-svm/code/sklearn_svm_job.py index ff5489880d..be744747cb 100644 --- a/examples/hello-world/step-by-step/higgs/sklearn-svm/code/sklearn_svm_job.py +++ b/examples/hello-world/step-by-step/higgs/sklearn-svm/code/sklearn_svm_job.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from svm_assembler import SVMAssembler + from nvflare import FedJob from nvflare.app_common.aggregators.collect_and_assemble_aggregator import CollectAndAssembleAggregator from nvflare.app_common.shareablegenerators import FullModelShareableGenerator from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather from nvflare.app_opt.sklearn.joblib_model_param_persistor import JoblibModelParamPersistor -from nvflare.job_config.script_runner import ScriptRunner, FrameworkType -from svm_assembler import SVMAssembler +from nvflare.job_config.script_runner import FrameworkType, ScriptRunner if __name__ == "__main__": n_clients = 3 @@ -35,19 +36,21 @@ initial_params = dict(kernel="rbf") job.to(JoblibModelParamPersistor(initial_params=initial_params), "server", id=persistor_id) job.to(FullModelShareableGenerator(), "server", id=shareable_generator_id) - job.to(CollectAndAssembleAggregator(assembler_id=assembler_id), "server", id=aggregator_id) - job.to(SVMAssembler(kernel= "rbf"), "server", id = assembler_id) - - ctrl = ScatterAndGather(min_clients=n_clients, - num_rounds=num_rounds, - start_round=0, - wait_time_after_min_received=0, - aggregator_id=aggregator_id, - persistor_id=persistor_id, - shareable_generator_id=shareable_generator_id, - train_task_name="train", - train_timeout=0, - allow_empty_global_weights=True) + job.to(CollectAndAssembleAggregator(assembler_id=assembler_id), "server", id=aggregator_id) + job.to(SVMAssembler(kernel="rbf"), "server", id=assembler_id) + + ctrl = ScatterAndGather( + min_clients=n_clients, + num_rounds=num_rounds, + start_round=0, + wait_time_after_min_received=0, + aggregator_id=aggregator_id, + persistor_id=persistor_id, + shareable_generator_id=shareable_generator_id, + train_task_name="train", + train_timeout=0, + allow_empty_global_weights=True, + ) job.to(ctrl, "server")