Skip to content

Commit

Permalink
cleaned up raf
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Aug 8, 2024
1 parent 8882208 commit ee50f33
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions ngclearn/components/neurons/spiking/RAFCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,17 @@ class RAFCell(JaxComponent):
thr: voltage/membrane threshold (to obtain action potentials in terms
of binary spikes) (Default: 5 mV)
omega: angular frequency (Default: 10)
b: oscillation dampening factor (Default: -1)
v_reset: membrane reset potential condition (Default: 0 mV)
w_reset: reset condition for angular driver (Default: 0 mV)
w_reset: reset condition for angular driver (Default: 0)
v0: membrane potential initial condition (Default: 0 mV)
b: oscillation dampening factor (Default: -1.)
w0: angular driver initial condition (Default: 0)
integration_type: type of integration to use for this cell's dynamics;
current supported forms include "euler" (Euler/RK-1 integration)
Expand All @@ -110,8 +116,8 @@ class RAFCell(JaxComponent):

# Define Functions
def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
omega=10., thr=5., v_reset=0., w_reset=0., b=-1.,
integration_type="euler", batch_size=1, **kwargs):
thr=5., omega=10., b=-1., v_reset=0., w_reset=0.,
v0=0., w0=0., integration_type="euler", batch_size=1, **kwargs):
#v_rest=-72., v_reset=-75., w_reset=0., thr=5., v0=-70., w0=0.,
super().__init__(name, **kwargs)

Expand All @@ -129,6 +135,8 @@ def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
#self.v_rest = v_rest
self.v_reset = v_reset
self.w_reset = w_reset
self.v0 = v0
self.w0 = w0
self.thr = thr

## Layer Size Setup
Expand All @@ -147,9 +155,6 @@ def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
@staticmethod
def _advance_state(t, dt, tau_m, resist_m, tau_w, thr, omega, b,
v_reset, w_reset, intgFlag, j, v, w, tols):
## center variables before running dynamics
v = v - v_reset
w = w - w_reset
## continue with centered dynamics
j_ = j * resist_m
if intgFlag == 1: ## RK-2/midpoint
Expand All @@ -164,11 +169,8 @@ def _advance_state(t, dt, tau_m, resist_m, tau_w, thr, omega, b,
_, _v = step_euler(0., v, _dfv, dt, v_params)
s = _emit_spike(_v, thr)
## hyperpolarize/reset/snap variables
v = _v * (1. - s) + s #* v_reset
w = _w * (1. - s) + s #* w_reset
## artificially shift variables back to rest/reset values
v = v + v_reset
w = w + w_reset
v = _v * (1. - s) + s * v_reset
w = _w * (1. - s) + s * w_reset
tols = _update_times(t, s, tols)
return j, v, w, s, tols

Expand All @@ -181,11 +183,11 @@ def advance_state(self, j, v, w, s, tols):
self.tols.set(tols)

@staticmethod
def _reset(batch_size, n_units, v_reset, w_reset):
def _reset(batch_size, n_units, v0, w0):
restVals = jnp.zeros((batch_size, n_units))
j = restVals # None
v = restVals + v_reset
w = restVals + w_reset
v = restVals + v0
w = restVals + w0
s = restVals #+ 0
tols = restVals #+ 0
return j, v, w, s, tols
Expand Down Expand Up @@ -221,9 +223,8 @@ def help(cls): ## component help function
"tau_m": "Cell membrane time constant",
"resist_m": "Membrane resistance value",
"tau_w": "Recovery variable time constant",
"v_thr": "Base voltage threshold value",
"v_rest": "Resting membrane potential value",
"v_reset": "Reset membrane potential value",
"w_reset": "Reset angular driver value",
"b": "Exponential dampening factor applied to oscillations",
"omega": "Angular frequency of neuronal progress per second (radians)",
"v0": "Initial condition for membrane potential/voltage",
Expand Down

0 comments on commit ee50f33

Please sign in to comment.