Skip to content

Commit

Permalink
[backport 0.9] Fix missing DDP in torch distributed (intel-analytics#…
Browse files Browse the repository at this point in the history
…3185) (intel-analytics#3195)

* fix ddp

* add model join

* fix

* add ut

* remove join

* remove debug msg
  • Loading branch information
hkvision committed Dec 7, 2020
1 parent 58b343c commit abc38b8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,33 @@ def get_optimizer(model, config):


class TestPyTorchEstimator(TestCase):
def test_linear(self):
def test_data_creator(self):
estimator = Estimator.from_torch(model=get_model,
optimizer=get_optimizer,
loss=nn.BCELoss(),
config={"lr": 1e-2},
workers_per_node=2,
backend="torch_distributed")
train_stats = estimator.fit(train_data_loader, epochs=2, batch_size=128)
print(train_stats)
val_stats = estimator.evaluate(val_data_loader, batch_size=64)
print(val_stats)
assert 0 < val_stats["val_accuracy"] < 1
assert estimator.get_model()

# Verify syncing weights, i.e. the two workers have the same weights after training
import ray
remote_workers = estimator.estimator.remote_workers
state_dicts = ray.get([worker.state_dict.remote() for worker in remote_workers])
weights = [state["models"] for state in state_dicts]
worker1_weights = weights[0][0]
worker2_weights = weights[1][0]
for layer in list(worker1_weights.keys()):
assert np.allclose(worker1_weights[layer].numpy(),
worker2_weights[layer].numpy())
estimator.shutdown()

def test_linear_spark_xshards(self):
def test_spark_xshards(self):
from zoo import init_nncontext
from zoo.orca.data import SparkXShards
estimator = Estimator.from_torch(model=get_model,
Expand Down
14 changes: 9 additions & 5 deletions pyzoo/zoo/orca/learn/pytorch/torch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

from filelock import FileLock
import logging
import inspect
import io
import itertools
import os
Expand Down Expand Up @@ -125,7 +124,7 @@ def setup_horovod(self):
self.rank = hvd.rank()
self.size = hvd.size()
self.setup_components_horovod()
self.setup_operator()
self.setup_operator(self.models)

def setup_address(self):
ip = ray.services.get_node_ip_address()
Expand All @@ -134,6 +133,7 @@ def setup_address(self):

def setup_torch_distribute(self, url, world_rank, world_size):
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
dist.init_process_group(
backend="gloo",
init_method=url,
Expand All @@ -143,7 +143,11 @@ def setup_torch_distribute(self, url, world_rank, world_size):
self.rank = world_rank
self.size = world_size
self.setup_components()
self.setup_operator()
training_models = [
DistributedDataParallel(model)
for model in self.models
]
self.setup_operator(training_models)

def setup_components(self):
"""Runs the creator functions without any distributed coordination."""
Expand Down Expand Up @@ -193,12 +197,12 @@ def setup_components_horovod(self):
self._create_schedulers_if_available()
self._create_loss()

def setup_operator(self):
def setup_operator(self, training_models):
"""Create the training operator."""
self.training_operator =\
self.training_operator_cls(
self.config,
models=self.models,
models=training_models,
optimizers=self.optimizers,
criterion=self.criterion,
world_rank=self.rank,
Expand Down

0 comments on commit abc38b8

Please sign in to comment.