Skip to content

Commit

Permalink
sync
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv committed Feb 10, 2025
1 parent d1e3a7a commit ffb72ab
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 57 deletions.
6 changes: 3 additions & 3 deletions nvflare/app_opt/flower/connectors/flower_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class FlowerServerConnector(Connector):
FlowerServerConnector specifies commonly required methods for server connector implementations.
"""

def __init__(self):
Connector.__init__(self)
def __init__(self, monitor_interval):
Connector.__init__(self, monitor_interval)
self.num_rounds = None

def configure(self, config: dict, fl_ctx: FLContext):
Expand Down Expand Up @@ -79,7 +79,7 @@ def process_app_request(self, op: str, request: Shareable, fl_ctx: FLContext, ab
return make_reply(ReturnCode.SERVICE_UNAVAILABLE)

reply = self.send_request_to_flower(request, fl_ctx)
self.log_info(fl_ctx, f"received reply for '{op}'")
self.log_debug(fl_ctx, f"received reply for '{op}'")
return reply


Expand Down
54 changes: 14 additions & 40 deletions nvflare/app_opt/flower/connectors/grpc_client_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time

import flwr.proto.grpcadapter_pb2 as pb2
from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterServicer

Expand All @@ -23,7 +20,7 @@
from nvflare.app_opt.flower.defs import Constant
from nvflare.app_opt.flower.grpc_server import GrpcServer
from nvflare.app_opt.flower.utils import msg_container_to_shareable, reply_should_exit, shareable_to_msg_container
from nvflare.fuel.f3.drivers.net_utils import get_open_tcp_port
from nvflare.fuel.utils.network_utils import get_local_addresses
from nvflare.security.logging import secure_format_exception


Expand All @@ -33,7 +30,7 @@ def __init__(
int_server_grpc_options=None,
per_msg_timeout=2.0,
tx_timeout=10.0,
client_shutdown_timeout=5.0,
client_shutdown_timeout=0.5,
):
"""Constructor of GrpcClientConnector.
GrpcClientConnector is used to connect Flare Client with the Flower Client App.
Expand All @@ -52,31 +49,22 @@ def __init__(
self.internal_server_addr = None
self._training_stopped = False
self._client_name = None
self._stopping = False
self._exit_waiter = threading.Event()

def initialize(self, fl_ctx: FLContext):
super().initialize(fl_ctx)
self._client_name = fl_ctx.get_identity_name()

def _start_client(self, server_addr: str, fl_ctx: FLContext):
def _start_client(self, superlink_addr: str, clientapp_api_addr: str, fl_ctx: FLContext):
app_ctx = {
Constant.APP_CTX_CLIENT_NAME: self._client_name,
Constant.APP_CTX_SERVER_ADDR: server_addr,
Constant.APP_CTX_SUPERLINK_ADDR: superlink_addr,
Constant.APP_CTX_CLIENTAPP_API_ADDR: clientapp_api_addr,
Constant.APP_CTX_NUM_ROUNDS: self.num_rounds,
}
self.start_applet(app_ctx, fl_ctx)

def _stop_client(self):
self._training_stopped = True

# do not stop the applet until should-exit is sent
if not self._exit_waiter.wait(timeout=2.0):
self.logger.warning(f"did not send should-exit before shutting down supernode")

# give 1 sec for the supernode to quite gracefully
self.logger.info("about to stop applet")
time.sleep(1.0)
self.stop_applet(self.client_shutdown_timeout)

def _is_stopped(self) -> (bool, int):
Expand All @@ -87,26 +75,25 @@ def _is_stopped(self) -> (bool, int):
if self._training_stopped:
return True, 0

if self._stopping:
self.stop(fl_ctx=None)
return True, 0

return False, 0

def start(self, fl_ctx: FLContext):
if not self.num_rounds:
raise RuntimeError("cannot start - num_rounds is not set")

# dynamically determine address on localhost
port = get_open_tcp_port(resources={})
if not port:
raise RuntimeError("failed to get a port for Flower server")
self.internal_server_addr = f"127.0.0.1:{port}"
# get addresses for flower supernode:
# - superlink_addr for supernode to connect to superlink
# - clientapp_api_addr for client app to connect to the supernode
addresses = get_local_addresses(2)
superlink_addr = addresses[0]
clientapp_api_addr = addresses[1]

self.internal_server_addr = superlink_addr
self.logger.info(f"Start internal server at {self.internal_server_addr}")
self.internal_grpc_server = GrpcServer(self.internal_server_addr, 10, self.int_server_grpc_options, self)
self.internal_grpc_server.start(no_blocking=True)
self.logger.info(f"Started internal grpc server at {self.internal_server_addr}")
self._start_client(self.internal_server_addr, fl_ctx)
self._start_client(superlink_addr, clientapp_api_addr, fl_ctx)
self.logger.info("Started external Flower grpc client")

def stop(self, fl_ctx: FLContext):
Expand Down Expand Up @@ -144,26 +131,13 @@ def SendReceive(self, request: pb2.MessageContainer, context):
"""
try:
if self.stopped:
self._stopping = True
self._exit_waiter.set()
self.logger.info("asked supernode to exit_1!")
return reply_should_exit()

reply = self._send_flower_request(msg_container_to_shareable(request))
rc = reply.get_return_code()
if rc == ReturnCode.OK:
return shareable_to_msg_container(reply)
else:
# server side already ended
self.logger.warning(f"Flower server has stopped with RC {rc}")
self._stopping = True
self._exit_waiter.set()
self.logger.info("asked supernode to exit_2!")
return reply_should_exit()
except Exception as ex:
self._abort(reason=f"_send_flower_request exception: {secure_format_exception(ex)}")
self._stopping = True
self._exit_waiter.set()
self.logger.info("asked supernode to exit_3!")
return reply_should_exit()
44 changes: 30 additions & 14 deletions nvflare/app_opt/flower/connectors/grpc_server_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,29 @@
from nvflare.app_opt.flower.defs import Constant
from nvflare.app_opt.flower.grpc_client import GrpcClient
from nvflare.app_opt.flower.utils import msg_container_to_shareable, shareable_to_msg_container
from nvflare.fuel.f3.drivers.net_utils import get_open_tcp_port
from nvflare.fuel.utils.network_utils import get_local_addresses
from nvflare.security.logging import secure_format_exception


class GrpcServerConnector(FlowerServerConnector):
def __init__(
self,
int_client_grpc_options=None,
flower_server_ready_timeout=Constant.FLOWER_SERVER_READY_TIMEOUT,
monitor_interval: float = 0.5,
):
FlowerServerConnector.__init__(self)
FlowerServerConnector.__init__(self, monitor_interval)
self.int_client_grpc_options = int_client_grpc_options
self.flower_server_ready_timeout = flower_server_ready_timeout
self.internal_grpc_client = None
self._server_stopped = False
self._exit_code = 0

def _start_server(self, addr: str, fl_ctx: FLContext):
def _start_server(self, serverapp_api_addr: str, fleet_api_addr: str, exec_api_addr: str, fl_ctx: FLContext):
app_ctx = {
Constant.APP_CTX_SERVER_ADDR: addr,
Constant.APP_CTX_SERVERAPP_API_ADDR: serverapp_api_addr,
Constant.APP_CTX_FLEET_API_ADDR: fleet_api_addr,
Constant.APP_CTX_EXEC_API_ADDR: exec_api_addr,
Constant.APP_CTX_NUM_ROUNDS: self.num_rounds,
}
self.start_applet(app_ctx, fl_ctx)
Expand All @@ -49,7 +53,7 @@ def _stop_server(self):
def _is_stopped(self) -> (bool, int):
runner_stopped, ec = self.is_applet_stopped()
if runner_stopped:
self.logger.info("applet is stopped!")
self.logger.debug("applet is stopped!")
return runner_stopped, ec

if self._server_stopped:
Expand All @@ -60,16 +64,24 @@ def _is_stopped(self) -> (bool, int):

def start(self, fl_ctx: FLContext):
# we dynamically create server address on localhost
port = get_open_tcp_port(resources={})
if not port:
raise RuntimeError("failed to get a port for Flower grpc server")

server_addr = f"127.0.0.1:{port}"
self.log_info(fl_ctx, f"starting grpc connector: {server_addr=}")
self._start_server(server_addr, fl_ctx)
# we need 3 free local addresses for flwr's superlink:
# - address for client to connect to (fleet-api-address)
# - address for serverapp to connect to (serverapp-api-address)
# - address for "flwr run" to connect to (exec-api-address)
try:
addresses = get_local_addresses(3)
except Exception as ex:
raise RuntimeError(f"failed to get addresses for Flower grpc server: {secure_format_exception(ex)}")

serverapp_api_addr = addresses[0]
fleet_api_addr = addresses[1]
exec_api_addr = addresses[2]

self.log_info(fl_ctx, f"starting grpc connector: {serverapp_api_addr=} {fleet_api_addr=} {exec_api_addr=}")
self._start_server(serverapp_api_addr, fleet_api_addr, exec_api_addr, fl_ctx)

# start internal grpc client
self.internal_grpc_client = GrpcClient(server_addr, self.int_client_grpc_options)
self.internal_grpc_client = GrpcClient(fleet_api_addr, self.int_client_grpc_options)
self.internal_grpc_client.start(ready_timeout=self.flower_server_ready_timeout)

def stop(self, fl_ctx: FLContext):
Expand Down Expand Up @@ -101,7 +113,11 @@ def send_request_to_flower(self, request: Shareable, fl_ctx: FLContext) -> Share
self.log_warning(fl_ctx, "dropped app request since applet is already stopped")
return make_reply(ReturnCode.SERVICE_UNAVAILABLE)

result = self.internal_grpc_client.send_request(shareable_to_msg_container(request))
grpc_client = self.internal_grpc_client
if not grpc_client:
return make_reply(ReturnCode.SERVICE_UNAVAILABLE)

result = grpc_client.send_request(shareable_to_msg_container(request))

if isinstance(result, pb2.MessageContainer):
return msg_container_to_shareable(result)
Expand Down

0 comments on commit ffb72ab

Please sign in to comment.