Skip to content

Commit

Permalink
fix fqcn handling
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv committed Jan 1, 2024
1 parent 5fb42f1 commit ca00690
Show file tree
Hide file tree
Showing 15 changed files with 333 additions and 111 deletions.
3 changes: 3 additions & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class FLContextKey(object):
JOB_PARTICIPANTS = "__job_participants"
JOB_BLOCK_REASON = "__job_block_reason" # why the job should be blocked from scheduling
SSID = "__ssid__"

CLIENT_TOKEN = "__client_token"
AUTHORIZATION_RESULT = "_authorization_result"
AUTHORIZATION_REASON = "_authorization_reason"
Expand All @@ -174,6 +175,8 @@ class FLContextKey(object):
FILTER_DIRECTION = "__filter_dir__"
ROOT_URL = "__root_url__" # the URL for accessing the FL Server
NOT_READY_TO_END_RUN = "not_ready_to_end_run__" # component sets this to indicate it's not ready to end run yet
CLIENT_CONFIG = "__client_config__"
SERVER_CONFIG = "__server_config__"


class ReservedTopic(object):
Expand Down
16 changes: 7 additions & 9 deletions nvflare/fuel/f3/cellnet/core_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,28 +239,26 @@ def get_certificate(self, target: str) -> bytes:

cert = self.credential_manager.get_certificate(target)
if cert:
print(f"==== CERT CACHED for {target}")
return cert

print(f"==== Exchange cert with {target}")
cert = self.exchange_certificate(target)
self.credential_manager.save_certificate(target, cert)

return cert

def exchange_certificate(self, target: str) -> bytes:
root = FQCN.get_root(target)
req = self.credential_manager.create_request(root)
response = self.core_cell.send_request(_SM_CHANNEL, _SM_TOPIC, root, Message(None, req))
req = self.credential_manager.create_request()
response = self.core_cell.send_request(_SM_CHANNEL, _SM_TOPIC, target, Message(None, req))
reply = response.payload

if not reply:
error_code = response.get_header(MessageHeaderKey.RETURN_CODE)
raise RuntimeError(f"Cert exchanged to {root} failed: {error_code}")
raise RuntimeError(f"Cert exchanged to {target} failed: {error_code}")

return self.credential_manager.process_response(reply)
return self.credential_manager.process_response(response)

def _handle_cert_request(self, request: Message):

reply = self.credential_manager.process_request(request.payload)
reply = self.credential_manager.process_request(request)
return Message(None, reply)


Expand Down
40 changes: 17 additions & 23 deletions nvflare/fuel/f3/cellnet/credential_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from cryptography.x509 import Certificate

from nvflare.fuel.f3.cellnet.cell_cipher import SimpleCellCipher
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.endpoint import Endpoint
from nvflare.fuel.f3.message import Message

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -90,35 +92,25 @@ def get_certificate(self, fqcn: str) -> bytes:
if not self.cell_cipher:
raise RuntimeError("This cell doesn't support certificate exchange, not running in secure mode")

target = FQCN.get_root(fqcn)
return self.cert_cache.get(target)

def save_certificate(self, fqcn: str, cert: bytes):
target = FQCN.get_root(fqcn)
self.cert_cache[target] = cert

def create_request(self, target: str) -> dict:
return self.cert_cache.get(fqcn)

def create_request(self) -> dict:
req = {
CERT_TARGET: target,
CERT_ORIGIN: FQCN.get_root(self.local_endpoint.name),
CERT_CONTENT: self.local_cert,
CERT_CA_CONTENT: self.ca_cert,
}

return req

def process_request(self, request: dict) -> dict:

target = request.get(CERT_TARGET)
origin = request.get(CERT_ORIGIN)

reply = {CERT_TARGET: target, CERT_ORIGIN: origin}

def process_request(self, request: Message) -> dict:
origin = request.get_header(MessageHeaderKey.ORIGIN)
target = request.get_header(MessageHeaderKey.DESTINATION)
reply = {}
if not self.local_cert:
reply[CERT_ERROR] = f"Target {target} is not running in secure mode"
else:
cert = request.get(CERT_CONTENT)
payload = request.payload
cert = payload.get(CERT_CONTENT)

# Save cert from requester in the cache
self.cert_cache[origin] = cert
Expand All @@ -128,14 +120,16 @@ def process_request(self, request: dict) -> dict:

return reply

@staticmethod
def process_response(reply: dict) -> bytes:

def process_response(self, message: Message) -> bytes:
origin = message.get_header(MessageHeaderKey.ORIGIN)
reply = message.payload
error = reply.get(CERT_ERROR)
if error:
raise RuntimeError(f"Request to get certificate from {target} failed: {error}")
raise RuntimeError(f"Request to get certificate from {origin} failed: {error}")

return reply.get(CERT_CONTENT)
cert = reply.get(CERT_CONTENT)
self.cert_cache[origin] = cert
return cert

def get_local_cert(self) -> Certificate:
return x509.load_pem_x509_certificate(self.local_cert)
Expand Down
7 changes: 5 additions & 2 deletions nvflare/lighter/impl/cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ def serialize_cert(cert):


def load_crt(path):
serialized_cert = open(path, "rb").read()
return x509.load_pem_x509_certificate(serialized_cert, default_backend())
return load_crt_bytes(open(path, "rb").read())


def load_crt_bytes(data: bytes):
return x509.load_pem_x509_certificate(data, default_backend())


class CertBuilder(Builder):
Expand Down
89 changes: 50 additions & 39 deletions nvflare/lighter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def generate_password(passlen=16):


def sign_one(content, signing_pri_key):
if isinstance(content, str):
content = content.encode("utf-8") # to bytes
signature = signing_pri_key.sign(
data=content,
padding=padding.PSS(
Expand All @@ -44,10 +46,35 @@ def sign_one(content, signing_pri_key):
return b64encode(signature).decode("utf-8")


def verify_one(content, signature, public_key):
if isinstance(content, str):
content = content.encode("utf-8") # to bytes
if isinstance(signature, str):
signature = b64decode(signature.encode("utf-8")) # decode to bytes
public_key.verify(
signature=signature,
data=content,
padding=padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
algorithm=hashes.SHA256(),
)


def verify_cert(cert_to_be_verified, root_ca_public_key):
root_ca_public_key.verify(
cert_to_be_verified.signature,
cert_to_be_verified.tbs_certificate_bytes,
padding.PKCS1v15(),
cert_to_be_verified.signature_hash_algorithm,
)


def load_private_key_file(file_path):
with open(file_path, "rt") as f:
pri_key = serialization.load_pem_private_key(f.read().encode("ascii"), password=None, backend=default_backend())
return pri_key
return load_private_key(f.read())


def load_private_key(data: str):
return serialization.load_pem_private_key(data.encode("ascii"), password=None, backend=default_backend())


def sign_folders(folder, signing_pri_key, crt_path, max_depth=9999):
Expand All @@ -58,25 +85,17 @@ def sign_folders(folder, signing_pri_key, crt_path, max_depth=9999):
for file in files:
if file == ".__nvfl_sig.json" or file == ".__nvfl_submitter.crt":
continue
signature = signing_pri_key.sign(
data=open(os.path.join(root, file), "rb").read(),
padding=padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
algorithm=hashes.SHA256(),
signature = sign_one(
content=open(os.path.join(root, file), "rb").read(),
signing_pri_key=signing_pri_key,
)
signatures[file] = b64encode(signature).decode("utf-8")
signatures[file] = signature
for folder in folders:
signature = signing_pri_key.sign(
data=folder.encode("utf-8"),
padding=padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
algorithm=hashes.SHA256(),
signature = sign_one(
content=folder,
signing_pri_key=signing_pri_key,
)
signatures[folder] = b64encode(signature).decode("utf-8")
signatures[folder] = signature

json.dump(signatures, open(os.path.join(root, ".__nvfl_sig.json"), "wt"))
shutil.copyfile(crt_path, os.path.join(root, ".__nvfl_submitter.crt"))
Expand All @@ -95,31 +114,27 @@ def verify_folder_signature(src_folder, root_ca_path):
public_key = cert.public_key()
except:
continue # TODO: shall return False
root_ca_public_key.verify(
cert.signature, cert.tbs_certificate_bytes, padding.PKCS1v15(), cert.signature_hash_algorithm
)
for k in signatures:
signatures[k] = b64decode(signatures[k].encode("utf-8"))
verify_cert(cert_to_be_verified=cert, root_ca_public_key=root_ca_public_key)
for file in files:
if file == ".__nvfl_sig.json" or file == ".__nvfl_submitter.crt":
continue
signature = signatures.get(file)
if signature:
public_key.verify(
verify_one(
content=open(os.path.join(root, file), "rb").read(),
signature=signature,
data=open(os.path.join(root, file), "rb").read(),
padding=padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
algorithm=hashes.SHA256(),
public_key=public_key,
)

for folder in folders:
signature = signatures.get(folder)
if signature:
public_key.verify(
verify_one(
content=folder,
signature=signature,
data=folder.encode("utf-8"),
padding=padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
algorithm=hashes.SHA256(),
public_key=public_key,
)

return True
except Exception as e:
return False
Expand All @@ -130,15 +145,11 @@ def sign_all(content_folder, signing_pri_key):
for f in os.listdir(content_folder):
path = os.path.join(content_folder, f)
if os.path.isfile(path):
signature = signing_pri_key.sign(
data=open(path, "rb").read(),
padding=padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
algorithm=hashes.SHA256(),
signature = sign_one(
content=open(path, "rb").read(),
signing_pri_key=signing_pri_key,
)
signatures[f] = b64encode(signature).decode("utf-8")
signatures[f] = signature
return signatures


Expand Down
29 changes: 24 additions & 5 deletions nvflare/private/aux_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@
from nvflare.fuel.f3.cellnet.core_cell import ReturnCode as CellReturnCode
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.private.defs import CellChannel
from nvflare.security.logging import secure_format_traceback
from nvflare.security.logging import secure_format_exception, secure_format_traceback


class AuxMsgTarget:

def __init__(self, name: str, fqcn: str):
self.name = name
self.fqcn = fqcn
Expand Down Expand Up @@ -158,7 +157,7 @@ def dispatch(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareab

def send_aux_request(
self,
targets: List[AuxMsgTarget], # AuxMsgTargets of targets
targets: List[AuxMsgTarget],
topic: str,
request: Shareable,
timeout: float,
Expand All @@ -167,6 +166,24 @@ def send_aux_request(
optional: bool = False,
secure: bool = False,
) -> dict:
"""Send aux request to specified targets.
Args:
targets: a list of AuxMsgTarget(s)
topic: topic of the message
request: the request to be sent
timeout: timeout of the request
fl_ctx: FL context data
bulk_send: whether to bulk send
optional: whether the request is optional
secure: whether to use P2P message encryption
Returns: a dict of target_name => reply
Note: each AuxMsgTarget in "targets" has the target's name and FQCN.
The returned dict is keyed on the client Name, not client FQCN (which can be multiple levels).
"""
try:
return self._send_to_cell(
targets=targets,
Expand All @@ -179,9 +196,11 @@ def send_aux_request(
optional=optional,
secure=secure,
)
except Exception:
except Exception as ex:
if optional:
self.logger.debug(f"Failed to send aux message {topic} to targets: {targets}")
self.logger.debug(
f"Failed to send aux message {topic} to targets: {targets}: {secure_format_exception(ex)}"
)
self.logger.debug(secure_format_traceback())
else:
self.logger.error(f"Failed to send aux message {topic} to targets: {targets}")
Expand Down
9 changes: 4 additions & 5 deletions nvflare/private/fed/app/bridge/bridge.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import argparse
import sys
import threading
import logging
import logging.config
import sys
import threading

from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.fuel.utils.config_service import search_file
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.cellnet.net_agent import NetAgent
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.utils.config_service import ConfigService
from nvflare.fuel.utils.config_service import ConfigService, search_file

SSL_ROOT_CERT = "rootCA.pem"

Expand Down
Loading

0 comments on commit ca00690

Please sign in to comment.