From e4a8eece7673c90d34d8b8304b04514a0a7616d6 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sat, 23 Mar 2024 19:02:49 +0800 Subject: [PATCH] Fix mixed precision for `LearnableAffine` --- kimm/layers/learnable_affine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/kimm/layers/learnable_affine.py b/kimm/layers/learnable_affine.py index ec4ede9..6e41b3f 100644 --- a/kimm/layers/learnable_affine.py +++ b/kimm/layers/learnable_affine.py @@ -34,7 +34,9 @@ def build(self, input_shape): self.built = True def call(self, inputs, training=None, mask=None): - return ops.add(ops.multiply(inputs, self.scale), self.bias) + scale = ops.cast(self.scale, self.compute_dtype) + bias = ops.cast(self.bias, self.compute_dtype) + return ops.add(ops.multiply(inputs, scale), bias) def get_config(self): config = super().get_config()