Skip to content

Commit

Permalink
implement c server bridge
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv committed Jan 20, 2024
1 parent 637d783 commit 8fb8318
Show file tree
Hide file tree
Showing 14 changed files with 381 additions and 70 deletions.
36 changes: 17 additions & 19 deletions nvflare/app_common/xgb/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ class XGBBridge(ABC, FLComponent):
def __init__(self):
FLComponent.__init__(self)
self.abort_signal = None
self.target_stopped = False
self.target_rc = 0

def set_abort_signal(self, abort_signal: Signal):
self.abort_signal = abort_signal
Expand Down Expand Up @@ -146,7 +144,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
def _send_request(self, op: str, req: Shareable) -> bytes:
reply = self.sender.send_to_server(op, req, self.abort_signal)
if isinstance(reply, Shareable):
rcv_buf = reply.get(Constant.KEY_XGB_RCV_BUF)
rcv_buf = reply.get(Constant.PARAM_KEY_RCV_BUF)
if not isinstance(rcv_buf, bytes):
raise RuntimeError(f"invalid reply for op {op}: expect bytes but got {type(rcv_buf)}")
return rcv_buf
Expand All @@ -155,31 +153,31 @@ def _send_request(self, op: str, req: Shareable) -> bytes:

def send_all_gather(self, rank: int, seq: int, send_buf: bytes) -> bytes:
req = Shareable()
req[Constant.KEY_XGB_RANK] = rank
req[Constant.KEY_XGB_SEQ] = seq
req[Constant.KEY_XGB_SEND_BUF] = send_buf
req[Constant.PARAM_KEY_RANK] = rank
req[Constant.PARAM_KEY_SEQ] = seq
req[Constant.PARAM_KEY_SEND_BUF] = send_buf
return self._send_request(Constant.OP_ALL_GATHER, req)

def send_all_gather_v(self, rank: int, seq: int, send_buf: bytes) -> bytes:
req = Shareable()
req[Constant.KEY_XGB_RANK] = rank
req[Constant.KEY_XGB_SEQ] = seq
req[Constant.KEY_XGB_SEND_BUF] = send_buf
req[Constant.PARAM_KEY_RANK] = rank
req[Constant.PARAM_KEY_SEQ] = seq
req[Constant.PARAM_KEY_SEND_BUF] = send_buf
return self._send_request(Constant.OP_ALL_GATHER_V, req)

def send_all_reduce(self, rank: int, seq: int, data_type: int, reduce_op: int, send_buf: bytes) -> bytes:
req = Shareable()
req[Constant.KEY_XGB_RANK] = rank
req[Constant.KEY_XGB_SEQ] = seq
req[Constant.KEY_XGB_DATA_TYPE] = data_type
req[Constant.KEY_XGB_REDUCE_OP] = reduce_op
req[Constant.KEY_XGB_SEND_BUF] = send_buf
req[Constant.PARAM_KEY_RANK] = rank
req[Constant.PARAM_KEY_SEQ] = seq
req[Constant.PARAM_KEY_DATA_TYPE] = data_type
req[Constant.PARAM_KEY_REDUCE_OP] = reduce_op
req[Constant.PARAM_KEY_SEND_BUF] = send_buf
return self._send_request(Constant.OP_ALL_REDUCE, req)

def send_broadcast(self, rank: int, seq: int, root: int, send_buf: bytes) -> bytes:
req = Shareable()
req[Constant.KEY_XGB_RANK] = rank
req[Constant.KEY_XGB_SEQ] = seq
req[Constant.KEY_XGB_ROOT] = root
req[Constant.KEY_XGB_SEND_BUF] = send_buf
return self._send_request(Constant.OP_ALL_GATHER, req)
req[Constant.PARAM_KEY_RANK] = rank
req[Constant.PARAM_KEY_SEQ] = seq
req[Constant.PARAM_KEY_ROOT] = root
req[Constant.PARAM_KEY_SEND_BUF] = send_buf
return self._send_request(Constant.OP_BROADCAST, req)
6 changes: 4 additions & 2 deletions nvflare/app_common/xgb/bridges/c/client_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def _get_pending_op(self) -> (int, dict):
if op < 0:
raise RuntimeError(f"error get_pending_op: {op}")

self.logger.info(f"***** got pending op: {op=} {seq.value=} {data_type.value=} {root.value=}")

props = {
Constant.PARAM_KEY_SEQ: seq.value,
Constant.PARAM_KEY_SEND_BUF: bytes(send_buf[0:send_size.value]),
Expand Down Expand Up @@ -176,7 +178,7 @@ def _poll_requests(self):
handler_f = self.op_table.get(op)
if handler_f is None:
self.logger.error(f"no handler for opcode {op}")
self.target_rc = Constant.ERR_CLIENT_ERROR
self.target_rc = Constant.ERR_TARGET_ERROR
self.target_done = True
return
else:
Expand All @@ -187,7 +189,7 @@ def _poll_requests(self):
except Exception as ex:
self.logger.error(f"exception handling all_gather: {secure_format_exception(ex)}")
self._stop_target()
self.target_rc = Constant.ERR_CLIENT_ERROR
self.target_rc = Constant.ERR_TARGET_ERROR
self.target_done = True
return
time.sleep(0.001)
Expand Down
3 changes: 2 additions & 1 deletion nvflare/app_common/xgb/bridges/c/impl/mock/build.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
gcc -fPIC -o client.o -c client.c
gcc -shared -o libxgb.so client.o
gcc -fPIC -o server.o -c server.c
gcc -shared -o libxgb.so client.o server.o
11 changes: 8 additions & 3 deletions nvflare/app_common/xgb/bridges/c/impl/mock/client.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <string.h>
#include <stdlib.h>
#include <pthread.h>
#include "../xgb_common.h"
#include "../xgb_client.h"

int XGBC_max_num_clients = 0;
Expand Down Expand Up @@ -75,6 +76,7 @@ int xgbc_get_pending_op(
}

pthread_mutex_lock(&XGBC_op_mutex);
*seq = c->seq;
*send_buf = c->send_buf;
*send_size = c->send_buf_size;
*data_type = c->data_type;
Expand Down Expand Up @@ -148,14 +150,17 @@ int _send_op(
c->send_buf = send_buf;
c->send_buf_size = send_size;
c->seq = seq;
c->data_type = data_type;
c->reduce_op = reduce_op;
c->root = root;
c->received = 0;
c->op = op;
pthread_mutex_unlock(&XGBC_op_mutex);

// wait for reply
while (!c->received) {
if (c->aborted) {
return 0;
return ERR_ABORTED;
}
usleep(1000);
}
Expand Down Expand Up @@ -185,13 +190,13 @@ int xgbc_send_all_reduce(
unsigned char* send_buf, size_t send_size,
unsigned char** rcv_buf, size_t* rcv_size
) {
return _send_op(OP_ALL_GATHER_V, rank, seq, send_buf, send_size, data_type, reduce_op, 0, rcv_buf, rcv_size);
return _send_op(OP_ALL_REDUCE, rank, seq, send_buf, send_size, data_type, reduce_op, 0, rcv_buf, rcv_size);
}

int xgbc_send_broadcast(
int rank, int seq, int root, unsigned char* send_buf, size_t send_size, unsigned char** rcv_buf, size_t* rcv_size
) {
return _send_op(OP_ALL_GATHER_V, rank, seq, send_buf, send_size, 0, 0, root, rcv_buf, rcv_size);
return _send_op(OP_BROADCAST, rank, seq, send_buf, send_size, 0, 0, root, rcv_buf, rcv_size);
}

void _check_result(
Expand Down
Binary file modified nvflare/app_common/xgb/bridges/c/impl/mock/client.o
Binary file not shown.
117 changes: 117 additions & 0 deletions nvflare/app_common/xgb/bridges/c/impl/mock/server.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <stdio.h>
#include <pthread.h>
#include "../xgb_common.h"
#include "../xgb_server.h"

int XGBS_rcv_count = 0;
int XGBS_pending_count = 0;
int XGBS_pending_seq = -1;
int XGBS_pending_op = 0;
int XGBS_world_size = 0;
int XGBS_aborted = 0;
pthread_mutex_t XGBS_op_mutex;
pthread_mutex_t XGBS_count_mutex;

void xgbs_initialize(int world_size) {
XGBS_world_size = world_size;
}

int _xgbs_process_request(
int op, int rank, int seq,
unsigned char* send_buf, size_t send_size,
unsigned char** rcv_buf, size_t* rcv_size) {

unsigned char* rcv_buf_ptr;
int rc = 0;

printf("start processing req: op=%d, rank=%d seq=%d\n", op, rank, seq);

pthread_mutex_lock(&XGBS_op_mutex);
if (XGBS_pending_seq < 0) {
XGBS_pending_seq = seq;
XGBS_pending_op = op;
} else if (seq != XGBS_pending_seq) {
printf("received seq %d from rank %d while working on seq %d\n", seq, rank, XGBS_pending_seq);
rc = ERR_SEQ_MISMATCH;
} else if (op != XGBS_pending_op) {
printf("received op %d from rank %d while working on op %d\n", op, rank, XGBS_pending_op);
rc = ERR_OP_MISMATCH;
}
pthread_mutex_unlock(&XGBS_op_mutex);

if (rc != 0) {
return rc;
}

// echo back
rcv_buf_ptr = (unsigned char*)malloc(send_size);
memcpy(rcv_buf_ptr, send_buf, send_size);
*rcv_buf = rcv_buf_ptr;
*rcv_size = send_size;

// don't return until all ranks are received
pthread_mutex_lock(&XGBS_count_mutex);
XGBS_rcv_count ++;
XGBS_pending_count ++;
pthread_mutex_unlock(&XGBS_count_mutex);

while (XGBS_rcv_count < XGBS_world_size) {
if (XGBS_aborted) {
printf("CCC: aborted while waiting for ranks for op=%d, rank=%d, seq=%d\n", op, rank, seq);
return ERR_ABORTED;
}
usleep(1000);
}

pthread_mutex_lock(&XGBS_count_mutex);
XGBS_pending_count --;
pthread_mutex_unlock(&XGBS_count_mutex);
if (XGBS_pending_count == 0) {
// every rank has got its result - reset for next request
XGBS_rcv_count = 0;
XGBS_pending_seq = -1;
}
printf("finished req: op=%d rank=%d seq=%d\n", op, rank, seq);
return 0;
}

int xgbs_all_gather(
int rank, int seq,
unsigned char* send_buf, size_t send_size,
unsigned char** rcv_buf, size_t* rcv_size) {
return _xgbs_process_request(OP_ALL_GATHER, rank, seq, send_buf, send_size, rcv_buf, rcv_size);
}

int xgbs_all_gather_v(
int rank, int seq,
unsigned char* send_buf, size_t send_size,
unsigned char** rcv_buf, size_t* rcv_size) {
return _xgbs_process_request(OP_ALL_GATHER_V, rank, seq, send_buf, send_size, rcv_buf, rcv_size);
}

int xgbs_all_reduce(
int rank, int seq,
int data_type, int reduce_op,
unsigned char* send_buf, size_t send_size,
unsigned char** rcv_buf, size_t* rcv_size) {
return _xgbs_process_request(OP_ALL_REDUCE, rank, seq, send_buf, send_size, rcv_buf, rcv_size);
}

int xgbs_broadcast(
int rank, int seq, int root,
unsigned char* send_buf, size_t send_size,
unsigned char** rcv_buf, size_t* rcv_size) {
return _xgbs_process_request(OP_BROADCAST, rank, seq, send_buf, send_size, rcv_buf, rcv_size);
}

void xgbs_free_buf(unsigned char* buf) {
free(buf);
}

void xgbs_abort() {
printf("CCC: abort received\n");
XGBS_aborted = 1;
}
Binary file not shown.
11 changes: 0 additions & 11 deletions nvflare/app_common/xgb/bridges/c/impl/xgb_client.h
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
#ifndef _XGB_CLIENT_H_
#define _XGB_CLIENT_H_

#define OP_ALL_GATHER 1
#define OP_ALL_GATHER_V 2
#define OP_ALL_REDUCE 3
#define OP_BROADCAST 4
#define OP_DONE 99

#define ERR_OP_MISMATCH -1
#define ERR_INVALID_RANK -2
#define ERR_NO_CLIENT_FOR_RANK -3
#define ERR_ABORTED -4

typedef struct {
int rank;
int op;
Expand Down
17 changes: 17 additions & 0 deletions nvflare/app_common/xgb/bridges/c/impl/xgb_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef _XGB_COMMON_H_
#define _XGB_COMMON_H_

#define OP_ALL_GATHER 1
#define OP_ALL_GATHER_V 2
#define OP_ALL_REDUCE 3
#define OP_BROADCAST 4
#define OP_DONE 99

#define ERR_OP_MISMATCH -1
#define ERR_INVALID_RANK -2
#define ERR_NO_CLIENT_FOR_RANK -3
#define ERR_SEQ_MISMATCH -4
#define ERR_ABORTED -5


#endif
31 changes: 31 additions & 0 deletions nvflare/app_common/xgb/bridges/c/impl/xgb_server.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef _XGB_SERVER_H_
#define _XGB_SERVER_H_

extern void xgbs_initialize(int world_size);

extern int xgbs_all_gather(
int rank, int seq,
unsigned char* send_buf, size_t send_size,
unsigned char** rcv_buf, size_t* rcv_size);

extern int xgbs_all_gather_v(
int rank, int seq,
unsigned char* send_buf, size_t send_size,
unsigned char** rcv_buf, size_t* rcv_size);

extern int xgbs_all_reduce(
int rank, int seq,
int data_type, int reduce_op,
unsigned char* send_buf, size_t send_size,
unsigned char** rcv_buf, size_t* rcv_size);

extern int xgbs_broadcast(
int rank, int seq, int root,
unsigned char* send_buf, size_t send_size,
unsigned char** rcv_buf, size_t* rcv_size);

extern void xgbs_free_buf(unsigned char* buf);

extern void xgbs_abort();

#endif
Loading

0 comments on commit 8fb8318

Please sign in to comment.