Skip to content

Commit

Permalink
fix?
Browse files Browse the repository at this point in the history
  • Loading branch information
VanyaBelyaev committed Aug 7, 2024
1 parent 7a84335 commit 203933d
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 48 deletions.
127 changes: 82 additions & 45 deletions ostap/fitting/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,8 @@ def _rad_iter_ ( self ) :
"""
_l = len ( self )
for i in range ( 0 , _l ) :
## yield self.get ( i )
yield self[i]

yield self [ i ]

# =============================================================================
## access to the entries in RooAbsData
# @code
Expand Down Expand Up @@ -120,11 +119,12 @@ def _rad_getitem_ ( data , index ) :

if isinstance ( index , integer_types ) and index < 0 : index += N

## simple entry index
if isinstance ( index , integer_types ) and 0 <= index < N :

entry = data.get ( index )
if not data.isWeighted() : return entry

entry = data.get ( index )
if not data.isWeighted() : return entry ## RETURN

weight = data.weight()
if data.store_asym_error () :
wel , weh = data.weightErrors ()
Expand All @@ -133,28 +133,30 @@ def _rad_getitem_ ( data , index ) :
we = data.weightError ()
if 0 <= we : weight = VE ( weight , we * we )

return entry, weight

return entry, weight ## RETUR N

## range -> range
elif isinstance ( index , range ) :

## simple case
start , stop , step = index.start , index.stop , index.step
if 1 == step : return data.reduce ( ROOT.RooFit.EventRange ( start , stop ) )
if 1 == step and start <= stop : return data.reduce ( ROOT.RooFit.EventRange ( start , stop ) ) ## RETURN
index = range ( start , stop , step )

## slice -> range
elif isinstance ( index , slice ) :

start , stop , step = index.indices ( N )
if 1 == step : return data.reduce ( ROOT.RooFit.EventRange ( start , stop ) )
if 1 == step and start <= stop : return data.reduce ( ROOT.RooFit.EventRange ( start , stop ) ) ## RETURN
index = range ( start , stop , step )


## the actual loop over entries
if isinstance ( index , sequence_types ) :

weighted = data.isWeighted ()
se = weighted and data.store_error ()
sae = weighted and data.store_asym_error ()
weighted = data.isWeighted ()
store_errors = weighted and data.store_error ()
store_asym_errors = weighted and data.store_asym_error ()

result = data.emptyClone ( dsID () )
for i in index :
Expand All @@ -169,11 +171,11 @@ def _rad_getitem_ ( data , index ) :

vars = data.get ( j )

if weighted and sae :
wel , weh = data.weight_errors()
if store_asym_errors :
wel , weh = data.weight_errors ()
result.add ( vars , data.weight () , wel , weh )
elif weighted and se :
we = data.weightError()
elif store_errors :
we = data.weightError ()
result.add ( vars , data.weight () , we )
elif weighted :
result.add ( vars , data.weight () )
Expand All @@ -191,7 +193,7 @@ def _rad_getitem_ ( data , index ) :
# weight_error_low, weight_error_high = dataset.weightErrors()
# @endcode
# @see RooAbsData::weightError
def _rad_weight_errors( data , *etype ) :
def _rad_weight_errors ( data , *etype ) :
""" Get (asymmetric) weigth errors for the current entry in dataset
>>> dataset = ...
>>> weight_error_low, weigth_error_high = dataset.weight_errors ()
Expand All @@ -206,7 +208,6 @@ def _rad_weight_errors( data , *etype ) :
#
return float ( wel.value ) , float ( weh.value )


# =============================================================================
## Get variables in form of RooArgList
# @author Vanya BELYAEV [email protected]
Expand Down Expand Up @@ -450,7 +451,6 @@ def _rad_div_ ( self , fraction ) :
>>> dataset = ....
>>> small = dataset / 10
"""
print ( 'I AM RAD-DIV', type(fraction) , fraction )
if isinstance ( fraction , num_types ) :
if 1.0 < fraction : return _rad_mul_ ( self , 1.0 / fraction )
elif 1 == fraction : return self.clone ()
Expand Down Expand Up @@ -840,14 +840,37 @@ def _rds_make_unique_ ( dataset ,
>>> unique = dataset.make_unique ( ( 'evt' , 'run' ) , choice = 'max' , criterium = 'PT' )
- CPU performance is more or less reasonable up to dataset with 10^7 entries
"""

weighted = dataset.isWeighted ()
store_errors = weighted and dataset.store_errors ()
store_asym_errors = weighted and dataset.store_asym_errors ()


ds = dataset.emptyClone()
for i in dataset.unique_entries ( entrytag = entrytag ,
choice = choice ,
criterium = criterium ,
seed = seed ,
report = report ) :
ds.add ( dataset [ i ] )
ds = dataset.emptyClone()
for index in dataset.unique_entries ( entrytag = entrytag ,
choice = choice ,
criterium = criterium ,
seed = seed ,
report = report ) :

if weighted :

entry , weight = dataset [ index ]

if store_asym_errors and isinstance ( weight , VAE ) :
ds.add ( entry , weight.value , weight.neg_error , weight.pos_erroe )
elif store_errors and isinstance ( weight , VE ) :
ds.add ( entry , weight.value () )
elif isinsance ( weight , num_types ) :
ds.add ( entry , float ( weight ) )
else :
raise TypeError ( 'Inconsistent sae/se/w %s/%s/%s ' % ( store_asym_errors ,
store_errors ,
type ( weight ) ) )
else :

entry = dataset [ index ]
ds.add ( entry )

return ds

Expand All @@ -862,6 +885,7 @@ def _rds_make_unique_ ( dataset ,
ROOT.RooAbsData . duplicates ,
]


# =============================================================================
## some decoration over RooDataSet
ROOT.RooAbsData . varlist = _rad_vlist_
Expand All @@ -879,6 +903,7 @@ def _rds_make_unique_ ( dataset ,
ROOT.RooAbsData .content = property ( lambda s : s.get() , None , None , """Variables (as ROOT.RooArgSet)""" )

ROOT.RooAbsData . __contains__ = _rad_contains_

ROOT.RooAbsData . __iter__ = _rad_iter_
ROOT.RooAbsData . __getitem__ = _rad_getitem_

Expand Down Expand Up @@ -1146,7 +1171,6 @@ def ds_draw ( dataset ,
## something else ? e.g. DataFrame
assert not cut_range , "ds_draw: `cut_range' is not allowed!"
assert ( first , last ) == ALL_ENTRIES , "ds_draw: `first'/`last' are not allowed!"
print ( 'DS_DRAW', type ( dataset ), varlst , cuts , delta )
ranges = data_range ( dataset , varlst , cuts = cuts , delta = delta )

if not ranges :
Expand Down Expand Up @@ -1174,16 +1198,17 @@ def ds_draw ( dataset ,

# =============================================================================
## get the attibute for RooDataSet
def _ds_getattr_ ( dataset , aname ) :
def _ds_getattr_ ( dataset , attname ) :
"""Get the attibute from RooDataSet
>>> dset = ...
>>> print dset.pt
"""
_vars = dataset.get()
return getattr ( _vars , aname )
return getattr ( _vars , attname )

# =============================================================================
## get the attibute for RooDataSet
# =============================================================================
def get_var ( self, aname ) :
Expand Down Expand Up @@ -2356,12 +2381,10 @@ def _ds_store_error_ ( dataset ) :

if not dataset.isWeighted() : return False ## UNWEIGHTED!

attr = '_store_weight_error'
if not hasattr ( dataset , attr ) :

attr = '_store_weight_error_'
if not hasattr ( dataset , attr ) :
wn = Ostap.Utils.storeError ( dataset )
wn = True if wn else False

wn = True if wn else False
setattr ( dataset , attr , wn )

return getattr ( dataset , attr , '' )
Expand All @@ -2383,34 +2406,48 @@ def _ds_store_asym_error_ ( dataset ) :
- see Ostap::Utils::storeAsymError
"""

print ( 'STORE_ASYM_ERROR/0' , dataset.isWeighted() )

if not dataset.isWeighted() : return False ## UNWEIGHTED!

attr = '_store_asym_weight_error_'
print ( 'STORE_ASYM_ERROR/1' , attr )

attr = '_store_asym_weight_error'
if not hasattr ( dataset , attr ) :

print ( 'STORE_ASYM_ERROR/2' , attr )

wn = Ostap.Utils.storeAsymError ( dataset )
wn = True if wn else False

print ( 'STORE_ASYM_ERROR/3' , attr , wn )

setattr ( dataset , attr , wn )


print ( 'STORE_ASYM_ERROR/4' , attr , hasattr ( dataset , attr ) )

return getattr ( dataset , attr , '' )

# =============================================================================

ROOT.RooDataSet.wname = _ds_wname_
ROOT.RooDataSet.weight_name = _ds_wname_
ROOT.RooDataSet.store_error = _ds_store_error_
ROOT.RooDataSet.store_asym_error = _ds_store_asym_error_
ROOT.RooDataSet.wname = _ds_wname_
ROOT.RooDataSet.weight_name = _ds_wname_
ROOT.RooAbsData.store_error = _ds_store_error_
ROOT.RooAbsData.store_errors = _ds_store_error_
ROOT.RooAbsData.store_asym_error = _ds_store_asym_error_
ROOT.RooAbsData.store_asym_errors = _ds_store_asym_error_

if not hasattr ( ROOT.RooDataSet , 'weightVar' ) :
ROOT.RooDataSet.weightVar = _ds_weight_var_


_new_methods_ += [
ROOT.RooDataSet.wname ,
ROOT.RooDataSet.weight_name ,
ROOT.RooDataSet.store_error ,
ROOT.RooDataSet.store_asym_error ,
ROOT.RooDataSet.wname ,
ROOT.RooDataSet.weight_name ,
ROOT.RooAbsData.store_error ,
ROOT.RooAbsData.store_errors ,
ROOT.RooAbsData.store_asym_error ,
ROOT.RooAbsData.store_asym_errors ,
]

if (3,0) <= sys.version_info :
Expand Down
3 changes: 2 additions & 1 deletion ostap/fitting/roocollections.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def _ras_getattr_ ( self , aname ) :
>>> print aset.pt
"""
_v = self.find ( aname )
if not _v : raise AttributeError
if not _v :
raise AttributeError("%s: invalid attribute `%s'" % ( type ( self ) , aname ) )
return _v

# =============================================================================
Expand Down
7 changes: 5 additions & 2 deletions ostap/fitting/roostats.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,9 +791,12 @@ def plot ( self ) :
gr2 = ROOT.TGraph ()
gr1.red ()
gr2.blue ()

ps = fc.GetPointsToScan()
for entry in ps :

weighted = ps.isWeighted()
for entry in ps :
if weighted : entry , _ = entry
point = float ( entry[0] )
if fci.IsInInterval ( entry ) : gr1.append ( point , 1 )
else : gr2.append ( point , 0 )
Expand Down

0 comments on commit 203933d

Please sign in to comment.