From d35f8ea0d3389683c27107795101fa10f4933457 Mon Sep 17 00:00:00 2001 From: Frithjof Gressmann Date: Mon, 18 Mar 2024 09:38:21 -0500 Subject: [PATCH] Apply black formatting --- src/miv_simulator/cells.py | 104 ++++++++++---- src/miv_simulator/clamps/network.py | 156 +++++++++++++++------ src/miv_simulator/env.py | 71 +++++++--- src/miv_simulator/network.py | 193 ++++++++++++++++++++------ src/miv_simulator/optimization.py | 59 ++++++-- src/miv_simulator/optimize_network.py | 27 +++- src/miv_simulator/synapses.py | 118 +++++++++++----- src/scripts/analysis/network_clamp.py | 20 ++- 8 files changed, 548 insertions(+), 200 deletions(-) diff --git a/src/miv_simulator/cells.py b/src/miv_simulator/cells.py index eed326f..f952934 100644 --- a/src/miv_simulator/cells.py +++ b/src/miv_simulator/cells.py @@ -102,7 +102,9 @@ def get_soma_xyz( pt_swc_types = neurotree_dict["swc_type"] soma_pts = np.where(pt_swc_types == swc_type_defs["soma"])[0] - soma_coords = np.column_stack((pt_xs[soma_pts], pt_ys[soma_pts], pt_zs[soma_pts])) + soma_coords = np.column_stack( + (pt_xs[soma_pts], pt_ys[soma_pts], pt_zs[soma_pts]) + ) return soma_coords[0] @@ -230,7 +232,8 @@ def make_input_cell( param_values = input_gen["params"] template = getattr(h, template_name) params = [ - param_values[p] for p in env.netclamp_config.template_params[template_name] + param_values[p] + for p in env.netclamp_config.template_params[template_name] ] cell = template(gid, *params) else: @@ -283,7 +286,9 @@ class BRKneuron: """ - def __init__(self, gid, pop_name, env=None, cell_config=None, mech_dict=None): + def __init__( + self, gid, pop_name, env=None, cell_config=None, mech_dict=None + ): """ :param gid: int @@ -348,7 +353,9 @@ def __init__(self, gid, pop_name, env=None, cell_config=None, mech_dict=None): h.define_shape() soma_node = insert_section_node(self, "soma", index=0, sec=BRK_nrn.soma) - apical_node = insert_section_node(self, "apical", index=1, sec=BRK_nrn.dend) + apical_node = insert_section_node( + self, "apical", index=1, sec=BRK_nrn.dend + ) connect_nodes( self.tree, self.soma[0], self.apical[0], connect_hoc_sections=False ) @@ -422,7 +429,9 @@ class PRneuron: Conforms to the same API as BiophysCell. """ - def __init__(self, gid, pop_name, env=None, cell_config=None, mech_dict=None): + def __init__( + self, gid, pop_name, env=None, cell_config=None, mech_dict=None + ): """ :param gid: int @@ -451,7 +460,9 @@ def __init__(self, gid, pop_name, env=None, cell_config=None, mech_dict=None): self.spike_onset_delay = 0.0 self.is_reduced = True if not isinstance(cell_config, PRconfig): - raise RuntimeError("PRneuron: argument cell_attrs must be of type PRconfig") + raise RuntimeError( + "PRneuron: argument cell_attrs must be of type PRconfig" + ) param_dict = { "pp": cell_config.pp, @@ -478,7 +489,9 @@ def __init__(self, gid, pop_name, env=None, cell_config=None, mech_dict=None): h.define_shape() soma_node = insert_section_node(self, "soma", index=0, sec=PR_nrn.soma) - apical_node = insert_section_node(self, "apical", index=1, sec=PR_nrn.dend) + apical_node = insert_section_node( + self, "apical", index=1, sec=PR_nrn.dend + ) connect_nodes( self.tree, self.soma[0], self.apical[0], connect_hoc_sections=False ) @@ -679,7 +692,9 @@ def __init__( neurotree_dict: Optional[ Dict[ str, - Union[ndarray, Dict[str, Union[int, Dict[int, ndarray], ndarray]]], + Union[ + ndarray, Dict[str, Union[int, Dict[int, ndarray], ndarray]] + ], ] ] = None, mech_file_path: None = None, @@ -704,7 +719,9 @@ def __init__( self.template_class = env.template_dict[population_name] for sec_type in env.SWC_Types: if sec_type not in default_ordered_sec_types: - raise AttributeError("Unexpected SWC Type definitions found in Env") + raise AttributeError( + "Unexpected SWC Type definitions found in Env" + ) self.nodes = {key: [] for key in default_ordered_sec_types} self.mech_file_path = mech_file_path @@ -718,7 +735,9 @@ def __init__( hoc_cell, section_content = make_neurotree_hoc_cell( self.template_class, gid, neurotree_dict, section_content=True ) - import_morphology_from_hoc(self, hoc_cell, section_content=section_content) + import_morphology_from_hoc( + self, hoc_cell, section_content=section_content + ) if (mech_dict is None) and (mech_file_path is not None): import_mech_dict_from_file(self, self.mech_file_path) elif mech_dict is None: @@ -799,7 +818,9 @@ def get_distance_to_node( return length if loc is not None: length += loc * node.section.L - rpath = list(reversed(nx.shortest_path(cell.tree, source=root, target=node))) + rpath = list( + reversed(nx.shortest_path(cell.tree, source=root, target=node)) + ) while not len(rpath) == 0: node = rpath.pop() if not len(rpath) == 0: @@ -961,7 +982,9 @@ def import_morphology_from_hoc( if root_sec: insert_section_tree(cell, [root_sec], sec_info_dict) else: - raise RuntimeError(f"import_morphology_from_hoc: unable to locate root section") + raise RuntimeError( + f"import_morphology_from_hoc: unable to locate root section" + ) def import_mech_dict_from_file(cell, mech_file_path=None): @@ -973,7 +996,9 @@ def import_mech_dict_from_file(cell, mech_file_path=None): """ if mech_file_path is None: if cell.mech_file_path is None: - raise ValueError("import_mech_dict_from_file: missing mech_file_path") + raise ValueError( + "import_mech_dict_from_file: missing mech_file_path" + ) elif not os.path.isfile(cell.mech_file_path): raise OSError( "import_mech_dict_from_file: invalid mech_file_path: %s" @@ -989,7 +1014,9 @@ def import_mech_dict_from_file(cell, mech_file_path=None): cell.mech_dict = copy.deepcopy(cell.init_mech_dict) -def init_cable(cell: Union[BiophysCell, SCneuron], verbose: bool = False) -> None: +def init_cable( + cell: Union[BiophysCell, SCneuron], verbose: bool = False +) -> None: for sec_type in cell.nodes: for node in cell.nodes[sec_type]: reset_cable_by_node(cell, node, verbose=verbose) @@ -1010,7 +1037,9 @@ def reset_cable_by_node( if sec_type in cell.mech_dict and "cable" in cell.mech_dict[sec_type]: mech_content = cell.mech_dict[sec_type]["cable"] if mech_content is not None: - update_mechanism_by_node(cell, node, "cable", mech_content, verbose=verbose) + update_mechanism_by_node( + cell, node, "cable", mech_content, verbose=verbose + ) else: init_nseg(node.section, verbose=verbose) reinit_diam(node.section, node.diam_bounds) @@ -1099,7 +1128,9 @@ def init_spike_detector( sec_seg_locs = [seg.x for seg in node.sec] for loc in sec_seg_locs: if ( - get_distance_to_node(cell, node, root=cell.root, loc=loc) + get_distance_to_node( + cell, node, root=cell.root, loc=loc + ) >= distance ): break @@ -1167,7 +1198,9 @@ def update_mechanism_by_node( cell: BiophysCell, node: SectionNode, mech_name: str, - mech_content: Optional[Dict[str, Union[Dict[str, float], Dict[str, int]]]] = None, + mech_content: Optional[ + Dict[str, Union[Dict[str, float], Dict[str, int]]] + ] = None, verbose: bool = True, ) -> None: """ @@ -1313,7 +1346,11 @@ def filter_nodes( for swc_type in swc_types: nodes.extend(cell.nodes[swc_type]) - result = [v for v in nodes if matches([(layers, v.get_layer()), (sections, v.sec)])] + result = [ + v + for v in nodes + if matches([(layers, v.get_layer()), (sections, v.sec)]) + ] return result @@ -1373,16 +1410,16 @@ def report_topology( ) ) - diams_str = ", ".join(f"{node.sec.diam3d(i):.2f}" for i in range(node.sec.n3d())) + diams_str = ", ".join( + f"{node.sec.diam3d(i):.2f}" for i in range(node.sec.n3d()) + ) report = ( f"node: {node.name}, L: {node.sec.L:.1f}, diams: [{diams_str}], nseg: {node.sec.nseg}, " f"children: {len(node.sec.children())}, exc_syns: {num_exc_syns}, inh_syns: {num_inh_syns}" ) parent, edge_data = get_node_parent(cell, node, return_edge_data=True) if parent is not None: - report += ( - f", parent: {parent.name}; connection_loc: {edge_data['parent_loc']:.1f}" - ) + report += f", parent: {parent.name}; connection_loc: {edge_data['parent_loc']:.1f}" logger.info(report) children = get_node_children(cell, node) for child in children: @@ -1706,7 +1743,9 @@ def init_circuit_context( % (pop_name, gid) ) else: - raise RuntimeError("init_circuit_context: invalid synapses parameters") + raise RuntimeError( + "init_circuit_context: invalid synapses parameters" + ) if init_weights and has_weights: for weight_config_dict in weight_config: @@ -1740,10 +1779,14 @@ def init_circuit_context( cell_weights_dict, ) in cell_weights_iter: assert cell_weights_gid == gid - cell_weights_dicts[weights_namespace] = cell_weights_dict + cell_weights_dicts[ + weights_namespace + ] = cell_weights_dict else: - raise RuntimeError("init_circuit_context: invalid weights parameters") + raise RuntimeError( + "init_circuit_context: invalid weights parameters" + ) if len(weights_namespaces) != len(cell_weights_dicts): logger.warning( "init_circuit_context: Unable to load all weights namespaces: %s" @@ -1902,7 +1945,8 @@ def correct_node_for_spines_g_pas( num_spines = node.spine_count[i] g_pas_correction_factor = ( - SA_seg * node.sec(segment.x).g_pas + num_spines * SA_spine * soma_g_pas + SA_seg * node.sec(segment.x).g_pas + + num_spines * SA_spine * soma_g_pas ) / (SA_seg * node.sec(segment.x).g_pas) node.sec(segment.x).g_pas *= g_pas_correction_factor if verbose: @@ -1930,7 +1974,9 @@ def correct_node_for_spines_cm(node, env: AbstractEnv, gid, verbose=True): for i, segment in enumerate(node.sec): SA_seg = segment.area() num_spines = node.spine_count[i] - cm_correction_factor = (SA_seg + cm_fraction * num_spines * SA_spine) / SA_seg + cm_correction_factor = ( + SA_seg + cm_fraction * num_spines * SA_spine + ) / SA_seg node.sec(segment.x).cm *= cm_correction_factor if verbose: logger.info( @@ -2122,7 +2168,9 @@ def make_BRK_cell( pop_name=pop_name, env=env, cell_config=BRKconfig(**mech_dict["BoothRinzelKiehn"]), - mech_dict={k: mech_dict[k] for k in mech_dict if k != "BoothRinzelKiehn"}, + mech_dict={ + k: mech_dict[k] for k in mech_dict if k != "BoothRinzelKiehn" + }, ) circuit_flag = ( diff --git a/src/miv_simulator/clamps/network.py b/src/miv_simulator/clamps/network.py index b5218ca..9e67074 100644 --- a/src/miv_simulator/clamps/network.py +++ b/src/miv_simulator/clamps/network.py @@ -112,7 +112,9 @@ def generate_weights(env, weight_source_rules, this_syn_attrs): weights_name = weight_rule["name"] rule_params = weight_rule["params"] fraction = rule_params["fraction"] - seed_offset = int(env.model_config["Random Seeds"]["Sparse Weights"]) + seed_offset = int( + env.model_config["Random Seeds"]["Sparse Weights"] + ) seed = int(seed_offset + 1) weights_dict[presyn_id] = synapses.generate_sparse_weights( weights_name, fraction, seed, source_syn_dict @@ -137,7 +139,9 @@ def generate_weights(env, weight_source_rules, this_syn_attrs): rule_params = weight_rule["params"] mu = rule_params["mu"] sigma = rule_params["sigma"] - seed_offset = int(env.model_config["Random Seeds"]["GC Normal Weights"]) + seed_offset = int( + env.model_config["Random Seeds"]["GC Normal Weights"] + ) seed = int(seed_offset + 1) weights_dict[presyn_id] = synapses.generate_normal_weights( weights_name, mu, sigma, seed, source_syn_dict @@ -254,7 +258,9 @@ def init_inputs_from_features( if time_range[0] is None: time_range[0] = 0.0 - equilibration_duration = float(env.stimulus_config["Equilibration Duration"]) + equilibration_duration = float( + env.stimulus_config["Equilibration Duration"] + ) spatial_resolution = float(env.stimulus_config["Spatial Resolution"]) temporal_resolution = float(env.stimulus_config["Temporal Resolution"]) @@ -337,7 +343,9 @@ def init_inputs_from_features( if phase_mod_config_dict is not None: phase_mod_config = phase_mod_config_dict[gid] - spikes_attr_dict[gid] = stimulus.generate_stimulus_from_spike_trains( + spikes_attr_dict[ + gid + ] = stimulus.generate_stimulus_from_spike_trains( env, population, selectivity_type_names, @@ -354,7 +362,9 @@ def init_inputs_from_features( comm=env.comm, seed=seed, ) - spikes_attr_dict[gid][spike_train_attr_name] += equilibration_duration + spikes_attr_dict[gid][ + spike_train_attr_name + ] += equilibration_duration input_source_dict[pop_index] = {"spiketrains": spikes_attr_dict} @@ -483,12 +493,17 @@ def init( ) env.comm.barrier() if env.comm.rank == 0: - presyn_gid_rank_dict = {rank: set() for rank in range(env.comm.size)} + presyn_gid_rank_dict = { + rank: set() for rank in range(env.comm.size) + } for i, gid in enumerate(presyn_gid_set): rank = i % env.comm.size presyn_gid_rank_dict[rank].add(gid) presyn_sources[presyn_name] = env.comm.scatter( - [presyn_gid_rank_dict[rank] for rank in sorted(presyn_gid_rank_dict)], + [ + presyn_gid_rank_dict[rank] + for rank in sorted(presyn_gid_rank_dict) + ], root=0, ) else: @@ -503,12 +518,12 @@ def init( if env.comm.rank == 0: with h5py.File(coords_path, "r") as coords_f: reference_u_arc_distance_bounds = ( - coords_f["Populations"][population][distances_namespace].attrs[ - "Reference U Min" - ], - coords_f["Populations"][population][distances_namespace].attrs[ - "Reference U Max" - ], + coords_f["Populations"][population][ + distances_namespace + ].attrs["Reference U Min"], + coords_f["Populations"][population][ + distances_namespace + ].attrs["Reference U Max"], ) env.comm.barrier() reference_u_arc_distance_bounds = env.comm.bcast( @@ -596,7 +611,8 @@ def init( ## if spike_generator_dict contains an entry for the respective presynaptic population, ## then use the given generator to generate spikes. if not ( - (presyn_gid in env.gidset) or (is_cell_registered(env, presyn_gid)) + (presyn_gid in env.gidset) + or (is_cell_registered(env, presyn_gid)) ): cell = make_input_cell( env, @@ -745,7 +761,9 @@ def update_params(env, pop_param_dict): syn_name = param_tuple.syn_name param_path = param_tuple.param_path - if isinstance(param_path, list) or isinstance(param_path, tuple): + if isinstance(param_path, list) or isinstance( + param_path, tuple + ): p, s = param_path else: p, s = param_path, None @@ -768,8 +786,12 @@ def update_params(env, pop_param_dict): this_sec_type, syn_name, param_name=p, - value={s: param_value} if (s is not None) else param_value, - filters={"sources": sources} if sources is not None else None, + value={s: param_value} + if (s is not None) + else param_value, + filters={"sources": sources} + if sources is not None + else None, update_targets=True, ) @@ -923,7 +945,9 @@ def init_state_objfun( ) time_step = env.stimulus_config["Temporal Resolution"] - equilibration_duration = float(env.stimulus_config["Equilibration Duration"]) + equilibration_duration = float( + env.stimulus_config["Equilibration Duration"] + ) opt_param_config = optimization_params( env.netclamp_config.optimize_parameters, @@ -967,7 +991,9 @@ def gid_state_values(spkdict, t_offset, n_trials, t_rec, state_recs_dict): for rec in state_recs: vec = np.asarray(rec["vec"].to_python(), dtype=np.float32) if filter_fun is None: - data = np.asarray([np.mean(vec[t_inds]) for t_inds in t_trial_inds]) + data = np.asarray( + [np.mean(vec[t_inds]) for t_inds in t_trial_inds] + ) else: data = np.asarray( [ @@ -1003,12 +1029,18 @@ def eval_problem(cell_param_dict, **kwargs): elif trial_regime == "best": return { gid: -( - np.min(np.abs(np.asarray(state_values_dict[gid]) - target_value)) + np.min( + np.abs( + np.asarray(state_values_dict[gid]) - target_value + ) + ) ) for gid in my_cell_index_set } else: - raise RuntimeError(f"state_objfun: unknown trial regime {trial_regime}") + raise RuntimeError( + f"state_objfun: unknown trial regime {trial_regime}" + ) return opt_eval_fun(problem_regime, my_cell_index_set, eval_problem) @@ -1105,8 +1137,12 @@ def init_rate_objfun( env, population, gid, recording_profile=recording_profile ) - target_v_threshold = opt_targets[f"{population} state"]["v"].get("threshold", None) - target_v_margin = opt_targets[f"{population} state"]["v"].get("margin", -1.0) + target_v_threshold = opt_targets[f"{population} state"]["v"].get( + "threshold", None + ) + target_v_margin = opt_targets[f"{population} state"]["v"].get( + "margin", -1.0 + ) if target_v_threshold is None: raise RuntimeError( f"network_clamp: network clamp optimization configuration for population {population} " @@ -1211,7 +1247,9 @@ def eval_problem(cell_param_dict, **kwargs): for gid in my_cell_index_set } else: - raise RuntimeError(f"rate_objfun: unknown trial regime {trial_regime}") + raise RuntimeError( + f"rate_objfun: unknown trial regime {trial_regime}" + ) N_objectives = 1 opt_rate_feature_dtypes = [ @@ -1229,7 +1267,9 @@ def eval_problem(cell_param_dict, **kwargs): ) rates_array = np.asarray(firing_rates_dict[gid]) nz_idxs = np.argwhere( - np.logical_not(np.isclose(rates_array, 0.0, rtol=1e-4, atol=1e-4)) + np.logical_not( + np.isclose(rates_array, 0.0, rtol=1e-4, atol=1e-4) + ) ) feature_array["mean_rate"] = 0.0 if len(nz_idxs) > 0: @@ -1334,7 +1374,9 @@ def init_rate_dist_objfun( ] = 0.0 trj_d, trj_t = stimulus.read_stimulus( - input_features_path if input_features_path is not None else spike_events_path, + input_features_path + if input_features_path is not None + else spike_events_path, target_features_arena, target_features_stimulus, ) @@ -1372,7 +1414,9 @@ def gid_firing_rate_vectors(spkdict, cell_index_set): ) for gid in cell_index_set: rate_vector = spike_density_dict[gid]["rate"] - idxs = np.where(np.isclose(rate_vector, 0.0, atol=1e-3, rtol=1e-3))[0] + idxs = np.where( + np.isclose(rate_vector, 0.0, atol=1e-3, rtol=1e-3) + )[0] rate_vector[idxs] = 0.0 rates_dict[gid].append(rate_vector) for gid in spkdict[population]: @@ -1393,7 +1437,9 @@ def mean_trial_rate_mse(gid, rate_vectors, target_rate_vector): f"{np.min(mean_rate_vector):.02f} / {np.max(mean_rate_vector):.02f} Hz" ) - return np.square(np.subtract(mean_rate_vectore, target_rate_vector)).mean() + return np.square( + np.subtract(mean_rate_vectore, target_rate_vector) + ).mean() def best_trial_rate_mse(gid, rate_vectors, target_rate_vector): mses = [] @@ -1443,7 +1489,9 @@ def eval_problem(cell_param_dict, **kwargs): for gid in my_cell_index_set } else: - raise RuntimeError(f"firing_rate_dist: unknown trial regime {trial_regime}") + raise RuntimeError( + f"firing_rate_dist: unknown trial regime {trial_regime}" + ) return opt_eval_fun(problem_regime, my_cell_index_set, eval_problem) @@ -1500,9 +1548,7 @@ def optimize_run( if results_file is None: if env.results_path is not None: - file_path = ( - f"{env.results_path}/distgfs.network_clamp.{env.results_file_id}.h5" - ) + file_path = f"{env.results_path}/distgfs.network_clamp.{env.results_file_id}.h5" else: file_path = f"distgfs.network_clamp.{env.results_file_id}.h5" else: @@ -1526,7 +1572,9 @@ def optimize_run( reduce_fun_name = "opt_reduce_max" feature_dtypes = None else: - raise RuntimeError(f"optimize_run: unknown problem regime {problem_regime}") + raise RuntimeError( + f"optimize_run: unknown problem regime {problem_regime}" + ) distgfs_params = { "opt_id": "network_clamp.optimize", @@ -1562,7 +1610,9 @@ def optimize_run( params_dict = dict(opt_result[0]) result_value = opt_result[1] results_config_tuples = [] - for param_pattern, param_tuple in zip(param_names, param_tuples): + for param_pattern, param_tuple in zip( + param_names, param_tuples + ): results_config_tuples.append( ( param_tuple.population, @@ -1616,7 +1666,9 @@ def dist_ctrl( for this_param_path, pop_param_tuple_dict in zip( param_path, pop_param_tuple_dicts ): - params_basename = os.path.splitext(os.path.basename(this_param_path))[0] + params_basename = os.path.splitext( + os.path.basename(this_param_path) + )[0] this_results_file_id = f"{results_file_id}_{params_basename}" task_id = controller.submit_call( "dist_run", @@ -1889,7 +1941,9 @@ def go( pop_params_tuple_dicts = None if rank == 0: if results_file_id is None: - results_file_id = generate_results_file_id(population, seed=input_seed) + results_file_id = generate_results_file_id( + population, seed=input_seed + ) if len(params_path) > 0: pop_params_tuple_dicts = [] if len(params_id) == 0: @@ -1907,7 +1961,9 @@ def go( this_gid_params, ) in this_pop_param_dict.items(): if this_param_id is not None: - this_gid_params_list = this_gid_params[this_param_id] + this_gid_params_list = this_gid_params[ + this_param_id + ] else: this_gid_params_list = this_gid_params for this_gid_param in this_gid_params_list: @@ -1961,7 +2017,9 @@ def go( comm=comm0, ) cell_index = None - attr_name, attr_cell_index = next(iter(attr_info_dict[population]["Trees"])) + attr_name, attr_cell_index = next( + iter(attr_info_dict[population]["Trees"]) + ) cell_index_set = set(attr_cell_index) comm.barrier() cell_index_set = comm.bcast(cell_index_set, root=0) @@ -2018,9 +2076,9 @@ def go( for this_params_path, pop_params_tuple_dict in zip( params_path, pop_params_tuple_dicts ): - params_basename = os.path.splitext(os.path.basename(this_params_path))[ - 0 - ] + params_basename = os.path.splitext( + os.path.basename(this_params_path) + )[0] env.results_file_id = f"{results_file_id}_{params_basename}" env.results_file_path = f"{env.results_path}/{env.modelName}_results_{env.results_file_id}.h5" run_with(env, pop_params_tuple_dict) @@ -2123,7 +2181,9 @@ def optimize( comm=comm0, ) cell_index = None - attr_name, attr_cell_index = next(iter(attr_info_dict[population]["Trees"])) + attr_name, attr_cell_index = next( + iter(attr_info_dict[population]["Trees"]) + ) cell_index_set = set(attr_cell_index) comm.barrier() cell_index_set = comm.bcast(cell_index_set, root=0) @@ -2164,7 +2224,9 @@ def optimize( ) if population in env.netclamp_config.optimize_parameters[param_type]: - opt_params = env.netclamp_config.optimize_parameters[param_type][population] + opt_params = env.netclamp_config.optimize_parameters[param_type][ + population + ] else: raise RuntimeError( f"network_clamp.optimize: population {population} does not have optimization configuration" @@ -2178,7 +2240,9 @@ def optimize( constraint_names = ["mean_v_below_threshold"] elif target == "state": assert target_state_variable is not None - opt_target = opt_params["Targets"]["state"][target_state_variable]["mean"] + opt_target = opt_params["Targets"]["state"][target_state_variable][ + "mean" + ] init_params["target_value"] = opt_target init_params["state_variable"] = target_state_variable init_params["state_filter"] = target_state_filter @@ -2215,6 +2279,8 @@ def optimize( ) if results_config_dict is not None: if results_path is not None: - file_path = f"{results_path}/network_clamp.optimize.{results_file_id}.yaml" + file_path = ( + f"{results_path}/network_clamp.optimize.{results_file_id}.yaml" + ) write_to_yaml(file_path, results_config_dict) comm.barrier() diff --git a/src/miv_simulator/env.py b/src/miv_simulator/env.py index 8394e77..3a69165 100644 --- a/src/miv_simulator/env.py +++ b/src/miv_simulator/env.py @@ -46,7 +46,9 @@ ["template_params", "weight_generators", "optimize_parameters"], ) -ArenaConfig = namedtuple("Arena", ["name", "domain", "trajectories", "properties"]) +ArenaConfig = namedtuple( + "Arena", ["name", "domain", "trajectories", "properties"] +) DomainConfig = namedtuple("Domain", ["vertices", "simplices"]) @@ -267,11 +269,15 @@ def __init__( if "Definitions" in self.model_config: self.parse_definitions() - self.SWC_Type_index = {item[1]: item[0] for item in self.SWC_Types.items()} + self.SWC_Type_index = { + item[1]: item[0] for item in self.SWC_Types.items() + } self.Synapse_Type_index = { item[1]: item[0] for item in self.Synapse_Types.items() } - self.layer_type_index = {item[1]: item[0] for item in self.layers.items()} + self.layer_type_index = { + item[1]: item[0] for item in self.layers.items() + } if "Global Parameters" in self.model_config: self.parse_globals() @@ -317,7 +323,9 @@ def __init__( self.spike_input_attribute_info = None if self.spike_input_path is not None: if rank == 0: - self.logger.info(f"env.spike_input_path = {str(self.spike_input_path)}") + self.logger.info( + f"env.spike_input_path = {str(self.spike_input_path)}" + ) self.spike_input_attribute_info = read_cell_attribute_info( self.spike_input_path, sorted(self.Populations.keys()), @@ -351,7 +359,9 @@ def __init__( self.parse_gapjunction_config() if self.dataset_prefix is not None: - self.dataset_path = os.path.join(self.dataset_prefix, self.datasetName) + self.dataset_path = os.path.join( + self.dataset_prefix, self.datasetName + ) if "Cell Data" in self.model_config: self.data_file_path = os.path.join( self.dataset_path, self.model_config["Cell Data"] @@ -422,7 +432,9 @@ def __init__( ): projection_dict[dst].append(src) self.projection_dict = dict(projection_dict) - self.logger.info(f"projection_dict = {str(self.projection_dict)}") + self.logger.info( + f"projection_dict = {str(self.projection_dict)}" + ) self.projection_dict = self.comm.bcast(self.projection_dict, root=0) # If True, instantiate as spike source those cells that do not @@ -443,16 +455,20 @@ def __init__( "nstates": int(config["nstates"]), "opsin type": config["opsin type"], "protocol": config["protocol"], - "protocol parameters": config.get("protocol parameters", dict()), + "protocol parameters": config.get( + "protocol parameters", dict() + ), "rho parameters": config.get("rho parameters", dict()), } # Configuration profile for recording intracellular quantities self.recording_profile = None - if ("Recording" in self.model_config) and (recording_profile is not None): - self.recording_profile = self.model_config["Recording"]["Intracellular"][ - recording_profile - ] + if ("Recording" in self.model_config) and ( + recording_profile is not None + ): + self.recording_profile = self.model_config["Recording"][ + "Intracellular" + ][recording_profile] self.recording_profile["label"] = recording_profile for recvar, recdict in self.recording_profile.get( "synaptic quantity", {} @@ -548,7 +564,10 @@ def init_stimulus_config( if stimulus_id is None: self.stimulus_id = None else: - if stimulus_id in self.stimulus_config["Arena"][arena_id].trajectories: + if ( + stimulus_id + in self.stimulus_config["Arena"][arena_id].trajectories + ): self.stimulus_id = stimulus_id else: raise RuntimeError( @@ -571,7 +590,9 @@ def parse_stimulus_config(self) -> None: pop_selectivity_type_prob_dict[ int(self.selectivity_types[selectivity_type_name]) ] = float(selectivity_type_prob) - selectivity_type_prob_dict[pop] = pop_selectivity_type_prob_dict + selectivity_type_prob_dict[ + pop + ] = pop_selectivity_type_prob_dict stimulus_config[ "Selectivity Type Probabilities" ] = selectivity_type_prob_dict @@ -721,13 +742,17 @@ def parse_connection_config(self) -> None: if layer_name == "default": pop_connection_extents[layer_name] = { "width": extent_config[population][layer_name]["width"], - "offset": extent_config[population][layer_name]["offset"], + "offset": extent_config[population][layer_name][ + "offset" + ], } else: layer_index = self.layers[layer_name] pop_connection_extents[layer_index] = { "width": extent_config[population][layer_name]["width"], - "offset": extent_config[population][layer_name]["offset"], + "offset": extent_config[population][layer_name][ + "offset" + ], } self.connection_extents[population] = pop_connection_extents @@ -766,7 +791,9 @@ def parse_connection_config(self) -> None: if swctype_mechparams_dict is not None: for swc_type in swctype_mechparams_dict: swc_type_index = self.SWC_Types[swc_type] - res_mechparams[swc_type_index] = self.parse_syn_mechparams( + res_mechparams[ + swc_type_index + ] = self.parse_syn_mechparams( swctype_mechparams_dict[swc_type] ) else: @@ -949,7 +976,9 @@ def load_celltypes(self) -> None: population_names = None if rank == 0: population_names = read_population_names(self.data_file_path, comm0) - (population_ranges, _) = read_population_ranges(self.data_file_path, comm0) + (population_ranges, _) = read_population_ranges( + self.data_file_path, comm0 + ) self.cell_attribute_info = read_cell_attribute_info( self.data_file_path, population_names, comm=comm0 ) @@ -958,7 +987,9 @@ def load_celltypes(self) -> None: self.logger.info(f"attribute info: {str(self.cell_attribute_info)}") population_ranges = self.comm.bcast(population_ranges, root=0) population_names = self.comm.bcast(population_names, root=0) - self.cell_attribute_info = self.comm.bcast(self.cell_attribute_info, root=0) + self.cell_attribute_info = self.comm.bcast( + self.cell_attribute_info, root=0 + ) comm0.Free() for k in typenames: @@ -968,7 +999,9 @@ def load_celltypes(self) -> None: celltypes[k]["num"] = population_ranges[k][1] if "mechanism file" in celltypes[k]: if isinstance(celltypes[k]["mechanism file"], str): - celltypes[k]["mech_file_path"] = celltypes[k]["mechanism file"] + celltypes[k]["mech_file_path"] = celltypes[k][ + "mechanism file" + ] mech_dict = None if rank == 0: mech_file_path = celltypes[k]["mech_file_path"] diff --git a/src/miv_simulator/network.py b/src/miv_simulator/network.py index 6b91ce7..d7a8849 100644 --- a/src/miv_simulator/network.py +++ b/src/miv_simulator/network.py @@ -140,7 +140,9 @@ def connect_cells(env: Env) -> None: weight_dicts = synapse_config["weights"] if rank == 0: - logger.info(f"*** Reading synaptic attributes of population {postsyn_name}") + logger.info( + f"*** Reading synaptic attributes of population {postsyn_name}" + ) cell_attr_namespaces = ["Synapse Attributes"] @@ -168,7 +170,9 @@ def connect_cells(env: Env) -> None: node_allocation=env.node_allocation, ) - for iter_count, (gid, gid_attr_data) in enumerate(synapses_attr_gen): + for iter_count, (gid, gid_attr_data) in enumerate( + synapses_attr_gen + ): if gid is not None: (attr_tuple, attr_tuple_index) = gid_attr_data syn_ids_ind = attr_tuple_index.get("syn_ids", None) @@ -289,7 +293,9 @@ def connect_cells(env: Env) -> None: syn_id_index = weight_attr_info.get("syn_id", None) syn_name_inds = [ (syn_name, attr_index) - for syn_name, attr_index in sorted(weight_attr_info.items()) + for syn_name, attr_index in sorted( + weight_attr_info.items() + ) if syn_name != "syn_id" ] for gid, cell_weights_tuple in syn_weights_iter: @@ -304,19 +310,30 @@ def connect_cells(env: Env) -> None: "not found in network configuration" ) else: - weights_values = cell_weights_tuple[syn_name_index] - assert len(weights_syn_ids) == len(weights_values) + weights_values = cell_weights_tuple[ + syn_name_index + ] + assert len(weights_syn_ids) == len( + weights_values + ) syn_attrs.add_mech_attrs_from_iter( gid, syn_name, zip_longest( weights_syn_ids, [ - {"weight": Promise(expr_closure, [x])} + { + "weight": Promise( + expr_closure, [x] + ) + } for x in weights_values ] if expr_closure - else [{"weight": x} for x in weights_values], + else [ + {"weight": x} + for x in weights_values + ], ), multiple=multiple_weights, append=append_weights, @@ -369,7 +386,9 @@ def connect_cells(env: Env) -> None: lambda edgeset: presyn_input_sources.update(edgeset[1][0]), edge_iter, ) - env.microcircuit_input_sources[presyn_name] = presyn_input_sources + env.microcircuit_input_sources[ + presyn_name + ] = presyn_input_sources else: syn_edge_iter = edge_iter syn_attrs.init_edge_attrs_from_iter( @@ -552,7 +571,9 @@ def connect_cell_selection(env): presyn_names = sorted(env.projection_dict[postsyn_name]) gid_range = [ - gid for gid in env.cell_selection[postsyn_name] if env.pc.gid_exists(gid) + gid + for gid in env.cell_selection[postsyn_name] + if env.pc.gid_exists(gid) ] synapse_config = env.celltypes[postsyn_name]["synapses"] @@ -573,7 +594,9 @@ def connect_cell_selection(env): weight_dicts = synapse_config["weights"] if rank == 0: - logger.info(f"*** Reading synaptic attributes of population {postsyn_name}") + logger.info( + f"*** Reading synaptic attributes of population {postsyn_name}" + ) syn_attrs_iter, syn_attrs_info = read_cell_attribute_selection( forest_file_path, @@ -630,7 +653,9 @@ def connect_cell_selection(env): syn_id_index = weight_attr_info.get("syn_id", None) syn_name_inds = [ (syn_name, attr_index) - for syn_name, attr_index in sorted(weight_attr_info.items()) + for syn_name, attr_index in sorted( + weight_attr_info.items() + ) if syn_name != "syn_id" ] @@ -646,18 +671,27 @@ def connect_cell_selection(env): "not found in network configuration" ) else: - weights_values = cell_weights_tuple[syn_name_index] + weights_values = cell_weights_tuple[ + syn_name_index + ] syn_attrs.add_mech_attrs_from_iter( gid, syn_name, zip_longest( weights_syn_ids, [ - {"weight": Promise(expr_closure, [x])} + { + "weight": Promise( + expr_closure, [x] + ) + } for x in weights_values ] if expr_closure - else [{"weight": x} for x in weights_values], + else [ + {"weight": x} + for x in weights_values + ], ), multiple=multiple_weights, append=append_weights, @@ -676,7 +710,8 @@ def connect_cell_selection(env): connectivity_file_path, selection=gid_range, projections=[ - (presyn_name, postsyn_name) for presyn_name in sorted(presyn_names) + (presyn_name, postsyn_name) + for presyn_name in sorted(presyn_names) ], comm=env.comm, namespaces=["Synapses", "Connections"], @@ -698,7 +733,9 @@ def connect_cell_selection(env): syn_attrs.init_edge_attrs_from_iter( postsyn_name, presyn_name, a, syn_edge_iter ) - env.microcircuit_input_sources[presyn_name] = presyn_input_sources + env.microcircuit_input_sources[ + presyn_name + ] = presyn_input_sources del graph[postsyn_name][presyn_name] first_gid = None @@ -894,7 +931,9 @@ def make_cells(env: Env) -> None: pop_names = sorted(env.celltypes.keys()) if rank == 0: - logger.info(f"Population attributes: {pprint.pformat(env.cell_attribute_info)}") + logger.info( + f"Population attributes: {pprint.pformat(env.cell_attribute_info)}" + ) for pop_name in pop_names: if rank == 0: logger.info(f"*** Creating population {pop_name}") @@ -977,7 +1016,11 @@ def make_cells(env: Env) -> None: mech_dict=mech_dict, ) # cells.init_spike_detector(biophys_cell) - if rank == 0 and gid == first_gid and mech_file_path is not None: + if ( + rank == 0 + and gid == first_gid + and mech_file_path is not None + ): logger.info( f"*** make_cells: population: {pop_name}; gid: {gid}; loaded biophysics from path: {mech_file_path}" ) @@ -999,7 +1042,9 @@ def make_cells(env: Env) -> None: "Coordinates" in env.cell_attribute_info[pop_name] ): if rank == 0: - logger.info(f"*** Reading coordinates for population {pop_name}") + logger.info( + f"*** Reading coordinates for population {pop_name}" + ) if env.node_allocation is None: cell_attr_dict = scatter_read_cell_attributes( @@ -1021,7 +1066,9 @@ def make_cells(env: Env) -> None: return_type="tuple", ) if rank == 0: - logger.info(f"*** Done reading coordinates for population {pop_name}") + logger.info( + f"*** Done reading coordinates for population {pop_name}" + ) coords_iter, coords_attr_info = cell_attr_dict["Coordinates"] @@ -1069,12 +1116,16 @@ def make_cells(env: Env) -> None: recording_fraction = env.recording_profile.get("fraction", 1.0) recording_limit = env.recording_profile.get("limit", -1) all_pop_biophys_gids = sorted( - item for sublist in pop_biophys_gids_per_rank for item in sublist + item + for sublist in pop_biophys_gids_per_rank + for item in sublist ) for gid in all_pop_biophys_gids: if ranstream_recording.uniform() <= recording_fraction: recording_set.add(gid) - if (recording_limit > 0) and (len(recording_set) > recording_limit): + if (recording_limit > 0) and ( + len(recording_set) > recording_limit + ): break logger.info(f"recording_set = {recording_set}") recording_set = env.comm.bcast(recording_set, root=0) @@ -1109,7 +1160,9 @@ def make_cell_selection(env): for pop_name in pop_names: if rank == 0: - logger.info(f"*** Creating selected cells from population {pop_name}") + logger.info( + f"*** Creating selected cells from population {pop_name}" + ) template_name = env.celltypes[pop_name]["template"] template_name_lower = template_name.lower() @@ -1213,7 +1266,9 @@ def make_cell_selection(env): "Coordinates" in env.cell_attribute_info[pop_name] ): if rank == 0: - logger.info(f"*** Reading coordinates for population {pop_name}") + logger.info( + f"*** Reading coordinates for population {pop_name}" + ) coords_iter, coords_attr_info = read_cell_attribute_selection( data_file_path, @@ -1228,7 +1283,9 @@ def make_cell_selection(env): z_index = coords_attr_info.get("Z Coordinate", None) if rank == 0: - logger.info(f"*** Done reading coordinates for population {pop_name}") + logger.info( + f"*** Done reading coordinates for population {pop_name}" + ) for i, (gid, cell_coords_tuple) in enumerate(coords_iter): if rank == 0: @@ -1299,8 +1356,12 @@ def make_input_cell_selection(env): rank = int(env.pc.id()) nhosts = int(env.pc.nhost()) - created_input_sources = {pop_name: set() for pop_name in env.celltypes.keys()} - for pop_name, input_gid_range in sorted(env.microcircuit_input_sources.items()): + created_input_sources = { + pop_name: set() for pop_name in env.celltypes.keys() + } + for pop_name, input_gid_range in sorted( + env.microcircuit_input_sources.items() + ): pop_index = int(env.Populations[pop_name]) has_spike_train = False @@ -1322,7 +1383,9 @@ def make_input_cell_selection(env): spike_generator = None else: if pop_name in env.netclamp_config.input_generators: - spike_generator = env.netclamp_config.input_generators[pop_name] + spike_generator = env.netclamp_config.input_generators[ + pop_name + ] else: raise RuntimeError( f"make_input_cell_selection: population {pop_name} has neither input spike trains nor input generator configuration" @@ -1333,7 +1396,9 @@ def make_input_cell_selection(env): else: input_source_dict = {pop_index: {"spiketrains": {}}} - if (env.cell_selection is not None) and (pop_name in env.cell_selection): + if (env.cell_selection is not None) and ( + pop_name in env.cell_selection + ): local_input_gid_range = input_gid_range.difference( set(env.cell_selection[pop_name]) ) @@ -1408,7 +1473,9 @@ def init_input_cells(env: Env) -> None: if env.arena_id and env.stimulus_id: vecstim_namespace = f"{env.celltypes[pop_name]['spike train']['namespace']} {env.arena_id} {env.stimulus_id}" else: - vecstim_namespace = env.celltypes[pop_name]["spike train"]["namespace"] + vecstim_namespace = env.celltypes[pop_name]["spike train"][ + "namespace" + ] vecstim_attr = env.celltypes[pop_name]["spike train"]["attribute"] has_vecstim = False @@ -1417,7 +1484,8 @@ def init_input_cells(env: Env) -> None: env.spike_input_ns is not None ): if (pop_name in env.spike_input_attribute_info) and ( - env.spike_input_ns in env.spike_input_attribute_info[pop_name] + env.spike_input_ns + in env.spike_input_attribute_info[pop_name] ): has_vecstim = True vecstim_source_loc.append( @@ -1478,7 +1546,9 @@ def init_input_cells(env: Env) -> None: return_type="tuple", ) - vecstim_iter, vecstim_attr_info = cell_vecstim_dict[input_ns] + vecstim_iter, vecstim_attr_info = cell_vecstim_dict[ + input_ns + ] else: if pop_name in env.cell_selection: gid_range = [ @@ -1508,11 +1578,15 @@ def init_input_cells(env: Env) -> None: else: vecstim_iter = [] - vecstim_attr_index = vecstim_attr_info.get(vecstim_attr, None) + vecstim_attr_index = vecstim_attr_info.get( + vecstim_attr, None + ) trial_index_attr_index = vecstim_attr_info.get( trial_index_attr, None ) - trial_dur_attr_index = vecstim_attr_info.get(trial_dur_attr, None) + trial_dur_attr_index = vecstim_attr_info.get( + trial_dur_attr, None + ) for gid, vecstim_tuple in vecstim_iter: if not (env.pc.gid_exists(gid)): continue @@ -1533,11 +1607,17 @@ def init_input_cells(env: Env) -> None: env.n_trials, ) spiketrain += ( - float(env.stimulus_config["Equilibration Duration"]) + float( + env.stimulus_config[ + "Equilibration Duration" + ] + ) + env.stimulus_onset ) if len(spiketrain) > 0: - cell.play(h.Vector(spiketrain.astype(np.float64))) + cell.play( + h.Vector(spiketrain.astype(np.float64)) + ) if rank == 0: logger.info( f"*** Spike train for {pop_name} gid {gid} is of length {len(spiketrain)} ({spiketrain[0]} : {spiketrain[-1]} ms)" @@ -1549,8 +1629,12 @@ def init_input_cells(env: Env) -> None: for pop_name in sorted(env.microcircuit_input_sources.keys()): gid_range = env.microcircuit_input_sources.get(pop_name, set()) - if (env.cell_selection is not None) and (pop_name in env.cell_selection): - this_gid_range = gid_range.difference(set(env.cell_selection[pop_name])) + if (env.cell_selection is not None) and ( + pop_name in env.cell_selection + ): + this_gid_range = gid_range.difference( + set(env.cell_selection[pop_name]) + ) else: this_gid_range = gid_range @@ -1560,7 +1644,8 @@ def init_input_cells(env: Env) -> None: env.spike_input_ns is not None ): if (pop_name in env.spike_input_attribute_info) and ( - env.spike_input_ns in env.spike_input_attribute_info[pop_name] + env.spike_input_ns + in env.spike_input_attribute_info[pop_name] ): has_spike_train = True spike_input_source_loc.append( @@ -1573,7 +1658,9 @@ def init_input_cells(env: Env) -> None: env.spike_input_ns in env.cell_attribute_info[pop_name] ): has_spike_train = True - spike_input_source_loc.append((input_file_path, env.spike_input_ns)) + spike_input_source_loc.append( + (input_file_path, env.spike_input_ns) + ) if rank == 0: logger.info( @@ -1620,7 +1707,9 @@ def init_input_cells(env: Env) -> None: env.spike_input_attr, None ) elif "t" in cell_spikes_attr_info.keys(): - spike_train_attr_index = cell_spikes_attr_info.get("t", None) + spike_train_attr_index = cell_spikes_attr_info.get( + "t", None + ) elif "Spike Train" in cell_spikes_attr_info.keys(): spike_train_attr_index = cell_spikes_attr_info.get( "Spike Train", None @@ -1653,11 +1742,17 @@ def init_input_cells(env: Env) -> None: env.n_trials, ) spiketrain += ( - float(env.stimulus_config["Equilibration Duration"]) + float( + env.stimulus_config[ + "Equilibration Duration" + ] + ) + env.stimulus_onset ) if len(spiketrain) > 0: - input_cell.play(h.Vector(spiketrain.astype(np.float64))) + input_cell.play( + h.Vector(spiketrain.astype(np.float64)) + ) if rank == 0: logger.info( f"*** Spike train for {pop_name} gid {gid} is of length {len(spiketrain)} ({spiketrain[0]} : {spiketrain[-1]} ms)" @@ -1715,7 +1810,9 @@ def init(env: Env, subworld_size: Optional[int] = None) -> None: env.pc.setup_transfer() env.connectgjstime = time.time() - st if rank == 0: - logger.info(f"*** Gap junctions created in {env.connectgjstime:.02f} s") + logger.info( + f"*** Gap junctions created in {env.connectgjstime:.02f} s" + ) if env.opsin_config is not None: st = time.time() @@ -1764,7 +1861,9 @@ def init(env: Env, subworld_size: Optional[int] = None) -> None: dt_lfp=lfp_config_dict["dt"], fdst=lfp_config_dict["fraction"], maxEDist=lfp_config_dict["maxEDist"], - seed=int(env.model_config["Random Seeds"]["Local Field Potential"]), + seed=int( + env.model_config["Random Seeds"]["Local Field Potential"] + ), ) if rank == 0: logger.info( @@ -1937,7 +2036,9 @@ def run( ) if env.recording_profile is not None: if rank == 0: - logger.info(f"*** Writing intracellular data up to {h.t:.2f} ms") + logger.info( + f"*** Writing intracellular data up to {h.t:.2f} ms" + ) io_utils.recsout( env, env.results_file_path, diff --git a/src/miv_simulator/optimization.py b/src/miv_simulator/optimization.py index f3e77bd..34af701 100644 --- a/src/miv_simulator/optimization.py +++ b/src/miv_simulator/optimization.py @@ -67,10 +67,16 @@ def parse_optimization_param_dict( ): for source, source_dict in sorted(param_dict.items(), key=keyfun): for sec_type, sec_type_dict in sorted(source_dict.items(), key=keyfun): - for syn_name, syn_mech_dict in sorted(sec_type_dict.items(), key=keyfun): - for param_fst, param_rst in sorted(syn_mech_dict.items(), key=keyfun): + for syn_name, syn_mech_dict in sorted( + sec_type_dict.items(), key=keyfun + ): + for param_fst, param_rst in sorted( + syn_mech_dict.items(), key=keyfun + ): if isinstance(param_rst, dict): - for const_name, const_range in sorted(param_rst.items()): + for const_name, const_range in sorted( + param_rst.items() + ): param_path = (param_fst, const_name) param_tuples.append( SynParam( @@ -141,7 +147,9 @@ def parse_optimization_param_dict( param_name, ) - param_initial_value = (param_range[1] - param_range[0]) / 2.0 + param_initial_value = ( + param_range[1] - param_range[0] + ) / 2.0 param_initial_dict[param_key] = param_initial_value param_bounds[param_key] = param_range param_names.append(param_key) @@ -208,7 +216,9 @@ def parse_optimization_param_entries( param_names=param_names, ) else: - raise RuntimeError(f"Invalid optimization parameter object: {param_entries}") + raise RuntimeError( + f"Invalid optimization parameter object: {param_entries}" + ) def optimization_params( @@ -298,7 +308,9 @@ def update_network_params(env, param_tuples): biophys_cell_dict = env.biophys_cells[population] for gid in biophys_cell_dict: - if (phenotype_dict is not None) and (param_phenotype is not None): + if (phenotype_dict is not None) and ( + param_phenotype is not None + ): gid_phenotype = phenotype_dict.get(gid, None) if gid_phenotype is not None: if gid_phenotype != param_phenotype: @@ -316,8 +328,12 @@ def update_network_params(env, param_tuples): this_sec_type, syn_name, param_name=p, - value={s: param_value} if (s is not None) else param_value, - filters={"sources": sources} if sources is not None else None, + value={s: param_value} + if (s is not None) + else param_value, + filters={"sources": sources} + if sources is not None + else None, origin=None if is_reduced else "soma", update_targets=True, ) @@ -368,8 +384,12 @@ def update_run_params(env, param_tuples): this_sec_type, syn_name, param_name=p, - value={s: param_value} if (s is not None) else param_value, - filters={"sources": sources} if sources is not None else None, + value={s: param_value} + if (s is not None) + else param_value, + filters={"sources": sources} + if sources is not None + else None, origin=None if is_reduced else "soma", update_targets=True, ) @@ -411,8 +431,12 @@ def network_features( target_trj_rate_map = pop_target_trj_rate_map_dict[gid] rate_map_len = len(target_trj_rate_map) if gid in spike_density_dict: - measured_rate = spike_density_dict[gid]["rate"][:rate_map_len] - ref_signal = target_trj_rate_map - np.mean(target_trj_rate_map) + measured_rate = spike_density_dict[gid]["rate"][ + :rate_map_len + ] + ref_signal = target_trj_rate_map - np.mean( + target_trj_rate_map + ) signal = measured_rate - np.mean(measured_rate) noise = signal - ref_signal snr = np.var(signal) / max(np.var(noise), 1e-6) @@ -440,7 +464,9 @@ def distgfs_broker_bcast(broker, tag): nprocs = broker.nprocs_per_worker data_dict = {} while len(data_dict) < nprocs: - if broker.merged_comm.Iprobe(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG): + if broker.merged_comm.Iprobe( + source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG + ): data = broker.merged_comm.recv( source=MPI.ANY_SOURCE, tag=tag, status=status ) @@ -547,14 +573,17 @@ def opt_reduce_max(xs): return {k: np.max(vs[k]) for k in ks} -def opt_eval_fun(problem_regime, cell_index_set, eval_problem_fun, feature_dtypes=None): +def opt_eval_fun( + problem_regime, cell_index_set, eval_problem_fun, feature_dtypes=None +): problem_regime = ProblemRegime[problem_regime] def f(pp, **kwargs): if problem_regime == ProblemRegime.every: results_dict = eval_problem_fun(pp, **kwargs) elif ( - problem_regime == ProblemRegime.mean or problem_regime == ProblemRegime.max + problem_regime == ProblemRegime.mean + or problem_regime == ProblemRegime.max ): mpp = {gid: pp for gid in cell_index_set} results_dict = eval_problem_fun(mpp, **kwargs) diff --git a/src/miv_simulator/optimize_network.py b/src/miv_simulator/optimize_network.py index d847c43..7c5396e 100644 --- a/src/miv_simulator/optimize_network.py +++ b/src/miv_simulator/optimize_network.py @@ -112,7 +112,9 @@ def dmosopt_get_best(file_path, opt_id): epochs=epochs, feasible=True, ) - best_x_items = tuple((param_names[i], best_x[:, i]) for i in range(best_x.shape[1])) + best_x_items = tuple( + (param_names[i], best_x[:, i]) for i in range(best_x.shape[1]) + ) best_y_items = tuple( (objective_names[i], best_y[:, i]) for i in range(best_y.shape[1]) ) @@ -215,9 +217,12 @@ def optimize_network( resample_fraction = 0.1 # Create an optimizer - feature_dtypes = [(feature_name, np.float32) for feature_name in objective_names] + feature_dtypes = [ + (feature_name, np.float32) for feature_name in objective_names + ] constraint_names = [ - f"{target_pop_name} positive rate" for target_pop_name in target_populations + f"{target_pop_name} positive rate" + for target_pop_name in target_populations ] dmosopt_params = { "opt_id": "miv_simulator.optimize_network", @@ -255,7 +260,9 @@ def optimize_network( } if get_best: - best = dmosopt_get_best(dmosopt_params["file_path"], dmosopt_params["opt_id"]) + best = dmosopt_get_best( + dmosopt_params["file_path"], dmosopt_params["opt_id"] + ) else: best = dmosopt.run( dmosopt_params, @@ -279,7 +286,9 @@ def optimize_network( results_config_dict = {} for i in range(n_res): result_param_list = [] - for param_pattern, param_tuple in zip(param_names, param_tuples): + for param_pattern, param_tuple in zip( + param_names, param_tuples + ): result_param_list.append( ( param_tuple.population, @@ -298,7 +307,9 @@ def optimize_network( def init_network_objfun( operational_config, opt_targets, param_names, param_tuples, worker, **kwargs ): - param_tuples = [syn_param_from_dict(param_tuple) for param_tuple in param_tuples] + param_tuples = [ + syn_param_from_dict(param_tuple) for param_tuple in param_tuples + ] objective_names = operational_config["objective_names"] target_populations = operational_config["target_populations"] @@ -427,7 +438,9 @@ def compute_objectives(local_features, operational_config, opt_targets): constraints.append(rate_constr) objective_names = operational_config["objective_names"] - feature_dtypes = [(feature_name, np.float32) for feature_name in objective_names] + feature_dtypes = [ + (feature_name, np.float32) for feature_name in objective_names + ] target_vals = opt_targets target_ranges = opt_targets diff --git a/src/miv_simulator/synapses.py b/src/miv_simulator/synapses.py index 62c1818..13d893e 100644 --- a/src/miv_simulator/synapses.py +++ b/src/miv_simulator/synapses.py @@ -260,7 +260,9 @@ def synapse_seg_counts( ran = None if ran is not None: l = L / nseg - dens = ran.normal(density_dict["mean"], density_dict["variance"]) + dens = ran.normal( + density_dict["mean"], density_dict["variance"] + ) rc = dens * l segcount_total += rc segcounts.append(rc) @@ -360,15 +362,18 @@ def distribute_uniform_synapses( int_seg_count = math.floor(seg_count) syn_count = 0 while syn_count < int_seg_count: - syn_loc = seg_start + seg_range * (syn_count + 1) / math.ceil( - seg_count - ) + syn_loc = seg_start + seg_range * ( + syn_count + 1 + ) / math.ceil(seg_count) assert (syn_loc <= 1) & (syn_loc >= 0) if syn_loc < 1.0: syn_cdist = math.sqrt( reduce( lambda a, b: a + b, - (interp_loc[i](syn_loc) ** 2 for i in range(3)), + ( + interp_loc[i](syn_loc) ** 2 + for i in range(3) + ), ) ) syn_cdists.append(syn_cdist) @@ -522,7 +527,9 @@ def distribute_poisson_synapses( else: while True: sample = r.exponential(beta) - if (sample >= L_seg_start) and (sample < L_seg_end): + if (sample >= L_seg_start) and ( + sample < L_seg_end + ): break interval += sample while interval < L_seg_end: @@ -646,7 +653,9 @@ def __init__( self, env: AbstractEnv, syn_mech_names: Dict[str, str], - syn_param_rules: Dict[str, Dict[str, Union[str, List[str], Dict[str, int]]]], + syn_param_rules: Dict[ + str, Dict[str, Union[str, List[str], Dict[str, int]]] + ], ) -> None: """An Env object containing imported network configuration metadata uses an instance of SynapseAttributes to track all metadata @@ -665,7 +674,9 @@ def __init__( self.env = env self.syn_mech_names = syn_mech_names self.syn_config = { - k: v["synapses"] for k, v in env.celltypes.items() if "synapses" in v + k: v["synapses"] + for k, v in env.celltypes.items() + if "synapses" in v } self.syn_param_rules = syn_param_rules self.syn_name_index_dict = { @@ -766,7 +777,9 @@ def init_syn_id_attrs( """ if gid in self.syn_id_attr_dict: - raise RuntimeError(f"Entry {gid} exists in synapse attribute dictionary") + raise RuntimeError( + f"Entry {gid} exists in synapse attribute dictionary" + ) else: syn_dict = self.syn_id_attr_dict[gid] sec_dict = self.sec_dict[gid] @@ -1288,7 +1301,9 @@ def modify_mech_attrs( attr_dict[k] = new_val else: - raise RuntimeError(f"modify_mech_attrs: unknown type of parameter {k}") + raise RuntimeError( + f"modify_mech_attrs: unknown type of parameter {k}" + ) syn.attr_dict[syn_index] = attr_dict def add_mech_attrs_from_iter( @@ -1440,10 +1455,12 @@ def partition_synapses_by_source( source_names = {id: name for name, id in self.env.Populations.items()} source_order = { - id: i for i, (name, id) in enumerate(sorted(self.env.Populations.items())) + id: i + for i, (name, id) in enumerate(sorted(self.env.Populations.items())) } source_inverse_order = { - i: id for i, (name, id) in enumerate(sorted(self.env.Populations.items())) + i: id + for i, (name, id) in enumerate(sorted(self.env.Populations.items())) } if syn_ids is None: @@ -1460,9 +1477,9 @@ def partition_synapses_by_source( ) return { - source_names[source_inverse_order[source_id_x[0]]]: generator_ifempty( - source_id_x[1] - ) + source_names[ + source_inverse_order[source_id_x[0]] + ]: generator_ifempty(source_id_x[1]) for source_id_x in enumerate(source_parts) } @@ -1517,10 +1534,12 @@ def partition_syn_ids_by_source( source_names = {id: name for name, id in self.env.Populations.items()} source_order = { - id: i for i, (name, id) in enumerate(sorted(self.env.Populations.items())) + id: i + for i, (name, id) in enumerate(sorted(self.env.Populations.items())) } source_inverse_order = { - i: id for i, (name, id) in enumerate(sorted(self.env.Populations.items())) + i: id + for i, (name, id) in enumerate(sorted(self.env.Populations.items())) } syn_id_attr_dict = self.syn_id_attr_dict[gid] @@ -1534,9 +1553,9 @@ def partition_pred(syn_id): source_iter = partitionn(syn_ids, partition_pred, n=len(source_names)) return { - source_names[source_inverse_order[source_id_x[0]]]: generator_ifempty( - source_id_x[1] - ) + source_names[ + source_inverse_order[source_id_x[0]] + ]: generator_ifempty(source_id_x[1]) for source_id_x in enumerate(source_iter) } @@ -1616,11 +1635,21 @@ def insert_hoc_cell_syns( swc_type_ais = env.SWC_Types["ais"] swc_type_hill = env.SWC_Types["hillock"] - syns_dict_dend = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: None))) - syns_dict_axon = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: None))) - syns_dict_ais = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: None))) - syns_dict_hill = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: None))) - syns_dict_soma = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: None))) + syns_dict_dend = defaultdict( + lambda: defaultdict(lambda: defaultdict(lambda: None)) + ) + syns_dict_axon = defaultdict( + lambda: defaultdict(lambda: defaultdict(lambda: None)) + ) + syns_dict_ais = defaultdict( + lambda: defaultdict(lambda: defaultdict(lambda: None)) + ) + syns_dict_hill = defaultdict( + lambda: defaultdict(lambda: defaultdict(lambda: None)) + ) + syns_dict_soma = defaultdict( + lambda: defaultdict(lambda: defaultdict(lambda: None)) + ) syns_dict_by_type = { swc_type_apical: syns_dict_dend, @@ -1650,7 +1679,9 @@ def insert_hoc_cell_syns( syn_attrs = env.synapse_attributes syn_id_attr_dict = syn_attrs.syn_id_attr_dict[gid] - make_syn_mech = make_unique_synapse_mech if unique else make_shared_synapse_mech + make_syn_mech = ( + make_unique_synapse_mech if unique else make_shared_synapse_mech + ) syn_count = 0 nc_count = 0 @@ -1715,7 +1746,9 @@ def insert_hoc_cell_syns( this_nc, this_vecstim = mknetcon_vecstim( syn_pps, delay=syn.source.delay ) - syn_attrs.add_vecstim(gid, syn_id, syn_name, this_vecstim, this_nc) + syn_attrs.add_vecstim( + gid, syn_id, syn_name, this_vecstim, this_nc + ) if insert_netcons: if this_nc is None: this_nc = mknetcon( @@ -1783,7 +1816,9 @@ def insert_biophys_cell_syns( cell = env.biophys_cells[postsyn_name][gid] - connection_syn_params = env.connection_config[postsyn_name][presyn_name].mechanisms + connection_syn_params = env.connection_config[postsyn_name][ + presyn_name + ].mechanisms synapse_config = env.celltypes[postsyn_name]["synapses"] @@ -1849,7 +1884,9 @@ def config_biophys_cell_syns( syn_ids = list(syn_id_attr_dict.keys()) if insert: - source_syn_ids_dict = syn_attrs.partition_syn_ids_by_source(gid, syn_ids) + source_syn_ids_dict = syn_attrs.partition_syn_ids_by_source( + gid, syn_ids + ) if not (gid in env.biophys_cells[postsyn_name]): raise KeyError( f"config_biophys_cell_syns: insert: biophysical cell with gid {gid} does not exist" @@ -2148,7 +2185,9 @@ def config_syn( def syn_in_seg( syn_name: str, seg: Segment, - syns_dict: DefaultDict[Section, DefaultDict[float, DefaultDict[str, "HocObject"]]], + syns_dict: DefaultDict[ + Section, DefaultDict[float, DefaultDict[str, "HocObject"]] + ], ) -> Optional["HocObject"]: """ If a synaptic mechanism of the specified type already exists in the specified segment, it is returned. Otherwise, @@ -2181,7 +2220,9 @@ def make_syn_mech(mech_name: str, seg: Segment) -> "HocObject": def make_shared_synapse_mech( syn_name: str, seg: Segment, - syns_dict: DefaultDict[Section, DefaultDict[float, DefaultDict[str, "HocObject"]]], + syns_dict: DefaultDict[ + Section, DefaultDict[float, DefaultDict[str, "HocObject"]] + ], mech_names: Optional[Dict[str, str]] = None, ) -> "HocObject": """ @@ -2384,7 +2425,9 @@ def modify_syn_param( :param verbose: bool """ if sec_type not in cell.nodes: - raise ValueError(f"modify_syn_mech_param: sec_type: {sec_type} not in cell") + raise ValueError( + f"modify_syn_mech_param: sec_type: {sec_type} not in cell" + ) if param_name is None: raise ValueError( f"modify_syn_mech_param: missing required parameter to modify synaptic mechanism: {syn_name} " @@ -2420,7 +2463,9 @@ def modify_syn_param( cell.mech_dict[sec_type]["synapses"][syn_name] = mech_content # This parameter of this syn_name has already been specified in this type of section, and the user wants to append # a new rule set - elif param_name in cell.mech_dict[sec_type]["synapses"][syn_name] and append: + elif ( + param_name in cell.mech_dict[sec_type]["synapses"][syn_name] and append + ): cell.mech_dict[sec_type]["synapses"][syn_name][param_name].append(rules) # This syn_name has been specified, but not this parameter, or the user wants to replace an existing rule set else: @@ -2503,7 +2548,9 @@ def update_syn_mech_param_by_sec_type( """ new_rules = copy.deepcopy(rules) if "filters" in new_rules: - synapse_filters = get_syn_filter_dict(env, new_rules["filters"], convert=True) + synapse_filters = get_syn_filter_dict( + env, new_rules["filters"], convert=True + ) del new_rules["filters"] else: synapse_filters = {} @@ -2728,7 +2775,8 @@ def write_syn_spike_count( syn_names = list(syn_attrs.syn_name_index_dict.keys()) output_dict = { - syn_name: defaultdict(lambda: defaultdict(int)) for syn_name in syn_names + syn_name: defaultdict(lambda: defaultdict(int)) + for syn_name in syn_names } gids = [] diff --git a/src/scripts/analysis/network_clamp.py b/src/scripts/analysis/network_clamp.py index 90cca69..50970c4 100644 --- a/src/scripts/analysis/network_clamp.py +++ b/src/scripts/analysis/network_clamp.py @@ -33,7 +33,9 @@ def cli(): default="GC", help="target population", ) -@click.option("--gid", "-g", required=True, type=int, default=0, help="target cell gid") +@click.option( + "--gid", "-g", required=True, type=int, default=0, help="target cell gid" +) @click.option( "--arena-id", "-a", @@ -217,7 +219,9 @@ def show( multiple=True, help="generate weights for the given presynaptic population", ) -@click.option("--t-max", "-t", type=float, default=150.0, help="simulation end time") +@click.option( + "--t-max", "-t", type=float, default=150.0, help="simulation end time" +) @click.option("--t-min", type=float) @click.option( "--template-paths", @@ -339,7 +343,9 @@ def show( default="Network clamp default", help="recording profile to use", ) -@click.option("--input-seed", type=int, help="seed for generation of spike trains") +@click.option( + "--input-seed", type=int, help="seed for generation of spike trains" +) def go( config_file, config_prefix, @@ -443,7 +449,9 @@ def go( help="file containing target cell gids", ) @click.option("--arena-id", "-a", type=str, required=False, help="arena id") -@click.option("--stimulus-id", "-s", type=str, required=False, help="stimulus id") +@click.option( + "--stimulus-id", "-s", type=str, required=False, help="stimulus id" +) @click.option( "--generate-weights", "-w", @@ -452,7 +460,9 @@ def go( multiple=True, help="generate weights for the given presynaptic population", ) -@click.option("--t-max", "-t", type=float, default=150.0, help="simulation end time") +@click.option( + "--t-max", "-t", type=float, default=150.0, help="simulation end time" +) @click.option("--t-min", type=float) @click.option( "--nprocs-per-worker",