Skip to content

Commit

Permalink
Added ability to handle parameters from base class.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhwen authored and holgerroth committed Apr 12, 2024
1 parent 5cfc9df commit 7e58002
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions nvflare/job_config/fed_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,16 @@ def _get_base_app(self, custom_dir, app, app_config):
)

def _get_args(self, component, custom_dir):
constructor = component.__class__.__init__
parameters = inspect.signature(constructor).parameters
parameters = self._get_init_parameters(component)
attrs = component.__dict__
args = {}

for param in parameters:
attr_key = param if param in attrs.keys() else "_" + param

if attr_key in ["args", "kwargs"]:
continue

if attr_key in attrs.keys() and parameters[param].default != attrs[attr_key]:
if type(attrs[attr_key]).__name__ in dir(builtins):
args[param] = attrs[attr_key]
Expand All @@ -251,6 +254,19 @@ def _get_args(self, component, custom_dir):

return args

def _get_init_parameters(self, component):
class__ = component.__class__
parameters = {}
self._retrieve_parameters(class__, parameters)
return parameters

def _retrieve_parameters(self, class__, parameters):
constructor = class__.__init__
parameters.update(inspect.signature(constructor).parameters)
for item in class__.__bases__:
parameters.update(self._retrieve_parameters(item, parameters))
return parameters

def _get_filters(self, filters, custom_dir):
r = []
for f in filters:
Expand Down

0 comments on commit 7e58002

Please sign in to comment.