Skip to content

Commit

Permalink
fnal fix?
Browse files Browse the repository at this point in the history
  • Loading branch information
VanyaBelyaev committed Aug 12, 2024
1 parent de8b6bc commit 14cfa8d
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 73 deletions.
171 changes: 99 additions & 72 deletions ostap/math/tests/test_math_bernstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import random
import ostap.math.models
import ostap.math.bernstein
from ostap.core.core import Ostap, SE
from ostap.utils.timing import timing
from ostap.core.core import Ostap, SE, isequal
from ostap.plotting.canvas import use_canvas
from ostap.utils.utils import vrange, crange
from ostap.utils.timing import timing
# ============================================================================
# logging
# =============================================================================
Expand All @@ -30,40 +32,42 @@ def test_solve ():
logger = getLogger("test_solve")

# 1) construct function with known roots

## roots in [0,1]
troots = [ random.uniform(0.05,0.95) for i in range ( 4 ) ]
troots.sort()
troots = [ r for r in crange ( 0.05 , 0.95 , 5 ) ]

## roots in [2,9]
roots = troots + [ random.uniform(2,9) for i in range ( 4 ) ]
## roots in [2,10]
roots = troots + [ random.uniform(2,10) for i in range ( 2 ) ]

roots.sort ()

## complex roots
croots = [ complex ( random.uniform(-6,-1) , random.uniform ( -5 , 5 ) ) for i in range ( 3 ) ]
croots += [ complex ( random.uniform( 3, 6) , random.uniform ( -5 , 5 ) ) for i in range ( 3 ) ]

croots = [ complex ( random.uniform(-10,10) , random.uniform ( 1 , 5 ) ) for i in range ( 2 ) ]

## Bernstein polynomial with known roots
bs = Ostap.Math.Bernstein ( 0 , 1 , roots ) ## , croots )
bs = Ostap.Math.Bernstein ( 0 , 1 , roots , croots )

logger.info ( "Bernstein: %s" % bs )

## find roots of Bernstein polynomial
rr = bs.solve()
rr = [ r for r in rr ]
rr.sort()

logger.info ('Roots found : [ %s]' % ( ', '.join ( "%.6f" % r for r in rr ) ) )
logger.info ('Roots true : [ %s]' % ( ', '.join ( "%.6f" % r for r in troots ) ) )

if len ( rr ) != len ( troots ) :
logger.error ( 'Mismatch in number of roots found!' )
else :
diff = 0.0
for i,r in enumerate ( rr ) :
diff += abs ( r - troots[i] )
diff /= len ( rr )
logger.info ( 'Mean root distance is %.4g' % diff )
with use_canvas ( "test_solve" , wait = 5 ) :

bs.draw()
logger.info ('Roots found : [ %s]' % ( ', '.join ( "%.6f" % r for r in rr ) ) )
logger.info ('Roots true : [ %s]' % ( ', '.join ( "%.6f" % r for r in troots ) ) )

if len ( rr ) != len ( troots ) :
logger.error ( 'Mismatch in number of roots found!' )
else :
diff = 0.0
for i,r in enumerate ( rr ) :
diff += abs ( r - troots[i] )
diff /= len ( rr )
logger.info ( 'Mean root distance is %.4g' % diff )

functions.add ( bs )

Expand All @@ -76,44 +80,52 @@ def test_nroots ():
logger = getLogger("test_nroots")

# 1) construct function with known roots

## roots in [0,1]
troots = [ random.uniform(0.01,0.99) for i in range ( 6 ) ]
troots.sort()
troots = [ r for r in crange ( 0.05 , 0.95 , 5 ) ]

## roots in [2,10]
roots = troots + [ random.uniform(2,10) for i in range ( 2 ) ]

## roots in [1,10]
roots = troots + [ random.uniform(1.01,9.99) for i in range ( 4 ) ]
roots.sort ()

## complex roots
croots = [ complex ( random.uniform(-3 , -1 ) , random.gauss ( 0 , 3 ) ) for i in range ( 4 ) ]
croots += [ complex ( random.uniform( 3 , 5 ) , random.gauss ( 0 , 3 ) ) for i in range ( 4 ) ]
croots = [ complex ( random.uniform(-10,10) , random.uniform ( 1 , 5 ) ) for i in range ( 2 ) ]

## Bernstein polynomial with known roots
bs = Ostap.Math.Bernstein ( 0 , 1 , roots , croots )


troots.sort()
logger.info ('Roots true : [%s]' % ( ', '.join ( "%.6f" % r for r in troots ) ) )

delta = 1.e-4 * ( bs.xmax() - bs.xmin() )
for i in range( 20 ) :
with use_canvas ( "test_nroots" , wait = 5 ) :

while True :
x1 = random.uniform ( bs.xmin() , bs.xmax() )
x2 = random.uniform ( x1 , bs.xmax() )
if x2 <= x1 or abs ( x1 - x2 ) < delta : continue
break

nr = bs.nroots ( x1 , x2 )
nt = 0
for r in troots :
if x1 <= r < x2 : nt += 1
bs.draw()

if nr != nt :
logger.error ('Roots between [%.6f, %.6f) : %d [true is %d]' % ( x1 , x2 , nr , nt ) )
else :
logger.info ('Roots between [%.6f, %.6f) : %d [true is %d]' % ( x1 , x2 , nr , nt ) )
## find roots of Bernstein polynomial
rr = bs.solve()
rr = [ r for r in rr ]
rr.sort()
logger.info ('Roots found : [ %s]' % ( ', '.join ( "%.6f" % r for r in rr ) ) )

troots.sort()
logger.info ('Roots true : [%s]' % ( ', '.join ( "%.6f" % r for r in troots ) ) )

delta = 1.e-4 * ( bs.xmax() - bs.xmin() )
for i in range( 20 ) :

while True :
x1 = random.uniform ( bs.xmin() , bs.xmax() )
x2 = random.uniform ( x1 , bs.xmax() )
if x2 <= x1 or abs ( x1 - x2 ) < delta : continue
break

nr = bs.nroots ( x1 , x2 )
nt = 0
for r in troots :
if x1 <= r < x2 : nt += 1

if nr != nt :
logger.error ('Roots between [%.6f, %.6f) : %d [true is %d]' % ( x1 , x2 , nr , nt ) )
else :
logger.info ('Roots between [%.6f, %.6f) : %d [true is %d]' % ( x1 , x2 , nr , nt ) )

functions.add ( bs )

Expand Down Expand Up @@ -243,9 +255,11 @@ def test_poly () :

for i in range(100) :
x = random.uniform ( b.xmin() , b.xmax() )
y = b(x)
if not ymin <= y <= ymax :
raise ValueError ( 'Invalid polynom value y(%s)=%s (%s/%s)' % ( x , y , ymin , ymax ) )
y = b ( x )
if isequal ( ymin , y ) : pass
elif isequal ( ymax , y ) : pass
elif not ymin <= y <= ymax :
raise ValueError ( 'Invalid polynom value y(%s)=%s (%s/%s)' % ( x , y , ymin , ymax ) )

logger.info ('Random poly is OK' )

Expand All @@ -269,11 +283,10 @@ def test_even () :
x1 = random.uniform ( b.xmin() , b.xmax() )
dx = x1 - xmid
x2 = xmid - dx
y1 = b(x1)
y2 = b(x2)
y1 = b ( x1 )
y2 = b ( x2 )
check_equality ( y1 , y2 , 'Invalid BernsteinEven' , 1.e-7 )


logger.info ('Even poly is OK' )

functions.add ( b )
Expand All @@ -300,8 +313,8 @@ def test_monotonic () :
break
y1 , y2 = b ( x1 ) , b ( x2 )

ok1 = ( x1 < x2 ) and ( y1 <= y2 )
ok2 = ( x1 > x2 ) and ( y1 >= y2 )
ok1 = ( x1 < x2 ) and ( ( y1 <= y2 ) or isequal ( y1 , y2 ) )
ok2 = ( x1 > x2 ) and ( ( y1 >= y2 ) or isequal ( y1 , y2 ) )

if not ok1 and not ok2 :
raise ValueError ( 'Invalid Increasing y(%s)=%s>y(%s)=%s' % ( x1 , y1 , x2 , y1 ) )
Expand All @@ -317,13 +330,14 @@ def test_monotonic () :

while True :
x1 = random.uniform ( b.xmin() , b.xmax() )
x2 = random.uniform ( x1 , b.xmax() )
x2 = random.uniform ( x1 , b.xmax() )
if x2 <= x1 or abs ( x1 - x2 ) < 1.e-6 : continue
break
break

y1 , y2 = b ( x1 ) , b ( x2 )

ok1 = ( x1 > x2 ) and ( y1 <= y2 )
ok2 = ( x1 < x2 ) and ( y1 >= y2 )
ok1 = ( x1 > x2 ) and ( ( y1 <= y2 ) or isequal ( y1 , y2 ) )
ok2 = ( x1 < x2 ) and ( ( y1 >= y2 ) or isequal ( y1 , y2 ) )

if not ok1 and not ok2 :
raise ValueError ( 'Invalid Deccreasing y(%s)=%s>y(%s)=%s' % ( x1 , y1 , x2 , y1 ) )
Expand Down Expand Up @@ -372,21 +386,30 @@ def test_convex () :
if b1 < 0 or b2 < 0 :
raise ValueError ( 'Invalid Convex value (b<0)' )

if b.increasing() and b1 > b2 :
if b.increasing() :
if isequal ( b1 , b2 ) : pass
elif b1 > b2 :
raise ValueError ( 'Invalid Convex Increasing' )

if b.decreasing() and b1 < b2 :
raise ValueError ( 'Invalid Convex Decreasing' )
if b.decreasing() :
if isequal ( b1 , b2 ) : pass
elif b1 < b2 :
raise ValueError ( 'Invalid Convex Decreasing' )

bm = b ( xm )
if b.convex () and ( b1 + b2 ) < 2 * bm :
raise ValueError ( 'Invalid Convex!' )
if b.concave () and ( b1 + b2 ) > 2 * bm :
bm = b ( xm )

if b.convex () :
if isequal ( b1 + b2 , 2 * bm ) : pass
elif ( b1 + b2 ) < 2 * bm :
raise ValueError ( 'Invalid Convex!' )

if b.concave () :
if isequal ( b1 + b2 , 2 * bm ) : pass
elif ( b1 + b2 ) > 2 * bm :
raise ValueError ( 'Invalid Concave!' )

logger.info ('Convex poly is OK' )


# =============================================================================
## test for Convex positive polynmomials
def test_convexonly () :
Expand Down Expand Up @@ -424,9 +447,13 @@ def test_convexonly () :
raise ValueError ( 'Invalid Convex value (b<0)' )

bm = b ( xm )
if b.convex () and ( b1 + b2 ) < 2 * bm :
if b.convex () :
if isequal ( b1 + b2 , 2 * bm ) : pass
elif ( b1 + b2 ) < 2 * bm :
raise ValueError ( 'Invalid Convex!' )
if b.concave () and ( b1 + b2 ) > 2 * bm :
if b.concave () :
if isequal ( b1 + b2 , 2 * bm ) : pass
elif ( b1 + b2 ) > 2 * bm :
raise ValueError ( 'Invalid Concave!' )

logger.info ('ConvexOnly poly is OK' )
Expand Down
2 changes: 1 addition & 1 deletion ostap/plotting/makestyles.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def make_styles ( config = None ) :
## generic style
style = root_style ( name )
if not style :
logger.info ( 'Create new generic style %s/%s' % ( name , description ) )
logger.debug ( 'Create new generic style %s/%s' % ( name , description ) )
style = ROOT.TStyle ( name , description )

set_style ( style , section )
Expand Down

0 comments on commit 14cfa8d

Please sign in to comment.