Skip to content

Commit

Permalink
fix ?
Browse files Browse the repository at this point in the history
  • Loading branch information
VanyaBelyaev committed Aug 9, 2024
1 parent 489fe54 commit ca341b1
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 74 deletions.
146 changes: 82 additions & 64 deletions ostap/fitting/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _rad_getitem_ ( data , index ) :

weight = data.weight()
if data.store_asym_error () :
wel , weh = data.weightErrors ()
wel , weh = data.weight_errors ()
weight = VAE ( weight , wel , weh )
elif data.store_error () :
we = data.weightError ()
Expand All @@ -204,41 +204,39 @@ def _rad_getitem_ ( data , index ) :
if 1 == step and start <= stop : return data.reduce ( ROOT.RooFit.EventRange ( start , stop ) ) ## RETURN
index = range ( start , stop , step )


## "list" of indices
assert isinstance ( index , sequence_types ) , "Invalid type of `index':%s" % type ( index )

## the actual loop over entries
if isinstance ( index , sequence_types ) :

weighted = data.isWeighted ()
store_errors = weighted and data.store_error ()
store_asym_errors = weighted and data.store_asym_error ()

result = data.emptyClone ( dsID () )
for i in index :

j = int ( i ) ## the content must be convertible to integers
weighted = data.isWeighted ()
store_errors = weighted and data.store_errors ()
store_asym_errors = weighted and data.store_asym_errors ()

if j < 0 : j += N ## allow negative indices

if not 0 <= j < N : ## is adjusted integer in the proper range ?
logger.error ( 'Invalid index [%s]=%s, skip it' % ( i , j ) )
continue

vars = data.get ( j )

if store_asym_errors :
wel , weh = data.weight_errors ()
result.add ( vars , data.weight () , wel , weh )
elif store_errors :
we = data.weightError ()
result.add ( vars , data.weight () , we )
elif weighted :
result.add ( vars , data.weight () )
else :
result.add ( vars )
## preare the resul
result = data.emptyClone ( dsID () )

## the actual loop over entries
for i , j in enumerate ( index ) :

j = int ( j ) ## the content must be convertible to integers
if j < 0 : j += N ## allow negative indices
if not 0 <= j < N : ## is adjusted integer in the proper range ?
raise IndexError ( 'Invalid index [%s]=%s,' % ( i , j ) )

vars = data.get ( j )

if store_asym_errors :
wel , weh = data.weight_errors ()
result.add ( vars , data.weight () , wel , weh )
elif store_errors :
we = data.weightError ()
result.add ( vars , data.weight () , we )
elif weighted :
result.add ( vars , data.weight () )
else :
result.add ( vars )

return result

raise IndexError ( 'Invalid index %s'% index )
return result

# ==============================================================================

Expand All @@ -248,7 +246,7 @@ def _rad_getitem_ ( data , index ) :
## Get (asymmetric) weigth errors for the current entry in dataset
# @code
# dataset = ...
# weight_error_low, weight_error_high = dataset.weightErrors()
# weight_error_low, weight_error_high = dataset.weight_errors ()
# @endcode
# @see RooAbsData::weightError
def _rad_weight_errors ( data , *etype ) :
Expand Down Expand Up @@ -460,7 +458,7 @@ def _rad_mul_ ( ds1 , ds2 ) :
for i in range ( l ) :
if random.uniform ( 0 , 1 ) <= fraction :
if store_asym_errors :
wel , weh = ds1.weightErrors()
wel , weh = ds1.weight_errors ()
entry = ds1.get ( i ) , ds1.weight () , wel , weh
elif store_error :
entry = ds1.get ( i ) , ds1.weight () , ds1.weightError ()
Expand Down Expand Up @@ -629,36 +627,62 @@ def _rds_bootstrap_ ( dataset , size = 100 , extended = False ) :
del ds

# =============================================================================
## get (random) sub-sample from the dataset
## get (random) unique sub-sample from the dataset
# @code
# data = ...
# subset = data.sample ( 100 ) ## get 100 events
# subset = data.sample ( 0.01 ) ## get 1% of events
# @endcode
def _rad_sample_ ( self , num ) :
"""Get (random) sub-sample from the dataset
"""Get (random) unique sub-sample from the dataset
>>> data = ...
>>> subset = data.sample ( 100 ) ## get 100 events
>>> subset = data.sample ( 0.01 ) ## get 1% of events
"""
if 0 == num : return self.emptyClone ( dsID () )
elif isinstance ( num , integer_types ) and 0 < num :
num = min ( num , len ( self ) )
N = len ( self )
if 0 == num : return self.emptyClone ( dsID () )
elif num == N : return _rad_shuffle_ ( self )
elif isinstance ( num , integer_types ) and 0 < num < N : pass
elif isinstance ( num , float ) and 0 < num < 1 :
from ostap.math.random_ext import poisson
num = poisson ( num * len ( self ) )
num = poisson ( num * N )
return _rad_sample_ ( self , num )
else :
raise TypeError("Unknown ``num''=%s" % num )

result = self.emptyClone ( dsID () )
indices = random.sample ( range ( len ( self ) ) , num )

while indices :
i = indices.pop()
result.add ( self[i] )

return result
raise TypeError("Unknown `num':%s" % num )
##
indices = random.sample ( range ( N ) , num )
return self [ indices ]

# =============================================================================
## get (random) sub-sample from the dataset with replacement
# @code
# data = ...
# subset = data.choince ( 100 ) ## get 100 events
# subset = data.choince ( 0.01 ) ## get 1% of events
# @endcode
def _rad_choice_ ( self , num ) :
"""Get (random) sub-sample from the dataset with replacement
>>> data = ...
>>> subset = data.chince ( 100 ) ## get 100 events
>>> subset = data.choice ( 0.01 ) ## get 1% of events
"""
N = len ( self )
if 0 == num or 0 == N : return self.emptyClone ( dsID () )
elif isinstance ( num , integer_types ) and 0 < num <= N : pass
elif isinstance ( num , float ) and 0 < num < 1 :
from ostap.math.random_ext import poisson
num = poisson ( num * N )
return _rad_choice_ ( self , num )
else :
raise TypeError("Unknown `num':%s" % num )
##
if ( 3 , 6 ) <= sys.version_info :
indices = random.choices ( range ( N ) , k = num )
else :
indices = [ random.randrange ( N ) for i in range ( num ) ]
##
return self [ indices ]


# =============================================================================
## get the shuffled sample
Expand All @@ -671,16 +695,10 @@ def _rad_shuffle_ ( self ) :
>>> data = ....
>>> shuffled = data.shuffle()
"""
result = self.emptyClone ( dsID () )

indices = [ i for i in range( len ( self ) ) ]
indices = [ i for i in range ( len ( self ) ) ]
random.shuffle ( indices )

while indices :
i = indices.pop()
result.add ( self[i] )

return result
return self [ indices ]

# ==============================================================================
## Imporved reduce
Expand Down Expand Up @@ -975,6 +993,7 @@ def _rds_make_unique_ ( dataset ,


ROOT.RooDataSet . sample = _rad_sample_
ROOT.RooDataSet . choice = _rad_choice_
ROOT.RooDataSet . shuffle = _rad_shuffle_


Expand Down Expand Up @@ -2564,17 +2583,16 @@ def ds_to_csv ( dataset , fname , vars = () , more_vars = () , weight_var = '' ,
writer.writerow ( vnames )

## loop over entries in the dataset
for entry in progress_bar ( dataset , max_value = len ( dataset ) , silent = not progress ) :
for entry, _ in progress_bar ( dataset , max_value = len ( dataset ) , silent = not progress ) :

values = [ entry [ a ].getVal() for a in vnames1 ]
values += [ v.getVal() for v in mvars ]

if weighted and sae :
if sae :
e1 , e2 = dataset.weight_errors ()
values += [ dataset.weight() , e1 , e2 ]
elif weighted and se :
e1 , e2 = dataset.weight_errors ()
values += [ dataset.weight() , 0.5*(e1+e2) ]
elif se :
values += [ dataset.weight() , dataset.weightError () ]
elif weighted :
values += [ dataset.weight() ]

Expand Down
47 changes: 41 additions & 6 deletions ostap/fitting/tests/test_fitting_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
evt = ROOT.RooRealVar ( 'Evt' , '#event' , 0 , 1000000 )
run = ROOT.RooRealVar ( 'Run' , '#run' , 0 , 1000000 )
mass = ROOT.RooRealVar ( 'Mass' , 'mass-variable' , 0 , 100 )
pt = ROOT.RooRealVar ( 'Pt' , 'pt-variable' , 0 , 100 )
weight = ROOT.RooRealVar ( 'Weight' , 'some weight' , -10 , 10 )

varset = ROOT.RooArgSet ( evt , run , mass , weight )
varset = ROOT.RooArgSet ( evt , run , mass , pt , weight )
dataset = ROOT.RooDataSet ( dsID () , 'Test Data set-0' , varset )

for r in range ( 100 ) :
Expand All @@ -41,9 +42,9 @@
for e in range ( 100 ) :

evt .setVal ( e )
mass.setVal ( random.uniform ( 0 , 10 ) )
weight.setVal ( random.gauss ( 1 , 0.1 ) )

mass.setVal ( random.uniform ( 0 , 10 ) )
weight.setVal ( random.gauss ( 1 , 0.1 ) )
pt.setVal ( random.uniform ( 1 , 99 ) )
dataset.add ( varset )

if ( 1 <= r <= 2 ) and 3 <= e <= 3 :
Expand All @@ -57,8 +58,8 @@
# =========================================================================================
## (1) print datasets
# =========================================================================================
logger.info ( 'Print unweighted dataset:\n%s' % dataset .table ( prefix = '# ' ) )
logger.info ( 'Print weighted dataset:\n%s' % weighted.table ( prefix = '# ' ) )
logger.info ( 'Print unweighted dataset:\n%s' % dataset .table ( prefix = '# ' ) )
logger.info ( 'Print weighted dataset:\n%s' % weighted.table ( prefix = '# ' ) )

# =========================================================================================
## (2) loop over some subset of entries
Expand All @@ -79,6 +80,40 @@
for index, row , weight in weighted .rows ( 'Mass, 2*Mass, Mass/2' , '(Evt<5) && (Mass<5)' , first = 100 , last = 1000 , progress = False ) :
print ( index, row , weight )

# =========================================================================================
## (4) subset
# =========================================================================================
ss1 = dataset [ 1:500:10 ]
ss2 = weighted [ 1:500:10 ]
logger.info ( 'Print small unweighted dataset:\n%s' % ss1.table ( prefix = '# ' ) )
logger.info ( 'Print small weighted dataset:\n%s' % ss2.table ( prefix = '# ' ) )

# =========================================================================================
## (5) subset/sample (unique)
# =========================================================================================
ss1 = dataset . sample ( 100 )
ss2 = weighted . sample ( 100 )
logger.info ( 'Print small unweighted sample:\n%s' % ss1.table ( prefix = '# ' ) )
logger.info ( 'Print small weighted sample:\n%s' % ss2.table ( prefix = '# ' ) )

# =========================================================================================
## (6) subset/sample (wih r
# =========================================================================================
ss1 = dataset . choice ( 100 )
ss2 = weighted . choice ( 100 )
logger.info ( 'Print small unweighted sample:\n%s' % ss1.table ( prefix = '# ' ) )
logger.info ( 'Print small weighted sample:\n%s' % ss2.table ( prefix = '# ' ) )

# =========================================================================================
## (7) shuffle
# =========================================================================================
ss1 = ss1. shuffle ()
ss2 = ss2. shuffle ()
logger.info ( 'Print shuffle unweighted sample:\n%s' % ss1.table ( prefix = '# ' ) )
logger.info ( 'Print shuffle weighted sample:\n%s' % ss2.table ( prefix = '# ' ) )




# =============================================================================
## The END
Expand Down
4 changes: 2 additions & 2 deletions ostap/tools/tests/test_tools_tmva.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def test_tmva () :
counters = {}
methods = reader.methods[:]
for m in methods : counters[m] = SE()
for evt in ds_S1 :
for evt , _ in ds_S1 :
for method in methods : counters[method] += reader ( method , evt )
title = 'Signal response (RooDataSet)'
table = counters_table ( counters , title = title , prefix = '# ' )
Expand All @@ -274,7 +274,7 @@ def test_tmva () :
counters = {}
methods = reader.methods[:]
for m in methods : counters[m] = SE()
for evt in ds_B1 :
for evt, _ in ds_B1 :
for method in methods : counters[method] += reader ( method , evt )
title = 'Background response (RooDataSet)'
table = counters_table ( counters , title = title , prefix = '# ' )
Expand Down
4 changes: 2 additions & 2 deletions ostap/tools/tests/test_tools_tmva2.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def test_tmva2() :
counters = {}
methods = reader.methods[:]
for m in methods : counters[m] = SE()
for evt in ds_S1 :
for evt, _ in ds_S1 :
for method in methods : counters[method] += reader ( method , evt )
title = 'Signal response (RooDataSet)'
table = counters_table ( counters , title = title , prefix = '# ' )
Expand All @@ -265,7 +265,7 @@ def test_tmva2() :
counters = {}
methods = reader.methods[:]
for m in methods : counters[m] = SE()
for evt in ds_B1 :
for evt, _ in ds_B1 :
for method in methods : counters[method] += reader ( method , evt )
title = 'Background response (RooDataSet)'
table = counters_table ( counters , title = title , prefix = '# ' )
Expand Down

0 comments on commit ca341b1

Please sign in to comment.