Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use chunking / padding in more jitted functions #129

Merged
merged 2 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 84 additions & 10 deletions src/adam_core/coordinates/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from jax import config, jit, lax, vmap

from ..constants import Constants as c
from ..utils.chunking import process_in_chunks
from . import types
from .cartesian import CartesianCoordinates
from .cometary import CometaryCoordinates
Expand Down Expand Up @@ -161,7 +162,19 @@ def cartesian_to_spherical(
vlat : Latitudinal velocity in degrees per arbitrary unit of time.
(same unit of time as the x, y, and z velocities).
"""
coords_spherical = _cartesian_to_spherical_vmap(coords_cartesian)
# Define chunk size
chunk_size = 50

# Process in chunks
coords_spherical_chunks = []
for cartesian_chunk in process_in_chunks(coords_cartesian, chunk_size):
coords_spherical_chunk = _cartesian_to_spherical_vmap(cartesian_chunk)
coords_spherical_chunks.append(coords_spherical_chunk)

# Concatenate chunks and remove padding
coords_spherical = jnp.concatenate(coords_spherical_chunks, axis=0)
coords_spherical = coords_spherical[: len(coords_cartesian)]

return coords_spherical


Expand Down Expand Up @@ -276,7 +289,19 @@ def spherical_to_cartesian(
vy : y-velocity in the same units of y per arbitrary unit of time.
vz : z-velocity in the same units of z per arbitrary unit of time.
"""
coords_cartesian = _spherical_to_cartesian_vmap(coords_spherical)
# Define chunk size
chunk_size = 50

# Process in chunks
coords_cartesian_chunks = []
for spherical_chunk in process_in_chunks(coords_spherical, chunk_size):
coords_cartesian_chunk = _spherical_to_cartesian_vmap(spherical_chunk)
coords_cartesian_chunks.append(coords_cartesian_chunk)

# Concatenate chunks and remove padding
coords_cartesian = jnp.concatenate(coords_cartesian_chunks, axis=0)
coords_cartesian = coords_cartesian[: len(coords_spherical)]

return coords_cartesian


Expand Down Expand Up @@ -537,7 +562,7 @@ def cartesian_to_keplerian(
vz : z-velocity in units of au per day.
t0 : {`~numpy.ndarray`, `~jax.numpy.ndarray`} (N)
Epoch at which cometary elements are defined in MJD TDB.
mu : {`~numpy.ndarray`, `~jax.numpy.ndarray`} (N, 6)
mu : {`~numpy.ndarray`, `~jax.numpy.ndarray`} (N)
Gravitational parameter (GM) of the attracting body in units of
au**3 / d**2.

Expand All @@ -559,7 +584,25 @@ def cartesian_to_keplerian(
P : period in days.
tp : time of periapsis passage in days.
"""
coords_keplerian = _cartesian_to_keplerian_vmap(coords_cartesian, t0, mu)
# Define chunk size
chunk_size = 50

# Process in chunks
coords_keplerian_chunks = []
for cartesian_chunk, t0_chunk, mu_chunk in zip(
process_in_chunks(coords_cartesian, chunk_size),
process_in_chunks(t0, chunk_size),
process_in_chunks(mu, chunk_size),
):
coords_keplerian_chunk = _cartesian_to_keplerian_vmap(
cartesian_chunk, t0_chunk, mu_chunk
)
coords_keplerian_chunks.append(coords_keplerian_chunk)

# Concatenate chunks and remove padding
coords_keplerian = jnp.concatenate(coords_keplerian_chunks, axis=0)
coords_keplerian = coords_keplerian[: len(coords_cartesian)]

return coords_keplerian


Expand Down Expand Up @@ -945,9 +988,24 @@ def keplerian_to_cartesian(
)
raise ValueError(err)

coords_cartesian = _keplerian_to_cartesian_a_vmap(
coords_keplerian, mu, max_iter, tol
)
# Define chunk size
chunk_size = 50

# Process in chunks
coords_cartesian_chunks = []
for keplerian_chunk, mu_chunk in zip(
process_in_chunks(coords_keplerian, chunk_size),
process_in_chunks(mu, chunk_size),
):
coords_cartesian_chunk = _keplerian_to_cartesian_a_vmap(
keplerian_chunk, mu_chunk, max_iter, tol
)
coords_cartesian_chunks.append(coords_cartesian_chunk)

# Concatenate chunks and remove padding
coords_cartesian = jnp.concatenate(coords_cartesian_chunks, axis=0)
coords_cartesian = coords_cartesian[: len(coords_keplerian)]

return coords_cartesian


Expand Down Expand Up @@ -1188,9 +1246,25 @@ def cometary_to_cartesian(
vy : y-velocity in units of au per day.
vz : z-velocity in units of au per day.
"""
coords_cartesian = _cometary_to_cartesian_vmap(
coords_cometary, t0, mu, max_iter, tol
)
# Define chunk size
chunk_size = 50

# Process in chunks
coords_cartesian_chunks = []
for cometary_chunk, t0_chunk, mu_chunk in zip(
process_in_chunks(coords_cometary, chunk_size),
process_in_chunks(t0, chunk_size),
process_in_chunks(mu, chunk_size),
):
coords_cartesian_chunk = _cometary_to_cartesian_vmap(
cometary_chunk, t0_chunk, mu_chunk, max_iter, tol
)
coords_cartesian_chunks.append(coords_cartesian_chunk)

# Concatenate chunks and remove padding
coords_cartesian = jnp.concatenate(coords_cartesian_chunks, axis=0)
coords_cartesian = coords_cartesian[: len(coords_cometary)]

return coords_cartesian


Expand Down
2 changes: 1 addition & 1 deletion src/adam_core/dynamics/ephemeris.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from ..observers.observers import Observers
from ..orbits.ephemeris import Ephemeris
from ..orbits.orbits import Orbits
from ..utils.chunking import process_in_chunks
from .aberrations import _add_light_time, add_stellar_aberration
from .propagation import process_in_chunks


@jit
Expand Down
21 changes: 1 addition & 20 deletions src/adam_core/dynamics/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..coordinates.origin import Origin
from ..orbits.orbits import Orbits
from ..time import Timestamp
from ..utils.chunking import process_in_chunks
from .lagrange import apply_lagrange_coefficients, calc_lagrange_coefficients

config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -69,26 +70,6 @@ def _propagate_2body(
)


def pad_to_fixed_size(array, target_shape, pad_value=0):
"""
Pad an array to a fixed shape with a specified pad value.
"""
pad_width = [(0, max(0, t - s)) for s, t in zip(array.shape, target_shape)]
return jnp.pad(array, pad_width, constant_values=pad_value)


def process_in_chunks(array, chunk_size):
"""
Yield chunks of the array with a fixed size, padding the last chunk if necessary.
"""
n = array.shape[0]
for i in range(0, n, chunk_size):
chunk = array[i : i + chunk_size]
if chunk.shape[0] < chunk_size:
chunk = pad_to_fixed_size(chunk, (chunk_size,) + chunk.shape[1:])
yield chunk


def propagate_2body(
orbits: Orbits,
times: Timestamp,
Expand Down
47 changes: 47 additions & 0 deletions src/adam_core/utils/chunking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import jax.numpy as jnp


def pad_to_fixed_size(array, target_shape, pad_value=0):
"""
Pad an array to a fixed shape with a specified pad value.

Parameters
----------
array : array-like
Array to pad
target_shape : tuple
Desired output shape
pad_value : int or float, optional
Value to use for padding, by default 0

Returns
-------
padded_array : array-like
Padded array with desired shape
"""
pad_width = [(0, max(0, t - s)) for s, t in zip(array.shape, target_shape)]
return jnp.pad(array, pad_width, constant_values=pad_value)


def process_in_chunks(array, chunk_size):
"""
Yield chunks of the array with a fixed size, padding the last chunk if necessary.

Parameters
----------
array : array-like
Array to process in chunks
chunk_size : int
Size of each chunk

Yields
------
chunk : array-like
Array chunk of fixed size (padded if necessary)
"""
n = array.shape[0]
for i in range(0, n, chunk_size):
chunk = array[i : i + chunk_size]
if chunk.shape[0] < chunk_size:
chunk = pad_to_fixed_size(chunk, (chunk_size,) + chunk.shape[1:])
yield chunk
Loading