Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finding Good Learning Rate For Different Values of Depth, Heads #84

Closed
afiaka87 opened this issue Mar 13, 2021 · 8 comments
Closed

Finding Good Learning Rate For Different Values of Depth, Heads #84

afiaka87 opened this issue Mar 13, 2021 · 8 comments

Comments

@afiaka87
Copy link
Contributor

afiaka87 commented Mar 13, 2021

I've discovered that dalle-pytorch is somehow resilient to a variety of learning rates somehow. Use 3e-4 unless you have reason otherwise. This information here is incorrect.

Most important result I've found is that a learning rate of 4e-4 to 5e-4 works better than 3e-4 for depth >= 26. Increase the default when training with higher depth!

I had access to two A100's with 40 GiB of VRAM yesterday so I did a "hyperparameter sweep" with Weights and Biases.

I only chose three parameters to tune: learning rate, depth and heads.

Wandb ran the first 1200 iterations of a training session 48 times while varying those values. Here are the results:

https://wandb.ai/afiaka87/hp_tuning/reports/DALLE-Pytorch-Sweep-over-Learning-Rate-Depth-and-Heads--Vmlldzo1Mjg4Mjk

loss-over-time-and-importance
all_params

@afiaka87
Copy link
Contributor Author

Here is the sweep.yml I used. I also had to modfiy train_dalle.py to accept learning rate, depth and heads as arguments. Aside from that, it's pretty much a "one-button click" process. Which is super useful! You can even run the same code in the same docker container on a totally different machine. Allows for much cheaper hyperparameter tuning because it lends itself well to pre-emptible instances and even google colab.

method: random
metric:
  goal: minimize
  name: dalle_loss
parameters:
  depth:
    distribution: int_uniform
    max: 32
    min: 16
  heads:
    distribution: int_uniform
    max: 8
    min: 2
  learning_rate:
    distribution: uniform
    max: 0.0005
    min: 0.0001
program: train_dalle.py

@afiaka87 afiaka87 changed the title Finding Good Hyperparameters Finding Good Learning Rate For Different Values of Depth, Heads Mar 13, 2021
@afiaka87
Copy link
Contributor Author

While the parallel parameters graphs are cool, it's tough to get info out of them at first glance. Here's the top 12 performers along with their parameters, loss and runtime (another important metric that's easy to forget about).
top12performers

@TheodoreGalanos
Copy link

TheodoreGalanos commented Mar 14, 2021

Interesting thanks for sharing! I was just trying to find out if there's data / intuition about the impact of the different hyperparameters in dall-e.

Does that mean that smaller # of heads are viable? The default value seems to be at 16, although I still haven't checked to see if that was following the paper or it is a placeholder..

Also, I'm curious if you have some intuition on the relationship of that with the size of the codebook. I'm considering a 2048 codebook, wonder how that affects the other parameters (maybe 12 layers are enough?)

Thanks!

@afiaka87
Copy link
Contributor Author

afiaka87 commented Mar 14, 2021

Interesting thanks for sharing! I was just trying to find out if there's data / intuition about the impact of the different hyperparameters in dall-e.

Does that mean that smaller # of heads are viable? The default value seems to be at 16, although I still haven't checked to see if that was following the paper or it is a placeholder..

Also, I'm curious if you have some intuition on the relationship of that with the size of the codebook. I'm considering a 2048 codebook, wonder how that affects the other parameters (maybe 12 layers are enough?)

Thanks!

It's tough to say if a smaller number of heads is viable, but you're correct there are a few runs that get lucky enough to make that work well. I really dont think 46 runs is representative enough. More importantly though, if you check that link you'll see an "importance" score. This is calculated by running a decision forest tree on your hyper parameters in order to give you a "correlation" metric that accounts for second-order effects.

As you'll see, most important on the list for decreasing loss is depth, second is learning rate, heads is last. but I only let it run with a HEADs count of 4 to 8, so it had minimal impact kind of by design.

Anyway, it's not rigorous but your theory might be correct. You will however have to test it yourself ha. Let me know how it goes.

Edit: Also, I'm seeing a default of 8 for heads in the current codebase:

And yeah, it's early days and while lucid et al have a lot from OpenAI, it's still important to do these sanity checks at the beginning so we dont waste too much time.

@afiaka87
Copy link
Contributor Author

I think a good next step would be to lock in as many of the OpenAI parameters that we know of and try to find a good learning rate for it. Unfortunately it seems even with the (super useful) reversible=True parameter, going to a depth of 64 requires a lot of VRAM. Moreso than just 16GB, I think.

@awilson9
Copy link

The default value seems to be at 16, although I still haven't checked to see if that was following the paper or it is a placeholder..

This is what is mentioned in the paper (Section B.1)

We use 64 attention layers, each of which uses 62 attention heads with a per-head state
size of 64.

Correct me if i'm wrong but I believe this correlates to

DEPTH=64
HEADS=62
DIM_HEAD=64

@awilson9
Copy link

awilson9 commented Mar 15, 2021

Unfortunately with these parameters the largest batch size I'm able to run with my 24GB card is 2

Edit: Apparently CUDA 11.2 has known performance issues with pytorch. Downgrading to 11.1.1 allowed me to increase the batch size to 4

@afiaka87
Copy link
Contributor Author

I'm forced to admit that this network is ultimately just very resilient to a variety of learning rates somehow. Use 3e-4 unless you have reason otherwise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants