From 63393ad2886ff45656a4cb367a768f9cbc779e77 Mon Sep 17 00:00:00 2001 From: John Wang Date: Thu, 11 Jan 2024 19:26:04 +0000 Subject: [PATCH] put set model parameters method in abstract --- whobpyt/datatypes/AbstractNMM.py | 43 +++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/whobpyt/datatypes/AbstractNMM.py b/whobpyt/datatypes/AbstractNMM.py index 9793ad5a..559b789b 100644 --- a/whobpyt/datatypes/AbstractNMM.py +++ b/whobpyt/datatypes/AbstractNMM.py @@ -1,12 +1,13 @@ import torch - +from whobpyt.datatypes.parameter import par +from torch.nn.parameter import Parameter class AbstractNMM(torch.nn.Module): # This is the abstract class for the models (typically neural mass models) that are being trained. # The neuroimaging modality might be integrated into the model as well. - def __init__(self): + def __init__(self, params): super(AbstractNMM, self).__init__() # May not want to enherit from torch.nn.Module in the future - + self.params = params self.state_names = ["None"] # The names of the state variables of the model self.output_names = ["None"] # The variable to be used as output from the NMM, for purposes such as the input to an objective function self.track_params = [] # Which NMM Parameters to track over training @@ -25,7 +26,41 @@ def setModelParameters(self): # Setting the parameters that will be optimized as either model parameters or 2ndLevel/hyper # parameters (for various optional features). # This should be called in the __init__() function implementation for convenience if possible. - pass + """ + Sets the parameters of the model. + """ + + param_reg = [] + param_hyper = [] + + + var_names = [a for a in dir(self.params) if (type(getattr(self.params, a)) == par)] + for var_name in var_names: + var = getattr(self.params, var_name) + if (var.fit_par): + if var_name == 'lm': + size = var.val.shape + var.val = Parameter(- 1 * torch.ones((size[0], size[1]))) + var.prior_mean = Parameter(var.prior_mean) + var.prior_var_inv = Parameter(var.prior_var_inv) + param_reg.append(var.val) + if var.fit_hyper: + param_hyper.append(var.prior_mean) + param_hyper.append(var.prior_var_inv) + self.track_params.append(var_name) + else: + var.val = Parameter(var.val) # TODO: This is not consistent with what user would expect giving a variance + var.prior_mean = Parameter(var.prior_mean) + var.prior_var_inv = Parameter(var.prior_var_inv) + param_reg.append(var.val) + if var.fit_hyper: + param_hyper.append(var.prior_mean) + param_hyper.append(var.prior_var_inv) + self.track_params.append(var_name) + + + + self.params_fitted = {'modelparameter': param_reg,'hyperparameter': param_hyper} def createIC(self, ver): # Create the initial conditions for the model state variables.