Skip to content

Commit

Permalink
Add a standalone detect checkpoint function (#158)
Browse files Browse the repository at this point in the history
* Add a standalone detect checkpoint function
  • Loading branch information
akoumjian authored Feb 12, 2024
1 parent 790e919 commit 3fb2bed
Showing 1 changed file with 153 additions and 96 deletions.
249 changes: 153 additions & 96 deletions thor/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,42 +120,86 @@ def create_checkpoint_data(stage: VALID_STAGES, **data) -> CheckpointData:
raise ValueError(f"Invalid stage: {stage}")


def detect_checkpoint_stage(test_orbit_directory: pathlib.Path) -> VALID_STAGES:
"""
Looks for existing files and indicates the next stage to run
"""
if not test_orbit_directory.is_dir():
raise ValueError(f"{test_orbit_directory} is not a directory")

if not test_orbit_directory.exists():
logger.info(f"Working directory does not exist, starting at beginning.")
return "filter_observations"

if not (test_orbit_directory / "filtered_observations.parquet").exists():
logger.info(
"No filtered observations found, starting stage filter_observations"
)
return "filter_observations"

if (test_orbit_directory / "recovered_orbits.parquet").exists() and (
test_orbit_directory / "recovered_orbit_members.parquet"
).exists():
logger.info("Found recovered orbits, pipeline is complete.")
return "complete"

if (test_orbit_directory / "od_orbits.parquet").exists() and (
test_orbit_directory / "od_orbit_members.parquet"
).exists():
logger.info("Found OD orbits, starting stage recover_orbits")
return "recover_orbits"

if (test_orbit_directory / "iod_orbits.parquet").exists() and (
test_orbit_directory / "iod_orbit_members.parquet"
).exists():
logger.info("Found IOD orbits, starting stage differential_correction")
return "differential_correction"

if (test_orbit_directory / "clusters.parquet").exists() and (
test_orbit_directory / "cluster_members.parquet"
).exists():
logger.info("Found clusters, starting stage initial_orbit_determination")
return "initial_orbit_determination"

if (test_orbit_directory / "transformed_detections.parquet").exists():
logger.info("Found transformed detections, starting stage cluster_and_link")
return "cluster_and_link"

if (test_orbit_directory / "filtered_observations.parquet").exists():
logger.info("Found filtered observations, starting stage range_and_transform")
return "range_and_transform"

raise ValueError(f"Could not detect stage from {test_orbit_directory}")


def load_initial_checkpoint_values(
test_orbit_directory: Optional[pathlib.Path] = None,
test_orbit_directory: Optional[Union[pathlib.Path, str]] = None,
) -> CheckpointData:
"""
Check for completed stages and return values from disk if they exist.
We want to avoid loading objects into memory that are not required.
"""
stage: VALID_STAGES = "filter_observations"
# Without a checkpoint directory, we always start at the beginning
if isinstance(test_orbit_directory, str):
test_orbit_directory = pathlib.Path(test_orbit_directory)

if test_orbit_directory is None:
return create_checkpoint_data(stage)
logger.info("Not using a workign directory, start at beginning.")
return create_checkpoint_data("filter_observations")

# filtered_observations is always needed when it exists
filtered_observations_path = pathlib.Path(
test_orbit_directory, "filtered_observations.parquet"
)
# If it doesn't exist, start at the beginning.
if not filtered_observations_path.exists():
return create_checkpoint_data(stage)
logger.info("Found filtered observations")
filtered_observations = Observations.from_parquet(filtered_observations_path)
stage: VALID_STAGES = detect_checkpoint_stage(test_orbit_directory)

if filtered_observations.fragmented():
filtered_observations = qv.defragment(filtered_observations)
if stage == "filter_observations":
return create_checkpoint_data(stage)

# If the pipeline was started but we have recovered_orbits already, we
# are done and should exit early.
recovered_orbits_path = pathlib.Path(
test_orbit_directory, "recovered_orbits.parquet"
)
recovered_orbit_members_path = pathlib.Path(
test_orbit_directory, "recovered_orbit_members.parquet"
)
if recovered_orbits_path.exists() and recovered_orbit_members_path.exists():
logger.info("Found recovered orbits in checkpoint")
# If we've already completed the pipeline, we can load the recovered orbits
if stage == "complete":
recovered_orbits_path = pathlib.Path(
test_orbit_directory, "recovered_orbits.parquet"
)
recovered_orbit_members_path = pathlib.Path(
test_orbit_directory, "recovered_orbit_members.parquet"
)
recovered_orbits = FittedOrbits.from_parquet(recovered_orbits_path)
recovered_orbit_members = FittedOrbitMembers.from_parquet(
recovered_orbit_members_path
Expand All @@ -167,88 +211,101 @@ def load_initial_checkpoint_values(
recovered_orbit_members = qv.defragment(recovered_orbit_members)

return create_checkpoint_data(
"complete",
stage,
recovered_orbits=recovered_orbits,
recovered_orbit_members=recovered_orbit_members,
)

# Now with filtered_observations available, we can check for the later
# stages in reverse order.
od_orbits_path = pathlib.Path(test_orbit_directory, "od_orbits.parquet")
od_orbit_members_path = pathlib.Path(
test_orbit_directory, "od_orbit_members.parquet"
)
if od_orbits_path.exists() and od_orbit_members_path.exists():
logger.info("Found OD orbits in checkpoint")
od_orbits = FittedOrbits.from_parquet(od_orbits_path)
od_orbit_members = FittedOrbitMembers.from_parquet(od_orbit_members_path)

if od_orbits.fragmented():
od_orbits = qv.defragment(od_orbits)
if od_orbit_members.fragmented():
od_orbit_members = qv.defragment(od_orbit_members)

return create_checkpoint_data(
"recover_orbits",
filtered_observations=filtered_observations,
od_orbits=od_orbits,
od_orbit_members=od_orbit_members,
)

iod_orbits_path = pathlib.Path(test_orbit_directory, "iod_orbits.parquet")
iod_orbit_members_path = pathlib.Path(
test_orbit_directory, "iod_orbit_members.parquet"
# Filtered observations are required for all other stages
filtered_observations_path = pathlib.Path(
test_orbit_directory, "filtered_observations.parquet"
)
if iod_orbits_path.exists() and iod_orbit_members_path.exists():
logger.info("Found IOD orbits")
iod_orbits = FittedOrbits.from_parquet(iod_orbits_path)
iod_orbit_members = FittedOrbitMembers.from_parquet(iod_orbit_members_path)

if iod_orbits.fragmented():
iod_orbits = qv.defragment(iod_orbits)
if iod_orbit_members.fragmented():
iod_orbit_members = qv.defragment(iod_orbit_members)
filtered_observations = Observations.from_parquet(filtered_observations_path)
if filtered_observations.fragmented():
filtered_observations = qv.defragment(filtered_observations)

return create_checkpoint_data(
"differential_correction",
filtered_observations=filtered_observations,
iod_orbits=iod_orbits,
iod_orbit_members=iod_orbit_members,
if stage == "recover_orbits":
od_orbits_path = pathlib.Path(test_orbit_directory, "od_orbits.parquet")
od_orbit_members_path = pathlib.Path(
test_orbit_directory, "od_orbit_members.parquet"
)

clusters_path = pathlib.Path(test_orbit_directory, "clusters.parquet")
cluster_members_path = pathlib.Path(test_orbit_directory, "cluster_members.parquet")
if clusters_path.exists() and cluster_members_path.exists():
logger.info("Found clusters")
clusters = Clusters.from_parquet(clusters_path)
cluster_members = ClusterMembers.from_parquet(cluster_members_path)

if clusters.fragmented():
clusters = qv.defragment(clusters)
if cluster_members.fragmented():
cluster_members = qv.defragment(cluster_members)

return create_checkpoint_data(
"initial_orbit_determination",
filtered_observations=filtered_observations,
clusters=clusters,
cluster_members=cluster_members,
if od_orbits_path.exists() and od_orbit_members_path.exists():
logger.info("Found OD orbits in checkpoint")
od_orbits = FittedOrbits.from_parquet(od_orbits_path)
od_orbit_members = FittedOrbitMembers.from_parquet(od_orbit_members_path)

if od_orbits.fragmented():
od_orbits = qv.defragment(od_orbits)
if od_orbit_members.fragmented():
od_orbit_members = qv.defragment(od_orbit_members)

return create_checkpoint_data(
"recover_orbits",
filtered_observations=filtered_observations,
od_orbits=od_orbits,
od_orbit_members=od_orbit_members,
)

if stage == "differential_correction":
iod_orbits_path = pathlib.Path(test_orbit_directory, "iod_orbits.parquet")
iod_orbit_members_path = pathlib.Path(
test_orbit_directory, "iod_orbit_members.parquet"
)

transformed_detections_path = pathlib.Path(
test_orbit_directory, "transformed_detections.parquet"
)
if transformed_detections_path.exists():
logger.info("Found transformed detections")
transformed_detections = TransformedDetections.from_parquet(
transformed_detections_path
if iod_orbits_path.exists() and iod_orbit_members_path.exists():
logger.info("Found IOD orbits")
iod_orbits = FittedOrbits.from_parquet(iod_orbits_path)
iod_orbit_members = FittedOrbitMembers.from_parquet(iod_orbit_members_path)

if iod_orbits.fragmented():
iod_orbits = qv.defragment(iod_orbits)
if iod_orbit_members.fragmented():
iod_orbit_members = qv.defragment(iod_orbit_members)

return create_checkpoint_data(
"differential_correction",
filtered_observations=filtered_observations,
iod_orbits=iod_orbits,
iod_orbit_members=iod_orbit_members,
)

if stage == "initial_orbit_determination":
clusters_path = pathlib.Path(test_orbit_directory, "clusters.parquet")
cluster_members_path = pathlib.Path(
test_orbit_directory, "cluster_members.parquet"
)

return create_checkpoint_data(
"cluster_and_link",
filtered_observations=filtered_observations,
transformed_detections=transformed_detections,
if clusters_path.exists() and cluster_members_path.exists():
logger.info("Found clusters")
clusters = Clusters.from_parquet(clusters_path)
cluster_members = ClusterMembers.from_parquet(cluster_members_path)

if clusters.fragmented():
clusters = qv.defragment(clusters)
if cluster_members.fragmented():
cluster_members = qv.defragment(cluster_members)

return create_checkpoint_data(
"initial_orbit_determination",
filtered_observations=filtered_observations,
clusters=clusters,
cluster_members=cluster_members,
)

if stage == "cluster_and_link":
transformed_detections_path = pathlib.Path(
test_orbit_directory, "transformed_detections.parquet"
)
if transformed_detections_path.exists():
logger.info("Found transformed detections")
transformed_detections = TransformedDetections.from_parquet(
transformed_detections_path
)

return create_checkpoint_data(
"cluster_and_link",
filtered_observations=filtered_observations,
transformed_detections=transformed_detections,
)

return create_checkpoint_data(
"range_and_transform", filtered_observations=filtered_observations
Expand Down

0 comments on commit 3fb2bed

Please sign in to comment.