Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Issue with using optimise_hyperparameter with PyTorch DDP #1588

Open
aman1b opened this issue Jul 31, 2024 · 2 comments
Open

[BUG] Issue with using optimise_hyperparameter with PyTorch DDP #1588

aman1b opened this issue Jul 31, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@aman1b
Copy link

aman1b commented Jul 31, 2024

Hi community,

I have been stuck on this issue for some time now and would greatly appreciate any help! I am trying to run the optimise_hyperparameter function over 2 A100GPU using PyTorch DDP strategy.

When I run this I get the following error:
RuntimeError: DDP expects same model across all ranks, but Rank 0 has 160 params, while rank 1 has inconsistent 137 params.

I have tried setting the seed across ranks but no luck. Has anyone experiences this issue or have an example of using this function and training a TFT with DDP?

I am using the latest package versions and training on an Azure VM. The application is run once I trigger the train_model function.

def prepare_data(data_prep_folder):

# Load in training and validation dataset
training = torch.load(f"{data_prep_folder}/{constants.TRAIN_DATASET_FILE_NAME}")
validation = torch.load(f"{data_prep_folder}/{constants.VALIDATION_DATASET_FILE_NAME}")

logger.info(f"Training set loaded with {len(training)} length.")
logger.info(f"Validation set loaded with {len(validation)} length.")

# Create dataloaders
train_dataloader = training.to_dataloader(
    train=True,
    batch_size=128,
    num_workers=47,
    pin_memory=True
)

val_dataloader = validation.to_dataloader(
    train=False,
    batch_size=128,
    num_workers=47,
    pin_memory=True
)

logger.info(f"Dataloaders created with 128 batch size and 47 workers.")
return train_dataloader, val_dataloader

def hyperparameter_tuner(train_dataloader, val_dataloader, model_train_folder):
# Start time
start_time = time.time()
logger.info("Starting hyperparameter tuning...")

# Create study
study = optimize_hyperparameters(
    train_dataloader,
    val_dataloader,
    model_path=model_train_folder,
    n_trials=2,
    max_epochs=30,
    gradient_clip_val_range=(0.01, 1.0),
    hidden_size_range=(8, 128),
    hidden_continuous_size_range=(8, 128),
    attention_head_size_range=(1, 4),
    learning_rate_range=(0.001, 0.1),
    dropout_range=(0.1, 0.3),
    trainer_kwargs=dict(
        accelerator='gpu', 
        strategy=DDPStrategy(),
        devices='auto', 
        limit_train_batches=10
    ),
    reduce_on_plateau_patience=4,
    use_learning_rate_finder=False
)

logger.info("Hyperparameter tuning finished.")

# Get best parameters
best_params = study.best_trial.params

logger.info(f"Best trial parameters: {best_params}")

training_time = time.time() - start_time
hours, remainder = divmod(training_time, 3600)
minutes, seconds = divmod(remainder, 60)

logger.info(f"Tuning took {int(hours)} hours, {int(minutes)} minutes, and {int(seconds)} seconds.")

return best_params
@aman1b
Copy link
Author

aman1b commented Aug 6, 2024

Can anyone help here? How can I use DDP with the optimize_hyperparameter function?

@fkiraly fkiraly added the bug Something isn't working label Aug 30, 2024
@fkiraly fkiraly changed the title Issue with using optimise_hyperparameter with PyTorch DDP -> please help! :) [BUG] Issue with using optimise_hyperparameter with PyTorch DDP Aug 30, 2024
@fkiraly
Copy link
Collaborator

fkiraly commented Aug 30, 2024

Potentially related to the windows failures reported here: #1623

Can you kindly paste the full output of pip list, from your python environment, and also let us know what your operating system and python version are?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: Needs triage & validation
Development

No branches or pull requests

2 participants