Skip to content

Commit

Permalink
simulation and comparison of classical STDP curves
Browse files Browse the repository at this point in the history
  • Loading branch information
jlubo committed Apr 2, 2024
1 parent dbeed8d commit fe902ae
Show file tree
Hide file tree
Showing 14 changed files with 456 additions and 45 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
*.so
*.dat
STDP/*.svg
10 changes: 5 additions & 5 deletions STDP/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ BUILD_CATALOGUE_SCRIPT := arbor-build-catalogue
custom-catalogue.so: $(wildcard mechanisms/*.mod)
$(BUILD_CATALOGUE_SCRIPT) custom mechanisms

.PRECIOUS: arbor_traces_%.dat arbor_spikes_%.dat
arbor_traces_%.dat arbor_spikes_%.dat: config_%.json arbor_lif_stdp.py custom-catalogue.so
.PRECIOUS: arbor_traces_%_lif.dat arbor_spikes_%_lif.dat arbor_traces_%_classical.dat
arbor_traces_%_lif.dat arbor_spikes_%_lif.dat: config_%_lif.json arbor_stdp_lif.py arbor_stdp_classical.py custom-catalogue.so config_brian2_arbor_lif.json config_brian2_arbor_classical.json
./run_arbor.sh $*

.PRECIOUS: brian2_traces_%.dat brian2_spikes_%.dat
brian2_traces_%.dat brian2_spikes_%.dat: config_%.json brian2_lif_stdp.py
.PRECIOUS: brian2_traces_%_lif.dat brian2_spikes_%_lif.dat brian2_traces_%_classical.dat
brian2_traces_%_lif.dat brian2_spikes_%_lif.dat brian2_traces_%_classical.dat: config_%_lif.json brian2_stdp_lif.py brian2_stdp_classical.py config_brian2_arbor_lif.json config_brian2_arbor_classical.json
./run_brian2.sh $*

comparison_%.png: arbor_traces_%.dat arbor_spikes_%.dat brian2_traces_%.dat brian2_spikes_%.dat compare.py
comparison_%.png: arbor_traces_%_lif.dat arbor_spikes_%_lif.dat brian2_traces_%_lif.dat brian2_spikes_%_lif.dat brian2_traces_%_classical.dat compare.py
./compare.py $*

.PHONY: clean
Expand Down
206 changes: 206 additions & 0 deletions STDP/arbor_stdp_classical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
#!/usr/bin/env python3
"""
Arbor simulation of two neuron populations connecting via STDP synapses.
Event generators are not used; instead, the spiking is inherently triggered in mechanisms,
resembling the way of the Brian 2 implementation in 'brian2_stdp_classical.py'.
"""

import json
import arbor
import numpy as np


class SingleRecipe(arbor.recipe):
"""Implementation of Arbor simulation recipe."""

def __init__(self, config):
"""Initialize the recipe from config."""

# The base C++ class constructor must be called first, to ensure that
# all memory in the C++ class is initialized correctly.
arbor.recipe.__init__(self)

self.the_props = arbor.neuron_cable_properties()
self.the_cat = arbor.load_catalogue("./custom-catalogue.so")
self.the_cat.extend(arbor.default_catalogue(), "")
self.the_props.catalogue = self.the_cat

self.config = config
self.N = config["simulation"]["N"]
self.dt = self.config["simulation"]["dt"]
self.t_max = self.config["simulation"]["runtime"]

# arrays of spike time values
self.t_spike_1 = np.array([ ])
self.t_spike_2 = np.array([ ])


def num_cells(self):
"""Return the number of cells."""
return 2*self.N


def num_sources(self, gid):
"""Return the number of spikes sources on gid."""
if gid < self.N:
return 0
else:
return 1


def num_targets(self, gid):
"""Return the number of post-synaptic targets on gid."""
if gid < self.N:
return 1
else:
return 0


def cell_kind(self, gid):
"""Return type of cell with gid."""
return arbor.cell_kind.cable


def cell_description(self, gid):
"""Return cell description of gid."""

# morphology
tree = arbor.segment_tree()
radius = self.config["neuron"]["radius"]

tree.append(arbor.mnpos,
arbor.mpoint(-radius, 0, 0, radius),
arbor.mpoint(radius, 0, 0, radius),
tag=1)

labels = arbor.label_dict({'center': '(location 0 0.5)'})

# cell mechanism
e_thresh = self.the_cat[self.config["neuron"]["type"]].parameters["e_thresh"].default
e_reset = self.the_cat[self.config["neuron"]["type"]].parameters["e_reset"].default
decor = arbor.decor()
decor.set_property(Vm=e_reset)
neuron = arbor.mechanism(self.config["neuron"]["type"])
neuron.set("tau_refrac", self.config["neuron"]["tau_refrac"])
if gid < self.N:
# define spike times for neurons 0 to N-1
t_spike = gid*self.t_max/(self.N-1)
neuron.set("t_spike", t_spike)
try:
self.t_spike_1 = np.column_stack((self.t_spike_1, [t_spike, gid]))
except ValueError:
self.t_spike_1 = [t_spike, gid]
else:
# define spike times for neurons N to 2*N
t_spike = (2*self.N-1-gid)*self.t_max/(self.N-1)
neuron.set("t_spike", t_spike)
try:
self.t_spike_2 = np.column_stack((self.t_spike_2, [t_spike, gid]))
except ValueError:
self.t_spike_2 = [t_spike, gid]

# add incoming plastic synapse
syn_config_stdp = self.config["synapses"]["stdp"]
mech_expsyn = arbor.mechanism('expsyn_stdp')
mech_expsyn.set('taupre', syn_config_stdp["tau_pre"])
mech_expsyn.set('taupost', syn_config_stdp["tau_post"])
mech_expsyn.set('Apre', syn_config_stdp["A_pre"])
mech_expsyn.set('Apost', syn_config_stdp["A_post"])
mech_expsyn.set('max_weight', 50)
decor.place('"center"', arbor.synapse(mech_expsyn), "expsyn_stdp_post")

decor.place('"center"', arbor.threshold_detector(e_thresh), "spike_detector")
decor.paint('(all)', arbor.density(neuron))

return arbor.cable_cell(tree, decor, labels)


def connections_on(self, gid):
"""Defines the list of synaptic connections incoming to the neuron given by gid"""

policy = arbor.selection_policy.univalent
weight = 0
delay = self.dt # may not be <= 0

# neurons with gid 0 to N-1 are presynaptic
if gid < self.N:
conn = [ ]

# neurons with gid N to 2*N are postsynaptic
else:
src = gid - self.N
conn = [arbor.connection((src, "spike_detector"), ('expsyn_stdp_post', policy), weight, delay)]

return conn


def probes(self, gid):
"""Return probes on gid."""

probe_list = []
#probe_list = [arbor.cable_probe_membrane_voltage('"center"')]
#probe_list = [arbor.cable_probe_density_state('"center"', self.config["neuron"]["type"], "t")]

# neurons with gid N to 2*N are postsynaptic
if gid >= self.N and gid < 2*self.N:
probe_list.append(arbor.cable_probe_point_state(0, "expsyn_stdp", "weight_plastic"))

return probe_list


def global_properties(self, kind):
"""Return the global properties."""
assert kind == arbor.cell_kind.cable

return self.the_props


def main(variant):
"""Runs simulation and stores results."""

# set up simulation and run
config = json.load(open(f"config_{variant}_classical.json", 'r'))
recipe = SingleRecipe(config)

context = arbor.context()
domains = arbor.partition_load_balance(recipe, context)
sim = arbor.simulation(recipe, context, domains)

sim.record(arbor.spike_recording.all)
reg_sched = arbor.regular_schedule(config["simulation"]["dt"])
handle_weight_plastic_array = [sim.sample((i, 0), reg_sched) for i in range(recipe.N, 2*recipe.N)]

sim.run(tfinal=config["simulation"]["runtime"] + 1,
dt=config["simulation"]["dt"])

# read out and store weight changes and spike data
data_weight_plastic_final = np.zeros(recipe.N)
for i in range(recipe.N):
if len(sim.samples(handle_weight_plastic_array[i])) > 0:
data_buf, _ = sim.samples(handle_weight_plastic_array[i])[0]
data_weight_plastic_final[i] = data_buf[-1, 1]

t_spike_1_unsorted_T = recipe.t_spike_1.T
t_spike_2_unsorted_T = recipe.t_spike_2.T
t_spike_1 = t_spike_1_unsorted_T[t_spike_1_unsorted_T[:,1].argsort()].T
t_spike_2 = t_spike_2_unsorted_T[t_spike_2_unsorted_T[:,1].argsort()].T

data_stacked = np.column_stack(
[t_spike_2[0] - t_spike_1[0],
data_weight_plastic_final])

spikes = np.column_stack((sim.spikes()['time'], sim.spikes()['source']['gid']))

np.savetxt(f'arbor_traces_{variant}_classical.dat', data_stacked)
np.savetxt(f'arbor_spikes_{variant}_classical.dat', spikes, fmt="%.4f %.0f") # integer formatting for neuron number


if __name__ == '__main__':

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('variant', help="name of variant, e.g., brian2_arbor")
args = parser.parse_args()

main(args.variant)
6 changes: 3 additions & 3 deletions STDP/arbor_lif_stdp.py → STDP/arbor_stdp_lif.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def main(variant):
"""Runs simulation and stores results."""

# set up simulation and run
config = json.load(open(f"config_{variant}.json", 'r'))
config = json.load(open(f"config_{variant}_lif.json", 'r'))
recipe = SingleRecipe(config)

context = arbor.context()
Expand Down Expand Up @@ -166,8 +166,8 @@ def main(variant):

spike_times = sorted([s[1] for s in sim.spikes()])

numpy.savetxt(f'arbor_traces_{variant}.dat', data_stacked)
numpy.savetxt(f'arbor_spikes_{variant}.dat', spike_times)
numpy.savetxt(f'arbor_traces_{variant}_lif.dat', data_stacked)
numpy.savetxt(f'arbor_spikes_{variant}_lif.dat', spike_times)


if __name__ == '__main__':
Expand Down
82 changes: 82 additions & 0 deletions STDP/brian2_stdp_classical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python3
"""
Brian 2 simulation of two neuron populations connecting via STDP synapses.
"""

import json
import numpy as np
from brian2 import ms, siemens, uS
from brian2 import NeuronGroup, Synapses, SpikeMonitor
from brian2 import run, defaultclock


def main(variant):
"""Runs simulation of classical STDP curve (based on
https://brian2.readthedocs.io/en/stable/resources/tutorials/2-intro-to-brian-synapses.html)
and stores results."""

config = json.load(open(f"config_{variant}_classical.json"))
neuron_config = config["neuron"]
#start_scope()

syn_config_stdp = config["synapses"]["stdp"]

tau_refrac = neuron_config["tau_refrac"] * ms

defaultclock.dt = config["simulation"]["dt"] * ms

tau_pre = syn_config_stdp["tau_pre"] * ms
tau_post = syn_config_stdp["tau_post"] * ms
A_pre = syn_config_stdp["A_pre"] * uS
A_post = - 1.05 * A_pre * tau_pre / tau_post
t_max = config["simulation"]["runtime"]*ms
N = config["simulation"]["N"]

# Presynaptic neurons (`neurons_1`) spike at times from 0 to t_max
# Postsynaptic neurons (`neurons_2`) spike at times from t_max to 0
# So difference in spike times will vary from -t_max to +t_max
neurons_1 = NeuronGroup(N, 't_spike : second', threshold='t > t_spike', refractory=tau_refrac)
neurons_2 = NeuronGroup(N, 't_spike : second', threshold='t > t_spike', refractory=tau_refrac)
neurons_1.t_spike = 'i*t_max/(N-1)'
neurons_2.t_spike = '(N-1-i)*t_max/(N-1)'

S = Synapses(neurons_1, neurons_2,
'''
w : siemens
dapre/dt = -apre/tau_pre : siemens (event-driven)
dapost/dt = -apost/tau_post : siemens (event-driven)
''',
on_pre='''
apre += A_pre
w = w+apost
''',
on_post='''
apost += A_post
w = w+apre
''')
S.connect(j='i') # as many synapses as neurons in each group
S.w = syn_config_stdp["weight"] * uS

spikemon_1 = SpikeMonitor(neurons_1)
spikemon_2 = SpikeMonitor(neurons_2)

run(t_max + 1 * ms)

np.savetxt(f'brian2_traces_{variant}_classical.dat',
np.column_stack([(neurons_2.t_spike - neurons_1.t_spike) / ms, S.w / uS]))

spike_indices = np.vstack((spikemon_1.i, spikemon_2.i)).flatten()
spike_times = np.vstack((spikemon_1.t / ms, spikemon_2.t / ms)).flatten()
np.savetxt(f'brian2_spikes_{variant}_classical.dat',
np.sort(np.column_stack([spike_times, spike_indices]), axis=0),
fmt="%.4f %.0f") # integer formatting for neuron number


if __name__ == '__main__':

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('variant')
args = parser.parse_args()

main(args.variant)
22 changes: 11 additions & 11 deletions STDP/brian2_lif_stdp.py → STDP/brian2_stdp_lif.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#!/usr/bin/env python3
"""Brian2 simulation of a single cell
Brian2 simulation of a single cell receiving inhibitory and plastic
"""Brian 2 simulation of a single cell receiving inhibitory and plastic
excitatory stimulus.
"""

Expand All @@ -14,9 +12,10 @@


def main(variant):
"""Runs simulation and stores results."""
"""Runs simulation with spikes generated at specific times
and stores results."""

config = json.load(open(f"config_{variant}.json"))
config = json.load(open(f"config_{variant}_lif.json"))
neuron_config = config["neuron"]
# sphere with 200 um radius
area = 4 * np.pi * (neuron_config["radius"] * umeter)**2
Expand Down Expand Up @@ -62,20 +61,21 @@ def main(variant):
A_pre = syn_config_stdp["A_pre"] * uS
A_post = syn_config_stdp["A_post"] * uS

S_exc = Synapses(ssg_exc, neurons, '''w : siemens
S_exc = Synapses(ssg_exc, neurons,
'''w : siemens
dapre/dt = -apre/tau_pre : siemens (event-driven)
dapost/dt = -apost/tau_post : siemens (event-driven)
''',
on_pre='''
on_pre='''
ge += w
apre += A_pre
w += apost
''',
on_post='''
on_post='''
apost += A_post
w += apre
''',
delay=0 * ms)
delay=0 * ms)

S_exc.connect('True')
S_exc.w = syn_config_stdp["weight"] * uS
Expand All @@ -91,11 +91,11 @@ def main(variant):

run(config["simulation"]["runtime"] * ms)

np.savetxt(f'brian2_traces_{variant}.dat', np.column_stack(
np.savetxt(f'brian2_traces_{variant}_lif.dat', np.column_stack(
[neuron_monitor.t / ms, neuron_monitor.v[0] / mV,
neuron_monitor.ge[0] / uS, neuron_monitor.gi[0] / uS,
synapse_monitor.w[0] / uS]))
np.savetxt(f'brian2_spikes_{variant}.dat', spikemon.t / ms)
np.savetxt(f'brian2_spikes_{variant}_lif.dat', spikemon.t / ms)


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit fe902ae

Please sign in to comment.