Skip to content

Commit

Permalink
Speed up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed May 16, 2024
1 parent 2f6f374 commit 8febbf5
Showing 1 changed file with 20 additions and 35 deletions.
55 changes: 20 additions & 35 deletions kimm/_src/models/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def test_feature_extractor(self):


@pytest.mark.requires_trainable_backend # numpy is too slow to test
class ModelTest(testing.TestCase, parameterized.TestCase):
class ModelsTest(testing.TestCase, parameterized.TestCase):
@classmethod
def setUpClass(cls):
cls.original_image_data_format = backend.image_data_format()
Expand All @@ -475,29 +475,41 @@ def test_predict(
features,
weights="imagenet",
):
# We also enable feature_extractor=True in model instantiation to
# speed up the testing

# Load the image
image_path = keras.utils.get_file(
"elephant.png",
"https://github.com/james77777778/keras-image-models/releases/download/0.1.0/elephant.png",
)
image = utils.load_img(image_path, target_size=(image_size, image_size))

# Test channels_last
# Test channels_last and feature_extractor=True
backend.set_image_data_format("channels_last")
model = model_class(weights=weights)
if hasattr(model, "get_reparameterized_model"):
model = model.get_reparameterized_model()
model = model_class(weights=weights, feature_extractor=True)
x = utils.img_to_array(image, data_format="channels_last")
x = ops.expand_dims(ops.convert_to_tensor(x), axis=0)

y = model(x, training=False)

# Verify output correctness
prob = y["TOP"]
if weights == "imagenet":
names = [p[1] for p in decode_predictions(y)[0]]
names = [p[1] for p in decode_predictions(prob)[0]]
# Test correct label is in top 3 (weak correctness test).
self.assertIn("African_elephant", names[:3])
elif weights is None:
self.assertEqual(list(y.shape), [1, 1000])
self.assertEqual(list(prob.shape), [1, 1000])

# Verify features
self.assertIsInstance(y, dict)
self.assertContainsSubset(
model_class.available_feature_keys, list(y.keys())
)
for feature_info in features:
name, shape = feature_info
self.assertEqual(list(y[name].shape), shape)

# Test channels_first
if (
Expand All @@ -509,46 +521,19 @@ def test_predict(

backend.set_image_data_format("channels_first")
model = model_class(weights=weights)
if hasattr(model, "get_reparameterized_model"):
model = model.get_reparameterized_model()
x = utils.img_to_array(image, data_format="channels_first")
x = ops.expand_dims(ops.convert_to_tensor(x), axis=0)

y = model(x, training=False)

# Verify output correctness
if weights == "imagenet":
names = [p[1] for p in decode_predictions(y)[0]]
# Test correct label is in top 3 (weak correctness test).
self.assertIn("African_elephant", names[:3])
elif weights is None:
self.assertEqual(list(y.shape), [1, 1000])

@parameterized.named_parameters(MODEL_CONFIGS)
def test_feature_extractor(
self,
model_class,
image_size,
features,
weights="imagenet",
):
backend.set_image_data_format("channels_last")
x = random.uniform([1, image_size, image_size, 3]) * 255.0
model = model_class(
include_top=False, weights=None, feature_extractor=True
)
if hasattr(model, "get_reparameterized_model"):
model = model.get_reparameterized_model()

y = model(x, training=False)

self.assertIsInstance(y, dict)
self.assertContainsSubset(
model_class.available_feature_keys, list(y.keys())
)
for feature_info in features:
name, shape = feature_info
self.assertEqual(list(y[name].shape), shape)

@parameterized.named_parameters(
(
kimm_models.repvgg.RepVGGA0.__name__,
Expand Down

0 comments on commit 8febbf5

Please sign in to comment.