Skip to content

Commit

Permalink
feat: benchmark. (#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored Mar 4, 2025
1 parent 47f3ac1 commit a91745d
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 38 deletions.
5 changes: 3 additions & 2 deletions src/app.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *cont
}

NnCpuDevice cpu(&net.netConfig, rootNodeConfig, &execution);
NnExecutor executor(&net.netConfig, rootNodeConfig, &cpu, &execution, synchronizer.get());
NnExecutor executor(&net.netConfig, rootNodeConfig, &cpu, &execution, synchronizer.get(), args->benchmark);

NnRootWeightLoader weightLoader(&executor, network, nNodes);
loadLlmNetWeight(args->modelPath, &net, &weightLoader);
Expand All @@ -246,6 +246,7 @@ void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *cont
context.sampler = &sampler;
context.tokenizer = &tokenizer;
context.network = network;
context.executor = &executor;

handler(&context);

Expand All @@ -269,7 +270,7 @@ void runWorkerApp(AppCliArgs *args) {

NnNetworkNodeSynchronizer synchronizer(network, &execution, &netConfig, &nodeConfig);
NnCpuDevice cpu(&netConfig, &nodeConfig, &execution);
NnExecutor executor(&netConfig, &nodeConfig, &cpu, &execution, &synchronizer);
NnExecutor executor(&netConfig, &nodeConfig, &cpu, &execution, &synchronizer, false);

NnWorkerWeightReader weightReader(&executor, network);
weightReader.read();
Expand Down
1 change: 1 addition & 0 deletions src/app.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ typedef struct {
Tokenizer *tokenizer;
Sampler *sampler;
NnNetwork *network;
NnExecutor *executor;
} AppInferenceContext;

void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *context));
Expand Down
42 changes: 26 additions & 16 deletions src/dllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "nn/nn-config-builder.hpp"
#include "nn/nn-cpu.hpp"
#include "nn/nn-network.hpp"
#include "nn/nn-executor.hpp"
#include "llm.hpp"
#include "tokenizer.hpp"
#include "app.hpp"
Expand All @@ -26,12 +27,13 @@ static void inference(AppInferenceContext *context) {
if (nInputTokens > context->args->steps)
throw std::runtime_error("The number of prompt tokens is greater than the number of steps");

Timer evalTimer;
NnSize sentBytes = 0;
NnSize recvBytes = 0;
NnUint evalTotalTime = 0;
NnUint predTotalTime = 0;

printf("%s\n", context->args->prompt);
for (;;) {
Timer batchTimer;
long remainingTokens = nInputTokens - 1 - (long)pos;
if (remainingTokens <= 0)
break;
Expand All @@ -51,23 +53,25 @@ static void inference(AppInferenceContext *context) {

if (context->network != nullptr)
context->network->getStats(&sentBytes, &recvBytes);
printf("🔷️ E%5u ms S%6zu kB R%6zu kB (%d tokens)\n",
batchTimer.elapsedMiliseconds(),

NnUint evalTime = context->executor->getTotalTime(STEP_EXECUTE_OP) + context->executor->getTotalTime(STEP_SYNC_POINTERS);
NnUint syncTime = context->executor->getTotalTime(STEP_SYNC_NODES);
printf("🔷️ Eval%5u ms Sync%5u ms | Sent%6zu kB Recv%6zu kB | (%d tokens)\n",
evalTime / 1000,
syncTime / 1000,
sentBytes / 1024,
recvBytes / 1024,
batchSize);
evalTotalTime += evalTime + syncTime;
}
NnUint evalTime = evalTimer.elapsedMiliseconds();

fflush(stdout);

context->inference->setBatchSize(1);
context->tokenizer->resetDecoder();

Timer predTimer;
const NnUint maxPos = std::min(context->header->seqLen, context->args->steps);
for (; pos < maxPos; pos++) {
Timer tokenTimer;
context->inference->setPosition(pos);
context->inference->setToken(0, token);
context->inference->forward();
Expand All @@ -79,29 +83,34 @@ static void inference(AppInferenceContext *context) {
if (context->network != nullptr)
context->network->getStats(&sentBytes, &recvBytes);

printf("🔶 P%5u ms S%6zu kB R%6zu kB %s\n",
tokenTimer.elapsedMiliseconds(),
NnUint predTime = context->executor->getTotalTime(STEP_EXECUTE_OP) + context->executor->getTotalTime(STEP_SYNC_POINTERS);
NnUint syncTime = context->executor->getTotalTime(STEP_SYNC_NODES);
printf("🔶 Pred%5u ms Sync%5u ms | Sent%6zu kB Recv%6zu kB | %s\n",
predTime / 1000,
syncTime / 1000,
sentBytes / 1024,
recvBytes / 1024,
piece == nullptr ? "~" : piece);
fflush(stdout);
predTotalTime += predTime + syncTime;
}
NnUint predTime = predTimer.elapsedMiliseconds();

NnUint nEvalTokens = nInputTokens - 1;
NnUint nPredTokens = pos - nEvalTokens;
float evalTotalTimeMs = evalTotalTime / 1000.0;
float predTotalTimeMs = predTotalTime / 1000.0;
printf("\n");
printf("Evaluation\n");
printf(" nBatches: %d\n", context->args->nBatches);
printf(" nTokens: %d\n", nEvalTokens);
printf(" tokens/s: %3.2f (%3.2f ms/tok)\n",
nEvalTokens / (evalTime / 1000.0),
evalTime / ((float) nEvalTokens));
(nEvalTokens * 1000) / evalTotalTimeMs,
evalTotalTimeMs / ((float) nEvalTokens));
printf("Prediction\n");
printf(" nTokens: %d\n", nPredTokens);
printf(" tokens/s: %3.2f (%3.2f ms/tok)\n",
nPredTokens / (predTime / 1000.0),
predTime / ((float) nPredTokens));
(nPredTokens * 1000) / predTotalTimeMs,
predTotalTimeMs / ((float) nPredTokens));
}

static NnUint readStdin(const char *guide, char *buffer, NnUint size) {
Expand Down Expand Up @@ -211,9 +220,10 @@ int main(int argc, char **argv) {
int returnCode = EXIT_SUCCESS;
try {
AppCliArgs args = AppCliArgs::parse(argc, argv, true);
if (std::strcmp(args.mode, "inference") == 0)
if (std::strcmp(args.mode, "inference") == 0) {
args.benchmark = true;
runInferenceApp(&args, &inference);
else if (std::strcmp(args.mode, "chat") == 0)
} else if (std::strcmp(args.mode, "chat") == 0)
runInferenceApp(&args, &chat);
else if (std::strcmp(args.mode, "worker") == 0)
runWorkerApp(&args);
Expand Down
4 changes: 4 additions & 0 deletions src/nn/nn-core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ void printNodeRequiredMemory(NnNetConfig *netConfig, NnNodeConfig *nodeConfig) {
}

Timer::Timer() {
reset();
}

void Timer::reset() {
startTime = std::chrono::high_resolution_clock::now();
}

Expand Down
1 change: 1 addition & 0 deletions src/nn/nn-core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ class Timer {
std::chrono::time_point<std::chrono::high_resolution_clock> startTime;
public:
Timer();
void reset();
NnUint elapsedMiliseconds();
NnUint elapsedMicroseconds();
};
Expand Down
2 changes: 1 addition & 1 deletion src/nn/nn-cpu-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ int main() {
NnCpuDevice device(&netConfig, &nodeConfig, &execution);
NnFakeNodeSynchronizer synchronizer;
float *rms = (float *)device.buffers[0];
NnExecutor executor(&netConfig, &nodeConfig, &device, &execution, &synchronizer);
NnExecutor executor(&netConfig, &nodeConfig, &device, &execution, &synchronizer, false);
executor.loadWeight("rms_norm", 0, sizeof(rmsNormWeight), (NnByte *)rmsNormWeight);

execution.setBatchSize(2);
Expand Down
41 changes: 24 additions & 17 deletions src/nn/nn-executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
#include <stdexcept>
#include "nn-executor.hpp"

#define DEBUG_EXECUTOR_BENCHMARK false

void NnFakeNodeSynchronizer::sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) {
// Nothing
}
Expand Down Expand Up @@ -35,7 +33,7 @@ void NnNetExecution::setBatchSize(NnUint batchSize) {
this->batchSize = batchSize;
}

NnExecutor::NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnDevice *device, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer)
NnExecutor::NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnDevice *device, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer, bool benchmark)
: segments(nodeConfig->nSegments), steps()
{
NnUint maxNThreads = device->maxNThreads();
Expand All @@ -50,7 +48,7 @@ NnExecutor::NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnDevic
if (segmentConfig->nOps > 0) {
NnDeviceSegment *segment = device->createSegment(segmentIndex);
segments[segmentIndex] = std::unique_ptr<NnDeviceSegment>(segment);

for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++)
steps.push_back(NnExecutorStep{ STEP_EXECUTE_OP, segment, opIndex, &segmentConfig->ops[opIndex] });
}
Expand All @@ -67,6 +65,10 @@ NnExecutor::NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnDevic
context.device = device;
context.nSteps = (NnUint)steps.size();
context.steps = steps.data();
if (benchmark)
context.timer = new Timer();
else
context.timer = nullptr;

threads = new NnExecutorThread[netExecution->nThreads];
for (NnUint threadIndex = 0; threadIndex < netExecution->nThreads; threadIndex++) {
Expand All @@ -77,6 +79,8 @@ NnExecutor::NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnDevic
}

NnExecutor::~NnExecutor() {
if (context.timer != nullptr)
delete context.timer;
delete[] threads;
}

Expand All @@ -97,11 +101,6 @@ void NnExecutor::loadWeight(const char *name, NnUint index, NnSize nBytes, NnByt
}

inline void executeStep(NnExecutorStep *step, NnUint nThreads, NnExecutorThread *thread, NnExecutorContext *context) {
#if DEBUG_EXECUTOR_BENCHMARK
assert(nThreads == 1);
Timer startTime;
#endif

if (step->type == STEP_EXECUTE_OP) {
step->segment->forward(step->arg0, nThreads, thread->threadIndex, context->batchSize);
} else if (step->type == STEP_SYNC_NODES) {
Expand All @@ -112,14 +111,6 @@ inline void executeStep(NnExecutorStep *step, NnUint nThreads, NnExecutorThread
} else {
throw std::invalid_argument("Unsupported step type");
}

#if DEBUG_EXECUTOR_BENCHMARK
NnUint duration = startTime.elapsedMicroseconds();
if (step->type == STEP_EXECUTE_OP)
printf("🕒 [OP %16s %2d] %u μs\n", opCodeToString(step->opConfig->code), step->opConfig->index, duration);
else if (step->type == STEP_SYNC_NODES)
printf("🕒 [SYNC %17d] %u μs\n", step->arg0, duration);
#endif
}

static inline void *executorThreadHandler(void *arg) {
Expand All @@ -138,6 +129,12 @@ static inline void *executorThreadHandler(void *arg) {

NnUint currentCount = context->doneThreadCount.fetch_add(1);
if (currentCount == doneCount) {
if (context->timer != nullptr) {
NnUint time = context->timer->elapsedMicroseconds();
context->totalTime[step->type] += time;
context->timer->reset();
}

context->doneThreadCount.store(0);
context->currentStepIndex.fetch_add(1);
} else {
Expand All @@ -155,6 +152,11 @@ void NnExecutor::forward() {
context.doneThreadCount.exchange(0);
context.batchSize = netExecution->batchSize;

if (context.timer != nullptr) {
std::memset(context.totalTime, 0, sizeof(context.totalTime));
context.timer->reset();
}

NnUint threadIndex;
for (threadIndex = 1; threadIndex < nThreads; threadIndex++) {
int result = pthread_create(&threads[threadIndex].handler, NULL, (PthreadFunc)executorThreadHandler, (void *)&threads[threadIndex]);
Expand All @@ -165,3 +167,8 @@ void NnExecutor::forward() {
for (threadIndex = 1; threadIndex < nThreads; threadIndex++)
pthread_join(threads[threadIndex].handler, NULL);
}

NnUint NnExecutor::getTotalTime(NnExecutorStepType type) {
assert(type < N_STEP_TYPES);
return context.totalTime[type];
}
10 changes: 8 additions & 2 deletions src/nn/nn-executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ enum NnExecutorStepType {
STEP_SYNC_POINTERS
};

#define N_STEP_TYPES STEP_SYNC_POINTERS + 1

typedef struct {
NnExecutorStepType type;
NnDeviceSegment *segment;
Expand All @@ -68,6 +70,8 @@ typedef struct {
std::atomic_uint currentStepIndex;
std::atomic_uint doneThreadCount;
NnUint batchSize;
Timer *timer;
NnUint totalTime[N_STEP_TYPES];
} NnExecutorContext;

typedef struct {
Expand All @@ -77,17 +81,19 @@ typedef struct {
} NnExecutorThread;

class NnExecutor {
public:
private:
NnNetExecution *netExecution;
NnNodeConfig *nodeConfig;
std::vector<std::unique_ptr<NnDeviceSegment>> segments;
std::vector<NnExecutorStep> steps;
NnExecutorThread *threads;
NnExecutorContext context;
NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnDevice *device, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer);
public:
NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnDevice *device, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer, bool benchmark);
~NnExecutor();
void loadWeight(const char *name, NnUint index, NnSize nBytes, NnByte *weight);
void forward();
NnUint getTotalTime(NnExecutorStepType type);
};

#endif

0 comments on commit a91745d

Please sign in to comment.