Skip to content

Commit

Permalink
1. add test for ostap.stats.ustat module
Browse files Browse the repository at this point in the history
  1. change the interface for fuctions from `ostap.stats.ustat` module
  1. change the interface for `Ostap::UStat`  class
  • Loading branch information
VanyaBelyaev committed Nov 24, 2023
1 parent f4ab876 commit 01850af
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 46 deletions.
4 changes: 4 additions & 0 deletions ReleaseNotes/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
1. Further optimisation in `Ostap::Math::ChebyshedSum`
1. add new test `ostap/math/tests/test_math.poly.py`
1. Reduce usage of `Ostap::Utils::Iterator`
1. add test for `ostap.stats.ustat` module

## Backward incompatible:

1. change the interface for fuctions from `ostap.stats.ustat` module
1. change the interface for `Ostap::UStat` class

## Bug fixes:


Expand Down
126 changes: 126 additions & 0 deletions ostap/stats/tests/test_stats_ustat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# =============================================================================
# @file ostap/stats/tests/test_stats_ustat.py
# Test uStatistics for Goodness-Of-Fit tests
# Copyright (c) Ostap developpers.
# =============================================================================
""" Test uStatistics for goodness-of-fit tests
"""
# =============================================================================
import ostap.stats.ustat as uStat
import ostap.fitting.models as Models
from ostap.utils.timing import timing
from ostap.core.pyrouts import SE
from ostap.plotting.canvas import use_canvas
import ROOT, random, math
# =============================================================================
from ostap.logger.logger import getLogger
if '__main__' == __name__ : logger = getLogger ( 'tests_stats_ustat' )
else : logger = getLogger ( __name__ )
# ==============================================================================


histos = set()

def test_stats_ustat_G2D () :

logger = getLogger ( "test_stats_ustat_G2D" )

x = ROOT.RooRealVar ( 'x' , 'x-variable' , 0 , 10 )
y = ROOT.RooRealVar ( 'y' , 'y-variable' , 0 , 10 )

pdf = Models.Gauss2D_pdf ( 'G2D' , x , y ,
muX = ( 5 , 4 , 6 ) ,
muY = ( 5 , 4 , 6 ) ,
sigmaX = ( 1 , 0.5 , 1.5 ) ,
sigmaY = ( 2 , 1.5 , 2.5 ) ,
theta = ( math.pi/4 , math.pi/8 , math.pi/2 ) )



for n in ( 100 , 200 , 500 , 1000 ) : ## , 5000 , 10000 ) :


title = "2D: N=%4d test" % n
with timing ( title , logger = logger ) , use_canvas ( title , wait = 5 ) :

pdf.muX = 5
pdf.muY = 5
pdf.sigmaX = 1
pdf.sigmaY = 2
pdf.theta = math.pi/4


data = pdf.generate ( n )

pdf.fitTo ( data , silent = True )

t , h , r = uStat.uPlot ( pdf , data )

histos.add ( h )
h.draw()

data.clear()
del data

def test_stats_ustat_G3D () :

logger = getLogger ( "test_stats_ustat_G3D" )

x = ROOT.RooRealVar ( 'x1' , 'x-variable' , 0 , 10 )
y = ROOT.RooRealVar ( 'y2' , 'y-variable' , 0 , 10 )
z = ROOT.RooRealVar ( 'z1' , 'z-variable' , 0 , 10 )

pdf = Models.Gauss3D_pdf ( 'G3D' , x , y , z ,
muX = ( 5 , 4 , 6 ) ,
muY = ( 5 , 4 , 6 ) ,
muZ = ( 5 , 4 , 6 ) ,
sigmaX = ( 1 , 0.5 , 1.5 ) ,
sigmaY = ( 2 , 1.5 , 2.5 ) ,
sigmaZ = ( 3 , 2.5 , 3.5 ) ,
phi = ( math.pi/4 , math.pi/8 , math.pi/2 ) ,
theta = ( math.pi/4 , math.pi/8 , math.pi/2 ) ,
psi = ( math.pi/4 , math.pi/8 , math.pi/2 ) )

for n in ( 100 , 200 , 500 , 1000 ) : ## , 5000 , 10000 ) :


title = "3D: N=%4d test" % n
with timing ( title , logger = logger ) , use_canvas ( title , wait = 5 ) :

pdf.muX = 5
pdf.muY = 5
pdf.muZ = 5
pdf.sigmaX = 1
pdf.sigmaY = 2
pdf.sigmaZ = 3
pdf.phi = math.pi/4
pdf.theta = math.pi/4
pdf.psi = math.pi/4

data = pdf.generate ( n )

pdf.fitTo ( data , silent = True )

t , h , r = uStat.uPlot ( pdf , data )

histos.add ( h )
h.draw()

data.clear()
del data




# =============================================================================
if '__main__' == __name__ :

test_stats_ustat_G2D ()
test_stats_ustat_G3D ()

# =============================================================================
## The END
# =============================================================================

79 changes: 43 additions & 36 deletions ostap/stats/ustat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
#
# >>> pdf = ... ## pdf
# >>> data = ... ## dataset
# >>> pdf.fitTo( data , ... ) ## fit it!
# >>> pdf.fitTo ( data , ... ) ## fit it!
#
# >>> import ostap.stats.ustat as uStat
#
# >>> r,histo = uStat.uPlot ( pdf , data )
# >>> print r ## print fit results
# >>> print ( r ) ## print fit results
# >>> histo.Draw() ## plot the results
#
# @endcode
Expand All @@ -30,7 +30,7 @@
# @date 2011-09-21
#
# ============================================================================
""" ``U-statistics'' useful for ``Goodness-Of-Fit'' tests
""" `U-statistics' useful for `Goodness-Of-Fit' tests
This is a simple translation of the original C++ lines written by Greig Cowan into Python
Expand All @@ -51,17 +51,16 @@
>>> histo.Draw() ## plot the results
"""
# ============================================================================
from __future__ import print_function
__author__ = "Vanya BELYAEV [email protected]"
__date__ = "2010-09-21"
__version__ = "$Revision$"
__version__ = "$Revision:$"
# ============================================================================
__all__ = (
"uPlot" , ## make plot of U-statistics
"uCalc" , ## calculate U-statistics
)
# ============================================================================
from ostap.core.core import cpp, Ostap, hID
from ostap.core.core import Ostap, hID
import ostap.histos.histos
import ROOT, math, ctypes
# =============================================================================
Expand All @@ -83,22 +82,35 @@
# @see Analysis::UStat::calculate
# @date 2011-09-21
def uCalc ( pdf ,
args ,
data ,
histo ,
args = None ,
histo = None ,
silent = False ) :
"""Calculate U-statistics
"""
import sys

if not isinstance ( pdf , ROOT.RooAbsPdf ) or not pdf :
from ostap.fitting.pdfbasic import APDF1
assert pdf and isinstance ( pdf , APDF1 ) , "Invalid type of `pdf'!"
pdf = pdf.pdf

if not args : args = pdf.getObservables ( data )
if not histo : histo = ROOT.nullptr

##
tStat = ctypes.c_double (-1)
sc = Ostap.UStat.calculate ( pdf ,
data ,
histo ,
tStat ,
histo ,
args )
if sc.isFailure() :
logger.error ( "Error from Ostap::UStat::Calculate %s" % sc )

if not histo : histo = None

tStat = float ( tStat.value )
return histo, tStat
return tStat, histo

# =============================================================================
## make the plot of U-statistics
Expand All @@ -112,8 +124,8 @@ def uCalc ( pdf ,
#
# >>> import ostap.stats.ustat as uStat
#
# >>> r,histo = uStat.uPlot ( pdf , data )
# >>> print r ## print fit results
# >>> t , res , histo = uStat.uPlot ( pdf , data )
# >>> print ( res ) ## print fit results
# >>> histo.Draw() ## plot the results
#
# @endcode
Expand All @@ -136,47 +148,42 @@ def uPlot ( pdf ,
>>> import ostap.stats.ustat as uStat
>>> r,histo = uStat.uPlot ( pdf , data )
>>> print r ## print fit results
>>> t, res , histo = uStat.uPlot ( pdf , data )
>>> print ( res ) ## print fit results
>>> histo.Draw() ## plot the results
"""

if not bins or bins <= 0 :
nEntries = float(data.numEntries())
bins = 10
for nbins in ( 100 ,
50 ,
40 ,
25 ,
20 ,
16 ,
10 ,
8 ,
5 ) :
if nEntries/nbins < 100 : continue
bins = 10
for nbins in ( 1000 , 500 ,
200 , 100 ,
50 , 40 ,
25 , 20 ,
16 , 10 ,
8 , 5 ) :
if nEntries/float(nbins) < 100 : continue
bins = nbins
break

histo = ROOT.TH1F ( hID () ,'U-statistics', bins , 0 , 1 )
histo.Sumw2 ( )
histo.SetMinimum ( 0 )

if not args : args = pdf.getObservables ( data )

h,tStat = uCalc ( pdf ,
args ,
data ,
histo ,
silent )

tStat , hh = uCalc ( pdf ,
data ,
args ,
histo ,
silent )

res = histo.Fit ( 'pol0' , 'SLQ0+' )
func = histo.GetFunction ( 'pol0' )
if func :
func.SetLineWidth ( 3 )
func.SetLineColor ( 2 )
func.ResetBit ( 1 << 9 )

return res , histo, float(tStat)
return float ( tStat ) , histo , res

# ===========================================================================

Expand Down
10 changes: 5 additions & 5 deletions source/include/Ostap/UStat.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ namespace Ostap
* @param args (input) the arguments
*/
static Ostap::StatusCode calculate
( const RooAbsPdf& pdf ,
const RooDataSet& data ,
TH1& hist ,
double& tStat ,
RooArgSet * args = 0 ) ;
( const RooAbsPdf& pdf ,
const RooDataSet& data ,
double& tStat ,
TH1* hist = nullptr ,
RooArgSet* args = nullptr ) ;
// ========================================================================
};
// ==========================================================================
Expand Down
10 changes: 5 additions & 5 deletions source/src/UStat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ namespace
Ostap::StatusCode Ostap::UStat::calculate
( const RooAbsPdf& pdf ,
const RooDataSet& data ,
TH1& hist ,
double& tStat ,
TH1* hist ,
RooArgSet* args )
{
//
Expand Down Expand Up @@ -187,7 +187,7 @@ Ostap::StatusCode Ostap::UStat::calculate
{ return Ostap::StatusCode ( InvalidItem2 ) ; } // RETURN
//
const double distance = getDistance ( event_i.get() , event_j.get() ) ;
if ( 0 > distance ) { return Ostap::StatusCode( InvalidDist ) ; } // RETURN
if ( 0 > distance ) { return Ostap::StatusCode ( InvalidDist ) ; } // RETURN
//
if ( 0 == j || distance < min_distance )
{ min_distance = distance ; }
Expand All @@ -199,7 +199,7 @@ Ostap::StatusCode Ostap::UStat::calculate
//
const double value = std::exp ( -val1 * num * pdfValue ) ;
//
hist.Fill ( value ) ;
if ( hist ) { hist -> Fill ( value ) ; }
//
tstat.push_back ( value ) ;
//
Expand All @@ -208,7 +208,7 @@ Ostap::StatusCode Ostap::UStat::calculate
//
// calculate T-statistics
//
std::sort ( tstat.begin() , tstat.end() ) ;
std::stable_sort ( tstat.begin() , tstat.end() ) ;
double tS = 0 ;
double nD = tstat.size() ;
for ( TStat::const_iterator t = tstat.begin() ; tstat.end() != t ; ++t )
Expand All @@ -224,5 +224,5 @@ Ostap::StatusCode Ostap::UStat::calculate
return Ostap::StatusCode::SUCCESS ;
}
// ============================================================================
// The END
// The END
// ============================================================================

0 comments on commit 01850af

Please sign in to comment.