diff --git a/keras/src/layers/normalization/unit_normalization.py b/keras/src/layers/normalization/unit_normalization.py index 3b0b34f4d80..be77aa59c30 100644 --- a/keras/src/layers/normalization/unit_normalization.py +++ b/keras/src/layers/normalization/unit_normalization.py @@ -43,6 +43,18 @@ def call(self, inputs): return ops.normalize(inputs, axis=self.axis, order=2, epsilon=1e-12) def compute_output_shape(self, input_shape): + # Ensure axis is always treated as a list + if isinstance(self.axis, int): + axes = [self.axis] + else: + axes = self.axis + + for axis in axes: + if axis >= len(input_shape) or axis < -len(input_shape): + raise ValueError( + f"Axis {self.axis} is out of bounds for " + f"input shape {input_shape}." + ) return input_shape def get_config(self):