Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 6, 2024
1 parent 1666427 commit 2678206
Show file tree
Hide file tree
Showing 13 changed files with 94 additions and 125 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ repos:
args: [ --fix ]
# Run the formatter.
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
hooks:
Expand Down
18 changes: 8 additions & 10 deletions virtualizarr/kerchunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@

from enum import Enum, auto


class AutoName(Enum):
# Recommended by official Python docs for auto naming:
# https://docs.python.org/3/library/enum.html#using-automatic-values
def _generate_next_value_(name, start, count, last_values):
return name


class FileType(AutoName):
netcdf3 = auto()
netcdf4 = auto()
Expand All @@ -34,6 +36,7 @@ class FileType(AutoName):
fits = auto()
zarr = auto()


def read_kerchunk_references_from_file(
filepath: str, filetype: Optional[FileType]
) -> KerchunkStoreRefs:
Expand All @@ -57,6 +60,7 @@ def read_kerchunk_references_from_file(

if filetype.name.lower() == "netcdf3":
from kerchunk.netCDF3 import NetCDF3ToZarr

refs = NetCDF3ToZarr(filepath, inline_threshold=0).translate()

elif filetype.name.lower() == "netcdf4":
Expand Down Expand Up @@ -87,7 +91,7 @@ def _automatically_determine_filetype(filepath: str) -> FileType:

if file_extension == ".nc":
# based off of: https://github.com/TomNicholas/VirtualiZarr/pull/43#discussion_r1543415167
with open(filepath, 'rb') as f:
with open(filepath, "rb") as f:
magic = f.read()
if magic[0:3] == b"CDF":
filetype = FileType.netcdf3
Expand Down Expand Up @@ -119,9 +123,7 @@ def find_var_names(ds_reference_dict: KerchunkStoreRefs) -> list[str]:
return found_var_names


def extract_array_refs(
ds_reference_dict: KerchunkStoreRefs, var_name: str
) -> KerchunkArrRefs:
def extract_array_refs(ds_reference_dict: KerchunkStoreRefs, var_name: str) -> KerchunkArrRefs:
"""Extract only the part of the kerchunk reference dict that is relevant to this one zarr array"""

found_var_names = find_var_names(ds_reference_dict)
Expand All @@ -131,9 +133,7 @@ def extract_array_refs(
# TODO these function probably have more loops in them than they need to...

arr_refs = {
key.split("/")[1]: refs[key]
for key in refs.keys()
if var_name == key.split("/")[0]
key.split("/")[1]: refs[key] for key in refs.keys() if var_name == key.split("/")[0]
}

return fully_decode_arr_refs(arr_refs)
Expand Down Expand Up @@ -175,9 +175,7 @@ def dataset_to_kerchunk_refs(ds: xr.Dataset) -> KerchunkStoreRefs:
for var_name, var in ds.variables.items():
arr_refs = variable_to_kerchunk_arr_refs(var)

prepended_with_var_name = {
f"{var_name}/{key}": val for key, val in arr_refs.items()
}
prepended_with_var_name = {f"{var_name}/{key}": val for key, val in arr_refs.items()}

all_arr_refs.update(prepended_with_var_name)

Expand Down
8 changes: 2 additions & 6 deletions virtualizarr/manifests/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def _from_kerchunk_refs(cls, arr_refs: KerchunkArrRefs) -> "ManifestArray":

zarray = ZArray.from_kerchunk_refs(decoded_arr_refs[".zarray"])

kerchunk_chunk_dict = {
k: v for k, v in decoded_arr_refs.items() if re.match(_CHUNK_KEY, k)
}
kerchunk_chunk_dict = {k: v for k, v in decoded_arr_refs.items() if re.match(_CHUNK_KEY, k)}
chunkmanifest = ChunkManifest._from_kerchunk_chunk_dict(kerchunk_chunk_dict)

obj = object.__new__(cls)
Expand Down Expand Up @@ -206,9 +204,7 @@ def __getitem__(
indexer = _possibly_expand_trailing_ellipsis(key, self.ndim)

if len(indexer) != self.ndim:
raise ValueError(
f"Invalid indexer for array with ndim={self.ndim}: {indexer}"
)
raise ValueError(f"Invalid indexer for array with ndim={self.ndim}: {indexer}")

if all(
isinstance(axis_indexer, slice) and axis_indexer == slice(None)
Expand Down
12 changes: 3 additions & 9 deletions virtualizarr/manifests/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,12 @@ def _check_same_ndims(ndims: list[int]) -> None:

def _check_same_shapes_except_on_concat_axis(shapes: list[tuple[int, ...]], axis: int):
"""Check that shapes are compatible for concatenation"""
shapes_without_concat_axis = [
_remove_element_at_position(shape, axis) for shape in shapes
]
shapes_without_concat_axis = [_remove_element_at_position(shape, axis) for shape in shapes]

first_shape, *other_shapes = shapes_without_concat_axis
for other_shape in other_shapes:
if other_shape != first_shape:
raise ValueError(
f"Cannot concatenate arrays with shapes {[shape for shape in shapes]}"
)
raise ValueError(f"Cannot concatenate arrays with shapes {[shape for shape in shapes]}")


def _remove_element_at_position(t: tuple[int, ...], pos: int) -> tuple[int, ...]:
Expand Down Expand Up @@ -273,9 +269,7 @@ def broadcast_to(x: "ManifestArray", /, shape: Tuple[int, ...]) -> "ManifestArra
# concatenate same array upon itself d_requested number of times along existing axis
result = concatenate([result] * d_requested, axis=axis)
else:
raise ValueError(
f"Array with shape {x.shape} cannot be broadcast to shape {shape}"
)
raise ValueError(f"Array with shape {x.shape} cannot be broadcast to shape {shape}")

return result

Expand Down
20 changes: 5 additions & 15 deletions virtualizarr/manifests/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

from ..types import ChunkKey

_INTEGER = (
r"([1-9]+\d*|0)" # matches 0 or an unsigned integer that does not begin with zero
)
_INTEGER = r"([1-9]+\d*|0)" # matches 0 or an unsigned integer that does not begin with zero
_SEPARATOR = r"\."
_CHUNK_KEY = rf"^{_INTEGER}+({_SEPARATOR}{_INTEGER})*$" # matches 1 integer, optionally followed by more integers each separated by a separator (i.e. a period)

Expand All @@ -32,9 +30,7 @@ def __repr__(self) -> str:
return f"ChunkEntry(path='{self.path}', offset={self.offset}, length={self.length})"

@classmethod
def from_kerchunk(
cls, path_and_byte_range_info: List[Union[str, int]]
) -> "ChunkEntry":
def from_kerchunk(cls, path_and_byte_range_info: List[Union[str, int]]) -> "ChunkEntry":
path, offset, length = path_and_byte_range_info
return ChunkEntry(path=path, offset=offset, length=length)

Expand Down Expand Up @@ -127,9 +123,7 @@ def to_zarr_json(self, filepath: str) -> None:

@classmethod
def _from_kerchunk_chunk_dict(cls, kerchunk_chunk_dict) -> "ChunkManifest":
chunkentries = {
k: ChunkEntry.from_kerchunk(v) for k, v in kerchunk_chunk_dict.items()
}
chunkentries = {k: ChunkEntry.from_kerchunk(v) for k, v in kerchunk_chunk_dict.items()}
return ChunkManifest(entries=chunkentries)


Expand Down Expand Up @@ -181,12 +175,8 @@ def check_keys_form_grid(chunk_keys: Iterable[ChunkKey]):
chunk_grid_shape = get_chunk_grid_shape(chunk_keys)

# create every possible combination
all_possible_combos = itertools.product(
*[range(length) for length in chunk_grid_shape]
)
all_required_chunk_keys: set[ChunkKey] = set(
join(inds) for inds in all_possible_combos
)
all_possible_combos = itertools.product(*[range(length) for length in chunk_grid_shape])
all_required_chunk_keys: set[ChunkKey] = set(join(inds) for inds in all_possible_combos)

# check that every possible combination is represented once in the list of chunk keys
if set(chunk_keys) != all_required_chunk_keys:
Expand Down
1 change: 1 addition & 0 deletions virtualizarr/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import xarray as xr


@pytest.fixture
def netcdf4_file(tmpdir):
# Set up example xarray dataset
Expand Down
46 changes: 27 additions & 19 deletions virtualizarr/tests/test_kerchunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from virtualizarr.manifests import ChunkEntry, ChunkManifest, ManifestArray
from virtualizarr.xarray import dataset_from_kerchunk_refs


def gen_ds_refs(
zgroup: str = '{"zarr_format":2}',
zarray: str = '{"chunks":[2,3],"compressor":null,"dtype":"<i8","fill_value":null,"filters":null,"order":"C","shape":[2,3],"zarr_format":2}',
zattrs: str = '{"_ARRAY_DIMENSIONS":["x","y"]}',
chunk: list = ["test1.nc", 6144, 48],
zgroup: str = '{"zarr_format":2}',
zarray: str = '{"chunks":[2,3],"compressor":null,"dtype":"<i8","fill_value":null,"filters":null,"order":"C","shape":[2,3],"zarr_format":2}',
zattrs: str = '{"_ARRAY_DIMENSIONS":["x","y"]}',
chunk: list = ["test1.nc", 6144, 48],
):
return {
"version": 1,
Expand All @@ -25,9 +26,10 @@ def gen_ds_refs(
},
}


def test_dataset_from_df_refs():
ds_refs = gen_ds_refs()
ds = dataset_from_kerchunk_refs(ds_refs)
ds = dataset_from_kerchunk_refs(ds_refs)
assert "a" in ds
da = ds["a"]
assert isinstance(da.data, ManifestArray)
Expand All @@ -41,15 +43,23 @@ def test_dataset_from_df_refs():
assert da.data.zarray.fill_value is None
assert da.data.zarray.order == "C"

assert da.data.manifest.dict() == {
"0.0": {"path": "test1.nc", "offset": 6144, "length": 48}
}
assert da.data.manifest.dict() == {"0.0": {"path": "test1.nc", "offset": 6144, "length": 48}}


def test_dataset_from_df_refs_with_filters():
filters = [{"elementsize":4,"id":"shuffle"},{"id":"zlib","level":4}]
zarray = {"chunks":[2,3],"compressor":None,"dtype":"<i8","fill_value":None,"filters":filters,"order":"C","shape":[2,3],"zarr_format":2}
filters = [{"elementsize": 4, "id": "shuffle"}, {"id": "zlib", "level": 4}]
zarray = {
"chunks": [2, 3],
"compressor": None,
"dtype": "<i8",
"fill_value": None,
"filters": filters,
"order": "C",
"shape": [2, 3],
"zarr_format": 2,
}
ds_refs = gen_ds_refs(zarray=ujson.dumps(zarray))
ds = dataset_from_kerchunk_refs(ds_refs)
ds = dataset_from_kerchunk_refs(ds_refs)
da = ds["a"]
assert da.data.zarray.filters == filters

Expand Down Expand Up @@ -163,15 +173,13 @@ def test_automatically_determine_filetype_netcdf3_netcdf4():
assert FileType("netcdf4") == _automatically_determine_filetype(netcdf4_file_path)




def test_FileType():
# tests if FileType converts user supplied strings to correct filetype
assert 'netcdf3' == FileType("netcdf3").name
assert 'netcdf4' == FileType("netcdf4").name
assert 'grib' == FileType("grib").name
assert 'tiff' == FileType("tiff").name
assert 'fits' == FileType("fits").name
assert 'zarr' == FileType("zarr").name
assert "netcdf3" == FileType("netcdf3").name
assert "netcdf4" == FileType("netcdf4").name
assert "grib" == FileType("grib").name
assert "tiff" == FileType("tiff").name
assert "fits" == FileType("fits").name
assert "zarr" == FileType("zarr").name
with pytest.raises(ValueError):
FileType(None)
3 changes: 1 addition & 2 deletions virtualizarr/tests/test_manifests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ def test_not_equal_chunk_entries(self):
assert not (marr1 == marr2).all()

@pytest.mark.skip(reason="Not Implemented")
def test_partly_equals(self):
...
def test_partly_equals(self): ...


# TODO we really need some kind of fixtures to generate useful example data
Expand Down
6 changes: 2 additions & 4 deletions virtualizarr/tests/test_manifests/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,6 @@ def test_stack(self):

@pytest.mark.skip(reason="Not implemented")
class TestSerializeManifest:
def test_serialize_manifest_to_zarr(self):
...
def test_serialize_manifest_to_zarr(self): ...

def test_deserialize_manifest_from_zarr(self):
...
def test_deserialize_manifest_from_zarr(self): ...
5 changes: 1 addition & 4 deletions virtualizarr/tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,6 @@ def test_concat_dim_coords_along_existing_dim(self):
assert result.data.zarray.zarr_format == zarray.zarr_format





class TestOpenVirtualDatasetIndexes:
def test_no_indexes(self, netcdf4_file):
vds = open_virtual_dataset(netcdf4_file, indexes={})
Expand Down Expand Up @@ -273,7 +270,7 @@ def test_combine_by_coords(self, netcdf4_files):

class TestLoadVirtualDataset:
def test_loadable_variables(self, netcdf4_file):
vars_to_load = ['air', 'time']
vars_to_load = ["air", "time"]
vds = open_virtual_dataset(netcdf4_file, loadable_variables=vars_to_load)

for name in vds.variables:
Expand Down
24 changes: 12 additions & 12 deletions virtualizarr/tests/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@

def test_zarr_v3_roundtrip(tmpdir):
arr = ManifestArray(
chunkmanifest={"0.0": ChunkEntry(path="test.nc", offset=6144, length=48)},
zarray=dict(
shape=(2, 3),
dtype=np.dtype("<i8"),
chunks=(2, 3),
compressor=None,
filters=None,
fill_value=None,
order="C",
zarr_format=3,
),
)
chunkmanifest={"0.0": ChunkEntry(path="test.nc", offset=6144, length=48)},
zarray=dict(
shape=(2, 3),
dtype=np.dtype("<i8"),
chunks=(2, 3),
compressor=None,
filters=None,
fill_value=None,
order="C",
zarr_format=3,
),
)
original = xr.Dataset({"a": (["x", "y"], arr)}, attrs={"something": 0})

original.virtualize.to_zarr(tmpdir / "store.zarr")
Expand Down
Loading

0 comments on commit 2678206

Please sign in to comment.