diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8c8a2be..99263f3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,6 +19,7 @@ repos: args: [--fix=lf] - id: name-tests-test args: ["--pytest-test-first"] + exclude: ^tests/fixtures - id: requirements-txt-fixer - id: trailing-whitespace - repo: https://github.com/pre-commit/pygrep-hooks diff --git a/ethology/annotations/json_schemas.py b/ethology/annotations/json_schemas.py new file mode 100644 index 0000000..4cbdd5a --- /dev/null +++ b/ethology/annotations/json_schemas.py @@ -0,0 +1,145 @@ +"""JSON schemas for manual annotations files. + +We use JSON schemas to check the structure of a supported +annotation file via validators. + +Note that the schema validation only checks the type of a key +if that key is present. It does not check for the presence of +the keys. + +References +---------- +- https://github.com/python-jsonschema/jsonschema +- https://json-schema.org/understanding-json-schema/ +- https://cocodataset.org/#format-data +- https://gitlab.com/vgg/via/-/blob/master/via-2.x.y/CodeDoc.md?ref_type=heads#description-of-via-project-json-file + +""" + +# The VIA schema corresponds to the +# format exported by VGG Image Annotator 2.x.y +# for manual labels +VIA_SCHEMA = { + "type": "object", + "properties": { + # settings for the browser-based UI of VIA + "_via_settings": { + "type": "object", + "properties": { + "ui": {"type": "object"}, + "core": {"type": "object"}, + "project": {"type": "object"}, + }, + }, + # annotations data per image + "_via_img_metadata": { + "type": "object", + "additionalProperties": { + # Each image under _via_img_metadata is indexed + # using a unique key: FILENAME-FILESIZE. + # We use "additionalProperties" to allow for any + # key name, see https://stackoverflow.com/a/69811612/24834957 + "type": "object", + "properties": { + "filename": {"type": "string"}, + "size": {"type": "integer"}, + "regions": { + "type": "array", # 'regions' is a list of dicts + "items": { + "type": "object", + "properties": { + "shape_attributes": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "x": {"type": "integer"}, + "y": {"type": "integer"}, + "width": {"type": "integer"}, + "height": {"type": "integer"}, + }, + "region_attributes": {"type": "object"}, + }, + }, + }, + }, + "file_attributes": {"type": "object"}, + }, + }, + }, + # _via_image_id_list contains an + # ordered list of image keys using a unique key: FILENAME-FILESIZE, + # the position in the list defines the image ID + "_via_image_id_list": { + "type": "array", + "items": {"type": "string"}, + }, + # region attributes and file attributes, to + # display in VIA's UI and to classify the data + "_via_attributes": { + "type": "object", + "properties": { + "region": {"type": "object"}, + "file": {"type": "object"}, + }, + }, + # version of the VIA tool used + "_via_data_format_version": {"type": "string"}, + }, +} + +# The COCO schema follows the COCO dataset +# format for object detection +# See https://cocodataset.org/#format-data +COCO_SCHEMA = { + "type": "object", + "properties": { + "info": {"type": "object"}, + "licenses": { + "type": "array", + }, + "images": { + "type": "array", + "items": { + "type": "object", + "properties": { + "file_name": {"type": "string"}, + "id": {"type": "integer"}, + "width": {"type": "integer"}, + "height": {"type": "integer"}, + }, + }, + }, + "annotations": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "integer"}, + "image_id": {"type": "integer"}, + "bbox": { + "type": "array", + "items": {"type": "integer"}, + }, + # (box coordinates are measured from the + # top left image corner and are 0-indexed) + "category_id": {"type": "integer"}, + "area": {"type": "number"}, + # float according to the official schema + "iscrowd": {"type": "integer"}, + # 0 or 1 according to the official schema + }, + }, + }, + "categories": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "integer"}, + "name": {"type": "string"}, + "supercategory": {"type": "string"}, + }, + }, + }, + }, +} diff --git a/ethology/annotations/validators.py b/ethology/annotations/validators.py new file mode 100644 index 0000000..4ff1179 --- /dev/null +++ b/ethology/annotations/validators.py @@ -0,0 +1,244 @@ +"""Validators for annotation files.""" + +import json +from pathlib import Path + +import attrs +import jsonschema +import jsonschema.exceptions +import jsonschema.validators +from attrs import define, field, validators + +from ethology.annotations.json_schemas import COCO_SCHEMA, VIA_SCHEMA + + +@define +class ValidJSON: + """Class for valid JSON files. + + It checks the JSON file exists, can be decoded, and optionally + validates the file against a JSON schema. + + Attributes + ---------- + path : pathlib.Path + Path to the JSON file. + + schema : dict, optional + JSON schema to validate the file against. + + Raises + ------ + FileNotFoundError + If the file does not exist. + ValueError + If the JSON file cannot be decoded. + jsonschema.exceptions.ValidationError + If the type of any of the keys in the JSON file + does not match the type specified in the schema. + + + Notes + ----- + https://json-schema.org/understanding-json-schema/ + + """ + + # Required attributes + path: Path = field(validator=validators.instance_of(Path)) + + # Optional attributes + schema: dict | None = field(default=None) + + @path.validator + def _file_is_json(self, attribute, value): + """Ensure that the file is a JSON file.""" + try: + with open(value) as file: + json.load(file) + except FileNotFoundError as not_found_error: + raise FileNotFoundError( + f"File not found: {value}." + ) from not_found_error + except json.JSONDecodeError as decode_error: + raise ValueError( + f"Error decoding JSON data from file: {value}." + ) from decode_error + + @path.validator + def _file_matches_JSON_schema(self, attribute, value): + """Ensure that the JSON file matches the expected schema. + + The schema validation only checks the type for each specified + key if the key exists. It does not check for the presence of + the keys. + """ + # read json file + with open(value) as file: + data = json.load(file) + + # check against schema if provided + if self.schema: + try: + jsonschema.validate(instance=data, schema=self.schema) + except jsonschema.exceptions.ValidationError as val_err: + # forward the error message as it is quite informative + raise val_err + + +@define +class ValidVIAJSON(ValidJSON): + """Class for valid VIA JSON files for untracked data. + + It checks the input VIA JSON file contains the required keys. + + Attributes + ---------- + path : pathlib.Path + Path to the VIA JSON file. + + schema : dict, optional + JSON schema to validate the file against. Default is VIA_SCHEMA. + + Raises + ------ + ValueError + If the VIA JSON file misses any of the required keys. + + """ + + # run the parent's validators first + path: Path = field(validator=attrs.fields(ValidJSON).path.validator) + schema: dict = field( + validator=attrs.fields(ValidJSON).schema.validator, # type: ignore + default=VIA_SCHEMA, + ) + + # TODO: add a validator to check the schema defines types + # for the required keys + + # run additional validators + @path.validator + def _file_contains_required_keys(self, attribute, value): + """Ensure that the VIA JSON file contains the required keys.""" + required_keys = { + "main": ["_via_img_metadata", "_via_image_id_list"], + "image_keys": ["filename", "regions"], + "region_keys": ["shape_attributes", "region_attributes"], + "shape_attributes_keys": ["x", "y", "width", "height"], + } + + # Read data as dict + with open(value) as file: + data = json.load(file) + + # Check first level keys + _check_keys(required_keys["main"], data) + + # Check keys in nested dicts + for img_str, img_dict in data["_via_img_metadata"].items(): + # Check keys for each image dictionary + _check_keys( + required_keys["image_keys"], + img_dict, + additional_message=f" for {img_str}", + ) + # Check keys for each region + for i, region in enumerate(img_dict["regions"]): + _check_keys( + required_keys["region_keys"], + region, + additional_message=f" for region {i} under {img_str}", + ) + + # Check keys under shape_attributes + _check_keys( + required_keys["shape_attributes_keys"], + region["shape_attributes"], + additional_message=f" for region {i} under {img_str}", + ) + + +@define +class ValidCOCOJSON(ValidJSON): + """Class valid COCO JSON files for untracked data. + + It checks the input COCO JSON file contains the required keys. + + Attributes + ---------- + path : pathlib.Path + Path to the COCO JSON file. + + Raises + ------ + ValueError + If the COCO JSON file misses any of the required keys. + + """ + + # run the parent's validators first + path: Path = field(validator=attrs.fields(ValidJSON).path.validator) + schema: dict = field( + validator=attrs.fields(ValidJSON).schema.validator, # type: ignore + default=COCO_SCHEMA, + ) + + # TODO: add a validator to check the schema defines types + # for the required keys + + # run additional validators + @path.validator + def _file_contains_required_keys(self, attribute, value): + """Ensure that the COCO JSON file contains the required keys.""" + required_keys = { + "main": ["images", "annotations", "categories"], + "image_keys": ["id", "file_name"], # add "height" and "width"? + "annotations_keys": ["id", "image_id", "bbox", "category_id"], + "categories_keys": ["id", "name", "supercategory"], + } + + # Read data as dict + with open(value) as file: + data = json.load(file) + + # Check first level keys + _check_keys(required_keys["main"], data) + + # Check keys in images dicts + for img_dict in data["images"]: + _check_keys( + required_keys["image_keys"], + img_dict, + additional_message=f" for image dict {img_dict}", + ) + + # Check keys in annotations dicts + for annot_dict in data["annotations"]: + _check_keys( + required_keys["annotations_keys"], + annot_dict, + additional_message=f" for annotation dict {annot_dict}", + ) + + # Check keys in categories dicts + for cat_dict in data["categories"]: + _check_keys( + required_keys["categories_keys"], + cat_dict, + additional_message=f" for category dict {cat_dict}", + ) + + +def _check_keys( + list_required_keys: list[str], + data_dict: dict, + additional_message: str = "", +): + """Check if the required keys are present in the input data_dict.""" + missing_keys = set(list_required_keys) - data_dict.keys() + if missing_keys: + raise ValueError( + f"Required key(s) {sorted(missing_keys)} not " + f"found in {list(data_dict.keys())}{additional_message}." + ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..28d0ec7 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,92 @@ +"""Pytest configuration file with shared fixtures across all tests.""" + +from pathlib import Path + +import pooch +import pytest + +GIN_TEST_DATA_REPO = ( + "https://gin.g-node.org/neuroinformatics/ethology-test-data" +) + +pytest_plugins = [ + "tests.fixtures.annotations", +] + + +@pytest.fixture(scope="session") +def pooch_registry() -> dict: + """Pooch registry for the test data. + + This fixture is common to the entire test session. The + file registry is downloaded fresh for every test session. + + Returns + ------- + dict + URL and hash of the GIN repository with the test data + + """ + # Cache the test data in the user's home directory + test_data_dir = Path.home() / ".ethology-test-data" + + # Remove the file registry if it exists + # otherwise it is not downloaded from scratch every time + file_registry_path = test_data_dir / "files-registry.txt" + if file_registry_path.is_file(): + Path(file_registry_path).unlink() + + # Initialise pooch registry + registry = pooch.create( + test_data_dir, + base_url=f"{GIN_TEST_DATA_REPO}/raw/master/test_data", + ) + + # Download only the registry file from GIN + file_registry = pooch.retrieve( + url=f"{GIN_TEST_DATA_REPO}/raw/master/files-registry.txt", + known_hash=None, + fname=file_registry_path.name, + path=file_registry_path.parent, + ) + + # Load registry file onto pooch registry + registry.load_registry(file_registry) + + return registry + + +@pytest.fixture() +def get_paths_test_data(): + """Define a factory fixture to get the paths of the data files + under a specific subdirectory in the GIN repository. + + The name of the subdirectories is intended to match a testing module. For + example, to get the paths to the test files for the annotations + module, we would call `get_paths_test_data(pooch_registry, + "test_annotations")` in a test. This assumes in the GIN repository + there is a subdirectory named `test_annotations` under the `test_data` + directory with the relevant test files. + """ + + def _get_paths_test_data(pooch_registry, subdir_name: str) -> dict: + """Return the paths of the test files under the specified zip filename. + + subdir_name is the name of the subdirectory under `test_data`. + """ + test_filename_to_path = {} + for relative_filepath in pooch_registry.registry: + # relative to test_data + if relative_filepath.startswith(f"{subdir_name}/"): + # fetch file from pooch registry + fetched_filepath = pooch_registry.fetch( + relative_filepath, # under test_data + progressbar=True, + ) + + test_filename_to_path[Path(fetched_filepath).name] = Path( + fetched_filepath + ) + return test_filename_to_path + + return _get_paths_test_data diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/fixtures/annotations.py b/tests/fixtures/annotations.py new file mode 100644 index 0000000..dced3eb --- /dev/null +++ b/tests/fixtures/annotations.py @@ -0,0 +1,8 @@ +"""Pytest fixtures shared across annotation tests.""" + +import pytest + + +@pytest.fixture() +def annotations_test_data(pooch_registry, get_paths_test_data): + return get_paths_test_data(pooch_registry, "test_annotations") diff --git a/tests/test_unit/test_annotations/__init__.py b/tests/test_unit/test_annotations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_unit/test_annotations/test_validators.py b/tests/test_unit/test_annotations/test_validators.py new file mode 100644 index 0000000..604c7c4 --- /dev/null +++ b/tests/test_unit/test_annotations/test_validators.py @@ -0,0 +1,451 @@ +import json +from contextlib import nullcontext as does_not_raise + +import jsonschema +import pytest + +from ethology.annotations.json_schemas import COCO_SCHEMA, VIA_SCHEMA +from ethology.annotations.validators import ( + ValidCOCOJSON, + ValidJSON, + ValidVIAJSON, + _check_keys, +) + + +@pytest.fixture() +def json_file_with_decode_error(tmp_path): + """Return factory of paths to JSON files with a decoding error.""" + json_file = tmp_path / "JSON_decode_error.json" + with open(json_file, "w") as f: + f.write("just-a-string") + return json_file + + +@pytest.fixture() +def json_file_with_not_found_error(tmp_path): + """Return the path to a JSON file that does not exist.""" + return tmp_path / "JSON_file_not_found.json" + + +@pytest.fixture() +def via_json_file_with_schema_error( + tmp_path, + annotations_test_data, +): + """Return path to a VIA JSON file that doesn't match its schema.""" + return _json_file_with_schema_error( + tmp_path, + annotations_test_data["VIA_JSON_sample_1.json"], + ) + + +@pytest.fixture() +def coco_json_file_with_schema_error( + tmp_path, + annotations_test_data, +): + """Return path to a COCO JSON file that doesn't match its schema.""" + return _json_file_with_schema_error( + tmp_path, + annotations_test_data["COCO_JSON_sample_1.json"], + ) + + +def _json_file_with_schema_error(out_parent_path, json_valid_path): + """Return path to a JSON file that doesn't match the expected schema.""" + # read valid json file + with open(json_valid_path) as f: + data = json.load(f) + + # modify so that it doesn't match the corresponding schema + # if VIA, change "width" of a bounding box from int to float + # if COCO, change "annotations" from list of dicts to list of lists + if "VIA" in json_valid_path.name: + _, img_dict = list(data["_via_img_metadata"].items())[0] + img_dict["regions"][0]["shape_attributes"]["width"] = 49.5 + elif "COCO" in json_valid_path.name: + data["annotations"] = [[d] for d in data["annotations"]] + + # save the modified json to a new file + out_json = out_parent_path / f"{json_valid_path.name}_schema_error.json" + with open(out_json, "w") as f: + json.dump(data, f) + return out_json + + +@pytest.fixture() +def via_json_file_with_missing_keys(tmp_path, annotations_test_data): + """Return factory of paths to VIA JSON files with missing required keys.""" + + def _via_json_file_with_missing_keys( + valid_json_filename, required_keys_to_pop + ): + """Return path to a JSON file that is missing required keys.""" + # read valid json file + valid_json_path = annotations_test_data[valid_json_filename] + with open(valid_json_path) as f: + data = json.load(f) + + # remove any keys in the first level + for key in required_keys_to_pop.get("main", []): + data.pop(key) + + # remove keys in nested dictionaries + edited_image_dicts = {} + if "_via_img_metadata" in data: + list_img_metadata_tuples = list(data["_via_img_metadata"].items()) + + # remove image keys for first image dictionary + img_str, img_dict = list_img_metadata_tuples[0] + edited_image_dicts["image_keys"] = img_str + for key in required_keys_to_pop.get("image_keys", []): + img_dict.pop(key) + + # remove region keys for first region under second image dictionary + img_str, img_dict = list_img_metadata_tuples[1] + edited_image_dicts["region_keys"] = img_str + for key in required_keys_to_pop.get("region_keys", []): + img_dict["regions"][0].pop(key) + + # remove shape_attributes keys for first region under third image + # dictionary + img_str, img_dict = list_img_metadata_tuples[2] + edited_image_dicts["shape_attributes_keys"] = img_str + for key in required_keys_to_pop.get("shape_attributes_keys", []): + img_dict["regions"][0]["shape_attributes"].pop(key) + + # save the modified json to a new file + out_json = tmp_path / f"{valid_json_path.name}_missing_keys.json" + with open(out_json, "w") as f: + json.dump(data, f) + return out_json, edited_image_dicts + + return _via_json_file_with_missing_keys + + +@pytest.fixture() +def coco_json_file_with_missing_keys(tmp_path, annotations_test_data): + """Return factory of paths to COCO JSON files with missing required + keys. + """ + + def _coco_json_file_with_missing_keys( + valid_json_filename, required_keys_to_pop + ): + """Return path to a JSON file that is missing required keys.""" + # read valid json file + valid_json_path = annotations_test_data[valid_json_filename] + with open(valid_json_path) as f: + data = json.load(f) + + # remove any keys in the first level + for key in required_keys_to_pop.get("main", []): + data.pop(key) + + edited_image_dicts = {} + + # remove required keys in first images dictionary + if "images" in data: + edited_image_dicts["image_keys"] = data["images"][0] + for key in required_keys_to_pop.get("image_keys", []): + data["images"][0].pop(key) + + # remove required keys in first annotations dictionary + if "annotations" in data: + edited_image_dicts["annotations_keys"] = data["annotations"][0] + for key in required_keys_to_pop.get("annotations_keys", []): + data["annotations"][0].pop(key) + + # remove required keys in first categories dictionary + if "categories" in data: + edited_image_dicts["categories_keys"] = data["categories"][0] + for key in required_keys_to_pop.get("categories_keys", []): + data["categories"][0].pop(key) + + # save the modified json to a new file + out_json = tmp_path / f"{valid_json_path.name}_missing_keys.json" + with open(out_json, "w") as f: + json.dump(data, f) + return out_json, edited_image_dicts + + return _coco_json_file_with_missing_keys + + +@pytest.mark.parametrize( + "input_file_standard, input_schema", + [ + ("VIA", None), + ("VIA", VIA_SCHEMA), + ("COCO", None), + ("COCO", COCO_SCHEMA), + ], +) +@pytest.mark.parametrize( + "input_json_file_suffix", + ["JSON_sample_1.json", "JSON_sample_2.json"], +) +def test_valid_json( + input_file_standard, + input_json_file_suffix, + input_schema, + annotations_test_data, +): + """Test the ValidJSON validator with valid inputs.""" + filepath = annotations_test_data[ + f"{input_file_standard}_{input_json_file_suffix}" + ] + + with does_not_raise(): + ValidJSON( + path=filepath, + schema=input_schema, + ) + + +@pytest.mark.parametrize( + "invalid_json_file_str, input_schema, expected_exception, log_message", + [ + ( + "json_file_with_decode_error", + None, # should be independent of schema + pytest.raises(ValueError), + "Error decoding JSON data from file: {}.", + ), + ( + "json_file_with_not_found_error", + None, # should be independent of schema + pytest.raises(FileNotFoundError), + "File not found: {}.", + ), + ( + "via_json_file_with_schema_error", + VIA_SCHEMA, + pytest.raises(jsonschema.exceptions.ValidationError), + "49.5 is not of type 'integer'\n\n", + ), + ( + "coco_json_file_with_schema_error", + COCO_SCHEMA, + pytest.raises(jsonschema.exceptions.ValidationError), + "[{'area': 432, 'bbox': [1278, 556, 16, 27], 'category_id': 1, " + "'id': 8917, 'image_id': 199, 'iscrowd': 0}] is not of type " + "'object'\n\n", + ), + ], +) +def test_valid_json_errors( + invalid_json_file_str, + input_schema, + expected_exception, + log_message, + request, +): + """Test the ValidJSON validator throws the expected errors.""" + invalid_json_file = request.getfixturevalue(invalid_json_file_str) + + with expected_exception as excinfo: + ValidJSON(path=invalid_json_file, schema=input_schema) + + if input_schema: + assert log_message in str(excinfo.value) + else: + assert log_message.format(invalid_json_file) == str(excinfo.value) + + +@pytest.mark.parametrize( + "input_json_file", + [ + "VIA_JSON_sample_1.json", + "VIA_JSON_sample_2.json", + ], +) +def test_valid_via_json(annotations_test_data, input_json_file): + """Test the ValidVIAJSON validator with valid inputs.""" + filepath = annotations_test_data[input_json_file] + with does_not_raise(): + ValidVIAJSON( + path=filepath, + ) + + +@pytest.mark.parametrize( + "valid_via_json_file", + [ + "VIA_JSON_sample_1.json", + "VIA_JSON_sample_2.json", + ], +) +@pytest.mark.parametrize( + "missing_keys, expected_exception, log_message", + [ + ( + {"main": ["_via_image_id_list"]}, + pytest.raises(ValueError), + "Required key(s) ['_via_image_id_list'] not found " + "in ['_via_settings', '_via_img_metadata', '_via_attributes', " + "'_via_data_format_version'].", + ), + ( + {"main": ["_via_image_id_list", "_via_img_metadata"]}, + pytest.raises(ValueError), + "Required key(s) ['_via_image_id_list', '_via_img_metadata'] " + "not found in ['_via_settings', '_via_attributes', " + "'_via_data_format_version'].", + ), + ( + {"image_keys": ["filename"]}, + pytest.raises(ValueError), + "Required key(s) ['filename'] not found " + "in ['size', 'regions', 'file_attributes'] " + "for {}.", + ), + ( + {"region_keys": ["shape_attributes"]}, + pytest.raises(ValueError), + "Required key(s) ['shape_attributes'] not found in " + "['region_attributes'] for region 0 under {}.", + ), + ( + {"shape_attributes_keys": ["x"]}, + pytest.raises(ValueError), + "Required key(s) ['x'] not found in " + "['name', 'y', 'width', 'height'] for region 0 under {}.", + ), + ], +) +def test_valid_via_json_missing_keys( + valid_via_json_file, + missing_keys, + via_json_file_with_missing_keys, + expected_exception, + log_message, +): + """Test the ValidVIAJSON when input has missing keys.""" + # create invalid VIA json file with missing keys + invalid_json_file, edited_image_dicts = via_json_file_with_missing_keys( + valid_via_json_file, missing_keys + ) + + # get key of affected images in _via_img_metadata + img_key_str = edited_image_dicts.get(list(missing_keys.keys())[0], None) + + # run validation + with expected_exception as excinfo: + ValidVIAJSON( + path=invalid_json_file, + ) + + assert str(excinfo.value) == log_message.format(img_key_str) + + +@pytest.mark.parametrize( + "valid_coco_json_file", + [ + "COCO_JSON_sample_1.json", + "COCO_JSON_sample_2.json", + ], +) +@pytest.mark.parametrize( + "missing_keys, expected_exception, log_message", + [ + ( + {"main": ["categories"]}, + pytest.raises(ValueError), + "Required key(s) ['categories'] not found " + "in ['annotations', 'images', 'info', 'licenses'].", + ), + ( + {"main": ["categories", "images"]}, + pytest.raises(ValueError), + "Required key(s) ['categories', 'images'] not found " + "in ['annotations', 'info', 'licenses'].", + ), + ( + {"image_keys": ["file_name"]}, + pytest.raises(ValueError), + "Required key(s) ['file_name'] not found in " + "['height', 'id', 'width'] for image dict {}.", + ), + ( + {"annotations_keys": ["category_id"]}, + pytest.raises(ValueError), + "Required key(s) ['category_id'] not found in " + "['area', 'bbox', 'id', 'image_id', 'iscrowd'] for " + "annotation dict {}.", + ), + ( + {"categories_keys": ["id"]}, + pytest.raises(ValueError), + "Required key(s) ['id'] not found in " + "['name', 'supercategory'] for category dict {}.", + ), + ], +) +def test_valid_coco_json_missing_keys( + valid_coco_json_file, + missing_keys, + coco_json_file_with_missing_keys, + expected_exception, + log_message, +): + """Test the ValidCOCOJSON when input has missing keys.""" + # create invalid json file with missing keys + invalid_json_file, edited_image_dicts = coco_json_file_with_missing_keys( + valid_coco_json_file, missing_keys + ) + + # get key of affected image in _via_img_metadata + img_dict = edited_image_dicts.get(list(missing_keys.keys())[0], None) + + # run validation + with expected_exception as excinfo: + ValidCOCOJSON( + path=invalid_json_file, + ) + + assert str(excinfo.value) == log_message.format(img_dict) + + +@pytest.mark.parametrize( + "list_required_keys, data_dict, additional_message, expected_exception", + [ + ( + ["images", "annotations", "categories"], + {"images": "", "annotations": "", "categories": ""}, + "", + does_not_raise(), + ), # zero missing keys + ( + ["images", "annotations", "categories"], + {"annotations": "", "categories": ""}, + "", + pytest.raises(ValueError), + ), # one missing key + ( + ["images", "annotations", "categories"], + {"annotations": ""}, + "", + pytest.raises(ValueError), + ), # two missing keys + ( + ["images", "annotations", "categories"], + {"annotations": "", "categories": ""}, + "FOO", + pytest.raises(ValueError), + ), # one missing key with additional message + ], +) +def test_check_keys( + list_required_keys, data_dict, additional_message, expected_exception +): + """Test the _check_keys helper function.""" + with expected_exception as excinfo: + _check_keys(list_required_keys, data_dict, additional_message) + + if excinfo: + missing_keys = set(list_required_keys) - data_dict.keys() + assert str(excinfo.value) == ( + f"Required key(s) {sorted(missing_keys)} not " + f"found in {list(data_dict.keys())}{additional_message}." + ) diff --git a/tests/test_unit/test_placeholder.py b/tests/test_unit/test_placeholder.py deleted file mode 100644 index 3ada1ee..0000000 --- a/tests/test_unit/test_placeholder.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_placeholder(): - assert True