Skip to content

Commit

Permalink
fix: tokenizer utf8 support (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored Feb 14, 2025
1 parent 5ce3c19 commit 63465c5
Show file tree
Hide file tree
Showing 7 changed files with 465 additions and 199 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ jobs:
make dllama
make nn-cpu-test
make nn-cpu-ops-test
make tokenizer-test
- name: nn-cpu-test
run: ./nn-cpu-test
- name: nn-cpu-ops-test
run: ./nn-cpu-ops-test
- name: tokenizer-test
run: ./tokenizer-test

build-windows:
name: Windows
Expand All @@ -51,7 +54,10 @@ jobs:
make dllama
make nn-cpu-test
make nn-cpu-ops-test
make tokenizer-test
- name: nn-cpu-test
run: ./nn-cpu-test
- name: nn-cpu-ops-test
run: ./nn-cpu-ops-test
- name: tokenizer-test
run: ./tokenizer-test
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ifndef TERMUX_VERSION
endif

ifdef DEBUG
CXXFLAGS += -g
CXXFLAGS += -g -fsanitize=address
else
CXXFLAGS += -O3
endif
Expand Down
16 changes: 8 additions & 8 deletions src/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ class ApiServer {
}

inference->setBatchSize(1);
tokenizer->resetDecoder();

for (; pos < maxPredPos;) {
int prevToken = token;
Expand All @@ -407,24 +408,23 @@ class ApiServer {

token = sampler->sample(inference->logitsPipe);

char* piece = tokenizer->decode(prevToken, token);
bool isSafe = isSafePiece(piece);
EosDetectorType eosType = eosDetector->append(token, isSafe ? piece : "");
char *piece = tokenizer->decode(token);
EosDetectorType eosType = eosDetector->append(token, piece);

if (isSafePiece(piece)) {
if (piece != nullptr) {
printf("%s", piece);
fflush(stdout);
}

if (eosType == NOT_EOS || eosType == EOS) {
char* delta = eosDetector->getDelta();
if (delta != NULL) {
char *delta = eosDetector->getDelta();
if (delta != nullptr) {
std::string deltaStr(delta);
if (params.stream)
writeChatCompletionChunk(request, deltaStr, false);
buffer += deltaStr;
}
eosDetector->clear();
eosDetector->reset();
}
pos++;
if (eosType == EOS) break;
Expand Down Expand Up @@ -501,7 +501,7 @@ static void server(AppInferenceContext *context) {

TokenizerChatStops stops(context->tokenizer);
ChatTemplate chatTemplate(context->args->chatTemplateType, context->tokenizer->chatTemplate, stops.stops[0]);
EosDetector eosDetector(context->tokenizer->chatEosId, stops.nStops, stops.stops, stops.maxStopLength, stops.maxStopLength);
EosDetector eosDetector(stops.nStops, context->tokenizer->eosTokenIds.data(), stops.stops, stops.maxStopLength, stops.maxStopLength);
ApiServer api(context->inference, context->tokenizer, context->sampler, context->args, context->header, &eosDetector, &chatTemplate);

printf("Server URL: http://127.0.0.1:%d/v1/\n", context->args->port);
Expand Down
41 changes: 18 additions & 23 deletions src/dllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,31 +62,29 @@ static void inference(AppInferenceContext *context) {
fflush(stdout);

context->inference->setBatchSize(1);
context->tokenizer->resetDecoder();

Timer predTimer;
const NnSize maxPos = std::min(context->header->seqLen, context->args->steps);
for (; pos < maxPos; pos++) {
Timer tokenTimer;
unsigned int prevToken = token;
context->inference->setPosition(pos);
context->inference->setToken(0, token);
context->inference->forward();

token = context->sampler->sample(context->inference->logitsPipe);

char* piece = context->tokenizer->decode(prevToken, token);

if (isSafePiece(piece)) {
if (context->network != nullptr)
context->network->getStats(&sentBytes, &recvBytes);

printf("🔶 P%5u ms S%6zu kB R%6zu kB %s\n",
tokenTimer.elapsed(),
sentBytes / 1024,
recvBytes / 1024,
piece);
fflush(stdout);
}
char *piece = context->tokenizer->decode(token);

if (context->network != nullptr)
context->network->getStats(&sentBytes, &recvBytes);

printf("🔶 P%5u ms S%6zu kB R%6zu kB %s\n",
tokenTimer.elapsed(),
sentBytes / 1024,
recvBytes / 1024,
piece == nullptr ? "~" : piece);
fflush(stdout);
}
NnSize predTime = predTimer.elapsed();

Expand Down Expand Up @@ -126,7 +124,7 @@ static void chat(AppInferenceContext *context) {

TokenizerChatStops stops(context->tokenizer);
ChatTemplate chatTemplate(context->args->chatTemplateType, context->tokenizer->chatTemplate, stops.stops[0]);
EosDetector eosDetector(context->tokenizer->chatEosId, stops.nStops, stops.stops, stops.maxStopLength, stops.maxStopLength);
EosDetector eosDetector(stops.nStops, context->tokenizer->eosTokenIds.data(), stops.stops, stops.maxStopLength, stops.maxStopLength);

const size_t sysPromptLength = readStdin("💻 System prompt (optional): ", prompt, sizeof(prompt));
std::vector<ChatItem> deltaItems;
Expand Down Expand Up @@ -173,29 +171,26 @@ static void chat(AppInferenceContext *context) {
}

context->inference->setBatchSize(1);
context->tokenizer->resetDecoder();

printf("\n🤖 Assistant\n");
std::string answer;
while (pos < seqLen) {
int prevToken = token;

context->inference->setPosition(pos);
context->inference->setToken(0, token);

context->inference->forward();

token = context->sampler->sample(context->inference->logitsPipe);

char *piece = context->tokenizer->decode(prevToken, token);
bool isSafe = isSafePiece(piece);
EosDetectorType eosType = eosDetector.append(token, isSafe ? piece : "");
char *piece = context->tokenizer->decode(token);
EosDetectorType eosType = eosDetector.append(token, piece);
if (eosType == NOT_EOS || eosType == EOS) {
char *delta = eosDetector.getDelta();
if (delta != NULL) {
if (delta != nullptr) {
printf("%s", delta);
fflush(stdout);
}
eosDetector.clear();
eosDetector.reset();
}
pos++;
if (eosType == EOS) break;
Expand Down
Loading

0 comments on commit 63465c5

Please sign in to comment.