From 5c5ca7e9692e94971e82715183ed30331f7cbc96 Mon Sep 17 00:00:00 2001 From: egmcbride Date: Mon, 2 Dec 2024 16:35:43 -0800 Subject: [PATCH] improve handling of decoding results from multiple probes --- .../decoding_utils.py | 32 +++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/src/dynamic_routing_analysis/decoding_utils.py b/src/dynamic_routing_analysis/decoding_utils.py index d710497..ee78ac8 100644 --- a/src/dynamic_routing_analysis/decoding_utils.py +++ b/src/dynamic_routing_analysis/decoding_utils.py @@ -1340,12 +1340,15 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session= for rr in range(n_repeats): if rr in decoder_results[session_id]['results'][aa]['shift'][nu].keys(): temp_bal_acc=[] + temp_bal_acc_all_trials=[] # else: # print('n repeats invalid: '+str(rr)) # continue for sh in half_shift_inds: if sh in list(decoder_results[session_id]['results'][aa]['shift'][nu][rr].keys()): temp_bal_acc.append(decoder_results[session_id]['results'][aa]['shift'][nu][rr][sh]['balanced_accuracy_test']) + if sh==0: + temp_bal_acc_all_trials.append(decoder_results[session_id]['results'][aa]['no_shift'][nu][rr]['balanced_accuracy_test']) if len(temp_bal_acc)>0: all_bal_acc[session_id][aa][nu].append(np.array(temp_bal_acc)) all_bal_acc[session_id][aa][nu]=np.vstack(all_bal_acc[session_id][aa][nu]) @@ -1355,6 +1358,9 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session= if '_probe' in aa: area_name=aa.split('_probe')[0] probe_name=aa.split('_probe')[1] + elif '_all' in aa: + area_name=aa.split('_all')[0] + probe_name='all_probes' else: area_name=aa probe_name='' @@ -1582,6 +1588,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un 'session':[], 'area':[], 'project':[], + 'probe':[], 'vis_context_dprime':[], 'aud_context_dprime':[], 'overall_dprime':[], @@ -1645,6 +1652,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un 'session':[], 'area':[], 'project':[], + 'probe':[], 'block':[], 'cross_modal_dprime':[], 'n_good_blocks':[], @@ -1664,6 +1672,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un 'session':[], 'area':[], 'project':[], + 'probe':[], 'switch_trial':[], 'block':[], 'dprime_before':[], @@ -1684,6 +1693,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un 'session':[], 'area':[], 'project':[], + 'probe':[], 'trial_index':[], 'trials_since_rewarded_target':[], 'time_since_rewarded_target':[], @@ -1709,6 +1719,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un 'session':[], 'area':[], 'project':[], + 'probe':[], 'cross_modal_dprime':[], 'n_good_blocks':[], 'rewarded_target':[], @@ -1730,6 +1741,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un 'session':[], 'area':[], 'project':[], + 'probe':[], 'cross_modal_dprime':[], 'n_good_blocks':[], 'trial_index':[], @@ -1833,8 +1845,16 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un for aa in areas: if n_units not in decoder_results[session_id]['results'][aa]['shift'].keys(): continue - if type(aa)==str and 'probe' in aa: - area_name=aa.split('_')[0] + if type(aa)==str: + if '_probe' in aa: + area_name=aa.split('_probe')[0] + probe_name=aa.split('_probe')[1] + elif '_all' in aa: + area_name=aa.split('_all')[0] + probe_name='all_probes' + else: + area_name=aa + probe_name='' else: area_name=aa #make corrected decoder confidence @@ -1993,6 +2013,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un decoder_confidence_versus_response_type['session'].append(session_id_str) decoder_confidence_versus_response_type['area'].append(area_name) decoder_confidence_versus_response_type['project'].append(project) + decoder_confidence_versus_response_type['probe'].append(probe_name) if performance.query('rewarded_modality=="vis"').empty: decoder_confidence_versus_response_type['vis_context_dprime'].append(np.nan) else: @@ -2084,6 +2105,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un decoder_confidence_versus_trials_since_rewarded_target['session'].append(session_id_str) decoder_confidence_versus_trials_since_rewarded_target['area'].append(area_name) decoder_confidence_versus_trials_since_rewarded_target['project'].append(project) + decoder_confidence_versus_trials_since_rewarded_target['probe'].append(probe_name) decoder_confidence_versus_trials_since_rewarded_target['trial_index'].append(trials_middle['original_index'].values) decoder_confidence_versus_trials_since_rewarded_target['trials_since_rewarded_target'].append(trials_since_rewarded_target) decoder_confidence_versus_trials_since_rewarded_target['time_since_rewarded_target'].append(time_since_rewarded_target) @@ -2125,8 +2147,9 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un #decoder confidence for every trial decoder_confidence_all_trials['session'].append(session_id_str) - decoder_confidence_all_trials['area'].append(aa) + decoder_confidence_all_trials['area'].append(area_name) decoder_confidence_all_trials['project'].append(project) + decoder_confidence_all_trials['probe'].append(probe_name) decoder_confidence_all_trials['cross_modal_dprime'].append(performance['cross_modal_dprime'].mean()) decoder_confidence_all_trials['n_good_blocks'].append(np.sum(performance['cross_modal_dprime']>=1.0)) decoder_confidence_all_trials['trial_index'].append(trials_middle['original_index'].values) @@ -2166,6 +2189,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un decoder_confidence_dprime_by_block['session'].append(session_id_str) decoder_confidence_dprime_by_block['area'].append(area_name) decoder_confidence_dprime_by_block['project'].append(project) + decoder_confidence_dprime_by_block['probe'].append(probe_name) decoder_confidence_dprime_by_block['block'].append(bb) decoder_confidence_dprime_by_block['cross_modal_dprime'].append(block_dprime) decoder_confidence_dprime_by_block['n_good_blocks'].append(np.sum(performance['cross_modal_dprime']>=1.0)) @@ -2203,6 +2227,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un decoder_confidence_by_switch['session'].append(session_id_str) decoder_confidence_by_switch['area'].append(area_name) decoder_confidence_by_switch['project'].append(project) + decoder_confidence_by_switch['probe'].append(probe_name) decoder_confidence_by_switch['switch_trial'].append(switch_trial['id']) decoder_confidence_by_switch['block'].append(switch_trial_block_index) decoder_confidence_by_switch['dprime_before'].append(performance.query('block_index==(@switch_trial_block_index-1)')['cross_modal_dprime'].values[0]) @@ -2266,6 +2291,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un decoder_confidence_before_after_target['session'].append(session_id_str) decoder_confidence_before_after_target['area'].append(area_name) decoder_confidence_before_after_target['project'].append(project) + decoder_confidence_before_after_target['probe'].append(probe_name) decoder_confidence_before_after_target['cross_modal_dprime'].append(performance['cross_modal_dprime'].mean()) decoder_confidence_before_after_target['n_good_blocks'].append(np.sum(performance['cross_modal_dprime']>=1.0)) decoder_confidence_before_after_target['rewarded_target'].append(sign_corrected_decision_function[rewarded_target_trials])