Skip to content

Commit

Permalink
fix coco subsampling and category updating, increase test coverage (#64)
Browse files Browse the repository at this point in the history
* fix get_subsampled_coco

* fix update_categories

* update tests for new image_dir property of Coco
  • Loading branch information
fcakyon authored May 1, 2021
1 parent 29b675d commit b25bf6c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
19 changes: 13 additions & 6 deletions sahi/utils/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,16 +802,16 @@ def add_image(self, image):

self.images.append(image)

def update_categories(self, desired_name2id, update_image_filepaths=False):
def update_categories(self, desired_name2id, update_image_filenames=False):
"""
Rearranges category mapping of given COCO object based on given desired_name2id.
Can also be used to filter some of the categories.
Args:
desired_name2id: dict
{"big_vehicle": 1, "car": 2, "human": 3}
update_image_filepaths: bool
If True, updates image file_paths with absolute file paths.
update_image_filenames: bool
If True, updates coco image file_names with absolute file paths.
"""
# init vars
currentid2desiredid_mapping = {}
Expand Down Expand Up @@ -844,8 +844,11 @@ def update_categories(self, desired_name2id, update_image_filepaths=False):
# add updated images & annotations
for coco_image in copy.deepcopy(self.images):
updated_coco_image = CocoImage.from_coco_image_dict(coco_image.json)
if update_image_filepaths:
# update filename to abspath
file_name_is_abspath = True if os.path.abspath(coco_image.file_name) == coco_image.file_name else False
if update_image_filenames and not file_name_is_abspath:
updated_coco_image.file_name = str(Path(os.path.abspath(self.image_dir)) / coco_image.file_name)
# update annotations
for coco_annotation in coco_image.annotations:
current_category_id = coco_annotation.category_id
desired_category_id = currentid2desiredid_mapping[current_category_id]
Expand Down Expand Up @@ -896,7 +899,7 @@ def merge(self, coco, desired_name2id=None, verbose=1):

# update categories and image paths
for coco in [coco1, coco2]:
coco.update_categories(desired_name2id=desired_name2id, update_image_filepaths=True)
coco.update_categories(desired_name2id=desired_name2id, update_image_filenames=True)

# combine images and categories
coco1.images.extend(coco2.images)
Expand Down Expand Up @@ -1131,7 +1134,11 @@ def get_subsampled_coco(self, subsample_ratio=10):
Returns:
subsampled_coco: sahi.utils.coco.Coco
"""
subsampled_coco = Coco(name=self.name, remapping_dict=self.remapping_dict)
subsampled_coco = Coco(
name=self.name,
image_dir=self.image_dir,
remapping_dict=self.remapping_dict
)
subsampled_coco.add_categories_from_coco_category_list(self.json_categories)
for image_ind in tqdm(range(0, len(self.images), subsample_ratio)):
subsampled_coco.add_image(self.images[image_ind])
Expand Down
29 changes: 26 additions & 3 deletions tests/test_cocoutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,17 +291,20 @@ def test_split_coco_as_train_val(self):
from sahi.utils.coco import Coco

coco_dict_path = "tests/data/coco_utils/combined_coco.json"
coco = Coco.from_coco_dict_or_path(coco_dict_path)
image_dir = "tests/data/coco_utils/"
coco = Coco.from_coco_dict_or_path(coco_dict_path, image_dir=image_dir)
result = coco.split_coco_as_train_val(
train_split_rate=0.5, numpy_seed=0
)
self.assertEqual(len(result["train_coco"].json["images"]), 1)
self.assertEqual(len(result["train_coco"].json["annotations"]), 5)
self.assertEqual(result["train_coco"].json["images"][0]["height"], 682)
self.assertEqual(result["train_coco"].image_dir, image_dir)

self.assertEqual(len(result["val_coco"].json["images"]), 1)
self.assertEqual(len(result["val_coco"].json["annotations"]), 7)
self.assertEqual(result["val_coco"].json["images"][0]["height"], 1365)
self.assertEqual(result["val_coco"].image_dir, image_dir)

def test_coco2yolo(self):
from sahi.utils.coco import Coco
Expand Down Expand Up @@ -352,7 +355,8 @@ def test_coco_update_categories(self):
from sahi.utils.coco import Coco

coco_path = "tests/data/coco_utils/terrain2_coco.json"
coco = Coco.from_coco_dict_or_path(coco_path)
image_dir = "tests/data/coco_utils/"
coco = Coco.from_coco_dict_or_path(coco_path, image_dir=image_dir)

self.assertEqual(len(coco.json["annotations"]), 5)
self.assertEqual(len(coco.json["images"]), 1)
Expand All @@ -362,6 +366,7 @@ def test_coco_update_categories(self):
[{"id": 1, "name": "car", "supercategory": "car"}],
)
self.assertEqual(coco.json["annotations"][1]["category_id"], 1)
self.assertEqual(coco.image_dir, image_dir)

# update categories
desired_name2id = {"human": 1, "car": 2, "big_vehicle": 3}
Expand All @@ -379,6 +384,7 @@ def test_coco_update_categories(self):
],
)
self.assertEqual(coco.json["annotations"][1]["category_id"], 2)
self.assertEqual(coco.image_dir, image_dir)

def test_get_imageid2annotationlist_mapping(self):
from sahi.utils.coco import get_imageid2annotationlist_mapping
Expand Down Expand Up @@ -533,13 +539,22 @@ def test_coco_merge(self):
coco1.json["annotations"][9]["image_id"],
2,
)
self.assertEqual(
coco1.image_dir,
image_dir,
)
self.assertEqual(
coco2.image_dir,
image_dir,
)

def test_get_subsampled_coco(self):
from sahi.utils.coco import Coco
from sahi.utils.file import load_json

coco_path = "tests/data/coco_utils/visdrone2019-det-train-first50image.json"
coco = Coco.from_coco_dict_or_path(coco_path)
image_dir = "tests/data/coco_utils/"
coco = Coco.from_coco_dict_or_path(coco_path, image_dir=image_dir)
subsampled_coco = coco.get_subsampled_coco(subsample_ratio=5)
self.assertEqual(
len(coco.json["images"]),
Expand All @@ -557,6 +572,14 @@ def test_get_subsampled_coco(self):
len(coco.images[5].annotations),
len(subsampled_coco.images[1].annotations),
)
self.assertEqual(
coco.image_dir,
image_dir,
)
self.assertEqual(
subsampled_coco.image_dir,
image_dir,
)

def test_cocovid(self):
from sahi.utils.coco import CocoVid
Expand Down

0 comments on commit b25bf6c

Please sign in to comment.