Skip to content

Commit

Permalink
Add test for get_reparameterized_model
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Jan 29, 2024
1 parent 203b2a0 commit b748d37
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions kimm/models/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,19 @@ def test_model_feature_extractor(
name, shape = feature_info
self.assertEqual(list(y[name].shape), shape)

@parameterized.named_parameters(
(kimm_models.RepVGGA0.__name__, kimm_models.RepVGGA0, 224)
)
def test_model_get_reparameterized_model(self, model_class, image_size):
x = random.uniform([1, image_size, image_size, 3]) * 255.0
model = model_class(weights=None)
reparameterized_model = model.get_reparameterized_model()

y1 = model(x, training=False)
y2 = reparameterized_model(x, training=False)

self.assertAllClose(y1, y2, atol=1e-5)

@pytest.mark.serialization
@parameterized.named_parameters(MODEL_CONFIGS)
def test_model_serialization(
Expand Down

0 comments on commit b748d37

Please sign in to comment.