diff --git a/kimm/layers/learnable_affine.py b/kimm/layers/learnable_affine.py index ec4ede9..6e41b3f 100644 --- a/kimm/layers/learnable_affine.py +++ b/kimm/layers/learnable_affine.py @@ -34,7 +34,9 @@ def build(self, input_shape): self.built = True def call(self, inputs, training=None, mask=None): - return ops.add(ops.multiply(inputs, self.scale), self.bias) + scale = ops.cast(self.scale, self.compute_dtype) + bias = ops.cast(self.bias, self.compute_dtype) + return ops.add(ops.multiply(inputs, scale), bias) def get_config(self): config = super().get_config()