Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move setupmodelparameters from each model to abstractNMM #157

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 11 additions & 79 deletions examples/eg002r__multimodal_simulation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
# -*- coding: utf-8 -*-

r"""
=================================
Fitting S_E Mean to 0.164 using default RWW Parameters
=================================

What is being modeled:

- Created a Sphere'd Cube (chosen points on cube projected onto radius = 1 sphere), so that regions were more evently distributed. All corners of cube chosen as regions, thus there are 8 regions.

- EEG channels located on the center of each face of the cube. Thus there are 6 EEG channels.

- Added some randomness to initial values - to decorrelate the signals a bit. Looking for FC matrix to look similar to SC matrix.

Created a Sphered Cube chosen points on cube projected onto radius = 1 sphere, so that regions were more evently distributed. All corners of cube chosen as regions, thus there are 8 regions.
EEG channels located on the center of each face of the cube. Thus there are 6 EEG channels.
Added some randomness to initial values - to decorrelate the signals a bit. Looking for FC matrix to look similar to SC matrix.
"""

# sphinx_gallery_thumbnail_number = 1
Expand All @@ -21,6 +18,11 @@
# ---------------------------------------------------
#

# os stuff
import os
import sys
sys.path.append('..')

# whobpyt stuff
import whobpyt
from whobpyt.datatypes import par, Recording
Expand Down Expand Up @@ -193,7 +195,7 @@ def loss(self, simData, empData = None, returnLossComponents = False):
# Model Simulation
# ---------------------------------------------------
#
F.simulate(u = 0, numTP = randTS1.length)
F.evaluate(u=0, empRec=randTS1, TPperWindow=TPperWindow , base_window_num=2, transient_num = 10)


# %%
Expand All @@ -207,7 +209,6 @@ def loss(self, simData, empData = None, returnLossComponents = False):
plt.xlabel('Time Steps (multiply by step_size to get msec), step_size = ' + str(step_size))
plt.legend()


# %%
# Plots of EEG PSD
#
Expand All @@ -222,73 +223,4 @@ def loss(self, simData, empData = None, returnLossComponents = False):
plt.plot(sdAxis_dS, sdValues_dS_scaled.detach()[:,n])
plt.xlabel('Hz')
plt.ylabel('PSD')
plt.title("Simulated EEG PSD: After Training")


# %%
# Plots of BOLD FC
#

sim_FC = np.corrcoef(F.lastRec['bold'].npTS()[:,skip_trans:])

plt.figure(figsize = (8, 8))
plt.title("Simulated BOLD FC: After Training")
mask = np.eye(num_regions)
sns.heatmap(sim_FC, mask = mask, center=0, cmap='RdBu_r', vmin=-1.0, vmax = 1.0)


# %%
# CNMM Validation Model
# ---------------------------------------------------
#
# The Multi-Modal Model

model.eeg.params.LF = model.eeg.params.LF.cpu()

val_sim_len = 20*1000 # Simulation length in msecs
model_validate = RWWEI2_EEG_BOLD_np(num_regions, num_channels, model.params, model.eeg.params, model.bold.params, Con_Mtx.detach().cpu().numpy(), dist_mtx.detach().cpu().numpy(), step_size, val_sim_len)

sim_vals, hE = model_validate.forward(external = 0, hx = model_validate.createIC(ver = 0), hE = 0)


# %%
# Plots of S_E and S_I Validation
#

plt.figure(figsize = (16, 8))
plt.title("S_E and S_I")
for n in range(num_regions):
plt.plot(sim_vals['E'], label = "S_E Node = " + str(n))
plt.plot(sim_vals['I'], label = "S_I Node = " + str(n))

plt.xlabel('Time Steps (multiply by step_size to get msec), step_size = ' + str(step_size))
plt.legend()


# %%
# Plots of EEG PSD Validation
#

sampleFreqHz = 1000*(1/step_size)
sdAxis, sdValues = CostsPSD.calcPSD(torch.tensor(sim_vals['eeg']), sampleFreqHz, minFreq = 2, maxFreq = 40)
sdAxis_dS, sdValues_dS = CostsPSD.downSmoothPSD(sdAxis, sdValues, 32)
sdAxis_dS, sdValues_dS_scaled = CostsPSD.scalePSD(sdAxis_dS, sdValues_dS)

plt.figure()
for n in range(num_channels):
plt.plot(sdAxis_dS, sdValues_dS_scaled.detach()[:,n])
plt.xlabel('Hz')
plt.ylabel('PSD')
plt.title("Simulated EEG PSD: After Training")


# %%
# Plots of BOLD FC Validation
#

sim_FC = np.corrcoef((sim_vals['bold'].T)[:,skip_trans:])

plt.figure(figsize = (8, 8))
plt.title("Simulated BOLD FC: After Training")
mask = np.eye(num_regions)
sns.heatmap(sim_FC, mask = mask, center=0, cmap='RdBu_r', vmin=-1.0, vmax = 1.0)
plt.title("Simulated EEG PSD: After Training")
2 changes: 1 addition & 1 deletion examples/eg003r__fitting_rww_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@

# %%
# call model want to fit
model = RNNRWW(node_size, TPperWindow, step_size, repeat_size, tr, sc, True, params)
model = RNNRWW(params, node_size =node_size, TRs_per_window =TPperWindow, step_size=step_size, tr=tr, sc=sc, use_fit_gains=True)

# %%
# create objective function
Expand Down
2 changes: 1 addition & 1 deletion examples/eg004r__fitting_JR_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@

# %%
# call model want to fit
model = RNNJANSEN(node_size, TPperWindow, step_size, output_size, tr, sc, lm, dist, True, False, params)
model = RNNJANSEN(params, node_size=node_size, TRs_per_window=TPperWindow, step_size=step_size, output_size=output_size, tr=tr, sc=sc, lm=lm, dist=dist, use_fit_gains=True, use_fit_lfm = False)

# %%
# create objective function
Expand Down
88 changes: 4 additions & 84 deletions examples/eg005r__gpu_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
# Importage
# ---------------------------------------------------
#
# os stuff
import os
import sys
sys.path.append('..')

# whobpyt stuff
import whobpyt
Expand Down Expand Up @@ -112,10 +116,6 @@
end_time = time.time()
print(str((end_time - start_time)/60) + " minutes")

# %%
# Plots of loss over Training
plt.plot(np.arange(1,len(F.trainingStats.loss)+1), F.trainingStats.loss)
plt.title("Total Loss over Training Epochs")


# %%
Expand All @@ -124,89 +124,9 @@
#

# Simulation Length
step_size = 0.1 # Step Size in msecs
sim_len = 5000 # Simulation length in msecs
model = RWWEI2_EEG_BOLD(num_regions, num_channels, model.params, paramsEEG, paramsBOLD, Con_Mtx, dist_mtx, step_size, sim_len, device)

targetValue = torch.tensor([0.164]).to(device)
objFun = CostsmmRWWEI2(num_regions, simKey = "E", targetValue = targetValue, device = device)

# Create a Fitting Object
F = Fitting_FNGFPG(model, objFun, device)

# Training Data
empSubject = {}
empSubject['EEG_FC'] = channelFC
empSubject['BOLD_FC'] = sourceFC
num_epochs = 3
num_recordings = 1
block_len = 100 # in msec

# model training
start_time = time.time()
F.train(stim = 0, empDatas = [empSubject], num_epochs = num_epochs, block_len = block_len, learningrate = 0.05, resetIC = False)
end_time = time.time()
print(str((end_time - start_time)/60) + " minutes")

# %%
# Plots of loss over Training
plt.plot(np.arange(1,len(F.trainingStats.loss)+1), F.trainingStats.loss)
plt.title("Total Loss over Training Epochs")


# %%
# CNMM Verification Model
# ---------------------------------------------------
#
# The Multi-Modal Model

model.eeg.params.LF = model.eeg.params.LF.cpu()

val_sim_len = 20*1000 # Simulation length in msecs
model_validate = RWWEI2_EEG_BOLD_np(num_regions, num_channels, model.params, model.eeg.params, model.bold.params, Con_Mtx.detach().cpu().numpy(), dist_mtx.detach().cpu().numpy(), step_size, val_sim_len)

sim_vals, hE = model_validate.forward(external = 0, hx = model_validate.createIC(ver = 0), hE = 0)


# %%
# Plots of S_E and S_I Verification
#

plt.figure(figsize = (16, 8))
plt.title("S_E and S_I")
for n in range(num_regions):
plt.plot(sim_vals['E'][0:10000, n], label = "S_E Node = " + str(n))
plt.plot(sim_vals['I'][0:10000, n], label = "S_I Node = " + str(n))

plt.xlabel('Time Steps (multiply by step_size to get msec), step_size = ' + str(step_size))
plt.legend()


# %%
# Plots of EEG PSD Verification
#

sampleFreqHz = 1000*(1/step_size)
sdAxis, sdValues = CostsPSD.calcPSD(torch.tensor(sim_vals['eeg']), sampleFreqHz, minFreq = 2, maxFreq = 40)
sdAxis_dS, sdValues_dS = CostsPSD.downSmoothPSD(sdAxis, sdValues, 32)
sdAxis_dS, sdValues_dS_scaled = CostsPSD.scalePSD(sdAxis_dS, sdValues_dS)

plt.figure()
for n in range(num_channels):
plt.plot(sdAxis_dS, sdValues_dS_scaled.detach()[:,n])
plt.xlabel('Hz')
plt.ylabel('PSD')
plt.title("Simulated EEG PSD: After Training")


# %%
# Plots of BOLD FC Verification
#

skip_trans = int(500/step_size)
sim_FC = np.corrcoef((sim_vals['bold'].T)[:,skip_trans:])

plt.figure(figsize = (8, 8))
plt.title("Simulated BOLD FC: After Training")
mask = np.eye(num_regions)
sns.heatmap(sim_FC, mask = mask, center=0, cmap='RdBu_r', vmin=-1.0, vmax = 1.0)
35 changes: 32 additions & 3 deletions whobpyt/datatypes/AbstractNMM.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import torch

from whobpyt.datatypes.parameter import par
from torch.nn.parameter import Parameter
class AbstractNMM(torch.nn.Module):
# This is the abstract class for the models (typically neural mass models) that are being trained.
# The neuroimaging modality might be integrated into the model as well.

def __init__(self):
def __init__(self, params):
super(AbstractNMM, self).__init__() # May not want to enherit from torch.nn.Module in the future
self.params = params

self.state_names = ["None"] # The names of the state variables of the model
self.output_names = ["None"] # The variable to be used as output from the NMM, for purposes such as the input to an objective function
Expand All @@ -25,7 +27,34 @@ def setModelParameters(self):
# Setting the parameters that will be optimized as either model parameters or 2ndLevel/hyper
# parameters (for various optional features).
# This should be called in the __init__() function implementation for convenience if possible.
pass
# If use_fit_lfm is True, set lm as an attribute as type Parameter (containing variance information)
param_reg = []
param_hyper = []

var_names = [a for a in dir(self.params) if (type(getattr(self.params, a)) == par)]
for var_name in var_names:
var = getattr(self.params, var_name)
if (var.fit_hyper):
if var_name == 'lm':
size = var.prior_var.shape
var.val = Parameter(var.val.detach() - 1 * torch.ones((size[0], size[1]))) # TODO: This is not consistent with what user would expect giving a variance
param_hyper.append(var.prior_mean)
param_hyper.append(var.prior_var)
elif (var != 'std_in'):
var.randSet() #TODO: This should be done before giving params to model class
param_hyper.append(var.prior_mean)
param_hyper.append(var.prior_var)

if (var.fit_par):
param_reg.append(var.val) #TODO: This should got before fit_hyper, but need to change where randomness gets added in the code first

if (var.fit_par | var.fit_hyper):
self.track_params.append(var_name) #NMM Parameters

if var_name == 'lm':
setattr(self, var_name, var.val)

self.params_fitted = {'modelparameter': param_reg,'hyperparameter': param_hyper}

def createIC(self, ver):
# Create the initial conditions for the model state variables.
Expand Down
7 changes: 5 additions & 2 deletions whobpyt/models/BOLD/BOLD.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from whobpyt.datatypes import AbstractMode

class BOLD_Layer(AbstractMode):
class BOLD_Layer(torch.nn.Module):
'''
Balloon-Windkessel Hemodynamic Response Function Forward Model

Expand Down Expand Up @@ -29,7 +29,10 @@ def __init__(self, num_regions, params, useBC = False, device = torch.device('cp
self.device = device

self.num_blocks = 1

self.params_fitted = {}
self.params_fitted['modelparameter'] =[]
self.params_fitted['hyperparameter'] =[]
self.track_params = []
self.params = params

self.setModelParameters()
Expand Down
Loading
Loading