-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add HGNet and HGNetV2 * Update README and requirements.txt * Update configs * Fix mixed precision for `LearnableAffine`
- Loading branch information
1 parent
8995b40
commit 195a622
Showing
14 changed files
with
1,224 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from kimm.layers.attention import Attention | ||
from kimm.layers.layer_scale import LayerScale | ||
from kimm.layers.learnable_affine import LearnableAffine | ||
from kimm.layers.mobile_one_conv2d import MobileOneConv2D | ||
from kimm.layers.position_embedding import PositionEmbedding | ||
from kimm.layers.rep_conv2d import RepConv2D |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import keras | ||
from keras import layers | ||
from keras import ops | ||
|
||
|
||
@keras.saving.register_keras_serializable(package="kimm") | ||
class LearnableAffine(layers.Layer): | ||
def __init__(self, scale_value=1.0, bias_value=0.0, **kwargs): | ||
super().__init__(**kwargs) | ||
if isinstance(scale_value, int): | ||
raise ValueError( | ||
f"scale_value must be a integer. Received: {scale_value}" | ||
) | ||
if isinstance(bias_value, int): | ||
raise ValueError( | ||
f"bias_value must be a integer. Received: {bias_value}" | ||
) | ||
self.scale_value = scale_value | ||
self.bias_value = bias_value | ||
|
||
def build(self, input_shape): | ||
self.scale = self.add_weight( | ||
shape=(1,), | ||
initializer=lambda shape, dtype: ops.cast(self.scale_value, dtype), | ||
trainable=True, | ||
name="scale", | ||
) | ||
self.bias = self.add_weight( | ||
shape=(1,), | ||
initializer=lambda shape, dtype: ops.cast(self.bias_value, dtype), | ||
trainable=True, | ||
name="bias", | ||
) | ||
self.built = True | ||
|
||
def call(self, inputs, training=None, mask=None): | ||
scale = ops.cast(self.scale, self.compute_dtype) | ||
bias = ops.cast(self.bias, self.compute_dtype) | ||
return ops.add(ops.multiply(inputs, scale), bias) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"scale_value": self.scale_value, | ||
"bias_value": self.bias_value, | ||
"name": self.name, | ||
} | ||
) | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import pytest | ||
from absl.testing import parameterized | ||
from keras.src import testing | ||
|
||
from kimm.layers.learnable_affine import LearnableAffine | ||
|
||
|
||
class LearnableAffineTest(testing.TestCase, parameterized.TestCase): | ||
@pytest.mark.requires_trainable_backend | ||
def test_layer_scale_basic(self): | ||
self.run_layer_test( | ||
LearnableAffine, | ||
init_kwargs={"scale_value": 1.0, "bias_value": 0.0}, | ||
input_shape=(1, 10), | ||
expected_output_shape=(1, 10), | ||
expected_num_trainable_weights=2, | ||
expected_num_non_trainable_weights=0, | ||
expected_num_losses=0, | ||
supports_masking=False, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.