From f864fc6b5f74680af461cd844eced4cf5679e4e9 Mon Sep 17 00:00:00 2001 From: Ruge Li Date: Tue, 5 Sep 2023 15:50:16 -0700 Subject: [PATCH] rename and reorg DB handler --- cellpack/autopack/DBRecipeHandler.py | 162 +++++++++++++++++---- cellpack/autopack/loaders/recipe_loader.py | 90 +----------- cellpack/bin/pack.py | 4 +- cellpack/bin/upload.py | 4 +- 4 files changed, 137 insertions(+), 123 deletions(-) diff --git a/cellpack/autopack/DBRecipeHandler.py b/cellpack/autopack/DBRecipeHandler.py index 3dedc836b..364bb611c 100644 --- a/cellpack/autopack/DBRecipeHandler.py +++ b/cellpack/autopack/DBRecipeHandler.py @@ -21,6 +21,28 @@ def should_write(): def is_key(string_or_dict): return not isinstance(string_or_dict, dict) + @staticmethod + def is_nested_list(item): + return ( + isinstance(item, list) + and len(item) > 0 + and isinstance(item[0], (list, tuple)) + ) + + @staticmethod + def is_db_dict(item): + if isinstance(item, dict) and len(item) > 0: + for key, value in item.items(): + if key.isdigit() and isinstance(value, list): + return True + return False + + @staticmethod + def is_obj(comp_or_obj): + # in resolved DB data, if the top level of a downloaded comp doesn't have the key `name`, it's an obj + # TODO: true for all cases? better approaches? + return not comp_or_obj.get("name") and "object" in comp_or_obj + class CompositionDoc(DataDoc): """ @@ -124,7 +146,7 @@ def resolve_local_regions(self, local_data, recipe_data, db): Recursively resolves the regions of a composition from local data. Restructure the local data to match the db data. """ - unpack_recipe_data = DBHandler.prep_data_for_db(recipe_data) + unpack_recipe_data = DBUploader.prep_data_for_db(recipe_data) prep_recipe_data = ObjectDoc.convert_representation(unpack_recipe_data, db) # `gradients` is a list, convert it to dict for easy access and replace CompositionDoc.gradient_list_to_dict(prep_recipe_data) @@ -321,7 +343,7 @@ def should_write(self, db): # if there is repr in the obj doc from db full_doc_data = ObjectDoc.convert_representation(doc, db) # unpack objects to dicts in local data for comparison - local_data = DBHandler.prep_data_for_db(self.as_dict()) + local_data = DBUploader.prep_data_for_db(self.as_dict()) difference = DeepDiff(full_doc_data, local_data, ignore_order=True) if not difference: return doc, db.doc_id(doc) @@ -337,7 +359,7 @@ def should_write(self, db, grad_name): docs = db.get_doc_by_name("gradients", grad_name) if docs and len(docs) >= 1: for doc in docs: - local_data = DBHandler.prep_data_for_db(db.doc_to_dict(doc)) + local_data = DBUploader.prep_data_for_db(db.doc_to_dict(doc)) db_data = db.doc_to_dict(doc) difference = DeepDiff(db_data, local_data, ignore_order=True) if not difference: @@ -345,32 +367,17 @@ def should_write(self, db, grad_name): return None, None -class DBHandler(object): +class DBUploader(object): + """ + Handles the uploading of data to the database. + """ + def __init__(self, db_handler): self.db = db_handler self.objects_to_path_map = {} self.comp_to_path_map = {} self.grad_to_path_map = {} - @staticmethod - def is_nested_list(item): - return ( - isinstance(item, list) - and len(item) > 0 - and isinstance(item[0], (list, tuple)) - ) - - @staticmethod - def is_db_dict(item): - if isinstance(item, dict) and len(item) > 0: - for key, value in item.items(): - if key.isdigit() and isinstance(value, list): - return True - return False - - def collect_docs_by_id(self, collection, id): - return self.db.get_doc_by_id(collection, id) - @staticmethod def prep_data_for_db(data): """ @@ -379,18 +386,18 @@ def prep_data_for_db(data): modified_data = {} for key, value in data.items(): # convert 2d array to dict - if DBHandler.is_nested_list(value): + if DataDoc.is_nested_list(value): flatten_dict = dict(zip([str(i) for i in range(len(value))], value)) - modified_data[key] = DBHandler.prep_data_for_db(flatten_dict) + modified_data[key] = DBUploader.prep_data_for_db(flatten_dict) # If the value is an object, we want to convert it to dict elif isinstance(value, object) and "__dict__" in dir(value): unpacked_value = vars(value) modified_data[key] = unpacked_value if isinstance(unpacked_value, dict): - modified_data[key] = DBHandler.prep_data_for_db(unpacked_value) + modified_data[key] = DBUploader.prep_data_for_db(unpacked_value) # If the value is a dictionary, recursively convert its nested lists to dictionaries elif isinstance(value, dict): - modified_data[key] = DBHandler.prep_data_for_db(value) + modified_data[key] = DBUploader.prep_data_for_db(value) else: modified_data[key] = value return modified_data @@ -400,7 +407,7 @@ def upload_data(self, collection, data, id=None): If should_write is true, upload the data to the database """ # check if we need to convert part of the data(2d arrays and objs to dict) - modified_data = DBHandler.prep_data_for_db(data) + modified_data = DBUploader.prep_data_for_db(data) if id is None: name = modified_data["name"] doc = self.db.upload_doc(collection, modified_data) @@ -491,7 +498,7 @@ def upload_compositions(self, compositions, recipe_to_save, recipe_data): references_to_update[comp_name].update({"comp_id": doc_id}) return references_to_update - def get_recipe_id(self, recipe_data): + def _get_recipe_id(self, recipe_data): """ We use customized recipe id to declare recipe's name and version """ @@ -535,16 +542,25 @@ def upload_recipe(self, recipe_meta_data, recipe_data): """ After all other collections are checked or uploaded, upload the recipe with references into db """ - recipe_id = self.get_recipe_id(recipe_data) + recipe_id = self._get_recipe_id(recipe_data) # if the recipe is already exists in db, just return recipe, _ = self.db.get_doc_by_id("recipes", recipe_id) if recipe: print(f"{recipe_id} is already in firestore") return recipe_to_save = self.upload_collections(recipe_meta_data, recipe_data) - key = self.get_recipe_id(recipe_to_save) + key = self._get_recipe_id(recipe_to_save) self.upload_data("recipes", recipe_to_save, key) + +class DBRecipeLoader(object): + """ + Handles the logic for downloading and parsing the recipe data from the database. + """ + + def __init__(self, db_handler): + self.db = db_handler + def prep_db_doc_for_download(self, db_doc): """ convert data from db and resolve references. @@ -552,7 +568,7 @@ def prep_db_doc_for_download(self, db_doc): prep_data = {} if isinstance(db_doc, dict): for key, value in db_doc.items(): - if self.is_db_dict(value): + if DataDoc.is_db_dict(value): unpack_dict = [value[str(i)] for i in range(len(value))] prep_data[key] = unpack_dict elif key == "composition": @@ -575,3 +591,85 @@ def prep_db_doc_for_download(self, db_doc): else: prep_data[key] = value return prep_data + + def collect_docs_by_id(self, collection, id): + return self.db.get_doc_by_id(collection, id) + + @staticmethod + def _get_grad_and_obj(obj_data, obj_dict, grad_dict): + try: + grad_name = obj_data["gradient"]["name"] + obj_name = obj_data["name"] + except KeyError as e: + print(f"Missing keys in object: {e}") + return obj_dict, grad_dict + + grad_dict[grad_name] = obj_data["gradient"] + obj_dict[obj_name]["gradient"] = grad_name + return obj_dict, grad_dict + + @staticmethod + def _collect_and_sort_data(comp_data): + """ + Collect all object and gradient info from the downloaded composition data + Return autopack object data dict and gradient data dict with name as key + Return restructured composition dict with "composition" as key + """ + objects = {} + gradients = {} + composition = {} + for comp_name, comp_value in comp_data.items(): + composition[comp_name] = {} + if "count" in comp_value and comp_value["count"] is not None: + composition[comp_name]["count"] = comp_value["count"] + if "object" in comp_value and comp_value["object"] is not None: + composition[comp_name]["object"] = comp_value["object"]["name"] + object_copy = copy.deepcopy(comp_value["object"]) + objects[object_copy["name"]] = object_copy + if "gradient" in object_copy and isinstance( + object_copy["gradient"], dict + ): + objects, gradients = DBRecipeLoader._get_grad_and_obj( + object_copy, objects, gradients + ) + if "regions" in comp_value and comp_value["regions"] is not None: + for region_name in comp_value["regions"]: + composition[comp_name].setdefault("regions", {})[region_name] = [] + for region_item in comp_value["regions"][region_name]: + if DataDoc.is_obj(region_item): + composition[comp_name]["regions"][region_name].append( + { + "object": region_item["object"].get("name"), + "count": region_item.get("count"), + } + ) + object_copy = copy.deepcopy(region_item["object"]) + objects[object_copy["name"]] = object_copy + if "gradient" in object_copy and isinstance( + object_copy["gradient"], dict + ): + objects, gradients = DBRecipeLoader._get_grad_and_obj( + object_copy, objects, gradients + ) + else: + composition[comp_name]["regions"][region_name].append( + region_item["name"] + ) + return objects, gradients, composition + + @staticmethod + def _compile_db_recipe_data(db_recipe_data, obj_dict, grad_dict, comp_dict): + """ + Compile recipe data from db recipe data into a ready-to-pack structure + """ + recipe_data = { + **{ + k: db_recipe_data[k] + for k in ["format_version", "version", "name", "bounding_box"] + }, + "objects": obj_dict, + "composition": comp_dict, + } + if grad_dict: + recipe_data["gradients"] = [{**v} for v in grad_dict.values()] + return recipe_data diff --git a/cellpack/autopack/loaders/recipe_loader.py b/cellpack/autopack/loaders/recipe_loader.py index 2f318177f..487ce44fb 100644 --- a/cellpack/autopack/loaders/recipe_loader.py +++ b/cellpack/autopack/loaders/recipe_loader.py @@ -17,6 +17,7 @@ ) from cellpack.autopack.loaders.migrate_v1_to_v2 import convert as convert_v1_to_v2 from cellpack.autopack.loaders.migrate_v2_to_v2_1 import convert as convert_v2_to_v2_1 +from cellpack.autopack.DBRecipeHandler import DBRecipeLoader encoder.FLOAT_REPR = lambda o: format(o, ".8g") CURRENT_VERSION = "2.1" @@ -166,100 +167,15 @@ def _migrate_version(self, old_recipe): f"{old_recipe['format_version']} is not a format version we support" ) - @staticmethod - def _get_grad_and_obj(obj_data, obj_dict, grad_dict): - try: - grad_name = obj_data["gradient"]["name"] - obj_name = obj_data["name"] - except KeyError as e: - print(f"Missing keys in object: {e}") - return obj_dict, grad_dict - - grad_dict[grad_name] = obj_data["gradient"] - obj_dict[obj_name]["gradient"] = grad_name - return obj_dict, grad_dict - - @staticmethod - def _is_obj(comp_or_obj): - # if the top level of a downloaded comp doesn't have the key `name`, it's an obj - # TODO: true for all cases? better approaches? - return not comp_or_obj.get("name") and "object" in comp_or_obj - - @staticmethod - def _collect_and_sort_data(comp_data): - """ - Collect all object and gradient info from the downloaded firebase composition data - Return autopack object data dict and gradient data dict with name as key - Return restructured composition dict with "composition" as key - """ - objects = {} - gradients = {} - composition = {} - for comp_name, comp_value in comp_data.items(): - composition[comp_name] = {} - if "count" in comp_value and comp_value["count"] is not None: - composition[comp_name]["count"] = comp_value["count"] - if "object" in comp_value and comp_value["object"] is not None: - composition[comp_name]["object"] = comp_value["object"]["name"] - object_copy = copy.deepcopy(comp_value["object"]) - objects[object_copy["name"]] = object_copy - if "gradient" in object_copy and isinstance( - object_copy["gradient"], dict - ): - objects, gradients = RecipeLoader._get_grad_and_obj( - object_copy, objects, gradients - ) - if "regions" in comp_value and comp_value["regions"] is not None: - for region_name in comp_value["regions"]: - composition[comp_name].setdefault("regions", {})[region_name] = [] - for region_item in comp_value["regions"][region_name]: - if RecipeLoader._is_obj(region_item): - composition[comp_name]["regions"][region_name].append( - { - "object": region_item["object"].get("name"), - "count": region_item.get("count"), - } - ) - object_copy = copy.deepcopy(region_item["object"]) - objects[object_copy["name"]] = object_copy - if "gradient" in object_copy and isinstance( - object_copy["gradient"], dict - ): - objects, gradients = RecipeLoader._get_grad_and_obj( - object_copy, objects, gradients - ) - else: - composition[comp_name]["regions"][region_name].append( - region_item["name"] - ) - return objects, gradients, composition - - @staticmethod - def _compile_recipe_from_firebase(db_recipe_data, obj_dict, grad_dict, comp_dict): - """ - Compile recipe data from firebase recipe data into a ready-to-pack structure - """ - recipe_data = { - **{ - k: db_recipe_data[k] - for k in ["format_version", "version", "name", "bounding_box"] - }, - "objects": obj_dict, - "composition": comp_dict, - } - if grad_dict: - recipe_data["gradients"] = [{**v} for v in grad_dict.values()] - return recipe_data - def _read(self): new_values, database_name = autopack.load_file( self.file_path, self.db_handler, cache="recipes" ) if database_name == "firebase": - objects, gradients, composition = RecipeLoader._collect_and_sort_data( + objects, gradients, composition = DBRecipeLoader._collect_and_sort_data( new_values["composition"] ) - new_values = RecipeLoader._compile_recipe_from_firebase( + new_values = DBRecipeLoader._compile_db_recipe_data( new_values, objects, gradients, composition ) recipe_data = RecipeLoader.default_values.copy() diff --git a/cellpack/bin/pack.py b/cellpack/bin/pack.py index 6b983d196..45a6dcbda 100644 --- a/cellpack/bin/pack.py +++ b/cellpack/bin/pack.py @@ -14,7 +14,7 @@ from cellpack.autopack.loaders.recipe_loader import RecipeLoader from cellpack.autopack.loaders.analysis_config_loader import AnalysisConfigLoader from cellpack.autopack.FirebaseHandler import FirebaseHandler -from cellpack.autopack.DBRecipeHandler import DBHandler +from cellpack.autopack.DBRecipeHandler import DBRecipeLoader ############################################################################### log_file_path = path.abspath(path.join(__file__, "../../logging.conf")) @@ -42,7 +42,7 @@ def pack( """ if db_id == DATABASE_IDS.FIREBASE: db = FirebaseHandler() - db_handler = DBHandler(db) + db_handler = DBRecipeLoader(db) packing_config_data = ConfigLoader(config_path).config recipe_data = RecipeLoader( recipe, db_handler, packing_config_data["save_converted_recipe"] diff --git a/cellpack/bin/upload.py b/cellpack/bin/upload.py index f1a299197..6b811032b 100644 --- a/cellpack/bin/upload.py +++ b/cellpack/bin/upload.py @@ -1,7 +1,7 @@ from enum import Enum import fire from cellpack.autopack.FirebaseHandler import FirebaseHandler -from cellpack.autopack.DBRecipeHandler import DBHandler +from cellpack.autopack.DBRecipeHandler import DBUploader from cellpack.autopack.loaders.recipe_loader import RecipeLoader @@ -26,7 +26,7 @@ def upload( recipe_loader = RecipeLoader(recipe_path) recipe_full_data = recipe_loader.recipe_data recipe_meta_data = recipe_loader.get_only_recipe_metadata() - recipe_db_handler = DBHandler(db_handler) + recipe_db_handler = DBUploader(db_handler) recipe_db_handler.upload_recipe(recipe_meta_data, recipe_full_data)