From fed545c583b6409b28ef9070a9e1dbd7a294e45d Mon Sep 17 00:00:00 2001 From: Shailaja Akella Date: Wed, 11 Dec 2024 17:45:34 -0800 Subject: [PATCH] added glm utitlity functions --- src/dynamic_routing_analysis/IO_utils.py | 846 ++++++++++++++++++++++ src/dynamic_routing_analysis/glm_utils.py | 521 +++++++++++++ 2 files changed, 1367 insertions(+) create mode 100644 src/dynamic_routing_analysis/IO_utils.py create mode 100644 src/dynamic_routing_analysis/glm_utils.py diff --git a/src/dynamic_routing_analysis/IO_utils.py b/src/dynamic_routing_analysis/IO_utils.py new file mode 100644 index 0000000..7ecc61a --- /dev/null +++ b/src/dynamic_routing_analysis/IO_utils.py @@ -0,0 +1,846 @@ +import logging +import time + +import npc_lims +import npc_sessions +import numpy as np +import pandas as pd +import xarray as xr +from tqdm import tqdm + +logger = logging.getLogger(__name__) # debug < info < warning < error + +pd.set_option('display.max_columns', None) + + +def get_data_from_npc_sessions(session_id): + """Fetch data from DynamicRoutingSession if files are not found.""" + try: + session = npc_sessions.DynamicRoutingSession(session_id) + trials = session.trials[:] + dprimes = np.array(session.performance.cross_modal_dprime[:]) + epoch = session.epochs[:] + behavior_info = {'trials': trials, + 'dprime': dprimes, + 'is_good_behavior': np.count_nonzero(dprimes >= 1) >= 4, + 'epoch_info': epoch} + units_table = session.units[:] + return session, units_table, behavior_info, None + except Exception as e: + raise FileNotFoundError(f"Failed to load data from DynamicRoutingSession: {e}") + + +def get_session_data(session_id, version='0.0.259'): + + ''' + + :param session_id: ecephys session_id + :param version: cache version + :return: session object (if found), units_table, trials table, + epoch information and session performance + ''' + + # to get current cache version + # npc_lims.get_current_cache_version() + try: + # Attempt to load data from cached files + trials = pd.read_parquet( + npc_lims.get_cache_path('trials', session_id, version=version) + ) + dprimes = np.array(pd.read_parquet( + npc_lims.get_cache_path('performance', session_id, version=version) + ).cross_modal_dprime.values) + epoch = pd.read_parquet( + npc_lims.get_cache_path('epochs', session_id, version=version) + ) + behavior_info = {'trials': trials, + 'dprime': dprimes, + 'is_good_behavior': np.count_nonzero(dprimes >= 1) >= 4, + 'epoch_info': epoch} + + units_table = pd.read_parquet( + npc_lims.get_cache_path('units', session_id, version=version) + ) + return None, units_table, behavior_info, None + + except FileNotFoundError: + # Attempt to load data from DynamicRoutingSession as a fallback + logger.warning(f"File not found for session_id {session_id}. Attempting fallback.") + return get_data_from_npc_sessions(session_id) + + except Exception as e: + raise FileNotFoundError(f"Unexpected error occurred: {e}") + + +def setup_units_table(run_params, units_table): + ''' + Returns the units_table with the column indicating QC + Filters the table for area specific runs + ''' + + units_table['good_unit'] = (units_table['isi_violations_ratio'] < run_params['unit_inclusion_criteria'][ + 'isi_violations']) & \ + (units_table['presence_ratio'] > run_params['unit_inclusion_criteria'][ + 'presence_ratio']) & \ + (units_table['amplitude_cutoff'] < run_params['unit_inclusion_criteria'][ + 'amplitude_cutoff']) & \ + (units_table['firing_rate'] > run_params['unit_inclusion_criteria']['firing_rate']) + + if run_params['run_on_qc_units']: + units_table = units_table[units_table.good_unit] + + areas_to_include = run_params.get('areas_to_include', []) + if areas_to_include: + units_table = units_table[units_table.structure.isin(areas_to_include)] + + areas_to_exclude = run_params.get('areas_to_exclude', []) + if areas_to_exclude: + units_table = units_table[~units_table.structure.isin(areas_to_exclude)] + + return units_table + + +def setup_trials_table(run_params, trials_table): + ''' + Returns trials table excluding aborted trials if running encoding on quiescent period + ''' + + # TO-DO: find out how? Find out what else to include in the trials table. + + return trials_table + + +def define_kernels(run_params): + ''' + Returns kernel info for input variables + ''' + + # Define master kernel list + master_kernels_list = { + 'intercept': {'function_call': 'intercept', 'type': 'discrete', 'length': 0, 'offset': 0, + 'orthogonalize': None, 'num_weights': None, 'dropout': True, 'text': 'constant value'}, + 'vis1_vis': {'function_call': 'stimulus', 'type': 'discrete', 'length': 1, 'offset': 1, 'orthogonalize': None, + 'num_weights': None, 'dropout': True, 'text': 'target stim in rewarded context'}, + 'sound1_vis': {'function_call': 'stimulus', 'type': 'discrete', 'length': 1, 'offset': 1, 'orthogonalize': None, + 'num_weights': None, 'dropout': True, 'text': 'target stim in non-rewarded context'}, + 'vis2_vis': {'function_call': 'stimulus', 'type': 'discrete', 'length': 1, 'offset': 1, 'orthogonalize': None, + 'num_weights': None, 'dropout': True, 'text': 'non-target stim in vis context'}, + 'sound2_vis': {'function_call': 'stimulus', 'type': 'discrete', 'length': 1, 'offset': 1, 'orthogonalize': None, + 'num_weights': None, 'dropout': True, 'text': 'non-target stim in vis context'}, + 'vis1_aud': {'function_call': 'stimulus', 'type': 'discrete', 'length': 1, 'offset': 1, 'orthogonalize': None, + 'num_weights': None, 'dropout': True, 'text': 'target stim in non-rewarded context'}, + 'sound1_aud': {'function_call': 'stimulus', 'type': 'discrete', 'length': 1, 'offset': 1, 'orthogonalize': None, + 'num_weights': None, 'dropout': True, 'text': 'target stim in rewarded context'}, + 'vis2_aud': {'function_call': 'stimulus', 'type': 'discrete', 'length': 1, 'offset': 1, 'orthogonalize': None, + 'num_weights': None, 'dropout': True, 'text': 'non-target stim in aud context'}, + 'sound2_aud': {'function_call': 'stimulus', 'type': 'discrete', 'length': 1, 'offset': 1, 'orthogonalize': None, + 'num_weights': None, 'dropout': True, 'text': 'non-target stim in aud context'}, + 'nose': {'function_call': 'LP_features', 'type': 'continuous', 'length': 1, 'offset': -0.5, + 'orthogonalize': None, 'num_weights': None, 'dropout': True, + 'text': 'Z-scored Euclidean displacement of nose movements'}, + 'ears': {'function_call': 'LP_features', 'type': 'continuous', 'length': 1, 'offset': -0.5, + 'orthogonalize': None, 'num_weights': None, 'dropout': True, + 'text': 'Z-scored Euclidean displacement of ear movements'}, + 'jaw': {'function_call': 'LP_features', 'type': 'continuous', 'length': 1, 'offset': -0.5, + 'orthogonalize': None, 'num_weights': None, 'dropout': True, + 'text': 'Z-scored Euclidean displacement of jaw movements'}, + 'whisker_pad': {'function_call': 'LP_features', 'type': 'continuous', 'length': 1, 'offset': -0.5, + 'orthogonalize': None, 'num_weights': None, 'dropout': True, + 'text': 'Z-scored Euclidean displacement of whisker pad movements'}, + 'licks': {'function_call': 'licks', 'type': 'discrete', 'length': 1, 'offset': -0.5, 'orthogonalize': None, + 'num_weights': None, 'dropout': True, 'text': 'lick responses'}, + 'running': {'function_call': 'running', 'type': 'continuous', 'length': 1, 'offset': -0.5, + 'orthogonalize': None, 'num_weights': None, 'dropout': True, 'text': 'Z-scored running speed'}, + 'pupil': {'function_call': 'pupil', 'type': 'continuous', 'length': 1, 'offset': -0.5, 'orthogonalize': None, + 'num_weights': None, 'dropout': True, 'text': 'Z-scored pupil diameter'}, + 'hit': {'function_call': 'choice', 'type': 'discrete', 'length': 3, 'offset': -1.5, 'orthogonalize': None, + 'num_weights': None, 'dropout': True, 'text': 'lick to GO trial'}, + 'miss': {'function_call': 'choice', 'type': 'discrete', 'length': 3, 'offset': -1.5, 'orthogonalize': None, + 'num_weights': None, 'dropout': True, 'text': 'no lick to GO trial'}, + 'correct_reject': {'function_call': 'choice', 'type': 'discrete', 'length': 3, 'offset': -1.5, + 'orthogonalize': None, 'num_weights': None, 'dropout': True, + 'text': 'no lick to NO-GO trial'}, + 'false_alarm': {'function_call': 'choice', 'type': 'discrete', 'length': 3, 'offset': -1.5, + 'orthogonalize': None, 'num_weights': None, 'dropout': True, 'text': 'lick to NO-GO trial'}, + 'context': {'function_call': 'context', 'type': 'discrete', 'length': 0, 'offset': 0, 'orthogonalize': None, + 'num_weights': None, 'dropout': True, 'text': 'block-wise context'}, + 'session_time': {'function_call': 'session_time', 'type': 'continuous', 'length': 0, 'offset': 0, + 'orthogonalize': None, 'num_weights': None, 'dropout': True, + 'text': 'z-scored time in session'} + } + + # Define categories for input variables + categories = { + 'stimulus': ['vis1_vis', 'sound1_vis', 'vis2_vis', 'sound2_vis', 'vis1_aud', 'sound1_aud', 'vis2_aud', + 'sound2_aud'], + 'movements': ['ears', 'nose', 'jaw', 'whisker_pad', 'running', 'pupil', 'licks'], + 'movements_no_licks': ['ears', 'nose', 'jaw', 'whisker_pad', 'running', 'pupil'], + 'choice': ['hit', 'miss', 'correct_reject', 'false_alarm'], + 'LP_features': ['ears', 'nose', 'jaw', 'whisker_pad'] + } + + # Initialize selected keys list + selected_keys = [] + + # Determine selected input variables based on run_params + time_of_interest = run_params.get('time_of_interest', '') + input_variables = run_params.get('input_variables', []) + + # Choose input variables based on 'time_of_interest' + if not input_variables: + if 'trial' in time_of_interest or time_of_interest == 'full': + selected_keys = categories['stimulus'] + categories['movements'] + categories['choice'] + ['context', + 'session_time'] + elif 'quiescent' in time_of_interest: + selected_keys = categories['movements_no_licks'] + ['context', 'session_time'] + elif 'spontaneous' in time_of_interest: + selected_keys = categories['movements_no_licks'] + ['session_time'] + else: + # Extend selected_keys with input variables + for input_variable in input_variables: + selected_keys.extend(categories.get(input_variable, [input_variable])) + + # Add intercept if required + if run_params.get('intercept', False) and 'intercept' not in selected_keys: + selected_keys.append('intercept') + + # Log error if no input variables are selected + if not selected_keys: + raise ValueError("No input variables selected!") # raise value error . + + # remove drop variables if any + drop_keys = run_params.get('drop_variables', []) + if drop_keys and run_params['model_label'] != 'fullmodel': + for drop_key in drop_keys: + sub_keys = categories.get(drop_key, [drop_key]) + for sub_key in sub_keys: + print(get_timestamp() + f': dropping {sub_key}') + selected_keys.remove(sub_key) + + # Build kernels dictionary based on selected keys + kernels = {key: master_kernels_list[key] for key in selected_keys} + + # Update kernel lengths based on run_params + input_window_lengths = run_params.get('input_window_lengths', {}) + if input_window_lengths: + for key, length in input_window_lengths.items(): + if key in kernels: + kernels[key]['length'] = length + else: + raise KeyError(f"Key {key} not found in kernels.") + + # Update orthogonalization keys + input_ortho_keys = run_params.get('orthogonalize_against_context', []) + if input_ortho_keys: + ortho_keys = [] + for input_variable in input_ortho_keys: + ortho_keys.extend(categories.get(input_variable, [input_variable])) + for key in ortho_keys: + if key in kernels: + kernels[key]['orthogonalize'] = True + + return kernels + + +def get_spont_times(run_params, behavior_info): + ''' + Returns timestamps for spontaneous period based on time of interest. + ''' + def pick_values(start, stop, N, L): + iti = 5 # inter-trial interval + arr = np.arange(start, stop - L, iti + L) # Ensure end range does not exceed the stop + if len(arr) < N: # Handle edge case where N exceeds possible choices + logger.warning("Not enough intervals to pick from. Reducing number of snippets.") + N = len(arr) + picked_vals = np.zeros((N, 2)) + picked_vals[:, 0] = np.sort(np.random.choice(arr, size=N, replace=False)) + picked_vals[:, 1] = picked_vals[:, 0] + L + return picked_vals + + epoch = behavior_info['epoch_info'] + if 'Spontaneous' not in epoch.script_name.values: + logger.warning("No spontaneous activity recorded for this session.") + return np.empty((0, 2)) # Return an empty array if no spontaneous data exists + + start_times = epoch[epoch.script_name == 'Spontaneous'].start_time.values + stop_times = epoch[epoch.script_name == 'Spontaneous'].stop_time.values + num_snippets = 0 + L = 0 + + if 'full' in run_params['time_of_interest'] or run_params['time_of_interest'] == 'spontaneous': + return np.column_stack((start_times, stop_times)) + + elif 'trial' in run_params['time_of_interest']: + T = run_params['spontaneous_duration'] + L = run_params['trial_stop_time'] - run_params['trial_start_time'] + num_snippets = int(T // L) + + elif 'quiescent' in run_params['time_of_interest']: + T = run_params['spontaneous_duration'] + L = run_params['quiescent_stop_time'] - run_params['quiescent_start_time'] + num_snippets = int(T // L) + + intervals = [] + for i in range(len(start_times)): + snippets_per_epoch = num_snippets // len(start_times) + intervals.append( + pick_values(start_times[i], stop_times[i], snippets_per_epoch, L) + ) + + return np.vstack(intervals) + + +def establish_timebins(run_params, fit, behavior_info): + ''' + Returns the actual timestamps for each time bin that will be used in the regression model + ''' + + bin_starts = [] + epoch_trace = [] + if 'spontaneous' in run_params['time_of_interest'] or run_params['time_of_interest'] == 'full': + spont_times = get_spont_times(run_params, behavior_info) + for n in range(spont_times.shape[0]): + bin_edges = np.arange(spont_times[n, 0], spont_times[n, 1], run_params['spike_bin_width']) + bin_starts.append(bin_edges[:-1]) + epoch_trace.append([f'spontaneous{n}'] * len(bin_edges[:-1])) + + if 'full' in run_params['time_of_interest']: + if 'trial' in run_params['time_of_interest'] or run_params['time_of_interest'] == 'full': + start = behavior_info['trials'].start_time.values + stop = np.append(start[1:], behavior_info['trials'].stop_time.values[-1]) + for n in range(len(behavior_info['trials'])): + bin_edges = np.arange(start[n], stop[n], run_params['spike_bin_width']) + bin_starts.append(bin_edges[:-1]) + epoch_trace.append([f'trial{n}'] * len(bin_edges[:-1])) + + elif 'trial' in run_params['time_of_interest']: + start = behavior_info['trials'].stim_start_time.values + run_params['trial_start_time'] + stop = behavior_info['trials'].stim_start_time.values + run_params['trial_stop_time'] + for n in range(len(behavior_info['trials'])): + bin_edges = np.arange(start[n], stop[n], run_params['spike_bin_width']) + bin_starts.append(bin_edges[:-1]) + epoch_trace.append([f'trial{n}'] * len(bin_edges[:-1])) + + if 'quiescent' in run_params['time_of_interest']: + start = behavior_info['trials'].stim_start_time.values + run_params['quiescent_start_time'] + stop = behavior_info['trials'].stim_start_time.values + run_params['quiescent_stop_time'] + for n in range(len(behavior_info['trials'])): + bin_edges = np.arange(start[n], stop[n], run_params['spike_bin_width']) + bin_starts.append(bin_edges[:-1]) + epoch_trace.append([f'trial{n}'] * len(bin_edges[:-1])) + + epoch_trace = np.concatenate(epoch_trace) + bin_starts = np.concatenate(bin_starts) + + sorted_indices = np.argsort(bin_starts) + bin_starts = bin_starts[sorted_indices] + epoch_trace = epoch_trace[sorted_indices] + + bin_ends = bin_starts + run_params['spike_bin_width'] + timebins = np.vstack([bin_starts, bin_ends]).T + + fit['spike_bin_width'] = run_params['spike_bin_width'] + fit['timebins'] = timebins + fit['bin_centers'] = bin_starts + run_params['spike_bin_width'] / 2 + fit['epoch_trace'] = epoch_trace + + # Extend time bins to include trace around existing time bins for time-embedding, to create a fuller trace. + scale_factor = int(1 / run_params['spike_bin_width']) + result = next((x for x in range(2 * scale_factor, 6 * scale_factor) if x % 1 == 0), None) + r = result / scale_factor if result is not None else None + + bin_starts_all = [] + epoch_trace_all = [] + for epoch in np.unique(epoch_trace): + # Extend start by `r` + bins = np.arange(bin_starts[epoch_trace == epoch][0] - r, + bin_ends[epoch_trace == epoch][-1] + r, + run_params['spike_bin_width']) + bin_starts_all.append(bins) + epoch_trace_all.append([epoch]*len(bins)) + bin_starts_all = np.concatenate(bin_starts_all, axis=0) + epoch_trace_all = np.concatenate(epoch_trace_all) + + sorted_indices = np.argsort(bin_starts_all) + bin_starts_all = bin_starts_all[sorted_indices] + epoch_trace_all = epoch_trace_all[sorted_indices] + + bin_ends_all = bin_starts_all + run_params['spike_bin_width'] + timebins_all = np.vstack([bin_starts_all, bin_ends_all]).T + + fit['timebins_all'] = timebins_all + fit['bin_centers_all'] = bin_starts_all + run_params['spike_bin_width'] / 2 + fit['epoch_trace_all'] = epoch_trace_all + precision = 5 + rounded_times = np.round(timebins[:, 0], precision) + fit['mask'] = np.array([index for index, value in enumerate(timebins_all[:, 0]) if np.round(value, precision) in rounded_times]) + + assert len(fit['mask']) == timebins.shape[0], 'Incorrect masking, recheck timebins.' + # potentially a precision problem + + return fit + + +def get_spike_counts(spike_times, timebins): + ''' + spike_times, a list of spike times, sorted + timebins, numpy array of bins X start/stop + timebins[i,0] is the start of bin i + timbins[i,1] is the end of bin i + ''' + + counts = np.zeros([np.shape(timebins)[0]]) + spike_pointer = 0 + bin_pointer = 0 + while (spike_pointer < len(spike_times)) & (bin_pointer < np.shape(timebins)[0]): + if spike_times[spike_pointer] < timebins[bin_pointer, 0]: + # This spike happens before the time bin, advance spike + spike_pointer += 1 + elif spike_times[spike_pointer] >= timebins[bin_pointer, 1]: + # This spike happens after the time bin, advance time bin + bin_pointer += 1 + else: + counts[bin_pointer] += 1 + spike_pointer += 1 + + return counts + + +def process_spikes(units_table, run_params, fit): + ''' + Returns a dictionary including spike counts and unit-specific information. + ''' + + # identifies good units + units_table = setup_units_table(run_params, units_table) + + spikes = np.zeros((fit['timebins'].shape[0], len(units_table))) + + for uu, (_, unit) in tqdm(enumerate(units_table.iterrows()), total=len(units_table), desc='getting spike counts'): + spikes[:, uu] = get_spike_counts(np.array(unit['spike_times']), fit['timebins']) + + spike_count_arr = { + 'spike_counts': spikes, + 'bin_centers': fit['bin_centers'], + 'unit_id': units_table.unit_id.values, + 'structure': units_table.structure.values, + 'location': units_table.location.values, + 'quality': units_table.good_unit.values + } + fit['spike_count_arr'] = spike_count_arr + + # Check to make sure there are no NaNs in the fit_trace + assert np.isnan(fit['spike_count_arr']['spike_counts']).sum() == 0, "Have NaNs in spike_count_arr" + + return fit + + +def extract_unit_data(run_params, units_table, behavior_info): + ''' + Creates the fit dictionary + establishes time bins + processes spike times into spike counts for each time bin + ''' + + fit = dict() + fit = establish_timebins(run_params, fit, behavior_info) + fit = process_spikes(units_table, run_params, fit) + + return fit + + +def add_kernels(design, run_params, session, fit, behavior_info): + ''' + Iterates through the kernels in run_params['kernels'] and adds + each to the design matrix + Each kernel must have fields: + offset: + length: + + design the design matrix for this model + run_params the run_json for this model + session the SDK session object for this experiment + fit the fit object for this model + ''' + + run_params['kernels'] = define_kernels(run_params) + fit['failed_kernels'] = set() + fit['kernel_error_dict'] = dict() + + if session is None: + session = npc_sessions.DynamicRoutingSession(run_params["session_id"]) + + for kernel_name in run_params['kernels']: + if 'num_weights' not in run_params['kernels'][kernel_name]: + run_params['kernels'][kernel_name]['num_weights'] = None + design, fit = add_kernel_by_label(kernel_name, design, run_params, session, fit, behavior_info) + + return design, fit + + +def add_kernel_by_label(kernel_name, design, run_params, session, fit, behavior_info): + ''' + Adds the kernel specified by to the design matrix + kernel_name the label for this kernel, will raise an error if not implemented + design the design matrix for this model + run_params the run_json for this model + session the session object for this experiment + fit the fit object for this model + ''' + + print(get_timestamp() + ' Adding kernel: ' + kernel_name) + + try: + kernel_function = globals().get(run_params['kernels'][kernel_name]['function_call']) + if not callable(kernel_function): + raise ValueError(f"Invalid kernel name: {kernel_name}") + input_x = kernel_function(kernel_name, session, fit, behavior_info) + + if run_params['kernels'][kernel_name]['type'] == 'continuous': + input_x = standardize_inputs(input_x) + + if run_params['kernels'][kernel_name]['orthogonalize']: + context_kernel = context('context', session, fit, behavior_info) \ + if 'context' not in design.events.keys() else design.events['context'] + input_x = orthogonalize_this_kernel(input_x, context_kernel) + input_x = standardize_inputs(input_x) + + except Exception as e: + print(get_timestamp() + f"Exception: {e}") + print('Attempting to continue without this kernel.') + + fit['failed_kernels'].add(kernel_name) + fit['kernel_error_dict'][kernel_name] = { + 'error_type': 'kernel', + 'kernel_name': kernel_name, + 'exception': e.args[0], + } + return design, fit + else: + design.add_kernel( + input_x, + run_params['kernels'][kernel_name]['length'], + kernel_name, + offset=run_params['kernels'][kernel_name]['offset'], + num_weights=run_params['kernels'][kernel_name]['num_weights'] + ) + return design, fit + + +def intercept(kernel_name, session, fit, behavior_info): + return np.ones(len(fit['bin_centers_all'])) + + +def context(kernel_name, session, fit, behavior_info): + this_kernel = np.zeros(len(fit['bin_centers_all'])) + epoch_trace = fit['epoch_trace_all'] + + for n, epoch in enumerate(epoch_trace): + if 'trial' in epoch: + trial_no = int(''.join(filter(str.isdigit, epoch))) + this_kernel[n] = 1 if behavior_info['trials'].loc[trial_no, 'is_vis_context'] else -1 + + return this_kernel + + +def pupil(kernel_name, session, fit, behavior_info): + def process_pupil_data(df, behavior_info): + for pos, row in behavior_info['epoch_info'].iterrows(): + # Select rows within the current epoch + epoch_mask = (df.timestamps >= row.start_time) & (df.timestamps < row.stop_time) + epoch_df = df.loc[epoch_mask] + + # Compute the threshold for the current epoch + threshold = np.nanmean(epoch_df['pupil_area']) + 3 * np.nanstd(epoch_df['pupil_area']) + + # Apply threshold and set outliers to NaN within the epoch + df.loc[epoch_mask & (df['pupil_area'] > threshold), 'pupil_area'] = np.nan + df['pupil_area'] = df['pupil_area'].interpolate(method='linear') + return df + + df = process_pupil_data(session._eye_tracking.to_dataframe(), behavior_info) + return bin_timeseries(df.pupil_area.values, df.timestamps.values, fit['timebins_all']) + + +def running(kernel_name, session, fit, behavior_info): + return bin_timeseries(session._running_speed.data, session._running_speed.timestamps, fit['timebins_all']) + + +def licks(kernel_name, session, fit, behavior_info): + lick_times = session._all_licks[0].timestamps + + # Extract the bin edges + bin_starts, bin_stops = fit['timebins_all'][:, 0], fit['timebins_all'][:, 1] + + # Check if any lick times are within each bin + in_bin = (lick_times[:, None] >= bin_starts) & (lick_times[:, None] < bin_stops) + this_kernel = np.any(in_bin, axis=0).astype(int) + + return this_kernel + + +def LP_features(kernel_name, session, fit, behavior_info): + def eu_dist_for_LP(x_c, y_c): + return np.sqrt(x_c ** 2 + y_c ** 2) + + def part_info_LP(part_name, df): + confidence = df[part_name + '_likelihood'].values.astype('float') + temp_norm = df[part_name + '_temporal_norm'].values.astype('float') + x_c = df[part_name + '_x'].values.astype('float') + y_c = df[part_name + '_y'].values.astype('float') + xy = eu_dist_for_LP(x_c, y_c) + + xy[(confidence < 0.98) | (temp_norm > np.nanmean(temp_norm) + 3 * np.nanstd(temp_norm))] = np.nan + xy = pd.Series(xy).interpolate(limit_direction='both').to_numpy() + return xy, confidence + + map_names = {'ears': 'ear_base_l', 'jaw': 'jaw', 'nose': 'nose_tip', 'whisker_pad': 'whisker_pad_l_side'} + try: + df = session._lp[0][:] + except IndexError: + raise IndexError(f'{session.id} is not a session with video.') + timestamps = df['timestamps'].values.astype('float') + lp_part_name = map_names[kernel_name] + part_xy, confidence = part_info_LP(lp_part_name, df) + return bin_timeseries(part_xy, timestamps, fit['timebins_all']) + + +def choice(kernel_name, session, fit, behavior_info): + bin_starts, bin_stops = fit['timebins_all'][:, 0], fit['timebins_all'][:, 1] + if behavior_info['trials']['is_' + kernel_name].any(): + choice_times = behavior_info['trials'][behavior_info['trials']['is_' + kernel_name]].stim_start_time.values + in_bin = (choice_times[:, None] >= bin_starts) & (choice_times[:, None] < bin_stops) + this_kernel = np.any(in_bin, axis=0).astype(int) + else: + raise ValueError(f"No trials with is_{kernel_name}") + return this_kernel + + +def stimulus(kernel_name, session, fit, behavior_info): + stim_name, context_name = kernel_name.split('_') + bin_starts, bin_stops = fit['timebins_all'][:, 0], fit['timebins_all'][:, 1] + filtered_trials = behavior_info['trials'][ + (behavior_info['trials'].stim_name == stim_name) & (behavior_info['trials'].context_name == context_name)] + if not filtered_trials.empty: + stim_times = filtered_trials.stim_start_time.values + in_bin = (stim_times[:, None] >= bin_starts) & (stim_times[:, None] < bin_stops) + this_kernel = np.any(in_bin, axis=0).astype(int) + else: + raise ValueError(f"No trials presented with {stim_name} stimulus in {context_name} context") + return this_kernel + + +def session_time(kernel_name, session, fit, behavior_info): + return fit['bin_centers_all'] + + +# def toeplitz_this_kernel(this_kernel, kernel_length_samples, offset_samples, epoch_trace): +# top_kernel = np.zeros((len(epoch_trace), kernel_length_samples)) + np.nan +# unique_epochs = np.array([epoch_trace[i] for i in sorted(np.unique(epoch_trace, return_index=True)[1])]) +# for epoch in unique_epochs: +# ind = np.where(epoch_trace == epoch)[0] +# top_kernel[ind] = toeplitz(this_kernel[ind], kernel_length_samples, offset_samples).T +# return top_kernel.T + + +def toeplitz_this_kernel(input_x, kernel_length_samples, offset_samples): + ''' + Build a toeplitz matrix aligned to events. + + Args: + events (np.array of 1/0): Array with 1 if the event happened at that time, and 0 otherwise. + kernel_length_samples (int): How many kernel parameters + Returns + np.array, size(len(events), kernel_length_samples) of 1/0 + ''' + + total_len = len(input_x) + events = np.concatenate([input_x, np.zeros(kernel_length_samples)]) + arrays_list = [events] + for i in range(kernel_length_samples - 1): + arrays_list.append(np.roll(events, i + 1)) + this_kernel = np.vstack(arrays_list) + + # pad with zeros, roll offset_samples, and truncate to length + if offset_samples < 0: + this_kernel = np.concatenate([np.zeros((this_kernel.shape[0], np.abs(offset_samples))), this_kernel], axis=1) + this_kernel = np.roll(this_kernel, offset_samples)[:, np.abs(offset_samples):] + elif offset_samples > 0: + this_kernel = np.concatenate([this_kernel, np.zeros((this_kernel.shape[0], offset_samples))], axis=1) + this_kernel = np.roll(this_kernel, offset_samples)[:, :-offset_samples] + return this_kernel[:, :total_len] + + +class DesignMatrix: + def __init__(self, fit): + ''' + A toeplitz-matrix builder for running regression with multiple temporal kernels. + + Args + fit_dict, a dictionary with: + event_timestamps: The actual timestamps for each time bin that will be used in the regression model. + ''' + self.X = None + self.kernel_dict = {} + self.running_stop = 0 + self.events = {'timestamps': fit['bin_centers']} + self.spike_bin_width = fit['spike_bin_width'] + self.epochs = fit['epoch_trace'] + self.mask = fit['mask'] + + def make_labels(self, label, num_weights, offset, length): + base = [label] * num_weights + numbers = [str(x) for x in np.array(range(0, length)) + offset] + return [x[0] + '_' + x[1] for x in zip(base, numbers)] + + def add_kernel(self, input_x, kernel_length, label, offset=0, num_weights=None): + ''' + Add a temporal kernel. + + Args: + input_x (np.array): input data points + kernel_length (int): length of the kernel (in SECONDS). + label (string): Name of the kernel. + offset (int) :offset relative to the events. Negative offsets cause the kernel + to overhang before the event (in SECONDS) + ''' + + # Enforce unique labels + if label in self.kernel_dict.keys(): + raise ValueError('Labels must be unique') + + self.events[label] = input_x[self.mask] + + # CONVERT kernel_length to kernel_length_samples + if num_weights is None: + if kernel_length == 0: + kernel_length_samples = 1 + else: + kernel_length_samples = int(np.ceil((1 / self.spike_bin_width) * kernel_length)) + else: + # Some kernels are hard-coded by number of weights + kernel_length_samples = num_weights + + # CONVERT offset to offset_samples + offset_samples = int(np.floor((1 / self.spike_bin_width) * offset)) + + if np.abs(offset_samples) > 0: + this_kernel = toeplitz_this_kernel(input_x, kernel_length_samples, offset_samples) + else: + this_kernel = input_x.reshape(1, -1) + + # keep only the relevant trace for prediction + this_kernel = this_kernel[:, self.mask] + + self.kernel_dict[label] = { + 'kernel': this_kernel, + 'kernel_length_samples': kernel_length_samples, + 'offset_samples': offset_samples, + 'kernel_length_seconds': kernel_length, + 'offset_seconds': offset, + 'ind_start': self.running_stop, + 'ind_stop': self.running_stop + kernel_length_samples + } + self.running_stop += kernel_length_samples + + def get_X(self, kernels=None): + ''' + Get the design matrix. + Args: + kernels (optional list of kernel string names): which kernels to include (for model selection) + Returns: + X (xr.array): The design matrix + ''' + if kernels is None: + kernels = self.kernel_dict.keys() + + kernels_to_use = [] + param_labels = [] + for kernel_name in kernels: + kernels_to_use.append(self.kernel_dict[kernel_name]['kernel']) + param_labels.append(self.make_labels(kernel_name, + np.shape(self.kernel_dict[kernel_name]['kernel'])[0], + self.kernel_dict[kernel_name]['offset_samples'], + self.kernel_dict[kernel_name]['kernel_length_samples'])) + + X = np.vstack(kernels_to_use) + x_labels = np.hstack(param_labels) + + assert np.shape(X)[0] == np.shape(x_labels)[0], 'Weight Matrix must have the same length as the weight labels' + + X_array = xr.DataArray( + X, + dims=('weights', 'timestamps'), + coords={'weights': x_labels, + 'timestamps': self.events['timestamps']} + ) + self.X = X_array.T + return X_array.T + + +# def bin_timeseries2(x, x_timestamps, timebins): +# Tm = timebins.shape[0] +# start_indices = np.searchsorted(x_timestamps, timebins[:, 0], side='left') +# end_indices = np.searchsorted(x_timestamps, timebins[:, 1], side='right') +# +# binned = np.full(Tm, np.nan) +# +# for t_i in range(Tm): +# indices = slice(start_indices[t_i], end_indices[t_i]) +# if indices.start < indices.stop: # Check if the slice is non-empty +# binned[t_i] = np.nanmean(x[indices]) +# elif t_i > 0: # Propagate previous value +# binned[t_i] = binned[t_i - 1] +# return binned + + +def bin_timeseries(x, x_timestamps, timebins_all): + Tm = timebins_all.shape[0] + start_indices = np.searchsorted(x_timestamps, timebins_all[:, 0], side='left') + end_indices = np.searchsorted(x_timestamps, timebins_all[:, 1], side='right') + + binned = np.full(Tm, np.nan) + mask = [] + for t_i in range(Tm): + indices = slice(start_indices[t_i], end_indices[t_i]) + if indices.start < indices.stop: # Check if the slice is non-empty + binned[t_i] = np.nanmean(x[indices]) + elif t_i > 0: # Propagate previous value + binned[t_i] = binned[t_i - 1] + return binned + + +def get_timestamp(): + t = time.localtime() + return time.strftime('%Y-%m-%d: %H:%M:%S') + ' ' + + +def orthogonalize_this_kernel(this_kernel, y): + mat_to_ortho = np.concatenate((y.reshape(-1, 1), this_kernel.reshape(-1, 1)), axis=1) + print(get_timestamp() + ' : ' + 'othogonalizing against context') + Q, R = np.linalg.qr(mat_to_ortho) + return Q[:, 1] + + +def standardize_inputs(timeseries, mean_center=True, unit_variance=True, max_value=None): + ''' + Performs three different input standarizations to the timeseries + if mean_center, the timeseries is adjusted to have 0-mean. This can be performed with unit_variance. + if unit_variance, the timeseries is adjusted to have unit variance. This can be performed with mean_center. + if max_value is given, then the timeseries is normalized by max_value. This cannot be performed with mean_center and unit_variance. + ''' + if (max_value is not None) & (mean_center or unit_variance): + raise Exception( + 'Cannot perform max_value standardization and mean_center or unit_variance standardizations together.') + + if mean_center: + print(get_timestamp() + ' : ' + 'mean centering') + timeseries = timeseries - np.mean(timeseries) # mean center + if unit_variance: + print(get_timestamp() + ' : ' + 'standardized to unit variance') + timeseries = timeseries / np.std(timeseries) + if max_value is not None: + print(get_timestamp() + ' : ' + 'normalized by max value: ' + str(max_value)) + timeseries = timeseries / max_value + + return timeseries diff --git a/src/dynamic_routing_analysis/glm_utils.py b/src/dynamic_routing_analysis/glm_utils.py new file mode 100644 index 0000000..c6f36ab --- /dev/null +++ b/src/dynamic_routing_analysis/glm_utils.py @@ -0,0 +1,521 @@ +import logging +import time +from concurrent.futures import ProcessPoolExecutor, as_completed + +import numpy as np +from numpy.linalg import LinAlgError +from sklearn.model_selection import GridSearchCV, KFold +from tqdm import tqdm + +logger = logging.getLogger(__name__) # debug < info < warning < error + + +class Ridge: + def __init__(self, lam=None, W=None): + self.lam = lam + self.r2 = None + self.mean_r2 = None + self.W = W + + def fit(self, X, y): + ''' + Analytical OLS solution with added L2 regularization penalty. + Y: shape (n_timestamps * n_cells) + X: shape (n_timestamps * n_kernel_params) + lam (float): Strength of L2 regularization (hyperparameter to tune) + ''' + + # Compute the weights + try: + if self.lam == 0: + self.W = np.dot(np.linalg.inv(np.dot(X.T, X)), np.dot(X.T, y)) + else: + self.W = np.dot(np.linalg.inv(np.dot(X.T, X) + self.lam * np.eye(X.shape[-1])), + np.dot(X.T, y)) + except LinAlgError as e: + logger.info(f"Matrix inversion failed due to a linear algebra error:{e}. Falling back to pseudo-inverse.") + # Fallback to pseudo-inverse + if self.lam == 0: + self.W = np.dot(np.linalg.pinv(np.dot(X.T, X)), np.dot(X.T, y)) + else: + self.W = np.dot(np.linalg.pinv(np.dot(X.T, X) + self.lam * np.eye(X.shape[-1])), + np.dot(X.T, y)) + except Exception as e: + print("Unexpected error encountered:", e) + raise # Re-raise the exception to propagate unexpected errors + + self.mean_r2 = self.score(X, y) + + return self + + def get_params(self, deep=True): + return {'lam': self.lam} + + def set_params(self, **parameters): + for parameter, value in parameters.items(): + setattr(self, parameter, value) + return self + + def score(self, X, y): + ''' + Computes the fraction of variance in fit_trace_arr explained by the linear model y = X*W + y: (n_timepoints, n_cells) + W: (kernel_params, n_cells) + X: (n_timepoints, n_kernel_params) + ''' + # Y = X.values @ W.values + y_pred = self.predict(X) + var_total = np.var(y, axis=0) # Total variance in the ophys trace for each cell + var_resid = np.var(y - y_pred, axis=0) # Residual variance in the difference between the model and data + self.r2 = (var_total - var_resid) / var_total + return np.nanmean(self.r2) # Fraction of variance explained by linear model + + def predict(self, X): + y = np.dot(X, self.W) + return y + + +def nested_train_and_test(design_mat, spike_counts, L2_grid, folds_outer=10, folds_inner=6): + X = design_mat.data + y = spike_counts + + kf = KFold(n_splits=folds_outer, shuffle=True, random_state=0) + lams = np.zeros(folds_outer) + np.nan + train_r2 = np.zeros((y.shape[-1], folds_outer)) + test_r2 = np.zeros((y.shape[-1], folds_outer)) + + # outer CV + for k, (train_index, test_index) in enumerate(kf.split(X)): + X_train, y_train = X[train_index], y[train_index] + X_test, y_test = X[test_index], y[test_index] + + # inner CV + cv_inner = KFold(n_splits=folds_inner, shuffle=True, random_state=1) + model = Ridge() + try: + search = GridSearchCV(model, {'lam': np.array(L2_grid)}, cv=cv_inner, refit=True, + n_jobs=1) + except LinAlgError: + continue + + result = search.fit(X_train, y_train) + best_model = result.best_estimator_ + lams[k] = result.best_params_['lam'] + + # needs to be calculated because it updates + # the r2 of the best model with the test-r2 + train_mean_score = best_model.score(X_train, y_train) + train_r2[:, k] = best_model.r2 + test_mean_score = best_model.score(X_test, y_test) + test_r2[:, k] = best_model.r2 + + lam = np.median(lams) + model = Ridge(lam=lam).fit(X, y) + weights = model.W + y_pred = model.predict(X) + + return clean_r2_vals(train_r2), clean_r2_vals(test_r2), weights, y_pred, lams + + +def simple_train_and_test(design_mat, spike_counts, lam, folds_outer=10): + """ + Train and test a Ridge regression model using cross-validation with specified lambda values. + + Args: + design_mat: Input design matrix containing data. + spike_counts: Target variable (spike counts). + lam: Regularization parameter (single value or a list of values for each fold). + folds_outer: Number of folds for outer cross-validation. + + Returns: + train_r2: Mean R2 score on training data across folds. + test_r2: Mean R2 score on testing data across folds. + weights: Model weights after training on the entire dataset. + y_pred: Predictions on the entire dataset. + """ + X = design_mat.data + y = spike_counts + + kf = KFold(n_splits=folds_outer, shuffle=True, random_state=0) + test_r2 = np.zeros((y.shape[-1], folds_outer)) + train_r2 = np.zeros((y.shape[-1], folds_outer)) + + # If lam is a scalar, convert it to a list with the same value for each fold + if not isinstance(lam, list): + lam = [lam] * folds_outer + + if len(lam) != folds_outer: + raise ValueError(f"Length of lam ({len(lam)}) must match number of folds ({folds_outer}).") + + for k, (train_index, test_index) in enumerate(kf.split(X)): + X_train, y_train = X[train_index], y[train_index] + X_test, y_test = X[test_index], y[test_index] + + model = Ridge(lam=lam[k]) # Use the k-th lambda value for this fold + try: + model.fit(X_train, y_train) + except LinAlgError: + logger.info("") + continue + + train_mean_score = model.score(X_train, y_train) + train_r2[:, k] = model.r2 + test_mean_score = model.score(X_test, y_test) + test_r2[:, k] = model.r2 + + # check if train and test are empty. raise LinAlgError + + # Train the model on the entire dataset with the median lambda value + model = Ridge(lam=np.median(lam)) + model.fit(X, y) + weights = model.W + y_pred = model.predict(X) + + return clean_r2_vals(train_r2), clean_r2_vals(test_r2), weights, y_pred + + +# Define function to process a single unit +def process_unit(unit_no, design_mat, fit, run_params, function): + """ + Process a single unit for optimization or fitting. + + Parameters: + - unit_no: int, the unit index to process. + - design_mat: ndarray, the design matrix for the model. + - fit: dict, contains fitting parameters (e.g., spike counts, L2 grid, regularization). + - run_params: dict, contains runtime parameters (e.g., number of folds, no_nested_CV). + - function: str, either 'optimize' or 'fit'. + + Returns: + - Tuple with results depending on the specified function. + """ + fit_cell = fit['spike_count_arr']['spike_counts'][:, unit_no].reshape(-1, 1) + + if function == 'optimize': + # Temporary storage for results + unit_train_cv = np.zeros(len(fit['L2_grid'])) + unit_test_cv = np.zeros(len(fit['L2_grid'])) + + for L2_index, L2_value in enumerate(fit['L2_grid']): + cv_var_train, cv_var_test, _, _ = simple_train_and_test( + design_mat, fit_cell, lam=L2_value, folds_outer=run_params['n_outer_folds'] + ) + # Store results for this unit and L2 value + unit_train_cv[L2_index] = np.nanmean(cv_var_train) # Fixed axis issue + unit_test_cv[L2_index] = np.nanmean(cv_var_test) # Fixed axis issue + + return unit_no, unit_train_cv, unit_test_cv + + elif function == 'fit': + if run_params['no_nested_CV']: + lam_value = fit['cell_L2_regularization'][unit_no] + elif 'cell_L2_regularization_nested' in run_params.keys(): + lam_value = fit['cell_L2_regularization_nested'][unit_no] + else: + # If nested CV is enabled, use nested_train_and_test + cv_train, cv_test, weights, prediction, lams = nested_train_and_test( + design_mat, + fit_cell, + L2_grid=fit['L2_grid'], + folds_outer=run_params['n_outer_folds'], + folds_inner=run_params['n_inner_folds'] + ) + return unit_no, cv_train, cv_test, weights, prediction, lams + + # Perform simple training and testing for regular cases + cv_train, cv_test, weights, prediction = simple_train_and_test( + design_mat, + fit_cell, + lam=lam_value, + folds_outer=run_params['n_outer_folds'] + ) + return unit_no, cv_train, cv_test, weights, prediction + + else: + raise ValueError(f"Invalid function type: {function}. Expected 'optimize' or 'fit'.") + + +def evaluate_ridge(fit, design_mat, run_params): + ''' + fit, model dictionary + design_mat, design matrix + run_params, dictionary of parameters, which needs to include: + optimize_penalty_by_cell # If True, uses the best L2 value for each cell + optimize_penalty_by_area # If True, uses the best L2 value for this session + use_fixed_penalty # If True, uses the hard coded L2_fixed_lambda + + L2_fixed_lambda # This value is used if L2_use_fixed_value + L2_grid_range # Min/Max L2 values for optimization + L2_grid_num # Number of L2 values for optimization + L2_grid_type # log or linear + + returns fit, with the values added: + L2_grid # the L2 grid evaluated + for the case of no nested CV, + avg_L2_regularization # the average optimal L2 value, or the fixed value + cell_L2_regularization # the optimal L2 value for each cell + ''' + + spike_counts = fit['spike_count_arr']['spike_counts'] + x_is_continuous = [run_params['kernels'][kernel_name.rsplit('_', 1)[0]]['type'] == 'continuous' + for kernel_name in design_mat.weights.values] + num_units = spike_counts.shape[1] + + if run_params['use_fixed_penalty']: + print(get_timestamp() + 'Using a hard-coded regularization value') + fit['L2_regularization'] = run_params['L2_fixed_lambda'] + + elif run_params['no_nested_CV']: + if run_params['L2_grid_type'] == 'log': + fit['L2_grid'] = np.array([0] + list(np.geomspace(run_params['L2_grid_range'][0], + run_params['L2_grid_range'][1], num=run_params['L2_grid_num']))) + else: + fit['L2_grid'] = np.array([0] + list(np.linspace(run_params['L2_grid_range'][0], + run_params['L2_grid_range'][1], num=run_params['L2_grid_num']))) + + train_cv = np.full((num_units, len(fit['L2_grid'])), np.nan) + test_cv = np.full((num_units, len(fit['L2_grid'])), np.nan) + + if run_params['optimize_penalty_by_cell']: + print(get_timestamp() + ': optimizing penalty by cell') + with ProcessPoolExecutor(max_workers=10) as executor: + futures = { + executor.submit(process_unit, unit_no, design_mat.copy(), fit.copy(), run_params,'optimize'): unit_no + for unit_no in range(num_units) + } + for future in tqdm(as_completed(futures), total=num_units, desc='Processing units in parallel'): + try: + unit_no, unit_train_cv, unit_test_cv = future.result() + except LinAlgError: + logger.info(f"{unit_no}") + continue + train_cv[unit_no, :] = unit_train_cv + test_cv[unit_no, :] = unit_test_cv + + elif run_params['optimize_penalty_by_area']: + print(get_timestamp() + ': optimizing L2 penalty by area') + areas = np.unique(fit['spike_count_arr']['structure']) + for area in areas: + unit_ids = np.where(fit['spike_count_arr']['structure'] == area)[0] + fit_area = spike_counts[:, unit_ids] + for L2_index, L2_value in tqdm(enumerate(fit['L2_grid']), + total=len(fit['L2_grid']), desc=area): + cv_var_train, cv_var_test, _, _, = simple_train_and_test(design_mat, fit_area, + lam=L2_value, + folds_outer=run_params['n_outer_folds']) + train_cv[unit_ids, L2_index] = np.nanmean(cv_var_train, axis=1) + test_cv[unit_ids, L2_index] = np.nanmean(cv_var_test, axis=1) + else: + print(get_timestamp() + ': optimizing L2 penalty for all cells') + for L2_index, L2_value in enumerate(fit['L2_grid']): + cv_var_train, cv_var_test, _, _, = simple_train_and_test(design_mat, + spike_counts, + lam=L2_value, + folds_outer=run_params['n_outer_folds']) + train_cv[:, L2_index] = np.nanmean(cv_var_train, axis=1) + test_cv[:, L2_index] = np.nanmean(cv_var_test, axis=1) + test_cv[:, L2_index] = np.nanmean(cv_var_test, axis=1) + + fit['avg_L2_regularization'] = np.mean([fit['L2_grid'][x] for x in np.argmax(test_cv, 1)]) + fit['cell_L2_regularization'] = [fit['L2_grid'][x] for x in np.argmax(test_cv, 1)] + fit['L2_test_cv'] = test_cv + fit['L2_train_cv'] = train_cv + fit['L2_at_grid_min'] = [x == 0 for x in np.argmax(test_cv, 1)] + fit['L2_at_grid_max'] = [x == (len(fit['L2_grid']) - 1) for x in np.argmax(test_cv, 1)] + else: + if run_params['L2_grid_type'] == 'log': + fit['L2_grid'] = np.array([0] + list(np.geomspace(run_params['L2_grid_range'][0], + run_params['L2_grid_range'][1], num=run_params['L2_grid_num']))) + else: + fit['L2_grid'] = np.array([0] + list(np.linspace(run_params['L2_grid_range'][0], + run_params['L2_grid_range'][1], num=run_params['L2_grid_num']))) + + return fit + + +def evaluate_models(fit, design_mat, run_params): + X = design_mat.data + spike_counts = fit['spike_count_arr']['spike_counts'] + # x_is_continuous = [run_params['kernels'][kernel_name.rsplit('_', 1)[0]]['type'] == 'continuous' + # for kernel_name in design_mat.weights.values] + + # Initialize outputs + num_units = spike_counts.shape[1] + num_outer_folds = run_params['n_outer_folds'] + cv_var_train = np.full((num_units, num_outer_folds), np.nan) + cv_var_test = np.full((num_units, num_outer_folds), np.nan) + all_weights = np.full((X.shape[1], num_units), np.nan) + all_prediction = np.full(spike_counts.shape, np.nan) + + if isinstance(run_params['cell_L2_regularization'], list): + fit['cell_L2_regularization'] = run_params['cell_L2_regularization'] + + if isinstance(run_params['cell_L2_regularization_nested'], list): + fit['cell_L2_regularization_nested'] = run_params['cell_L2_regularization_nested'] + + cell_L2_regularization_nested = np.full((num_units, num_outer_folds), np.nan) + + if run_params['use_fixed_penalty']: + cv_var_train, cv_var_test, all_weights, all_prediction = simple_train_and_test( + design_mat, spike_counts, + lam=fit['L2_regularization'], + folds_outer=num_outer_folds + ) + elif run_params['no_nested_CV']: + if run_params['optimize_penalty_by_cell']: + print(get_timestamp() + ': fitting each cell') + with ProcessPoolExecutor(max_workers=10) as executor: + futures = { + executor.submit(process_unit, unit_no, design_mat.copy(), fit.copy(), run_params, 'fit'): unit_no + for unit_no in range(num_units) + } + for future in tqdm(as_completed(futures), total=num_units, desc='progress'): + unit_no, cv_train, cv_test, weights, prediction = future.result() + cv_var_train[unit_no] = cv_train + cv_var_test[unit_no] = cv_test + all_weights[:, unit_no] = weights.reshape(-1) + all_prediction[:, unit_no] = prediction.reshape(-1) + + elif run_params['optimize_penalty_by_area']: + areas = np.unique(fit['spike_count_arr']['structure']) + print(get_timestamp() + ': fitting units by area') + for area in tqdm(areas, total=len(areas), desc='progress'): + unit_ids = np.where(fit['spike_count_arr']['structure'] == area)[0] + fit_area = spike_counts[:, unit_ids] + L2_value = np.unique(np.take(fit['cell_L2_regularization'], unit_ids))[0] + cv_train, cv_test, weights, prediction = simple_train_and_test(design_mat, fit_area, + lam=L2_value, + folds_outer=run_params['n_outer_folds']) + cv_var_train[unit_ids] = cv_train + cv_var_test[unit_ids] = cv_test + all_weights[:, unit_ids] = weights + all_prediction[:, unit_ids] = prediction + else: + print(get_timestamp() + ': fitting all units') + L2_value = np.unique(np.array(fit['cell_L2_regularization']))[0] + cv_var_train, cv_var_test, all_weights, all_prediction = simple_train_and_test(design_mat, + spike_counts, + lam=L2_value, + folds_outer=run_params[ + 'n_outer_folds']) + + elif 'cell_L2_regularization_nested' in fit: + if run_params['optimize_penalty_by_cell']: + print(get_timestamp() + ': fitting each cell') + with ProcessPoolExecutor(max_workers=10) as executor: + futures = { + executor.submit(process_unit, unit_no, design_mat.copy(), fit.copy(), run_params, 'fit'): unit_no + for unit_no in range(num_units) + } + for future in tqdm(as_completed(futures), total=num_units, desc='progress'): + unit_no, cv_train, cv_test, weights, prediction = future.result() + cv_var_train[unit_no] = cv_train + cv_var_test[unit_no] = cv_test + all_weights[:, unit_no] = weights.reshape(-1) + all_prediction[:, unit_no] = prediction.reshape(-1) + + elif run_params['optimize_penalty_by_area']: + areas = np.unique(fit['spike_count_arr']['structure']) + for area in tqdm(areas, total=len(areas), desc='progress'): + unit_ids = np.where(fit['spike_count_arr']['structure'] == area)[0] + fit_area = spike_counts[:, unit_ids] + L2_value = np.unique(np.take(fit['cell_L2_regularization_nested'], unit_ids), axis=0) + cv_train, cv_test, weights, prediction = simple_train_and_test(design_mat, fit_area, + lam=L2_value, + folds_outer=run_params['n_outer_folds']) + cv_var_train[unit_ids] = cv_var_train + cv_var_test[unit_ids] = cv_var_test + all_weights[unit_ids] = weights + all_prediction[unit_ids] = prediction + + else: + L2_value = np.unique(np.array(fit['cell_L2_regularization_nested']), axis=0) + cv_var_train, cv_var_test, all_weights, all_prediction = \ + simple_train_and_test(design_mat, fit['spike_count_arr']['spike_counts'], + lam=L2_value, + folds_outer=run_params['n_outer_folds']) + + else: + if run_params['optimize_penalty_by_cell']: + print(get_timestamp() + ': fitting each cell') + with ProcessPoolExecutor(max_workers=10) as executor: + futures = { + executor.submit(process_unit, unit_no, design_mat.copy(), fit.copy(), run_params, 'fit'): unit_no + for unit_no in range(num_units) + } + for future in tqdm(as_completed(futures), total=num_units, desc='progress'): + unit_no, cv_train, cv_test, weights, prediction, lams = future.result() + cv_var_train[unit_no] = cv_train + cv_var_test[unit_no] = cv_test + all_weights[:, unit_no] = weights.reshape(-1) + all_prediction[:, unit_no] = prediction.reshape(-1) + cell_L2_regularization_nested[unit_no] = lams + + elif run_params['optimize_penalty_by_area']: + areas = np.unique(fit['spike_count_arr']['structure']) + for area in tqdm(areas, total=len(areas), desc='progress'): + unit_ids = np.where(fit['spike_count_arr']['structure'] == area)[0] + fit_area = spike_counts[:, unit_ids] + cv_train, cv_test, weights, prediction, lams = \ + nested_train_and_test(design_mat, fit_area, L2_grid=fit['L2_grid'], + folds_outer=run_params['n_outer_folds'], + folds_inner=run_params['n_inner_folds']) + + cv_var_train[unit_ids] = cv_train + cv_var_test[unit_ids] = cv_test + all_weights[:, unit_ids] = weights + all_prediction[:, unit_ids] = prediction + cell_L2_regularization_nested[unit_ids] = lams + + else: + cv_var_train, cv_var_test, all_weights, all_prediction, cell_L2_regularization_nested = \ + nested_train_and_test(design_mat, spike_counts, L2_grid=fit['L2_grid'], + folds_outer=run_params['n_outer_folds'], + folds_inner=run_params['n_inner_folds']) + model_label = run_params['model_label'] + fit[model_label] = { + 'weights': all_weights, + 'full_model_prediction': all_prediction, + 'cv_var_train': cv_var_train, + 'cv_var_test': cv_var_test + } + if not np.isnan(cell_L2_regularization_nested).all(): + fit['cell_L2_regularization_nested'] = cell_L2_regularization_nested + + return fit + + +def clean_r2_vals(x): + x[np.isinf(x) | np.isnan(x)] = 0 + return x + + +def get_timestamp(): + t = time.localtime() + return time.strftime('%Y-%m-%d: %H:%M:%S') + ' ' + + + +# def set_kernel_length(trials, units_table, feature_func=None, +# time_before=None, time_after=None, bin_size=None, +# kernel_lengths=None, kernel_conditions=None): +# if not kernel_lengths: +# kernel_lengths = [0.1, 0.25, 0.5, 1, 1.5] +# if not time_before: +# time_before = 2 +# if not time_after: +# time_after = 3 +# if not bin_size: +# bin_size = 0.025 +# +# n_units = len(units_table) +# r2 = np.zeros((len(kernel_lengths), n_units)) + np.nan +# for k, kernel_length in enumerate(kernel_lengths): +# if kernel_conditions: +# X = feature_func(trials, time_before, time_after, kernel_length, bin_size, kernel_conditions) +# else: +# X = feature_func(trials, time_before, time_after, kernel_length, bin_size) +# X = np.hstack((np.ones((X.shape[0], 1)), X)) +# r2_k, weights_k = train_and_test(X, spike_counts, folds_outer=5, folds_inner=3) +# r2[k, :] = r2_k +# return kernel_lengths[np.argmax(np.nanmedian(r2, axis=1))], np.nanmedian(r2, axis=1)