Skip to content

Commit

Permalink
Merge branch 'master' into polarised-dy
Browse files Browse the repository at this point in the history
  • Loading branch information
Radonirinaunimi committed Jul 29, 2024
2 parents 29346df + 68f5c66 commit 16580b7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
6 changes: 5 additions & 1 deletion n3fit/src/n3fit/backends/keras_backend/MetaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
else: # in case of disaster
_to_numpy_or_python_type = lambda ret: {k: i.numpy() for k, i in ret.items()}

# Starting with TF 2.16, a memory leak in TF https://github.com/tensorflow/tensorflow/issues/64170
# makes jit compilation unusable in GPU.
# Before TF 2.16 it was set to `False` by default. From 2.16 onwards, it is set to `True`
JIT_COMPILE = False

# Define in this dictionary new optimizers as well as the arguments they accept
# (with default values if needed be)
Expand Down Expand Up @@ -307,7 +311,7 @@ def compile(
target_output = [target_output]
self.target_tensors = target_output

super().compile(optimizer=opt, loss=loss)
super().compile(optimizer=opt, loss=loss, jit_compile=JIT_COMPILE)

def set_masks_to(self, names, val=0.0):
"""Set all mask value to the selected value
Expand Down
6 changes: 6 additions & 0 deletions n3fit/src/n3fit/stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
which will tell `Validation` that no validation set was found and that the training is to
be used instead.
"""

import logging

import numpy as np
Expand Down Expand Up @@ -345,6 +346,8 @@ def __init__(
self._threshold_chi2 = threshold_chi2
self._stopping_degrees = np.zeros(self._n_replicas, dtype=int)
self._counts = np.zeros(self._n_replicas, dtype=int)
# Keep track of the replicas that should not be stopped yet
self._dont_stop_me_now = np.ones(self._n_replicas, dtype=bool)

self._dont_stop = dont_stop
self._stop_now = False
Expand Down Expand Up @@ -451,6 +454,8 @@ def monitor_chi2(self, training_info, epoch, print_stats=False):
passes &= fitstate.vl_loss < self._best_val_chi2s
# And the ones that pass positivity
passes &= self._positivity(fitstate)
# Stop replicas that are ok being stopped (because they are finished or otherwise)
passes &= self._dont_stop_me_now

self._stopping_degrees += self._counts

Expand All @@ -470,6 +475,7 @@ def monitor_chi2(self, training_info, epoch, print_stats=False):
for i_replica in np.where(stop_replicas)[0]:
self._stop_epochs[i_replica] = epoch
self._counts[i_replica] = 0
self._dont_stop_me_now[i_replica] = False

# By using the stopping degree we only stop when none of the replicas are improving anymore
if min(self._stopping_degrees) > self.stopping_patience:
Expand Down

0 comments on commit 16580b7

Please sign in to comment.