Skip to content

Commit

Permalink
sketch idea for lithops parallelization
Browse files Browse the repository at this point in the history
  • Loading branch information
TomNicholas committed Dec 15, 2024
1 parent a48e8a4 commit 75c7da3
Showing 1 changed file with 36 additions and 11 deletions.
47 changes: 36 additions & 11 deletions virtualizarr/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,16 @@ def open_virtual_mfdataset(
if preprocess is not None:
preprocess = dask.delayed(preprocess)
elif parallel == "lithops":
raise NotImplementedError()
import lithops

# TODO use RetryingFunctionExecutor instead?
# TODO what's the easiest way to pass the lithops config in?
fn_exec = lithops.FunctionExecutor()

# lithops doesn't have a delayed primitive
open_ = open_virtual_dataset
# TODO I don't know how best to chain this with the getattr, or if that closing stuff is even necessary for virtual datasets
# getattr_ = getattr
elif parallel is not False:
raise ValueError(
f"{parallel} is an invalid option for the keyword argument ``parallel``"
Expand All @@ -351,25 +360,41 @@ def open_virtual_mfdataset(
open_ = open_virtual_dataset
getattr_ = getattr

datasets = [open_(p, **kwargs) for p in paths1d]
closers = [getattr_(ds, "_close") for ds in datasets]
if preprocess is not None:
datasets = [preprocess(ds) for ds in datasets]

if parallel == "dask":
datasets = [open_(p, **kwargs) for p in paths1d]
closers = [getattr_(ds, "_close") for ds in datasets]
if preprocess is not None:
datasets = [preprocess(ds) for ds in datasets]

# calling compute here will return the datasets/file_objs lists,
# the underlying datasets will still be stored as dask arrays
datasets, closers = dask.compute(datasets, closers)
elif parallel == "lithops":
raise NotImplementedError()

def generate_refs(path):
# allows passing the open_virtual_dataset function to lithops without evaluating it
vds = open_(path, **kwargs)
# TODO perhaps we should just load the loadable_vars here and close before returning?
return vds

futures = fn_exec.map(generate_refs, paths1d)

# wait for all the serverless workers to finish, and send their resulting virtual datasets back to the client
completed_futures, _ = fn_exec.wait(futures, download_results=True)
virtual_datasets = [future.get_result() for future in completed_futures]
elif parallel is False:
datasets = [open_(p, **kwargs) for p in paths1d]
closers = [getattr_(ds, "_close") for ds in datasets]
if preprocess is not None:
datasets = [preprocess(ds) for ds in datasets]

# Combine all datasets, closing them in case of a ValueError
try:
if combine == "nested":
# Combined nested list by successive concat and merge operations
# along each dimension, using structure given by "ids"
combined = _nested_combine(
datasets,
virtual_datasets,
concat_dims=concat_dim,
compat=compat,
data_vars=data_vars,
Expand All @@ -382,7 +407,7 @@ def open_virtual_mfdataset(
# Redo ordering from coordinates, ignoring how they were ordered
# previously
combined = combine_by_coords(
datasets,
virtual_datasets,
compat=compat,
data_vars=data_vars,
coords=coords,
Expand All @@ -395,7 +420,7 @@ def open_virtual_mfdataset(
" ``combine``"
)
except ValueError:
for ds in datasets:
for ds in virtual_datasets:
ds.close()
raise

Expand All @@ -405,7 +430,7 @@ def open_virtual_mfdataset(
if attrs_file is not None:
if isinstance(attrs_file, os.PathLike):
attrs_file = cast(str, os.fspath(attrs_file))
combined.attrs = datasets[paths1d.index(attrs_file)].attrs
combined.attrs = virtual_datasets[paths1d.index(attrs_file)].attrs

# TODO should we just immediately close everything?
# TODO We should have already read everything we're ever going to read into memory at this point
Expand Down

0 comments on commit 75c7da3

Please sign in to comment.