diff --git a/lightly/api/api_workflow_upload_embeddings.py b/lightly/api/api_workflow_upload_embeddings.py index 4e8885d0c..bf28c1c2f 100644 --- a/lightly/api/api_workflow_upload_embeddings.py +++ b/lightly/api/api_workflow_upload_embeddings.py @@ -259,7 +259,6 @@ def _order_csv_by_filenames(self, path_to_embeddings_csv: str) -> List[str]: f"The filenames in the embedding file and " f"the filenames on the server do not align" ) - io_utils.check_filenames(filenames) rows_without_header_ordered = self._order_list_by_filenames( filenames, rows_without_header diff --git a/lightly/data/dataset.py b/lightly/data/dataset.py index 4cd1ce87d..da7fd24b1 100644 --- a/lightly/data/dataset.py +++ b/lightly/data/dataset.py @@ -9,13 +9,11 @@ import torchvision.datasets as datasets from PIL import Image -from torch._C import Value from torchvision import transforms from torchvision.datasets.vision import StandardTransform, VisionDataset from lightly.data._helpers import DatasetFolder, _load_dataset_from_folder from lightly.data._video import VideoDataset -from lightly.utils.io import check_filenames def _get_filename_by_index(dataset, index): @@ -177,11 +175,6 @@ def is_valid_file(filepath: str): if index_to_filename is not None: self.index_to_filename = index_to_filename - # if created from an input directory with filenames, check if they - # are valid - if input_dir: - check_filenames(self.get_filenames()) - @classmethod def from_torch_dataset(cls, dataset, transform=None, index_to_filename=None): """Builds a LightlyDataset from a PyTorch (or torchvision) dataset. diff --git a/lightly/utils/io.py b/lightly/utils/io.py index 6da6b7813..74240bf6e 100644 --- a/lightly/utils/io.py +++ b/lightly/utils/io.py @@ -12,29 +12,6 @@ import numpy as np from numpy.typing import NDArray -INVALID_FILENAME_CHARACTERS = [","] - - -def _is_valid_filename(filename: str) -> bool: - """Returns False if the filename is misformatted.""" - for character in INVALID_FILENAME_CHARACTERS: - if character in filename: - return False - return True - - -def check_filenames(filenames: List[str]) -> None: - """Raises an error if one of the filenames is misformatted - - Args: - filenames: - A list of string being filenames - - """ - invalid_filenames = [f for f in filenames if not _is_valid_filename(f)] - if len(invalid_filenames) > 0: - raise ValueError(f"Invalid filename(s): {invalid_filenames}") - def check_embeddings(path: str, remove_additional_columns: bool = False) -> None: """Raises an error if the embeddings csv file has not the correct format @@ -147,7 +124,6 @@ def save_embeddings( >>> labels, >>> filenames) """ - check_filenames(filenames) n_embeddings = len(embeddings) n_filenames = len(filenames) @@ -203,8 +179,6 @@ def load_embeddings(path: str) -> Tuple[NDArray[np.float64], List[int], List[str # read embeddings embeddings.append(row[1:-1]) - check_filenames(filenames) - embedding_array = np.array(embeddings).astype(np.float64) return embedding_array, labels, filenames diff --git a/mypy.ini b/mypy.ini index edf952772..ae5fce7af 100644 --- a/mypy.ini +++ b/mypy.ini @@ -174,7 +174,6 @@ exclude = (?x)( tests/utils/benchmarking/test_linear_classifier.py | tests/utils/benchmarking/test_metric_callback.py | tests/utils/test_dist.py | - tests/utils/test_io.py | tests/models/test_ModelsSimSiam.py | tests/models/modules/test_masked_autoencoder.py | tests/models/test_ModelsSimCLR.py | @@ -232,6 +231,9 @@ follow_imports = skip [mypy-lightly.utils.benchmarking.*] follow_imports = skip +[mypy-tests.api_workflow.*] +follow_imports = skip + # Ignore errors in auto generated code. [mypy-lightly.openapi_generated.*] ignore_errors = True \ No newline at end of file diff --git a/tests/api_workflow/test_api_workflow_upload_embeddings.py b/tests/api_workflow/test_api_workflow_upload_embeddings.py index 4444d11d7..a3bb54c2b 100644 --- a/tests/api_workflow/test_api_workflow_upload_embeddings.py +++ b/tests/api_workflow/test_api_workflow_upload_embeddings.py @@ -4,11 +4,7 @@ import numpy as np from lightly.utils import io as io_utils -from lightly.utils.io import INVALID_FILENAME_CHARACTERS -from tests.api_workflow.mocked_api_workflow_client import ( - N_FILES_ON_SERVER, - MockedApiWorkflowSetup, -) +from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup class TestApiWorkflowUploadEmbeddings(MockedApiWorkflowSetup): @@ -80,15 +76,6 @@ def test_upload_wrong_filenames(self): with self.assertRaises(ValueError): self.t_ester_upload_embedding(n_data=n_data, special_name_first_sample=True) - def test_upload_comma_filenames(self): - n_data = len(self.api_workflow_client._mappings_api.sample_names) - for invalid_char in INVALID_FILENAME_CHARACTERS: - with self.subTest(msg=f"invalid_char: {invalid_char}"): - with self.assertRaises(ValueError): - self.t_ester_upload_embedding( - n_data=n_data, special_char_in_first_filename=invalid_char - ) - def test_set_embedding_id_default(self): self.api_workflow_client.set_embedding_id_to_latest() embeddings = ( diff --git a/tests/data/test_LightlyDataset.py b/tests/data/test_LightlyDataset.py index c5a1a7c6c..0b4c7b196 100644 --- a/tests/data/test_LightlyDataset.py +++ b/tests/data/test_LightlyDataset.py @@ -1,23 +1,18 @@ import os -import random -import re import shutil import tempfile import unittest -import warnings from typing import List, Tuple import numpy as np -import torch import torchvision from PIL.Image import Image from lightly.data import LightlyDataset from lightly.data._utils import check_images -from lightly.utils.io import INVALID_FILENAME_CHARACTERS try: - import av + import av as _ import cv2 from lightly.data._video import VideoDataset @@ -137,24 +132,6 @@ def test_create_lightly_dataset_from_folder_nosubdir(self): for i in range(n_tot): sample, target, fname = dataset[i] - def test_create_lightly_dataset_with_invalid_char_in_filename(self): - # create a dataset - n_tot = 100 - dataset = torchvision.datasets.FakeData(size=n_tot, image_size=(3, 32, 32)) - - for invalid_char in INVALID_FILENAME_CHARACTERS: - with self.subTest(msg=f"invalid_char: {invalid_char}"): - tmp_dir = tempfile.mkdtemp() - sample_names = [f"img_,_{i}.jpg" for i in range(n_tot)] - for sample_idx in range(n_tot): - data = dataset[sample_idx] - path = os.path.join(tmp_dir, sample_names[sample_idx]) - data[0].save(path) - - # create lightly dataset - with self.assertRaises(ValueError): - dataset = LightlyDataset(input_dir=tmp_dir) - def test_check_images(self): # create a dataset tmp_dir = tempfile.mkdtemp() diff --git a/tests/utils/test_io.py b/tests/utils/test_io.py index 988b622f1..04086be5f 100644 --- a/tests/utils/test_io.py +++ b/tests/utils/test_io.py @@ -1,62 +1,35 @@ import csv import json -import sys import tempfile import unittest +from pathlib import Path import numpy as np -from lightly.utils.io import ( - check_embeddings, - check_filenames, - save_custom_metadata, - save_embeddings, - save_schema, - save_tasks, -) -from tests.api_workflow.mocked_api_workflow_client import ( - MockedApiWorkflowClient, - MockedApiWorkflowSetup, -) - - -class TestCLICrop(MockedApiWorkflowSetup): - def test_save_metadata(self): +from lightly.utils import io +from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup + + +class TestCLICrop(MockedApiWorkflowSetup): # type: ignore[misc] + def test_save_metadata(self) -> None: metadata = [("filename.jpg", {"random_metadata": 42})] metadata_filepath = tempfile.mktemp(".json", "metadata") - save_custom_metadata(metadata_filepath, metadata) - - def test_valid_filenames(self): - valid = "img.png" - non_valid = "img,1.png" - filenames_list = [ - ([valid], True), - ([valid, valid], True), - ([non_valid], False), - ([valid, non_valid], False), - ] - for filenames, valid in filenames_list: - with self.subTest(msg=f"filenames:{filenames}"): - if valid: - check_filenames(filenames) - else: - with self.assertRaises(ValueError): - check_filenames(filenames) + io.save_custom_metadata(metadata_filepath, metadata) class TestEmbeddingsIO(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: # correct embedding file as created through lightly self.embeddings_path = tempfile.mktemp(".csv", "embeddings") embeddings = np.random.rand(32, 2) labels = [0 for i in range(len(embeddings))] filenames = [f"img_{i}.jpg" for i in range(len(embeddings))] - save_embeddings(self.embeddings_path, embeddings, labels, filenames) + io.save_embeddings(self.embeddings_path, embeddings, labels, filenames) - def test_valid_embeddings(self): - check_embeddings(self.embeddings_path) + def test_valid_embeddings(self) -> None: + io.check_embeddings(self.embeddings_path) - def test_whitespace_in_embeddings(self): + def test_whitespace_in_embeddings(self) -> None: # should fail because there whitespaces in the header columns lines = [ "filenames, embedding_0,embedding_1,labels\n", @@ -65,19 +38,19 @@ def test_whitespace_in_embeddings(self): with open(self.embeddings_path, "w") as f: f.writelines(lines) with self.assertRaises(RuntimeError) as context: - check_embeddings(self.embeddings_path) + io.check_embeddings(self.embeddings_path) self.assertTrue("must not contain whitespaces" in str(context.exception)) - def test_no_labels_in_embeddings(self): + def test_no_labels_in_embeddings(self) -> None: # should fail because there is no `labels` column in the header lines = ["filenames,embedding_0,embedding_1\n", "img_1.jpg,0.351,0.1231"] with open(self.embeddings_path, "w") as f: f.writelines(lines) with self.assertRaises(RuntimeError) as context: - check_embeddings(self.embeddings_path) + io.check_embeddings(self.embeddings_path) self.assertTrue("has no `labels` column" in str(context.exception)) - def test_no_empty_rows_in_embeddings(self): + def test_no_empty_rows_in_embeddings(self) -> None: # should fail because there are empty rows in the embeddings file lines = [ "filenames,embedding_0,embedding_1,labels\n", @@ -86,10 +59,10 @@ def test_no_empty_rows_in_embeddings(self): with open(self.embeddings_path, "w") as f: f.writelines(lines) with self.assertRaises(RuntimeError) as context: - check_embeddings(self.embeddings_path) + io.check_embeddings(self.embeddings_path) self.assertTrue("must not have empty rows" in str(context.exception)) - def test_embeddings_extra_rows(self): + def test_embeddings_extra_rows(self) -> None: rows = [ ["filenames", "embedding_0", "embedding_1", "labels", "selected", "masked"], ["image_0.jpg", "3.4", "0.23", "0", "1", "0"], @@ -99,14 +72,14 @@ def test_embeddings_extra_rows(self): csv_writer = csv.writer(f) csv_writer.writerows(rows) - check_embeddings(self.embeddings_path, remove_additional_columns=True) + io.check_embeddings(self.embeddings_path, remove_additional_columns=True) with open(self.embeddings_path) as csv_file: csv_reader = csv.reader(csv_file, delimiter=",") for row_read, row_original in zip(csv_reader, rows): self.assertListEqual(row_read, row_original[:-2]) - def test_embeddings_extra_rows_special_order(self): + def test_embeddings_extra_rows_special_order(self) -> None: input_rows = [ ["filenames", "embedding_0", "embedding_1", "masked", "labels", "selected"], ["image_0.jpg", "3.4", "0.23", "0", "1", "0"], @@ -121,26 +94,26 @@ def test_embeddings_extra_rows_special_order(self): csv_writer = csv.writer(f) csv_writer.writerows(input_rows) - check_embeddings(self.embeddings_path, remove_additional_columns=True) + io.check_embeddings(self.embeddings_path, remove_additional_columns=True) with open(self.embeddings_path) as csv_file: csv_reader = csv.reader(csv_file, delimiter=",") for row_read, row_original in zip(csv_reader, correct_output_rows): self.assertListEqual(row_read, row_original) - def test_save_tasks(self): + def test_save_tasks(self) -> None: tasks = [ "task1", "task2", "task3", ] with tempfile.NamedTemporaryFile(suffix=".json") as file: - save_tasks(file.name, tasks) + io.save_tasks(file.name, tasks) with open(file.name, "r") as f: loaded = json.load(f) self.assertListEqual(tasks, loaded) - def test_save_schema(self): + def test_save_schema(self) -> None: description = "classification" ids = [1, 2, 3, 4] names = ["name1", "name2", "name3", "name4"] @@ -154,16 +127,56 @@ def test_save_schema(self): ], } with tempfile.NamedTemporaryFile(suffix=".json") as file: - save_schema(file.name, description, ids, names) + io.save_schema(file.name, description, ids, names) with open(file.name, "r") as f: loaded = json.load(f) self.assertListEqual(sorted(expected_format), sorted(loaded)) - def test_save_schema_different(self): + def test_save_schema_different(self) -> None: with self.assertRaises(ValueError): - save_schema( + io.save_schema( "name_doesnt_matter", "description_doesnt_matter", [1, 2], ["name1"], ) + + +def test_save_and_load_embeddings(tmp_path: Path) -> None: + embeddings = np.random.rand(2, 32) + labels = [0, 1] + filenames = ["img_1.jpg", "img_2.jpg"] + + io.save_embeddings( + path=str(tmp_path / "embeddings.csv"), + embeddings=embeddings, + labels=labels, + filenames=filenames, + ) + + loaded_embeddings, loaded_labels, loaded_filenames = io.load_embeddings( + path=str(tmp_path / "embeddings.csv") + ) + assert np.allclose(embeddings, loaded_embeddings) + assert labels == loaded_labels + assert filenames == loaded_filenames + + +def test_save_and_load_embeddings__filename_with_comma(tmp_path: Path) -> None: + embeddings = np.random.rand(4, 32) + labels = [0, 1, 2, 3] + filenames = ["img,1.jpg", '",img,.jpg', ',"img".jpg', ',"img\n".jpg'] + + io.save_embeddings( + path=str(tmp_path / "embeddings.csv"), + embeddings=embeddings, + labels=labels, + filenames=filenames, + ) + + loaded_embeddings, loaded_labels, loaded_filenames = io.load_embeddings( + path=str(tmp_path / "embeddings.csv") + ) + assert np.allclose(embeddings, loaded_embeddings) + assert labels == loaded_labels + assert filenames == loaded_filenames