Skip to content

Commit

Permalink
improve handling of decoding results from multiple probes
Browse files Browse the repository at this point in the history
  • Loading branch information
egmcbride committed Dec 3, 2024
1 parent 40f7a16 commit 5c5ca7e
Showing 1 changed file with 29 additions and 3 deletions.
32 changes: 29 additions & 3 deletions src/dynamic_routing_analysis/decoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,12 +1340,15 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
for rr in range(n_repeats):
if rr in decoder_results[session_id]['results'][aa]['shift'][nu].keys():
temp_bal_acc=[]
temp_bal_acc_all_trials=[]
# else:
# print('n repeats invalid: '+str(rr))
# continue
for sh in half_shift_inds:
if sh in list(decoder_results[session_id]['results'][aa]['shift'][nu][rr].keys()):
temp_bal_acc.append(decoder_results[session_id]['results'][aa]['shift'][nu][rr][sh]['balanced_accuracy_test'])
if sh==0:
temp_bal_acc_all_trials.append(decoder_results[session_id]['results'][aa]['no_shift'][nu][rr]['balanced_accuracy_test'])
if len(temp_bal_acc)>0:
all_bal_acc[session_id][aa][nu].append(np.array(temp_bal_acc))
all_bal_acc[session_id][aa][nu]=np.vstack(all_bal_acc[session_id][aa][nu])
Expand All @@ -1355,6 +1358,9 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
if '_probe' in aa:
area_name=aa.split('_probe')[0]
probe_name=aa.split('_probe')[1]
elif '_all' in aa:
area_name=aa.split('_all')[0]
probe_name='all_probes'
else:
area_name=aa
probe_name=''
Expand Down Expand Up @@ -1582,6 +1588,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
'session':[],
'area':[],
'project':[],
'probe':[],
'vis_context_dprime':[],
'aud_context_dprime':[],
'overall_dprime':[],
Expand Down Expand Up @@ -1645,6 +1652,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
'session':[],
'area':[],
'project':[],
'probe':[],
'block':[],
'cross_modal_dprime':[],
'n_good_blocks':[],
Expand All @@ -1664,6 +1672,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
'session':[],
'area':[],
'project':[],
'probe':[],
'switch_trial':[],
'block':[],
'dprime_before':[],
Expand All @@ -1684,6 +1693,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
'session':[],
'area':[],
'project':[],
'probe':[],
'trial_index':[],
'trials_since_rewarded_target':[],
'time_since_rewarded_target':[],
Expand All @@ -1709,6 +1719,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
'session':[],
'area':[],
'project':[],
'probe':[],
'cross_modal_dprime':[],
'n_good_blocks':[],
'rewarded_target':[],
Expand All @@ -1730,6 +1741,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
'session':[],
'area':[],
'project':[],
'probe':[],
'cross_modal_dprime':[],
'n_good_blocks':[],
'trial_index':[],
Expand Down Expand Up @@ -1833,8 +1845,16 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
for aa in areas:
if n_units not in decoder_results[session_id]['results'][aa]['shift'].keys():
continue
if type(aa)==str and 'probe' in aa:
area_name=aa.split('_')[0]
if type(aa)==str:
if '_probe' in aa:
area_name=aa.split('_probe')[0]
probe_name=aa.split('_probe')[1]
elif '_all' in aa:
area_name=aa.split('_all')[0]
probe_name='all_probes'
else:
area_name=aa
probe_name=''
else:
area_name=aa
#make corrected decoder confidence
Expand Down Expand Up @@ -1993,6 +2013,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
decoder_confidence_versus_response_type['session'].append(session_id_str)
decoder_confidence_versus_response_type['area'].append(area_name)
decoder_confidence_versus_response_type['project'].append(project)
decoder_confidence_versus_response_type['probe'].append(probe_name)
if performance.query('rewarded_modality=="vis"').empty:
decoder_confidence_versus_response_type['vis_context_dprime'].append(np.nan)
else:
Expand Down Expand Up @@ -2084,6 +2105,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
decoder_confidence_versus_trials_since_rewarded_target['session'].append(session_id_str)
decoder_confidence_versus_trials_since_rewarded_target['area'].append(area_name)
decoder_confidence_versus_trials_since_rewarded_target['project'].append(project)
decoder_confidence_versus_trials_since_rewarded_target['probe'].append(probe_name)
decoder_confidence_versus_trials_since_rewarded_target['trial_index'].append(trials_middle['original_index'].values)
decoder_confidence_versus_trials_since_rewarded_target['trials_since_rewarded_target'].append(trials_since_rewarded_target)
decoder_confidence_versus_trials_since_rewarded_target['time_since_rewarded_target'].append(time_since_rewarded_target)
Expand Down Expand Up @@ -2125,8 +2147,9 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un

#decoder confidence for every trial
decoder_confidence_all_trials['session'].append(session_id_str)
decoder_confidence_all_trials['area'].append(aa)
decoder_confidence_all_trials['area'].append(area_name)
decoder_confidence_all_trials['project'].append(project)
decoder_confidence_all_trials['probe'].append(probe_name)
decoder_confidence_all_trials['cross_modal_dprime'].append(performance['cross_modal_dprime'].mean())
decoder_confidence_all_trials['n_good_blocks'].append(np.sum(performance['cross_modal_dprime']>=1.0))
decoder_confidence_all_trials['trial_index'].append(trials_middle['original_index'].values)
Expand Down Expand Up @@ -2166,6 +2189,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
decoder_confidence_dprime_by_block['session'].append(session_id_str)
decoder_confidence_dprime_by_block['area'].append(area_name)
decoder_confidence_dprime_by_block['project'].append(project)
decoder_confidence_dprime_by_block['probe'].append(probe_name)
decoder_confidence_dprime_by_block['block'].append(bb)
decoder_confidence_dprime_by_block['cross_modal_dprime'].append(block_dprime)
decoder_confidence_dprime_by_block['n_good_blocks'].append(np.sum(performance['cross_modal_dprime']>=1.0))
Expand Down Expand Up @@ -2203,6 +2227,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
decoder_confidence_by_switch['session'].append(session_id_str)
decoder_confidence_by_switch['area'].append(area_name)
decoder_confidence_by_switch['project'].append(project)
decoder_confidence_by_switch['probe'].append(probe_name)
decoder_confidence_by_switch['switch_trial'].append(switch_trial['id'])
decoder_confidence_by_switch['block'].append(switch_trial_block_index)
decoder_confidence_by_switch['dprime_before'].append(performance.query('block_index==(@switch_trial_block_index-1)')['cross_modal_dprime'].values[0])
Expand Down Expand Up @@ -2266,6 +2291,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
decoder_confidence_before_after_target['session'].append(session_id_str)
decoder_confidence_before_after_target['area'].append(area_name)
decoder_confidence_before_after_target['project'].append(project)
decoder_confidence_before_after_target['probe'].append(probe_name)
decoder_confidence_before_after_target['cross_modal_dprime'].append(performance['cross_modal_dprime'].mean())
decoder_confidence_before_after_target['n_good_blocks'].append(np.sum(performance['cross_modal_dprime']>=1.0))
decoder_confidence_before_after_target['rewarded_target'].append(sign_corrected_decision_function[rewarded_target_trials])
Expand Down

0 comments on commit 5c5ca7e

Please sign in to comment.