diff --git a/parquet_to_root/__main__.py b/parquet_to_root/__main__.py index 4d65a9f..1b52a44 100644 --- a/parquet_to_root/__main__.py +++ b/parquet_to_root/__main__.py @@ -1,8 +1,9 @@ from . import parquet_to_root import argparse +import sys parser = argparse.ArgumentParser() -parser.add_argument('infile', help='Input Parquet file') +parser.add_argument('infile', nargs='+', help='Input Parquet file') parser.add_argument('outfile', help='Output ROOT file') parser.add_argument('--treename', '-t', default='parquettree', help='Name of TTree') @@ -10,6 +11,11 @@ help='Verbose output') opts = parser.parse_args() -parquet_to_root(opts.infile, opts.outfile, - opts.treename, - opts.verbose) +try: + parquet_to_root(opts.infile, opts.outfile, + opts.treename, + opts.verbose) +except Exception as e: + print('Failure...') + print(e) + sys.exit(1) diff --git a/parquet_to_root/parquet_to_root_pyroot.py b/parquet_to_root/parquet_to_root_pyroot.py index 2556b5b..f3abfc9 100644 --- a/parquet_to_root/parquet_to_root_pyroot.py +++ b/parquet_to_root/parquet_to_root_pyroot.py @@ -91,15 +91,38 @@ def _do_fill(tree, entry, table, numpybufs, stringvars, vectorlens, stringarrs): tree.Fill() -def parquet_to_root_pyroot(infile, outfile, treename='parquettree', +def normalize_parquet(infiles): + '''Convert infiles argument to list; verify schema match across all files''' + import pyarrow.parquet as pq + import io + + # convert to a list + if isinstance(infiles, str) or isinstance(infiles, io.IOBase): + lfiles = [infiles] + else: + try: + lfiles = list(infiles) + except TypeError: + # This really shouldn't be hit, but maybe there's an edge case + lfiles = [infiles] + + schema = pq.read_schema(lfiles[0]) + for f in lfiles[1:]: + schema2 = pq.read_schema(f) + if schema != schema2: + raise ValueError(f"Mismatched Parquet schemas between {infiles[0]} and {f}") + + return lfiles, schema + + +def parquet_to_root_pyroot(infiles, outfile, treename='parquettree', verbose=False): import pyarrow.parquet as pq import pyarrow import ROOT - # Use parquet metadata for schema - table = pq.read_table(infile) - schema = table.schema + # Interpret files + infiles, schema = normalize_parquet(infiles) fout, local_root_file_creation = _get_outfile(outfile) tree = ROOT.TTree(treename, 'Parquet tree') @@ -129,9 +152,11 @@ def parquet_to_root_pyroot(infile, outfile, treename='parquettree', raise ValueError(f'Cannot translate field "{branch}" of input Parquet schema. Field is described as {field.type}') # Fill loop - for entry in range(len(table)): - # trash on every pass through loop; just here to make sure nothing gets garbage collected early - _do_fill(tree, entry, table, numpybufs, stringvars, vectorlens, stringarrs) + for infile in infiles: + table = pq.read_table(infile) + for entry in range(len(table)): + # trash on every pass through loop; just here to make sure nothing gets garbage collected early + _do_fill(tree, entry, table, numpybufs, stringvars, vectorlens, stringarrs) tree.Write() if local_root_file_creation: diff --git a/tests/test_run.py b/tests/test_run.py index 51e52f8..0a9963a 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -64,3 +64,50 @@ def test_run_with_existing_rootfile(): t2 = rf.Get('stars') assert t2.GetEntries() == 2935 return True + + +def test_run_with_multiple_inputs(): + from parquet_to_root import parquet_to_root + ROOT = pytest.importorskip("ROOT") + parquet_to_root(['tests/samples/HZZ.parquet','tests/samples/HZZ.parquet'], + 'HZZ.root', verbose=True) + rdf = ROOT.RDataFrame('parquettree', 'HZZ.root') + assert rdf.Count().GetValue() == 4842 + assert rdf.GetColumnNames().size() == 74 + assert rdf.Mean("Muon_Px").GetValue() == -0.6551689155476192 + assert rdf.Mean("MET_px").GetValue() == 0.23863275654291605 + return True + + +def test_fail_on_incompatible_inputs(): + from parquet_to_root import parquet_to_root + ROOT = pytest.importorskip("ROOT") + with pytest.raises(ValueError): + parquet_to_root(['tests/samples/HZZ.parquet','tests/samples/exoplanets.parquet'], + 'HZZ.root', verbose=True) + + +def test_cmdline_multiple_inputs(): + ROOT = pytest.importorskip("ROOT") + import subprocess + chk = subprocess.run("python3 -m parquet_to_root tests/samples/HZZ.parquet tests/samples/HZZ.parquet HZZ.root -t newtree", + shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + print(chk.stdout) + chk.check_returncode() + + rf = ROOT.TFile.Open('HZZ.root') + t = rf.Get('newtree') + assert t.GetEntries() == 4842 + return True + + +def test_cmdline_incompatible_inputs(): + ROOT = pytest.importorskip("ROOT") + import subprocess + chk = subprocess.run("python3 -m parquet_to_root tests/samples/HZZ.parquet tests/samples/exoplanets.parquet HZZ.root -t newtree", + shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + print(chk.stdout) + with pytest.raises(subprocess.CalledProcessError): + chk.check_returncode() + + return True