Skip to content

Commit

Permalink
put set model parameters method in abstract
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnWangDataAnalyst committed Jan 11, 2024
1 parent 2299e4b commit 63393ad
Showing 1 changed file with 39 additions and 4 deletions.
43 changes: 39 additions & 4 deletions whobpyt/datatypes/AbstractNMM.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import torch

from whobpyt.datatypes.parameter import par
from torch.nn.parameter import Parameter
class AbstractNMM(torch.nn.Module):
# This is the abstract class for the models (typically neural mass models) that are being trained.
# The neuroimaging modality might be integrated into the model as well.

def __init__(self):
def __init__(self, params):
super(AbstractNMM, self).__init__() # May not want to enherit from torch.nn.Module in the future

self.params = params
self.state_names = ["None"] # The names of the state variables of the model
self.output_names = ["None"] # The variable to be used as output from the NMM, for purposes such as the input to an objective function
self.track_params = [] # Which NMM Parameters to track over training
Expand All @@ -25,7 +26,41 @@ def setModelParameters(self):
# Setting the parameters that will be optimized as either model parameters or 2ndLevel/hyper
# parameters (for various optional features).
# This should be called in the __init__() function implementation for convenience if possible.
pass
"""
Sets the parameters of the model.
"""

param_reg = []
param_hyper = []


var_names = [a for a in dir(self.params) if (type(getattr(self.params, a)) == par)]
for var_name in var_names:
var = getattr(self.params, var_name)
if (var.fit_par):
if var_name == 'lm':
size = var.val.shape
var.val = Parameter(- 1 * torch.ones((size[0], size[1])))
var.prior_mean = Parameter(var.prior_mean)
var.prior_var_inv = Parameter(var.prior_var_inv)
param_reg.append(var.val)
if var.fit_hyper:
param_hyper.append(var.prior_mean)
param_hyper.append(var.prior_var_inv)
self.track_params.append(var_name)
else:
var.val = Parameter(var.val) # TODO: This is not consistent with what user would expect giving a variance
var.prior_mean = Parameter(var.prior_mean)
var.prior_var_inv = Parameter(var.prior_var_inv)
param_reg.append(var.val)
if var.fit_hyper:
param_hyper.append(var.prior_mean)
param_hyper.append(var.prior_var_inv)
self.track_params.append(var_name)



self.params_fitted = {'modelparameter': param_reg,'hyperparameter': param_hyper}

def createIC(self, ver):
# Create the initial conditions for the model state variables.
Expand Down

0 comments on commit 63393ad

Please sign in to comment.