From cafb30fb331e405dbd5ca1938101c46cf75c49bc Mon Sep 17 00:00:00 2001 From: kaeldai Date: Wed, 16 Feb 2022 10:55:52 -0800 Subject: [PATCH 1/4] fixing plotting functionality in examples dir --- examples/bio_450cells/plot_output.py | 16 ++++++++++------ examples/bio_450cells_compact/plot_output.py | 15 ++++++--------- .../bio_simulated_annealing/calculate_rates.py | 7 +++---- .../bio_simulated_annealing/check_preformance.py | 2 +- examples/bio_simulated_annealing/plot_output.py | 4 ++-- examples/filter_graitings/config.simulation.json | 2 +- examples/filter_graitings/plot_output.py | 9 +++++++++ .../config.simulation.json | 2 +- examples/filter_looming/config.simulation.json | 2 +- examples/point_120cells/plot_output.py | 12 ++++++------ examples/point_450cells/plot_output.py | 7 +++---- examples/point_450glifs/plot_output.py | 6 +++--- examples/point_iclamp/plot_output.py | 8 +++----- 13 files changed, 49 insertions(+), 43 deletions(-) create mode 100644 examples/filter_graitings/plot_output.py diff --git a/examples/bio_450cells/plot_output.py b/examples/bio_450cells/plot_output.py index 78de8964..65c898ad 100644 --- a/examples/bio_450cells/plot_output.py +++ b/examples/bio_450cells/plot_output.py @@ -3,13 +3,17 @@ from bmtk.analyzer.compartment import plot_traces from bmtk.analyzer.spike_trains import plot_raster, plot_rates_boxplot +config_file = 'config.simulation.json' +# config_file = 'config.simulation_vm.json' +# config_file = 'config.simulation_ecp.json' # Setting show to False so we can display all the plots at the same time -plot_raster(config_file='config.json', group_by='model_name', show=False) -plot_rates_boxplot(config_file='config.json', group_by='model_name', show=False) - - -plot_traces(config_file='config.json', report_name='membrane_potential', group_by='model_name', - times=(0.0, 200.0), show=False) +plot_raster(config_file=config_file, show=False) # , group_by='model_name', show=False) +plot_rates_boxplot(config_file=config_file, group_by='model_name', show=False) +if config_file == 'config.simulation_vm.json': + plot_traces( + config_file='config.simulation_vm.json', report_name='membrane_potential', group_by='model_name', + times=(0.0, 200.0), show=False + ) plt.show() diff --git a/examples/bio_450cells_compact/plot_output.py b/examples/bio_450cells_compact/plot_output.py index 5ec7fff8..c621b975 100644 --- a/examples/bio_450cells_compact/plot_output.py +++ b/examples/bio_450cells_compact/plot_output.py @@ -5,14 +5,11 @@ # Setting show to False so we can display all the plots at the same time -plot_raster(config_file='config.json', group_by='model_name', show=False) -plot_rates_boxplot(config_file='config.json', group_by='model_name', show=False) - - -plot_traces(config_file='config.json', report_name='membrane_potential', group_by='model_name', - group_excludes=['LIF_exc', 'LIF_inh'], times=(0.0, 200.0), show=False) - -plot_traces(config_file='config.json', report_name='calcium_concentration', group_by='model_name', - group_excludes=['LIF_exc', 'LIF_inh'], times=(0.0, 200.0), show=False) +plot_raster(config_file='config.simulation.json', group_by='model_name', show=False) +plot_rates_boxplot(config_file='config.simulation.json', group_by='model_name', show=False) +plot_traces( + config_file='config.simulation.json', report_name='membrane_potential', group_by='model_name', + group_excludes=['LIF_exc', 'LIF_inh'], times=(0.0, 200.0), show=False +) plt.show() diff --git a/examples/bio_simulated_annealing/calculate_rates.py b/examples/bio_simulated_annealing/calculate_rates.py index e9442f2a..dd3f5ad4 100644 --- a/examples/bio_simulated_annealing/calculate_rates.py +++ b/examples/bio_simulated_annealing/calculate_rates.py @@ -1,15 +1,14 @@ from bmtk.simulator import bionet from bmtk.utils.reports.spike_trains import SpikeTrains -config_file = 'config.json' +config_file = 'config.simulation.json' conf = bionet.Config.from_json(config_file, validate=True) -#conf.build_env() - +# conf.build_env() graph = bionet.BioNetwork.from_config(conf) graph.build() node_props = graph.node_properties() node_ids = {k: v.tolist() for k, v in node_props['v1'].groupby('pop_name').groups.items()} print(node_ids) -st = SpikeTrains.load('output/spikes.h5') \ No newline at end of file +st = SpikeTrains.load('output/spikes.h5') diff --git a/examples/bio_simulated_annealing/check_preformance.py b/examples/bio_simulated_annealing/check_preformance.py index b388014a..c68b754d 100644 --- a/examples/bio_simulated_annealing/check_preformance.py +++ b/examples/bio_simulated_annealing/check_preformance.py @@ -16,6 +16,6 @@ def run(config_file): sim.run() -mem = memory_usage((run, ('config.json',))) +mem = memory_usage((run, ('config.simulation.json',))) print(max(mem)) pc.done() diff --git a/examples/bio_simulated_annealing/plot_output.py b/examples/bio_simulated_annealing/plot_output.py index b28c846f..9077cf43 100644 --- a/examples/bio_simulated_annealing/plot_output.py +++ b/examples/bio_simulated_annealing/plot_output.py @@ -3,6 +3,6 @@ from bmtk.analyzer.spike_trains import plot_raster, plot_rates -plot_raster(config_file='config.json', title='Raster', show=False) -plot_rates(config_file='config.json', title='Rates', show=False) +plot_raster(config_file='config.simulation.json', title='Raster', show=False) +plot_rates(config_file='config.simulation.json', title='Rates', show=False) plt.show() diff --git a/examples/filter_graitings/config.simulation.json b/examples/filter_graitings/config.simulation.json index 50f5c2a7..c9058c11 100644 --- a/examples/filter_graitings/config.simulation.json +++ b/examples/filter_graitings/config.simulation.json @@ -42,7 +42,7 @@ "output_dir": "$OUTPUT_DIR", "rates_csv": "rates.csv", "spikes_csv": "spikes.csv", - "spikes_h5": "spikes.h5", + "spikes_file": "spikes.h5", "overwrite_output_dir": true }, diff --git a/examples/filter_graitings/plot_output.py b/examples/filter_graitings/plot_output.py new file mode 100644 index 00000000..7c471d63 --- /dev/null +++ b/examples/filter_graitings/plot_output.py @@ -0,0 +1,9 @@ +import matplotlib.pyplot as plt +from bmtk.analyzer.spike_trains import plot_raster, plot_rates_boxplot + + +# fig = plot_raster(spikes_file='output/spikes.h5', show=False) # , group_by='model_name', show=False) +fig = plot_raster(config_file='config.simulation.json', group_by='model_template', show=False) + +# fig.savefig('my_raster.png') +plt.show() \ No newline at end of file diff --git a/examples/filter_graitings_preset/config.simulation.json b/examples/filter_graitings_preset/config.simulation.json index 513b556c..7cdbe9b0 100644 --- a/examples/filter_graitings_preset/config.simulation.json +++ b/examples/filter_graitings_preset/config.simulation.json @@ -39,7 +39,7 @@ "output_dir": "$OUTPUT_DIR", "rates_csv": "rates.csv", "spikes_csv": "spikes.csv", - "spikes_h5": "spikes.h5", + "spikes_file": "spikes.h5", "overwrite_output_dir": true }, diff --git a/examples/filter_looming/config.simulation.json b/examples/filter_looming/config.simulation.json index f161533b..cae915bf 100644 --- a/examples/filter_looming/config.simulation.json +++ b/examples/filter_looming/config.simulation.json @@ -33,7 +33,7 @@ "log_file": "log.txt", "rates_csv": "rates.csv", "spikes_csv": "spikes.csv", - "spikes_h5": "spikes.h5", + "spikes_file": "spikes.h5", "overwrite_output_dir": true }, diff --git a/examples/point_120cells/plot_output.py b/examples/point_120cells/plot_output.py index dbe4cce1..bcac1e0d 100644 --- a/examples/point_120cells/plot_output.py +++ b/examples/point_120cells/plot_output.py @@ -5,11 +5,11 @@ # Setting show to False so we can display all the plots at the same time -plot_raster(config_file='config.json', group_by='pop_name', show=False) -plot_rates_boxplot(config_file='config.json', group_by='pop_name', show=False) - - -plot_traces(config_file='config.json', report_name='membrane_potential', group_by='pop_name', - times=(0.0, 200.0), show=False) +plot_raster(config_file='config.simulation.json', group_by='pop_name', show=False) +plot_rates_boxplot(config_file='config.simulation.json', group_by='pop_name', show=False) +plot_traces( + config_file='config.simulation.json', report_name='membrane_potential', group_by='pop_name', + times=(0.0, 200.0), show=False +) plt.show() diff --git a/examples/point_450cells/plot_output.py b/examples/point_450cells/plot_output.py index ee27edb2..010e6935 100644 --- a/examples/point_450cells/plot_output.py +++ b/examples/point_450cells/plot_output.py @@ -5,9 +5,8 @@ # Setting show to False so we can display all the plots at the same time -plot_raster(config_file='config.json', group_by='model_name', show=False) -plot_rates(config_file='config.json', group_by='model_name', smoothing=True, show=False) -plot_rates_boxplot(config_file='config.json', group_by='model_name', show=False) - +plot_raster(config_file='config.simulation.json', group_by='model_name', show=False) +plot_rates(config_file='config.simulation.json', group_by='model_name', smoothing=True, show=False) +plot_rates_boxplot(config_file='config.simulation.json', group_by='model_name', show=False) plt.show() diff --git a/examples/point_450glifs/plot_output.py b/examples/point_450glifs/plot_output.py index 9cdea7a9..eda7eb33 100644 --- a/examples/point_450glifs/plot_output.py +++ b/examples/point_450glifs/plot_output.py @@ -4,8 +4,8 @@ # Setting show to False so we can display all the plots at the same time -plot_raster(config_file='config.json', group_by='model_name', show=False) -plot_rates(config_file='config.json', group_by='model_name', smoothing=True, show=False) -plot_rates_boxplot(config_file='config.json', group_by='model_name', show=False) +plot_raster(config_file='config.simulation.json', group_by='model_name', show=False) +plot_rates(config_file='config.simulation.json', group_by='model_name', smoothing=True, show=False) +plot_rates_boxplot(config_file='config.simulation.json', group_by='model_name', show=False) plt.show() diff --git a/examples/point_iclamp/plot_output.py b/examples/point_iclamp/plot_output.py index edadca88..c28fe5dd 100644 --- a/examples/point_iclamp/plot_output.py +++ b/examples/point_iclamp/plot_output.py @@ -5,10 +5,8 @@ # Setting show to False so we can display all the plots at the same time -plot_raster(config_file='config.json', group_by='pop_name', show=False) -plot_rates_boxplot(config_file='config.json', group_by='pop_name', show=False) - - -plot_traces(config_file='config.json', report_name='membrane_potential', group_by='pop_name', show=False) +plot_raster(config_file='config.simulation.json', group_by='pop_name', show=False) +plot_rates_boxplot(config_file='config.simulation.json', group_by='pop_name', show=False) +plot_traces(config_file='config.simulation.json', report_name='membrane_potential', group_by='pop_name', show=False) plt.show() From 9a9c6d65eee4830e28462b212a52ae632cdc0db6 Mon Sep 17 00:00:00 2001 From: kaeldai Date: Wed, 16 Feb 2022 11:35:03 -0800 Subject: [PATCH 2/4] updating extra-stim to allow for non-mesh electrodes per issue #198 --- bmtk/simulator/bionet/modules/xstim.py | 50 ++++++++++++++------------ 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/bmtk/simulator/bionet/modules/xstim.py b/bmtk/simulator/bionet/modules/xstim.py index f2192ff6..2eb0e519 100644 --- a/bmtk/simulator/bionet/modules/xstim.py +++ b/bmtk/simulator/bionet/modules/xstim.py @@ -26,8 +26,8 @@ def __init__(self, positions_file, waveform, mesh_files_dir=None, cells=None, se self._local_gids = [] self._fih = None - #def __set_extracellular_mechanism(self): - # for gid in self._local_gids: + # def __set_extracellular_mechanism(self): + # for gid in self._local_gids: def initialize(self, sim): if self._cells is None: @@ -47,7 +47,7 @@ def initialize(self, sim): def set_pointers(): for gid in self._local_gids: cell = sim.net.get_cell_gid(gid) - #cell = sim.net.get_local_cell(gid) + # cell = sim.net.get_local_cell(gid) cell.set_ptr2e_extracellular() self._fih = sim.h.FInitializeHandler(0, set_pointers) @@ -72,7 +72,10 @@ def __init__(self, positions_file, waveform, mesh_files_dir, dt): stimelectrode_position_df = pd.read_csv(positions_file, sep=' ') - self.elmesh_files = stimelectrode_position_df['electrode_mesh_file'] + if 'electrode_mesh_file' in stimelectrode_position_df.columns: + self.elmesh_files = stimelectrode_position_df['electrode_mesh_file'] + else: + self.elmesh_files = None self.elpos = stimelectrode_position_df[['pos_x', 'pos_y', 'pos_z']].T.values self.elrot = stimelectrode_position_df[['rotation_x', 'rotation_y', 'rotation_z']].values self.elnsites = self.elpos.shape[1] # Number of electrodes in electrode file @@ -88,23 +91,30 @@ def __init__(self, positions_file, waveform, mesh_files_dir, dt): self.place_the_electrodes() def read_electrode_mesh(self): - el_counter = 0 - for mesh_file in self.elmesh_files: - file_path = mesh_file if os.path.isabs(mesh_file) else os.path.join(self._mesh_files_dir, mesh_file) - mesh = pd.read_csv(file_path, sep=" ") - mesh_size = mesh.shape[0] - self.el_mesh_size.append(mesh_size) - - self.el_mesh[el_counter] = np.zeros((3, mesh_size)) - self.el_mesh[el_counter][0] = mesh['x_pos'] - self.el_mesh[el_counter][1] = mesh['y_pos'] - self.el_mesh[el_counter][2] = mesh['z_pos'] - el_counter += 1 + if self.elmesh_files is None: + # if electrode_mesh_file is missing then we still treat the electrode as a mesh on a grid of size 1 with + # single point at the (0, 0, 0) relative to electrode position + for el_counter in range(self.elnsites): + self.el_mesh_size.append(1) + self.el_mesh[el_counter] = np.zeros((3, 1)) - def place_the_electrodes(self): + else: + # Each electrode has an associate mesh file + el_counter = 0 + for mesh_file in self.elmesh_files: + file_path = mesh_file if os.path.isabs(mesh_file) else os.path.join(self._mesh_files_dir, mesh_file) + mesh = pd.read_csv(file_path, sep=" ") + mesh_size = mesh.shape[0] + self.el_mesh_size.append(mesh_size) + + self.el_mesh[el_counter] = np.zeros((3, mesh_size)) + self.el_mesh[el_counter][0] = mesh['x_pos'] + self.el_mesh[el_counter][1] = mesh['y_pos'] + self.el_mesh[el_counter][2] = mesh['z_pos'] + el_counter += 1 + def place_the_electrodes(self): transfer_vector = np.zeros((self.elnsites, 3)) - for el in range(self.elnsites): mesh_mean = np.mean(self.el_mesh[el], axis=1) transfer_vector[el] = self.elpos[:, el] - mesh_mean[:] @@ -128,17 +138,13 @@ def rotate_the_electrodes(self): self.el_mesh[el] = new_mesh def set_transfer_resistance(self, gid, seg_coords): - rho = 300.0 # ohm cm r05 = seg_coords['p05'] nseg = r05.shape[1] cell_map = np.zeros((self.elnsites, nseg)) for el in six.moves.range(self.elnsites): - mesh_size = self.el_mesh_size[el] - for k in range(mesh_size): - rel = np.expand_dims(self.el_mesh[el][:, k], axis=1) rel_05 = rel - r05 r2 = np.einsum('ij,ij->j', rel_05, rel_05) From 96652f4270f733e7c5fa3b6c2692266a59217103 Mon Sep 17 00:00:00 2001 From: kaeldai Date: Wed, 16 Feb 2022 12:10:37 -0800 Subject: [PATCH 3/4] adding save_as option to plotting as suggested by issue #194 --- bmtk/analyzer/compartment.py | 7 +++- bmtk/analyzer/spike_trains.py | 45 +++++++++++++-------- bmtk/utils/reports/compartment/plotting.py | 29 ++++++++----- bmtk/utils/reports/spike_trains/plotting.py | 22 ++++++++-- 4 files changed, 72 insertions(+), 31 deletions(-) diff --git a/bmtk/analyzer/compartment.py b/bmtk/analyzer/compartment.py index 0224c4e3..18ac9b40 100644 --- a/bmtk/analyzer/compartment.py +++ b/bmtk/analyzer/compartment.py @@ -65,7 +65,7 @@ def _find_nodes(population, config=None, nodes_file=None, node_types_file=None): def plot_traces(config_file=None, report_name=None, population=None, report_path=None, group_by=None, group_excludes=None, nodes_file=None, node_types_file=None, node_ids=None, sections='origin', average=False, times=None, title=None, - show_legend=None, show=True): + show_legend=None, show=True, save_as=None): """Plot compartment variables (eg Membrane Voltage, Calcium conc.) traces from the output of simulation. Will attempt to look in the SONATA simulation configuration json "reports" sections for any matching "membrane_report" outputs with a matching report_name:: @@ -115,6 +115,8 @@ def plot_traces(config_file=None, report_name=None, population=None, report_path :param show_legend: Set True or False to determine if legend should be displayed on the plot. The default (None) function itself will guess if legend should be shown. :param show: bool to display or not display plot. default True. + :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not + save plot. :return: matplotlib figure.Figure object """ @@ -176,5 +178,6 @@ def plot_traces(config_file=None, report_name=None, population=None, report_path times=times, title=title, show_legend=show_legend, - show=show + show=show, + save_as=save_as ) diff --git a/bmtk/analyzer/spike_trains.py b/bmtk/analyzer/spike_trains.py index 0fa08059..7d452614 100644 --- a/bmtk/analyzer/spike_trains.py +++ b/bmtk/analyzer/spike_trains.py @@ -94,7 +94,7 @@ def _find_nodes(population, config=None, nodes_file=None, node_types_file=None): raise ValueError('Could not find nodes file with node population "{}".'.format(population)) -def _plot_helper(plot_fnc, config_file=None, population=None, times=None, title=None, show=True, +def _plot_helper(plot_fnc, config_file=None, population=None, times=None, title=None, show=True, save_as=None, group_by=None, group_excludes=None, spikes_file=None, nodes_file=None, node_types_file=None): sonata_config = SonataConfig.from_json(config_file) if config_file else None @@ -136,11 +136,14 @@ def _plot_helper(plot_fnc, config_file=None, population=None, times=None, title= else: node_groups = None - return plot_fnc(spike_trains=spike_trains, node_groups=node_groups, population=pop, times=times, title=title, show=show) + return plot_fnc( + spike_trains=spike_trains, node_groups=node_groups, population=pop, times=times, title=title, show=show, + save_as=save_as + ) def plot_raster(config_file=None, population=None, with_histogram=True, times=None, title=None, show=True, - group_by=None, group_excludes=None, + save_as=None, group_by=None, group_excludes=None, spikes_file=None, nodes_file=None, node_types_file=None): """Create a raster plot (plus optional histogram) from the results of the simulation. @@ -170,6 +173,8 @@ def plot_raster(config_file=None, population=None, with_histogram=True, times=No :param title: str, adds a title to the plot. If None (default) then name will be automatically generated using the report_name. :param show: bool to display or not display plot. default True. + :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not + save plot. :param group_by: Attribute of the "nodes" file used to group and average subsets of nodes. :param group_excludes: list of strings or None. When using the "group_by", allows users to exclude certain groupings based on the attribute value. @@ -181,15 +186,17 @@ def plot_raster(config_file=None, population=None, with_histogram=True, times=No :return: matplotlib figure.Figure object """ plot_fnc = partial(plotting.plot_raster, with_histogram=with_histogram) - return _plot_helper(plot_fnc, - config_file=config_file, population=population, times=times, title=title, show=show, - group_by=group_by, group_excludes=group_excludes, - spikes_file=spikes_file, nodes_file=nodes_file, node_types_file=node_types_file + return _plot_helper( + plot_fnc, + config_file=config_file, population=population, times=times, title=title, show=show, save_as=save_as, + group_by=group_by, group_excludes=group_excludes, + spikes_file=spikes_file, nodes_file=nodes_file, node_types_file=node_types_file ) def plot_rates(config_file=None, population=None, smoothing=False, smoothing_params=None, times=None, title=None, - show=True, group_by=None, group_excludes=None, spikes_file=None, nodes_file=None, node_types_file=None): + show=True, save_as=None, group_by=None, group_excludes=None, spikes_file=None, nodes_file=None, + node_types_file=None): """Calculate and plot the rates of each node recorded during the simulation - averaged across the entirety of the simulation. @@ -221,6 +228,8 @@ def plot_rates(config_file=None, population=None, smoothing=False, smoothing_par :param title: str, adds a title to the plot. If None (default) then name will be automatically generated using the report_name. :param show: bool to display or not display plot. default True. + :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not + save plot. :param group_by: Attribute of the "nodes" file used to group and average subsets of nodes. :param group_excludes: list of strings or None. When using the "group_by", allows users to exclude certain groupings based on the attribute value. @@ -232,14 +241,15 @@ def plot_rates(config_file=None, population=None, smoothing=False, smoothing_par :return: matplotlib figure.Figure object """ plot_fnc = partial(plotting.plot_rates, smoothing=smoothing, smoothing_params=smoothing_params) - return _plot_helper(plot_fnc, - config_file=config_file, population=population, times=times, title=title, show=show, - group_by=group_by, group_excludes=group_excludes, - spikes_file=spikes_file, nodes_file=nodes_file, node_types_file=node_types_file + return _plot_helper( + plot_fnc, + config_file=config_file, population=population, times=times, title=title, show=show, save_as=save_as, + group_by=group_by, group_excludes=group_excludes, + spikes_file=spikes_file, nodes_file=nodes_file, node_types_file=node_types_file ) -def plot_rates_boxplot(config_file=None, population=None, times=None, title=None, show=True, +def plot_rates_boxplot(config_file=None, population=None, times=None, title=None, show=True, save_as=None, group_by=None, group_excludes=None, spikes_file=None, nodes_file=None, node_types_file=None): """Creates a box plot of the firing rates taken from nodes recorded during the simulation. @@ -268,6 +278,8 @@ def plot_rates_boxplot(config_file=None, population=None, times=None, title=None :param title: str, adds a title to the plot. If None (default) then name will be automatically generated using the report_name. :param show: bool to display or not display plot. default True. + :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not + save plot. :param group_by: Attribute of the "nodes" file used to group and average subsets of nodes. :param group_excludes: list of strings or None. When using the "group_by", allows users to exclude certain groupings based on the attribute value. @@ -279,8 +291,9 @@ def plot_rates_boxplot(config_file=None, population=None, times=None, title=None :return: matplotlib figure.Figure object """ plot_fnc = partial(plotting.plot_rates_boxplot) - return _plot_helper(plot_fnc, - config_file=config_file, population=population, times=times, title=title, show=show, + return _plot_helper( + plot_fnc, + config_file=config_file, population=population, times=times, title=title, show=show, save_as=save_as, group_by=group_by, group_excludes=group_excludes, spikes_file=spikes_file, nodes_file=nodes_file, node_types_file=node_types_file ) @@ -344,8 +357,6 @@ def to_dataframe(config_file, spikes_file=None, population=None): :param population: :return: """ - - # _, spike_trains = _find_spikes(config_file=config_file, spikes_file=spikes_file, population=population) pop, spike_trains = _find_spikes(config_file=config_file, spikes_file=spikes_file, population=population) diff --git a/bmtk/utils/reports/compartment/plotting.py b/bmtk/utils/reports/compartment/plotting.py index 0628e58b..fa35d977 100644 --- a/bmtk/utils/reports/compartment/plotting.py +++ b/bmtk/utils/reports/compartment/plotting.py @@ -43,7 +43,7 @@ def __get_node_groups(report, node_groups, population): def plot_traces(report, population=None, node_ids=None, sections='origin', average=False, node_groups=None, times=None, - title=None, show_legend=None, show=True): + title=None, show_legend=None, show=True, save_as=None): """Displays the time trace of one or more nodes from a SONATA CompartmentReport file. To plot a group of individual variable traces (based on their soma):: @@ -79,6 +79,8 @@ def plot_traces(report, population=None, node_ids=None, sections='origin', avera :param show_legend: Set True or False to determine if legend should be displayed on the plot. The default (None) function itself will guess if legend should be shown. :param show: bool to display or not display plot. default True. + :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not + save plot. :return: matplotlib figure.Figure object """ if node_groups is not None and node_ids is not None: @@ -89,26 +91,26 @@ def plot_traces(report, population=None, node_ids=None, sections='origin', avera node_groups = {'node_ids': node_ids, 'label': average} return plot_traces_averaged(report=report, population=population, sections=sections, node_groups=node_groups, times=times, title=title, show_legend=show_legend, - show=show) + show=show, save_as=save_as) return plot_traces_individual(report=report, population=population, node_ids=node_ids, sections=sections, - times=times, title=title, show_legend=show_legend, show=show) + times=times, title=title, show_legend=show_legend, show=show, save_as=save_as) elif node_groups is not None: return plot_traces_averaged(report=report, population=population, sections=sections, node_groups=node_groups, - times=times, title=title, show_legend=show_legend, show=show) + times=times, title=title, show_legend=show_legend, show=show, save_as=save_as) elif average is not None: return plot_traces_averaged(report=report, population=population, sections=sections, times=times, title=title, - show_legend=show_legend, show=show) + show_legend=show_legend, show=show, save_as=save_as) else: return plot_traces_individual(report=report, population=population, sections=sections, times=times, title=title, - show_legend=show_legend, show=show) + show_legend=show_legend, show=show, save_as=save_as) def plot_traces_individual(report, population=None, node_ids=None, sections='origin', times=None, title=None, - show_legend=None, show=True): + show_legend=None, show=True, save_as=None): """Used the plot time traces of individual nodes from a SONATA compartment report file. Recommended use plot_traces instead, which will call this function if necessarcy. @@ -124,6 +126,8 @@ def plot_traces_individual(report, population=None, node_ids=None, sections='ori :param show_legend: Set True or False to determine if legend should be displayed on the plot. The default (None) function itself will guess if legend should be shown. :param show: bool to display or not display plot. default True. + :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not + save plot. :return: matplotlib figure.Figure object """ @@ -157,6 +161,9 @@ def plot_traces_individual(report, population=None, node_ids=None, sections='ori if title: axes.set_title(title) + if save_as: + plt.savefig(save_as) + if show: plt.show() @@ -164,7 +171,7 @@ def plot_traces_individual(report, population=None, node_ids=None, sections='ori def plot_traces_averaged(report, population=None, sections='origin', node_groups=None, times=None, title=None, - show_background=True, show_legend=None, show=True): + show_background=True, show_legend=None, show=True, save_as=None): """Used to plot averages across multiple nodes in a SONATA Compartment Report file. Recommended that you use "plot_traces" function. @@ -192,6 +199,8 @@ def plot_traces_averaged(report, population=None, sections='origin', node_groups :param show_legend: Set True or False to determine if legend should be displayed on the plot. The default (None) function itself will guess if legend should be shown. :param show: bool to display or not display plot. default True. + :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not + save plot. :return: matplotlib figure.Figure object """ cr = CompartmentReport.load(report) @@ -220,7 +229,6 @@ def plot_traces_averaged(report, population=None, sections='origin', node_groups background_data = traces_data[:, nodes_indx] axes.plot(trace_times, background_data, c='lightgray') - for node_grp in node_groups: grp_ids = node_grp.pop('node_ids') has_labels = has_labels or 'label' in node_grp @@ -248,6 +256,9 @@ def plot_traces_averaged(report, population=None, sections='origin', node_groups if (show_legend is None or show_legend) and has_labels: axes.legend() + if save_as: + plt.savefig(save_as) + if show: plt.show() diff --git a/bmtk/utils/reports/spike_trains/plotting.py b/bmtk/utils/reports/spike_trains/plotting.py index 8ac497c5..790776ca 100644 --- a/bmtk/utils/reports/spike_trains/plotting.py +++ b/bmtk/utils/reports/spike_trains/plotting.py @@ -81,7 +81,7 @@ def __get_node_groups(spike_trains, node_groups, population): def plot_raster(spike_trains, with_histogram=True, population=None, node_groups=None, times=None, title=None, - show=True): + show=True, save_as=None): """will create a raster plot (plus optional histogram) from a SpikeTrains object or SONATA Spike-Trains file. Will return the figure @@ -107,6 +107,8 @@ def plot_raster(spike_trains, with_histogram=True, population=None, node_groups= data. :param title: str, Use to add a title. Default no tile :param show: bool to display or not display plot. default True. + :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not + save plot. :return: matplotlib figure.Figure object """ @@ -172,6 +174,9 @@ def plot_raster(spike_trains, with_histogram=True, population=None, node_groups= hist_axes.axes.get_yaxis().set_visible(False) raster_axes.set_xticks([]) + if save_as: + plt.savefig(save_as) + if show: plt.show() @@ -185,7 +190,7 @@ def moving_average(data, window_size=10): def plot_rates(spike_trains, population=None, node_groups=None, times=None, smoothing=False, - smoothing_params=None, title=None, show=True): + smoothing_params=None, title=None, show=True, save_as=None): """Calculate and plot the rates of each node in a SpikeTrains object or SONATA Spike-Trains file. If start and stop times are not specified from the "times" parameter, will try to parse values from the timestamps data. @@ -210,6 +215,8 @@ def plot_rates(spike_trains, population=None, node_groups=None, times=None, smoo :param smoothing_params: dict, parameters when using a function pointer smoothing value. :param title: str, Use to add a title. Default no tile :param show: bool to display or not display plot. default True. + :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not + save plot. :return: matplotlib figure.Figure object """ @@ -254,13 +261,17 @@ def plot_rates(spike_trains, population=None, node_groups=None, times=None, smoo if title: axes.set_title(title) + if save_as: + plt.savefig(save_as) + if show: plt.show() return fig -def plot_rates_boxplot(spike_trains, population=None, node_groups=None, times=None, title=None, show=True): +def plot_rates_boxplot(spike_trains, population=None, node_groups=None, times=None, title=None, show=True, + save_as=None): """Creates a box plot of the firing rates taken from a SpikeTrains object or SONATA Spike-Trains file. If start and stop times are not specified from the "times" parameter, will try to parse values from the timestamps data. @@ -282,6 +293,8 @@ def plot_rates_boxplot(spike_trains, population=None, node_groups=None, times=No label and color. If None all nodes will be labeled and colored the same. :param title: str, Use to add a title. Default no tile :param show: bool to display or not display plot. default True. + :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not + save plot. :return: matplotlib figure.Figure object """ spike_trains = __get_spike_trains(spike_trains=spike_trains) @@ -319,6 +332,9 @@ def plot_rates_boxplot(spike_trains, population=None, node_groups=None, times=No if title: axes.set_title(title) + if save_as: + plt.savefig(save_as) + if show: plt.show() From 66fe904bb249b81b03f4af0bc75bd57d8b091e4d Mon Sep 17 00:00:00 2001 From: kaeldai Date: Wed, 16 Feb 2022 12:11:42 -0800 Subject: [PATCH 4/4] bumping version --- bmtk/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtk/__init__.py b/bmtk/__init__.py index 07d14264..9365dfc6 100644 --- a/bmtk/__init__.py +++ b/bmtk/__init__.py @@ -20,4 +20,4 @@ # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -__version__ = '1.0.2' +__version__ = '1.0.3'