Skip to content

Commit

Permalink
1. modify point-to-point-dissimilarity GoF method: split into chub…
Browse files Browse the repository at this point in the history
…nks for large datasets, use parallel processing for permutations
  • Loading branch information
VanyaBelyaev committed Oct 4, 2024
1 parent 4f32c43 commit 3ab3dc9
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 114 deletions.
2 changes: 1 addition & 1 deletion ReleaseNotes/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
1. Add very simple "efficiency-counter" `ostap.stats.counters.EffCounter`
1. suppress `ostap.core.config.config_goodby` prints for non-interactive sessions
1. add the most primitive splitter `ostap.utils.utils.splitter`

1. modify `point-to-point-dissimilarity` GoF method: split into chubnks for large datasets, use parallel processing for permutations

## Backward incompatible

Expand Down
242 changes: 159 additions & 83 deletions ostap/stats/gof_np.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,37 +23,30 @@
'PPDNP' , ## Point-to-Point Dissimilarity Goodness-of-fit method
)
# =============================================================================
import sys, os, warnings
from ostap.core.ostap_types import string_types
from ostap.stats.gof import normalize2
from ostap.core.core import SE, VE
from ostap.utils.progress_bar import progress_bar
from ostap.utils.utils import split_n_range, splitter
from ostap.utils.basic import numcpu
from ostap.stats.counters import EffCounter
import os, abc, warnings
# =============================================================================
try : # =======================================================================
# =========================================================================
import numpy as np
with warnings.catch_warnings():
warnings.simplefilter ( "ignore" , category = UserWarning )
with warnings.catch_warnings():
warnings.simplefilter ( "ignore" , category = UserWarning )
import numpy as np
import scipy as sp
from numpy.lib.recfunctions import structured_to_unstructured as s2u
from scipy.spatial.distance import cdist as cdist
# =========================================================================
except ImportError :
# =========================================================================
np = None
sp = None
s2u = None
# =============================================================================
try : # =======================================================================
# =========================================================================
import scipy as sp
with warnings.catch_warnings():
warnings.simplefilter ( "ignore" , category = UserWarning )
from scipy.spatial.distance import cdist as cdist
# =========================================================================
except ImportError :
# =========================================================================
sp = None
cdist = None
# =============================================================================
from ostap.core.ostap_types import string_types
from ostap.stats.gof import normalize2
from ostap.core.core import SE, VE
from ostap.utils.progress_bar import progress_bar
import abc
cdist = None
# =============================================================================
# logging
# =============================================================================
Expand Down Expand Up @@ -159,23 +152,23 @@ def normalize ( self , data1 , data2 ) :
# @code
# ds1, ds2 = ...
# gof = ...
# for d1,d2 in gof.permulations ( 100 , ds1 , ds2 ) :
# for d1,d2 in gof.permutations ( 100 , ds1 , ds2 ) :
# ...
# @endcode
def permutations ( self , data1 , data2 ) :
def permutations ( self , Nperm , data1 , data2 ) :
""" Generator of permutations
>>> ds1, ds2 = ...
>>> gof = ...
>>> for d1,d2 in gof.permulations ( 100 , ds1 , ds2 ) :
>>> for d1,d2 in gof.permutations ( 100 , ds1 , ds2 ) :
...
"""
n1 = len ( data1 )
pooled = np.concatenate ( [ data1 , data2 ] )

for i in progress_bar ( self.Nperm , silent = self.silent ) :
##
for i in progress_bar ( Nperm , silent = self.silent ) :
np.random.shuffle ( pooled )
yield pooled [ : n1 ] , pooled [ n1 : ]

del pooled
# =========================================================================
@property
Expand Down Expand Up @@ -213,6 +206,65 @@ def psi_conf ( psi , scale = 1.0 ) :
return psi , None , True

raise TypeError ( "Unknown `psi':%s" % psi )

try :
import joblib as jl
except ImportError :
jl = None

# =====================================================================
## @class PERMUTATOR
# Helper class that allow to run permutattion test in parallel
class PERMUTATOR(object) :
""" Helper class that allow to run permutation test in parallel
"""
def __init__ ( self, gof, t_value , ds1 , ds2 ) :
self.gof = gof
self.ds1 = ds1
self.ds2 = ds2
self.t_value = t_value

# =========================================================================
## run N-permutations
def __call__ ( self , N , silent = True ) :
np.random.seed()
n1 = len ( self.ds1 )
pooled = np.concatenate ( [ self.ds1 , self.ds2 ] )
counter = EffCounter()
for i in progress_bar ( N , silent = silent ) :
np.random.shuffle ( pooled )
tv = self.gof.t_value ( pooled [ : n1 ] , pooled [ n1: ] )
counter += bool ( self.t_value < tv )

del pooled
return counter

# =============================================================================
if jl : # =====================================================================
# =========================================================================
## Run NN-permutations in parallel using joblib
def jl_run ( self , NN , silent = True ) :
""" Run NN-permutations in parallel using joblib """
nj = 4 * numcpu () + 4
lst = splitter ( NN , nj )
##
conf = { 'n_jobs' : -1 , 'verbose' : 0 }
if '1.3.0' <= jl.__version__ < '1.4.0' : conf [ 'return_as' ] = 'generator'
elif '1.4.0' <= jl.__version__ : conf [ 'return_as' ] = 'unordered_generator'
##
input = ( jl.delayed (self)( N ) for N in lst )
results = jl.Parallel ( **conf ) ( input )
counter = EffCounter()
for r in progress_bar ( results , max_value = nj , silent = silent ) :
counter += r
#
return counter

PERMUTATOR.run = jl_run
# =========================================================================
else : # ======================================================================
# =========================================================================
PERMUTATOR.run = None

# =============================================================================
## @class PPD
Expand All @@ -230,7 +282,7 @@ class PPDNP(AGoFNP,GoFNP) :
"""
def __init__ ( self ,
mc2mc = False ,
Nperm = 100 ,
Nperm = 1000 ,
psi = 'gaussian' ,
sigma = 0.05 ,
silent = False ) :
Expand Down Expand Up @@ -263,6 +315,61 @@ def sigma ( self ) :
return self.__sigma

# =========================================================================
## Calculate `sum-of-(transformed)-distances' between all elements in data1 & data2
def sum_distances ( self, data1 , data2 ) :
""" Calculate `sum-of-(transformed)-distances' between all elements in data1 & data2
"""
n1 = len ( data1 )
n2 = len ( data2 )
## if too many distances, process them in chunks
nnmax = 1000000
if n1 * n2 > nnmax :
# ================================================================
if n1 > n2 : ## swap datasets
data1 , data2 = data2 , data1
n1 , n2 = n2 , n1
# =================================================================
result = 0.0
nsplit = ( n1 * n2 ) // nnmax + 2
## split the second (larger) dataset into `nsplit` parts
for f , l in split_n_range ( 0 , n2 , nsplit ) :
result += self.sum_distances ( data1 , data2 [ f : l ] )
return result
##
## how to build distances?
scale = -0.5/(self.sigma**2)
distance_type , transform , _ = psi_conf ( self.psi , scale )
##
## calculate all pair-wise distances
distances = cdist ( data1 , data2 , distance_type ) .flatten () ## data <-> data
distances = distances [ distances > 0 ]
if transform : distances = transform ( distances )
##
return np.sum ( distances )
# =========================================================================
# calculate t-value for (non-structured) 2D arrays
def t_value ( self , ds1 , ds2 ) :
""" Calculate t-value for (non-structured) 2D arrays
"""
##
sh1 = ds1.shape
sh2 = ds2.shape
assert 2 == len ( sh1 ) and 2 == len ( sh2 ) and sh1[1] == sh2[1] , \
"Invalid arrays: %s , %s" % ( sh1 , sh2 )

n1 = len ( ds1 )
n2 = len ( ds2 )
##

## calculate sums of distances, Eq (3.7)
result = self.sum_distances ( ds1 , ds1 ) / ( n1 * ( n1 - 1 ) )
result -= self.sum_distances ( ds1 , ds2 ) / ( n1 * n2 )
if self.mc2mc :
## add the distances from the second dataset?
result += self.sum_distances ( ds2 , ds2 ) / ( n2 * ( n2 - 1 ) )
##
return result
# =========================================================================
## Calculate T-value for two datasets
# @code
# ppd = ...
Expand All @@ -279,51 +386,17 @@ def __call__ ( self , data1 , data2 , normalize = True ) :
>>> t = ppd ( data1 , data1 , normalize = False )
>>> t = ppd ( data1 , data1 , normalize = True )
"""

## transform/normalize ?
if normalize : ds1 , ds2 = self.normalize ( data1 , data2 )
else : ds1 , ds2 = data1 , data2

n1 = len ( ds1 )
n2 = len ( ds2 )

## convert to unstructured datasets
uds1 = s2u ( data1 , copy = False )
uds2 = s2u ( data2 , copy = False )
uds1 = s2u ( data1 , copy = False )
uds2 = s2u ( data2 , copy = False )

## how to build distances?
scale = -0.5/(self.sigma**2)
distance_type , transform , _ = psi_conf ( self.psi , scale )

## list of all distances between points in the first dataset
dist_11 = cdist ( uds1 , uds1 , distance_type ) .flatten () ## data <-> data
dist_11 = dist_11 [ dist_11 > 0 ]
if transform : dist_11 = transform ( dist_11 )

## result = np.sum ( dist_11 ) / ( n1 * ( n1 - 1 ) )
result = np.mean ( dist_11 )
del dist_11

## list of all distances between points in the 1st and 2nd datasets
dist_12 = cdist ( uds1 , uds2 , distance_type ) .flatten () ## data <-> mc
dist_12 = dist_12 [ dist_12 > 0 ]
if transform : dist_12 = transform ( dist_12 )

## result -= np.sum ( dist_12 ) / ( n1 * n2 )
result -= np.mean ( dist_12 )
del dist_12
return self.t_value ( uds1 , uds2 )

## add the distances from the second dataset?
if self.mc2mc :
## list of all distances between points in the 2nd dataset
dist_22 = cdist ( uds2 , uds2 , distance_type ) . flatten() ## mc<-> mc
dist_22 = dist_22 [ dist_22 > 0 ]
if transform : dist_22 = transform ( dist_22 )
## result += np.sum ( dist_22 ) / ( n2 * ( n2 - 1 ) )
result += np.mean ( dist_22 )
del dist_22

return result
# =========================================================================
## Calculate the t & p-values
# @code
Expand All @@ -342,32 +415,35 @@ def pvalue ( self , data1 , data2 , normalize = True ) :
if normalize : ds1 , ds2 = self.normalize ( data1 , data2 )
else : ds1 , ds2 = data1 , data2

## calculate t-value
tvalue = self
t = tvalue ( ds1 , ds2 , normalize = False )

pv = SE ()
## start the permutation test
for d1 , d2 in self.permutations ( ds1 , ds2 ) :
ti = tvalue ( d1 , d2 , normalize = False )
pv += float ( t < ti )
## convert to unstructured datasets
uds1 = s2u ( ds1 , copy = False )
uds2 = s2u ( ds2 , copy = False )

t_value = self.t_value ( uds1 , uds2 )

p = VE ( pv.eff () , pv.effErr() ** 2 )
permutator = PERMUTATOR ( self , t_value , uds1 , uds2 )

if permutator.run : eff = permutator.run ( self.Nperm , silent = self.silent )
else : eff = permutator ( self.Nperm , silent = self.silent )

p = eff.eff

if self.__increasing : p = 1 - p

return t , p
return t_value , p

# =============================================================================
if '__main__' == __name__ :

from ostap.utils.docme import docme
docme ( __name__ , logger = logger )

if not np : logger.warning ('Numpy is not available')
if not sp : logger.warning ('Scipy is not available')
if not s2u : logger.warning ('s2u is not available')
if not cdist : logger.warning ('cdist is not available')
if not np : logger.warning ('Numpy is not available')
if not sp : logger.warning ('Scipy is not available')
if not s2u : logger.warning ('s2u is not available')
if not cdist : logger.warning ('cdist is not available')
if not jl : logger.warning ('joblib is not available')

# =============================================================================
## The END
# =============================================================================
Loading

0 comments on commit 3ab3dc9

Please sign in to comment.