Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Go from one step being an epoch to one step being a batch #1802

Closed
wants to merge 71 commits into from

Conversation

APJansen
Copy link
Collaborator

@APJansen APJansen commented Aug 30, 2023

As part of the changes mentioned in issue #1803, this is a very simple change that results in a factor 2 speedup.

It simply copies the input x grids by the number of epochs, and calls one step a batch rather than an epoch. This avoids some Tensorflow overhead that, with the other improvements mentioned there, take up nearly 50% of the total training time.

The current state is that the fit runs, but some changes need to be made downstream as it's crashing (perhaps just undoing the changes I made just before the fit, just after the fit?).

If anyone wants to take this up, please do.

To illustrate what this does, here is a tensorboard profile without doing this:
image
These gaps are almost completely removed by this PR.

@APJansen APJansen added the help wanted Extra attention is needed label Aug 30, 2023
@APJansen
Copy link
Collaborator Author

APJansen commented Sep 4, 2023

Tested this on the basic runcard with 1 replica on the CPU. It's about 10% faster even there. It's running fully now, results are identical with master when looking at the validation chi2, although the training chi2 is completely different. But since after 100s of epochs the validations are still identical, it must be a bug in the computation of this training chi2 rather than something in the actual training being different.

@APJansen
Copy link
Collaborator Author

APJansen commented Sep 6, 2023

The issue was that for the training losses, Keras computes a running average over batches. I just added a function in between that corrects the logs to give the current batch's losses.
There is still something wrong though. The relative difference in training chi2 starts out at 10^-8, so that could be round off errors from taking and "un-taking" the average over batches. But after 1000 epochs it has grown to be of order 1, whereas the numerical errors shouldn't even stack, each computation is independent. Probably they are taking a weighted average biased towards more recent batches or something. (validation remains identical to the last digit)

Apart from this I think there are 2 points remaining:

  1. Is the extra memory use ever an issue? I haven't had any issues, but it does seem very stupid to do it like this. Perhaps it can be rewritten in terms of a data generator that just returns the single datapoint for epochs number of times.
  2. Whether and how to change terminology. Perhaps the cleanest would be to use "steps" rather than "epochs", and implement a step as a batch. But to do that consistently would require a lot of changes, and also the runcard syntax. I imagine that is not wanted. We can also keep using epochs everywhere except the changes here. I can only see it causing issues when implementing a new callback, but since that would start probably by looking at the existing ones, where now only on_batch_end is used, it will probably be ok.

@scarlehoff
Copy link
Member

Is the extra memory use ever an issue?

Yes. When running many replicas in parallel increasing the memory usage drastically reduce how many can you actually fit in a cluster. If the gain on CPU is only of ~10% it might not be worth it. We might want to have some branching even "run_in_batches=True" in the runcard.

Whether and how to change terminology.

Better not to change the terminology. The on_X_end have to change because these are internal to tensorflow but now we use "epoch" to mean "training iterations" (take into account that most of the people in the collaboration don't actually touch the code and epoch is now part of the vocabulary). The people touching this part of the code are only like 3 including you and they are all aware of the change :P

Comment on lines 168 to 172
# This looks stupid, but it's actually faster, as it avoids some Tensorflow overhead
# every epoch. Each step is now a batch rather than an epoch
for k, v in x_params.items():
x_params[k] = tf.repeat(v, epochs, axis=0)
y = [tf.repeat(yi, epochs, axis=0) for yi in y]
Copy link
Member

@scarlehoff scarlehoff Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# This looks stupid, but it's actually faster, as it avoids some Tensorflow overhead
# every epoch. Each step is now a batch rather than an epoch
for k, v in x_params.items():
x_params[k] = tf.repeat(v, epochs, axis=0)
y = [tf.repeat(yi, epochs, axis=0) for yi in y]
# Instead of running # epochs, run #epochs batches in a single epoch
# this avoids some Tensorflow overhead and it's actually faster
for k, v in x_params.items():
x_params[k] = tf.repeat(v, epochs, axis=0)
y = [tf.repeat(yi, epochs, axis=0) for yi in y]

I'd say this is a clever trick, not a stupid one.

RE the memory usage, I wonder whether there's a way of tricking tensorflow to pass a tensor of length 1 in the batch dimension but make it believe there are actually many.
It depends on how the batch-taking is implemented in tensorflow but maybe we can do something like

class FakeTensor(Tensor):
    def get_batch(self, i):
        if i < self._epochs:
             return self._true_tensor
         return end_signal

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented something like this, it works, but unfortunately it seems to completely negate the benefits of the copying. No idea why.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might need to actually trick tensorflow for that to work.

However, I think the timings are quite ok as they are so we can leave this in the backburner for the time being. No need to overoptimize when there are many other bigger problems in the way.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well they seem ok here, but that's just because the rest is slow ;P It's about a factor 2 speedup once the other optimizations in issue #1803 are implemented.
But sure it can wait until the rest is done. I hoped to fix this quickly before I go on holiday, but of course it's always a bit more tricky than you expect. I think the simplest is to just make say 1000 copies and train for epochs/1000.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's about a factor 2 speedup once the other optimizations in issue #1803 are implemented.

But is a factor 2 also in CPU? Because if it is only an improvement for GPU I'd say it's better to branch there. GPU and CPU are different enough devices that I think some branching is ok, like the eigen/tensordot thing and in CPU the memory growth can harm more than the 10% gain you quoted before.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, it reduces the gap between steps from ~50 to ~3 ms, so it depends how much effect the other refactorings have on the CPU. If it's still not significant after that we can always default to the old behavior when the number of replicas is 1.

@scarlehoff
Copy link
Member

The issue was that for the training losses, Keras computes a running average over batches. I just added a function in between that corrects the logs to give the current batch's losses.
There is still something wrong though. The relative difference in training chi2 starts out at 10^-8, so that could be round off errors from taking and "un-taking" the average over batches. But after 1000 epochs it has grown to be of order 1, whereas the numerical errors shouldn't even stack, each computation is independent. Probably they are taking a weighted average biased towards more recent batches or something. (validation remains identical to the last digit)

Thinking about this, you said that validation remain identical, what about the weights of the neural networks after a complete fit? If those are the same then it is clear that everything stays the same and the difference is only on the reporting, which is ok. For the final training loss (i.e., the one that we write in the .json file) we can evaluate the loss manually and write that.

@APJansen
Copy link
Collaborator Author

APJansen commented Sep 8, 2023

Better not to change the terminology. The on_X_end have to change because these are internal to tensorflow but now we use "epoch" to mean "training iterations" (take into account that most of the people in the collaboration don't actually touch the code and epoch is now part of the vocabulary). The people touching this part of the code are only like 3 including you and they are all aware of the change :P

Haha ok, I agree. I just put a comment at the top of the callbacks module.

Thinking about this, you said that validation remain identical, what about the weights of the neural networks after a complete fit? If those are the same then it is clear that everything stays the same and the difference is only on the reporting, which is ok. For the final training loss (i.e., the one that we write in the .json file) we can evaluate the loss manually and write that.

The weights must also be the same. I didn't check but literally every digit of every of the first 1000 epochs (didn't train for longer) are identical, would be quite the coincidence if the weights were different ;P I'm not sure if it's not an issue though, I think we want more than the final loss to understand what the model is doing. My current approach for correcting this doesn't work, not just because it's off (which I still don't understand), but also because it relies on being run every epoch, which is what I was testing with, but usually it's only every 100 epochs.

@scarlehoff
Copy link
Member

Yes. When running many replicas in parallel increasing the memory usage drastically reduce how many can you actually fit in a cluster. If the gain on CPU is only of ~10% it might not be worth it. We might want to have some branching even "run_in_batches=True" in the runcard.

There's a solution that would work in both cases, without the branching. We could put the data in the first layer as a fixed layer (shared between all replicas) and then the input can be a very long list of None, it should have no effect on the memory.

@APJansen
Copy link
Collaborator Author

So the approach I've chosen is:

  • for 1 replica (in practice, CPU), don't change anything
  • for multiple replicas, copy up to 100 times, and if the number of epochs isn't divisible by this, try 10, and log a warning.

The reason for the first point is that it's not worth it, as it does come with the cost of correcting the training logs, which will slow it down overall.
For the latter, using 100 copies leaves only 0.5% speedup on the table (after other refactorings to come), so this seems like a good tradeoff.

I did this mostly by creating a class CallbackStep, which all the others now inherit from, that takes care of these conversions between epochs and batches, and calls a on_step_end that the others define.

I've tested that for 1 replica results are identical, and for multiple only the training chi2's differ slightly, up to 0.1%. This is because of the conversion of the logs back from an average to a single step loss, and so it doesn't affect the training at all.

I still need to check the performance, but maybe I'll wait with that until enough other refactorings are merged that this becomes substantial.

This was referenced Feb 13, 2024
@APJansen
Copy link
Collaborator Author

I have revived this in #1939 (was having trouble rebasing, and didn't want to waste time if I could just cherry-pick), so I think this can be closed.

@APJansen APJansen closed this Feb 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants