Skip to content

Commit

Permalink
Permit multiple input parquet files
Browse files Browse the repository at this point in the history
  • Loading branch information
ponyisi committed Jan 12, 2021
1 parent 35bbefa commit 0103345
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 11 deletions.
14 changes: 10 additions & 4 deletions parquet_to_root/__main__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from . import parquet_to_root
import argparse
import sys

parser = argparse.ArgumentParser()
parser.add_argument('infile', help='Input Parquet file')
parser.add_argument('infile', nargs='+', help='Input Parquet file')
parser.add_argument('outfile', help='Output ROOT file')
parser.add_argument('--treename', '-t', default='parquettree',
help='Name of TTree')
parser.add_argument('--verbose', '-v', action='store_true',
help='Verbose output')
opts = parser.parse_args()

parquet_to_root(opts.infile, opts.outfile,
opts.treename,
opts.verbose)
try:
parquet_to_root(opts.infile, opts.outfile,
opts.treename,
opts.verbose)
except Exception as e:
print('Failure...')
print(e)
sys.exit(1)
39 changes: 32 additions & 7 deletions parquet_to_root/parquet_to_root_pyroot.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,38 @@ def _do_fill(tree, entry, table, numpybufs, stringvars, vectorlens, stringarrs):
tree.Fill()


def parquet_to_root_pyroot(infile, outfile, treename='parquettree',
def normalize_parquet(infiles):
'''Convert infiles argument to list; verify schema match across all files'''
import pyarrow.parquet as pq
import io

# convert to a list
if isinstance(infiles, str) or isinstance(infiles, io.IOBase):
lfiles = [infiles]
else:
try:
lfiles = list(infiles)
except TypeError:
# This really shouldn't be hit, but maybe there's an edge case
lfiles = [infiles]

schema = pq.read_schema(lfiles[0])
for f in lfiles[1:]:
schema2 = pq.read_schema(f)
if schema != schema2:
raise ValueError(f"Mismatched Parquet schemas between {infiles[0]} and {f}")

return lfiles, schema


def parquet_to_root_pyroot(infiles, outfile, treename='parquettree',
verbose=False):
import pyarrow.parquet as pq
import pyarrow
import ROOT

# Use parquet metadata for schema
table = pq.read_table(infile)
schema = table.schema
# Interpret files
infiles, schema = normalize_parquet(infiles)

fout, local_root_file_creation = _get_outfile(outfile)
tree = ROOT.TTree(treename, 'Parquet tree')
Expand Down Expand Up @@ -129,9 +152,11 @@ def parquet_to_root_pyroot(infile, outfile, treename='parquettree',
raise ValueError(f'Cannot translate field "{branch}" of input Parquet schema. Field is described as {field.type}')

# Fill loop
for entry in range(len(table)):
# trash on every pass through loop; just here to make sure nothing gets garbage collected early
_do_fill(tree, entry, table, numpybufs, stringvars, vectorlens, stringarrs)
for infile in infiles:
table = pq.read_table(infile)
for entry in range(len(table)):
# trash on every pass through loop; just here to make sure nothing gets garbage collected early
_do_fill(tree, entry, table, numpybufs, stringvars, vectorlens, stringarrs)

tree.Write()
if local_root_file_creation:
Expand Down
47 changes: 47 additions & 0 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,50 @@ def test_run_with_existing_rootfile():
t2 = rf.Get('stars')
assert t2.GetEntries() == 2935
return True


def test_run_with_multiple_inputs():
from parquet_to_root import parquet_to_root
ROOT = pytest.importorskip("ROOT")
parquet_to_root(['tests/samples/HZZ.parquet','tests/samples/HZZ.parquet'],
'HZZ.root', verbose=True)
rdf = ROOT.RDataFrame('parquettree', 'HZZ.root')
assert rdf.Count().GetValue() == 4842
assert rdf.GetColumnNames().size() == 74
assert rdf.Mean("Muon_Px").GetValue() == -0.6551689155476192
assert rdf.Mean("MET_px").GetValue() == 0.23863275654291605
return True


def test_fail_on_incompatible_inputs():
from parquet_to_root import parquet_to_root
ROOT = pytest.importorskip("ROOT")
with pytest.raises(ValueError):
parquet_to_root(['tests/samples/HZZ.parquet','tests/samples/exoplanets.parquet'],
'HZZ.root', verbose=True)


def test_cmdline_multiple_inputs():
ROOT = pytest.importorskip("ROOT")
import subprocess
chk = subprocess.run("python3 -m parquet_to_root tests/samples/HZZ.parquet tests/samples/HZZ.parquet HZZ.root -t newtree",
shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
print(chk.stdout)
chk.check_returncode()

rf = ROOT.TFile.Open('HZZ.root')
t = rf.Get('newtree')
assert t.GetEntries() == 4842
return True


def test_cmdline_incompatible_inputs():
ROOT = pytest.importorskip("ROOT")
import subprocess
chk = subprocess.run("python3 -m parquet_to_root tests/samples/HZZ.parquet tests/samples/exoplanets.parquet HZZ.root -t newtree",
shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
print(chk.stdout)
with pytest.raises(subprocess.CalledProcessError):
chk.check_returncode()

return True

0 comments on commit 0103345

Please sign in to comment.