diff --git a/doc/API/models.rst b/doc/API/models.rst index 4bf67e47..75cd571f 100644 --- a/doc/API/models.rst +++ b/doc/API/models.rst @@ -11,7 +11,7 @@ Models :undoc-members: :special-members: __init__ -.. autoclass:: whobpyt.models.RWW2.RWW2 +.. autoclass:: whobpyt.models.RWWEI2.RWWEI2 :members: info, createIC, setModelParameters, forward :undoc-members: :special-members: __init__ diff --git a/examples/eg002r__multimodal_simulation.py b/examples/eg002r__multimodal_simulation.py index 98dcaca8..89523b1f 100644 --- a/examples/eg002r__multimodal_simulation.py +++ b/examples/eg002r__multimodal_simulation.py @@ -24,7 +24,7 @@ # whobpyt stuff import whobpyt from whobpyt.datatypes import par, Recording -from whobpyt.models.RWW2 import mmRWW2, mmRWW2_np, RWW2, RWW2_np, ParamsRWW2 +from whobpyt.models.RWWEI2 import RWWEI2_EEG_BOLD, RWWEI2_EEG_BOLD_np, RWWEI2, RWWEI2_np, ParamsRWWEI2 from whobpyt.models.BOLD import BOLD_Layer, BOLD_np, BOLD_Params from whobpyt.models.EEG import EEG_Layer, EEG_np, EEG_Params from whobpyt.optimization import CostsFC, CostsPSD, CostsMean, CostsFixedFC, CostsFixedPSD @@ -67,7 +67,7 @@ init_state = (init_state + torch.randn_like(init_state)/30).to(device) # Randomizing initial values # Create a RWW Params -paramsNode = ParamsRWW2(num_regions) +paramsNode = ParamsRWWEI2(num_regions) #Create #EEG Params paramsEEG = EEG_Params(torch.eye(num_regions)) @@ -105,7 +105,7 @@ # The Multi-Modal Model -model = mmRWW2(num_regions, num_channels, paramsNode, paramsEEG, paramsBOLD, Con_Mtx, dist_mtx, step_size, sim_len, device = device) +model = RWWEI2_EEG_BOLD(num_regions, num_channels, paramsNode, paramsEEG, paramsBOLD, Con_Mtx, dist_mtx, step_size, sim_len, device = device) # %% @@ -135,15 +135,15 @@ def __init__(self): #self.BOLD_PSD = CostsPSD(...) # Not Currently Used #self.BOLD_FC = CostsFC(num_regions, varIdx = 4, targetValue = SC_mtx_norm) - def loss(self, node_history, EEG_history, BOLD_history, temp, returnLossComponents = False): + def loss(self, simData, empData = None, returnLossComponents = False): # sim, ts_window, self.model, next_window - S_E_mean_loss = self.S_E_mean.calcLoss(node_history) - S_I_mean_loss = torch.tensor([0]).to(device) #self.S_I_mean.calcLoss(node_history) - EEG_PSD_loss = torch.tensor([0]).to(device) #self.EEG_PSD.calcLoss(EEG_history) - EEG_FC_loss = torch.tensor([0]).to(device) #self.EEG_FC.calcLoss(EEG_history) - BOLD_PSD_loss = torch.tensor([0]).to(device) #self.BOLD_PS.calcLoss(BOLD_history) - BOLD_FC_loss = torch.tensor([0]).to(device) #self.BOLD_FC.calcLoss(BOLD_history) + S_E_mean_loss = self.S_E_mean.loss(simData) + S_I_mean_loss = torch.tensor([0]).to(device) #self.S_I_mean.loss(simData) + EEG_PSD_loss = torch.tensor([0]).to(device) #self.EEG_PSD.loss(simData) + EEG_FC_loss = torch.tensor([0]).to(device) #self.EEG_FC.loss(simData) + BOLD_PSD_loss = torch.tensor([0]).to(device) #self.BOLD_PS.loss(simData) + BOLD_FC_loss = torch.tensor([0]).to(device) #self.BOLD_FC.loss(simData) totalLoss = self.S_E_mean_weight*S_E_mean_loss + self.S_I_mean_weight*S_I_mean_loss \ + self.EEG_PSD_weight*EEG_PSD_loss + self.EEG_FC_weight*EEG_FC_loss \ @@ -246,7 +246,7 @@ def loss(self, node_history, EEG_history, BOLD_history, temp, returnLossComponen model.eeg.params.LF = model.eeg.params.LF.cpu() val_sim_len = 20*1000 # Simulation length in msecs -model_validate = mmRWW2_np(num_regions, num_channels, model.params, model.eeg.params, model.bold.params, Con_Mtx.detach().cpu().numpy(), dist_mtx.detach().cpu().numpy(), step_size, val_sim_len) +model_validate = RWWEI2_EEG_BOLD_np(num_regions, num_channels, model.params, model.eeg.params, model.bold.params, Con_Mtx.detach().cpu().numpy(), dist_mtx.detach().cpu().numpy(), step_size, val_sim_len) sim_vals, hE = model_validate.forward(external = 0, hx = model_validate.createIC(ver = 0), hE = 0) diff --git a/examples/eg003r__fitting_rww_example.py b/examples/eg003r__fitting_rww_example.py index 6c1ca4b0..67bdad7d 100644 --- a/examples/eg003r__fitting_rww_example.py +++ b/examples/eg003r__fitting_rww_example.py @@ -114,7 +114,7 @@ # %% # create objective function -ObjFun = CostsRWW() +ObjFun = CostsRWW(model) # %% # call model fit diff --git a/examples/eg004r__fitting_JR_example.py b/examples/eg004r__fitting_JR_example.py index 4164494e..d8a43f53 100644 --- a/examples/eg004r__fitting_JR_example.py +++ b/examples/eg004r__fitting_JR_example.py @@ -112,7 +112,7 @@ # %% # create objective function -ObjFun = CostsJR() +ObjFun = CostsJR(model) # %% # call model fit diff --git a/examples/eg005r__gpu_support.py b/examples/eg005r__gpu_support.py index 09b4334b..e173cac2 100644 --- a/examples/eg005r__gpu_support.py +++ b/examples/eg005r__gpu_support.py @@ -20,11 +20,11 @@ # whobpyt stuff import whobpyt from whobpyt.datatypes import par, Recording -from whobpyt.models.RWW2 import mmRWW2, mmRWW2_np, RWW2, RWW2_np, ParamsRWW2 +from whobpyt.models.RWWEI2 import RWWEI2_EEG_BOLD, RWWEI2_EEG_BOLD_np, RWWEI2, RWWEI2_np, ParamsRWWEI2 from whobpyt.models.BOLD import BOLD_Layer, BOLD_np, BOLD_Params from whobpyt.models.EEG import EEG_Layer, EEG_np, EEG_Params from whobpyt.optimization import CostsFC, CostsPSD, CostsMean, CostsFixedFC, CostsFixedPSD -from whobpyt.optimization.custom_cost_mmRWW2 import CostsmmRWW2 +from whobpyt.optimization.custom_cost_mmRWW2 import CostsmmRWWEI2 from whobpyt.run import Model_fitting, Fitting_FNGFPG, Fitting_Batch from whobpyt.data.generators import gen_cube @@ -67,7 +67,7 @@ plt.title("SC of Artificial Data") # Create a RWW Params -paramsNode = ParamsRWW2(num_regions) +paramsNode = ParamsRWWEI2(num_regions) paramsNode.J = par((0.15 * np.ones(num_regions)), fit_par = True, asLog = True) #This is a parameter that will be updated during training paramsNode.G = par(torch.tensor(1.0), None, None, True, False, False) @@ -92,7 +92,7 @@ # Simulation Length step_size = 0.1 # Step Size in msecs sim_len = 1500 # Simulation length in msecs -model = RWW2(num_regions, paramsNode, Con_Mtx, dist_mtx, step_size, sim_len, device = device) +model = RWWEI2(num_regions, paramsNode, Con_Mtx, dist_mtx, step_size, sim_len, device = device) demoPSD = torch.rand(100).to(device) objFun = CostsFixedPSD(num_regions = num_regions, simKey = "E", sampleFreqHz = 10000, minFreq = 1, maxFreq = 100, targetValue = demoPSD, rmTransient = 5000, device = device) @@ -126,10 +126,10 @@ # Simulation Length step_size = 0.1 # Step Size in msecs sim_len = 5000 # Simulation length in msecs -model = mmRWW2(num_regions, num_channels, model.params, paramsEEG, paramsBOLD, Con_Mtx, dist_mtx, step_size, sim_len, device) +model = RWWEI2_EEG_BOLD(num_regions, num_channels, model.params, paramsEEG, paramsBOLD, Con_Mtx, dist_mtx, step_size, sim_len, device) targetValue = torch.tensor([0.164]).to(device) -objFun = CostsmmRWW2(num_regions, simKey = "E", targetValue = targetValue, device = device) +objFun = CostsmmRWWEI2(num_regions, simKey = "E", targetValue = targetValue, device = device) # Create a Fitting Object F = Fitting_FNGFPG(model, objFun, device) @@ -163,7 +163,7 @@ model.eeg.params.LF = model.eeg.params.LF.cpu() val_sim_len = 20*1000 # Simulation length in msecs -model_validate = mmRWW2_np(num_regions, num_channels, model.params, model.eeg.params, model.bold.params, Con_Mtx.detach().cpu().numpy(), dist_mtx.detach().cpu().numpy(), step_size, val_sim_len) +model_validate = RWWEI2_EEG_BOLD_np(num_regions, num_channels, model.params, model.eeg.params, model.bold.params, Con_Mtx.detach().cpu().numpy(), dist_mtx.detach().cpu().numpy(), step_size, val_sim_len) sim_vals, hE = model_validate.forward(external = 0, hx = model_validate.createIC(ver = 0), hE = 0) diff --git a/whobpyt/datatypes/AbstractLoss.py b/whobpyt/datatypes/AbstractLoss.py index 0c0c3cd8..82c9a3f7 100644 --- a/whobpyt/datatypes/AbstractLoss.py +++ b/whobpyt/datatypes/AbstractLoss.py @@ -3,11 +3,16 @@ class AbstractLoss: # This is the abstract class for objective function components, or for a custom objective function with multiple components. - def __init__(self): + def __init__(self, simKey = None, device = torch.device('cpu')): - self.simKey = None #This is a string key to extract from the dictionary of simulation outputs the time series used by the objective function + self.simKey = simKey #This is a string key to extract from the dictionary of simulation outputs the time series used by the objective function + device = device - def loss(self, sim, emp, simKey, model: torch.nn.Module, state_vals): + def loss(self, simData, empData): # Calculates a loss to be backpropagated through - # TODO: In some classes this function is called calcLoss, need to make consistent + # If the objective function needs additional info, it should be defined at initialization so that the parameter fitting paradigms don't need to change + + # simData: is a dictionary of simulated state variable/neuroimaging modality time series. Typically accessed as simData[self.simKey]. + # empData: is the target either as a time series or a calculated phenomena metric + pass diff --git a/whobpyt/models/RWW2/__init__.py b/whobpyt/models/RWW2/__init__.py deleted file mode 100644 index fbd417bf..00000000 --- a/whobpyt/models/RWW2/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .RWW2 import RWW2 -from .RWW2_validate import RWW2_np - -from .ParamsRWW2 import ParamsRWW2 - -from .mmRWW2 import mmRWW2 -from .mmRWW2_validate import mmRWW2_np \ No newline at end of file diff --git a/whobpyt/models/RWW2/mmRWW2.py b/whobpyt/models/RWWEI2/Multimodal_RWWEI2.py similarity index 86% rename from whobpyt/models/RWW2/mmRWW2.py rename to whobpyt/models/RWWEI2/Multimodal_RWWEI2.py index 3454d693..3b027d56 100644 --- a/whobpyt/models/RWW2/mmRWW2.py +++ b/whobpyt/models/RWWEI2/Multimodal_RWWEI2.py @@ -2,19 +2,19 @@ import torch from whobpyt.datatypes import AbstractNMM, AbstractMode, AbstractParams -from whobpyt.models.RWW2 import RWW2, ParamsRWW2 +from whobpyt.models.RWWEI2 import RWWEI2, ParamsRWWEI2 from whobpyt.models.BOLD import BOLD_Layer, BOLD_Params from whobpyt.models.EEG import EEG_Layer, EEG_Params -class mmRWW2(RWW2): +class RWWEI2_EEG_BOLD(RWWEI2): - model_name = "mmRWW2" + model_name = "RWWEI2_EEG_BOLD" def __init__(self, num_regions, num_channels, paramsNode, paramsEEG, paramsBOLD, Con_Mtx, dist_mtx, step_size, sim_len, device = torch.device('cpu')): self.eeg = EEG_Layer(num_regions, paramsEEG, num_channels, device = device) self.bold = BOLD_Layer(num_regions, paramsBOLD, device = device) - super(mmRWW2, self).__init__(num_regions, paramsNode, Con_Mtx, dist_mtx, step_size, useBC = False, device = device) + super(RWWEI2_EEG_BOLD, self).__init__(num_regions, paramsNode, Con_Mtx, dist_mtx, step_size, useBC = False, device = device) self.node_size = num_regions self.step_size = step_size @@ -52,12 +52,12 @@ def forward(self, external, hx, hE, setNoise = None): return forward(self, external, hx, hE, setNoise) def setModelParameters(self): - super(mmRWW2, self).setModelParameters() # Currently this is the only one with parameters being fitted + super(RWWEI2_EEG_BOLD, self).setModelParameters() # Currently this is the only one with parameters being fitted self.eeg.setModelParameters() self.bold.setModelParameters() def createIC(self, ver): - #super(mmRWW2, self).createIC() + #super(RWWEI2_EEG_BOLD, self).createIC() #self.eeg.createIC() #self.bold.createIC() @@ -69,7 +69,7 @@ def createIC(self, ver): def forward(self, external, hx, hE, setNoise): - NMM_vals, hE = super(mmRWW2, self).forward(external, self.next_start_state[:, 0:2, :], hE, setNoise) #TODO: Fix the hx in the future + NMM_vals, hE = super(RWWEI2_EEG_BOLD, self).forward(external, self.next_start_state[:, 0:2, :], hE, setNoise) #TODO: Fix the hx in the future EEG_vals, hE = self.eeg.forward(self.step_size, self.sim_len, NMM_vals["E"].permute((1,0,2))) BOLD_vals, hE = self.bold.forward(self.next_start_state[:, 2:6, :], self.step_size, self.sim_len, NMM_vals["E"].permute((1,0,2))) diff --git a/whobpyt/models/RWW2/mmRWW2_validate.py b/whobpyt/models/RWWEI2/Multimodal_RWWEI2_validate.py similarity index 78% rename from whobpyt/models/RWW2/mmRWW2_validate.py rename to whobpyt/models/RWWEI2/Multimodal_RWWEI2_validate.py index c58762e4..18bf809d 100644 --- a/whobpyt/models/RWW2/mmRWW2_validate.py +++ b/whobpyt/models/RWWEI2/Multimodal_RWWEI2_validate.py @@ -2,16 +2,16 @@ import numpy as np from whobpyt.datatypes import AbstractNMM, AbstractMode, AbstractParams -from whobpyt.models.RWW2 import RWW2_np, ParamsRWW2 +from whobpyt.models.RWWEI2 import RWWEI2_np, ParamsRWWEI2 from whobpyt.models.BOLD import BOLD_np, BOLD_Params from whobpyt.models.EEG import EEG_np, EEG_Params -class mmRWW2_np(RWW2_np): +class RWWEI2_EEG_BOLD_np(RWWEI2_np): - model_name = "mmRWW2_np" + model_name = "RWWEI2_EEG_BOLD_np" def __init__(self, num_regions, num_channels, paramsNode, paramsEEG, paramsBOLD, Con_Mtx, dist_mtx, step_size, sim_len): - super(mmRWW2_np, self).__init__(num_regions, paramsNode, Con_Mtx, dist_mtx, step_size) + super(RWWEI2_EEG_BOLD_np, self).__init__(num_regions, paramsNode, Con_Mtx, dist_mtx, step_size) self.eeg = EEG_np(num_regions, paramsEEG, num_channels) self.bold = BOLD_np(num_regions, paramsBOLD) @@ -39,12 +39,12 @@ def forward(self, external, hx, hE): return forward(self, external, hx, hE) def setModelParameters(self): - super(mmRWW2, self).setModelParameters() # Currently this is the only one with parameters being fitted + super(RWWEI2_EEG_BOLD, self).setModelParameters() # Currently this is the only one with parameters being fitted self.eeg.setModelParameters() self.bold.setModelParameters() def createIC(self, ver): - #super(mmRWW2, self).createIC() + #super(RWWEI2_EEG_BOLD, self).createIC() #self.eeg.createIC() #self.bold.createIC() @@ -55,7 +55,7 @@ def createIC(self, ver): def forward(self, external, hx, hE): - NMM_vals, hE = super(mmRWW2_np, self).forward(external, self.next_start_state[:, 0:2], hE) #TODO: Fix the hx in the future + NMM_vals, hE = super(RWWEI2_EEG_BOLD_np, self).forward(external, self.next_start_state[:, 0:2], hE) #TODO: Fix the hx in the future EEG_vals, hE = self.eeg.forward(self.step_size, self.sim_len, NMM_vals["E"]) BOLD_vals, hE = self.bold.forward(self.next_start_state[:, 2:6], self.step_size, self.sim_len, NMM_vals["E"]) diff --git a/whobpyt/models/RWW2/ParamsRWW2.py b/whobpyt/models/RWWEI2/ParamsRWWEI2.py similarity index 96% rename from whobpyt/models/RWW2/ParamsRWW2.py rename to whobpyt/models/RWWEI2/ParamsRWWEI2.py index d31cf46d..0fda2596 100644 --- a/whobpyt/models/RWW2/ParamsRWW2.py +++ b/whobpyt/models/RWWEI2/ParamsRWWEI2.py @@ -1,7 +1,7 @@ import torch from whobpyt.datatypes import AbstractParams, par -class ParamsRWW2(AbstractParams): +class ParamsRWWEI2(AbstractParams): ## EQUATIONS & BIOLOGICAL VARIABLES FROM: # # Deco G, Ponce-Alvarez A, Hagmann P, Romani GL, Mantini D, Corbetta M. How local excitation–inhibition ratio impacts the whole brain dynamics. Journal of Neuroscience. 2014 Jun 4;34(23):7886-98. diff --git a/whobpyt/models/RWW2/README.md b/whobpyt/models/RWWEI2/README.md similarity index 89% rename from whobpyt/models/RWW2/README.md rename to whobpyt/models/RWWEI2/README.md index 93739995..55cde9b8 100644 --- a/whobpyt/models/RWW2/README.md +++ b/whobpyt/models/RWWEI2/README.md @@ -1,4 +1,4 @@ -# Reduced Wong Wang Neural Mass Model (Implementation 2) +# Reduced Wong Wang Excitatory Inhibitory Neural Mass Model (Implementation 2) ## Description: diff --git a/whobpyt/models/RWW2/RWW2.py b/whobpyt/models/RWWEI2/RWWEI2.py similarity index 97% rename from whobpyt/models/RWW2/RWW2.py rename to whobpyt/models/RWWEI2/RWWEI2.py index 52af7928..de2fc22d 100644 --- a/whobpyt/models/RWW2/RWW2.py +++ b/whobpyt/models/RWWEI2/RWWEI2.py @@ -2,7 +2,7 @@ from whobpyt.datatypes import AbstractNMM, AbstractParams, par from math import sqrt -class RWW2(AbstractNMM): +class RWWEI2(AbstractNMM): ''' Reduced Wong Wang Excitatory Inhibatory (RWWEXcInh) Model - Version 2 @@ -16,7 +16,7 @@ class RWW2(AbstractNMM): Attributes ------------- - params : ParamsRWW2 + params : ParamsRWWEI2 An AbstractParams object which contains the model's parameters step_size : Float The step size of numerical integration (in msec) @@ -61,7 +61,7 @@ class RWW2(AbstractNMM): def __init__(self, num_regions, params, Con_Mtx, Dist_Mtx, step_size = 0.1, sim_len = 1000, useBC = False, device = torch.device('cpu')): ''' ''' - super(RWW2, self).__init__() # To inherit parameters attribute + super(RWWEI2, self).__init__() # To inherit parameters attribute # Initialize the RWW Model # diff --git a/whobpyt/models/RWW2/RWW2_validate.py b/whobpyt/models/RWWEI2/RWWEI2_validate.py similarity index 97% rename from whobpyt/models/RWW2/RWW2_validate.py rename to whobpyt/models/RWWEI2/RWWEI2_validate.py index 598798ac..091c6b7b 100644 --- a/whobpyt/models/RWW2/RWW2_validate.py +++ b/whobpyt/models/RWWEI2/RWWEI2_validate.py @@ -15,7 +15,7 @@ import numpy from math import sqrt -class RWW2_np(): +class RWWEI2_np(): def __init__(self, num_regions, params, Con_Mtx, Dist_Mtx, step_size = 0.1): # Initialize the RWW Model diff --git a/whobpyt/models/RWWEI2/__init__.py b/whobpyt/models/RWWEI2/__init__.py new file mode 100644 index 00000000..4bb0b9f1 --- /dev/null +++ b/whobpyt/models/RWWEI2/__init__.py @@ -0,0 +1,7 @@ +from .RWWEI2 import RWWEI2 +from .RWWEI2_validate import RWWEI2_np + +from .ParamsRWWEI2 import ParamsRWWEI2 + +from .Multimodal_RWWEI2 import RWWEI2_EEG_BOLD +from .Multimodal_RWWEI2_validate import RWWEI2_EEG_BOLD_np \ No newline at end of file diff --git a/whobpyt/models/__init__.py b/whobpyt/models/__init__.py index a2813f16..657c2b55 100644 --- a/whobpyt/models/__init__.py +++ b/whobpyt/models/__init__.py @@ -3,4 +3,4 @@ from . import JansenRit from . import Linear from . import RWW -from . import RWW2 \ No newline at end of file +from . import RWWEI2 \ No newline at end of file diff --git a/whobpyt/optimization/cost_FC.py b/whobpyt/optimization/cost_FC.py index f81e23f9..d11871b2 100644 --- a/whobpyt/optimization/cost_FC.py +++ b/whobpyt/optimization/cost_FC.py @@ -32,17 +32,17 @@ def __init__(self, simKey): simKey: str type of cost function to be used """ - super(CostsFC, self).__init__() + super(CostsFC, self).__init__(simKey) self.simKey = simKey - def loss(self, sim: torch.Tensor, emp: torch.Tensor): + def loss(self, simData: dict, empData: torch.Tensor): """Function to calculate the cost function for Functional Connectivity (FC) fitting. It initially calculates the FC matrix using the data from the BOLD time series, makes that mean-zero, and then calculates the Pearson Correlation between the simulated FC and empirical FC. The FC matrix values are then transposed to the 0-1 range. We then use this FC matrix as a probability matrix and use it to get the cross-entropy-like loss using negative log likelihood. Parameters ---------- - sim: torch.Tensor with node_size X datapoint + simData: dict of torch.Tensor with node_size X datapoint simulated BOLD - emp: torch.Tensor with node_size X datapoint + empData: torch.Tensor with node_size X datapoint empirical BOLD Returns @@ -52,8 +52,10 @@ def loss(self, sim: torch.Tensor, emp: torch.Tensor): """ method_arg_type_check(self.loss) # Check that the passed arguments (excluding self) abide by their expected data types + sim = simData[self.simKey] + logits_series_tf = sim - labels_series_tf = emp + labels_series_tf = empData # get node_size() and TRs_per_window() node_size = logits_series_tf.shape[0] truncated_backprop_length = logits_series_tf.shape[1] @@ -114,7 +116,7 @@ class CostsFixedFC(AbstractLoss): Whether to run on GPU or CPU Methods: -------- - calcLoss: function + loss: function calculates functional connectivity and uses it to calculate the loss """ def __init__(self, simKey, device = torch.device('cpu')): @@ -130,7 +132,7 @@ def __init__(self, simKey, device = torch.device('cpu')): self.simKey = simKey self.device = device - def calcLoss(self, simTS, empFC): + def loss(self, simData, empData): """Function to calculate the cost function for Functional Connectivity (FC) fitting. It initially calculates the FC matrix using the data from the time series, makes that mean-zero, and then calculates the Pearson Correlation between the simulated FC and empirical FC. @@ -139,9 +141,9 @@ def calcLoss(self, simTS, empFC): Parameters ---------- - simTS: torch.tensor with node_size X time_point + simData: dict of torch.tensor with node_size X time_point Simulated Time Series - empFC: torch.tensor with node_size X node_size + empData: torch.tensor with node_size X node_size Empirical Functional Connectivity Returns @@ -149,6 +151,9 @@ def calcLoss(self, simTS, empFC): losses_corr: torch.tensor cost function value """ + simTS = simData[self.simKey] + empFC = empData + logits_series_tf = simTS # get node_size() and TRs_per_window() diff --git a/whobpyt/optimization/cost_Mean.py b/whobpyt/optimization/cost_Mean.py index c38c43ee..cd0f79f0 100644 --- a/whobpyt/optimization/cost_Mean.py +++ b/whobpyt/optimization/cost_Mean.py @@ -1,7 +1,8 @@ import torch +from whobpyt.datatypes.AbstractLoss import AbstractLoss -class CostsMean(): +class CostsMean(AbstractLoss): ''' Target Mean Value of a Variable @@ -30,6 +31,7 @@ def __init__(self, num_regions, simKey, targetValue = None, empiricalData = None targetValue : Tensor The target value either as single number or vector ''' + super(CostsMean, self).__init__(simKey) self.num_regions = num_regions self.simKey = simKey # This is the key from the numerical simulation used to select the time series @@ -48,18 +50,18 @@ def __init__(self, num_regions, simKey, targetValue = None, empiricalData = None if empiricalData != None: # In the future, if given empiricalData then will calculate the target value in this initialization function. - # That will possibly involve a time series of targets, for which then the calcLoss would need a parameter to identify + # That will possibly involve a time series of targets, for which then the loss would need a parameter to identify # which one to fit to. pass - def calcLoss(self, simData, empData = None): + def loss(self, simData, empData = None): ''' Method to calculate the loss Parameters -------------- - simData: Tensor[ Nodes x Time ] or [ Nodes x Time x Blocks(Batch) ] + simData: dict of Tensor[ Nodes x Time ] or [ Nodes x Time x Blocks(Batch) ] The time series used by the loss function Returns @@ -69,7 +71,9 @@ def calcLoss(self, simData, empData = None): ''' - meanVar = torch.mean(simData, 1) + sim = simData[self.simKey] + + meanVar = torch.mean(sim, 1) return torch.nn.functional.mse_loss(meanVar, self.targetValue) \ No newline at end of file diff --git a/whobpyt/optimization/cost_PSD.py b/whobpyt/optimization/cost_PSD.py index d67958d9..5bf5f931 100644 --- a/whobpyt/optimization/cost_PSD.py +++ b/whobpyt/optimization/cost_PSD.py @@ -1,13 +1,15 @@ import torch from warnings import warn +from whobpyt.datatypes.AbstractLoss import AbstractLoss -class CostsPSD(): +class CostsPSD(AbstractLoss): ''' WARNING: This function is no longer supported. TODO: Needs to be updated. ''' # TODO: Deal with num_region vs. num_channels vs. num_parcels conflict with variable naming def __init__(self, num_regions, simKey, sampleFreqHz, targetValue = None, empiricalData = None): + super(CostsPSD, self).__init__(simKey) self.num_regions = num_regions self.simKey = simKey # This is the index in the data simulation to extract variable time series from @@ -18,7 +20,7 @@ def __init__(self, num_regions, simKey, sampleFreqHz, targetValue = None, empiri if empiricalData != None: # In the future, if given empiricalData then will calculate the target value in this initialization function. - # That will possibly involve a time series of targets, for which then the calcLoss would need a parameter to identify + # That will possibly involve a time series of targets, for which then the loss would need a parameter to identify # which one to fit to. pass warn(f'{self.__class__.__name__} will be deprecated.', DeprecationWarning, stacklevel=2) @@ -80,18 +82,19 @@ def scalePSD(sdAxis_dS, sdValues_dS): return sdAxis_dS, sdValues_dS_scaled - def calcLoss(self, simData): - # simData assumed to be in the form [time_steps, regions or channels, one or more variables] + def loss(self, simData, empData = None): + # simData assumed to be dict with values in the form [time_steps, regions or channels, one or more variables] # Returns the MSE of the difference between the simulated and target power spectrum + sim = simData[self.simKey] - sdAxis, sdValues = powerSpectrumLoss.calcPSD(simData[:, :, self.simKey], sampleFreqHz = self.sampleFreqHz, minFreq = 2, maxFreq = 40) + sdAxis, sdValues = powerSpectrumLoss.calcPSD(sim[:, :, self.simKey], sampleFreqHz = self.sampleFreqHz, minFreq = 2, maxFreq = 40) sdAxis_dS, sdValues_dS = powerSpectrumLoss.downSmoothPSD(sdAxis, sdValues, numPoints = 32) sdAxis_dS, sdValues_dS_scaled = powerSpectrumLoss.scalePSD(sdAxis_dS, sdValues_dS) return torch.nn.functional.mse_loss(sdValues_dS_scaled, self.targetValue) -class CostsFixedPSD(): +class CostsFixedPSD(AbstractLoss): """ Updated Code that fits to a fixed PSD @@ -153,7 +156,8 @@ def __init__(self, num_regions, simKey, sampleFreqHz, minFreq, maxFreq, targetVa Whether to run the objective function on CPU or GPU. """ - + super(CostsFixedPSD, self).__init__(simKey) + self.num_regions = num_regions self.simKey = simKey # This is the index in the data simulation to extract variable time series from self.batch_size = batch_size @@ -173,7 +177,7 @@ def __init__(self, num_regions, simKey, sampleFreqHz, minFreq, maxFreq, targetVa if empiricalData != None: # In the future, if given empiricalData then will calculate the target value in this initialization function. - # That will possibly involve a time series of targets, for which then the calcLoss would need a parameter to identify + # That will possibly involve a time series of targets, for which then the loss would need a parameter to identify # which one to fit to. pass @@ -185,7 +189,7 @@ def calcPSD(self, signal, sampleFreqHz, minFreq = None, maxFreq = None, axMethod Parameters ---------- - signal: torch.tensor + signal: dict of torch.tensor The timeseries outputted by a model. Dimensions: [nodes, time, batch] sampleFreqHz: Int The sampling frequency of the data. @@ -241,7 +245,7 @@ def calcPSD(self, signal, sampleFreqHz, minFreq = None, maxFreq = None, axMethod return sdAxis, sdValues - def calcLoss(self, simData, empData = None): + def loss(self, simData, empData = None): """ NOTE: If using batching, the batches will be averaged before calculating the error (as opposed to having an error for each simulated time series in the batch). @@ -260,8 +264,9 @@ def calcLoss(self, simData, empData = None): The MSE of the difference between the simulated and target power spectrum within the specified range """ + sim = simData[self.simKey] - psdAxis, psdValues = self.calcPSD(simData, sampleFreqHz = self.sampleFreqHz, minFreq = self.minFreq, maxFreq = self.maxFreq) # TODO: Sampling frequency of simulated data and target time series is currently assumed to be the same. + psdAxis, psdValues = self.calcPSD(sim, sampleFreqHz = self.sampleFreqHz, minFreq = self.minFreq, maxFreq = self.maxFreq) # TODO: Sampling frequency of simulated data and target time series is currently assumed to be the same. meanValue = torch.mean(psdValues, 2) diff --git a/whobpyt/optimization/cost_TS.py b/whobpyt/optimization/cost_TS.py index 1b970195..0f8d40de 100644 --- a/whobpyt/optimization/cost_TS.py +++ b/whobpyt/optimization/cost_TS.py @@ -12,22 +12,24 @@ class CostsTS(AbstractLoss): def __init__(self, simKey): - super(CostsTS, self).__init__() + super(CostsTS, self).__init__(simKey) self.simKey = simKey - def loss(self, sim: torch.Tensor, emp: torch.Tensor, model: torch.nn.Module = None, state_vals = None): + def loss(self, simData: dict, empData: torch.Tensor): """ Calculate the Pearson Correlation between the simFC and empFC. From there, compute the probability and negative log-likelihood. Parameters ---------- - sim: tensor with node_size X datapoint + simData: dict of tensor with node_size X datapoint simulated EEG - emp: tensor with node_size X datapoint + empData: tensor with node_size X datapoint empirical EEG """ - method_arg_type_check(self.loss, exclude = ['model', 'state_vals']) # Check that the passed arguments (excluding self) abide by their expected data types + method_arg_type_check(self.loss) # Check that the passed arguments (excluding self) abide by their expected data types + sim = simData[self.simKey] + emp = empData losses = torch.sqrt(torch.mean((sim - emp) ** 2)) # return losses diff --git a/whobpyt/optimization/custom_cost_JR.py b/whobpyt/optimization/custom_cost_JR.py index 377f7c9d..00c72ded 100644 --- a/whobpyt/optimization/custom_cost_JR.py +++ b/whobpyt/optimization/custom_cost_JR.py @@ -12,14 +12,18 @@ from whobpyt.functions.arg_type_check import method_arg_type_check class CostsJR(AbstractLoss): - def __init__(self): - super(CostsJR, self).__init__() + def __init__(self, model): self.mainLoss = CostsTS("eeg") self.simKey = "eeg" + self.model = model - def loss(self, sim: torch.Tensor, emp: torch.Tensor, model: torch.nn.Module, state_vals: dict): + def loss(self, simData: dict, empData: torch.Tensor): method_arg_type_check(self.loss) # Check that the passed arguments (excluding self) abide by their expected data types + sim = simData + emp = empData + + model = self.model # define some constants lb = 0.001 diff --git a/whobpyt/optimization/custom_cost_RWW.py b/whobpyt/optimization/custom_cost_RWW.py index 815fe4e8..19ff4ed3 100644 --- a/whobpyt/optimization/custom_cost_RWW.py +++ b/whobpyt/optimization/custom_cost_RWW.py @@ -8,18 +8,24 @@ import torch from whobpyt.datatypes.parameter import par from whobpyt.datatypes.AbstractLoss import AbstractLoss +from whobpyt.datatypes.AbstractNMM import AbstractNMM from whobpyt.optimization.cost_FC import CostsFC from whobpyt.functions.arg_type_check import method_arg_type_check class CostsRWW(AbstractLoss): - def __init__(self): - super(CostsRWW, self).__init__() + def __init__(self, model : AbstractNMM): self.mainLoss = CostsFC("bold") self.simKey = "bold" + self.model = model - def loss(self, sim: torch.Tensor, emp: torch.Tensor, model: torch.nn.Module, state_vals: dict): + def loss(self, simData: dict, empData: torch.Tensor): method_arg_type_check(self.loss) # Check that the passed arguments (excluding self) abide by their expected data types + sim = simData + emp = empData + + model = self.model + state_vals = sim # define some constants lb = 0.001 diff --git a/whobpyt/optimization/custom_cost_mmRWW2.py b/whobpyt/optimization/custom_cost_mmRWW2.py index 0bf8825f..84d51f1a 100644 --- a/whobpyt/optimization/custom_cost_mmRWW2.py +++ b/whobpyt/optimization/custom_cost_mmRWW2.py @@ -5,10 +5,8 @@ from whobpyt.optimization import CostsPSD from whobpyt.optimization import CostsFixedFC -class CostsmmRWW2(AbstractLoss): - def __init__(self, num_regions, simKey, targetValue, device = torch.device('cpu')): - super(CostsmmRWW2, self).__init__() - +class CostsmmRWWEI2(AbstractLoss): + def __init__(self, num_regions, simKey, targetValue, device = torch.device('cpu')): # Defining the Objective Function # --------------------------------------------------- # Written in such as way as to be able to adjust the relative importance of components that make up the objective function. @@ -31,18 +29,15 @@ def __init__(self, num_regions, simKey, targetValue, device = torch.device('cpu' self.EEG_FC = CostsFixedFC(simKey = "eeg", device = device) #self.BOLD_PSD = CostsPSD(...) # Not Currently Used self.BOLD_FC = CostsFixedFC(simKey = "bold", device = device) - - def loss(self, sim, emp, model: torch.nn.Module, state_vals): - pass - def calcLoss(self, simData, empData, returnLossComponents = False): + def loss(self, simData, empData, returnLossComponents = False): - S_E_mean_loss = self.S_E_mean.calcLoss(simData[self.S_E_mean.simKey]) - S_I_mean_loss = torch.tensor([0]).to(self.device) #self.S_I_mean.calcLoss(node_history) - EEG_PSD_loss = torch.tensor([0]).to(self.device) #self.EEG_PSD.calcLoss(EEG_history) - EEG_FC_loss = self.EEG_FC.calcLoss(simData[self.EEG_FC.simKey], empData['EEG_FC']) - BOLD_PSD_loss = torch.tensor([0]).to(self.device) #self.BOLD_PS.calcLoss(BOLD_history) - BOLD_FC_loss = self.BOLD_FC.calcLoss(simData[self.BOLD_FC.simKey], empData['BOLD_FC']) + S_E_mean_loss = self.S_E_mean.loss(simData) + S_I_mean_loss = torch.tensor([0]).to(self.device) #self.S_I_mean.loss(node_history) + EEG_PSD_loss = torch.tensor([0]).to(self.device) #self.EEG_PSD.loss(EEG_history) + EEG_FC_loss = self.EEG_FC.loss(simData, empData['EEG_FC']) + BOLD_PSD_loss = torch.tensor([0]).to(self.device) #self.BOLD_PS.loss(BOLD_history) + BOLD_FC_loss = self.BOLD_FC.loss(simData, empData['BOLD_FC']) totalLoss = self.S_E_mean_weight*S_E_mean_loss + self.S_I_mean_weight*S_I_mean_loss \ + self.EEG_PSD_weight*EEG_PSD_loss + self.EEG_FC_weight*EEG_FC_loss \ diff --git a/whobpyt/run/batchfitting.py b/whobpyt/run/batchfitting.py index ad93971d..35fea022 100644 --- a/whobpyt/run/batchfitting.py +++ b/whobpyt/run/batchfitting.py @@ -104,7 +104,7 @@ def train(self, stim, empDatas, num_epochs, batch_size, learningrate = 0.05, sta # calculating loss - loss = self.cost.calcLoss(sim_vals[self.cost.simKey], empData) + loss = self.cost.loss(sim_vals, empData) optim.zero_grad() loss.backward() diff --git a/whobpyt/run/customfitting.py b/whobpyt/run/customfitting.py index 1896efe2..77e8164b 100644 --- a/whobpyt/run/customfitting.py +++ b/whobpyt/run/customfitting.py @@ -135,7 +135,7 @@ def train(self, stim, empDatas, num_epochs, block_len, learningrate = 0.05, rese # calculating loss - loss = self.cost.calcLoss(sim_vals, empData) + loss = self.cost.loss(sim_vals, empData) optim.zero_grad() loss.backward() diff --git a/whobpyt/run/modelfitting.py b/whobpyt/run/modelfitting.py index a6ab5606..3f774359 100644 --- a/whobpyt/run/modelfitting.py +++ b/whobpyt/run/modelfitting.py @@ -170,8 +170,7 @@ def train(self, u, empRecs: list, ts_window = torch.tensor(windowedTS[win_idx, :, :], dtype=torch.float32) # calculating loss - sim = next_window[self.cost.simKey] - loss = self.cost.loss(sim, ts_window, self.model, next_window) + loss = self.cost.loss(next_window, ts_window) # TIME SERIES: Put the window of simulated forward model. for name in set(self.model.state_names + self.model.output_names):