Skip to content

Commit

Permalink
reduce usage of Ostap::Utils::Iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
VanyaBelyaev committed Nov 22, 2023
1 parent c6109e8 commit b7b1097
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 106 deletions.
32 changes: 21 additions & 11 deletions ostap/fitting/roocollections.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,28 @@ def _rac_names_ ( self ) :
## iterator for RooArgSet
# @author Vanya BELYAEV [email protected]
# @date 2011-06-07
def _ras_iter_ ( self ) :
"""Simple iterator for RootArgSet:
>>> arg_set = ...
>>> for i in arg_set : print i
"""
it = Ostap.Utils.Iterator ( self )
val = it.Next()
while val :
yield val
if root_info < (6,31) :
def _ras_iter_ ( self ) :
"""Simple iterator for RooArgSet:
>>> arg_set = ...
>>> for i in arg_set : print i
"""
it = Ostap.Utils.Iterator ( self ) ## only for ROOT < 6.31
val = it.Next()

del it
while val :
yield val
val = it.Next()
del it
else :
def _ras_iter_ ( self ) :
"""Simple iterator for RooArgSet:
>>> arg_set = ...
>>> for i in arg_set : print i
"""
cnt = self.get()
for v in cnt : yield v



# =============================================================================
## get the attibute for RooArgSet
Expand Down
87 changes: 5 additions & 82 deletions ostap/stats/ustat.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@
# ============================================================================
__all__ = (
"uPlot" , ## make plot of U-statistics
"uDist" , ## calculate U-statistics
"uCalc" , ## calclulate the distance between two data points
"uCalc" , ## calculate U-statistics
)
# ============================================================================
from ostap.core.core import cpp, Ostap, hID
Expand All @@ -72,34 +71,6 @@
if '__main__' == __name__ : logger = getLogger ( 'ostap.stats.ustat' )
else : logger = getLogger ( __name__ )
# =============================================================================
## calculate the distance between two data points
# @author Vanya Belyaev [email protected]
# @date 2011-09-21
def uDist ( x , y ) :
"""Calculate the distance between two data points
"""

ix = Ostap.Utils.Iterator ( x )
iy = Ostap.Utils.Iterator ( y )

dist = 0.0

xv = ix.Next()
yv = iy.Next()
while xv and yv :

if not hasattr ( xv , 'getVal' ) : break
if not hasattr ( yv , 'getVal' ) : break

d = xv.getVal()-yv.getVal()
dist += d*d
xv = ix.Next()
yv = iy.Next()

del ix
del iy

return math.sqrt( dist )

# =============================================================================
## calculate U-statistics
Expand Down Expand Up @@ -129,54 +100,6 @@ def uCalc ( pdf ,
tStat = float ( tStat.value )
return histo, tStat

numEntries = data.numEntries ()
dim = args.getSize ()
data_clone = data.Clone ()

from ostap.logger.progress_bar import progress_bar
for i in progress_bar ( xrange ( numEntries ) ) :

event_x = data_clone.get ( i )
event_i = event_x.selectCommon ( args )

## fill args and evaluate PDF
for a in args : a.setVal ( event_i.getRealValue ( a.GetName () ) )
pdfValue = pdf.getVal ( args )

small_v = 1.e+100
small_j = 0
for j in xrange ( 0 , numEntries ) :

if j == i : continue

event_y = data.get( j )
event_j = event_y.selectCommon ( args )

dist = uDist ( event_i , event_j )

if 0 == j or dist < small_v :
small_v = dist
small_j = j

value = 0
if 1 == dim :
value = small_v
value *= numEntries * pdfValue
elif 2 == dim :
value = small_v**2
value *= numEntries * pdfValue
value *= math.pi
else :
logger.error ( ' Not-implemented (yet) %s ' % dim )
continue

value = math.exp ( -1 * value )

histo.Fill ( value )

del data_clone
return histo

# =============================================================================
## make the plot of U-statistics
#
Expand Down Expand Up @@ -221,7 +144,8 @@ def uPlot ( pdf ,
if not bins or bins <= 0 :
nEntries = float(data.numEntries())
bins = 10
for nbins in ( 50 ,
for nbins in ( 100 ,
50 ,
40 ,
25 ,
20 ,
Expand All @@ -232,8 +156,7 @@ def uPlot ( pdf ,
if nEntries/nbins < 100 : continue
bins = nbins
break
print('#bins %s' % bins)


histo = ROOT.TH1F ( hID () ,'U-statistics', bins , 0 , 1 )
histo.Sumw2 ( )
histo.SetMinimum ( 0 )
Expand Down Expand Up @@ -263,5 +186,5 @@ def uPlot ( pdf ,
docme ( __name__ , logger = logger )

# ===========================================================================
# The END
## The END
# ===========================================================================
2 changes: 1 addition & 1 deletion source/src/Iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ bool Ostap::Utils::Iterator::reset () const
return true ;
}
// ============================================================================
// The END
// The END
// ============================================================================
39 changes: 27 additions & 12 deletions source/src/UStat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,31 @@ namespace
}
#else
//
RooArgSet::const_iterator xIter = x -> begin () ;
RooArgSet::const_iterator yIter = y -> begin () ;
RooArgSet::const_iterator xEnd = x -> end () ;
RooArgSet::const_iterator yEnd = y -> begin () ;
for ( ; xIter != xEnd && yIter != yEnd ; ++xIter, ++yIter )
// RooArgSet::const_iterator xIter = x -> begin () ;
// RooArgSet::const_iterator yIter = y -> begin () ;
// RooArgSet::const_iterator xEnd = x -> end () ;
// RooArgSet::const_iterator yEnd = y -> begin () ;
// for ( ; xIter != xEnd && yIter != yEnd ; ++xIter, ++yIter )
// {
// const RooRealVar* xVar = static_cast<const RooRealVar*> ( *xIter ) ;
// const RooRealVar* yVar = static_cast<const RooRealVar*> ( *yIter ) ;
// //
// const double val = xVar->getVal() - yVar->getVal() ;
// result += val*val ;
// }
//
for ( auto* xa : *x )
{
const RooRealVar* xVar = static_cast<const RooRealVar*> ( *xIter ) ;
const RooRealVar* yVar = static_cast<const RooRealVar*> ( *yIter ) ;
const double val = xVar->getVal() - yVar->getVal() ;
result += val*val ;
if ( nullptr == xa ) { continue ; }
const RooAbsArg* ya = y -> find ( *xa ) ;
if ( nullptr == ya ) { continue ; }
//
const RooRealVar* xv = static_cast<const RooRealVar*> ( xa ) ;
const RooRealVar* yv = static_cast<const RooRealVar*> ( ya ) ;
//
const double val = xv->getVal() - yv->getVal() ;
result += val*val ;
//
}
//
#endif
Expand Down Expand Up @@ -122,8 +137,8 @@ Ostap::StatusCode Ostap::UStat::calculate
//
const unsigned int num = data.numEntries () ;
//
const RooArgSet * event_x = 0 ;
const RooArgSet * event_y = 0 ;
const RooArgSet* event_x = 0 ;
const RooArgSet* event_y = 0 ;
//
for ( unsigned int i = 0 ; i < num ; ++i )
{
Expand All @@ -133,7 +148,7 @@ Ostap::StatusCode Ostap::UStat::calculate
if ( 0 == event_x || 0 == event_x->getSize() )
{ return Ostap::StatusCode ( InvalidItem1 ) ; } // RETURN
//
std::unique_ptr<RooArgSet> event_i ( ( RooArgSet*)event_x->selectCommon( *args ) ) ;
std::unique_ptr<RooArgSet> event_i ( ( RooArgSet*)event_x->selectCommon ( *args ) ) ;
if ( !event_i || 0 == event_i->getSize() )
{ return Ostap::StatusCode ( InvalidItem2 ) ; } // RETURN
//
Expand Down

0 comments on commit b7b1097

Please sign in to comment.