From 5f4080f210c6f1d183b66b3fc185189f9f3b14f4 Mon Sep 17 00:00:00 2001 From: egmcbride Date: Tue, 5 Nov 2024 12:10:20 -0800 Subject: [PATCH] better handle missing results (i.e. not enough units) --- .../decoding_utils.py | 43 ++++++++++++------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/src/dynamic_routing_analysis/decoding_utils.py b/src/dynamic_routing_analysis/decoding_utils.py index c160a71..5fb9aeb 100644 --- a/src/dynamic_routing_analysis/decoding_utils.py +++ b/src/dynamic_routing_analysis/decoding_utils.py @@ -1064,7 +1064,7 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units= #loop through repeats for nunits in n_units_input: - if nunits>len(area_units): + if nunits!='all' and nunits>len(area_units): continue decoder_results[session_id]['results'][aa]['shift'][nunits]={} for rr in range(n_repeats): @@ -1243,6 +1243,8 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session= all_bal_acc[session_id][aa]={} ### ADD LOOP FOR NUNITS ### for nu in nunits: + if nu not in decoder_results[session_id]['results'][aa]['shift'].keys(): + continue all_bal_acc[session_id][aa][nu]=[] for rr in range(n_repeats): if rr in decoder_results[session_id]['results'][aa]['shift'][nu].keys(): @@ -1271,19 +1273,28 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session= ### LOOP THROUGH NUNITS TO APPEND TO DICT ### for nu in nunits: - true_acc_ind=np.where(half_shifts==1)[0][0] - null_acc_ind=np.where(half_shifts!=1)[0] - true_accuracy=all_bal_acc[session_id][aa][nu][true_acc_ind] - null_accuracy_mean=np.mean(all_bal_acc[session_id][aa][nu][null_acc_ind]) - null_accuracy_median=np.median(all_bal_acc[session_id][aa][nu][null_acc_ind]) - null_accuracy_std=np.std(all_bal_acc[session_id][aa][nu][null_acc_ind]) - p_value=np.mean(all_bal_acc[session_id][aa][nu][null_acc_ind]>=true_accuracy) - - linear_shift_dict['true_accuracy_'+str(nu)].append(true_accuracy) - linear_shift_dict['null_accuracy_mean_'+str(nu)].append(null_accuracy_mean) - linear_shift_dict['null_accuracy_median_'+str(nu)].append(null_accuracy_median) - linear_shift_dict['null_accuracy_std_'+str(nu)].append(null_accuracy_std) - linear_shift_dict['p_value_'+str(nu)].append(p_value) + if nu in all_bal_acc[session_id][aa].keys(): + + true_acc_ind=np.where(half_shifts==1)[0][0] + null_acc_ind=np.where(half_shifts!=1)[0] + true_accuracy=all_bal_acc[session_id][aa][nu][true_acc_ind] + null_accuracy_mean=np.mean(all_bal_acc[session_id][aa][nu][null_acc_ind]) + null_accuracy_median=np.median(all_bal_acc[session_id][aa][nu][null_acc_ind]) + null_accuracy_std=np.std(all_bal_acc[session_id][aa][nu][null_acc_ind]) + p_value=np.mean(all_bal_acc[session_id][aa][nu][null_acc_ind]>=true_accuracy) + + linear_shift_dict['true_accuracy_'+str(nu)].append(true_accuracy) + linear_shift_dict['null_accuracy_mean_'+str(nu)].append(null_accuracy_mean) + linear_shift_dict['null_accuracy_median_'+str(nu)].append(null_accuracy_median) + linear_shift_dict['null_accuracy_std_'+str(nu)].append(null_accuracy_std) + linear_shift_dict['p_value_'+str(nu)].append(p_value) + + else: + linear_shift_dict['true_accuracy_'+str(nu)].append(np.nan) + linear_shift_dict['null_accuracy_mean_'+str(nu)].append(np.nan) + linear_shift_dict['null_accuracy_median_'+str(nu)].append(np.nan) + linear_shift_dict['null_accuracy_std_'+str(nu)].append(np.nan) + linear_shift_dict['p_value_'+str(nu)].append(np.nan) #make big dict/dataframe for this: #save true decoding, mean/median null decoding, and p value for each area/probe @@ -1594,7 +1605,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un npc_lims.get_cache_path('performance',session_id,version='any') ) except: - print('skipping session:',session_id) + print('trials or performance not available; skipping session:',session_id) continue trials_since_rewarded_target=[] @@ -1681,6 +1692,8 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un temp_shifts=[] for rr in range(n_repeats): if n_units is not None: + if n_units not in decoder_results[session_id]['results'][aa]['shift'].keys(): + continue if n_units=='all' and rr>0: continue if sh in list(decoder_results[session_id]['results'][aa]['shift'][n_units][rr].keys()):