Skip to content

Commit

Permalink
add table with all trial predict probabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
egmcbride committed Dec 2, 2024
1 parent cc14323 commit 40f7a16
Showing 1 changed file with 53 additions and 1 deletion.
54 changes: 53 additions & 1 deletion src/dynamic_routing_analysis/decoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,10 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=

areas=decoder_results[session_id]['areas']

#TODO: ####option to exclude areas with multiple probes in a session or SC subdivisions####

#TODO: add decoder accuracy using all trials (no shift)

#save balanced accuracy by shift
for aa in areas:
if aa in decoder_results[session_id]['results']:
Expand Down Expand Up @@ -1721,6 +1725,22 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
'n_units':[],
}

#TODO: add table with decoder condfidence for all trials, plus other useful session-level information
decoder_confidence_all_trials={
'session':[],
'area':[],
'project':[],
'cross_modal_dprime':[],
'n_good_blocks':[],
'trial_index':[],
'confidence':[],
'predict_proba':[],
'ccf_ap_mean':[],
'ccf_dv_mean':[],
'ccf_ml_mean':[],
'n_units':[],
}

start_time=time.time()

##loop through sessions##
Expand Down Expand Up @@ -1830,7 +1850,10 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un

decision_function_shifts=[]
predict_proba_shifts=[]


confidence_all_trials=[]
predict_proba_all_trials=[]

for sh in half_shift_inds:
temp_shifts=[]
temp_proba_shifts=[]
Expand All @@ -1842,6 +1865,10 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
if sh in list(decoder_results[session_id]['results'][aa]['shift'][n_units][rr].keys()):
temp_shifts.append(decoder_results[session_id]['results'][aa]['shift'][n_units][rr][sh]['decision_function'])
temp_proba_shifts.append(decoder_results[session_id]['results'][aa]['shift'][n_units][rr][sh]['predict_proba'][:,1])

if sh==0:
confidence_all_trials.append(decoder_results[session_id]['results'][aa]['no_shift'][n_units][rr]['decision_function'])
predict_proba_all_trials.append(decoder_results[session_id]['results'][aa]['no_shift'][n_units][rr]['predict_proba'][:,1])
else:
if sh in list(decoder_results[session_id]['results'][aa]['shift'][rr].keys()):
temp_shifts.append(decoder_results[session_id]['results'][aa]['shift'][rr][sh]['decision_function'])
Expand All @@ -1853,6 +1880,9 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
decision_function_shifts.append(np.nan)
predict_proba_shifts.append(np.nan)

confidence_all_trials=np.nanmean(np.vstack(confidence_all_trials),axis=0)
predict_proba_all_trials=np.nanmean(np.vstack(predict_proba_all_trials),axis=0)

# true_label=decoder_results[session_id]['results'][aa]['shift'][np.where(shifts==0)[0][0]]['true_label']

try:
Expand Down Expand Up @@ -2093,6 +2123,23 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
decoder_confidence_versus_trials_since_rewarded_target['cross_modal_dprime'].append(performance['cross_modal_dprime'].mean())
decoder_confidence_versus_trials_since_rewarded_target['n_good_blocks'].append(np.sum(performance['cross_modal_dprime']>=1.0))

#decoder confidence for every trial
decoder_confidence_all_trials['session'].append(session_id_str)
decoder_confidence_all_trials['area'].append(aa)
decoder_confidence_all_trials['project'].append(project)
decoder_confidence_all_trials['cross_modal_dprime'].append(performance['cross_modal_dprime'].mean())
decoder_confidence_all_trials['n_good_blocks'].append(np.sum(performance['cross_modal_dprime']>=1.0))
decoder_confidence_all_trials['trial_index'].append(trials_middle['original_index'].values)
decoder_confidence_all_trials['confidence'].append(confidence_all_trials)
decoder_confidence_all_trials['predict_proba'].append(predict_proba_all_trials)
decoder_confidence_all_trials['n_units'].append(decoder_results[session_id]['results'][aa]['n_units'])

# 'ccf_ap_mean', 'ccf_dv_mean', 'ccf_ml_mean'
if 'ccf_ap_mean' in decoder_results[session_id]['results'][aa].keys():
decoder_confidence_all_trials['ccf_ap_mean'].append(decoder_results[session_id]['results'][aa]['ccf_ap_mean'])
decoder_confidence_all_trials['ccf_dv_mean'].append(decoder_results[session_id]['results'][aa]['ccf_dv_mean'])
decoder_confidence_all_trials['ccf_ml_mean'].append(decoder_results[session_id]['results'][aa]['ccf_ml_mean'])

##loop through blocks##
blocks=trials_middle['block_index'].unique()
for bb in blocks:
Expand Down Expand Up @@ -2263,12 +2310,14 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
decoder_confidence_dprime_by_block_dict=decoder_confidence_dprime_by_block.copy()
decoder_confidence_by_switch_dict=decoder_confidence_by_switch.copy()
decoder_confidence_versus_trials_since_rewarded_target_dict=decoder_confidence_versus_trials_since_rewarded_target.copy()
decoder_confidence_all_trials_dict=decoder_confidence_all_trials.copy()
decoder_confidence_before_after_target_dict=decoder_confidence_before_after_target.copy()

decoder_confidence_versus_response_type=pd.DataFrame(decoder_confidence_versus_response_type)
decoder_confidence_dprime_by_block=pd.DataFrame(decoder_confidence_dprime_by_block)
decoder_confidence_by_switch=pd.DataFrame(decoder_confidence_by_switch)
decoder_confidence_versus_trials_since_rewarded_target=pd.DataFrame(decoder_confidence_versus_trials_since_rewarded_target)
decoder_confidence_all_trials=pd.DataFrame(decoder_confidence_all_trials)
decoder_confidence_before_after_target=pd.DataFrame(decoder_confidence_before_after_target)

if savepath is not None:
Expand All @@ -2291,6 +2340,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
'decoder_confidence_dprime_by_block'+n_units_str:decoder_confidence_dprime_by_block_dict,
'decoder_confidence_by_switch'+n_units_str:decoder_confidence_by_switch_dict,
'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str:decoder_confidence_versus_trials_since_rewarded_target_dict,
'decoder_confidence_all_trials'+n_units_str:decoder_confidence_all_trials_dict,
'decoder_confidence_before_after_target'+n_units_str:decoder_confidence_before_after_target_dict,
},
}
Expand All @@ -2305,12 +2355,14 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
decoder_confidence_dprime_by_block.to_csv(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_dprime_by_block'+n_units_str+'.csv'),index=False)
decoder_confidence_by_switch.to_csv(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_by_switch'+n_units_str+'.csv'),index=False)
decoder_confidence_versus_trials_since_rewarded_target.to_csv(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str+'.csv'),index=False)
decoder_confidence_all_trials.to_csv(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_all_trials'+n_units_str+'.csv'),index=False)
decoder_confidence_before_after_target.to_csv(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_before_after_target'+n_units_str+'.csv'),index=False)

decoder_confidence_versus_response_type.to_pickle(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_versus_response_type'+n_units_str+'.pkl'))
decoder_confidence_dprime_by_block.to_pickle(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_dprime_by_block'+n_units_str+'.pkl'))
decoder_confidence_by_switch.to_pickle(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_by_switch'+n_units_str+'.pkl'))
decoder_confidence_versus_trials_since_rewarded_target.to_pickle(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str+'.pkl'))
decoder_confidence_all_trials.to_pickle(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_all_trials'+n_units_str+'.pkl'))
decoder_confidence_before_after_target.to_pickle(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_before_after_target'+n_units_str+'.pkl'))

print('saved '+n_units_str+' decoder confidence tables to:',savepath)
Expand Down

0 comments on commit 40f7a16

Please sign in to comment.