Skip to content

Commit

Permalink
better handle missing results (i.e. not enough units)
Browse files Browse the repository at this point in the history
  • Loading branch information
egmcbride committed Nov 5, 2024
1 parent 29ade70 commit 5f4080f
Showing 1 changed file with 28 additions and 15 deletions.
43 changes: 28 additions & 15 deletions src/dynamic_routing_analysis/decoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units=

#loop through repeats
for nunits in n_units_input:
if nunits>len(area_units):
if nunits!='all' and nunits>len(area_units):
continue
decoder_results[session_id]['results'][aa]['shift'][nunits]={}
for rr in range(n_repeats):
Expand Down Expand Up @@ -1243,6 +1243,8 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
all_bal_acc[session_id][aa]={}
### ADD LOOP FOR NUNITS ###
for nu in nunits:
if nu not in decoder_results[session_id]['results'][aa]['shift'].keys():
continue
all_bal_acc[session_id][aa][nu]=[]
for rr in range(n_repeats):
if rr in decoder_results[session_id]['results'][aa]['shift'][nu].keys():
Expand Down Expand Up @@ -1271,19 +1273,28 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
### LOOP THROUGH NUNITS TO APPEND TO DICT ###

for nu in nunits:
true_acc_ind=np.where(half_shifts==1)[0][0]
null_acc_ind=np.where(half_shifts!=1)[0]
true_accuracy=all_bal_acc[session_id][aa][nu][true_acc_ind]
null_accuracy_mean=np.mean(all_bal_acc[session_id][aa][nu][null_acc_ind])
null_accuracy_median=np.median(all_bal_acc[session_id][aa][nu][null_acc_ind])
null_accuracy_std=np.std(all_bal_acc[session_id][aa][nu][null_acc_ind])
p_value=np.mean(all_bal_acc[session_id][aa][nu][null_acc_ind]>=true_accuracy)

linear_shift_dict['true_accuracy_'+str(nu)].append(true_accuracy)
linear_shift_dict['null_accuracy_mean_'+str(nu)].append(null_accuracy_mean)
linear_shift_dict['null_accuracy_median_'+str(nu)].append(null_accuracy_median)
linear_shift_dict['null_accuracy_std_'+str(nu)].append(null_accuracy_std)
linear_shift_dict['p_value_'+str(nu)].append(p_value)
if nu in all_bal_acc[session_id][aa].keys():

true_acc_ind=np.where(half_shifts==1)[0][0]
null_acc_ind=np.where(half_shifts!=1)[0]
true_accuracy=all_bal_acc[session_id][aa][nu][true_acc_ind]
null_accuracy_mean=np.mean(all_bal_acc[session_id][aa][nu][null_acc_ind])
null_accuracy_median=np.median(all_bal_acc[session_id][aa][nu][null_acc_ind])
null_accuracy_std=np.std(all_bal_acc[session_id][aa][nu][null_acc_ind])
p_value=np.mean(all_bal_acc[session_id][aa][nu][null_acc_ind]>=true_accuracy)

linear_shift_dict['true_accuracy_'+str(nu)].append(true_accuracy)
linear_shift_dict['null_accuracy_mean_'+str(nu)].append(null_accuracy_mean)
linear_shift_dict['null_accuracy_median_'+str(nu)].append(null_accuracy_median)
linear_shift_dict['null_accuracy_std_'+str(nu)].append(null_accuracy_std)
linear_shift_dict['p_value_'+str(nu)].append(p_value)

else:
linear_shift_dict['true_accuracy_'+str(nu)].append(np.nan)
linear_shift_dict['null_accuracy_mean_'+str(nu)].append(np.nan)
linear_shift_dict['null_accuracy_median_'+str(nu)].append(np.nan)
linear_shift_dict['null_accuracy_std_'+str(nu)].append(np.nan)
linear_shift_dict['p_value_'+str(nu)].append(np.nan)

#make big dict/dataframe for this:
#save true decoding, mean/median null decoding, and p value for each area/probe
Expand Down Expand Up @@ -1594,7 +1605,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
npc_lims.get_cache_path('performance',session_id,version='any')
)
except:
print('skipping session:',session_id)
print('trials or performance not available; skipping session:',session_id)
continue

trials_since_rewarded_target=[]
Expand Down Expand Up @@ -1681,6 +1692,8 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
temp_shifts=[]
for rr in range(n_repeats):
if n_units is not None:
if n_units not in decoder_results[session_id]['results'][aa]['shift'].keys():
continue
if n_units=='all' and rr>0:
continue
if sh in list(decoder_results[session_id]['results'][aa]['shift'][n_units][rr].keys()):
Expand Down

0 comments on commit 5f4080f

Please sign in to comment.