Skip to content

Commit

Permalink
format style
Browse files Browse the repository at this point in the history
  • Loading branch information
chesterxgchen committed Aug 26, 2024
1 parent a3cfd85 commit 41a25f5
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from kmeans_assembler import KMeansAssembler

from nvflare import FedJob
from nvflare.app_common.aggregators.collect_and_assemble_aggregator import CollectAndAssembleAggregator
from nvflare.app_common.shareablegenerators import FullModelShareableGenerator
from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather
from nvflare.app_opt.sklearn.joblib_model_param_persistor import JoblibModelParamPersistor
from nvflare.job_config.script_runner import ScriptRunner, FrameworkType
from nvflare.job_config.script_runner import FrameworkType, ScriptRunner

if __name__ == "__main__":
n_clients = 3
Expand All @@ -38,16 +39,18 @@
job.to(CollectAndAssembleAggregator(assembler_id=assembler_id), "server", id=aggregator_id)
job.to(KMeansAssembler(), "server", id=assembler_id)

ctrl = ScatterAndGather(min_clients=n_clients,
num_rounds=num_rounds,
start_round=0,
wait_time_after_min_received=0,
aggregator_id=aggregator_id,
persistor_id=persistor_id,
shareable_generator_id=shareable_generator_id,
train_task_name="train",
train_timeout=0,
allow_empty_global_weights=True)
ctrl = ScatterAndGather(
min_clients=n_clients,
num_rounds=num_rounds,
start_round=0,
wait_time_after_min_received=0,
aggregator_id=aggregator_id,
persistor_id=persistor_id,
shareable_generator_id=shareable_generator_id,
train_task_name="train",
train_timeout=0,
allow_empty_global_weights=True,
)

job.to(ctrl, "server")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from nvflare.app_common.shareablegenerators import FullModelShareableGenerator
from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather
from nvflare.app_opt.sklearn.joblib_model_param_persistor import JoblibModelParamPersistor
from nvflare.job_config.script_runner import ScriptRunner, FrameworkType
from nvflare.job_config.script_runner import FrameworkType, ScriptRunner

if __name__ == "__main__":
n_clients = 3
Expand All @@ -32,28 +32,24 @@
job = FedJob("sklearn_svm")

initial_params = dict(
n_classes=2,
learning_rate="constant",
eta0=1e-05,
loss="log_loss",
penalty="l2",
fit_intercept=True,
max_iter=1
n_classes=2, learning_rate="constant", eta0=1e-05, loss="log_loss", penalty="l2", fit_intercept=True, max_iter=1
)
job.to(JoblibModelParamPersistor(initial_params=initial_params), "server", id=persistor_id)
job.to(FullModelShareableGenerator(), "server", id=shareable_generator_id)
job.to(InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS), "server", id=aggregator_id)

ctrl = ScatterAndGather(min_clients=n_clients,
num_rounds=num_rounds,
start_round=0,
wait_time_after_min_received=0,
aggregator_id=aggregator_id,
persistor_id=persistor_id,
shareable_generator_id=shareable_generator_id,
train_task_name="train",
train_timeout=0,
allow_empty_global_weights=True)
ctrl = ScatterAndGather(
min_clients=n_clients,
num_rounds=num_rounds,
start_round=0,
wait_time_after_min_received=0,
aggregator_id=aggregator_id,
persistor_id=persistor_id,
shareable_generator_id=shareable_generator_id,
train_task_name="train",
train_timeout=0,
allow_empty_global_weights=True,
)

job.to(ctrl, "server")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from nvflare.app_common.shareablegenerators import FullModelShareableGenerator
from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather
from nvflare.app_opt.sklearn.joblib_model_param_persistor import JoblibModelParamPersistor
from nvflare.job_config.script_runner import ScriptRunner, FrameworkType
from nvflare.job_config.script_runner import FrameworkType, ScriptRunner

if __name__ == "__main__":
n_clients = 3
Expand All @@ -32,28 +32,24 @@
job = FedJob("sklearn_sgd")

initial_params = dict(
n_classes=2,
learning_rate="constant",
eta0=1e-05,
loss="log_loss",
penalty="l2",
fit_intercept=True,
max_iter=1
n_classes=2, learning_rate="constant", eta0=1e-05, loss="log_loss", penalty="l2", fit_intercept=True, max_iter=1
)
job.to(JoblibModelParamPersistor(initial_params=initial_params), "server", id=persistor_id)
job.to(FullModelShareableGenerator(), "server", id=shareable_generator_id)
job.to(InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS), "server", id=aggregator_id)

ctrl = ScatterAndGather(min_clients=n_clients,
num_rounds=num_rounds,
start_round=0,
wait_time_after_min_received=0,
aggregator_id=aggregator_id,
persistor_id=persistor_id,
shareable_generator_id=shareable_generator_id,
train_task_name="train",
train_timeout=0,
allow_empty_global_weights=True)
ctrl = ScatterAndGather(
min_clients=n_clients,
num_rounds=num_rounds,
start_round=0,
wait_time_after_min_received=0,
aggregator_id=aggregator_id,
persistor_id=persistor_id,
shareable_generator_id=shareable_generator_id,
train_task_name="train",
train_timeout=0,
allow_empty_global_weights=True,
)

job.to(ctrl, "server")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from svm_assembler import SVMAssembler

from nvflare import FedJob
from nvflare.app_common.aggregators.collect_and_assemble_aggregator import CollectAndAssembleAggregator
from nvflare.app_common.shareablegenerators import FullModelShareableGenerator
from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather
from nvflare.app_opt.sklearn.joblib_model_param_persistor import JoblibModelParamPersistor
from nvflare.job_config.script_runner import ScriptRunner, FrameworkType
from svm_assembler import SVMAssembler
from nvflare.job_config.script_runner import FrameworkType, ScriptRunner

if __name__ == "__main__":
n_clients = 3
Expand All @@ -35,19 +36,21 @@
initial_params = dict(kernel="rbf")
job.to(JoblibModelParamPersistor(initial_params=initial_params), "server", id=persistor_id)
job.to(FullModelShareableGenerator(), "server", id=shareable_generator_id)
job.to(CollectAndAssembleAggregator(assembler_id=assembler_id), "server", id=aggregator_id)
job.to(SVMAssembler(kernel= "rbf"), "server", id = assembler_id)

ctrl = ScatterAndGather(min_clients=n_clients,
num_rounds=num_rounds,
start_round=0,
wait_time_after_min_received=0,
aggregator_id=aggregator_id,
persistor_id=persistor_id,
shareable_generator_id=shareable_generator_id,
train_task_name="train",
train_timeout=0,
allow_empty_global_weights=True)
job.to(CollectAndAssembleAggregator(assembler_id=assembler_id), "server", id=aggregator_id)
job.to(SVMAssembler(kernel="rbf"), "server", id=assembler_id)

ctrl = ScatterAndGather(
min_clients=n_clients,
num_rounds=num_rounds,
start_round=0,
wait_time_after_min_received=0,
aggregator_id=aggregator_id,
persistor_id=persistor_id,
shareable_generator_id=shareable_generator_id,
train_task_name="train",
train_timeout=0,
allow_empty_global_weights=True,
)

job.to(ctrl, "server")

Expand Down

0 comments on commit 41a25f5

Please sign in to comment.