diff --git a/src/cellmap_models/pytorch/__init__.py b/src/cellmap_models/pytorch/__init__.py index 5cdcb2d..963639f 100755 --- a/src/cellmap_models/pytorch/__init__.py +++ b/src/cellmap_models/pytorch/__init__.py @@ -1,2 +1 @@ -from . import cosem -from . import cellpose +from . import cosem, cellpose, untrained_models diff --git a/src/cellmap_models/pytorch/cosem/README.md b/src/cellmap_models/pytorch/cosem/README.md index ef9c8e3..1629ccb 100755 --- a/src/cellmap_models/pytorch/cosem/README.md +++ b/src/cellmap_models/pytorch/cosem/README.md @@ -25,7 +25,7 @@ Each model has a separate backbone and single layer prediction head. The `backbo ```python import cellmap_models.cosem as cosem_models model = cosem_models.load_model('setup04/1820500') -backnone = model.backbone +backbone = model.backbone head = model.prediction_head ``` diff --git a/src/cellmap_models/pytorch/cosem/load_model.py b/src/cellmap_models/pytorch/cosem/load_model.py index 5d0c8b6..1857087 100755 --- a/src/cellmap_models/pytorch/cosem/load_model.py +++ b/src/cellmap_models/pytorch/cosem/load_model.py @@ -81,6 +81,7 @@ def load_model(checkpoint_name: str) -> torch.nn.Module: new_checkpoint["model"].pop(key) continue new_key = key.replace("architecture.", "") + new_key = new_key.replace("unet.", "backbone.") new_checkpoint["model"][new_key] = new_checkpoint["model"].pop(key) model.load_state_dict(new_checkpoint["model"]) model.eval()