diff --git a/pyproject.toml b/pyproject.toml index 7fdac6f7..983c6c79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,7 +102,7 @@ lint = { composite = [ "black --check ./src/adam_core", "isort --check-only ./src/adam_core", ] } -fix = "ruff ./src/adam_core --fix" +fix = "ruff check ./src/adam_core --fix" typecheck = "mypy --strict ./src/adam_core" test = "pytest --benchmark-skip -m 'not profile' {args}" diff --git a/src/adam_core/dynamics/propagation.py b/src/adam_core/dynamics/propagation.py index 75fdf4f7..77791b45 100644 --- a/src/adam_core/dynamics/propagation.py +++ b/src/adam_core/dynamics/propagation.py @@ -159,8 +159,9 @@ def propagate_2body( else: cartesian_covariances = None - origin_code = np.empty(n_orbits * n_times, dtype="object") - origin_code.fill("SUN") + origin_code = np.repeat( + orbits.coordinates.origin.code.to_numpy(zero_copy_only=False), n_times + ) # Convert from the jax array to a numpy array orbits_propagated = np.asarray(orbits_propagated) diff --git a/src/adam_core/propagator/propagator.py b/src/adam_core/propagator/propagator.py index 49095ab0..fe5230b2 100644 --- a/src/adam_core/propagator/propagator.py +++ b/src/adam_core/propagator/propagator.py @@ -1,6 +1,6 @@ import logging from abc import ABC, abstractmethod -from typing import List, Literal, Optional, Type, Union +from typing import List, Literal, Optional, Tuple, Type, Union import numpy as np import numpy.typing as npt @@ -103,7 +103,7 @@ def _add_light_time( observers, lt_tol: float = 1e-12, max_iter: int = 10, - ): + ) -> Tuple[Orbits, np.ndarray]: orbits_aberrated = Orbits.empty() lts = np.zeros(len(orbits)) for i, (orbit, observer) in enumerate(zip(orbits, observers)): @@ -433,8 +433,8 @@ def _propagate_orbits(self, orbits: OrbitType, times: TimestampType) -> OrbitTyp def propagate_orbits( self, - orbits: OrbitType, - times: TimestampType, + orbits: Union[OrbitType, ObjectRef], + times: Union[TimestampType, ObjectRef], covariance: bool = False, covariance_method: Literal[ "auto", "sigma-point", "monte-carlo" @@ -495,6 +495,7 @@ def propagate_orbits( times_ref = ray.put(times) else: times_ref = times + times = ray.get(times_ref) if not isinstance(orbits, ObjectRef): orbits_ref = ray.put(orbits) @@ -574,6 +575,12 @@ def propagate_orbits( if propagated_variants is not None: propagated = propagated_variants.collapse(propagated) + # Preserve the time scale of the requested times + propagated = propagated.set_column( + "coordinates.time", + propagated.coordinates.time.rescale(times.scale), + ) + # Return the results with the original origin and frame # Preserve the original output origin for the input orbits # by orbit id diff --git a/src/adam_core/propagator/tests/test_propagator.py b/src/adam_core/propagator/tests/test_propagator.py index c3192e88..a273f213 100644 --- a/src/adam_core/propagator/tests/test_propagator.py +++ b/src/adam_core/propagator/tests/test_propagator.py @@ -1,6 +1,5 @@ import numpy as np import pyarrow as pa -import pytest import quivr as qv from ...coordinates.cartesian import CartesianCoordinates @@ -90,7 +89,6 @@ def test_propagator_single_worker(): pass -@pytest.mark.skipif(RAY_INSTALLED is False, reason="Ray is not installed.") def test_propagator_multiple_workers_ray(): orbits = make_real_orbits(10) times = Timestamp.from_iso8601(["2020-01-01T00:00:00", "2020-01-01T00:00:01"])