Skip to content

Commit

Permalink
fix?
Browse files Browse the repository at this point in the history
  • Loading branch information
VanyaBelyaev committed Jul 24, 2024
1 parent 9f5ab0f commit 2437cb3
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 292 deletions.
20 changes: 3 additions & 17 deletions ostap/fitting/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,17 +763,9 @@ def _rds_make_unique_ ( dataset ,
ROOT.RooAbsData . sample = _rad_sample_
ROOT.RooAbsData . shuffle = _rad_shuffle_

from ostap.trees.trees import ( _stat_var_ , _stat_vars_ ,
_stat_cov_ , _stat_covs_ , _stat_nEff_ ,
_sum_var_ , _sum_var_old_ , _stat_vct_ )
ROOT.RooAbsData . sumVar = _sum_var_
ROOT.RooAbsData . sumVar_ = _sum_var_old_
ROOT.RooAbsData . statVar = _stat_var_
ROOT.RooAbsData . statVars = _stat_vars_
ROOT.RooAbsData . statCov = _stat_cov_
from ostap.trees.trees import _stat_covs_

ROOT.RooAbsData . statCovs = _stat_covs_
ROOT.RooAbsData . statVct = _stat_vct_
ROOT.RooAbsData . nEff = _stat_nEff_

from ostap.stats.statvars import data_the_moment
ROOT.RooAbsData. the_moment = data_the_moment
Expand Down Expand Up @@ -810,13 +802,7 @@ def _rds_make_unique_ ( dataset ,
ROOT.RooAbsData . sample ,
ROOT.RooAbsData . shuffle ,
#
ROOT.RooAbsData . statVar ,
ROOT.RooAbsData . sumVar ,
ROOT.RooAbsData . sumVar_ ,
#
ROOT.RooAbsData . statCov ,
ROOT.RooAbsData . statCovs ,
ROOT.RooAbsData . statVct ,
]


Expand Down Expand Up @@ -3078,7 +3064,7 @@ def _rad_rows_ ( dataset , variables = [] , cuts = '' , cutrange = '' , first =
# ============================================================================

from ostap.stats.statvars import data_decorate as _dd
_dd ( ROOT.RooAbsData )
_new_methods_ += list ( _dd ( ROOT.RooAbsData ) )

_decorated_classes_ = (
ROOT.RooAbsData ,
Expand Down
138 changes: 106 additions & 32 deletions ostap/stats/statvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
- data_get_stat - get the momentt-based statistics
- data_central_moment - get the central moment (with uncertainty)
- data_mean - get the mean (with uncertainty)
- data_nEff - get the effective number of entries
- data_sum - get the (weigted) sum
- data_variance - get the variance (with uncertainty)
- data_dispersion - get the dispersion (with uncertainty)
- data_rms - get the RMS (with uncertainty)
Expand Down Expand Up @@ -46,6 +48,8 @@
'data_moment' , ## get the moment (with uncertainty)
'data_get_stat' , ## get the momentt-based statistics
'data_central_moment' , ## get the central moment (with uncertainty)
'data_nEff' , ## get the effective number of entries
'data_sum' , ## get the (weigted) sum
'data_mean' , ## get the mean (with uncertainty)
'data_variance' , ## get the variance (with uncertainty)
'data_dispersion' , ## get the dispersion (with uncertainty)
Expand Down Expand Up @@ -77,9 +81,11 @@
# =============================================================================
from builtins import range
from ostap.math.base import isequal, iszero
from ostap.core.core import Ostap, rootException, strings, WSE
from ostap.core.ostap_types import string_types, integer_types, num_types
from ostap.trees.cuts import expression_types, vars_and_cuts
from ostap.core.core import Ostap, rootException, WSE
from ostap.core.ostap_types import ( string_types , integer_types ,
num_types , dictlike_types )
from ostap.trees.cuts import expression_types, vars_and_cuts
from ostap.utils.basic import loop_items
import ostap.stats.moment
import ostap.logger.table as T
import ROOT
Expand Down Expand Up @@ -110,8 +116,8 @@ def data_get_moment ( data , order , center , expression , cuts = '' , *args ) :
>>> print data.get_moment ( 3 , 0.0 , 'mass' , 'pt>1' ) ## ditto
- see Ostap::StatVar::get_moment
"""
assert isinstance ( order , integer_types ) and 0<= order , 'Invalid order %s' % order
assert isinstance ( center , num_types ) , 'Invalid center!'
assert isinstance ( order , integer_types ) and 0 <= order , 'Invalid order %s' % order
assert isinstance ( center , num_types ) , 'Invalid center!'

assert isinstance ( expression , expression_types ) , 'Invalid type of expression!'
assert isinstance ( cuts , expression_types ) , 'Invalid type of cuts/weight!'
Expand Down Expand Up @@ -189,19 +195,19 @@ def data_central_moment ( data , order , expression , cuts = '' , *args ) :
cuts , *args )

# ==============================================================================
## Get the statistics from data
## Get the (s)Statistic-bases statistics/counter from data
# @code
# statobj = Ostap.Math.MinValue()
# data = ...
# result = data.get_stat( statobj , 'x+y' , 'pt>1' )
# @encode
# @see Ostap::Math::Moment
# @see Ostap::Math::WMoment
# @see Ostap::Math::Statistic
# @see Ostap::Math::WStatistic
# @see Ostap::statVar::the_moment
def data_get_stat ( data , statobj , expression , cuts = '' , *args ) :
"""Get the (w)moments -based statistics
"""Get the (W)Statistic-based statistics.counters from data
>>> data = ...
>>> stat = Ostap.Math.MinValue()
>>> stat = Ostap.Math.HarmonicMean()
>>> result = data.get_stat ( stat , 'x/y+z' , '0<qq' )
- see Ostap.Math.Moment
- see Ostap.Math.WMoment
Expand Down Expand Up @@ -230,44 +236,45 @@ def data_get_stat ( data , statobj , expression , cuts = '' , *args ) :
## Get the statistics from data
# @code
# data = ...
# result = data_statistics ( data , 'x+y' , 'pt>1' )
# result = data_statistics ( data , 'x+y' , 'pt>1' )
# results = data_statistics ( data , 'x;y;z' , 'pt>1' ) ## result is dictionary
# @encode
# @see Ostap::StatEntity
# @see Ostap::WStatEntity
# @see Ostap::StatVar::statVar
def data_statistics ( data , expressions , cuts = '' , *args ) :
"""Get statistics from data
>>> data = ...
>>> result = data_statistics ( data , 'x/y+z' , '0<qq' )
>>> data = ...
>>> result = data_statistics ( data , 'x/y+z' , '0<qq' )
>>> results = data_statistics ( data , 'x/y;z' , '0<qq' ) ## result is dictionary
- see Ostap.Math.StatEntity
- see Ostap.Math.WStatEntity
- see Ostap.StatVar.statVar
"""

## decode expressions & cuts
var_lst, cuts, input_string = vars_and_cuts ( expressions , cuts )

if not input_string :

names = strings ( *var_lst )
results = StatVar.Statistics()

## only one name is specified
if input_string :
with rootException() :
if cuts : StatVar.statVars ( data , results , names , cuts , *args )
else : StatVar.statVars ( data , results , names , *args )
return StatVar.statVar ( data , var_lst[0] , cuts , *args )

## several variables are specofied
from ostap.core.core import strings
names = strings ( *var_lst )
results = StatVar.Statistics()
with rootException() :
if cuts : StatVar.statVars ( data , results , names , cuts , *args )
else : StatVar.statVars ( data , results , names , *args )
assert len ( var_lst ) == len ( results ) , \
'Invalid output from StatVar::statVars!'

result = {}
for v,r in zip ( var_lst , results ) : result [ v ] = r
return result
result = {}
for v,r in zip ( var_lst , results ) : result [ v ] = r
return result


## single name
var_name = var_lst [ 0 ]

with rootException() :
return StatVar.statVar ( data , var_name , cuts , *args )



# ==============================================================================
## Get the covarince from dataxpressio
# @code
Expand Down Expand Up @@ -298,6 +305,52 @@ def data_covariance ( data , expression1 , expression2 , cuts = '' , *args ) :
return StatVar.statCov ( data , expression1 , expression2 , *args )
else :
return StatVar.statCov ( data , expression1 , expression2 , cuts , *args )


# ==============================================================================
## Get the (weighted) sum over the variable(s)
# @code
# data = ...
# result = data_sum ( data , 'x+y' , 'pt>1' )
# results = data_sum ( data , 'x;y;z' , 'pt>1' ) ## result is dictionary
# @encode
# @see Ostap::StatVar::statVar
def data_sum ( data , expressions , cuts = '' , *args ) :
"""Get (weighted) sum over the variables
>>> data = ...
>>> result = data_sum ( data , 'x/y+z' , '0<qq' )
>>> results = data_sum ( data , 'x/y;z' , '0<qq' ) ## result is dictionary
- see Ostap.StatVar.statVar
"""

result = data_statistics ( data , expressions , cuts , *args )

if isinstance ( result , dictlike_type ) :
for k , r in loop_items ( result ) :
result [ key ] = VE ( r.sum() , r.sum2() )
else :
result = VE ( result.sum() , result.sumw2() )

return result

# ==============================================================================
## Get the effewctoevnumebr of etries in data
# @code
# data = ...
# result = data_nEff ( data , 'x+y' )
# @encode
# @see Ostap::StatVar::nEff
def data_nEff ( data , expression = '' , *args ) :
"""Get statistics from data
>>> data = ...
>>> result = data_nEff ( data , 'x/y+z' )
- see Ostap.StatVar.nEff
"""

assert isinstance ( expression , expression_types ) , 'Invalid type of expression!'
expression = str ( expression ).strip()

return StatVar.nEff ( expression )

# =============================================================================
## Get harmonic mean over the data
Expand Down Expand Up @@ -925,6 +978,7 @@ def data_decorate ( klass ) :
if hasattr ( klass , 'get_moment' ) : klass.orig_get_moment = klass.get_moment
if hasattr ( klass , 'moment' ) : klass.orig_moment = klass.moment
if hasattr ( klass , 'central_moment' ) : klass.orig_central_moment = klass.central_moment
if hasattr ( klass , 'nEff' ) : klass.orig_nEff = klass.nEff
if hasattr ( klass , 'mean' ) : klass.orig_mean = klass.mean
if hasattr ( klass , 'variance' ) : klass.orig_variance = klass.variance
if hasattr ( klass , 'dispersion' ) : klass.orig_dispersion = klass.dispersion
Expand All @@ -944,7 +998,9 @@ def data_decorate ( klass ) :
klass.get_moment = data_get_moment
klass.moment = data_moment
klass.central_moment = data_central_moment

klass.mean = data_mean
klass.nEff = data_nEff
klass.variance = data_variance
klass.dispersion = data_dispersion
klass.rms = data_rms
Expand All @@ -961,7 +1017,19 @@ def data_decorate ( klass ) :
klass.deciles = data_deciles

if hasattr ( klass , 'get_stats' ) : klass.orig_get_stats = klass.get_stats


if hasattr ( klass , 'statVar' ) : klass.orig_statVar = klass.statVar
if hasattr ( klass , 'statVars' ) : klass.orig_statVars = klass.statVars
if hasattr ( klass , 'sumVar' ) : klass.orig_sumVar = klass.sumVar
if hasattr ( klass , 'sumVars' ) : klass.orig_sumVars = klass.sumVars
if hasattr ( klass , 'statCov' ) : klass.orig_statCov = klass.statCov

klass.statVar = data_statistics
klass.statVars = data_statistics
klass.sumVar = data_sum
klass.sumVars = data_sum
klass.statCov = data_covariance

if hasattr ( klass , 'the_moment' ) : klass.orig_the_moment = klass.the_moment
if hasattr ( klass , 'the_mean' ) : klass.orig_the_mean = klass.the_mean
if hasattr ( klass , 'the_rms' ) : klass.orig_the_rms = klass.the_rms
Expand Down Expand Up @@ -993,6 +1061,7 @@ def data_decorate ( klass ) :
klass.moment ,
klass.central_moment ,
klass.mean ,
klass.nEff ,
klass.variance ,
klass.dispersion ,
klass.rms ,
Expand All @@ -1007,6 +1076,11 @@ def data_decorate ( klass ) :
klass.quintiles ,
klass.deciles ,
klass.get_stats ,
klass.statVar ,
klass.statVars ,
klass.sumVar ,
klass.sumVars ,
klass.statCov ,
klass.the_moment ,
klass.the_mean ,
klass.the_rms ,
Expand Down
4 changes: 2 additions & 2 deletions ostap/tools/tests/test_tools_reweight2.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,8 @@ def prepare_data ( ) :

## 4e) 2D-statistics
mcstat = mcds.statCov('x','y','weight')
logger.info ( tag + ': x/y covariance DATA (unbinned):\n# %s' % ( str ( datastat [2] ).replace ( '\n' , '\n# ' ) ) )
logger.info ( tag + ': x/y covariance MC (unbinned):\n# %s' % ( str ( mcstat [2] ).replace ( '\n' , '\n# ' ) ) )
logger.info ( tag + ': x/y correlation DATA (unbinned): %+.2f' % datastat.correlation () )
logger.info ( tag + ': x/y correlation MC (unbinned): %+.2f' % mcstat.correlation () )

if not active and 3 < iter :
logger.info ( allright ( 'No more iterations, converged after #%d' % iter ) )
Expand Down
Loading

0 comments on commit 2437cb3

Please sign in to comment.