You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Notice that aside from setting the concat_pool=False, this is just normal fastai code.
The reason for the concat_pool flag set to False is due to a current bug which might be just Pytorch XLA not supporting the fastai adaptive concat pool layer.
The only real change is addition is the import of the fastai_xla_extensions.all module similar in spirit to the import of fastai top layer modules.
If we later find that there is a need for domain specific modules for fastai_xla_extensions to enable them to run on multi tpu cores (e.g. vision, tabular, text, etc) we can split it along the same lines as well. (Right now, there isn't.)
Training
Next comes the training, and this is the examplar code:
learner.xla_fit(5, lr=2e-3)
or
learner.xla_fit_one_cycle(5, lr_max=slice(2e-3))
or for any other fit methods (including fine_tune, 'lr_find) as well as the inference methods get_preds, the equivalent multi tpu core method will be the same method, but with the prefix xla_`.
These xla_ prefixed methods support the same method signature as its equivalent fit method with the addition of 3 default arguments num_cores=8, start_method='fork' and master_cbs=None.
The num_cores=8 argument is used to set the number of processes (and matching tpu cores utilized) - 1 or 8 are supported values on Colab.
I don't know what other values are possible on Google Cloud TPUs, but whatever Pytorch XLA supports should be possible (as this is just passed to xmp.spawn as the nprocs argument.
Also, as specified in the Pytorch XLA documentation, start_method='fork' is the only supported value on Colab -- this argument is also passed to xmp.spawn call.
The master_cbs=None argument will be discussed in the following section.
The rationale for this design is that the xla_ prefixed methods actually spawn multiple processes for each tpu core - and for each spawned process, the function called for each prefix method recreates the learner and then calls the equivalent fit method, e.g. the xla_fit method calls a function which then creates a learner and then calls the fit method on that learner.
This is why patching the fit methods will not be ideal -- each fit method will have to figure out whether it was called on the main process (aka the parent process) where it would spawn child processes for each tpu core, or if it was called on the spawned process, in which case it would execute the original algorithm for the fit method. It would be more complicated, and also might create duplicated code.
I am open to other proposals, if we could figure out a way to accomplish this cleanly, without having the duplicate the original fit code.
Handling Callbacks
One of the complications of running on multiple processes is that each process can interfere with each other, especially if the methods weren't designed to run in parallel.
One example where this conflict becomes apparent is the ProgressCallback which is responsible for displaying a progress bar whenever the model is training or doing some batch inference.
Since multiple training processes are running, each copy of the ProgressCallback would then print out its version of the progress bar plus the associated running data such as losses, time remaining etc, which becomes an ugly mess.
The current solution is to remove the progress callbacks on all the spawned processes except for the one called the master ordinal which has a rank 0, so there is only one progress bar displayed.
For additional callbacks, this might become a problem. In order to handle this, a new attribute called master_cbs is to be added to the Learner which are callbacks that are only added to the process with the master ordinal (rank 0). You can add these master-only callbacks by calling a new patched method add_master_cbs on the Learner.
For the xla_ prefixed methods that support a cbs=None parameter (for callbacks which run only for the duration of the fit call) , an additional master_cbs=None is also added, allowing you to add ephemeral callbacks that run only on the master ordinal process.
Model Weights and other Training Artifacts
One of the problems of running on multiple processes is that the learner instances running on the spawned processes go away after the completion of the spawned processes (including the master ordinal process).
And this means that trained model weights not synced with the model weights on the learner instance that spawned the processes. This is why at the end of the spawned processes, the master ordinal process saves the updated model weights to a temporary file. The main process then loads the updated model weights and the learner syncs its model to the updated model weights.
However, the other artifacts such as the list of losses, lr schedules also need to be saved, in order for the recorder to display the plot_loss, plot_sched, and other useful graphs to be displayed correctly.
Lastly, one limitation of the spawned processes is that drawing the matplotlib figures don't seem to show up (even if done only on the master ordinal process), so these need to be fixed so that the figures are saved to disk by the master ordinal process and the parent process then displays these figures correctly.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
This is a preliminary design discussion on how the multiple tpu core module interfaces will be
implemented.
Module Import, DataBlock, DataLoaders and Learner Setup
The following is the proposed exemplar code:
Notice that aside from setting the
concat_pool=False
, this is just normal fastai code.The reason for the concat_pool flag set to False is due to a current bug which might be just Pytorch XLA not supporting the fastai adaptive concat pool layer.
The only real change is addition is the import of the fastai_xla_extensions.all module similar in spirit to the import of fastai top layer modules.
If we later find that there is a need for domain specific modules for fastai_xla_extensions to enable them to run on multi tpu cores (e.g. vision, tabular, text, etc) we can split it along the same lines as well. (Right now, there isn't.)
Training
Next comes the training, and this is the examplar code:
or
or for any other
fit
methods (includingfine_tune
, 'lr_find) as well as the inference methods
get_preds, the equivalent multi tpu core method will be the same method, but with the prefix
xla_`.These
xla_
prefixed methods support the same method signature as its equivalentfit
method with the addition of 3 default argumentsnum_cores=8
,start_method='fork'
andmaster_cbs=None
.The
num_cores=8
argument is used to set the number of processes (and matching tpu cores utilized) -1
or8
are supported values on Colab.I don't know what other values are possible on Google Cloud TPUs, but whatever Pytorch XLA supports should be possible (as this is just passed to
xmp.spawn
as thenprocs
argument.Also, as specified in the Pytorch XLA documentation,
start_method='fork'
is the only supported value on Colab -- this argument is also passed toxmp.spawn
call.The
master_cbs=None
argument will be discussed in the following section.The rationale for this design is that the
xla_
prefixed methods actually spawn multiple processes for each tpu core - and for each spawned process, the function called for each prefix method recreates the learner and then calls the equivalentfit
method, e.g. thexla_fit
method calls a function which then creates a learner and then calls the fit method on that learner.This is why patching the
fit
methods will not be ideal -- each fit method will have to figure out whether it was called on the main process (aka the parent process) where it would spawn child processes for each tpu core, or if it was called on the spawned process, in which case it would execute the original algorithm for thefit
method. It would be more complicated, and also might create duplicated code.I am open to other proposals, if we could figure out a way to accomplish this cleanly, without having the duplicate the original
fit
code.Handling Callbacks
One of the complications of running on multiple processes is that each process can interfere with each other, especially if the methods weren't designed to run in parallel.
One example where this conflict becomes apparent is the
ProgressCallback
which is responsible for displaying a progress bar whenever the model is training or doing some batch inference.Since multiple training processes are running, each copy of the ProgressCallback would then print out its version of the progress bar plus the associated running data such as losses, time remaining etc, which becomes an ugly mess.
The current solution is to remove the progress callbacks on all the spawned processes except for the one called the master ordinal which has a rank 0, so there is only one progress bar displayed.
For additional callbacks, this might become a problem. In order to handle this, a new attribute called
master_cbs
is to be added to theLearner
which are callbacks that are only added to the process with the master ordinal (rank 0). You can add these master-only callbacks by calling a new patched methodadd_master_cbs
on theLearner
.For the
xla_
prefixed methods that support acbs=None
parameter (for callbacks which run only for the duration of the fit call) , an additionalmaster_cbs=None
is also added, allowing you to add ephemeral callbacks that run only on the master ordinal process.Model Weights and other Training Artifacts
One of the problems of running on multiple processes is that the learner instances running on the spawned processes go away after the completion of the spawned processes (including the master ordinal process).
And this means that trained model weights not synced with the model weights on the learner instance that spawned the processes. This is why at the end of the spawned processes, the master ordinal process saves the updated model weights to a temporary file. The main process then loads the updated model weights and the learner syncs its model to the updated model weights.
However, the other artifacts such as the list of losses, lr schedules also need to be saved, in order for the recorder to display the plot_loss, plot_sched, and other useful graphs to be displayed correctly.
Lastly, one limitation of the spawned processes is that drawing the matplotlib figures don't seem to show up (even if done only on the master ordinal process), so these need to be fixed so that the figures are saved to disk by the master ordinal process and the parent process then displays these figures correctly.
Beta Was this translation helpful? Give feedback.
All reactions