Skip to content

Commit

Permalink
1. add loop methdod for RooAbsData and implement rows in terms…
Browse files Browse the repository at this point in the history
… of `loop`

  1. allow more recusion in `vars_and_cuts` function
  1. add new test
  1. From now for weighted datasets `dataset[i]` returns `(entry,weight)` tuple
  1. from now iteration over weighted dataset gives `(entry,weight)` tuple
  1. change sinature of `dataset.loop` , `dataset.rows` methods to return triplets `index, entry, weight`
  • Loading branch information
VanyaBelyaev committed Aug 8, 2024
1 parent 719419b commit e7bea8c
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 119 deletions.
6 changes: 5 additions & 1 deletion ReleaseNotes/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
1. Some tweaks for style configuration
1. update `ostap.utils.valerrors` & and new test
1. allow to use `width` keyword when `line_width` is not specified for `XXX.draw` method
1. add `loop` methdod for `RooAbsData` and implement `rows` in terms of `loop`
1. allow more recusion in `vars_and_cuts` function
1. add new test

## Backward incompatibl

Expand All @@ -23,7 +26,8 @@
`cut_low` and the argument is not optionl anymore
1. From now for weighted datasets `dataset[i]` returns `(entry,weight)` tuple
1. from now iteration over weighted dataset gives `(entry,weight)` tuple

1. change sinature of `dataset.loop` , `dataset.rows` methods to return triplets `index, entry, weight`

## Bug fixes

1. fix a typo in `ostap.ploting.canvas`
Expand Down
195 changes: 115 additions & 80 deletions ostap/fitting/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
# =============================================================================
from builtins import range
from ostap.utils.progress_bar import progress_bar
from collections import defaultdict
from ostap.core.meta_info import root_info, ostap_version
from ostap.core.core import ( Ostap, VE, SE ,
Expand Down Expand Up @@ -61,60 +62,113 @@
_minv = -0.99 * sys.float_info.max
# =============================================================================
## iterator for RooAbsData entries
# - For unweighted dataset `entry` is a RooArsSet
# @cdoe
# dataset = ...
# for entry in dataset : ...
# @endcode
# @code
# - For weighted datatsets, `entry` is a tuple of (RooArgSet,weight)
# weighted = ...
# for enry, weight in weighted : ...
# for entry, weight in dataset : ...
# @endcode
# - For unweighted datatsets, `weight` is `None`
# @author Vanya BELYAEV [email protected]
# @date 2011-06-07
def _rad_iter_ ( self ) :
"""Iterator for RooAbsData
- for unweighted dataset `entry` is a RooArsSet
>>> dataset = ...
>>> for entry in dataset : ...
- for weighted datatsets, `entry` is a tuple of (RooArgSet,weight)
>>> weighted = ...
>>> for entry, weight in weighted : ...
>>> for entry,weight in dataset : ...
- for unweighted datatsets, `weight` is `None`
"""
_l = len ( self )
for i in range ( 0 , _l ) :
yield self [ i ]

# ===========================================================================
## Iterator over "good" events in dataset
# @code
# dataset = ...
# for index , entry, weight in dataset.loop ( 'pt>1' ) :
# print (index, entry, weight)
# @endcode
def _rad_loop_ ( dataset , cuts = '' , cutrange = '' , first = 0 , last = LAST_ENTRY , progress = False ) :
"""Iterator for `good' events in dataset
>>> dataset = ...
>>> for index, entry, weight in dataset.loop ( ''pt>1' ) :
>>> print (index, entry, weight)
"""

first, last = evt_range ( len ( dataset ) , first , last )

assert isinstance ( cuts , expression_types ) or not cuts, \
"Invalid type of cuts: %s" % type ( cuts )

assert isinstance ( cutrange , expression_types ) or not cutrange, \
"Invalid type of cutrange: %s" % type ( cutrange )

cuts = str(cuts).strip() if cuts else ''
cutrange = str(cutrange).strip() if cutrange else ''

fcuts = None
if cuts : fcuts = make_formula ( cuts , cuts , dataset.varlist() )

weighted = dataset.isWeighted ()
store_errors = weighted and dataset.store_errors ()
store_asym_errors = weighted and dataset.store_asym_errors ()
simple_weight = weighted and ( not store_errors ) and ( not store_asym_errors )

## loop over dataset
source = range ( first , last )
if progress : source = progress_bar ( source )

nevents = 0
for event in source :

vars , ww = dataset [ event ]
if not vars : break

if cutrange and not vars.allInRange ( cutrange ) : continue

wc = fcuts.getVal() if fcuts else 1.0
if not wc : continue

entry , weight = dataset [ event ]

if weight is None : weight = wc if cuts else weight
else : weight = weight * wc

nevents += 1
yield event , entry , weight

del fcuts
## report summary
if progress : logger.info ( 'loop: %d from %d entries' % ( nevents , last - first ) )


ROOT.RooAbsData.loop = _rad_loop_


# =============================================================================
## access to the entries in RooAbsData
# @code
# dataset = ...
# weighted = ... ## weighte dataset
# event = dataset [4] ## index
# event , weight = weighted [4] ## index
#
# event , weight = dataset [4] ## index
# events = dataset[0:1000] ## slice
# events = dataset[0:-1:10] ## slice
# events = dataset[ (1,2,3,10) ] ## sequence of indices
# @eendcode
# @endcode
# - For unweighted ddatasets `weight` is `None`
# @author Vanya BELYAEV [email protected]
# @date 2013-03-31
def _rad_getitem_ ( data , index ) :
"""Get the entry from RooDataSet
>>> dataset = ... ## normal dataset
>>> weighted = ... ## weighted dataset
>>> event = dataset [4] ## index
>>> event, weight = weighted [4] ## index
>>> dataset = ...
>>> event, weight = dataset [4] ## index
- For unweighted ddatsets `weight` is `None`
>>> events = dataset[0:1000] ## slice
>>> events = dataset[0:-1:10] ## slice
>>> events = dataset[ (1,2,3,4,10) ] ## sequnce of indices
"""

N = len ( data )

if isinstance ( index , integer_types ) and index < 0 : index += N
Expand All @@ -123,7 +177,7 @@ def _rad_getitem_ ( data , index ) :
if isinstance ( index , integer_types ) and 0 <= index < N :

entry = data.get ( index )
if not data.isWeighted() : return entry ## RETURN
if not data.isWeighted() : return entry, None ## RETURN

weight = data.weight()
if data.store_asym_error () :
Expand Down Expand Up @@ -186,6 +240,10 @@ def _rad_getitem_ ( data , index ) :

raise IndexError ( 'Invalid index %s'% index )

# ==============================================================================



# ==============================================================================
## Get (asymmetric) weigth errors for the current entry in dataset
# @code
Expand Down Expand Up @@ -405,7 +463,7 @@ def _rad_mul_ ( ds1 , ds2 ) :
wel , weh = ds1.weightErrors()
entry = ds1.get ( i ) , ds1.weight () , wel , weh
elif store_error :
entry = ds1.get ( i ) , ds1.weight () , ds1.weightError ()
entry = ds1.get ( i ) , ds1.weight () , ds1.weightError ()
elif weighted :
entry = ds1.get ( i ) , ds1.weight ()
else:
Expand Down Expand Up @@ -852,24 +910,20 @@ def _rds_make_unique_ ( dataset ,
criterium = criterium ,
seed = seed ,
report = report ) :

if weighted :

entry , weight = dataset [ index ]
if store_asym_errors and isinstance ( weight , VAE ) :
ds.add ( entry , weight.value , weight.neg_error , weight.pos_erroe )
elif store_errors and isinstance ( weight , VE ) :
ds.add ( entry , weight.value () )
elif isinsance ( weight , num_types ) :
ds.add ( entry , float ( weight ) )
else :
entry, weight = dataset [ index ]

if store_asym_errors and isinstance ( weight , VAE ) :
ds.add ( entry , weight.value , weight.neg_error , weight.pos_erroe )
elif store_errors and isinstance ( weight , VE ) :
ds.add ( entry , weight.value () )
elif weighted and isinstance ( weight , num_types ) :
ds.add ( entry , float ( weight ) )
elif weighted :
raise TypeError ( 'Inconsistent sae/se/w %s/%s/%s ' % ( store_asym_errors ,
store_errors ,
type ( weight ) ) )
else :

entry = dataset [ index ]
ds.add ( entry )

return ds
Expand Down Expand Up @@ -3058,62 +3112,43 @@ def get_result ( data ) : return _array ( 'd' , data )
## Iterator for rows in dataset
# @code
# dataset = ...
# for row , weight in dataset.rows ( 'pt, pt/p, mass ' , 'pt>1' ) :
# print (row, weight)
# for index, row , weight in dataset.rows ( 'pt, pt/p, mass ' , 'pt>1' ) :
# print (index, row, weight)
# @endcode
def _rad_rows_ ( dataset , variables = [] , cuts = '' , cutrange = '' , first = 0 , last = LAST_ENTRY ) :
def _rad_rows_ ( dataset ,
variables = [] ,
cuts = '' ,
cutrange = '' ,
first = 0 ,
last = LAST_ENTRY ,
progress = False ) :
"""Iterator for rows in dataset
>>> dataset = ...
>>> for row , weight in dataset.rows ( 'pt, pt/p, mass ' , 'pt>1' ) :
>>> print (row, weight)
"""

first, last = evt_range ( len ( dataset ) , first , last )

varlst, cuts, _ = vars_and_cuts ( variables , cuts )

vars = strings ( varlst )

first , last = evt_range ( len ( dataset ) , first , last )
varlst , cuts, _ = vars_and_cuts ( variables , cuts )

formulas = []
varlist = dataset.varlist ()
for v in vars :
for v in varlst :
f0 = make_formula ( v , v , varlist )
formulas.append ( f0 )

fcuts = None
if cuts : fcuts = make_formula ( cuts , cuts , varlist )

weighted = dataset.isWeighted ()
store_errors = weighted and daatset.store_errors ()
store_asym_errors = weighted and daatset.store_asym_errors ()
simple_weight = weighted and ( not srore_errors ) and ( not store_asym_errors )
## loop over dataset
for event in range ( first , last ) :

ww = None
if weighted : vars, ww = dataset [ event ]
else : vars = dataset [ event ]

if not vars : break

if cutrange and not vars.allInRange ( cutrange ) : continue

wc = fcuts.getVal() if fcuts else 1.0
if not wc : continue

wd = ww if weighted and not ( ww is None ) else 1.0
## loop over dataset
for index, entry, weight in _rad_loop_ ( dataset ,
cuts = cuts ,
cutrange = cutrange ,
first = first ,
last = last ,
progress = progress ) :

w = wc * wd
if simple_weight and not w : continue

weight = w if weighted else None

result = tuple ( tuple ( float ( f ) for f in formulas ) )
yield get_result ( result ) , weight

del fcuts
del formulas
yield index , get_result ( result ) , weight

del formulas

ROOT.RooAbsData.rows = _rad_rows_

Expand Down
9 changes: 5 additions & 4 deletions ostap/fitting/ds2numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def ds2numpy ( dataset , var_lst , silent = True , more_vars = {} ) :

## add PDF values
if funcs :
for i, entry in enumerate ( source ) :
for i, item in enumerate ( source ) :
entry , weight = item
for vname , func , obsvars in funcs :
obsvars.assign ( entry )
data [ vname ] [ i ] = func.getVal()
Expand Down Expand Up @@ -298,15 +299,15 @@ def ds2numpy ( dataset , var_lst , silent = False , more_vars = {} ) :
## make an explict loop:
for i , item in enumerate ( progress_bar ( dataset , silent = silent ) ) :

if weighted : evt, the_weight = item
else : evt = item
evt, the_weight = item

for v in evt :
vname = v.name
if vname in doubles : data [ vname ] [ i ] = float ( v )
elif vname in categories : data [ vname ] [ i ] = int ( v )

if weighted and weight : data [ weight ] [ i ] = float ( the_weight )
if weighted and weight and not ( the_weigth is None ) :
data [ weight ] [ i ] = float ( the_weight )

## add PDF values
for vname , func , obsvars in funcs :
Expand Down
5 changes: 1 addition & 4 deletions ostap/fitting/roostats.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,10 +793,7 @@ def plot ( self ) :
gr2.blue ()

ps = fc.GetPointsToScan()

weighted = ps.isWeighted()
for entry in ps :
if weighted : entry , _ = entry
for entry, _ in ps :
point = float ( entry[0] )
if fci.IsInInterval ( entry ) : gr1.append ( point , 1 )
else : gr2.append ( point , 0 )
Expand Down
Loading

0 comments on commit e7bea8c

Please sign in to comment.