diff --git a/ostap/fitting/dataset.py b/ostap/fitting/dataset.py index 30107ffe..9cc1af2b 100644 --- a/ostap/fitting/dataset.py +++ b/ostap/fitting/dataset.py @@ -181,7 +181,7 @@ def _rad_getitem_ ( data , index ) : weight = data.weight() if data.store_asym_error () : - wel , weh = data.weightErrors () + wel , weh = data.weight_errors () weight = VAE ( weight , wel , weh ) elif data.store_error () : we = data.weightError () @@ -204,41 +204,39 @@ def _rad_getitem_ ( data , index ) : if 1 == step and start <= stop : return data.reduce ( ROOT.RooFit.EventRange ( start , stop ) ) ## RETURN index = range ( start , stop , step ) + + ## "list" of indices + assert isinstance ( index , sequence_types ) , "Invalid type of `index':%s" % type ( index ) - ## the actual loop over entries - if isinstance ( index , sequence_types ) : - - weighted = data.isWeighted () - store_errors = weighted and data.store_error () - store_asym_errors = weighted and data.store_asym_error () - - result = data.emptyClone ( dsID () ) - for i in index : - - j = int ( i ) ## the content must be convertible to integers + weighted = data.isWeighted () + store_errors = weighted and data.store_errors () + store_asym_errors = weighted and data.store_asym_errors () - if j < 0 : j += N ## allow negative indices - - if not 0 <= j < N : ## is adjusted integer in the proper range ? - logger.error ( 'Invalid index [%s]=%s, skip it' % ( i , j ) ) - continue - - vars = data.get ( j ) - - if store_asym_errors : - wel , weh = data.weight_errors () - result.add ( vars , data.weight () , wel , weh ) - elif store_errors : - we = data.weightError () - result.add ( vars , data.weight () , we ) - elif weighted : - result.add ( vars , data.weight () ) - else : - result.add ( vars ) + ## preare the resul + result = data.emptyClone ( dsID () ) + + ## the actual loop over entries + for i , j in enumerate ( index ) : + + j = int ( j ) ## the content must be convertible to integers + if j < 0 : j += N ## allow negative indices + if not 0 <= j < N : ## is adjusted integer in the proper range ? + raise IndexError ( 'Invalid index [%s]=%s,' % ( i , j ) ) + + vars = data.get ( j ) + + if store_asym_errors : + wel , weh = data.weight_errors () + result.add ( vars , data.weight () , wel , weh ) + elif store_errors : + we = data.weightError () + result.add ( vars , data.weight () , we ) + elif weighted : + result.add ( vars , data.weight () ) + else : + result.add ( vars ) - return result - - raise IndexError ( 'Invalid index %s'% index ) + return result # ============================================================================== @@ -248,7 +246,7 @@ def _rad_getitem_ ( data , index ) : ## Get (asymmetric) weigth errors for the current entry in dataset # @code # dataset = ... -# weight_error_low, weight_error_high = dataset.weightErrors() +# weight_error_low, weight_error_high = dataset.weight_errors () # @endcode # @see RooAbsData::weightError def _rad_weight_errors ( data , *etype ) : @@ -460,7 +458,7 @@ def _rad_mul_ ( ds1 , ds2 ) : for i in range ( l ) : if random.uniform ( 0 , 1 ) <= fraction : if store_asym_errors : - wel , weh = ds1.weightErrors() + wel , weh = ds1.weight_errors () entry = ds1.get ( i ) , ds1.weight () , wel , weh elif store_error : entry = ds1.get ( i ) , ds1.weight () , ds1.weightError () @@ -629,36 +627,62 @@ def _rds_bootstrap_ ( dataset , size = 100 , extended = False ) : del ds # ============================================================================= -## get (random) sub-sample from the dataset +## get (random) unique sub-sample from the dataset # @code # data = ... # subset = data.sample ( 100 ) ## get 100 events # subset = data.sample ( 0.01 ) ## get 1% of events # @endcode def _rad_sample_ ( self , num ) : - """Get (random) sub-sample from the dataset + """Get (random) unique sub-sample from the dataset >>> data = ... >>> subset = data.sample ( 100 ) ## get 100 events >>> subset = data.sample ( 0.01 ) ## get 1% of events """ - if 0 == num : return self.emptyClone ( dsID () ) - elif isinstance ( num , integer_types ) and 0 < num : - num = min ( num , len ( self ) ) + N = len ( self ) + if 0 == num : return self.emptyClone ( dsID () ) + elif num == N : return _rad_shuffle_ ( self ) + elif isinstance ( num , integer_types ) and 0 < num < N : pass elif isinstance ( num , float ) and 0 < num < 1 : from ostap.math.random_ext import poisson - num = poisson ( num * len ( self ) ) + num = poisson ( num * N ) return _rad_sample_ ( self , num ) else : - raise TypeError("Unknown ``num''=%s" % num ) - - result = self.emptyClone ( dsID () ) - indices = random.sample ( range ( len ( self ) ) , num ) - - while indices : - i = indices.pop() - result.add ( self[i] ) - - return result + raise TypeError("Unknown `num':%s" % num ) + ## + indices = random.sample ( range ( N ) , num ) + return self [ indices ] + +# ============================================================================= +## get (random) sub-sample from the dataset with replacement +# @code +# data = ... +# subset = data.choince ( 100 ) ## get 100 events +# subset = data.choince ( 0.01 ) ## get 1% of events +# @endcode +def _rad_choice_ ( self , num ) : + """Get (random) sub-sample from the dataset with replacement + >>> data = ... + >>> subset = data.chince ( 100 ) ## get 100 events + >>> subset = data.choice ( 0.01 ) ## get 1% of events + """ + N = len ( self ) + if 0 == num or 0 == N : return self.emptyClone ( dsID () ) + elif isinstance ( num , integer_types ) and 0 < num <= N : pass + elif isinstance ( num , float ) and 0 < num < 1 : + from ostap.math.random_ext import poisson + num = poisson ( num * N ) + return _rad_choice_ ( self , num ) + else : + raise TypeError("Unknown `num':%s" % num ) + ## + if ( 3 , 6 ) <= sys.version_info : + indices = random.choices ( range ( N ) , k = num ) + else : + indices = [ random.randrange ( N ) for i in range ( num ) ] + ## + return self [ indices ] + # ============================================================================= ## get the shuffled sample @@ -671,16 +695,10 @@ def _rad_shuffle_ ( self ) : >>> data = .... >>> shuffled = data.shuffle() """ - result = self.emptyClone ( dsID () ) - indices = [ i for i in range( len ( self ) ) ] + indices = [ i for i in range ( len ( self ) ) ] random.shuffle ( indices ) - - while indices : - i = indices.pop() - result.add ( self[i] ) - - return result + return self [ indices ] # ============================================================================== ## Imporved reduce @@ -975,6 +993,7 @@ def _rds_make_unique_ ( dataset , ROOT.RooDataSet . sample = _rad_sample_ +ROOT.RooDataSet . choice = _rad_choice_ ROOT.RooDataSet . shuffle = _rad_shuffle_ @@ -2564,17 +2583,16 @@ def ds_to_csv ( dataset , fname , vars = () , more_vars = () , weight_var = '' , writer.writerow ( vnames ) ## loop over entries in the dataset - for entry in progress_bar ( dataset , max_value = len ( dataset ) , silent = not progress ) : + for entry, _ in progress_bar ( dataset , max_value = len ( dataset ) , silent = not progress ) : values = [ entry [ a ].getVal() for a in vnames1 ] values += [ v.getVal() for v in mvars ] - if weighted and sae : + if sae : e1 , e2 = dataset.weight_errors () values += [ dataset.weight() , e1 , e2 ] - elif weighted and se : - e1 , e2 = dataset.weight_errors () - values += [ dataset.weight() , 0.5*(e1+e2) ] + elif se : + values += [ dataset.weight() , dataset.weightError () ] elif weighted : values += [ dataset.weight() ] diff --git a/ostap/fitting/tests/test_fitting_dataset.py b/ostap/fitting/tests/test_fitting_dataset.py index 09f3e84f..eda0f1b8 100644 --- a/ostap/fitting/tests/test_fitting_dataset.py +++ b/ostap/fitting/tests/test_fitting_dataset.py @@ -30,9 +30,10 @@ evt = ROOT.RooRealVar ( 'Evt' , '#event' , 0 , 1000000 ) run = ROOT.RooRealVar ( 'Run' , '#run' , 0 , 1000000 ) mass = ROOT.RooRealVar ( 'Mass' , 'mass-variable' , 0 , 100 ) +pt = ROOT.RooRealVar ( 'Pt' , 'pt-variable' , 0 , 100 ) weight = ROOT.RooRealVar ( 'Weight' , 'some weight' , -10 , 10 ) -varset = ROOT.RooArgSet ( evt , run , mass , weight ) +varset = ROOT.RooArgSet ( evt , run , mass , pt , weight ) dataset = ROOT.RooDataSet ( dsID () , 'Test Data set-0' , varset ) for r in range ( 100 ) : @@ -41,9 +42,9 @@ for e in range ( 100 ) : evt .setVal ( e ) - mass.setVal ( random.uniform ( 0 , 10 ) ) - weight.setVal ( random.gauss ( 1 , 0.1 ) ) - + mass.setVal ( random.uniform ( 0 , 10 ) ) + weight.setVal ( random.gauss ( 1 , 0.1 ) ) + pt.setVal ( random.uniform ( 1 , 99 ) ) dataset.add ( varset ) if ( 1 <= r <= 2 ) and 3 <= e <= 3 : @@ -57,8 +58,8 @@ # ========================================================================================= ## (1) print datasets # ========================================================================================= -logger.info ( 'Print unweighted dataset:\n%s' % dataset .table ( prefix = '# ' ) ) -logger.info ( 'Print weighted dataset:\n%s' % weighted.table ( prefix = '# ' ) ) +logger.info ( 'Print unweighted dataset:\n%s' % dataset .table ( prefix = '# ' ) ) +logger.info ( 'Print weighted dataset:\n%s' % weighted.table ( prefix = '# ' ) ) # ========================================================================================= ## (2) loop over some subset of entries @@ -79,6 +80,40 @@ for index, row , weight in weighted .rows ( 'Mass, 2*Mass, Mass/2' , '(Evt<5) && (Mass<5)' , first = 100 , last = 1000 , progress = False ) : print ( index, row , weight ) +# ========================================================================================= +## (4) subset +# ========================================================================================= +ss1 = dataset [ 1:500:10 ] +ss2 = weighted [ 1:500:10 ] +logger.info ( 'Print small unweighted dataset:\n%s' % ss1.table ( prefix = '# ' ) ) +logger.info ( 'Print small weighted dataset:\n%s' % ss2.table ( prefix = '# ' ) ) + +# ========================================================================================= +## (5) subset/sample (unique) +# ========================================================================================= +ss1 = dataset . sample ( 100 ) +ss2 = weighted . sample ( 100 ) +logger.info ( 'Print small unweighted sample:\n%s' % ss1.table ( prefix = '# ' ) ) +logger.info ( 'Print small weighted sample:\n%s' % ss2.table ( prefix = '# ' ) ) + +# ========================================================================================= +## (6) subset/sample (wih r +# ========================================================================================= +ss1 = dataset . choice ( 100 ) +ss2 = weighted . choice ( 100 ) +logger.info ( 'Print small unweighted sample:\n%s' % ss1.table ( prefix = '# ' ) ) +logger.info ( 'Print small weighted sample:\n%s' % ss2.table ( prefix = '# ' ) ) + +# ========================================================================================= +## (7) shuffle +# ========================================================================================= +ss1 = ss1. shuffle () +ss2 = ss2. shuffle () +logger.info ( 'Print shuffle unweighted sample:\n%s' % ss1.table ( prefix = '# ' ) ) +logger.info ( 'Print shuffle weighted sample:\n%s' % ss2.table ( prefix = '# ' ) ) + + + # ============================================================================= ## The END diff --git a/ostap/tools/tests/test_tools_tmva.py b/ostap/tools/tests/test_tools_tmva.py index 866b25a5..a3a88307 100755 --- a/ostap/tools/tests/test_tools_tmva.py +++ b/ostap/tools/tests/test_tools_tmva.py @@ -264,7 +264,7 @@ def test_tmva () : counters = {} methods = reader.methods[:] for m in methods : counters[m] = SE() - for evt in ds_S1 : + for evt , _ in ds_S1 : for method in methods : counters[method] += reader ( method , evt ) title = 'Signal response (RooDataSet)' table = counters_table ( counters , title = title , prefix = '# ' ) @@ -274,7 +274,7 @@ def test_tmva () : counters = {} methods = reader.methods[:] for m in methods : counters[m] = SE() - for evt in ds_B1 : + for evt, _ in ds_B1 : for method in methods : counters[method] += reader ( method , evt ) title = 'Background response (RooDataSet)' table = counters_table ( counters , title = title , prefix = '# ' ) diff --git a/ostap/tools/tests/test_tools_tmva2.py b/ostap/tools/tests/test_tools_tmva2.py index 1aed0e73..a8cfee81 100755 --- a/ostap/tools/tests/test_tools_tmva2.py +++ b/ostap/tools/tests/test_tools_tmva2.py @@ -255,7 +255,7 @@ def test_tmva2() : counters = {} methods = reader.methods[:] for m in methods : counters[m] = SE() - for evt in ds_S1 : + for evt, _ in ds_S1 : for method in methods : counters[method] += reader ( method , evt ) title = 'Signal response (RooDataSet)' table = counters_table ( counters , title = title , prefix = '# ' ) @@ -265,7 +265,7 @@ def test_tmva2() : counters = {} methods = reader.methods[:] for m in methods : counters[m] = SE() - for evt in ds_B1 : + for evt, _ in ds_B1 : for method in methods : counters[method] += reader ( method , evt ) title = 'Background response (RooDataSet)' table = counters_table ( counters , title = title , prefix = '# ' )