diff --git a/examples/main/main.cpp b/examples/main/main.cpp index ca34fd8be0cd8..dd451e3138338 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -501,16 +501,22 @@ int main(int argc, char ** argv) { exit(1); } - int enc_input_size = embd_inp.size(); - llama_token * enc_input_buf = embd_inp.data(); + if (llama_model_has_encoder(model)) { + int enc_input_size = embd_inp.size(); + llama_token * enc_input_buf = embd_inp.data(); - if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) { - LOG_TEE("%s : failed to eval\n", __func__); - return 1; - } + if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) { + LOG_TEE("%s : failed to eval\n", __func__); + return 1; + } - embd_inp.clear(); - embd_inp.push_back(llama_token_pad(model)); + llama_token decoder_start_token_id = llama_model_decoder_start_token(model); + if (decoder_start_token_id == -1) { + decoder_start_token_id = llama_token_bos(model); + } + embd_inp.clear(); + embd_inp.push_back(decoder_start_token_id); + } while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict diff --git a/llama.cpp b/llama.cpp index b1bba775202d7..e4f6c1b020800 100644 --- a/llama.cpp +++ b/llama.cpp @@ -296,6 +296,7 @@ enum llm_kv { LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, + LLM_KV_DECODER_START_TOKEN_ID, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -384,6 +385,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, + { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -1908,6 +1910,7 @@ struct llama_hparams { uint32_t n_expert_used = 0; uint32_t n_vocab_type = 0; // for BERT-style token types uint32_t n_rel_attn_bkts = 0; + int32_t decoder_start_token_id = -1; uint32_t n_layer_dense_lead = 0; uint32_t n_lora_q = 0; @@ -4606,6 +4609,10 @@ static void llm_load_hparams( { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + uint32_t decoder_start_token_id; + if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, decoder_start_token_id, false)) { + hparams.decoder_start_token_id = decoder_start_token_id; + } model.type = e_model::MODEL_UNKNOWN; } break; default: (void)0; @@ -17872,6 +17879,17 @@ struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const ch return it->second; } +bool llama_model_has_encoder(const struct llama_model * model) { + switch (model->arch) { + case LLM_ARCH_T5: return true; + default: return false; + } +} + +llama_token llama_model_decoder_start_token(const struct llama_model * model) { + return model->hparams.decoder_start_token_id; +} + uint32_t llama_model_quantize( const char * fname_inp, const char * fname_out, diff --git a/llama.h b/llama.h index d1d1a060df922..3945fcc513c02 100644 --- a/llama.h +++ b/llama.h @@ -482,6 +482,13 @@ extern "C" { // Get a llama model tensor LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name); + // Returns true if the model contains an encoder that requires llama_encode() call + LLAMA_API bool llama_model_has_encoder(const struct llama_model * model); + + // For encoder-decoder models, this function returns id of the token that must be provided + // to the decoder to start generating output sequence. For other models, it returns -1. + LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp,