Skip to content

Commit

Permalink
llama : add n_enc_output field in llm_build_context containing the nu…
Browse files Browse the repository at this point in the history
…mber of embeddings generated by the encoder
  • Loading branch information
sszymczy committed Jun 20, 2024
1 parent 684160a commit eb4c17e
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7347,6 +7347,7 @@ struct llm_build_context {
const int32_t n_tokens;
const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
const int32_t n_outputs;
const int32_t n_enc_outputs;
const int32_t kv_head; // index of where we store new KV data in the cache
const int32_t n_ctx_orig;

Expand Down Expand Up @@ -7396,6 +7397,7 @@ struct llm_build_context {
n_tokens (batch.n_tokens),
n_kv (worst_case ? kv_self.size : kv_self.n),
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
n_enc_outputs (worst_case ? n_tokens : lctx.encoder_output.size() / hparams.n_embd),
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_ctx_orig (cparams.n_ctx_orig_yarn),
flash_attn (cparams.flash_attn),
Expand Down Expand Up @@ -7660,14 +7662,14 @@ struct llm_build_context {

struct ggml_tensor * llm_build_inp_enc_output() {
const int64_t n_embd = hparams.n_embd;
lctx.inp_enc_output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, lctx.encoder_output.size() == 0 ? 512 : lctx.encoder_output.size() / n_embd);
lctx.inp_enc_output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc_outputs);
ggml_set_input(lctx.inp_enc_output);
cb(lctx.inp_enc_output, "enc_output", -1);
return lctx.inp_enc_output;
}

struct ggml_tensor * llm_build_inp_cross_KQ_mask() {
lctx.inp_cross_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, lctx.encoder_output.size() == 0 ? 512 : lctx.encoder_output.size() / n_embd, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
lctx.inp_cross_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc_outputs, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
ggml_set_input(lctx.inp_cross_KQ_mask);
cb(lctx.inp_cross_KQ_mask, "enc_mask", -1);
return lctx.inp_cross_KQ_mask;
Expand Down Expand Up @@ -11717,7 +11719,6 @@ struct llm_build_context {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
const int32_t n_enc_output = lctx.encoder_output.size() == 0 ? 512 : lctx.encoder_output.size() / n_embd;

struct ggml_tensor * cur;
struct ggml_tensor * inpL;
Expand Down Expand Up @@ -11926,7 +11927,7 @@ struct llm_build_context {
cb(Vcur, "Vcur", il);

Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_enc_output);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_enc_outputs);

struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
Expand All @@ -11937,10 +11938,10 @@ struct llm_build_context {
kq = ggml_soft_max_ext(ctx0, kq, enc_KQ_mask, 1.0f, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);

struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_enc_output)));
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_enc_outputs)));
cb(v, "v", il);

struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_enc_output, n_embd_head, n_head_kv), kq);
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_enc_outputs, n_embd_head, n_head_kv), kq);
cb(kqv, "kqv", il);

struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
Expand Down

0 comments on commit eb4c17e

Please sign in to comment.