Skip to content

Commit

Permalink
moved back and cleaned up bernoulli and poisson cells
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 25, 2024
1 parent 9afaadf commit bf72094
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 123 deletions.
77 changes: 5 additions & 72 deletions ngclearn/components/input_encoders/bernoulliCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,49 +24,10 @@ def _update_times(t, s, tols):
_tols = (1. - s) * tols + (s * t)
return _tols

@jit
def _sample_bernoulli(dkey, data):
"""
Samples a Bernoulli spike train on-the-fly
Args:
dkey: JAX key to drive stochasticity/noise
data: sensory data (vector/matrix)
Returns:
binary spikes
"""
s_t = random.bernoulli(dkey, p=data).astype(jnp.float32)
return s_t

@partial(jit, static_argnums=[3])
def _sample_constrained_bernoulli(dkey, data, dt, fmax=63.75):
"""
Samples a Bernoulli spike train on-the-fly that is constrained to emit
at a particular rate over a time window.
Args:
dkey: JAX key to drive stochasticity/noise
data: sensory data (vector/matrix)
dt: integration time constant
fmax: maximum frequency (Hz)
Returns:
binary spikes
"""
pspike = data * (dt/1000.) * fmax
eps = random.uniform(dkey, data.shape, minval=0., maxval=1., dtype=jnp.float32)
s_t = (eps < pspike).astype(jnp.float32)
return s_t

class BernoulliCell(JaxComponent):
"""
A Bernoulli cell that produces variations of Bernoulli-distributed spikes
on-the-fly (including constrained-rate trains).
A Bernoulli cell that produces spikes by sampling a Bernoulli distribution
on-the-fly (to produce data-scaled Bernoulli spike trains).
| --- Cell Input Compartments: ---
| inputs - input (takes in external signals)
Expand All @@ -80,17 +41,11 @@ class BernoulliCell(JaxComponent):
name: the string name of this cell
n_units: number of cellular entities (neural population size)
target_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
"""

@deprecate_args(max_freq="target_freq")
def __init__(self, name, n_units, target_freq=0., batch_size=1, **kwargs):
def __init__(self, name, n_units, batch_size=1, **kwargs):
super().__init__(name, **kwargs)

## Constrained Bernoulli meta-parameters
self.target_freq = target_freq ## maximum frequency (in Hertz/Hz)

## Layer Size Setup
self.batch_size = batch_size
self.n_units = n_units
Expand All @@ -101,32 +56,10 @@ def __init__(self, name, n_units, target_freq=0., batch_size=1, **kwargs):
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike

def validate(self, dt=None, **validation_kwargs):
valid = super().validate(**validation_kwargs)
if dt is None:
warn(f"{self.name} requires a validation kwarg of `dt`")
return False
## check for unstable combinations of dt and target-frequency meta-params
events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability
if events_per_timestep > 1.:
valid = False
warn(
f"{self.name} will be unable to make as many temporal events as "
f"requested! ({events_per_timestep} events/timestep) Unstable "
f"combination of dt = {dt} and target_freq = {self.target_freq} "
f"being used!"
)
return valid

@staticmethod
def _advance_state(t, dt, target_freq, key, inputs, tols):
def _advance_state(t, key, inputs, tols):
key, *subkeys = random.split(key, 2)
if target_freq > 0.:
outputs = _sample_constrained_bernoulli( ## sample Bernoulli w/ target rate
subkeys[0], data=inputs, dt=dt, fmax=target_freq
)
else:
outputs = _sample_bernoulli(subkeys[0], data=inputs)
outputs = random.bernoulli(subkeys[0], p=inputs).astype(jnp.float32)
tols = _update_times(t, outputs, tols)
return outputs, tols, key

Expand Down
97 changes: 46 additions & 51 deletions ngclearn/components/input_encoders/poissonCell.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
from ngclearn import resolver, Component, Compartment
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit
from ngclearn.utils import tensorstats
from jax import numpy as jnp, random, jit, scipy
from functools import partial
from ngcsimlib.deprecators import deprecate_args
from ngcsimlib.logger import info, warn

@jit
def _update_times(t, s, tols):
"""
Updates time-of-last-spike (tols) variable.
Args:
t: current time (a scalar/int value)
s: binary spike vector
tols: current time-of-last-spike variable
Returns:
updated tols variable
"""
_tols = (1. - s) * tols + (s * t)
return _tols

class PoissonCell(JaxComponent):
"""
A Poisson cell that produces approximately Poisson-distributed spikes
on-the-fly.
A Poisson cell that samples a homogeneous Poisson process on-the-fly to
produce a spike train.
| --- Cell Input Compartments: ---
| inputs - input (takes in external signals)
Expand All @@ -24,45 +42,33 @@ class PoissonCell(JaxComponent):
n_units: number of cellular entities (neural population size)
max_freq: maximum frequency (in Hertz) of this Poisson spike train (
must be > 0.)
target_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
"""

# Define Functions
@deprecate_args(max_freq="target_freq")
def __init__(self, name, n_units, target_freq=63.75, batch_size=1,
**kwargs):
def __init__(self, name, n_units, target_freq=0., batch_size=1, **kwargs):
super().__init__(name, **kwargs)

## Poisson meta-parameters
## Constrained Bernoulli meta-parameters
self.target_freq = target_freq ## maximum frequency (in Hertz/Hz)

## Layer Size Setup
self.batch_size = batch_size
self.n_units = n_units

_key, subkey = random.split(self.key.value, 2)
self.key.set(_key)
## Compartment setup
# Compartments (state of the cell, parameters, will be updated through stateless calls)
restVals = jnp.zeros((self.batch_size, self.n_units))
self.inputs = Compartment(restVals,
display_name="Input Stimulus") # input
# compartment
self.outputs = Compartment(restVals,
display_name="Spikes") # output compartment
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
units="ms") # time of last spike
self.targets = Compartment(
random.uniform(subkey, (self.batch_size, self.n_units), minval=0.,
maxval=1.))
self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike

def validate(self, dt=None, **validation_kwargs):
valid = super().validate(**validation_kwargs)
if dt is None:
warn(f"{self.name} requires a validation kwarg of `dt`")
return False
## check for unstable combinations of dt and target-frequency meta-params
events_per_timestep = (dt / 1000.) * self.target_freq ## compute scaled probability
events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability
if events_per_timestep > 1.:
valid = False
warn(
Expand All @@ -74,54 +80,43 @@ def validate(self, dt=None, **validation_kwargs):
return valid

@staticmethod
def _advance_state(t, dt, target_freq, key, inputs, targets, tols):
ms_per_second = 1000 # ms/s
events_per_ms = target_freq / ms_per_second # e/s s/ms -> e/ms
ms_per_event = 1 / events_per_ms # ms/e
time_step_per_event = ms_per_event / dt # ms/e * ts/ms -> ts / e

cdf = scipy.special.gammaincc((t + dt) - tols,
time_step_per_event/inputs)
outputs = (targets < cdf).astype(jnp.float32)

key, subkey = random.split(key, 2)
targets = (targets * (1 - outputs) + random.uniform(subkey,
targets.shape) *
outputs)

tols = tols * (1. - outputs) + t * outputs
return outputs, tols, key, targets
def _advance_state(t, dt, target_freq, key, inputs, tols):
key, *subkeys = random.split(key, 2)
pspike = inputs * (dt / 1000.) * target_freq
eps = random.uniform(subkeys[0], inputs.shape, minval=0., maxval=1.,
dtype=jnp.float32)
outputs = (eps < pspike).astype(jnp.float32)
tols = _update_times(t, outputs, tols)
return outputs, tols, key

@resolver(_advance_state)
def advance_state(self, outputs, tols, key, targets):
def advance_state(self, outputs, tols, key):
self.outputs.set(outputs)
self.tols.set(tols)
self.key.set(key)
self.targets.set(targets)

@staticmethod
def _reset(batch_size, n_units, key):
def _reset(batch_size, n_units):
restVals = jnp.zeros((batch_size, n_units))
key, subkey = random.split(key, 2)
targets = random.uniform(subkey, (batch_size, n_units))
return restVals, restVals, restVals, targets, key
return restVals, restVals, restVals

@resolver(_reset)
def reset(self, inputs, outputs, tols, targets, key):
def reset(self, inputs, outputs, tols):
self.inputs.set(inputs)
self.outputs.set(outputs)
self.outputs.set(outputs) #None
self.tols.set(tols)
self.key.set(key)
self.targets.set(targets)

def save(self, directory, **kwargs):
target_freq = (self.target_freq if isinstance(self.target_freq, float)
else jnp.ones([[self.target_freq]]))
file_name = directory + "/" + self.name + ".npz"
jnp.savez(file_name, key=self.key.value)
jnp.savez(file_name, key=self.key.value, target_freq=target_freq)

def load(self, directory, **kwargs):
file_name = directory + "/" + self.name + ".npz"
data = jnp.load(file_name)
self.key.set(data['key'])
self.target_freq = data['target_freq']

@classmethod
def help(cls): ## component help function
Expand Down

0 comments on commit bf72094

Please sign in to comment.