Skip to content

Commit

Permalink
add option to keep crossval folds constant for a session
Browse files Browse the repository at this point in the history
  • Loading branch information
egmcbride committed Nov 8, 2024
1 parent 306f288 commit 89109fd
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions src/dynamic_routing_analysis/decoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def dump_dict_to_zarr(group, data):
print(f'Could not save {key} of type {type(value)}')

# 'linearSVC' or 'LDA' or 'RandomForest'
def decoder_helper(input_data,labels,decoder_type='linearSVC',crossval='5_fold',crossval_index=None,labels_as_index=False):
def decoder_helper(input_data,labels,decoder_type='linearSVC',crossval='5_fold',
crossval_index=None,labels_as_index=False,train_test_split_input=None):
#helper function to decode labels from input data using different decoder models

if decoder_type=='linearSVC':
Expand Down Expand Up @@ -172,6 +173,11 @@ def decoder_helper(input_data,labels,decoder_type='linearSVC',crossval='5_fold',
skf = StratifiedKFold(n_splits=5,shuffle=True)
train_test_split = skf.split(input_data, labels)

elif crossval=='5_fold_constant':
if train_test_split_input is None:
raise ValueError('Must provide train_test_split_input')
train_test_split = train_test_split_input

for train,test in train_test_split:

clf.fit(X[train],y[train])
Expand Down Expand Up @@ -199,7 +205,7 @@ def decoder_helper(input_data,labels,decoder_type='linearSVC',crossval='5_fold',
if decoder_type == 'LDA' or decoder_type == 'RandomForest' or decoder_type=='LogisticRegression':
ypred_proba[test,:] = clf.predict_proba(X[test])
else:
ypred_proba[test,:] = np.full((len(test),len(np.unique(labels))), fill_value=np.nan)
ypred_proba[test,:] = np.full((len(test),len(np.unique(labels))), fill_value=False)

models.append(clf)

Expand All @@ -210,7 +216,7 @@ def decoder_helper(input_data,labels,decoder_type='linearSVC',crossval='5_fold',
if decoder_type == 'LDA' or decoder_type == 'linearSVC':
coefs = clf.coef_
else:
coefs = np.full((X.shape[1]), fill_value=np.nan)
coefs = np.full((X.shape[1]), fill_value=False)

output['cr']=cr_dict_test

Expand Down Expand Up @@ -878,6 +884,7 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units=
decoder_type=params['decoder_type']
# use_coefs=params['use_coefs']
# generate_labels=params['generate_labels']

if 'only_use_all_units' in params:
only_use_all_units=params['only_use_all_units']
else:
Expand All @@ -887,6 +894,12 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units=
return_results=params['return_results']
else:
return_results=False

if crossval=='5_fold_constant':
skf = StratifiedKFold(n_splits=5,shuffle=True)
train_test_split = skf.split(input_data, labels)
else:
train_test_split=None

if session is not None:
session_info=npc_lims.get_session_info(session)
Expand Down Expand Up @@ -1122,7 +1135,8 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units=
decoder_type=decoder_type,
crossval=crossval,
crossval_index=None,
labels_as_index=labels_as_index)
labels_as_index=labels_as_index,
train_test_split_input=train_test_split)

if nunits=='all':
break
Expand Down Expand Up @@ -1414,11 +1428,11 @@ def compute_significant_decoding_by_area(all_decoder_results):
diff_from_null_DR['n_expts_DR'].append(len(DR_linear_shift_df.query('area==@area')))
for nu in n_units:
diff_from_null_DR['diff_from_null_mean_DR'+nu].append((DR_linear_shift_df.query('area==@area')['true_accuracy'+nu]-
DR_linear_shift_df.query('area==@area')['null_accuracy_mean'+nu]).mean())
DR_linear_shift_df.query('area==@area')['null_accuracy_median'+nu]).mean())
diff_from_null_DR['diff_from_null_median_DR'+nu].append((DR_linear_shift_df.query('area==@area')['true_accuracy'+nu]-
DR_linear_shift_df.query('area==@area')['null_accuracy_median'+nu]).median())
diff_from_null_DR['diff_from_null_sem_DR'+nu].append((DR_linear_shift_df.query('area==@area')['true_accuracy'+nu]-
DR_linear_shift_df.query('area==@area')['null_accuracy_mean'+nu]).sem())
DR_linear_shift_df.query('area==@area')['null_accuracy_median'+nu]).sem())
diff_from_null_DR['true_accuracy_DR'+nu].append(DR_linear_shift_df.query('area==@area')['true_accuracy'+nu].median())
diff_from_null_DR['true_accuracy_sem_DR'+nu].append(DR_linear_shift_df.query('area==@area')['true_accuracy'+nu].sem())
diff_from_null_DR['null_median_DR'+nu].append(DR_linear_shift_df.query('area==@area')['null_accuracy_median'+nu].median())
Expand Down

0 comments on commit 89109fd

Please sign in to comment.