Skip to content

Commit

Permalink
fix?
Browse files Browse the repository at this point in the history
  • Loading branch information
VanyaBelyaev committed Aug 7, 2024
1 parent 23ba7c8 commit 31740bc
Showing 1 changed file with 182 additions and 67 deletions.
249 changes: 182 additions & 67 deletions ostap/fitting/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,32 @@

# =============================================================================
## iterator for RooAbsData
# - For unweighted dataset `entry` is a RooArsSet
# @cdoe
# dataset = ...
# for entry in dataset : ...
# @endcode
# @code
# - For weighted datatsets, `entry` is a tuple of (RooArgSet,weight)
# weighted = ...
# for enry, weight in weighted : ...
# @endcode
# @author Vanya BELYAEV [email protected]
# @date 2011-06-07
def _rad_iter_ ( self ) :
"""Iterator for RooAbsData
- for unweighted dataset `entry` is a RooArsSet
>>> dataset = ...
>>> for i in dataset : ...
>>> for entry in dataset : ...
- for weighted datatsets, `entry` is a tuple of (RooArgSet,weight)
>>> weighted = ...
>>> for entry, weight in weighted : ...
"""
_l = len ( self )
for i in range ( 0 , _l ) :
yield self.get ( i )


## yield self.get ( i )
yield self[i]
# =============================================================================
## access to the entries in RooAbsData
# @code
Expand Down Expand Up @@ -109,13 +123,13 @@ def _rad_getitem_ ( data , index ) :

if isinstance ( index , integer_types ) and 0 <= index < N :

entry = daat.get ( index )

entry = data.get ( index )
if not data.isWeighted() : return entry

weight = data.weight()

if data.store_asym_error () : pass
weight = data.weight()
if data.store_asym_error () :
wel, weh = data.weightErrors ()
return entry, weight , wel, weh ## RETURN
elif data.store_error () :
we = data.weightError ()
if 0 <= we : weight = VE ( weight , we * we )
Expand Down Expand Up @@ -228,18 +242,37 @@ def _rad_contains_ ( self , aname ) :
# dset2 = ...
# dset1 += dset2
# @endcode
def _rad_iadd_ ( self , another ) :
def _rad_iadd_ ( ds1 , ds2 ) :
""" Merge/append two datasets into a single one
- two datasets must have identical structure
>>> dset1 = ...
>>> dset2 = ...
>>> dset1 += dset2
"""
if isinstance ( self , ROOT.RooDataSet ) :
if isinstance ( another , ROOT.RooDataSet ) :
self.append ( another )
return self

if isinstance ( ds1 , ROOT.RooDataSet ) and isinstance ( ds2 , ROOT.RooDataSet ) :
##
w1 = ds1.isWeighted()
w2 = ds2.isWeighted()
##
if w1 and w2 : pass
elif w1 : return NotImplemented
elif w2 : return NotImplemented
##
npw1 = ds1.IsNonPoissonWeighted()
npw2 = ds2.IsNonPoissonWeighted()
##
if npw1 and npw2 : pass
elif npw1 : return NotImplemented
elif npw2 : return NotImplemented
##
vs1 = set ( v.name for v in ds1.get() )
vs2 = set ( v.name for v in ds2.get() )
##
if vs1 != vs2 : return NotImplemented
##
ds1.append ( ds2 )
return ds1

return NotImplemented

# =============================================================================
Expand All @@ -249,37 +282,71 @@ def _rad_iadd_ ( self , another ) :
# dset2 = ...
# dset = dset1 + dset2
# @endcode
def _rad_add_ ( self , another ) :
def _rad_add_ ( ds1 , ds2 ) :
""" Merge/append two datasets into a single one
- two datasets must have identical structure
>>> dset1 = ...
>>> dset2 = ...
>>> dset = dset1 + dset2
"""
if isinstance ( self , ROOT.RooDataSet ) :
if isinstance ( another , ROOT.RooDataSet ) :
result = self.emptyClone( dsID() )
result.append ( self )
result.append ( another )
return result
if isinstance ( ds1 , ROOT.RooDataSet ) and isinstance ( ds2 , ROOT.RooDataSet ) :
##
w1 = ds1.isWeighted()
w2 = ds2.isWeighted()
##
if w1 and w2 : pass
elif w1 : return NotImplemented
elif w2 : return NotImplemented
##
npw1 = ds1.isNonPoissonWeighted()
npw2 = ds2.isNonPoissonWeighted()
##
if npw1 and npw2 : pass
elif npw1 : return NotImplemented
elif npw2 : return NotImplemented
##
vs1 = set ( v.name for v in ds1.get() )
vs2 = set ( v.name for v in ds2.get() )
##
if vs1 != vs2 : return NotImplemented
##
result = ds1.emptyClone( dsID() )
result.append ( ds1 )
result.append ( ds2 )
return result

return NotImplemented


# =============================================================================
# merge/append two datasets into a single one
def _rad_imul_ ( self , another ) :
def _rad_imul_ ( ds1 , ds2 ) :
""" Merge/append two datasets into a single one
- two datasets must have the same number of entries!
>>> dset1 = ...
>>> dset2 = ...
>>> dset1 *= dset2
"""
if isinstance ( another , ROOT.RooAbsData ) :
if len ( self ) == len ( another ) :
self.merge ( another )
return self

if isinstance ( ds1 , ROOT.RooDataSet ) and isinstance ( ds2 , ROOT.RooDataSet ) :
if len ( ds1 ) != len ( ds2 ) : return NotImplemented
##
w1 = ds1.isWeighted()
w2 = ds2.isWeighted()
##
if w1 and w2 : pass
elif w1 : return NotImplemented
elif w2 : return NotImplemented
##
npw1 = ds1.isNonPoissonWeighted()
npw2 = ds2.isNonPoissonWeighted()
##
if npw1 and npw2 : pass
elif npw1 : return NotImplemented
elif npw2 : return NotImplemented
##
ds1.merge ( ds2 )
return ds1

return NotImplemented

# =============================================================================
Expand All @@ -291,7 +358,7 @@ def _rad_imul_ ( self , another ) :
# ## merge two dataset of the same lenth
# merged = dataset1 * dataset2
# @endcode
def _rad_mul_ ( self , another ) :
def _rad_mul_ ( ds1 , ds2 ) :
"""
- (1) Get small (random) fraction of dataset:
>>> dataset = ....
Expand All @@ -300,32 +367,79 @@ def _rad_mul_ ( self , another ) :
>>> dataset3 = dataset1 * dataset2
"""

if isinstance ( another , ROOT.RooAbsData ) :

if len ( self ) == len ( another ) :

result = self.emptyClone( dsID() )
result.append ( self )
result.merge ( another )
return result

return NotImplemented

fraction = another
if isinstance ( ds1 , ROOT.RooDataSet ) and isinstance ( ds2 , ROOT.RooDataSet ) :
if len ( ds1 ) != len ( ds2 ) : return NotImplemented
##
w1 = ds1.isWeighted()
w2 = ds2.isWeighted()
##
if w1 and w2 : pass
elif w1 : return NotImplemented
elif w2 : return NotImplemented
##
npw1 = ds1.isNonPoissonWeighted()
npw2 = ds2.isNonPoissonWeighted()
##
if npw1 and npw2 : pass
elif npw1 : return NotImplemented
elif npw2 : return NotImplemented
##
result = ds1.emptyClone( dsID() )
result.append ( ds1 )
result.merge ( ds2 )
return ds1

# =======================================================================
fraction = ds2
if isinstance ( fraction , float ) and 0 < fraction < 1 :

res = self.emptyClone()
l = len ( self )
for i in range ( l ) :
if random.uniform(0,1) < fraction : res.add ( self[i] )
weighted = ds1.isWeighted ()
store_errors = weighted and ds1.store_error ()
store_asym_errors = weighted and ds1.store_asym_error ()

res = ds1.emptyClone()
l = len ( ds1 )
for i in range ( l ) :
if random.uniform ( 0 , 1 ) <= fraction :
if store_asym_errors :
wel , weh = ds1.weightErrors()
entry = ds1.get ( i ) , ds1.weight () , wel , weh
elif store_error :
entry = ds1.get ( i ) , ds1.weight () , ds1.weightError ()
elif weighted :
entry = ds1.get ( i ) , ds1.weight ()
else:
entry = ds1.get ( i ) ,
##
res.append ( *entry )

return res

elif 1 == fraction : return self.clone ()
elif 0 == fraction : return self.emptyClone ()
elif 1 == fraction : return ds1.clone ()
elif 0 == fraction : return ds1.emptyClone ()

return NotImplemented

# =============================================================================
## merge two dataset (of same length) OR get small (random) fraction of dataset
# @code
# ## get smaller dataset:
# dataset = ....
# small = 0.1 * dataset
# ## merge two dataset of the same lenth
# merged = dataset1 * dataset2
# @endcode
def _rad_rmul_ ( ds1 , ds2 ) :
"""
- (1) Get small (random) fraction of dataset:
>>> dataset = ....
>>> small = 0.1 * dataset
- (2) Merge two dataset (of the same length)
>>> dataset3 = dataset1 * dataset2
"""
return _rad_mul_ ( ds1 , ds2 )


# =============================================================================
## get small (random) fraction of dataset
# @code
Expand All @@ -337,9 +451,10 @@ def _rad_div_ ( self , fraction ) :
>>> dataset = ....
>>> small = dataset / 10
"""
if isinstance ( fraction , integer_types ) and 1 < fraction :
return _rad_mul_ ( self , 1.0 / fraction )
elif 1 == fraction : return self.clone ()
print ( 'I AM RAD-DIV', type(fraction) , fraction )
if isinstance ( fraction , num_types ) :
if 1.0 < fraction : return _rad_mul_ ( self , 1.0 / fraction )
elif 1 == fraction : return self.clone ()

return NotImplemented

Expand Down Expand Up @@ -770,19 +885,19 @@ def _rds_make_unique_ ( dataset ,

ROOT.RooAbsData . subset = _rad_subset_

ROOT.RooAbsData . __add__ = _rad_add_
ROOT.RooDataSet . __add__ = _rad_add_
ROOT.RooDataSet . __iadd__ = _rad_iadd_

ROOT.RooAbsData . __mul__ = _rad_mul_
ROOT.RooAbsData . __rmul__ = _rad_mul_
ROOT.RooAbsData . __imul__ = _rad_imul_
ROOT.RooAbsData . __mod__ = _rad_mod_
ROOT.RooAbsData . __div__ = _rad_div_
ROOT.RooAbsData . __truediv__ = ROOT.RooAbsData . __div__
ROOT.RooDataSet . __mul__ = _rad_mul_
ROOT.RooDataSet . __rmul__ = _rad_rmul_
ROOT.RooDataSet . __imul__ = _rad_imul_
ROOT.RooDataSet . __mod__ = _rad_mod_
ROOT.RooDataSet . __div__ = _rad_div_
ROOT.RooDataSet . __truediv__ = ROOT.RooDataSet . __div__


ROOT.RooAbsData . sample = _rad_sample_
ROOT.RooAbsData . shuffle = _rad_shuffle_
ROOT.RooDataSet . sample = _rad_sample_
ROOT.RooDataSet . shuffle = _rad_shuffle_



Expand All @@ -807,15 +922,15 @@ def _rds_make_unique_ ( dataset ,
ROOT.RooDataSet . __add__ ,
ROOT.RooDataSet . __iadd__ ,
#
ROOT.RooAbsData . __mul__ ,
ROOT.RooAbsData . __rmul__ ,
ROOT.RooAbsData . __imul__ ,
ROOT.RooAbsData . __div__ ,
ROOT.RooAbsData . __mod__ ,
ROOT.RooAbsData . __truediv__ ,
ROOT.RooDataSet . __mul__ ,
ROOT.RooDataSet . __rmul__ ,
ROOT.RooDataSet . __imul__ ,
ROOT.RooDataSet . __div__ ,
ROOT.RooDataSet . __mod__ ,
ROOT.RooDataSet . __truediv__ ,
#
ROOT.RooAbsData . sample ,
ROOT.RooAbsData . shuffle ,
ROOT.RooDataSet . sample ,
ROOT.RooDataSet . shuffle ,
#
]

Expand Down

0 comments on commit 31740bc

Please sign in to comment.