Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable jit compilation in tf > 2.16 #2135

Merged
merged 2 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion n3fit/src/n3fit/backends/keras_backend/MetaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
else: # in case of disaster
_to_numpy_or_python_type = lambda ret: {k: i.numpy() for k, i in ret.items()}

# Starting with TF 2.16, a memory leak in TF https://github.com/tensorflow/tensorflow/issues/64170
# makes jit compilation unusable in GPU.
# Before TF 2.16 it was set to `False` by default. From 2.16 onwards, it is set to `True`
JIT_COMPILE = False

# Define in this dictionary new optimizers as well as the arguments they accept
# (with default values if needed be)
Expand Down Expand Up @@ -307,7 +311,7 @@ def compile(
target_output = [target_output]
self.target_tensors = target_output

super().compile(optimizer=opt, loss=loss)
super().compile(optimizer=opt, loss=loss, jit_compile=JIT_COMPILE)

def set_masks_to(self, names, val=0.0):
"""Set all mask value to the selected value
Expand Down
6 changes: 6 additions & 0 deletions n3fit/src/n3fit/stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
which will tell `Validation` that no validation set was found and that the training is to
be used instead.
"""

import logging

import numpy as np
Expand Down Expand Up @@ -345,6 +346,8 @@ def __init__(
self._threshold_chi2 = threshold_chi2
self._stopping_degrees = np.zeros(self._n_replicas, dtype=int)
self._counts = np.zeros(self._n_replicas, dtype=int)
# Keep track of the replicas that should not be stopped yet
self._dont_stop_me_now = np.ones(self._n_replicas, dtype=bool)

self._dont_stop = dont_stop
self._stop_now = False
Expand Down Expand Up @@ -451,6 +454,8 @@ def monitor_chi2(self, training_info, epoch, print_stats=False):
passes &= fitstate.vl_loss < self._best_val_chi2s
# And the ones that pass positivity
passes &= self._positivity(fitstate)
# Stop replicas that are ok being stopped (because they are finished or otherwise)
passes &= self._dont_stop_me_now

self._stopping_degrees += self._counts

Expand All @@ -470,6 +475,7 @@ def monitor_chi2(self, training_info, epoch, print_stats=False):
for i_replica in np.where(stop_replicas)[0]:
self._stop_epochs[i_replica] = epoch
self._counts[i_replica] = 0
self._dont_stop_me_now[i_replica] = False

# By using the stopping degree we only stop when none of the replicas are improving anymore
if min(self._stopping_degrees) > self.stopping_patience:
Expand Down