Skip to content

Commit

Permalink
fix: fixed inference getting stuck (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored Feb 16, 2025
1 parent 1e73dcb commit a4964a0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 24 deletions.
4 changes: 3 additions & 1 deletion src/app.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,10 @@ void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *cont

RootLlmInference inference(&net, &cpu, &execution, &executor, network);

if (network != nullptr)
if (network != nullptr) {
network->resetStats();
network->setTurbo(true);
}

AppInferenceContext context;
context.args = args;
Expand Down
46 changes: 23 additions & 23 deletions src/nn/nn-network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ typedef SSIZE_T ssize_t;
#define SOCKET_LAST_ERROR strerror(errno)

#define ACK 23571113
#define ONE_MB 1048576
#define MAX_CHUNK_SIZE 4096

static inline bool isEagainError() {
#ifdef _WIN32
Expand Down Expand Up @@ -338,11 +338,11 @@ std::unique_ptr<NnNetwork> NnNetwork::connect(NnSize nSockets, char **hosts, NnS
return std::unique_ptr<NnNetwork>(new NnNetwork(nSockets, sockets));
}

NnNetwork::NnNetwork(NnSize nSockets, int *sockets) {
NnNetwork::NnNetwork(NnSize nSockets, int *sockets)
: sentBytes(0), recvBytes(0)
{
this->nSockets = nSockets;
this->sockets = sockets;
this->sentBytes.exchange(0);
this->recvBytes.exchange(0);
}

NnNetwork::~NnNetwork() {
Expand All @@ -362,25 +362,25 @@ void NnNetwork::setTurbo(bool enabled) {

void NnNetwork::write(NnSize socketIndex, const void *data, size_t size) {
assert(socketIndex >= 0 && socketIndex < nSockets);
sentBytes += size;
sentBytes.fetch_add(size);

char *current = (char*)data;
int s = sockets[socketIndex];
for (size_t chunk = 0; chunk < size; chunk += ONE_MB) {
size_t chunkSize = chunk + ONE_MB < size ? ONE_MB : size - chunk;
for (size_t chunk = 0; chunk < size; chunk += MAX_CHUNK_SIZE) {
size_t chunkSize = chunk + MAX_CHUNK_SIZE < size ? MAX_CHUNK_SIZE : size - chunk;
writeSocket(s, current, chunkSize);
current += chunkSize;
}
}

void NnNetwork::read(NnSize socketIndex, void *data, size_t size) {
assert(socketIndex >= 0 && socketIndex < nSockets);
recvBytes += size;
recvBytes.fetch_add(size);

char *current = (char*)data;
int s = sockets[socketIndex];
for (size_t chunk = 0; chunk < size; chunk += ONE_MB) {
size_t chunkSize = chunk + ONE_MB < size ? ONE_MB : size - chunk;
for (size_t chunk = 0; chunk < size; chunk += MAX_CHUNK_SIZE) {
size_t chunkSize = chunk + MAX_CHUNK_SIZE < size ? MAX_CHUNK_SIZE : size - chunk;
readSocket(s, current, chunkSize);
current += chunkSize;
}
Expand All @@ -399,7 +399,7 @@ void NnNetwork::readAck(NnSize socketIndex) {
bool NnNetwork::tryReadWithMaxAttempts(NnSize socketIndex, void *data, size_t size, unsigned long maxAttempts) {
assert(socketIndex >= 0 && socketIndex < nSockets);
if (tryReadSocket(sockets[socketIndex], data, size, maxAttempts)) {
recvBytes += size;
recvBytes.fetch_add(size);
return true;
}
return false;
Expand All @@ -420,7 +420,8 @@ void NnNetwork::writeMany(NnSize n, NnSocketIo *ios) {
if (io->size > 0) {
isWriting = true;
int socket = sockets[io->socketIndex];
ssize_t s = send(socket, (const char*)io->data, io->size, 0);
ssize_t chunkSize = io->size > MAX_CHUNK_SIZE ? MAX_CHUNK_SIZE : io->size;
ssize_t s = send(socket, (const char*)io->data, chunkSize, 0);
if (s < 0) {
if (isEagainError()) {
continue;
Expand All @@ -434,7 +435,7 @@ void NnNetwork::writeMany(NnSize n, NnSocketIo *ios) {
}
}
} while (isWriting);
sentBytes += nBytes;
sentBytes.fetch_add(nBytes);
}

void NnNetwork::writeAll(void *data, size_t size) {
Expand Down Expand Up @@ -477,18 +478,18 @@ void NnNetwork::readMany(NnSize n, NnSocketIo *ios) {
}
}
} while (isReading);
recvBytes += nBytes;
recvBytes.fetch_add(nBytes);
}

void NnNetwork::getStats(size_t *sentBytes, size_t *recvBytes) {
*sentBytes = this->sentBytes;
*recvBytes = this->recvBytes;
this->resetStats();
*sentBytes = this->sentBytes.load();
*recvBytes = this->recvBytes.load();
resetStats();
}

void NnNetwork::resetStats() {
this->sentBytes.exchange(0);
this->recvBytes.exchange(0);
sentBytes.exchange(0);
recvBytes.exchange(0);
}

static void syncWithRoot(NnNetwork *network, NnByte nodeIndex, NnByte *buffer, NnSize nBytes, NnSize nThreads, NnSize threadIndex) {
Expand Down Expand Up @@ -525,8 +526,7 @@ static void syncNodeSlices(bool onlyFromWorkerToRoot, NnNetwork *network, NnSize
if (nSocketsPerThread == 0) return;
NnSize sliceBytes = nBytes / nNodes;

std::unique_ptr<NnSocketIo> iosPtr(new NnSocketIo[nSocketsPerThread]);
NnSocketIo *ios = iosPtr.get();
std::vector<NnSocketIo> ios(nSocketsPerThread);

if (!onlyFromWorkerToRoot || isWorker) {
NnByte *mySliceData = &buffer[sliceBytes * nodeIndex];
Expand All @@ -537,7 +537,7 @@ static void syncNodeSlices(bool onlyFromWorkerToRoot, NnNetwork *network, NnSize
ios[i].data = mySliceData;
ios[i].size = sliceBytes;
}
network->writeMany(nSocketsPerThread, ios);
network->writeMany(nSocketsPerThread, &ios[0]);
}

if (!onlyFromWorkerToRoot || !isWorker) {
Expand All @@ -549,7 +549,7 @@ static void syncNodeSlices(bool onlyFromWorkerToRoot, NnNetwork *network, NnSize
ios[i].data = sliceData;
ios[i].size = sliceBytes;
}
network->readMany(nSocketsPerThread, ios);
network->readMany(nSocketsPerThread, &ios[0]);
}
}

Expand Down

0 comments on commit a4964a0

Please sign in to comment.