Skip to content

Commit

Permalink
Ak/more perf improvements (#130)
Browse files Browse the repository at this point in the history
* this seems like a better chunk size

* Change chunking
  • Loading branch information
akoumjian authored Dec 9, 2024
1 parent 3059a45 commit 3d2a90d
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 42 deletions.
32 changes: 16 additions & 16 deletions src/adam_core/coordinates/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,7 @@ def _cartesian_to_spherical(
)


def cartesian_to_spherical(
coords_cartesian: Union[np.ndarray, jnp.ndarray]
) -> jnp.ndarray:
def cartesian_to_spherical(coords_cartesian: np.ndarray) -> np.ndarray:
"""
Convert Cartesian coordinates to a spherical coordinates.
Expand Down Expand Up @@ -163,16 +161,17 @@ def cartesian_to_spherical(
(same unit of time as the x, y, and z velocities).
"""
# Define chunk size
chunk_size = 50
chunk_size = 200

# Process in chunks
coords_spherical_chunks = []
# Process in chunk
coords_spherical: np.ndarray = np.empty((0, 6))
for cartesian_chunk in process_in_chunks(coords_cartesian, chunk_size):
coords_spherical_chunk = _cartesian_to_spherical_vmap(cartesian_chunk)
coords_spherical_chunks.append(coords_spherical_chunk)
coords_spherical = np.concatenate(
(coords_spherical, np.asarray(coords_spherical_chunk))
)

# Concatenate chunks and remove padding
coords_spherical = jnp.concatenate(coords_spherical_chunks, axis=0)
coords_spherical = coords_spherical[: len(coords_cartesian)]

return coords_spherical
Expand Down Expand Up @@ -290,16 +289,17 @@ def spherical_to_cartesian(
vz : z-velocity in the same units of z per arbitrary unit of time.
"""
# Define chunk size
chunk_size = 50
chunk_size = 200

# Process in chunks
coords_cartesian_chunks = []
coords_cartesian: np.ndarray = np.empty((0, 6))
for spherical_chunk in process_in_chunks(coords_spherical, chunk_size):
coords_cartesian_chunk = _spherical_to_cartesian_vmap(spherical_chunk)
coords_cartesian_chunks.append(coords_cartesian_chunk)
coords_cartesian = np.concatenate(
(coords_cartesian, np.asarray(coords_cartesian_chunk))
)

# Concatenate chunks and remove padding
coords_cartesian = jnp.concatenate(coords_cartesian_chunks, axis=0)
# Remove padding
coords_cartesian = coords_cartesian[: len(coords_spherical)]

return coords_cartesian
Expand Down Expand Up @@ -585,7 +585,7 @@ def cartesian_to_keplerian(
tp : time of periapsis passage in days.
"""
# Define chunk size
chunk_size = 50
chunk_size = 200

# Process in chunks
coords_keplerian_chunks = []
Expand Down Expand Up @@ -989,7 +989,7 @@ def keplerian_to_cartesian(
raise ValueError(err)

# Define chunk size
chunk_size = 50
chunk_size = 200

# Process in chunks
coords_cartesian_chunks = []
Expand Down Expand Up @@ -1247,7 +1247,7 @@ def cometary_to_cartesian(
vz : z-velocity in units of au per day.
"""
# Define chunk size
chunk_size = 50
chunk_size = 200

# Process in chunks
coords_cartesian_chunks = []
Expand Down
19 changes: 9 additions & 10 deletions src/adam_core/dynamics/ephemeris.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,11 @@ def generate_ephemeris_2body(
times = propagated_orbits.coordinates.time.mjd().to_numpy(zero_copy_only=False)

# Define chunk size
chunk_size = 50
chunk_size = 200

# Process in chunks
ephemeris_chunks = []
light_time_chunks = []
ephemeris_spherical: np.ndarray = np.empty((0, 6))
light_time: np.ndarray = np.empty((0,))

for orbits_chunk, times_chunk, observer_coords_chunk, mu_chunk in zip(
process_in_chunks(propagated_orbits_barycentric.coordinates.values, chunk_size),
Expand All @@ -230,15 +230,14 @@ def generate_ephemeris_2body(
tol,
stellar_aberration,
)
ephemeris_chunks.append(ephemeris_chunk)
light_time_chunks.append(light_time_chunk)
ephemeris_spherical = np.concatenate(
(ephemeris_spherical, np.asarray(ephemeris_chunk))
)
light_time = np.concatenate((light_time, np.asarray(light_time_chunk)))

# Concatenate chunks and remove padding
ephemeris_spherical = jnp.concatenate(ephemeris_chunks, axis=0)[:num_entries]
light_time = jnp.concatenate(light_time_chunks, axis=0)[:num_entries]

ephemeris_spherical = np.array(ephemeris_spherical)
light_time = np.array(light_time)
ephemeris_spherical = np.array(ephemeris_spherical)[:num_entries]
light_time = np.array(light_time)[:num_entries]

if not propagated_orbits.coordinates.covariance.is_all_nan():

Expand Down
11 changes: 5 additions & 6 deletions src/adam_core/dynamics/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def propagate_2body(
object_ids = orbits.object_id.to_numpy(zero_copy_only=False)

# Define chunk size
chunk_size = 50 # Example chunk size
chunk_size = 200 # Changed from 1000

# Prepare arrays for chunk processing
# This creates a n x m matrix where n is the number of orbits and m is the number of times
Expand All @@ -121,7 +121,7 @@ def propagate_2body(
t1_ = np.tile(t1, n_orbits)

# Process in chunks
orbits_propagated_chunks = []
orbits_propagated: np.ndarray = np.empty((0, 6))
for orbits_chunk, t0_chunk, t1_chunk, mu_chunk in zip(
process_in_chunks(orbits_array_, chunk_size),
process_in_chunks(t0_, chunk_size),
Expand All @@ -131,10 +131,9 @@ def propagate_2body(
orbits_propagated_chunk = _propagate_2body_vmap(
orbits_chunk, t0_chunk, t1_chunk, mu_chunk, max_iter, tol
)
orbits_propagated_chunks.append(orbits_propagated_chunk)

# Concatenate all chunks
orbits_propagated = jnp.concatenate(orbits_propagated_chunks, axis=0)
orbits_propagated = np.concatenate(
(orbits_propagated, np.asarray(orbits_propagated_chunk))
)

# Remove padding
orbits_propagated = orbits_propagated[: n_orbits * n_times]
Expand Down
6 changes: 4 additions & 2 deletions src/adam_core/observers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def get_observer_state(
o_vec_ITRF93 = np.dot(R_EARTH_EQUATORIAL, o_hat_ITRF93)

# Warning! Converting times to ET will incur a loss of precision.
epochs_et = times.rescale("tdb").et()
epochs_et = times.et()
unique_epochs_et_tdb = epochs_et.unique()

N = len(epochs_et)
Expand Down Expand Up @@ -150,7 +150,7 @@ def get_observer_state(
-OMEGA_EARTH * R_EARTH_EQUATORIAL * rotation_direction
)

return CartesianCoordinates.from_kwargs(
observer_states = CartesianCoordinates.from_kwargs(
time=times,
x=r_obs[:, 0],
y=r_obs[:, 1],
Expand All @@ -161,3 +161,5 @@ def get_observer_state(
frame=frame,
origin=Origin.from_kwargs(code=[origin.name for i in range(len(times))]),
)

return observer_states
1 change: 0 additions & 1 deletion src/adam_core/time/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def et(self) -> pa.lib.DoubleArray:
Returns the times as ET seconds in a pyarrow array.
"""
tdb = self.rescale("tdb")

mjd = tdb.mjd()
return pc.multiply(pc.subtract(mjd, _J2000_TDB_MJD), 86400)

Expand Down
4 changes: 2 additions & 2 deletions src/adam_core/utils/chunking.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import jax.numpy as jnp
import numpy as np


def pad_to_fixed_size(array, target_shape, pad_value=0):
Expand All @@ -20,7 +20,7 @@ def pad_to_fixed_size(array, target_shape, pad_value=0):
Padded array with desired shape
"""
pad_width = [(0, max(0, t - s)) for s, t in zip(array.shape, target_shape)]
return jnp.pad(array, pad_width, constant_values=pad_value)
return np.pad(array, pad_width, constant_values=pad_value)


def process_in_chunks(array, chunk_size):
Expand Down
9 changes: 4 additions & 5 deletions src/adam_core/utils/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,9 @@ def get_perturber_state(
setup_SPICE()

# Convert epochs to ET in TDB
epochs_et = times.rescale("tdb").et()
epochs_et = times.et()
unique_epochs_et = epochs_et.unique()
N = len(times)
# Get position of the body in km and km/s in the desired frame and measured from the desired origin
states = np.empty((N, 6), dtype=np.float64)

for i, epoch in enumerate(unique_epochs_et):
Expand All @@ -144,9 +143,9 @@ def get_perturber_state(
)
states[mask, :] = state

# Convert to AU and AU per day
# Convert units (vectorized operations)
states = states / KM_P_AU
states[:, 3:] = states[:, 3:] * S_P_DAY
states[:, 3:] *= S_P_DAY

return CartesianCoordinates.from_kwargs(
time=times,
Expand All @@ -157,5 +156,5 @@ def get_perturber_state(
vy=states[:, 4],
vz=states[:, 5],
frame=frame,
origin=Origin.from_kwargs(code=[origin.name for i in range(N)]),
origin=Origin.from_kwargs(code=[origin.name] * N),
)

0 comments on commit 3d2a90d

Please sign in to comment.