Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support minimum gemma 2 #1772

Merged
merged 3 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/ctranslate2/layers/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ namespace ctranslate2 {
const std::unique_ptr<const LayerNorm> _shared_layer_norm;
const std::unique_ptr<const LayerNorm> _input_layer_norm;
const std::unique_ptr<const LayerNorm> _post_attention_layer_norm;
const std::unique_ptr<const LayerNorm> _pre_feedforward_layer_norm;
const std::unique_ptr<const LayerNorm> _post_feedforward_layer_norm;
const std::unique_ptr<const AttentionLayer> _encoder_attention;
const FeedForwardNetwork _ff;
};
Expand Down
104 changes: 104 additions & 0 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,110 @@ def set_decoder(self, spec, module):
gc.collect()


@register_loader("Gemma2Config")
class Gemma2Loader(ModelLoader):
@property
def architecture_name(self):
return "Gemma2ForCausalLM"

def get_model_spec(self, model):
num_layers = model.config.num_hidden_layers

num_heads = model.config.num_attention_heads
num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
if num_heads_kv == num_heads:
num_heads_kv = None

activation_config = getattr(
model.config, "hidden_activation", "gelu_pytorch_tanh"
)

spec = transformer_spec.TransformerDecoderModelSpec.from_config(
num_layers,
num_heads,
activation=(
common_spec.Activation.GELU
if activation_config == "gelu"
else common_spec.Activation.GELUTanh
),
pre_norm=True,
ffn_glu=True,
rms_norm=True,
rotary_dim=0,
rotary_interleave=False,
rotary_base=getattr(model.config, "rope_theta", 10000),
num_heads_kv=num_heads_kv,
head_dim=model.config.head_dim,
pre_post_layer_norm=True,
)

self.set_decoder(spec.decoder, model.model)
self.set_linear(spec.decoder.projection, model.lm_head)
spec.decoder.embeddings.multiply_by_sqrt_depth = model.config.hidden_size**0.5
return spec

def get_vocabulary(self, model, tokenizer):
tokens = super().get_vocabulary(model, tokenizer)

extra_ids = model.config.vocab_size - len(tokens)
for i in range(extra_ids):
tokens.append("<extra_id_%d>" % i)
if model.config.vocab_size < len(tokens):
tokens = tokens[: model.config.vocab_size]

return tokens

def set_vocabulary(self, spec, tokens):
spec.register_vocabulary(tokens)

def set_config(self, config, model, tokenizer):
config.bos_token = tokenizer.bos_token
config.eos_token = tokenizer.eos_token
config.unk_token = tokenizer.unk_token
config.layer_norm_epsilon = model.config.rms_norm_eps

def set_layer_norm(self, spec, layer_norm):
spec.gamma = layer_norm.weight
spec.layer_norm_use_residual = True

def set_decoder(self, spec, module):
spec.scale_embeddings = True
spec.start_from_zero_embedding = False
self.set_embeddings(spec.embeddings, module.embed_tokens)
self.set_layer_norm(spec.layer_norm, module.norm)

for layer_spec, layer in zip(spec.layer, module.layers):
self.set_layer_norm(layer_spec.input_layer_norm, layer.input_layernorm)

self.set_layer_norm(
layer_spec.post_attention_layer_norm, layer.post_attention_layernorm
)

self.set_layer_norm(
layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
)

self.set_layer_norm(
layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
)

wq = layer.self_attn.q_proj.weight
wk = layer.self_attn.k_proj.weight
wv = layer.self_attn.v_proj.weight
wo = layer.self_attn.o_proj.weight

layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv])
layer_spec.self_attention.linear[1].weight = wo

self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)

delattr(layer, "self_attn")
delattr(layer, "mlp")
gc.collect()


@register_loader("LlamaConfig")
class LlamaLoader(ModelLoader):
@property
Expand Down
22 changes: 22 additions & 0 deletions python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
max_position_embeddings: int = 0,
parallel_residual: bool = False,
shared_layer_norm: bool = False,
pre_post_layer_norm: bool = False,
multi_query_attention: bool = False,
num_heads_kv: Optional[int] = None,
head_dim: Optional[int] = None,
Expand Down Expand Up @@ -147,6 +148,7 @@ def __init__(
by the GPT-J and GPT-NeoX models.
shared_layer_norm: When using parallel residual, share the input and post
attention layer norms.
pre_post_layer_norm: Add post layer norm for each pre norm layer
multi_query_attention: Use multi-query attention (alias for num_heads_kv=1).
num_heads_kv: Number of attention heads for the key and value.
sliding_window: Max sequence length to retain in KV Cache.
Expand Down Expand Up @@ -216,6 +218,7 @@ def __init__(
max_position_embeddings=max_position_embeddings,
parallel_residual=parallel_residual,
shared_layer_norm=shared_layer_norm,
pre_post_layer_norm=pre_post_layer_norm,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
sliding_window=sliding_window,
Expand Down Expand Up @@ -279,6 +282,7 @@ def __init__(
max_position_embeddings=0,
parallel_residual=False,
shared_layer_norm=False,
pre_post_layer_norm=False,
num_heads_kv=None,
head_dim=None,
sliding_window=None,
Expand Down Expand Up @@ -319,6 +323,21 @@ def __init__(
delattr(self.self_attention, "layer_norm")
delattr(self.ffn, "layer_norm")

if pre_post_layer_norm:
self.input_layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
self.post_attention_layer_norm = common_spec.LayerNormSpec(
rms_norm=rms_norm
)
self.pre_feedforward_layer_norm = common_spec.LayerNormSpec(
rms_norm=rms_norm
)
self.post_feedforward_layer_norm = common_spec.LayerNormSpec(
rms_norm=rms_norm
)

delattr(self.self_attention, "layer_norm")
delattr(self.ffn, "layer_norm")


class FeedForwardSpec(model_spec.LayerSpec):
def __init__(self, glu=False, rms_norm=False):
Expand Down Expand Up @@ -530,6 +549,7 @@ def from_config(
max_position_embeddings: int = 0,
parallel_residual: bool = False,
shared_layer_norm: bool = False,
pre_post_layer_norm: bool = False,
multi_query_attention: bool = False,
num_heads_kv: Optional[int] = None,
head_dim: Optional[int] = None,
Expand Down Expand Up @@ -570,6 +590,7 @@ def from_config(
by the GPT-J and GPT-NeoX models.
shared_layer_norm: When using parallel residual, share the input and post
attention layer norms.
pre_post_layer_norm: add post layer norm for each pre norm layer
multi_query_attention: Use multi-query attention (alias for num_heads_kv=1).
num_heads_kv: Number of attention heads for the key and value.
head_dim: Number of head
Expand Down Expand Up @@ -602,6 +623,7 @@ def from_config(
max_position_embeddings=max_position_embeddings,
parallel_residual=parallel_residual,
shared_layer_norm=shared_layer_norm,
pre_post_layer_norm=pre_post_layer_norm,
multi_query_attention=multi_query_attention,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
Expand Down
2 changes: 1 addition & 1 deletion src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ namespace ctranslate2 {
if (queries_padder)
queries_padder->add_padding(fused_proj);

const ops::Split split_op(2, {_d_model, _num_heads_kv * _d_head, _num_heads_kv * _d_head});
const ops::Split split_op(2, {_num_heads * _d_head, _num_heads_kv * _d_head, _num_heads_kv * _d_head});
split_op(fused_proj, queries_proj, keys_proj, values_proj);

if (_merge_time_and_head_dims) {
Expand Down
39 changes: 39 additions & 0 deletions src/layers/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ namespace ctranslate2 {
, _input_layer_norm(build_optional_layer<LayerNorm>(model, scope + "/input_layer_norm"))
, _post_attention_layer_norm(build_optional_layer<LayerNorm>(
model, scope + "/post_attention_layer_norm"))
, _pre_feedforward_layer_norm(build_optional_layer<LayerNorm>(
model, scope + "/pre_feedforward_layer_norm"))
, _post_feedforward_layer_norm(build_optional_layer<LayerNorm>(
model, scope + "/post_feedforward_layer_norm"))
, _encoder_attention(build_optional_layer<MultiHeadAttention>(model,
scope + "/attention",
num_heads,
Expand Down Expand Up @@ -149,6 +153,41 @@ namespace ctranslate2 {
const DataType dtype = input.dtype();
const Device device = input.device();

const bool pre_post_layer_norm = _post_feedforward_layer_norm && _pre_feedforward_layer_norm;
if (pre_post_layer_norm) {
StorageView hidden(dtype, device);
StorageView context(dtype, device);
(*_input_layer_norm)(input, hidden);

if (_self_attention)
(*_self_attention)(hidden,
hidden,
input_length,
context,
cached_self_attn_keys,
cached_self_attn_values,
nullptr,
input_padder,
input_padder,
true,
position_bias,
offset);

(*_post_attention_layer_norm)(context, output);
ops::Add()(output, input, output);

context = std::move(output);
(*_pre_feedforward_layer_norm)(context, output);
hidden = std::move(output);

_ff(hidden, output);

hidden = std::move(output);
(*_post_feedforward_layer_norm)(hidden, output);
ops::Add()(output, context, output);
return;
}

const bool use_parallel_residual = _shared_layer_norm || _input_layer_norm;

if (use_parallel_residual) {
Expand Down
2 changes: 1 addition & 1 deletion third_party/googletest
Submodule googletest updated 245 files
Loading