Skip to content

Commit

Permalink
fix?
Browse files Browse the repository at this point in the history
  • Loading branch information
VanyaBelyaev committed Dec 5, 2023
1 parent 8971914 commit f0c2e5d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 36 deletions.
13 changes: 9 additions & 4 deletions ostap/fitting/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
)
# =============================================================================
from builtins import range
from collections import defaultdict
from collections import defaultdict
from ostap.core.meta_info import root_info
from ostap.core.core import ( Ostap, VE, SE ,
hID , dsID , strings ,
valid_pointer , split_string ,
Expand Down Expand Up @@ -2272,7 +2273,7 @@ def _ds_symmetrize_ ( ds , var1 , var2 , *vars ) :


# =============================================================================
## get the name of weigth variable in dataset
## get the name of weight variable in dataset
# @code
# dataset = ...
# wname = dataset.wname()
Expand All @@ -2288,8 +2289,12 @@ def _ds_wname_ ( dataset ) :

attr = '_weight_var_name'
if not hasattr ( dataset , attr ) :

wn = Ostap.Utils.getWeight ( dataset )

if ( 6 , 26 ) <= root_info :
wv = dataset.weightVar()
wn = wv.name if wv else Ostap.Utils.getWeight ( dataset )
else : wn = Ostap.Utils.getWeight ( dataset )

setattr ( dataset , attr , wn )

return getattr ( dataset , attr , '' )
Expand Down
34 changes: 2 additions & 32 deletions ostap/fitting/ds2numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,36 +46,6 @@



# input:
# dataset - initial dataset
# var_lst - name of variables to add in numpy array
# weight - Bool value, which work with weights vars in dataset
#ds = DS_to_Numpy(data, ['evt', 'run'], weight)
#ds = DS_to_Numpy_for_old_version(data, ['evt', 'run']) - for old ROOT package version
# output:
# data - numpy array with values of the required variables

#Check the list of variables for duplicates
def find_dublicates_in_var_list(var_lst):
return len(var_lst) != len(set(var_lst))

#add weight variable in numpy array
def add_weight ( ds , data ):

if not ds.isWeighted() : return data

weight = ds.weightVar().GetName()

## creathe the weight array
weights = np.zeros( len ( ds ) , dtype=np.float64)

## fill it
for i in ds : weight_array[i] = ds.weight()

data [ weight ] = weights

return data

# =============================================================================
if np and ( 6 , 26 ) <= root_info : ## 6.26 <= ROOT
# =============================================================================
Expand Down Expand Up @@ -142,7 +112,7 @@ def ds2numpy ( dataset , var_lst , silent = True ) :
categories = [ v.name for v in vars if isinstance ( v , ROOT.RooAbsCategory ) ]

## name of weight variable
weight = '' if not dataset.isWeighted() else dataset.weightVar().GetName ()
weight = '' if not dataset.isWeighted() else dataset.wname ()

dtypes = []
for v in vnames :
Expand Down Expand Up @@ -247,7 +217,7 @@ def ds2numpy ( dataset , var_lst , silent = False ) :
categories = [ v.name for v in vars if isinstance ( v , ROOT.RooAbsCategory ) ]

## name of weight variable
weight = '' if not dataset.isWeighted() else dataset.weightVar().GetName ()
weight = '' if not dataset.isWeighted() else dataset.wname ()

dtypes = []
for v in vnames :
Expand Down

0 comments on commit f0c2e5d

Please sign in to comment.