Skip to content

Commit

Permalink
Switch to snprintf in chat() for out of bounds safety
Browse files Browse the repository at this point in the history
Also end prompt that's passed via command line with \n, as this is
necessary for some frames and we get \n from fgets() normally.
  • Loading branch information
zeux committed Apr 19, 2024
1 parent 312d0de commit 2026083
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions src/run.c
Original file line number Diff line number Diff line change
Expand Up @@ -334,19 +334,15 @@ static const char* chatframe(const char* style, bool has_system) {
}

void chat(struct Transformer* transformer, struct Tokenizer* tokenizer, struct Sampler* sampler, char* cli_prompt, char* system_prompt, const char* arch) {

// buffers for reading the system prompt and user prompt from stdin
// you'll notice they are soomewhat haphazardly and unsafely set atm
char user_prompt[512];
char rendered_prompt[1152];
char rendered_prompt[sizeof(user_prompt) * 2];
int prompt_tokens[sizeof(rendered_prompt) + 4];
int num_prompt_tokens = 0;
int* prompt_tokens = (int*)malloc(1152 * sizeof(int));
int user_idx;

// start the main loop
int user_idx = 0;
int user_turn = 1; // user starts
int next = 0; // will store the next token in the sequence
int token; // stores the current token to feed into the transformer
int token = 0; // stores the current token to feed into the transformer
int pos = 0; // position in the sequence
for (;;) {

Expand All @@ -355,7 +351,7 @@ void chat(struct Transformer* transformer, struct Tokenizer* tokenizer, struct S
// get the user prompt
if (pos == 0 && cli_prompt != NULL) {
// user prompt for position 0 was passed in, use it
strcpy(user_prompt, cli_prompt);
snprintf(user_prompt, sizeof(user_prompt), "%s\n", cli_prompt);
} else {
// otherwise get user prompt from stdin
double seq_pct = (double)pos / (double)transformer->config.seq_len;
Expand All @@ -371,9 +367,9 @@ void chat(struct Transformer* transformer, struct Tokenizer* tokenizer, struct S
// render user/system prompts into the chat schema
const char* style = tokenizer->eot_id >= 0 ? "llama3" : arch;
if (pos == 0 && system_prompt[0] != '\0') {
sprintf(rendered_prompt, chatframe(style, true), system_prompt, user_prompt);
snprintf(rendered_prompt, sizeof(rendered_prompt), chatframe(style, true), system_prompt, user_prompt);
} else {
sprintf(rendered_prompt, chatframe(style, false), user_prompt);
snprintf(rendered_prompt, sizeof(rendered_prompt), chatframe(style, false), user_prompt);
}

// encode the rendered prompt into tokens
Expand Down Expand Up @@ -412,8 +408,6 @@ void chat(struct Transformer* transformer, struct Tokenizer* tokenizer, struct S
}
}
}

free(prompt_tokens);
}

// ----------------------------------------------------------------------------
Expand Down

0 comments on commit 2026083

Please sign in to comment.