-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1. add
loop
methdod for RooAbsData
and implement rows
in terms…
… of `loop` 1. allow more recusion in `vars_and_cuts` function 1. add new test 1. From now for weighted datasets `dataset[i]` returns `(entry,weight)` tuple 1. from now iteration over weighted dataset gives `(entry,weight)` tuple 1. change sinature of `dataset.loop` , `dataset.rows` methods to return triplets `index, entry, weight`
- Loading branch information
1 parent
719419b
commit e7bea8c
Showing
8 changed files
with
261 additions
and
119 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ | |
) | ||
# ============================================================================= | ||
from builtins import range | ||
from ostap.utils.progress_bar import progress_bar | ||
from collections import defaultdict | ||
from ostap.core.meta_info import root_info, ostap_version | ||
from ostap.core.core import ( Ostap, VE, SE , | ||
|
@@ -61,60 +62,113 @@ | |
_minv = -0.99 * sys.float_info.max | ||
# ============================================================================= | ||
## iterator for RooAbsData entries | ||
# - For unweighted dataset `entry` is a RooArsSet | ||
# @cdoe | ||
# dataset = ... | ||
# for entry in dataset : ... | ||
# @endcode | ||
# @code | ||
# - For weighted datatsets, `entry` is a tuple of (RooArgSet,weight) | ||
# weighted = ... | ||
# for enry, weight in weighted : ... | ||
# for entry, weight in dataset : ... | ||
# @endcode | ||
# - For unweighted datatsets, `weight` is `None` | ||
# @author Vanya BELYAEV [email protected] | ||
# @date 2011-06-07 | ||
def _rad_iter_ ( self ) : | ||
"""Iterator for RooAbsData | ||
- for unweighted dataset `entry` is a RooArsSet | ||
>>> dataset = ... | ||
>>> for entry in dataset : ... | ||
- for weighted datatsets, `entry` is a tuple of (RooArgSet,weight) | ||
>>> weighted = ... | ||
>>> for entry, weight in weighted : ... | ||
>>> for entry,weight in dataset : ... | ||
- for unweighted datatsets, `weight` is `None` | ||
""" | ||
_l = len ( self ) | ||
for i in range ( 0 , _l ) : | ||
yield self [ i ] | ||
|
||
# =========================================================================== | ||
## Iterator over "good" events in dataset | ||
# @code | ||
# dataset = ... | ||
# for index , entry, weight in dataset.loop ( 'pt>1' ) : | ||
# print (index, entry, weight) | ||
# @endcode | ||
def _rad_loop_ ( dataset , cuts = '' , cutrange = '' , first = 0 , last = LAST_ENTRY , progress = False ) : | ||
"""Iterator for `good' events in dataset | ||
>>> dataset = ... | ||
>>> for index, entry, weight in dataset.loop ( ''pt>1' ) : | ||
>>> print (index, entry, weight) | ||
""" | ||
|
||
first, last = evt_range ( len ( dataset ) , first , last ) | ||
|
||
assert isinstance ( cuts , expression_types ) or not cuts, \ | ||
"Invalid type of cuts: %s" % type ( cuts ) | ||
|
||
assert isinstance ( cutrange , expression_types ) or not cutrange, \ | ||
"Invalid type of cutrange: %s" % type ( cutrange ) | ||
|
||
cuts = str(cuts).strip() if cuts else '' | ||
cutrange = str(cutrange).strip() if cutrange else '' | ||
|
||
fcuts = None | ||
if cuts : fcuts = make_formula ( cuts , cuts , dataset.varlist() ) | ||
|
||
weighted = dataset.isWeighted () | ||
store_errors = weighted and dataset.store_errors () | ||
store_asym_errors = weighted and dataset.store_asym_errors () | ||
simple_weight = weighted and ( not store_errors ) and ( not store_asym_errors ) | ||
|
||
## loop over dataset | ||
source = range ( first , last ) | ||
if progress : source = progress_bar ( source ) | ||
|
||
nevents = 0 | ||
for event in source : | ||
|
||
vars , ww = dataset [ event ] | ||
if not vars : break | ||
|
||
if cutrange and not vars.allInRange ( cutrange ) : continue | ||
|
||
wc = fcuts.getVal() if fcuts else 1.0 | ||
if not wc : continue | ||
|
||
entry , weight = dataset [ event ] | ||
|
||
if weight is None : weight = wc if cuts else weight | ||
else : weight = weight * wc | ||
|
||
nevents += 1 | ||
yield event , entry , weight | ||
|
||
del fcuts | ||
## report summary | ||
if progress : logger.info ( 'loop: %d from %d entries' % ( nevents , last - first ) ) | ||
|
||
|
||
ROOT.RooAbsData.loop = _rad_loop_ | ||
|
||
|
||
# ============================================================================= | ||
## access to the entries in RooAbsData | ||
# @code | ||
# dataset = ... | ||
# weighted = ... ## weighte dataset | ||
# event = dataset [4] ## index | ||
# event , weight = weighted [4] ## index | ||
# | ||
# event , weight = dataset [4] ## index | ||
# events = dataset[0:1000] ## slice | ||
# events = dataset[0:-1:10] ## slice | ||
# events = dataset[ (1,2,3,10) ] ## sequence of indices | ||
# @eendcode | ||
# @endcode | ||
# - For unweighted ddatasets `weight` is `None` | ||
# @author Vanya BELYAEV [email protected] | ||
# @date 2013-03-31 | ||
def _rad_getitem_ ( data , index ) : | ||
"""Get the entry from RooDataSet | ||
>>> dataset = ... ## normal dataset | ||
>>> weighted = ... ## weighted dataset | ||
>>> event = dataset [4] ## index | ||
>>> event, weight = weighted [4] ## index | ||
>>> dataset = ... | ||
>>> event, weight = dataset [4] ## index | ||
- For unweighted ddatsets `weight` is `None` | ||
>>> events = dataset[0:1000] ## slice | ||
>>> events = dataset[0:-1:10] ## slice | ||
>>> events = dataset[ (1,2,3,4,10) ] ## sequnce of indices | ||
""" | ||
|
||
N = len ( data ) | ||
|
||
if isinstance ( index , integer_types ) and index < 0 : index += N | ||
|
@@ -123,7 +177,7 @@ def _rad_getitem_ ( data , index ) : | |
if isinstance ( index , integer_types ) and 0 <= index < N : | ||
|
||
entry = data.get ( index ) | ||
if not data.isWeighted() : return entry ## RETURN | ||
if not data.isWeighted() : return entry, None ## RETURN | ||
|
||
weight = data.weight() | ||
if data.store_asym_error () : | ||
|
@@ -186,6 +240,10 @@ def _rad_getitem_ ( data , index ) : | |
|
||
raise IndexError ( 'Invalid index %s'% index ) | ||
|
||
# ============================================================================== | ||
|
||
|
||
|
||
# ============================================================================== | ||
## Get (asymmetric) weigth errors for the current entry in dataset | ||
# @code | ||
|
@@ -405,7 +463,7 @@ def _rad_mul_ ( ds1 , ds2 ) : | |
wel , weh = ds1.weightErrors() | ||
entry = ds1.get ( i ) , ds1.weight () , wel , weh | ||
elif store_error : | ||
entry = ds1.get ( i ) , ds1.weight () , ds1.weightError () | ||
entry = ds1.get ( i ) , ds1.weight () , ds1.weightError () | ||
elif weighted : | ||
entry = ds1.get ( i ) , ds1.weight () | ||
else: | ||
|
@@ -852,24 +910,20 @@ def _rds_make_unique_ ( dataset , | |
criterium = criterium , | ||
seed = seed , | ||
report = report ) : | ||
|
||
if weighted : | ||
|
||
entry , weight = dataset [ index ] | ||
if store_asym_errors and isinstance ( weight , VAE ) : | ||
ds.add ( entry , weight.value , weight.neg_error , weight.pos_erroe ) | ||
elif store_errors and isinstance ( weight , VE ) : | ||
ds.add ( entry , weight.value () ) | ||
elif isinsance ( weight , num_types ) : | ||
ds.add ( entry , float ( weight ) ) | ||
else : | ||
entry, weight = dataset [ index ] | ||
|
||
if store_asym_errors and isinstance ( weight , VAE ) : | ||
ds.add ( entry , weight.value , weight.neg_error , weight.pos_erroe ) | ||
elif store_errors and isinstance ( weight , VE ) : | ||
ds.add ( entry , weight.value () ) | ||
elif weighted and isinstance ( weight , num_types ) : | ||
ds.add ( entry , float ( weight ) ) | ||
elif weighted : | ||
raise TypeError ( 'Inconsistent sae/se/w %s/%s/%s ' % ( store_asym_errors , | ||
store_errors , | ||
type ( weight ) ) ) | ||
else : | ||
|
||
entry = dataset [ index ] | ||
ds.add ( entry ) | ||
|
||
return ds | ||
|
@@ -3058,62 +3112,43 @@ def get_result ( data ) : return _array ( 'd' , data ) | |
## Iterator for rows in dataset | ||
# @code | ||
# dataset = ... | ||
# for row , weight in dataset.rows ( 'pt, pt/p, mass ' , 'pt>1' ) : | ||
# print (row, weight) | ||
# for index, row , weight in dataset.rows ( 'pt, pt/p, mass ' , 'pt>1' ) : | ||
# print (index, row, weight) | ||
# @endcode | ||
def _rad_rows_ ( dataset , variables = [] , cuts = '' , cutrange = '' , first = 0 , last = LAST_ENTRY ) : | ||
def _rad_rows_ ( dataset , | ||
variables = [] , | ||
cuts = '' , | ||
cutrange = '' , | ||
first = 0 , | ||
last = LAST_ENTRY , | ||
progress = False ) : | ||
"""Iterator for rows in dataset | ||
>>> dataset = ... | ||
>>> for row , weight in dataset.rows ( 'pt, pt/p, mass ' , 'pt>1' ) : | ||
>>> print (row, weight) | ||
""" | ||
|
||
first, last = evt_range ( len ( dataset ) , first , last ) | ||
|
||
varlst, cuts, _ = vars_and_cuts ( variables , cuts ) | ||
|
||
vars = strings ( varlst ) | ||
|
||
first , last = evt_range ( len ( dataset ) , first , last ) | ||
varlst , cuts, _ = vars_and_cuts ( variables , cuts ) | ||
|
||
formulas = [] | ||
varlist = dataset.varlist () | ||
for v in vars : | ||
for v in varlst : | ||
f0 = make_formula ( v , v , varlist ) | ||
formulas.append ( f0 ) | ||
|
||
fcuts = None | ||
if cuts : fcuts = make_formula ( cuts , cuts , varlist ) | ||
|
||
weighted = dataset.isWeighted () | ||
store_errors = weighted and daatset.store_errors () | ||
store_asym_errors = weighted and daatset.store_asym_errors () | ||
simple_weight = weighted and ( not srore_errors ) and ( not store_asym_errors ) | ||
## loop over dataset | ||
for event in range ( first , last ) : | ||
|
||
ww = None | ||
if weighted : vars, ww = dataset [ event ] | ||
else : vars = dataset [ event ] | ||
|
||
if not vars : break | ||
|
||
if cutrange and not vars.allInRange ( cutrange ) : continue | ||
|
||
wc = fcuts.getVal() if fcuts else 1.0 | ||
if not wc : continue | ||
|
||
wd = ww if weighted and not ( ww is None ) else 1.0 | ||
## loop over dataset | ||
for index, entry, weight in _rad_loop_ ( dataset , | ||
cuts = cuts , | ||
cutrange = cutrange , | ||
first = first , | ||
last = last , | ||
progress = progress ) : | ||
|
||
w = wc * wd | ||
if simple_weight and not w : continue | ||
|
||
weight = w if weighted else None | ||
|
||
result = tuple ( tuple ( float ( f ) for f in formulas ) ) | ||
yield get_result ( result ) , weight | ||
|
||
del fcuts | ||
del formulas | ||
yield index , get_result ( result ) , weight | ||
|
||
del formulas | ||
|
||
ROOT.RooAbsData.rows = _rad_rows_ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.