From 3c6cbdf392bb9373373362a36a10c8367808d4bb Mon Sep 17 00:00:00 2001 From: Aron Date: Wed, 25 Oct 2023 08:13:20 +0200 Subject: [PATCH] Merge prefactors into single layer --- .../n3fit/backends/keras_backend/MetaModel.py | 46 +++++++++++++++++-- n3fit/src/n3fit/layers/preprocessing.py | 16 +++++-- n3fit/src/n3fit/model_gen.py | 27 ++++------- 3 files changed, 65 insertions(+), 24 deletions(-) diff --git a/n3fit/src/n3fit/backends/keras_backend/MetaModel.py b/n3fit/src/n3fit/backends/keras_backend/MetaModel.py index dfefd38401..dfaad2d13d 100644 --- a/n3fit/src/n3fit/backends/keras_backend/MetaModel.py +++ b/n3fit/src/n3fit/backends/keras_backend/MetaModel.py @@ -358,7 +358,7 @@ def get_replica_weights(self, i_replica): ] prepro_weights = [ tf.Variable(w, name=w.name) - for w in self.get_layer(f"{PREPROCESSING_PREFIX}_{i_replica}").weights + for w in get_layer_replica_weights(self.get_layer(PREPROCESSING_PREFIX), i_replica) ] weights = {NN_PREFIX: NN_weights, PREPROCESSING_PREFIX: prepro_weights} @@ -379,8 +379,10 @@ def set_replica_weights(self, weights, i_replica=0): the replica number to set, defaulting to 0 """ self.get_layer(f"{NN_PREFIX}_{i_replica}").set_weights(weights[NN_PREFIX]) - self.get_layer(f"{PREPROCESSING_PREFIX}_{i_replica}").set_weights( - weights[PREPROCESSING_PREFIX] + set_layer_replica_weights( + layer=self.get_layer(PREPROCESSING_PREFIX), + weights=weights[PREPROCESSING_PREFIX], + i_replica=i_replica, ) def split_replicas(self): @@ -465,3 +467,41 @@ def append_weights(name, node): weights_ordered.append(w_h5) return weights_ordered + + +def get_layer_replica_weights(layer, i_replica: int): + """ + Get the weights for the given single replica `i_replica`, + from a `layer` that has weights for all replicas. + + Parameters + ---------- + layer: MetaLayer + the layer to get the weights from + i_replica: int + the replica number + + Returns + ------- + weights: list + list of weights for the replica + """ + return [tf.Variable(w[i_replica : i_replica + 1], name=w.name) for w in layer.weights] + + +def set_layer_replica_weights(layer, weights, i_replica: int): + """ + Set the weights for the given single replica `i_replica`, + from a `layer` that has weights for all replicas. + + Parameters + ---------- + layer: MetaLayer + the layer to set the weights for + weights: list + list of weights for the replica + i_replica: int + the replica number + """ + for w, w_new in zip(layer.weights, weights): + w[i_replica].assign(w_new[0]) diff --git a/n3fit/src/n3fit/layers/preprocessing.py b/n3fit/src/n3fit/layers/preprocessing.py index 77ea760607..f8ab1f8f55 100644 --- a/n3fit/src/n3fit/layers/preprocessing.py +++ b/n3fit/src/n3fit/layers/preprocessing.py @@ -33,6 +33,8 @@ class Preprocessing(MetaLayer): Whether large x preprocessing factor should be active seed: int seed for the initializer of the random alpha and beta values + num_replicas: int (default 1) + The number of replicas """ def __init__( @@ -40,6 +42,7 @@ def __init__( flav_info: Optional[list] = None, seed: int = 0, large_x: bool = True, + num_replicas: int = 1, **kwargs, ): if flav_info is None: @@ -49,6 +52,8 @@ def __init__( self.flav_info = flav_info self.seed = seed self.large_x = large_x + self.num_replicas = num_replicas + self.alphas = [] self.betas = [] super().__init__(**kwargs) @@ -87,7 +92,7 @@ def generate_weight(self, name: str, kind: str, dictionary: dict, set_to_zero: b # Generate the new trainable (or not) parameter newpar = self.builder_helper( name=name, - kernel_shape=(1,), + kernel_shape=(self.num_replicas, 1), initializer=initializer, trainable=trainable, constraint=constraint, @@ -117,9 +122,12 @@ def call(self, x): Returns ------- - prefactor: tensor(shape=[1,N,F]) + prefactor: tensor(shape=[1,R,N,F]) """ - alphas = op.stack(self.alphas, axis=1) - betas = op.stack(self.betas, axis=1) + # weight tensors of shape (R, 1, F) + alphas = op.stack(self.alphas, axis=-1) + betas = op.stack(self.betas, axis=-1) + + x = op.batchit(x, batch_dimension=0) return x ** (1 - alphas) * (1 - x) ** betas diff --git a/n3fit/src/n3fit/model_gen.py b/n3fit/src/n3fit/model_gen.py index a700a10fa9..49be7b3645 100644 --- a/n3fit/src/n3fit/model_gen.py +++ b/n3fit/src/n3fit/model_gen.py @@ -669,19 +669,18 @@ def pdfNN_layer_generator( else: sumrule_layer = lambda x: x + compute_preprocessing_factor = Preprocessing( + flav_info=flav_info, + input_shape=(1,), + name="preprocessing_factor", + seed=42, # TODO Aron: figure out what to do here + large_x=not subtract_one, + num_replicas=num_replicas, + ) + # Only these layers change from replica to replica: nn_replicas = [] - preprocessing_factor_replicas = [] for i_replica, replica_seed in enumerate(seed): - preprocessing_factor_replicas.append( - Preprocessing( - flav_info=flav_info, - input_shape=(1,), - name=f"preprocessing_factor_{i_replica}", - seed=replica_seed + number_of_layers, - large_x=not subtract_one, - ) - ) nn_replicas.append( generate_nn( layer_type=layer_type, @@ -713,12 +712,6 @@ def neural_network_replicas(x, postfix=""): return NNs_x - # Apply preprocessing factors for all replicas to a given input grid - def preprocessing_replicas(x, postfix=""): - return Lambda(lambda pfs: op.stack(pfs, axis=1), name=f"prefactors{postfix}")( - [pf(x) for pf in preprocessing_factor_replicas] - ) - def compute_unnormalized_pdf(x, postfix=""): # Preprocess the input grid x_nn_input = extract_nn_input(x) @@ -729,7 +722,7 @@ def compute_unnormalized_pdf(x, postfix=""): NNs_x = neural_network_replicas(x_processed, postfix=postfix) # Compute the preprocessing factor - preprocessing_factors_x = preprocessing_replicas(x_original, postfix=postfix) + preprocessing_factors_x = compute_preprocessing_factor(x_original) # Apply the preprocessing factor pref_NNs_x = apply_preprocessing_factor([preprocessing_factors_x, NNs_x])