Skip to content

Commit

Permalink
?
Browse files Browse the repository at this point in the history
  • Loading branch information
VanyaBelyaev committed Aug 10, 2024
1 parent 82740a8 commit 9654455
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 17 deletions.
40 changes: 24 additions & 16 deletions ostap/fitting/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,12 @@ def _rad_iter_ ( self ) :
# for index , entry, weight in dataset.loop ( 'pt>1' ) :
# print (index, entry, weight)
# @endcode
def _rad_loop_ ( dataset , cuts = '' , cut_range = '' , first = 0 , last = LAST_ENTRY , progress = False ) :
def _rad_loop_ ( dataset ,
cuts = '' ,
cut_range = '' ,
first = 0 ,
last = LAST_ENTRY ,
progress = False ) :
"""Iterator for `good' events in dataset
>>> dataset = ...
>>> for index, entry, weight in dataset.loop ( ''pt>1' ) :
Expand All @@ -104,7 +109,7 @@ def _rad_loop_ ( dataset , cuts = '' , cut_range = '' , first = 0 , last = LAST_
cut_range = str(cut_range).strip() if cut_range else ''

fcuts = None
if cuts : fcuts = make_formula ( cuts , cuts , dataset.varlist() )
if cuts : fcuts = make_formula ( cuts , cuts , dataset.varlist() )

weighted = dataset.isWeighted ()
store_errors = weighted and dataset.store_errors ()
Expand All @@ -118,16 +123,14 @@ def _rad_loop_ ( dataset , cuts = '' , cut_range = '' , first = 0 , last = LAST_
nevents = 0
for event in source :

vars , ww = dataset [ event ]
if not vars : break
entry, weight = dataset [ event ]
if not entry: break

if cut_range and not vars.allInRange ( cut_range ) : continue
if cut_range and not entry.allInRange ( cut_range ) : continue

wc = fcuts.getVal() if fcuts else 1.0
if not wc : continue

entry , weight = dataset [ event ]

if weight is None : weight = wc if cuts else weight
else : weight = weight * wc

Expand Down Expand Up @@ -203,26 +206,31 @@ def _rad_getitem_ ( data , index ) :
if 1 == step and start <= stop : return data.reduce ( ROOT.RooFit.EventRange ( start , stop ) ) ## RETURN
index = range ( start , stop , step )


## "list" of indices
assert isinstance ( index , sequence_types ) , "Invalid type of `index':%s" % type ( index )
## require sequence of indices hee
if not isinstance ( index , sequence_types ) :
raise IndexError ( "Invalid type of `index':%s" % type ( index ) )


weighted = data.isWeighted ()
store_errors = weighted and data.store_errors ()
store_asym_errors = weighted and data.store_asym_errors ()

## preare the resul
## preare the result
result = data.emptyClone ( dsID () )

## the actual loop over entries
## the actual loop over set of entries
for i , j in enumerate ( index ) :

if not isinstance ( j , integer_types ) :
raise IndexError ( 'Invalid index [%s]=%s,' % ( i , j ) )

j = int ( j ) ## the content must be convertible to integers
if j < 0 : j += N ## allow negative indices
if not 0 <= j < N : ## is adjusted integer in the proper range ?
if j < 0 : j += N ## allow `slightly-negative' indices
if not 0 <= j < N : ## adjusted integer in the proper range ?
raise IndexError ( 'Invalid index [%s]=%s,' % ( i , j ) )

vars = data.get ( j )
if not vars : raise IndexError ( 'Invalid index %s' % j ) ##

if store_asym_errors :
wel , weh = data.weight_errors ()
Expand All @@ -240,7 +248,6 @@ def _rad_getitem_ ( data , index ) :
# ==============================================================================



# ==============================================================================
## Get (asymmetric) weigth errors for the current entry in dataset
# @code
Expand Down Expand Up @@ -1703,6 +1710,7 @@ def _rds_makeWeighted_ ( dataset ,
cuts ,
wvarname )


ROOT.RooDataSet.makeWeighted = _rds_makeWeighted_

# =============================================================================
Expand Down
4 changes: 3 additions & 1 deletion ostap/fitting/tests/test_fitting_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@

mass.setVal ( random.uniform ( 0 , 10 ) )
dataset.add ( varset )


"""
# =============================================================================
weighted = dataset.makeWeighted ( 'Weight' )
Expand All @@ -66,7 +69,6 @@
logger.info ( 'Print unweighted dataset:\n%s' % dataset .table ( prefix = '# ' ) )
logger.info ( 'Print weighted dataset:\n%s' % weighted.table ( prefix = '# ' ) )
"""
# =============================================================================
## (2) loop over some subset of entries
Expand Down

0 comments on commit 9654455

Please sign in to comment.