Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple host names for FLARE server #3018

Merged
merged 11 commits into from
Oct 11, 2024
1 change: 1 addition & 0 deletions nvflare/apis/utils/format_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

type_pattern_mapping = {
"server": r"^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$",
"host_name": r"^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$",
"overseer": r"^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$",
"sp_end_point": r"^((([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9]):[0-9]*:[0-9]*)$",
"client": r"^[A-Za-z0-9-_]+$",
Expand Down
48 changes: 42 additions & 6 deletions nvflare/lighter/impl/cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,13 @@ def _build_write_cert_pair(self, participant, base_name, ctx):
f.write(serialize_cert(cert))
with open(os.path.join(dest_dir, f"{base_name}.key"), "wb") as f:
f.write(serialize_pri_key(pri_key))
if base_name == "client" and (listening_host := participant.props.get("listening_host")):
tmp_participant = Participant("server", listening_host, participant.org)
if base_name == "client" and (listening_host := participant.get_listening_host()):
tmp_participant = Participant(
type="server",
name=participant.name,
org=participant.org,
default_host=listening_host,
)
tmp_pri_key, tmp_cert = self.get_pri_key_cert(tmp_participant)
with open(os.path.join(dest_dir, "server.crt"), "wb") as f:
f.write(serialize_cert(tmp_cert))
Expand Down Expand Up @@ -142,10 +147,20 @@ def get_pri_key_cert(self, participant):
subject = self.get_subject(participant)
subject_org = participant.org
if participant.type == "admin":
role = participant.props.get("role")
role = participant.get_prop("role")
else:
role = None
cert = self._generate_cert(subject, subject_org, self.issuer, self.pri_key, pub_key, role=role)

server = participant if participant.type == "server" else None
cert = self._generate_cert(
subject,
subject_org,
self.issuer,
self.pri_key,
pub_key,
role=role,
server=server,
)
return pri_key, cert

def get_subject(self, participant):
Expand All @@ -157,10 +172,20 @@ def _generate_keys(self):
return pri_key, pub_key

def _generate_cert(
self, subject, subject_org, issuer, signing_pri_key, subject_pub_key, valid_days=360, ca=False, role=None
self,
subject,
subject_org,
issuer,
signing_pri_key,
subject_pub_key,
valid_days=360,
ca=False,
role=None,
server: Participant = None,
):
x509_subject = self._x509_name(subject, subject_org, role)
x509_issuer = self._x509_name(issuer)

builder = (
x509.CertificateBuilder()
.subject_name(x509_subject)
Expand All @@ -174,7 +199,6 @@ def _generate_cert(
+ datetime.timedelta(days=valid_days)
# Sign our certificate with our private key
)
.add_extension(x509.SubjectAlternativeName([x509.DNSName(subject)]), critical=False)
)
if ca:
builder = (
Expand All @@ -188,6 +212,18 @@ def _generate_cert(
)
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=False)
)

if server:
# This is to generate a server cert.
# Use SubjectAlternativeName for all host names
default_host = server.get_default_host()
host_names = server.get_host_names()
sans = [x509.DNSName(default_host)]
if host_names:
for h in host_names:
if h != default_host:
sans.append(x509.DNSName(h))
builder = builder.add_extension(x509.SubjectAlternativeName(sans), critical=False)
return builder.sign(signing_pri_key, hashes.SHA256(), default_backend())

def _x509_name(self, cn_name, org_name=None, role=None):
Expand Down
22 changes: 0 additions & 22 deletions nvflare/lighter/impl/local_cert.py

This file was deleted.

69 changes: 0 additions & 69 deletions nvflare/lighter/impl/local_static_file.py

This file was deleted.

132 changes: 84 additions & 48 deletions nvflare/lighter/impl/static_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import yaml

from nvflare.lighter import utils
from nvflare.lighter.spec import Builder
from nvflare.lighter.spec import Builder, Participant


class StaticFileBuilder(Builder):
Expand Down Expand Up @@ -124,28 +124,18 @@ def _build_server(self, server, ctx):
dest_dir = self.get_kit_dir(server, ctx)
server_0 = config["servers"][0]
server_0["name"] = self.project_name
admin_port = server.props.get("admin_port", 8003)
admin_port = server.get_prop("admin_port", 8003)
ctx["admin_port"] = admin_port
fed_learn_port = server.props.get("fed_learn_port", 8002)
fed_learn_port = server.get_prop("fed_learn_port", 8002)
ctx["fed_learn_port"] = fed_learn_port
ctx["server_name"] = self.get_server_name(server)
server_0["service"]["target"] = f"{self.get_server_name(server)}:{fed_learn_port}"
server_0["service"]["scheme"] = self.scheme
server_0["admin_host"] = self.get_server_name(server)
server_0["admin_port"] = admin_port
if self.overseer_agent:
overseer_agent = copy.deepcopy(self.overseer_agent)
if overseer_agent.get("overseer_exists", True):
overseer_agent["args"] = {
"role": "server",
"overseer_end_point": ctx.get("overseer_end_point", ""),
"project": self.project_name,
"name": self.get_server_name(server),
"fl_port": str(fed_learn_port),
"admin_port": str(admin_port),
}
overseer_agent.pop("overseer_exists", None)
config["overseer_agent"] = overseer_agent

self._prepare_overseer_agent(server, config, "server", ctx)

utils._write(os.path.join(dest_dir, "fed_server.json"), json.dumps(config, indent=2), "t")
replacement_dict = {
"admin_port": admin_port,
Expand Down Expand Up @@ -212,14 +202,15 @@ def _build_server(self, server, ctx):
)

def _build_client(self, client, ctx):
project = ctx["project"]
server = project.get_server()
if not server:
raise ValueError("missing server definition in project")
config = json.loads(self.template["fed_client"])
dest_dir = self.get_kit_dir(client, ctx)
fed_learn_port = ctx.get("fed_learn_port")
server_name = ctx.get("server_name")
# config["servers"][0]["service"]["target"] = f"{server_name}:{fed_learn_port}"
config["servers"][0]["service"]["scheme"] = self.scheme
config["servers"][0]["name"] = self.project_name
# config["enable_byoc"] = client.enable_byoc
config["servers"][0]["identity"] = server.name # the official identity of the server
replacement_dict = {
"client_name": f"{client.subject}",
"config_folder": self.config_folder,
Expand All @@ -228,23 +219,8 @@ def _build_client(self, client, ctx):
"type": "client",
"cln_uid": f"uid={client.subject}",
}
if self.overseer_agent:
overseer_agent = copy.deepcopy(self.overseer_agent)
if overseer_agent.get("overseer_exists", True):
overseer_agent["args"] = {
"role": "client",
"overseer_end_point": ctx.get("overseer_end_point", ""),
"project": self.project_name,
"name": client.subject,
}
overseer_agent.pop("overseer_exists", None)
config["overseer_agent"] = overseer_agent
# components = client.props.get("components", [])
# config["components"] = list()
# for comp in components:
# temp_dict = {"id": comp}
# temp_dict.update(components[comp])
# config["components"].append(temp_dict)

self._prepare_overseer_agent(client, config, "client", ctx)

utils._write(os.path.join(dest_dir, "fed_client.json"), json.dumps(config, indent=2), "t")
if self.docker_image:
Expand Down Expand Up @@ -302,6 +278,76 @@ def _build_client(self, client, ctx):
"t",
)

def _check_host_name(self, host_name: str, server: Participant) -> str:
if host_name == server.get_default_host():
# Use the default host - OK
return ""

available_host_names = server.get_host_names()
if available_host_names and host_name in available_host_names:
# use alternative host name - OK
return ""

return f"unknown host name '{host_name}'"

def _prepare_overseer_agent(self, participant, config, role, ctx):
project = ctx["project"]
server = project.get_server()
if not server:
raise ValueError(f"Missing server definition in project {project.name}")

fl_port = server.get_prop("fed_learn_port", 8002)
admin_port = server.get_prop("admin_port", 8003)

if self.overseer_agent:
overseer_agent = copy.deepcopy(self.overseer_agent)
if overseer_agent.get("overseer_exists", True):
if role == "server":
overseer_agent["args"] = {
"role": role,
"overseer_end_point": ctx.get("overseer_end_point", ""),
"project": self.project_name,
"name": server.name,
"fl_port": str(fl_port),
"admin_port": str(admin_port),
}
else:
overseer_agent["args"] = {
"role": role,
"overseer_end_point": ctx.get("overseer_end_point", ""),
"project": self.project_name,
"name": participant.subject,
}
else:
# do not use overseer system
# Dummy overseer agent is used here
if role == "server":
# the server expects the "connect_to" to be the same as its name
# otherwise the host name generated by the dummy agent won't be accepted!
connect_to = server.name
else:
connect_to = participant.get_connect_to()
if connect_to:
err = self._check_host_name(connect_to, server)
if err:
raise ValueError(f"bad connect_to in {participant.subject}: {err}")
else:
# connect_to is not explicitly specified: use the server's name by default
# Note: by doing this dynamically, we guarantee the sp_end_point to be correct, even if the
# project.yaml does not specify the default server host correctly!
connect_to = server.get_default_host()

# change the sp_end_point to use connect_to
agent_args = overseer_agent.get("args")
if agent_args:
sp_end_point = agent_args.get("sp_end_point")
if sp_end_point:
# format of the sp_end_point: server_host_name:fl_port:admin_port
agent_args["sp_end_point"] = f"{connect_to}:{fl_port}:{admin_port}"

overseer_agent.pop("overseer_exists", None)
config["overseer_agent"] = overseer_agent

def _build_admin(self, admin, ctx):
dest_dir = self.get_kit_dir(admin, ctx)
admin_port = ctx.get("admin_port")
Expand Down Expand Up @@ -338,17 +384,7 @@ def _build_admin(self, admin, ctx):
def prepare_admin_config(self, admin, ctx):
config = json.loads(self.template["fed_admin"])
agent_config = dict()
if self.overseer_agent:
overseer_agent = copy.deepcopy(self.overseer_agent)
if overseer_agent.get("overseer_exists", True):
overseer_agent["args"] = {
"role": "admin",
"overseer_end_point": ctx.get("overseer_end_point", ""),
"project": self.project_name,
"name": admin.subject,
}
overseer_agent.pop("overseer_exists", None)
agent_config["overseer_agent"] = overseer_agent
self._prepare_overseer_agent(admin, agent_config, "admin", ctx)
config["admin"].update(agent_config)
return config

Expand Down
Loading
Loading