Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow filenames with commas in embedding file #1395

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion lightly/api/api_workflow_upload_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ def _order_csv_by_filenames(self, path_to_embeddings_csv: str) -> List[str]:
f"The filenames in the embedding file and "
f"the filenames on the server do not align"
)
io_utils.check_filenames(filenames)

rows_without_header_ordered = self._order_list_by_filenames(
filenames, rows_without_header
Expand Down
7 changes: 0 additions & 7 deletions lightly/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@

import torchvision.datasets as datasets
from PIL import Image
from torch._C import Value
from torchvision import transforms
from torchvision.datasets.vision import StandardTransform, VisionDataset

from lightly.data._helpers import DatasetFolder, _load_dataset_from_folder
from lightly.data._video import VideoDataset
from lightly.utils.io import check_filenames


def _get_filename_by_index(dataset, index):
Expand Down Expand Up @@ -177,11 +175,6 @@ def is_valid_file(filepath: str):
if index_to_filename is not None:
self.index_to_filename = index_to_filename

# if created from an input directory with filenames, check if they
# are valid
if input_dir:
check_filenames(self.get_filenames())

@classmethod
def from_torch_dataset(cls, dataset, transform=None, index_to_filename=None):
"""Builds a LightlyDataset from a PyTorch (or torchvision) dataset.
Expand Down
26 changes: 0 additions & 26 deletions lightly/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,6 @@
import numpy as np
from numpy.typing import NDArray

INVALID_FILENAME_CHARACTERS = [","]


def _is_valid_filename(filename: str) -> bool:
"""Returns False if the filename is misformatted."""
for character in INVALID_FILENAME_CHARACTERS:
if character in filename:
return False
return True


def check_filenames(filenames: List[str]) -> None:
philippmwirth marked this conversation as resolved.
Show resolved Hide resolved
"""Raises an error if one of the filenames is misformatted

Args:
filenames:
A list of string being filenames

"""
invalid_filenames = [f for f in filenames if not _is_valid_filename(f)]
if len(invalid_filenames) > 0:
raise ValueError(f"Invalid filename(s): {invalid_filenames}")


def check_embeddings(path: str, remove_additional_columns: bool = False) -> None:
"""Raises an error if the embeddings csv file has not the correct format
Expand Down Expand Up @@ -147,7 +124,6 @@ def save_embeddings(
>>> labels,
>>> filenames)
"""
check_filenames(filenames)

n_embeddings = len(embeddings)
n_filenames = len(filenames)
Expand Down Expand Up @@ -203,8 +179,6 @@ def load_embeddings(path: str) -> Tuple[NDArray[np.float64], List[int], List[str
# read embeddings
embeddings.append(row[1:-1])

check_filenames(filenames)

embedding_array = np.array(embeddings).astype(np.float64)
return embedding_array, labels, filenames

Expand Down
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ exclude = (?x)(
tests/utils/benchmarking/test_linear_classifier.py |
tests/utils/benchmarking/test_metric_callback.py |
tests/utils/test_dist.py |
tests/utils/test_io.py |
tests/models/test_ModelsSimSiam.py |
tests/models/modules/test_masked_autoencoder.py |
tests/models/test_ModelsSimCLR.py |
Expand Down Expand Up @@ -232,6 +231,9 @@ follow_imports = skip
[mypy-lightly.utils.benchmarking.*]
follow_imports = skip

[mypy-tests.api_workflow.*]
follow_imports = skip

# Ignore errors in auto generated code.
[mypy-lightly.openapi_generated.*]
ignore_errors = True
15 changes: 1 addition & 14 deletions tests/api_workflow/test_api_workflow_upload_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
import numpy as np

from lightly.utils import io as io_utils
from lightly.utils.io import INVALID_FILENAME_CHARACTERS
from tests.api_workflow.mocked_api_workflow_client import (
N_FILES_ON_SERVER,
MockedApiWorkflowSetup,
)
from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup


class TestApiWorkflowUploadEmbeddings(MockedApiWorkflowSetup):
Expand Down Expand Up @@ -80,15 +76,6 @@ def test_upload_wrong_filenames(self):
with self.assertRaises(ValueError):
self.t_ester_upload_embedding(n_data=n_data, special_name_first_sample=True)

def test_upload_comma_filenames(self):
n_data = len(self.api_workflow_client._mappings_api.sample_names)
for invalid_char in INVALID_FILENAME_CHARACTERS:
with self.subTest(msg=f"invalid_char: {invalid_char}"):
with self.assertRaises(ValueError):
self.t_ester_upload_embedding(
n_data=n_data, special_char_in_first_filename=invalid_char
)

def test_set_embedding_id_default(self):
self.api_workflow_client.set_embedding_id_to_latest()
embeddings = (
Expand Down
25 changes: 1 addition & 24 deletions tests/data/test_LightlyDataset.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
import os
import random
import re
import shutil
import tempfile
import unittest
import warnings
from typing import List, Tuple

import numpy as np
import torch
import torchvision
from PIL.Image import Image

from lightly.data import LightlyDataset
from lightly.data._utils import check_images
from lightly.utils.io import INVALID_FILENAME_CHARACTERS

try:
import av
import av as _
import cv2

from lightly.data._video import VideoDataset
Expand Down Expand Up @@ -137,24 +132,6 @@ def test_create_lightly_dataset_from_folder_nosubdir(self):
for i in range(n_tot):
sample, target, fname = dataset[i]

def test_create_lightly_dataset_with_invalid_char_in_filename(self):
# create a dataset
n_tot = 100
dataset = torchvision.datasets.FakeData(size=n_tot, image_size=(3, 32, 32))

for invalid_char in INVALID_FILENAME_CHARACTERS:
with self.subTest(msg=f"invalid_char: {invalid_char}"):
tmp_dir = tempfile.mkdtemp()
sample_names = [f"img_,_{i}.jpg" for i in range(n_tot)]
for sample_idx in range(n_tot):
data = dataset[sample_idx]
path = os.path.join(tmp_dir, sample_names[sample_idx])
data[0].save(path)

# create lightly dataset
with self.assertRaises(ValueError):
dataset = LightlyDataset(input_dir=tmp_dir)

def test_check_images(self):
# create a dataset
tmp_dir = tempfile.mkdtemp()
Expand Down
123 changes: 68 additions & 55 deletions tests/utils/test_io.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,35 @@
import csv
import json
import sys
import tempfile
import unittest
from pathlib import Path

import numpy as np

from lightly.utils.io import (
check_embeddings,
check_filenames,
save_custom_metadata,
save_embeddings,
save_schema,
save_tasks,
)
from tests.api_workflow.mocked_api_workflow_client import (
MockedApiWorkflowClient,
MockedApiWorkflowSetup,
)


class TestCLICrop(MockedApiWorkflowSetup):
def test_save_metadata(self):
from lightly.utils import io
from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup


class TestCLICrop(MockedApiWorkflowSetup): # type: ignore[misc]
def test_save_metadata(self) -> None:
metadata = [("filename.jpg", {"random_metadata": 42})]
metadata_filepath = tempfile.mktemp(".json", "metadata")
save_custom_metadata(metadata_filepath, metadata)

def test_valid_filenames(self):
valid = "img.png"
non_valid = "img,1.png"
filenames_list = [
([valid], True),
([valid, valid], True),
([non_valid], False),
([valid, non_valid], False),
]
for filenames, valid in filenames_list:
with self.subTest(msg=f"filenames:{filenames}"):
if valid:
check_filenames(filenames)
else:
with self.assertRaises(ValueError):
check_filenames(filenames)
io.save_custom_metadata(metadata_filepath, metadata)


class TestEmbeddingsIO(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
# correct embedding file as created through lightly
self.embeddings_path = tempfile.mktemp(".csv", "embeddings")
embeddings = np.random.rand(32, 2)
labels = [0 for i in range(len(embeddings))]
filenames = [f"img_{i}.jpg" for i in range(len(embeddings))]
save_embeddings(self.embeddings_path, embeddings, labels, filenames)
io.save_embeddings(self.embeddings_path, embeddings, labels, filenames)

def test_valid_embeddings(self):
check_embeddings(self.embeddings_path)
def test_valid_embeddings(self) -> None:
io.check_embeddings(self.embeddings_path)

def test_whitespace_in_embeddings(self):
def test_whitespace_in_embeddings(self) -> None:
# should fail because there whitespaces in the header columns
lines = [
"filenames, embedding_0,embedding_1,labels\n",
Expand All @@ -65,19 +38,19 @@ def test_whitespace_in_embeddings(self):
with open(self.embeddings_path, "w") as f:
f.writelines(lines)
with self.assertRaises(RuntimeError) as context:
check_embeddings(self.embeddings_path)
io.check_embeddings(self.embeddings_path)
self.assertTrue("must not contain whitespaces" in str(context.exception))

def test_no_labels_in_embeddings(self):
def test_no_labels_in_embeddings(self) -> None:
# should fail because there is no `labels` column in the header
lines = ["filenames,embedding_0,embedding_1\n", "img_1.jpg,0.351,0.1231"]
with open(self.embeddings_path, "w") as f:
f.writelines(lines)
with self.assertRaises(RuntimeError) as context:
check_embeddings(self.embeddings_path)
io.check_embeddings(self.embeddings_path)
self.assertTrue("has no `labels` column" in str(context.exception))

def test_no_empty_rows_in_embeddings(self):
def test_no_empty_rows_in_embeddings(self) -> None:
# should fail because there are empty rows in the embeddings file
lines = [
"filenames,embedding_0,embedding_1,labels\n",
Expand All @@ -86,10 +59,10 @@ def test_no_empty_rows_in_embeddings(self):
with open(self.embeddings_path, "w") as f:
f.writelines(lines)
with self.assertRaises(RuntimeError) as context:
check_embeddings(self.embeddings_path)
io.check_embeddings(self.embeddings_path)
self.assertTrue("must not have empty rows" in str(context.exception))

def test_embeddings_extra_rows(self):
def test_embeddings_extra_rows(self) -> None:
rows = [
["filenames", "embedding_0", "embedding_1", "labels", "selected", "masked"],
["image_0.jpg", "3.4", "0.23", "0", "1", "0"],
Expand All @@ -99,14 +72,14 @@ def test_embeddings_extra_rows(self):
csv_writer = csv.writer(f)
csv_writer.writerows(rows)

check_embeddings(self.embeddings_path, remove_additional_columns=True)
io.check_embeddings(self.embeddings_path, remove_additional_columns=True)

with open(self.embeddings_path) as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
for row_read, row_original in zip(csv_reader, rows):
self.assertListEqual(row_read, row_original[:-2])

def test_embeddings_extra_rows_special_order(self):
def test_embeddings_extra_rows_special_order(self) -> None:
input_rows = [
["filenames", "embedding_0", "embedding_1", "masked", "labels", "selected"],
["image_0.jpg", "3.4", "0.23", "0", "1", "0"],
Expand All @@ -121,26 +94,26 @@ def test_embeddings_extra_rows_special_order(self):
csv_writer = csv.writer(f)
csv_writer.writerows(input_rows)

check_embeddings(self.embeddings_path, remove_additional_columns=True)
io.check_embeddings(self.embeddings_path, remove_additional_columns=True)

with open(self.embeddings_path) as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
for row_read, row_original in zip(csv_reader, correct_output_rows):
self.assertListEqual(row_read, row_original)

def test_save_tasks(self):
def test_save_tasks(self) -> None:
tasks = [
"task1",
"task2",
"task3",
]
with tempfile.NamedTemporaryFile(suffix=".json") as file:
save_tasks(file.name, tasks)
io.save_tasks(file.name, tasks)
with open(file.name, "r") as f:
loaded = json.load(f)
self.assertListEqual(tasks, loaded)

def test_save_schema(self):
def test_save_schema(self) -> None:
description = "classification"
ids = [1, 2, 3, 4]
names = ["name1", "name2", "name3", "name4"]
Expand All @@ -154,16 +127,56 @@ def test_save_schema(self):
],
}
with tempfile.NamedTemporaryFile(suffix=".json") as file:
save_schema(file.name, description, ids, names)
io.save_schema(file.name, description, ids, names)
with open(file.name, "r") as f:
loaded = json.load(f)
self.assertListEqual(sorted(expected_format), sorted(loaded))

def test_save_schema_different(self):
def test_save_schema_different(self) -> None:
with self.assertRaises(ValueError):
save_schema(
io.save_schema(
"name_doesnt_matter",
"description_doesnt_matter",
[1, 2],
["name1"],
)


def test_save_and_load_embeddings(tmp_path: Path) -> None:
embeddings = np.random.rand(2, 32)
labels = [0, 1]
filenames = ["img_1.jpg", "img_2.jpg"]

io.save_embeddings(
path=str(tmp_path / "embeddings.csv"),
embeddings=embeddings,
labels=labels,
filenames=filenames,
)

loaded_embeddings, loaded_labels, loaded_filenames = io.load_embeddings(
path=str(tmp_path / "embeddings.csv")
)
assert np.allclose(embeddings, loaded_embeddings)
assert labels == loaded_labels
assert filenames == loaded_filenames


def test_save_and_load_embeddings__filename_with_comma(tmp_path: Path) -> None:
embeddings = np.random.rand(2, 32)
labels = [0, 1]
filenames = ["img,1.jpg", "img_2.jpg"]
philippmwirth marked this conversation as resolved.
Show resolved Hide resolved

io.save_embeddings(
path=str(tmp_path / "embeddings.csv"),
embeddings=embeddings,
labels=labels,
filenames=filenames,
)

loaded_embeddings, loaded_labels, loaded_filenames = io.load_embeddings(
path=str(tmp_path / "embeddings.csv")
)
assert np.allclose(embeddings, loaded_embeddings)
assert labels == loaded_labels
assert filenames == loaded_filenames
Loading