diff --git a/include/ctranslate2/layers/transformer.h b/include/ctranslate2/layers/transformer.h index 95576ef0f..017f9b675 100644 --- a/include/ctranslate2/layers/transformer.h +++ b/include/ctranslate2/layers/transformer.h @@ -119,6 +119,8 @@ namespace ctranslate2 { const std::unique_ptr _shared_layer_norm; const std::unique_ptr _input_layer_norm; const std::unique_ptr _post_attention_layer_norm; + const std::unique_ptr _pre_feedforward_layer_norm; + const std::unique_ptr _post_feedforward_layer_norm; const std::unique_ptr _encoder_attention; const FeedForwardNetwork _ff; }; diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index d98c65860..cd8e8aef4 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -1421,6 +1421,110 @@ def set_decoder(self, spec, module): gc.collect() +@register_loader("Gemma2Config") +class Gemma2Loader(ModelLoader): + @property + def architecture_name(self): + return "Gemma2ForCausalLM" + + def get_model_spec(self, model): + num_layers = model.config.num_hidden_layers + + num_heads = model.config.num_attention_heads + num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads) + if num_heads_kv == num_heads: + num_heads_kv = None + + activation_config = getattr( + model.config, "hidden_activation", "gelu_pytorch_tanh" + ) + + spec = transformer_spec.TransformerDecoderModelSpec.from_config( + num_layers, + num_heads, + activation=( + common_spec.Activation.GELU + if activation_config == "gelu" + else common_spec.Activation.GELUTanh + ), + pre_norm=True, + ffn_glu=True, + rms_norm=True, + rotary_dim=0, + rotary_interleave=False, + rotary_base=getattr(model.config, "rope_theta", 10000), + num_heads_kv=num_heads_kv, + head_dim=model.config.head_dim, + pre_post_layer_norm=True, + ) + + self.set_decoder(spec.decoder, model.model) + self.set_linear(spec.decoder.projection, model.lm_head) + spec.decoder.embeddings.multiply_by_sqrt_depth = model.config.hidden_size**0.5 + return spec + + def get_vocabulary(self, model, tokenizer): + tokens = super().get_vocabulary(model, tokenizer) + + extra_ids = model.config.vocab_size - len(tokens) + for i in range(extra_ids): + tokens.append("" % i) + if model.config.vocab_size < len(tokens): + tokens = tokens[: model.config.vocab_size] + + return tokens + + def set_vocabulary(self, spec, tokens): + spec.register_vocabulary(tokens) + + def set_config(self, config, model, tokenizer): + config.bos_token = tokenizer.bos_token + config.eos_token = tokenizer.eos_token + config.unk_token = tokenizer.unk_token + config.layer_norm_epsilon = model.config.rms_norm_eps + + def set_layer_norm(self, spec, layer_norm): + spec.gamma = layer_norm.weight + spec.layer_norm_use_residual = True + + def set_decoder(self, spec, module): + spec.scale_embeddings = True + spec.start_from_zero_embedding = False + self.set_embeddings(spec.embeddings, module.embed_tokens) + self.set_layer_norm(spec.layer_norm, module.norm) + + for layer_spec, layer in zip(spec.layer, module.layers): + self.set_layer_norm(layer_spec.input_layer_norm, layer.input_layernorm) + + self.set_layer_norm( + layer_spec.post_attention_layer_norm, layer.post_attention_layernorm + ) + + self.set_layer_norm( + layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm + ) + + self.set_layer_norm( + layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm + ) + + wq = layer.self_attn.q_proj.weight + wk = layer.self_attn.k_proj.weight + wv = layer.self_attn.v_proj.weight + wo = layer.self_attn.o_proj.weight + + layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv]) + layer_spec.self_attention.linear[1].weight = wo + + self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj) + self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj) + self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj) + + delattr(layer, "self_attn") + delattr(layer, "mlp") + gc.collect() + + @register_loader("LlamaConfig") class LlamaLoader(ModelLoader): @property diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index abb812c8b..230e62cfd 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -101,6 +101,7 @@ def __init__( max_position_embeddings: int = 0, parallel_residual: bool = False, shared_layer_norm: bool = False, + pre_post_layer_norm: bool = False, multi_query_attention: bool = False, num_heads_kv: Optional[int] = None, head_dim: Optional[int] = None, @@ -147,6 +148,7 @@ def __init__( by the GPT-J and GPT-NeoX models. shared_layer_norm: When using parallel residual, share the input and post attention layer norms. + pre_post_layer_norm: Add post layer norm for each pre norm layer multi_query_attention: Use multi-query attention (alias for num_heads_kv=1). num_heads_kv: Number of attention heads for the key and value. sliding_window: Max sequence length to retain in KV Cache. @@ -216,6 +218,7 @@ def __init__( max_position_embeddings=max_position_embeddings, parallel_residual=parallel_residual, shared_layer_norm=shared_layer_norm, + pre_post_layer_norm=pre_post_layer_norm, num_heads_kv=num_heads_kv, head_dim=head_dim, sliding_window=sliding_window, @@ -279,6 +282,7 @@ def __init__( max_position_embeddings=0, parallel_residual=False, shared_layer_norm=False, + pre_post_layer_norm=False, num_heads_kv=None, head_dim=None, sliding_window=None, @@ -319,6 +323,21 @@ def __init__( delattr(self.self_attention, "layer_norm") delattr(self.ffn, "layer_norm") + if pre_post_layer_norm: + self.input_layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm) + self.post_attention_layer_norm = common_spec.LayerNormSpec( + rms_norm=rms_norm + ) + self.pre_feedforward_layer_norm = common_spec.LayerNormSpec( + rms_norm=rms_norm + ) + self.post_feedforward_layer_norm = common_spec.LayerNormSpec( + rms_norm=rms_norm + ) + + delattr(self.self_attention, "layer_norm") + delattr(self.ffn, "layer_norm") + class FeedForwardSpec(model_spec.LayerSpec): def __init__(self, glu=False, rms_norm=False): @@ -530,6 +549,7 @@ def from_config( max_position_embeddings: int = 0, parallel_residual: bool = False, shared_layer_norm: bool = False, + pre_post_layer_norm: bool = False, multi_query_attention: bool = False, num_heads_kv: Optional[int] = None, head_dim: Optional[int] = None, @@ -570,6 +590,7 @@ def from_config( by the GPT-J and GPT-NeoX models. shared_layer_norm: When using parallel residual, share the input and post attention layer norms. + pre_post_layer_norm: add post layer norm for each pre norm layer multi_query_attention: Use multi-query attention (alias for num_heads_kv=1). num_heads_kv: Number of attention heads for the key and value. head_dim: Number of head @@ -602,6 +623,7 @@ def from_config( max_position_embeddings=max_position_embeddings, parallel_residual=parallel_residual, shared_layer_norm=shared_layer_norm, + pre_post_layer_norm=pre_post_layer_norm, multi_query_attention=multi_query_attention, num_heads_kv=num_heads_kv, head_dim=head_dim, diff --git a/src/layers/attention.cc b/src/layers/attention.cc index 24ffffdc8..a206bcd05 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -363,7 +363,7 @@ namespace ctranslate2 { if (queries_padder) queries_padder->add_padding(fused_proj); - const ops::Split split_op(2, {_d_model, _num_heads_kv * _d_head, _num_heads_kv * _d_head}); + const ops::Split split_op(2, {_num_heads * _d_head, _num_heads_kv * _d_head, _num_heads_kv * _d_head}); split_op(fused_proj, queries_proj, keys_proj, values_proj); if (_merge_time_and_head_dims) { diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc index 291101eae..5ac5bfa36 100644 --- a/src/layers/transformer.cc +++ b/src/layers/transformer.cc @@ -120,6 +120,10 @@ namespace ctranslate2 { , _input_layer_norm(build_optional_layer(model, scope + "/input_layer_norm")) , _post_attention_layer_norm(build_optional_layer( model, scope + "/post_attention_layer_norm")) + , _pre_feedforward_layer_norm(build_optional_layer( + model, scope + "/pre_feedforward_layer_norm")) + , _post_feedforward_layer_norm(build_optional_layer( + model, scope + "/post_feedforward_layer_norm")) , _encoder_attention(build_optional_layer(model, scope + "/attention", num_heads, @@ -149,6 +153,41 @@ namespace ctranslate2 { const DataType dtype = input.dtype(); const Device device = input.device(); + const bool pre_post_layer_norm = _post_feedforward_layer_norm && _pre_feedforward_layer_norm; + if (pre_post_layer_norm) { + StorageView hidden(dtype, device); + StorageView context(dtype, device); + (*_input_layer_norm)(input, hidden); + + if (_self_attention) + (*_self_attention)(hidden, + hidden, + input_length, + context, + cached_self_attn_keys, + cached_self_attn_values, + nullptr, + input_padder, + input_padder, + true, + position_bias, + offset); + + (*_post_attention_layer_norm)(context, output); + ops::Add()(output, input, output); + + context = std::move(output); + (*_pre_feedforward_layer_norm)(context, output); + hidden = std::move(output); + + _ff(hidden, output); + + hidden = std::move(output); + (*_post_feedforward_layer_norm)(hidden, output); + ops::Add()(output, context, output); + return; + } + const bool use_parallel_residual = _shared_layer_norm || _input_layer_norm; if (use_parallel_residual) { diff --git a/third_party/googletest b/third_party/googletest index b7d472f12..f8d7d77c0 160000 --- a/third_party/googletest +++ b/third_party/googletest @@ -1 +1 @@ -Subproject commit b7d472f1225c5a64943821d8483fecb469d3f382 +Subproject commit f8d7d77c06936315286eb55f8de22cd23c188571