From 89109fded3380b733beb8fb16109d2f71ce54fb4 Mon Sep 17 00:00:00 2001 From: egmcbride Date: Fri, 8 Nov 2024 15:38:43 -0800 Subject: [PATCH] add option to keep crossval folds constant for a session --- .../decoding_utils.py | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/src/dynamic_routing_analysis/decoding_utils.py b/src/dynamic_routing_analysis/decoding_utils.py index 66518f1..d8721b6 100644 --- a/src/dynamic_routing_analysis/decoding_utils.py +++ b/src/dynamic_routing_analysis/decoding_utils.py @@ -36,7 +36,8 @@ def dump_dict_to_zarr(group, data): print(f'Could not save {key} of type {type(value)}') # 'linearSVC' or 'LDA' or 'RandomForest' -def decoder_helper(input_data,labels,decoder_type='linearSVC',crossval='5_fold',crossval_index=None,labels_as_index=False): +def decoder_helper(input_data,labels,decoder_type='linearSVC',crossval='5_fold', + crossval_index=None,labels_as_index=False,train_test_split_input=None): #helper function to decode labels from input data using different decoder models if decoder_type=='linearSVC': @@ -172,6 +173,11 @@ def decoder_helper(input_data,labels,decoder_type='linearSVC',crossval='5_fold', skf = StratifiedKFold(n_splits=5,shuffle=True) train_test_split = skf.split(input_data, labels) + elif crossval=='5_fold_constant': + if train_test_split_input is None: + raise ValueError('Must provide train_test_split_input') + train_test_split = train_test_split_input + for train,test in train_test_split: clf.fit(X[train],y[train]) @@ -199,7 +205,7 @@ def decoder_helper(input_data,labels,decoder_type='linearSVC',crossval='5_fold', if decoder_type == 'LDA' or decoder_type == 'RandomForest' or decoder_type=='LogisticRegression': ypred_proba[test,:] = clf.predict_proba(X[test]) else: - ypred_proba[test,:] = np.full((len(test),len(np.unique(labels))), fill_value=np.nan) + ypred_proba[test,:] = np.full((len(test),len(np.unique(labels))), fill_value=False) models.append(clf) @@ -210,7 +216,7 @@ def decoder_helper(input_data,labels,decoder_type='linearSVC',crossval='5_fold', if decoder_type == 'LDA' or decoder_type == 'linearSVC': coefs = clf.coef_ else: - coefs = np.full((X.shape[1]), fill_value=np.nan) + coefs = np.full((X.shape[1]), fill_value=False) output['cr']=cr_dict_test @@ -878,6 +884,7 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units= decoder_type=params['decoder_type'] # use_coefs=params['use_coefs'] # generate_labels=params['generate_labels'] + if 'only_use_all_units' in params: only_use_all_units=params['only_use_all_units'] else: @@ -887,6 +894,12 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units= return_results=params['return_results'] else: return_results=False + + if crossval=='5_fold_constant': + skf = StratifiedKFold(n_splits=5,shuffle=True) + train_test_split = skf.split(input_data, labels) + else: + train_test_split=None if session is not None: session_info=npc_lims.get_session_info(session) @@ -1122,7 +1135,8 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units= decoder_type=decoder_type, crossval=crossval, crossval_index=None, - labels_as_index=labels_as_index) + labels_as_index=labels_as_index, + train_test_split_input=train_test_split) if nunits=='all': break @@ -1414,11 +1428,11 @@ def compute_significant_decoding_by_area(all_decoder_results): diff_from_null_DR['n_expts_DR'].append(len(DR_linear_shift_df.query('area==@area'))) for nu in n_units: diff_from_null_DR['diff_from_null_mean_DR'+nu].append((DR_linear_shift_df.query('area==@area')['true_accuracy'+nu]- - DR_linear_shift_df.query('area==@area')['null_accuracy_mean'+nu]).mean()) + DR_linear_shift_df.query('area==@area')['null_accuracy_median'+nu]).mean()) diff_from_null_DR['diff_from_null_median_DR'+nu].append((DR_linear_shift_df.query('area==@area')['true_accuracy'+nu]- DR_linear_shift_df.query('area==@area')['null_accuracy_median'+nu]).median()) diff_from_null_DR['diff_from_null_sem_DR'+nu].append((DR_linear_shift_df.query('area==@area')['true_accuracy'+nu]- - DR_linear_shift_df.query('area==@area')['null_accuracy_mean'+nu]).sem()) + DR_linear_shift_df.query('area==@area')['null_accuracy_median'+nu]).sem()) diff_from_null_DR['true_accuracy_DR'+nu].append(DR_linear_shift_df.query('area==@area')['true_accuracy'+nu].median()) diff_from_null_DR['true_accuracy_sem_DR'+nu].append(DR_linear_shift_df.query('area==@area')['true_accuracy'+nu].sem()) diff_from_null_DR['null_median_DR'+nu].append(DR_linear_shift_df.query('area==@area')['null_accuracy_median'+nu].median())