From eb4c17ed4c38db97c72852cf9f408daced1d92ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Thu, 20 Jun 2024 13:07:27 +0200 Subject: [PATCH] llama : add n_enc_output field in llm_build_context containing the number of embeddings generated by the encoder --- llama.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/llama.cpp b/llama.cpp index 651032a21f513..6b9d6b702a4a7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7347,6 +7347,7 @@ struct llm_build_context { const int32_t n_tokens; const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) const int32_t n_outputs; + const int32_t n_enc_outputs; const int32_t kv_head; // index of where we store new KV data in the cache const int32_t n_ctx_orig; @@ -7396,6 +7397,7 @@ struct llm_build_context { n_tokens (batch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), n_outputs (worst_case ? n_tokens : lctx.n_outputs), + n_enc_outputs (worst_case ? n_tokens : lctx.encoder_output.size() / hparams.n_embd), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), @@ -7660,14 +7662,14 @@ struct llm_build_context { struct ggml_tensor * llm_build_inp_enc_output() { const int64_t n_embd = hparams.n_embd; - lctx.inp_enc_output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, lctx.encoder_output.size() == 0 ? 512 : lctx.encoder_output.size() / n_embd); + lctx.inp_enc_output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc_outputs); ggml_set_input(lctx.inp_enc_output); cb(lctx.inp_enc_output, "enc_output", -1); return lctx.inp_enc_output; } struct ggml_tensor * llm_build_inp_cross_KQ_mask() { - lctx.inp_cross_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, lctx.encoder_output.size() == 0 ? 512 : lctx.encoder_output.size() / n_embd, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.inp_cross_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc_outputs, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); ggml_set_input(lctx.inp_cross_KQ_mask); cb(lctx.inp_cross_KQ_mask, "enc_mask", -1); return lctx.inp_cross_KQ_mask; @@ -11717,7 +11719,6 @@ struct llm_build_context { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - const int32_t n_enc_output = lctx.encoder_output.size() == 0 ? 512 : lctx.encoder_output.size() / n_embd; struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -11926,7 +11927,7 @@ struct llm_build_context { cb(Vcur, "Vcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_enc_output); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_enc_outputs); struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); @@ -11937,10 +11938,10 @@ struct llm_build_context { kq = ggml_soft_max_ext(ctx0, kq, enc_KQ_mask, 1.0f, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); - struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_enc_output))); + struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_enc_outputs))); cb(v, "v", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_enc_output, n_embd_head, n_head_kv), kq); + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_enc_outputs, n_embd_head, n_head_kv), kq); cb(kqv, "kqv", il); struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);