Skip to content

Commit

Permalink
rename and reorg DB handler
Browse files Browse the repository at this point in the history
  • Loading branch information
rugeli committed Sep 5, 2023
1 parent 09fcbef commit f864fc6
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 123 deletions.
162 changes: 130 additions & 32 deletions cellpack/autopack/DBRecipeHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,28 @@ def should_write():
def is_key(string_or_dict):
return not isinstance(string_or_dict, dict)

@staticmethod
def is_nested_list(item):
return (
isinstance(item, list)
and len(item) > 0
and isinstance(item[0], (list, tuple))
)

@staticmethod
def is_db_dict(item):
if isinstance(item, dict) and len(item) > 0:
for key, value in item.items():
if key.isdigit() and isinstance(value, list):
return True
return False

@staticmethod
def is_obj(comp_or_obj):
# in resolved DB data, if the top level of a downloaded comp doesn't have the key `name`, it's an obj
# TODO: true for all cases? better approaches?
return not comp_or_obj.get("name") and "object" in comp_or_obj


class CompositionDoc(DataDoc):
"""
Expand Down Expand Up @@ -124,7 +146,7 @@ def resolve_local_regions(self, local_data, recipe_data, db):
Recursively resolves the regions of a composition from local data.
Restructure the local data to match the db data.
"""
unpack_recipe_data = DBHandler.prep_data_for_db(recipe_data)
unpack_recipe_data = DBUploader.prep_data_for_db(recipe_data)
prep_recipe_data = ObjectDoc.convert_representation(unpack_recipe_data, db)
# `gradients` is a list, convert it to dict for easy access and replace
CompositionDoc.gradient_list_to_dict(prep_recipe_data)
Expand Down Expand Up @@ -321,7 +343,7 @@ def should_write(self, db):
# if there is repr in the obj doc from db
full_doc_data = ObjectDoc.convert_representation(doc, db)
# unpack objects to dicts in local data for comparison
local_data = DBHandler.prep_data_for_db(self.as_dict())
local_data = DBUploader.prep_data_for_db(self.as_dict())
difference = DeepDiff(full_doc_data, local_data, ignore_order=True)
if not difference:
return doc, db.doc_id(doc)
Expand All @@ -337,40 +359,25 @@ def should_write(self, db, grad_name):
docs = db.get_doc_by_name("gradients", grad_name)
if docs and len(docs) >= 1:
for doc in docs:
local_data = DBHandler.prep_data_for_db(db.doc_to_dict(doc))
local_data = DBUploader.prep_data_for_db(db.doc_to_dict(doc))
db_data = db.doc_to_dict(doc)
difference = DeepDiff(db_data, local_data, ignore_order=True)
if not difference:
return doc, db.doc_id(doc)
return None, None


class DBHandler(object):
class DBUploader(object):
"""
Handles the uploading of data to the database.
"""

def __init__(self, db_handler):
self.db = db_handler
self.objects_to_path_map = {}
self.comp_to_path_map = {}
self.grad_to_path_map = {}

@staticmethod
def is_nested_list(item):
return (
isinstance(item, list)
and len(item) > 0
and isinstance(item[0], (list, tuple))
)

@staticmethod
def is_db_dict(item):
if isinstance(item, dict) and len(item) > 0:
for key, value in item.items():
if key.isdigit() and isinstance(value, list):
return True
return False

def collect_docs_by_id(self, collection, id):
return self.db.get_doc_by_id(collection, id)

@staticmethod
def prep_data_for_db(data):
"""
Expand All @@ -379,18 +386,18 @@ def prep_data_for_db(data):
modified_data = {}
for key, value in data.items():
# convert 2d array to dict
if DBHandler.is_nested_list(value):
if DataDoc.is_nested_list(value):
flatten_dict = dict(zip([str(i) for i in range(len(value))], value))
modified_data[key] = DBHandler.prep_data_for_db(flatten_dict)
modified_data[key] = DBUploader.prep_data_for_db(flatten_dict)
# If the value is an object, we want to convert it to dict
elif isinstance(value, object) and "__dict__" in dir(value):
unpacked_value = vars(value)
modified_data[key] = unpacked_value
if isinstance(unpacked_value, dict):
modified_data[key] = DBHandler.prep_data_for_db(unpacked_value)
modified_data[key] = DBUploader.prep_data_for_db(unpacked_value)
# If the value is a dictionary, recursively convert its nested lists to dictionaries
elif isinstance(value, dict):
modified_data[key] = DBHandler.prep_data_for_db(value)
modified_data[key] = DBUploader.prep_data_for_db(value)
else:
modified_data[key] = value
return modified_data
Expand All @@ -400,7 +407,7 @@ def upload_data(self, collection, data, id=None):
If should_write is true, upload the data to the database
"""
# check if we need to convert part of the data(2d arrays and objs to dict)
modified_data = DBHandler.prep_data_for_db(data)
modified_data = DBUploader.prep_data_for_db(data)
if id is None:
name = modified_data["name"]
doc = self.db.upload_doc(collection, modified_data)
Expand Down Expand Up @@ -491,7 +498,7 @@ def upload_compositions(self, compositions, recipe_to_save, recipe_data):
references_to_update[comp_name].update({"comp_id": doc_id})
return references_to_update

def get_recipe_id(self, recipe_data):
def _get_recipe_id(self, recipe_data):
"""
We use customized recipe id to declare recipe's name and version
"""
Expand Down Expand Up @@ -535,24 +542,33 @@ def upload_recipe(self, recipe_meta_data, recipe_data):
"""
After all other collections are checked or uploaded, upload the recipe with references into db
"""
recipe_id = self.get_recipe_id(recipe_data)
recipe_id = self._get_recipe_id(recipe_data)
# if the recipe is already exists in db, just return
recipe, _ = self.db.get_doc_by_id("recipes", recipe_id)
if recipe:
print(f"{recipe_id} is already in firestore")
return
recipe_to_save = self.upload_collections(recipe_meta_data, recipe_data)
key = self.get_recipe_id(recipe_to_save)
key = self._get_recipe_id(recipe_to_save)
self.upload_data("recipes", recipe_to_save, key)


class DBRecipeLoader(object):
"""
Handles the logic for downloading and parsing the recipe data from the database.
"""

def __init__(self, db_handler):
self.db = db_handler

def prep_db_doc_for_download(self, db_doc):
"""
convert data from db and resolve references.
"""
prep_data = {}
if isinstance(db_doc, dict):
for key, value in db_doc.items():
if self.is_db_dict(value):
if DataDoc.is_db_dict(value):
unpack_dict = [value[str(i)] for i in range(len(value))]
prep_data[key] = unpack_dict
elif key == "composition":
Expand All @@ -575,3 +591,85 @@ def prep_db_doc_for_download(self, db_doc):
else:
prep_data[key] = value
return prep_data

def collect_docs_by_id(self, collection, id):
return self.db.get_doc_by_id(collection, id)

@staticmethod
def _get_grad_and_obj(obj_data, obj_dict, grad_dict):
try:
grad_name = obj_data["gradient"]["name"]
obj_name = obj_data["name"]
except KeyError as e:
print(f"Missing keys in object: {e}")
return obj_dict, grad_dict

grad_dict[grad_name] = obj_data["gradient"]
obj_dict[obj_name]["gradient"] = grad_name
return obj_dict, grad_dict

@staticmethod
def _collect_and_sort_data(comp_data):
"""
Collect all object and gradient info from the downloaded composition data
Return autopack object data dict and gradient data dict with name as key
Return restructured composition dict with "composition" as key
"""
objects = {}
gradients = {}
composition = {}
for comp_name, comp_value in comp_data.items():
composition[comp_name] = {}
if "count" in comp_value and comp_value["count"] is not None:
composition[comp_name]["count"] = comp_value["count"]
if "object" in comp_value and comp_value["object"] is not None:
composition[comp_name]["object"] = comp_value["object"]["name"]
object_copy = copy.deepcopy(comp_value["object"])
objects[object_copy["name"]] = object_copy
if "gradient" in object_copy and isinstance(
object_copy["gradient"], dict
):
objects, gradients = DBRecipeLoader._get_grad_and_obj(
object_copy, objects, gradients
)
if "regions" in comp_value and comp_value["regions"] is not None:
for region_name in comp_value["regions"]:
composition[comp_name].setdefault("regions", {})[region_name] = []
for region_item in comp_value["regions"][region_name]:
if DataDoc.is_obj(region_item):
composition[comp_name]["regions"][region_name].append(
{
"object": region_item["object"].get("name"),
"count": region_item.get("count"),
}
)
object_copy = copy.deepcopy(region_item["object"])
objects[object_copy["name"]] = object_copy
if "gradient" in object_copy and isinstance(
object_copy["gradient"], dict
):
objects, gradients = DBRecipeLoader._get_grad_and_obj(
object_copy, objects, gradients
)
else:
composition[comp_name]["regions"][region_name].append(
region_item["name"]
)
return objects, gradients, composition

@staticmethod
def _compile_db_recipe_data(db_recipe_data, obj_dict, grad_dict, comp_dict):
"""
Compile recipe data from db recipe data into a ready-to-pack structure
"""
recipe_data = {
**{
k: db_recipe_data[k]
for k in ["format_version", "version", "name", "bounding_box"]
},
"objects": obj_dict,
"composition": comp_dict,
}
if grad_dict:
recipe_data["gradients"] = [{**v} for v in grad_dict.values()]
return recipe_data
90 changes: 3 additions & 87 deletions cellpack/autopack/loaders/recipe_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from cellpack.autopack.loaders.migrate_v1_to_v2 import convert as convert_v1_to_v2
from cellpack.autopack.loaders.migrate_v2_to_v2_1 import convert as convert_v2_to_v2_1
from cellpack.autopack.DBRecipeHandler import DBRecipeLoader

encoder.FLOAT_REPR = lambda o: format(o, ".8g")
CURRENT_VERSION = "2.1"
Expand Down Expand Up @@ -166,100 +167,15 @@ def _migrate_version(self, old_recipe):
f"{old_recipe['format_version']} is not a format version we support"
)

@staticmethod
def _get_grad_and_obj(obj_data, obj_dict, grad_dict):
try:
grad_name = obj_data["gradient"]["name"]
obj_name = obj_data["name"]
except KeyError as e:
print(f"Missing keys in object: {e}")
return obj_dict, grad_dict

grad_dict[grad_name] = obj_data["gradient"]
obj_dict[obj_name]["gradient"] = grad_name
return obj_dict, grad_dict

@staticmethod
def _is_obj(comp_or_obj):
# if the top level of a downloaded comp doesn't have the key `name`, it's an obj
# TODO: true for all cases? better approaches?
return not comp_or_obj.get("name") and "object" in comp_or_obj

@staticmethod
def _collect_and_sort_data(comp_data):
"""
Collect all object and gradient info from the downloaded firebase composition data
Return autopack object data dict and gradient data dict with name as key
Return restructured composition dict with "composition" as key
"""
objects = {}
gradients = {}
composition = {}
for comp_name, comp_value in comp_data.items():
composition[comp_name] = {}
if "count" in comp_value and comp_value["count"] is not None:
composition[comp_name]["count"] = comp_value["count"]
if "object" in comp_value and comp_value["object"] is not None:
composition[comp_name]["object"] = comp_value["object"]["name"]
object_copy = copy.deepcopy(comp_value["object"])
objects[object_copy["name"]] = object_copy
if "gradient" in object_copy and isinstance(
object_copy["gradient"], dict
):
objects, gradients = RecipeLoader._get_grad_and_obj(
object_copy, objects, gradients
)
if "regions" in comp_value and comp_value["regions"] is not None:
for region_name in comp_value["regions"]:
composition[comp_name].setdefault("regions", {})[region_name] = []
for region_item in comp_value["regions"][region_name]:
if RecipeLoader._is_obj(region_item):
composition[comp_name]["regions"][region_name].append(
{
"object": region_item["object"].get("name"),
"count": region_item.get("count"),
}
)
object_copy = copy.deepcopy(region_item["object"])
objects[object_copy["name"]] = object_copy
if "gradient" in object_copy and isinstance(
object_copy["gradient"], dict
):
objects, gradients = RecipeLoader._get_grad_and_obj(
object_copy, objects, gradients
)
else:
composition[comp_name]["regions"][region_name].append(
region_item["name"]
)
return objects, gradients, composition

@staticmethod
def _compile_recipe_from_firebase(db_recipe_data, obj_dict, grad_dict, comp_dict):
"""
Compile recipe data from firebase recipe data into a ready-to-pack structure
"""
recipe_data = {
**{
k: db_recipe_data[k]
for k in ["format_version", "version", "name", "bounding_box"]
},
"objects": obj_dict,
"composition": comp_dict,
}
if grad_dict:
recipe_data["gradients"] = [{**v} for v in grad_dict.values()]
return recipe_data

def _read(self):
new_values, database_name = autopack.load_file(
self.file_path, self.db_handler, cache="recipes"
)
if database_name == "firebase":
objects, gradients, composition = RecipeLoader._collect_and_sort_data(
objects, gradients, composition = DBRecipeLoader._collect_and_sort_data(
new_values["composition"]
)
new_values = RecipeLoader._compile_recipe_from_firebase(
new_values = DBRecipeLoader._compile_db_recipe_data(
new_values, objects, gradients, composition
)
recipe_data = RecipeLoader.default_values.copy()
Expand Down
4 changes: 2 additions & 2 deletions cellpack/bin/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from cellpack.autopack.loaders.recipe_loader import RecipeLoader
from cellpack.autopack.loaders.analysis_config_loader import AnalysisConfigLoader
from cellpack.autopack.FirebaseHandler import FirebaseHandler
from cellpack.autopack.DBRecipeHandler import DBHandler
from cellpack.autopack.DBRecipeHandler import DBRecipeLoader

###############################################################################
log_file_path = path.abspath(path.join(__file__, "../../logging.conf"))
Expand Down Expand Up @@ -42,7 +42,7 @@ def pack(
"""
if db_id == DATABASE_IDS.FIREBASE:
db = FirebaseHandler()
db_handler = DBHandler(db)
db_handler = DBRecipeLoader(db)
packing_config_data = ConfigLoader(config_path).config
recipe_data = RecipeLoader(
recipe, db_handler, packing_config_data["save_converted_recipe"]
Expand Down
Loading

0 comments on commit f864fc6

Please sign in to comment.