Skip to content

Commit

Permalink
refactorize Coco merging and category updating (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcakyon authored May 1, 2021
1 parent af9dc47 commit 0acf928
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 55 deletions.
183 changes: 148 additions & 35 deletions sahi/utils/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,15 @@ def add_annotation(self, annotation):
), "annotation must be a CocoAnnotation instance"
self.annotations.append(annotation)

@property
def json(self):
return {
"id": self.id,
"file_name": self.file_name,
"height": self.height,
"width": self.width,
}

def __repr__(self):
return f"""CocoImage<
id: {self.id},
Expand Down Expand Up @@ -728,18 +737,21 @@ def __repr__(self):


class Coco:
def __init__(self, name=None, remapping_dict=None):
def __init__(self, name=None, image_dir=None, remapping_dict=None):
"""
Creates Coco object.
Args:
name: str
Name of the Coco dataset, it determines exported json name.
image_dir: str
Base file directory that contains dataset images. Required for dataset merging.
remapping_dict: dict
{1:0, 2:1} maps category id 1 to 0 and category id 2 to 1
"""
self.name = name
self.remapping_dict = remapping_dict # TODO: utilize remapping_dict
self.image_dir = image_dir
self.remapping_dict = remapping_dict
self.categories = []
self.images = []

Expand Down Expand Up @@ -790,8 +802,116 @@ def add_image(self, image):

self.images.append(image)

def update_categories(self, desired_name2id, update_image_filepaths=False):
"""
Rearranges category mapping of given COCO object based on given desired_name2id.
Can also be used to filter some of the categories.
Args:
desired_name2id: dict
{"big_vehicle": 1, "car": 2, "human": 3}
update_image_filepaths: bool
If True, updates image file_paths with absolute file paths.
"""
# init vars
currentid2desiredid_mapping = {}
updated_coco = Coco(
name=self.name,
image_dir=self.image_dir,
remapping_dict=self.remapping_dict
)
# create category id mapping (currentid2desiredid_mapping)
for coco_category in copy.deepcopy(self.categories):
current_category_id = coco_category.id
current_category_name = coco_category.name
if current_category_name in desired_name2id.keys():
currentid2desiredid_mapping[current_category_id] = desired_name2id[
current_category_name
]
else:
# ignore categories that are not included in desired_name2id
currentid2desiredid_mapping[current_category_id] = None

# add updated categories
for name in desired_name2id.keys():
updated_coco_category = CocoCategory(
id=desired_name2id[name],
name=name,
supercategory=name
)
updated_coco.add_category(updated_coco_category)

# add updated images & annotations
for coco_image in copy.deepcopy(self.images):
updated_coco_image = CocoImage.from_coco_image_dict(coco_image.json)
if update_image_filepaths:
updated_coco_image.file_name = str(Path(os.path.abspath(self.image_dir)) / coco_image.file_name)
for coco_annotation in coco_image.annotations:
current_category_id = coco_annotation.category_id
desired_category_id = currentid2desiredid_mapping[current_category_id]
# append annotations with category id present in desired_name2id
if desired_category_id is not None:
# update cetegory id
coco_annotation.category_id = desired_category_id
# append updated annotation to target coco dict
updated_coco_image.add_annotation(coco_annotation)
updated_coco.add_image(updated_coco_image)

# overwrite instance
self.__class__ = updated_coco.__class__
self.__dict__ = updated_coco.__dict__

def merge(self, coco, desired_name2id=None, verbose=1):
"""
Combines the images/annotations/categories of given coco object with current one.
Args:
coco : sahi.utils.coco.Coco instance
A COCO dataset object
desired_name2id : dict
{"human": 1, "car": 2, "big_vehicle": 3}
verbose: bool
If True, merging info is printed
"""
assert self.image_dir and coco.image_dir, "image_dir should be provided for merging."

if verbose:
if not desired_name2id:
print("'desired_name2id' is not specified, combining all categories.")

# create desired_name2id by combining all categories, if desired_name2id is not specified
coco1 = self
coco2 = coco
category_ind = 0
if desired_name2id is None:
desired_name2id = {}
for coco in [coco1, coco2]:
temp_categories = copy.deepcopy(coco.json_categories)
for temp_category in temp_categories:
if temp_category["name"] not in desired_name2id:
desired_name2id[temp_category["name"]] = category_ind
category_ind += 1
else:
continue

# update categories and image paths
for coco in [coco1, coco2]:
coco.update_categories(desired_name2id=desired_name2id, update_image_filepaths=True)

# combine images and categories
coco1.images.extend(coco2.images)
self.images = coco1.images
self.categories = coco1.categories

# print categories
if verbose:
print(
"Categories are formed as:\n",
self.json_categories,
)

@classmethod
def from_coco_dict_or_path(cls, coco_dict_or_path, desired_name2id=None, remapping_dict=None, mp=False):
def from_coco_dict_or_path(cls, coco_dict_or_path, desired_name2id=None, image_dir=None, remapping_dict=None, mp=False):
"""
Creates coco object from COCO formatted dict or COCO dataset file path.
Expand All @@ -801,6 +921,8 @@ def from_coco_dict_or_path(cls, coco_dict_or_path, desired_name2id=None, remappi
List of COCO formatted dict or COCO dataset file path
desired_name2id : dict
{"human": 1, "car": 2, "big_vehicle": 3}
image_dir: str
Base file directory that contains dataset images. Required for merging and yolov5 conversion.
remapping_dict: dict
{1:0, 2:1} maps category id 1 to 0 and category id 2 to 1
mp: bool
Expand All @@ -812,28 +934,17 @@ def from_coco_dict_or_path(cls, coco_dict_or_path, desired_name2id=None, remappi
category_mapping: dict
"""
# init coco object
coco = cls(remapping_dict=remapping_dict)

if type(coco_dict_or_path) == list: # merge coco datasets if given as list
# create coco_dict_list
coco_dict_list = []
coco_dict_or_path_list = copy.deepcopy(coco_dict_or_path)
for coco_dict_or_path in coco_dict_or_path_list:
# load coco dataset dict
if type(coco_dict_or_path) == str:
coco_dict = load_json(coco_dict_or_path)
else:
coco_dict = coco_dict_or_path
# append to list
coco_dict_list.append(coco_dict)
# merge coco dicts
coco_dict = merge_from_list(coco_dict_list, desired_name2id=None)
coco = cls(image_dir=image_dir, remapping_dict=remapping_dict)

assert (type(coco_dict_or_path) == str
or type(coco_dict_or_path) == dict), \
"coco_dict_or_path should be dict or str"

# load coco dict if path is given
if type(coco_dict_or_path) == str:
coco_dict = load_json(coco_dict_or_path)
else:
# load coco dict if path is given
if type(coco_dict_or_path) == str:
coco_dict = load_json(coco_dict_or_path)
else:
coco_dict = coco_dict_or_path
coco_dict = coco_dict_or_path

# arrange image id to annotation id mapping
coco.add_categories_from_coco_category_list(coco_dict["categories"])
Expand All @@ -843,28 +954,26 @@ def from_coco_dict_or_path(cls, coco_dict_or_path, desired_name2id=None, remappi
imageid2annotationlist = get_imageid2annotationlist_mapping(coco_dict)
category_mapping = coco.category_mapping

coco_image_list = []
for coco_image_dict in coco_dict["images"]:
coco_image = CocoImage.from_coco_image_dict(coco_image_dict)
annotation_list = imageid2annotationlist[coco_image_dict["id"]]
for coco_annotation_dict in annotation_list:
# apply category remapping if remapping_dict is provided
if coco.remapping_dict is not None:
# apply category remapping (id:id)
category_id = coco.remapping_dict[coco_annotation_dict["category_id"]]
category_id = coco.remapping_dict[coco_annotation_dict["category_id"]]
# update category id
coco_annotation_dict["category_id"] = category_id
coco_annotation_dict["category_id"] = category_id
else:
category_id = coco_annotation_dict["category_id"]
# get category name (id:name)
category_name = category_mapping[category_id]
category_name = category_mapping[category_id]
coco_annotation = CocoAnnotation.from_coco_annotation_dict(
category_name=category_name, annotation_dict=coco_annotation_dict
)
coco_image.add_annotation(coco_annotation)
coco_image_list.append(coco_image)
coco.add_image(coco_image)

coco.images = coco_image_list
return coco

@property
Expand Down Expand Up @@ -1225,13 +1334,15 @@ def merge_from_list(coco_dict_list, desired_name2id=None, verbose=1):
Arguments:
---------
coco_dict)list : list of dict
coco_dict_list: list of dict
A list of coco dicts
desired_name2id : dict
desired_name2id: dict
{"human": 1, "car": 2, "big_vehicle": 3}
verbose: bool
If True, merging info is printed
Returns:
---------
merged_coco_dict : dict
merged_coco_dict: dict
Merged COCO dict.
"""
if verbose:
Expand All @@ -1241,11 +1352,13 @@ def merge_from_list(coco_dict_list, desired_name2id=None, verbose=1):
# create desired_name2id by combinin all categories, if desired_name2id is not specified
if desired_name2id is None:
desired_name2id = {}
for ind, coco_dict in enumerate(coco_dict_list):
ind = 0
for coco_dict in coco_dict_list:
temp_categories = copy.deepcopy(coco_dict["categories"])
for temp_category in temp_categories:
if temp_category["name"] not in desired_name2id:
desired_name2id[temp_category["name"]] = ind + 1
desired_name2id[temp_category["name"]] = ind
ind += 1
else:
continue

Expand Down
69 changes: 49 additions & 20 deletions tests/test_cocoutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,38 @@ def test_update_categories(self):
)
self.assertEqual(target_coco_dict["annotations"][1]["category_id"], 2)

def test_coco_update_categories(self):
from sahi.utils.coco import Coco

coco_path = "tests/data/coco_utils/terrain2_coco.json"
coco = Coco.from_coco_dict_or_path(coco_path)

self.assertEqual(len(coco.json["annotations"]), 5)
self.assertEqual(len(coco.json["images"]), 1)
self.assertEqual(len(coco.json["categories"]), 1)
self.assertEqual(
coco.json["categories"],
[{"id": 1, "name": "car", "supercategory": "car"}],
)
self.assertEqual(coco.json["annotations"][1]["category_id"], 1)

# update categories
desired_name2id = {"human": 1, "car": 2, "big_vehicle": 3}
coco.update_categories(desired_name2id=desired_name2id)

self.assertEqual(len(coco.json["annotations"]), 5)
self.assertEqual(len(coco.json["images"]), 1)
self.assertEqual(len(coco.json["categories"]), 3)
self.assertEqual(
coco.json["categories"],
[
{"id": 1, "name": "human", "supercategory": "human"},
{"id": 2, "name": "car", "supercategory": "car"},
{"id": 3, "name": "big_vehicle", "supercategory": "big_vehicle"},
],
)
self.assertEqual(coco.json["annotations"][1]["category_id"], 2)

def test_get_imageid2annotationlist_mapping(self):
from sahi.utils.coco import get_imageid2annotationlist_mapping

Expand Down Expand Up @@ -454,54 +486,51 @@ def test_merge_from_list(self):
merged_coco_dict["annotations"][12]["id"],
13,
)
self.assertEqual(
merged_coco_dict["annotations"][12]["category_id"],
coco_dict3["annotations"][0]["category_id"],
)
self.assertEqual(
merged_coco_dict["annotations"][12]["image_id"],
3,
)
self.assertEqual(
merged_coco_dict["annotations"][12]["category_id"],
coco_dict3["annotations"][0]["category_id"],
)
self.assertEqual(
merged_coco_dict["annotations"][9]["category_id"],
2,
1,
)
self.assertEqual(
merged_coco_dict["annotations"][9]["image_id"],
2,
)

def test_multi_coco_init(self):
def test_coco_merge(self):
from sahi.utils.coco import Coco

# load coco files to be combined
coco_path1 = "tests/data/coco_utils/terrain1_coco.json"
coco_path2 = "tests/data/coco_utils/terrain2_coco.json"
coco_path3 = "tests/data/coco_utils/terrain3_coco.json"
coco = Coco.from_coco_dict_or_path([coco_path1, coco_path2, coco_path3])
self.assertEqual(len(coco.json["images"]), 3)
self.assertEqual(len(coco.json["annotations"]), 22)
self.assertEqual(len(coco.json["categories"]), 2)
self.assertEqual(len(coco.images), 3)
image_dir = "tests/data/coco_utils/"
coco1 = Coco.from_coco_dict_or_path(coco_path1, image_dir=image_dir)
coco2 = Coco.from_coco_dict_or_path(coco_path2, image_dir=image_dir)
coco3 = Coco.from_coco_dict_or_path(coco_path3, image_dir=image_dir)
coco1.merge(coco2)
coco1.merge(coco3)
self.assertEqual(len(coco1.json["images"]), 3)
self.assertEqual(len(coco1.json["annotations"]), 22)
self.assertEqual(len(coco1.json["categories"]), 2)
self.assertEqual(len(coco1.images), 3)

self.assertEqual(
coco.json["annotations"][12]["id"],
coco1.json["annotations"][12]["id"],
13,
)
self.assertEqual(
coco.json["annotations"][12]["image_id"],
coco1.json["annotations"][12]["image_id"],
3,
)
self.assertEqual(
coco.json["annotations"][9]["category_id"],
2,
coco1.json["annotations"][9]["category_id"],
1,
)
self.assertEqual(
coco.json["annotations"][9]["image_id"],
coco1.json["annotations"][9]["image_id"],
2,
)

Expand Down

0 comments on commit 0acf928

Please sign in to comment.