Skip to content

Commit

Permalink
put set attribute method in abstract
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnWangDataAnalyst committed Jan 11, 2024
1 parent b4eb487 commit 2299e4b
Showing 1 changed file with 7 additions and 24 deletions.
31 changes: 7 additions & 24 deletions whobpyt/datatypes/AbstractParams.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,15 @@
class AbstractParams:
# This class stores the parameters used by a model. The parameters may be for the Neural Mass Model and/or Neuroimaging Modality.
# It should be useable by both the pytorch model for training and a numpy model for parameter verification.

params ={}
def __init__(self, **kwargs):
# Define the parameters using the par data structure
pass

def getFittedNames(self):

for var in kwargs:
self.params[var] = kwargs[var]
def setParamsAsattr(self):
# Returns a named list of paramters that are being fitted
# Assumes the par datastructure is being used for parameters

fp = []
vars_names = [a for a in dir(self) if not a.startswith('__')]
for var_name in vars_names:
var = getattr(model.param, var_name)
if (type(var) == whobpyt.datatypes.parameter.par):
if (var.fit_par == True):
fp.append(var_name)
if (var.fit_hyper == True):
fp.append(var_name + "_m")
fp.append(var_name + "_v_inv")
return fp

def to(self, device):
# Moves all parameters between CPU and GPU

vars_names = [a for a in dir(self) if not a.startswith('__')]
for var_name in vars_names:
var = getattr(self, var_name)
if (type(var) == par):
var.to(device)
for var in self.params:
setattr(self, var, self.params[var])

0 comments on commit 2299e4b

Please sign in to comment.