Skip to content

Commit

Permalink
Merge prefactors into single layer
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Dec 4, 2023
1 parent 453a98f commit 3c6cbdf
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 24 deletions.
46 changes: 43 additions & 3 deletions n3fit/src/n3fit/backends/keras_backend/MetaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def get_replica_weights(self, i_replica):
]
prepro_weights = [
tf.Variable(w, name=w.name)
for w in self.get_layer(f"{PREPROCESSING_PREFIX}_{i_replica}").weights
for w in get_layer_replica_weights(self.get_layer(PREPROCESSING_PREFIX), i_replica)
]
weights = {NN_PREFIX: NN_weights, PREPROCESSING_PREFIX: prepro_weights}

Expand All @@ -379,8 +379,10 @@ def set_replica_weights(self, weights, i_replica=0):
the replica number to set, defaulting to 0
"""
self.get_layer(f"{NN_PREFIX}_{i_replica}").set_weights(weights[NN_PREFIX])
self.get_layer(f"{PREPROCESSING_PREFIX}_{i_replica}").set_weights(
weights[PREPROCESSING_PREFIX]
set_layer_replica_weights(
layer=self.get_layer(PREPROCESSING_PREFIX),
weights=weights[PREPROCESSING_PREFIX],
i_replica=i_replica,
)

def split_replicas(self):
Expand Down Expand Up @@ -465,3 +467,41 @@ def append_weights(name, node):
weights_ordered.append(w_h5)

return weights_ordered


def get_layer_replica_weights(layer, i_replica: int):
"""
Get the weights for the given single replica `i_replica`,
from a `layer` that has weights for all replicas.
Parameters
----------
layer: MetaLayer
the layer to get the weights from
i_replica: int
the replica number
Returns
-------
weights: list
list of weights for the replica
"""
return [tf.Variable(w[i_replica : i_replica + 1], name=w.name) for w in layer.weights]


def set_layer_replica_weights(layer, weights, i_replica: int):
"""
Set the weights for the given single replica `i_replica`,
from a `layer` that has weights for all replicas.
Parameters
----------
layer: MetaLayer
the layer to set the weights for
weights: list
list of weights for the replica
i_replica: int
the replica number
"""
for w, w_new in zip(layer.weights, weights):
w[i_replica].assign(w_new[0])
16 changes: 12 additions & 4 deletions n3fit/src/n3fit/layers/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ class Preprocessing(MetaLayer):
Whether large x preprocessing factor should be active
seed: int
seed for the initializer of the random alpha and beta values
num_replicas: int (default 1)
The number of replicas
"""

def __init__(
self,
flav_info: Optional[list] = None,
seed: int = 0,
large_x: bool = True,
num_replicas: int = 1,
**kwargs,
):
if flav_info is None:
Expand All @@ -49,6 +52,8 @@ def __init__(
self.flav_info = flav_info
self.seed = seed
self.large_x = large_x
self.num_replicas = num_replicas

self.alphas = []
self.betas = []
super().__init__(**kwargs)
Expand Down Expand Up @@ -87,7 +92,7 @@ def generate_weight(self, name: str, kind: str, dictionary: dict, set_to_zero: b
# Generate the new trainable (or not) parameter
newpar = self.builder_helper(
name=name,
kernel_shape=(1,),
kernel_shape=(self.num_replicas, 1),
initializer=initializer,
trainable=trainable,
constraint=constraint,
Expand Down Expand Up @@ -117,9 +122,12 @@ def call(self, x):
Returns
-------
prefactor: tensor(shape=[1,N,F])
prefactor: tensor(shape=[1,R,N,F])
"""
alphas = op.stack(self.alphas, axis=1)
betas = op.stack(self.betas, axis=1)
# weight tensors of shape (R, 1, F)
alphas = op.stack(self.alphas, axis=-1)
betas = op.stack(self.betas, axis=-1)

x = op.batchit(x, batch_dimension=0)

return x ** (1 - alphas) * (1 - x) ** betas
27 changes: 10 additions & 17 deletions n3fit/src/n3fit/model_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,19 +669,18 @@ def pdfNN_layer_generator(
else:
sumrule_layer = lambda x: x

compute_preprocessing_factor = Preprocessing(
flav_info=flav_info,
input_shape=(1,),
name="preprocessing_factor",
seed=42, # TODO Aron: figure out what to do here
large_x=not subtract_one,
num_replicas=num_replicas,
)

# Only these layers change from replica to replica:
nn_replicas = []
preprocessing_factor_replicas = []
for i_replica, replica_seed in enumerate(seed):
preprocessing_factor_replicas.append(
Preprocessing(
flav_info=flav_info,
input_shape=(1,),
name=f"preprocessing_factor_{i_replica}",
seed=replica_seed + number_of_layers,
large_x=not subtract_one,
)
)
nn_replicas.append(
generate_nn(
layer_type=layer_type,
Expand Down Expand Up @@ -713,12 +712,6 @@ def neural_network_replicas(x, postfix=""):

return NNs_x

# Apply preprocessing factors for all replicas to a given input grid
def preprocessing_replicas(x, postfix=""):
return Lambda(lambda pfs: op.stack(pfs, axis=1), name=f"prefactors{postfix}")(
[pf(x) for pf in preprocessing_factor_replicas]
)

def compute_unnormalized_pdf(x, postfix=""):
# Preprocess the input grid
x_nn_input = extract_nn_input(x)
Expand All @@ -729,7 +722,7 @@ def compute_unnormalized_pdf(x, postfix=""):
NNs_x = neural_network_replicas(x_processed, postfix=postfix)

# Compute the preprocessing factor
preprocessing_factors_x = preprocessing_replicas(x_original, postfix=postfix)
preprocessing_factors_x = compute_preprocessing_factor(x_original)

# Apply the preprocessing factor
pref_NNs_x = apply_preprocessing_factor([preprocessing_factors_x, NNs_x])
Expand Down

0 comments on commit 3c6cbdf

Please sign in to comment.