Skip to content

Commit

Permalink
Support sys vars for job config and support parameterized template in…
Browse files Browse the repository at this point in the history
… job config (#2145)

* support sys vars for job config

* address pr reviews

* support os env vars used in job config

* support os env vars used in job config

---------

Co-authored-by: Chester Chen <[email protected]>
  • Loading branch information
yanchengnv and chesterxgchen authored Nov 17, 2023
1 parent 8b424b7 commit 2dad0bd
Show file tree
Hide file tree
Showing 19 changed files with 304 additions and 27 deletions.
19 changes: 19 additions & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class FLContextKey(object):
COMPONENT_NODE = "__component_node__"
CONFIG_CTX = "__config_ctx__"
FILTER_DIRECTION = "__filter_dir__"
ROOT_URL = "__root_url__" # the URL for accessing the FL Server


class ReservedTopic(object):
Expand Down Expand Up @@ -213,6 +214,7 @@ class AdminCommandNames(object):
AUX_COMMAND = "aux_command"
SYS_INFO = "sys_info"
REPORT_RESOURCES = "report_resources"
REPORT_ENV = "report_env"
SHOW_SCOPES = "show_scopes"
CALL = "call"
SHELL_PWD = "pwd"
Expand Down Expand Up @@ -334,6 +336,8 @@ class SystemComponents(object):
APP_DEPLOYER = "app_deployer"
DEFAULT_APP_DEPLOYER = "default_app_deployer"
JOB_META_VALIDATOR = "job_meta_validator"
FED_CLIENT = "fed_client"
RUN_MANAGER = "run_manager"


class JobConstants:
Expand Down Expand Up @@ -428,3 +432,18 @@ class ConfigVarName:

RUNNER_SYNC_TIMEOUT = "runner_sync_timeout"
MAX_RUNNER_SYNC_TRIES = "max_runner_sync_tries"


class SystemVarName:
"""
These vars are automatically generated by FLARE and can be referenced in job config (config_fed_client and
config_fed_server). For example, you can reference SITE_NAME as "{SITE_NAME}" in your config.
To avoid potential conflict with user-defined var names, these var names are in UPPER CASE.
"""

SITE_NAME = "SITE_NAME" # name of client site or server
WORKSPACE = "WORKSPACE" # directory of the workspace
JOB_ID = "JOB_ID" # Job ID
ROOT_URL = "ROOT_URL" # the URL of the Service Provider (server)
SECURE_MODE = "SECURE_MODE" # whether the system is running in secure mode
3 changes: 3 additions & 0 deletions nvflare/fuel/f3/cellnet/core_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -2053,3 +2053,6 @@ def _is_my_sub(self, candidate_info: FqcnInfo) -> int:
if candidate_info.is_root and not candidate_info.is_on_server:
return self.SUB_TYPE_CLIENT
return self.SUB_TYPE_NONE

def is_secure(self):
return self.secure
44 changes: 40 additions & 4 deletions nvflare/fuel/flare_api/flare_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,10 @@ def _do_command(self, command: str, enforce_meta=True):
if not isinstance(result, dict):
raise InternalError(f"result from server must be dict but got {type(result)}")

# check meta status first
# Check meta status if available
# There are still some commands that do not return meta. But for commands that do return meta, we will check
# its meta status first.
meta = result.get(ResultKey.META, None)
if enforce_meta and not meta:
raise InternalError("missing meta from result")

if meta:
if not isinstance(meta, dict):
raise InternalError(f"meta must be dict but got {type(meta)}")
Expand Down Expand Up @@ -194,6 +193,8 @@ def _do_command(self, command: str, enforce_meta=True):
elif cmd_status != MetaStatusValue.OK:
raise InternalError(f"{cmd_status}: {info}")

# Then check API Status. There are cases that a command does not return meta or ran into errors before
# setting meta. Even if the command does return meta, still need to make sure APIStatus is good.
status = result.get(ResultKey.STATUS, None)
if not status:
raise InternalError("missing status in result")
Expand All @@ -212,6 +213,10 @@ def _do_command(self, command: str, enforce_meta=True):
details = result.get(ResultKey.DETAILS, "")
raise RuntimeError(f"runtime error encountered: {status}: {details}")

if enforce_meta and not meta:
raise InternalError("missing meta from result")

# both API Status and Meta are okay
return result

@staticmethod
Expand Down Expand Up @@ -260,6 +265,7 @@ def submit_job(self, job_definition_path: str) -> str:
if not os.path.isdir(job_definition_path):
if os.path.isdir(os.path.join(self.upload_dir, job_definition_path)):
job_definition_path = os.path.join(self.upload_dir, job_definition_path)
job_definition_path = os.path.abspath(job_definition_path)
else:
raise InvalidJobDefinition(f"job_definition_path '{job_definition_path}' is not a valid folder")

Expand Down Expand Up @@ -802,6 +808,36 @@ def get_connected_client_list(self) -> List[ClientInfo]:
sys_info = self.get_system_info()
return sys_info.client_info

def get_client_env(self, client_names=None):
"""Get running environment values for specified clients. The env includes values of client name,
workspace directory, root url of the FL server, and secure mode or not.
These values can be used for 3rd-party system configuration (e.g. CellPipe to connect to the FLARE system).
Args:
client_names: clients to get env from. None means all clients.
Returns: list of env info for specified clients.
Raises: InvalidTarget exception, if no clients are connected or an invalid client name is specified
"""
if not client_names:
command = AdminCommandNames.REPORT_ENV
else:
if isinstance(client_names, str):
client_names = [client_names]
elif not isinstance(client_names, list):
raise ValueError(f"client_names must be str or list of str but got {type(client_names)}")
command = AdminCommandNames.REPORT_ENV + " " + " ".join(client_names)

result = self._do_command(command)
meta = result[ResultKey.META]
client_envs = meta.get(MetaKey.CLIENTS)
if not client_envs:
raise RuntimeError(f"missing {MetaKey.CLIENTS} from meta")
return client_envs

def monitor_job(
self, job_id: str, timeout: float = 0.0, poll_interval: float = 2.0, cb=None, *cb_args, **cb_kwargs
) -> MonitorReturnCode:
Expand Down
3 changes: 0 additions & 3 deletions nvflare/fuel/hci/server/authz.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ def pre_command(self, conn: Connection, args: List[str]):
return True

if return_code == PreAuthzReturnCode.ERROR:
conn.append_error(
"Authorization error", meta=make_meta(MetaStatusValue.NOT_AUTHORIZED, "Authorization error")
)
return False

# authz required - the command name is the name of the right to be checked!
Expand Down
2 changes: 2 additions & 0 deletions nvflare/fuel/utils/config_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def search_file(file_basename: str, dirs: List[str]) -> Union[None, str]:
Returns: the full path of the file, if found; None if not found
"""
if isinstance(dirs, str):
dirs = [dirs]
for d in dirs:
f = find_file_in_dir(file_basename, d)
if f:
Expand Down
104 changes: 100 additions & 4 deletions nvflare/fuel/utils/wfconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nvflare.fuel.common.excepts import ConfigError
from nvflare.security.logging import secure_format_exception

from .argument_utils import parse_vars
from .class_utils import ModuleScanner, get_class, instantiate_class
from .dict_utils import extract_first_level_primitive, merge_dict
from .json_scanner import JsonObjectProcessor, JsonScanner, Node
Expand All @@ -38,10 +39,22 @@ def __init__(self):
class _EnvUpdater(JsonObjectProcessor):
def __init__(self, vs, element_filter=None):
JsonObjectProcessor.__init__(self)
self.vars = vs
if element_filter is not None and not callable(element_filter):
raise ValueError("element_filter must be a callable function but got {}.".format(type(element_filter)))
self.vars = copy.copy(vs)

# make all os env vars available for config
env_vars = dict(os.environ)
if env_vars:
for k, v in env_vars.items():
# when referencing os env var, must use a $ sign prefix!
var_name = "$" + k
if var_name not in self.vars:
# only use env var when it is not locally defined!
self.vars[var_name] = v

self.element_filter = element_filter
self.num_updated = 0

def process_element(self, node: Node):
element = node.element
Expand All @@ -58,14 +71,97 @@ def process_element(self, node: Node):
parent_element[node.key] = element

def substitute(self, element: str):
a = re.split("{|}", element)
if len(a) == 3 and a[0] == "" and a[2] == "":
element = self.vars.get(a[1], None)
original_value = element

# Check for Simple Variable Ref (SVR)
# SVR is resolved to an object that is derived from the variable definition.
# If the variable def also contains refs, all such refs will also be resolved.
# If the variable def contains local vars, they are also resolved with the values from the ref.
# There are two kinds of SVR:
# - Simple ref that contains a single var name: {var_name}
# - Invoke a definition that contains local vars: {@var_name:n1=v1:n2=v2:...}
# The "@var_name" is a def that contains local vars n1, n2, ...
# When invoking such def, local var values could also be refs: {@var_name:n1={varp_name}}
is_svr = False
exp = element.strip()
if exp.startswith("{@") and exp.endswith("}"):
# this is a ref with local vars
is_svr = True
exp = exp[1 : len(exp) - 1]
else:
a = re.split("{|}", exp)
if len(a) == 3 and a[0] == "" and a[2] == "":
is_svr = True
exp = a[1]

if is_svr:
parts = exp.split(":")
var_name = parts[0]
params = []
for i, p in enumerate(parts):
if i > 0:
params.append(p)

if params:
# the var_name must reference a dict
local_vars = parse_vars(params)
item = self.vars.get(var_name)
if item:
if isinstance(item, dict):
# scan the item to resolve var refs
new_item = copy.deepcopy(item)
scanner = JsonScanner(new_item)
new_vars = copy.copy(self.vars)
new_vars.update(local_vars)
resolve_var_refs(scanner, new_vars)
element = new_item
else:
raise ConfigError(
f"bad parameterized expression '{element}': {var_name} must be dict but got {type(item)}"
)
else:
raise ConfigError(f"bad parameterized expression '{element}': {var_name} is not defined")
else:
# this is a single var without params
element = self.vars.get(var_name, None)
else:
element = element.format(**self.vars)
if element != original_value:
self.num_updated += 1
return element


def resolve_var_refs(scanner: JsonScanner, var_values: dict):
"""Resolve var references in the config contained in the scanner
Args:
scanner: the scanner that contains config data to be resolved
var_values: the dict that contains var values.
Returns: None
"""
updater = _EnvUpdater(var_values)
max_rounds = 20
num_rounds = 0

# var_values may contain multi-level refs (value contains refs to other vars)
# we keep scanning and resolving refs until all refs are resolved, or we reached max number of rounds.
# The max rounds could be reached either because there are cyclic refs or the ref level is too deep.
while True:
scanner.scan(updater)
num_rounds += 1
if updater.num_updated == 0:
# nothing was resolved - we have resolved everything.
break
else:
# prepare for the next round
if num_rounds > max_rounds:
# cyclic refs or nest level too deep.
raise ConfigError(f"item de-ref exceeds {max_rounds} rounds - cyclic refs or ref level too deep")
updater.num_updated = 0


class Configurator(JsonObjectProcessor):
def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions nvflare/private/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class SysCommandTopic(object):
SYS_INFO = "sys.info"
SHELL = "sys.shell"
REPORT_RESOURCES = "resource_manager.report_resources"
REPORT_ENV = "sys.report_env"


class ControlCommandTopic(object):
Expand Down
1 change: 1 addition & 0 deletions nvflare/private/fed/app/client/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def parse_arguments():
parser.add_argument("--client_name", "-c", type=str, help="client name", required=True)
# parser.add_argument("--listen_port", "-p", type=str, help="listen port", required=True)
parser.add_argument("--sp_target", "-g", type=str, help="Sp target", required=True)
parser.add_argument("--sp_scheme", "-scheme", type=str, help="Sp connection scheme", required=True)
parser.add_argument("--parent_url", "-p", type=str, help="parent_url", required=True)
parser.add_argument(
"--fed_client", "-s", type=str, help="an aggregation server specification json file", required=True
Expand Down
9 changes: 7 additions & 2 deletions nvflare/private/fed/client/client_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import MachineStatus, SystemComponents, WorkspaceConstants
from nvflare.apis.fl_constant import FLContextKey, MachineStatus, SystemComponents, WorkspaceConstants
from nvflare.apis.fl_context import FLContext, FLContextManager
from nvflare.apis.workspace import Workspace
from nvflare.fuel.utils.network_utils import get_open_ports
Expand Down Expand Up @@ -74,6 +74,9 @@ def __init__(self, client: FederatedClient, args, rank, workers=5):
private_stickers={
SystemComponents.DEFAULT_APP_DEPLOYER: AppDeployer(),
SystemComponents.JOB_META_VALIDATOR: JobMetaValidator(),
SystemComponents.FED_CLIENT: client,
FLContextKey.SECURE_MODE: self.client.secure_train,
FLContextKey.WORKSPACE_ROOT: args.workspace,
},
)

Expand Down Expand Up @@ -149,6 +152,7 @@ def start_app(

open_port = get_open_ports(1)[0]

server_config = list(self.client.servers.values())[0]
self.client_executor.start_app(
self.client,
job_id,
Expand All @@ -158,7 +162,8 @@ def start_app(
allocated_resource,
token,
resource_manager,
list(self.client.servers.values())[0]["target"],
target=server_config["target"],
scheme=server_config.get("scheme", "grpc"),
)

return "Start the client app..."
Expand Down
6 changes: 6 additions & 0 deletions nvflare/private/fed/client/client_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def start_app(
token,
resource_manager,
target: str,
scheme: str,
):
"""Starts the client app.
Expand All @@ -59,6 +60,7 @@ def start_app(
token: token from resource manager
resource_manager: resource manager
target: SP target location
scheme: SP target connection scheme
"""
pass

Expand Down Expand Up @@ -147,6 +149,7 @@ def start_app(
token,
resource_manager: ResourceManagerSpec,
target: str,
scheme: str,
):
"""Starts the app.
Expand All @@ -160,6 +163,7 @@ def start_app(
token: token from resource manager
resource_manager: resource manager
target: SP target location
scheme: SP connection scheme
"""
new_env = os.environ.copy()
if app_custom_folder != "":
Expand All @@ -185,6 +189,8 @@ def start_app(
+ str(client.cell.get_internal_listener_url())
+ " -g "
+ target
+ " -scheme "
+ scheme
+ " -s fed_client.json "
" --set" + command_options + " print_conf=True"
)
Expand Down
Loading

0 comments on commit 2dad0bd

Please sign in to comment.