diff --git a/ostap/stats/gof_np.py b/ostap/stats/gof_np.py index b4a77ba0..493f486a 100644 --- a/ostap/stats/gof_np.py +++ b/ostap/stats/gof_np.py @@ -40,9 +40,22 @@ import scipy as sp from numpy.lib.recfunctions import structured_to_unstructured as s2u from scipy.spatial.distance import cdist as cdist - sp_version = tuple ( int ( i ) for i in sp.__version__.split('.') ) - if (1,6,0) <= sp_version : qconf = { 'k' : [ 2 ] , 'workers' : -1 } - else : qconf = { 'k' : 2 } + sp_version = tuple ( int ( i ) for i in sp.__version__.split('.') ) + ## + if (1,6,0) <= sp_version : + qconf = { 'k' : [ 2 ] , 'workers' : -1 } + def neigbour_distances ( tree , data ) : + dist , _ = tree.query ( data , **conf ) + dist = dist.flatten() + return dist + else : + qconf = { 'k' : 2 } + def neigbour_distances ( tree , data ) : + dist , _ = tree.query ( data , **conf ) + dist = np.delete ( dist , 0 , axis = 1 ) + dist = dist.flatten() + return dist + # ========================================================================= except ImportError : # ========================================================================= @@ -395,9 +408,9 @@ def t_value ( self , ds1 , vpdf ) : "Invalid arrays: %s , %s" % ( sh1 , sh2 ) tree = sp.spatial.KDTree ( ds1 ) - uvalues , _ = tree.query ( ds1 , **qconf ) - - uvalues = uvalues.flatten () + ## uvalues , _ = tree.query ( ds1 , **qconf ) + ## uvalues = uvalues.flatten () + uvalues = neighbour_distances ( tree , ds1 ) del tree