From 40f7a16cdd9c3ddc852407931774f0d64df98164 Mon Sep 17 00:00:00 2001 From: egmcbride Date: Mon, 2 Dec 2024 15:57:24 -0800 Subject: [PATCH] add table with all trial predict probabilities --- .../decoding_utils.py | 54 ++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/src/dynamic_routing_analysis/decoding_utils.py b/src/dynamic_routing_analysis/decoding_utils.py index c47dbab..d710497 100644 --- a/src/dynamic_routing_analysis/decoding_utils.py +++ b/src/dynamic_routing_analysis/decoding_utils.py @@ -1324,6 +1324,10 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session= areas=decoder_results[session_id]['areas'] + #TODO: ####option to exclude areas with multiple probes in a session or SC subdivisions#### + + #TODO: add decoder accuracy using all trials (no shift) + #save balanced accuracy by shift for aa in areas: if aa in decoder_results[session_id]['results']: @@ -1721,6 +1725,22 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un 'n_units':[], } + #TODO: add table with decoder condfidence for all trials, plus other useful session-level information + decoder_confidence_all_trials={ + 'session':[], + 'area':[], + 'project':[], + 'cross_modal_dprime':[], + 'n_good_blocks':[], + 'trial_index':[], + 'confidence':[], + 'predict_proba':[], + 'ccf_ap_mean':[], + 'ccf_dv_mean':[], + 'ccf_ml_mean':[], + 'n_units':[], + } + start_time=time.time() ##loop through sessions## @@ -1830,7 +1850,10 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un decision_function_shifts=[] predict_proba_shifts=[] - + + confidence_all_trials=[] + predict_proba_all_trials=[] + for sh in half_shift_inds: temp_shifts=[] temp_proba_shifts=[] @@ -1842,6 +1865,10 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un if sh in list(decoder_results[session_id]['results'][aa]['shift'][n_units][rr].keys()): temp_shifts.append(decoder_results[session_id]['results'][aa]['shift'][n_units][rr][sh]['decision_function']) temp_proba_shifts.append(decoder_results[session_id]['results'][aa]['shift'][n_units][rr][sh]['predict_proba'][:,1]) + + if sh==0: + confidence_all_trials.append(decoder_results[session_id]['results'][aa]['no_shift'][n_units][rr]['decision_function']) + predict_proba_all_trials.append(decoder_results[session_id]['results'][aa]['no_shift'][n_units][rr]['predict_proba'][:,1]) else: if sh in list(decoder_results[session_id]['results'][aa]['shift'][rr].keys()): temp_shifts.append(decoder_results[session_id]['results'][aa]['shift'][rr][sh]['decision_function']) @@ -1853,6 +1880,9 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un decision_function_shifts.append(np.nan) predict_proba_shifts.append(np.nan) + confidence_all_trials=np.nanmean(np.vstack(confidence_all_trials),axis=0) + predict_proba_all_trials=np.nanmean(np.vstack(predict_proba_all_trials),axis=0) + # true_label=decoder_results[session_id]['results'][aa]['shift'][np.where(shifts==0)[0][0]]['true_label'] try: @@ -2093,6 +2123,23 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un decoder_confidence_versus_trials_since_rewarded_target['cross_modal_dprime'].append(performance['cross_modal_dprime'].mean()) decoder_confidence_versus_trials_since_rewarded_target['n_good_blocks'].append(np.sum(performance['cross_modal_dprime']>=1.0)) + #decoder confidence for every trial + decoder_confidence_all_trials['session'].append(session_id_str) + decoder_confidence_all_trials['area'].append(aa) + decoder_confidence_all_trials['project'].append(project) + decoder_confidence_all_trials['cross_modal_dprime'].append(performance['cross_modal_dprime'].mean()) + decoder_confidence_all_trials['n_good_blocks'].append(np.sum(performance['cross_modal_dprime']>=1.0)) + decoder_confidence_all_trials['trial_index'].append(trials_middle['original_index'].values) + decoder_confidence_all_trials['confidence'].append(confidence_all_trials) + decoder_confidence_all_trials['predict_proba'].append(predict_proba_all_trials) + decoder_confidence_all_trials['n_units'].append(decoder_results[session_id]['results'][aa]['n_units']) + + # 'ccf_ap_mean', 'ccf_dv_mean', 'ccf_ml_mean' + if 'ccf_ap_mean' in decoder_results[session_id]['results'][aa].keys(): + decoder_confidence_all_trials['ccf_ap_mean'].append(decoder_results[session_id]['results'][aa]['ccf_ap_mean']) + decoder_confidence_all_trials['ccf_dv_mean'].append(decoder_results[session_id]['results'][aa]['ccf_dv_mean']) + decoder_confidence_all_trials['ccf_ml_mean'].append(decoder_results[session_id]['results'][aa]['ccf_ml_mean']) + ##loop through blocks## blocks=trials_middle['block_index'].unique() for bb in blocks: @@ -2263,12 +2310,14 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un decoder_confidence_dprime_by_block_dict=decoder_confidence_dprime_by_block.copy() decoder_confidence_by_switch_dict=decoder_confidence_by_switch.copy() decoder_confidence_versus_trials_since_rewarded_target_dict=decoder_confidence_versus_trials_since_rewarded_target.copy() + decoder_confidence_all_trials_dict=decoder_confidence_all_trials.copy() decoder_confidence_before_after_target_dict=decoder_confidence_before_after_target.copy() decoder_confidence_versus_response_type=pd.DataFrame(decoder_confidence_versus_response_type) decoder_confidence_dprime_by_block=pd.DataFrame(decoder_confidence_dprime_by_block) decoder_confidence_by_switch=pd.DataFrame(decoder_confidence_by_switch) decoder_confidence_versus_trials_since_rewarded_target=pd.DataFrame(decoder_confidence_versus_trials_since_rewarded_target) + decoder_confidence_all_trials=pd.DataFrame(decoder_confidence_all_trials) decoder_confidence_before_after_target=pd.DataFrame(decoder_confidence_before_after_target) if savepath is not None: @@ -2291,6 +2340,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un 'decoder_confidence_dprime_by_block'+n_units_str:decoder_confidence_dprime_by_block_dict, 'decoder_confidence_by_switch'+n_units_str:decoder_confidence_by_switch_dict, 'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str:decoder_confidence_versus_trials_since_rewarded_target_dict, + 'decoder_confidence_all_trials'+n_units_str:decoder_confidence_all_trials_dict, 'decoder_confidence_before_after_target'+n_units_str:decoder_confidence_before_after_target_dict, }, } @@ -2305,12 +2355,14 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un decoder_confidence_dprime_by_block.to_csv(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_dprime_by_block'+n_units_str+'.csv'),index=False) decoder_confidence_by_switch.to_csv(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_by_switch'+n_units_str+'.csv'),index=False) decoder_confidence_versus_trials_since_rewarded_target.to_csv(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str+'.csv'),index=False) + decoder_confidence_all_trials.to_csv(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_all_trials'+n_units_str+'.csv'),index=False) decoder_confidence_before_after_target.to_csv(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_before_after_target'+n_units_str+'.csv'),index=False) decoder_confidence_versus_response_type.to_pickle(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_versus_response_type'+n_units_str+'.pkl')) decoder_confidence_dprime_by_block.to_pickle(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_dprime_by_block'+n_units_str+'.pkl')) decoder_confidence_by_switch.to_pickle(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_by_switch'+n_units_str+'.pkl')) decoder_confidence_versus_trials_since_rewarded_target.to_pickle(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str+'.pkl')) + decoder_confidence_all_trials.to_pickle(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_all_trials'+n_units_str+'.pkl')) decoder_confidence_before_after_target.to_pickle(upath.UPath(savepath) / (temp_session_str+'decoder_confidence_before_after_target'+n_units_str+'.pkl')) print('saved '+n_units_str+' decoder confidence tables to:',savepath)