Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Structured array for manifest #39

Closed
wants to merge 10 commits into from
7 changes: 1 addition & 6 deletions virtualizarr/manifests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,4 @@
# This is just to avoid conflicting with some type of file called manifest that .gitignore recommends ignoring.

from .array import ManifestArray # type: ignore # noqa
from .manifest import ( # type: ignore # noqa
ChunkEntry,
ChunkManifest,
concat_manifests,
stack_manifests,
)
from .manifest import ChunkEntry, ChunkManifest # type: ignore # noqa
12 changes: 7 additions & 5 deletions virtualizarr/manifests/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from ..zarr import Codec, ZArray
from .manifest import concat_manifests, stack_manifests
from .manifest import ChunkManifest

if TYPE_CHECKING:
from .array import ManifestArray
Expand Down Expand Up @@ -123,10 +123,11 @@ def concatenate(
new_shape = list(first_shape)
new_shape[axis] = new_length_along_concat_axis

concatenated_manifest = concat_manifests(
[arr.manifest for arr in arrays],
concatenated_manifest_entries = np.concatenate(
[arr.manifest.entries for arr in arrays],
axis=axis,
)
concatenated_manifest = ChunkManifest(entries=concatenated_manifest_entries)

new_zarray = ZArray(
chunks=first_arr.chunks,
Expand Down Expand Up @@ -206,10 +207,11 @@ def stack(
new_shape = list(first_shape)
new_shape.insert(axis, length_along_new_stacked_axis)

stacked_manifest = stack_manifests(
[arr.manifest for arr in arrays],
stacked_manifest_entries = np.stack(
[arr.manifest.entries for arr in arrays],
axis=axis,
)
stacked_manifest = ChunkManifest(entries=stacked_manifest_entries)

# chunk size has changed because a length-1 axis has been inserted
old_chunks = first_arr.chunks
Expand Down
213 changes: 90 additions & 123 deletions virtualizarr/manifests/manifest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import itertools
import re
from typing import Any, Iterable, Iterator, List, Mapping, Tuple, Union, cast
from typing import Any, Iterable, Iterator, List, NewType, Tuple, Union, cast

import numpy as np
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic import BaseModel, ConfigDict

from ..types import ChunkKey

Expand All @@ -14,6 +13,9 @@
_CHUNK_KEY = rf"^{_INTEGER}+({_SEPARATOR}{_INTEGER})*$" # matches 1 integer, optionally followed by more integers each separated by a separator (i.e. a period)


ChunkDict = NewType("ChunkDict", dict[ChunkKey, dict[str, Union[str, int]]])


class ChunkEntry(BaseModel):
"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we really need this class anymore

Information for a single chunk in the manifest.
Expand Down Expand Up @@ -42,11 +44,20 @@ def to_kerchunk(self) -> List[Union[str, int]]:
return [self.path, self.offset, self.length]


# TODO we want the path field to contain a variable-length string, but that's not available until numpy 2.0
# See https://numpy.org/neps/nep-0055-string_dtype.html
MANIFEST_STRUCTURED_ARRAY_DTYPES = np.dtype(
[("path", "<U32"), ("offset", np.int32), ("length", np.int32)]
)

Comment on lines +51 to +55
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because file paths can be strings of any length, we really need to be using numpy's new variable-width string dtype here.

Unfortunately it's only coming out with numpy 2.0, and although there is a release candidate for numpy 2.0, it's so new that pandas doesn't support it yet. Xarray has a pandas dependency, so currently we can't actually build an environment that let's us try virtualizarr with the variable-length string dtype yet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pandas just released 2.2.2, which is compatible with the upcoming numpy 2.0 release.

Not sure if that will break any part of xarray that we need for VirtualiZarr, but this might now be close enough to test out variable-length dtypes now.


class ChunkManifest(BaseModel):
"""
In-memory representation of a single Zarr chunk manifest.

Stores the manifest as a dictionary under the .chunks attribute, in this form:
Stores the manifest internally as a numpy structured array.

The manifest can be converted to or from a dictionary form looking like this

{
"0.0.0": {"path": "s3://bucket/foo.nc", "offset": 100, "length": 100},
Expand All @@ -55,24 +66,45 @@ class ChunkManifest(BaseModel):
"0.1.1": {"path": "s3://bucket/foo.nc", "offset": 400, "length": 100},
}

using the .from_dict() and .dict() methods, so users of this class can think of the manifest as if it were a dict.

See the chunk manifest SPEC proposal in https://github.com/zarr-developers/zarr-specs/issues/287 .

Validation is done when this object is instatiated, and this class is immutable,
so it's not possible to have a ChunkManifest object that does not represent a complete valid grid of chunks.
"""

model_config = ConfigDict(frozen=True)
model_config = ConfigDict(
frozen=True,
arbitrary_types_allowed=True, # so pydantic doesn't complain about the numpy array field
)

entries: Mapping[ChunkKey, ChunkEntry]
# shape_chunk_grid: Tuple[int, ...] # TODO do we need this for anything?
# TODO how to type hint to indicate a numpy structured array with specifically-typed fields?
entries: np.ndarray

@field_validator("entries")
@classmethod
def validate_chunks(cls, entries: Any) -> Mapping[ChunkKey, ChunkEntry]:
validate_chunk_keys(list(entries.keys()))
def from_dict(cls, chunks: ChunkDict) -> "ChunkManifest":
# TODO do some input validation here first?
validate_chunk_keys(chunks.keys())

# TODO should we actually pass shape in, in case there are not enough chunks to give correct idea of full shape?
shape = get_chunk_grid_shape(chunks.keys())

# TODO what if pydantic adjusts anything during validation?
return entries
# Initializing to empty implies that entries with path='' are treated as missing chunks
entries = np.empty(shape=shape, dtype=MANIFEST_STRUCTURED_ARRAY_DTYPES)

# populate the array
for key, entry in chunks.items():
try:
entries[split(key)] = tuple(entry.values())
except (ValueError, TypeError) as e:
msg = (
"Each chunk entry must be of the form dict(path=<str>, offset=<int>, length=<int>), "
f"but got {entry}"
)
raise ValueError(msg) from e

return ChunkManifest(entries=entries)

@property
def ndim_chunk_grid(self) -> int:
Expand All @@ -81,7 +113,7 @@ def ndim_chunk_grid(self) -> int:

Not the same as the dimension of an array backed by this chunk manifest.
"""
return get_ndim_from_key(list(self.entries.keys())[0])
return self.entries.ndim

@property
def shape_chunk_grid(self) -> Tuple[int, ...]:
Expand All @@ -90,23 +122,57 @@ def shape_chunk_grid(self) -> Tuple[int, ...]:

Not the same as the shape of an array backed by this chunk manifest.
"""
return get_chunk_grid_shape(list(self.entries.keys()))
return self.entries.shape

def __repr__(self) -> str:
return f"ChunkManifest<shape={self.shape_chunk_grid}>"

def __getitem__(self, key: ChunkKey) -> ChunkEntry:
return self.chunks[key]
indices = split(key)
return ChunkEntry(self.entries[indices])

def __iter__(self) -> Iterator[ChunkKey]:
return iter(self.chunks.keys())

def __len__(self) -> int:
return len(self.chunks)
return self.entries.size

def dict(self) -> dict[str, dict[str, Union[str, int]]]:
"""Converts the entire manifest to a nested dictionary."""
return {k: dict(entry) for k, entry in self.entries.items()}
def dict(self) -> ChunkDict:
"""
Converts the entire manifest to a nested dictionary, of the form

{
"0.0.0": {"path": "s3://bucket/foo.nc", "offset": 100, "length": 100},
"0.0.1": {"path": "s3://bucket/foo.nc", "offset": 200, "length": 100},
"0.1.0": {"path": "s3://bucket/foo.nc", "offset": 300, "length": 100},
"0.1.1": {"path": "s3://bucket/foo.nc", "offset": 400, "length": 100},
}
"""

def _entry_to_dict(entry: Tuple[str, int, int]) -> dict[str, Union[str, int]]:
return {
"path": entry[0],
"offset": entry[1],
"length": entry[2],
}

coord_vectors = np.mgrid[
tuple(slice(None, length) for length in self.shape_chunk_grid)
]

return cast(
ChunkDict,
{
join(inds): _entry_to_dict(entry.item())
for *inds, entry in np.nditer([*coord_vectors, self.entries])
if entry.item()[0]
!= "" # don't include entry if path='' (i.e. empty chunk)
},
)

def __eq__(self, other: Any) -> bool:
"""Two manifests are equal if all of their entries are identical."""
return (self.entries == other.entries).all()

@staticmethod
def from_zarr_json(filepath: str) -> "ChunkManifest":
Expand All @@ -122,15 +188,15 @@ def from_kerchunk_chunk_dict(cls, kerchunk_chunk_dict) -> "ChunkManifest":
chunkentries = {
k: ChunkEntry.from_kerchunk(v) for k, v in kerchunk_chunk_dict.items()
}
return ChunkManifest(entries=chunkentries)
return ChunkManifest.from_dict(chunkentries)


def split(key: ChunkKey) -> List[int]:
return list(int(i) for i in key.split("."))
def split(key: ChunkKey) -> Tuple[int, ...]:
return tuple(int(i) for i in key.split("."))


def join(inds: Iterable[int]) -> ChunkKey:
return cast(ChunkKey, ".".join(str(i) for i in inds))
def join(inds: Iterable[Any]) -> ChunkKey:
return cast(ChunkKey, ".".join(str(i) for i in list(inds)))


def get_ndim_from_key(key: str) -> int:
Expand All @@ -154,9 +220,6 @@ def validate_chunk_keys(chunk_keys: Iterable[ChunkKey]):
f"Inconsistent number of dimensions between chunk key {key} and {first_key}: {other_ndim} vs {ndim}"
)

# Check that the keys collectively form a complete grid
check_keys_form_grid(chunk_keys)


def get_chunk_grid_shape(chunk_keys: Iterable[ChunkKey]) -> Tuple[int, ...]:
# find max chunk index along each dimension
Expand All @@ -165,99 +228,3 @@ def get_chunk_grid_shape(chunk_keys: Iterable[ChunkKey]) -> Tuple[int, ...]:
max(indices_along_one_dim) + 1 for indices_along_one_dim in zipped_indices
)
return chunk_grid_shape


def check_keys_form_grid(chunk_keys: Iterable[ChunkKey]):
"""Check that the chunk keys collectively form a complete grid"""

chunk_grid_shape = get_chunk_grid_shape(chunk_keys)

# create every possible combination
all_possible_combos = itertools.product(
*[range(length) for length in chunk_grid_shape]
)
all_required_chunk_keys: set[ChunkKey] = set(
join(inds) for inds in all_possible_combos
)

# check that every possible combination is represented once in the list of chunk keys
if set(chunk_keys) != all_required_chunk_keys:
raise ValueError("Chunk keys do not form a complete grid")


def concat_manifests(manifests: List["ChunkManifest"], axis: int) -> "ChunkManifest":
"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 188-263 are what we get rid of by doing concatenation/stacking via the wrapped structured array.

Concatenate manifests along an existing dimension.

This only requires adjusting one index of chunk keys along a single dimension.

Note axis is not expected to be negative.
"""
if len(manifests) == 1:
return manifests[0]

chunk_grid_shapes = [manifest.shape_chunk_grid for manifest in manifests]
lengths_along_concat_dim = [shape[axis] for shape in chunk_grid_shapes]

# Note we do not need to change the keys of the first manifest
chunk_index_offsets = np.cumsum(lengths_along_concat_dim)[:-1]
new_entries = [
adjust_chunk_keys(manifest.entries, axis, offset)
for manifest, offset in zip(manifests[1:], chunk_index_offsets)
]
all_entries = [manifests[0].entries] + new_entries
merged_entries = dict((k, v) for d in all_entries for k, v in d.items())

# Arguably don't need to re-perform validation checks on a manifest we created out of already-validated manifests
# Could use pydantic's model_construct classmethod to skip these checks
# But we should actually performance test it because it might be pointless, and current implementation is safer
return ChunkManifest(entries=merged_entries)


def adjust_chunk_keys(
entries: Mapping[ChunkKey, ChunkEntry], axis: int, offset: int
) -> Mapping[ChunkKey, ChunkEntry]:
"""Replace all chunk keys with keys which have been offset along one axis."""

def offset_key(key: ChunkKey, axis: int, offset: int) -> ChunkKey:
inds = split(key)
inds[axis] += offset
return join(inds)

return {offset_key(k, axis, offset): v for k, v in entries.items()}


def stack_manifests(manifests: List[ChunkManifest], axis: int) -> "ChunkManifest":
"""
Stack manifests along a new dimension.

This only requires inserting one index into all chunk keys to add a new dimension.

Note axis is not expected to be negative.
"""

# even if there is only one manifest it still needs a new axis inserted
chunk_indexes_along_new_dim = range(len(manifests))
new_entries = [
insert_new_axis_into_chunk_keys(manifest.entries, axis, new_index_value)
for manifest, new_index_value in zip(manifests, chunk_indexes_along_new_dim)
]
merged_entries = dict((k, v) for d in new_entries for k, v in d.items())

# Arguably don't need to re-perform validation checks on a manifest we created out of already-validated manifests
# Could use pydantic's model_construct classmethod to skip these checks
# But we should actually performance test it because it might be pointless, and current implementation is safer
return ChunkManifest(entries=merged_entries)


def insert_new_axis_into_chunk_keys(
entries: Mapping[ChunkKey, ChunkEntry], axis: int, new_index_value: int
) -> Mapping[ChunkKey, ChunkEntry]:
"""Replace all chunk keys with keys which have a new axis inserted, with a given value."""

def insert_axis(key: ChunkKey, new_axis: int, index_value: int) -> ChunkKey:
inds = split(key)
inds.insert(new_axis, index_value)
return join(inds)

return {insert_axis(k, axis, new_index_value): v for k, v in entries.items()}
Loading
Loading