Skip to content

Commit

Permalink
add fadvise to open zarr for reading; optimize segmentat zarr batch w…
Browse files Browse the repository at this point in the history
…ith workers; add segmentation to filter reporting ; add ls_ds
  • Loading branch information
misko committed Dec 1, 2024
1 parent 19744c8 commit ab184a0
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 17 deletions.
79 changes: 79 additions & 0 deletions spf/scripts/ls_ds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import argparse
import concurrent
import json
import multiprocessing
import os
import time

from tqdm import tqdm

from spf.dataset.spf_dataset import v5spfdataset

LS_VERSION = 1.0


def ls_zarr(ds_fn, force=False):
ls_fn = ds_fn + ".ls.json"
if force or not os.path.exists(ls_fn):
ds = v5spfdataset(
ds_fn,
nthetas=65,
precompute_cache=None,
skip_fields=set(["signal_matrix"]),
ignore_qc=True,
segment_if_not_exist=False,
temp_file=True,
temp_file_suffix="",
)
ls_info = {
"ds_fn": ds_fn,
"frequency": ds.carrier_frequencies[0],
"rx_spacing": ds.rx_spacing,
"samples": len(ds),
"routine": ds.yaml_config["routine"],
"version": LS_VERSION,
}
with open(ls_fn, "w") as fp:
json.dump(ls_info, fp, indent=4)
with open(ls_fn, "r") as file:
ls_info = json.load(file)
if "version" not in ls_info or ls_info["version"] != LS_VERSION:
assert not force, f"LS_DS version not found(?) {ls_info} vs {LS_VERSION}"
return ls_zarr(ds_fn, force=True)
return ls_info
raise ValueError("Could not not process file {ds_fn}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i", "--input-zarrs", type=str, nargs="+", help="input zarr", required=True
)
parser.add_argument("--debug", default=False, action=argparse.BooleanOptionalAction)
parser.add_argument("-w", "--workers", type=int, default=4, help="n workers")
args = parser.parse_args()

if args.debug:
results = list(map(ls_zarr, args.input_zarrs))
else:
with concurrent.futures.ProcessPoolExecutor(
max_workers=args.workers
) as executor:
results = list(
tqdm(
executor.map(ls_zarr, args.input_zarrs), total=len(args.input_zarrs)
)
)
print(results[0], len(results))

# aggregate results
merged_stats = {}
for result in results:
key = f"{result['frequency']},{result['rx_spacing']},{result['routine']}"
if key not in merged_stats:
merged_stats[key] = 0
merged_stats[key] += result["samples"]

print("frequency,rx_spacing,routine,samples")
for key in sorted(merged_stats.keys()):
print(f"{key},{merged_stats[key]}")
2 changes: 1 addition & 1 deletion spf/scripts/run_filters_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def report_workdir_to_csv(workdir, output_csv_fn):
total=len(fns),
)
)
header = ["type", "movement"]
header = ["type", "movement", "segmentation_version"]
for field in results[0][0].keys():
if field == "ds_fn":
pass
Expand Down
53 changes: 37 additions & 16 deletions spf/scripts/segment_zarr.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,51 @@
import argparse
import concurrent

from spf.dataset.spf_dataset import v5spfdataset


def process_zarr(args):
ds = v5spfdataset(
args["input_zarr"],
nthetas=65,
precompute_cache=args["precompute_cache"],
gpu=args["gpu"],
skip_fields=set(["signal_matrix"]),
ignore_qc=True,
# readahead=True, #this is hard coded in the segmentation code
n_parallel=args["parallel"],
segment_if_not_exist=True,
)
print(args["input_zarr"], ds.phi_drifts[0], ds.phi_drifts[1])


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i", "--input-zarr", type=str, help="input zarr", required=True
"-i", "--input-zarrs", type=str, nargs="+", help="input zarr", required=True
)
parser.add_argument(
"-c", "--precompute-cache", type=str, help="precompute cache", required=True
)
parser.add_argument("--gpu", default=True, action=argparse.BooleanOptionalAction)
parser.add_argument(
"-p", "--parallel", type=int, default=24, help="precompute cache"
)
parser.add_argument("--debug", default=False, action=argparse.BooleanOptionalAction)
parser.add_argument("-p", "--parallel", type=int, default=12, help="parallel")
parser.add_argument("-w", "--workers", type=int, default=2, help="n workers")
args = parser.parse_args()

ds = v5spfdataset(
args.input_zarr,
nthetas=65,
precompute_cache=args.precompute_cache,
gpu=args.gpu,
skip_fields=set(["signal_matrix"]),
ignore_qc=True,
# readahead=True, #this is hard coded in the segmentation code
n_parallel=args.parallel,
segment_if_not_exist=True,
)
print(args.input_zarr, ds.phi_drifts[0], ds.phi_drifts[1])
jobs = [
{
"input_zarr": zarr_fn,
"precompute_cache": args.precompute_cache,
"parallel": args.parallel,
"gpu": args.gpu,
}
for zarr_fn in args.input_zarrs
]
if args.debug:
list(map(process_zarr, jobs))
else:
with concurrent.futures.ProcessPoolExecutor(
max_workers=args.workers
) as executor:
executor.map(process_zarr, jobs)
7 changes: 7 additions & 0 deletions spf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,17 @@ def zarr_shrink(filename):

@contextmanager
def zarr_open_from_lmdb_store_cm(filename, mode="r", readahead=False):
f = None
try:
if mode == "r":
f = open(filename + "/data.mdb", "rb")
os.posix_fadvise(f.fileno(), 0, 0, os.POSIX_FADV_WILLNEED)
z = zarr_open_from_lmdb_store(filename, mode, readahead=readahead)
yield z
finally:
if f:
os.posix_fadvise(f.fileno(), 0, 0, os.POSIX_FADV_DONTNEED)
f.close()
z.store.close()


Expand Down

0 comments on commit ab184a0

Please sign in to comment.