From 2dd0ace494a6525fa0881cc6fb1b4b09b397c548 Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Mon, 7 Dec 2020 17:09:36 +0800 Subject: [PATCH] Fix missing DDP in torch distributed (#3185) * fix ddp * add model join * fix * add ut * remove join * remove debug msg --- .../src/bigdl/orca/learn/pytorch/torch_runner.py | 14 +++++++++----- .../pytorch/test_estimator_pytorch_backend.py | 16 ++++++++++++++-- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py b/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py index c8c0b3c4930..917281d57ab 100644 --- a/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py +++ b/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py @@ -33,7 +33,6 @@ from filelock import FileLock import logging -import inspect import io import itertools import os @@ -125,7 +124,7 @@ def setup_horovod(self): self.rank = hvd.rank() self.size = hvd.size() self.setup_components_horovod() - self.setup_operator() + self.setup_operator(self.models) def setup_address(self): ip = ray.services.get_node_ip_address() @@ -134,6 +133,7 @@ def setup_address(self): def setup_torch_distribute(self, url, world_rank, world_size): import torch.distributed as dist + from torch.nn.parallel import DistributedDataParallel dist.init_process_group( backend="gloo", init_method=url, @@ -143,7 +143,11 @@ def setup_torch_distribute(self, url, world_rank, world_size): self.rank = world_rank self.size = world_size self.setup_components() - self.setup_operator() + training_models = [ + DistributedDataParallel(model) + for model in self.models + ] + self.setup_operator(training_models) def setup_components(self): """Runs the creator functions without any distributed coordination.""" @@ -193,12 +197,12 @@ def setup_components_horovod(self): self._create_schedulers_if_available() self._create_loss() - def setup_operator(self): + def setup_operator(self, training_models): """Create the training operator.""" self.training_operator =\ self.training_operator_cls( self.config, - models=self.models, + models=training_models, optimizers=self.optimizers, criterion=self.criterion, world_rank=self.rank, diff --git a/python/orca/test/bigdl/orca/learn/ray/pytorch/test_estimator_pytorch_backend.py b/python/orca/test/bigdl/orca/learn/ray/pytorch/test_estimator_pytorch_backend.py index faf22529083..83a09356b8b 100644 --- a/python/orca/test/bigdl/orca/learn/ray/pytorch/test_estimator_pytorch_backend.py +++ b/python/orca/test/bigdl/orca/learn/ray/pytorch/test_estimator_pytorch_backend.py @@ -91,11 +91,12 @@ def get_optimizer(model, config): class TestPyTorchEstimator(TestCase): - def test_linear(self): + def test_data_creator(self): estimator = Estimator.from_torch(model=get_model, optimizer=get_optimizer, loss=nn.BCELoss(), config={"lr": 1e-2}, + workers_per_node=2, backend="torch_distributed") train_stats = estimator.fit(train_data_loader, epochs=2, batch_size=128) print(train_stats) @@ -103,9 +104,20 @@ def test_linear(self): print(val_stats) assert 0 < val_stats["val_accuracy"] < 1 assert estimator.get_model() + + # Verify syncing weights, i.e. the two workers have the same weights after training + import ray + remote_workers = estimator.estimator.remote_workers + state_dicts = ray.get([worker.state_dict.remote() for worker in remote_workers]) + weights = [state["models"] for state in state_dicts] + worker1_weights = weights[0][0] + worker2_weights = weights[1][0] + for layer in list(worker1_weights.keys()): + assert np.allclose(worker1_weights[layer].numpy(), + worker2_weights[layer].numpy()) estimator.shutdown() - def test_linear_spark_xshards(self): + def test_spark_xshards(self): from zoo import init_nncontext from zoo.orca.data import SparkXShards estimator = Estimator.from_torch(model=get_model,