Skip to content

Commit

Permalink
Type check io tests
Browse files Browse the repository at this point in the history
  • Loading branch information
philippmwirth committed Sep 14, 2023
1 parent cee8e03 commit d59bdfe
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ exclude = (?x)(
tests/utils/benchmarking/test_linear_classifier.py |
tests/utils/benchmarking/test_metric_callback.py |
tests/utils/test_dist.py |
tests/utils/test_io.py |
tests/models/test_ModelsSimSiam.py |
tests/models/modules/test_masked_autoencoder.py |
tests/models/test_ModelsSimCLR.py |
Expand Down Expand Up @@ -232,6 +231,9 @@ follow_imports = skip
[mypy-lightly.utils.benchmarking.*]
follow_imports = skip

[mypy-tests.api_workflow.*]
follow_imports = skip

# Ignore errors in auto generated code.
[mypy-lightly.openapi_generated.*]
ignore_errors = True
24 changes: 12 additions & 12 deletions tests/utils/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,26 @@
from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup


class TestCLICrop(MockedApiWorkflowSetup):
def test_save_metadata(self):
class TestCLICrop(MockedApiWorkflowSetup): # type: ignore[misc]
def test_save_metadata(self) -> None:
metadata = [("filename.jpg", {"random_metadata": 42})]
metadata_filepath = tempfile.mktemp(".json", "metadata")
io.save_custom_metadata(metadata_filepath, metadata)


class TestEmbeddingsIO(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
# correct embedding file as created through lightly
self.embeddings_path = tempfile.mktemp(".csv", "embeddings")
embeddings = np.random.rand(32, 2)
labels = [0 for i in range(len(embeddings))]
filenames = [f"img_{i}.jpg" for i in range(len(embeddings))]
io.save_embeddings(self.embeddings_path, embeddings, labels, filenames)

def test_valid_embeddings(self):
def test_valid_embeddings(self) -> None:
io.check_embeddings(self.embeddings_path)

def test_whitespace_in_embeddings(self):
def test_whitespace_in_embeddings(self) -> None:
# should fail because there whitespaces in the header columns
lines = [
"filenames, embedding_0,embedding_1,labels\n",
Expand All @@ -41,7 +41,7 @@ def test_whitespace_in_embeddings(self):
io.check_embeddings(self.embeddings_path)
self.assertTrue("must not contain whitespaces" in str(context.exception))

def test_no_labels_in_embeddings(self):
def test_no_labels_in_embeddings(self) -> None:
# should fail because there is no `labels` column in the header
lines = ["filenames,embedding_0,embedding_1\n", "img_1.jpg,0.351,0.1231"]
with open(self.embeddings_path, "w") as f:
Expand All @@ -50,7 +50,7 @@ def test_no_labels_in_embeddings(self):
io.check_embeddings(self.embeddings_path)
self.assertTrue("has no `labels` column" in str(context.exception))

def test_no_empty_rows_in_embeddings(self):
def test_no_empty_rows_in_embeddings(self) -> None:
# should fail because there are empty rows in the embeddings file
lines = [
"filenames,embedding_0,embedding_1,labels\n",
Expand All @@ -62,7 +62,7 @@ def test_no_empty_rows_in_embeddings(self):
io.check_embeddings(self.embeddings_path)
self.assertTrue("must not have empty rows" in str(context.exception))

def test_embeddings_extra_rows(self):
def test_embeddings_extra_rows(self) -> None:
rows = [
["filenames", "embedding_0", "embedding_1", "labels", "selected", "masked"],
["image_0.jpg", "3.4", "0.23", "0", "1", "0"],
Expand All @@ -79,7 +79,7 @@ def test_embeddings_extra_rows(self):
for row_read, row_original in zip(csv_reader, rows):
self.assertListEqual(row_read, row_original[:-2])

def test_embeddings_extra_rows_special_order(self):
def test_embeddings_extra_rows_special_order(self) -> None:
input_rows = [
["filenames", "embedding_0", "embedding_1", "masked", "labels", "selected"],
["image_0.jpg", "3.4", "0.23", "0", "1", "0"],
Expand All @@ -101,7 +101,7 @@ def test_embeddings_extra_rows_special_order(self):
for row_read, row_original in zip(csv_reader, correct_output_rows):
self.assertListEqual(row_read, row_original)

def test_save_tasks(self):
def test_save_tasks(self) -> None:
tasks = [
"task1",
"task2",
Expand All @@ -113,7 +113,7 @@ def test_save_tasks(self):
loaded = json.load(f)
self.assertListEqual(tasks, loaded)

def test_save_schema(self):
def test_save_schema(self) -> None:
description = "classification"
ids = [1, 2, 3, 4]
names = ["name1", "name2", "name3", "name4"]
Expand All @@ -132,7 +132,7 @@ def test_save_schema(self):
loaded = json.load(f)
self.assertListEqual(sorted(expected_format), sorted(loaded))

def test_save_schema_different(self):
def test_save_schema_different(self) -> None:
with self.assertRaises(ValueError):
io.save_schema(
"name_doesnt_matter",
Expand Down

0 comments on commit d59bdfe

Please sign in to comment.