Skip to content

Commit

Permalink
Fix mixed precision for LearnableAffine
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Mar 23, 2024
1 parent 6ef977c commit e4a8eec
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion kimm/layers/learnable_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def build(self, input_shape):
self.built = True

def call(self, inputs, training=None, mask=None):
return ops.add(ops.multiply(inputs, self.scale), self.bias)
scale = ops.cast(self.scale, self.compute_dtype)
bias = ops.cast(self.bias, self.compute_dtype)
return ops.add(ops.multiply(inputs, scale), bias)

def get_config(self):
config = super().get_config()
Expand Down

0 comments on commit e4a8eec

Please sign in to comment.