Skip to content

Commit

Permalink
Automatically find last checkpoint and support multi-stage training
Browse files Browse the repository at this point in the history
  • Loading branch information
DSaurus committed Dec 14, 2023
1 parent 3fe3153 commit de84aaa
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
10 changes: 9 additions & 1 deletion threestudio/systems/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
update_if_possible,
)
from threestudio.utils.config import parse_structured
from threestudio.utils.misc import C, cleanup, get_device, load_module_weights
from threestudio.utils.misc import (
C,
cleanup,
find_last_path,
get_device,
load_module_weights,
)
from threestudio.utils.saving import SaverMixin
from threestudio.utils.typing import *

Expand Down Expand Up @@ -241,6 +247,8 @@ class Config(BaseSystem.Config):
cfg: Config

def configure(self) -> None:
self.cfg.geometry_convert_from = find_last_path(self.cfg.geometry_convert_from)
self.cfg.weights = find_last_path(self.cfg.weights)
if (
self.cfg.geometry_convert_from # from_coarse must be specified
and not self.cfg.weights # not initialized from coarse when weights are specified
Expand Down
21 changes: 21 additions & 0 deletions threestudio/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,24 @@ def broadcast(tensor, src=0):
def enable_gradient(model, enabled: bool = True) -> None:
for param in model.parameters():
param.requires_grad_(enabled)


def find_last_path(path: str):
if (path is not None) and ("LAST" in path):
path = path.replace(" ", "_")
base_dir_prefix, suffix = path.split("LAST", 1)
base_dir = os.path.dirname(base_dir_prefix)
prefix = os.path.split(base_dir_prefix)[-1]
base_dir_prefix = os.path.join(base_dir, prefix)
all_path = os.listdir(base_dir)
all_path = [os.path.join(base_dir, dir) for dir in all_path]
filtered_path = [dir for dir in all_path if dir.startswith(base_dir_prefix)]
filtered_path.sort(reverse=True)
last_path = filtered_path[0]
new_path = last_path + suffix
if os.path.exists(new_path):
return new_path
else:
raise FileNotFoundError(new_path)
else:
return path

0 comments on commit de84aaa

Please sign in to comment.