Skip to content

Commit

Permalink
llama : add llama_model_decoder_start_token() API call that returns d…
Browse files Browse the repository at this point in the history
…ecoder_start_token_id

llama : add llama_model_has_encoder() API call
llama-cli : use llama_model_has_encoder() and llama_model_decoder_start_token() API calls
  • Loading branch information
sszymczy committed Jun 19, 2024
1 parent cd9a969 commit e7bd870
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 8 deletions.
22 changes: 14 additions & 8 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,16 +501,22 @@ int main(int argc, char ** argv) {
exit(1);
}

int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data();
if (llama_model_has_encoder(model)) {
int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data();

if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}

embd_inp.clear();
embd_inp.push_back(llama_token_pad(model));
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) {
decoder_start_token_id = llama_token_bos(model);
}
embd_inp.clear();
embd_inp.push_back(decoder_start_token_id);
}

while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
Expand Down
18 changes: 18 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ enum llm_kv {
LLM_KV_EXPERT_WEIGHTS_SCALE,
LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID,

LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
Expand Down Expand Up @@ -384,6 +385,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },

{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
Expand Down Expand Up @@ -1908,6 +1910,7 @@ struct llama_hparams {
uint32_t n_expert_used = 0;
uint32_t n_vocab_type = 0; // for BERT-style token types
uint32_t n_rel_attn_bkts = 0;
int32_t decoder_start_token_id = -1;

uint32_t n_layer_dense_lead = 0;
uint32_t n_lora_q = 0;
Expand Down Expand Up @@ -4606,6 +4609,10 @@ static void llm_load_hparams(
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
uint32_t decoder_start_token_id;
if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, decoder_start_token_id, false)) {
hparams.decoder_start_token_id = decoder_start_token_id;
}
model.type = e_model::MODEL_UNKNOWN;
} break;
default: (void)0;
Expand Down Expand Up @@ -17872,6 +17879,17 @@ struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const ch
return it->second;
}

bool llama_model_has_encoder(const struct llama_model * model) {
switch (model->arch) {
case LLM_ARCH_T5: return true;
default: return false;
}
}

llama_token llama_model_decoder_start_token(const struct llama_model * model) {
return model->hparams.decoder_start_token_id;
}

uint32_t llama_model_quantize(
const char * fname_inp,
const char * fname_out,
Expand Down
7 changes: 7 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,13 @@ extern "C" {
// Get a llama model tensor
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);

// Returns true if the model contains an encoder that requires llama_encode() call
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);

// For encoder-decoder models, this function returns id of the token that must be provided
// to the decoder to start generating output sequence. For other models, it returns -1.
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);

// Returns 0 on success
LLAMA_API uint32_t llama_model_quantize(
const char * fname_inp,
Expand Down

0 comments on commit e7bd870

Please sign in to comment.