Skip to content

Commit

Permalink
llama : inference support for FLAN-T5 model family
Browse files Browse the repository at this point in the history
  • Loading branch information
sszymczy committed Jun 23, 2024
1 parent c7db40e commit dae5b79
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,9 +523,9 @@ enum llm_tensor {
LLM_TENSOR_DEC_CROSS_ATTN_OUT,
LLM_TENSOR_DEC_CROSS_ATTN_REL_B,
LLM_TENSOR_DEC_FFN_NORM,
LLM_TENSOR_DEC_FFN_GATE,
LLM_TENSOR_DEC_FFN_DOWN,
LLM_TENSOR_DEC_FFN_UP,
LLM_TENSOR_DEC_OUTPUT,
LLM_TENSOR_DEC_OUTPUT_NORM,
LLM_TENSOR_ENC_ATTN_NORM,
LLM_TENSOR_ENC_ATTN_Q,
Expand All @@ -534,9 +534,9 @@ enum llm_tensor {
LLM_TENSOR_ENC_ATTN_OUT,
LLM_TENSOR_ENC_ATTN_REL_B,
LLM_TENSOR_ENC_FFN_NORM,
LLM_TENSOR_ENC_FFN_GATE,
LLM_TENSOR_ENC_FFN_DOWN,
LLM_TENSOR_ENC_FFN_UP,
LLM_TENSOR_ENC_OUTPUT,
LLM_TENSOR_ENC_OUTPUT_NORM,
};

Expand Down Expand Up @@ -1155,6 +1155,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
LLM_ARCH_T5,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_DEC_OUTPUT_NORM, "dec.output_norm" },
{ LLM_TENSOR_DEC_ATTN_NORM, "dec.blk.%d.attn_norm" },
{ LLM_TENSOR_DEC_ATTN_Q, "dec.blk.%d.attn_q" },
Expand All @@ -1169,6 +1170,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_DEC_CROSS_ATTN_OUT, "dec.blk.%d.cross_attn_o" },
{ LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "dec.blk.%d.cross_attn_rel_b" },
{ LLM_TENSOR_DEC_FFN_NORM, "dec.blk.%d.ffn_norm" },
{ LLM_TENSOR_DEC_FFN_GATE, "dec.blk.%d.ffn_gate" },
{ LLM_TENSOR_DEC_FFN_DOWN, "dec.blk.%d.ffn_down" },
{ LLM_TENSOR_DEC_FFN_UP, "dec.blk.%d.ffn_up" },
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
Expand All @@ -1179,6 +1181,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" },
{ LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" },
{ LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" },
{ LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" },
{ LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" },
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
},
Expand Down Expand Up @@ -2237,6 +2240,7 @@ struct llama_layer {
struct ggml_tensor * ffn_gate; // w1
struct ggml_tensor * ffn_down; // w2
struct ggml_tensor * ffn_up; // w3
struct ggml_tensor * enc_ffn_gate;
struct ggml_tensor * enc_ffn_down;
struct ggml_tensor * enc_ffn_up;

Expand Down Expand Up @@ -6827,13 +6831,13 @@ static bool llm_load_tensors(

model.rel_attn_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", 0), {hparams.n_head, hparams.n_rel_attn_bkts});
// this tensor seems to be unused in HF transformers implementation
model.cross_rel_attn_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", 0), {hparams.n_head, hparams.n_rel_attn_bkts});
model.cross_rel_attn_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", 0), {hparams.n_head, hparams.n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);

// output
{
model.enc_output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd});
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_DEC_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (model.output == NULL) {
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
Expand All @@ -6854,6 +6858,7 @@ static bool llm_load_tensors(
layer.enc_wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});

layer.enc_ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd});
layer.enc_ffn_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.enc_ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd});
layer.enc_ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff});

Expand All @@ -6872,6 +6877,7 @@ static bool llm_load_tensors(
layer.cross_wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});

layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd});
layer.ffn_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff});

Expand Down Expand Up @@ -12074,12 +12080,15 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);

// T5 uses relu, flan-T5 uses gelu-gated
cur = llm_build_ffn(ctx0, cur,
model.layers[il].enc_ffn_up, NULL,
NULL, NULL,
model.layers[il].enc_ffn_gate, NULL,
model.layers[il].enc_ffn_down, NULL,
NULL,
LLM_FFN_RELU, LLM_FFN_SEQ, cb, il);
model.layers[il].enc_ffn_gate ? LLM_FFN_GELU : LLM_FFN_RELU,
model.layers[il].enc_ffn_gate ? LLM_FFN_PAR : LLM_FFN_SEQ,
cb, il);
cb(cur, "ffn_out", il);
}

Expand Down Expand Up @@ -12246,12 +12255,15 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);

// T5 uses relu, flan-T5 uses gelu-gated
cur = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_up, NULL,
NULL, NULL,
model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_RELU, LLM_FFN_SEQ, cb, il);
model.layers[il].enc_ffn_gate ? LLM_FFN_GELU : LLM_FFN_RELU,
model.layers[il].enc_ffn_gate ? LLM_FFN_PAR : LLM_FFN_SEQ,
cb, il);
cb(cur, "ffn_out", il);
}

Expand Down

0 comments on commit dae5b79

Please sign in to comment.