From 5a849d69accde8915ff48c24508f2bc7a0562612 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Mon, 3 Feb 2025 08:44:41 -0800 Subject: [PATCH] Add AMPLIFY huggingface conversion utility Signed-off-by: Peter St. John --- Dockerfile | 10 +- .../src/bionemo/amplify/convert.py | 164 ++++++++++++++++++ .../tests/bionemo/amplify/test_convert.py | 43 +++++ .../src/bionemo/esm2/testing/compare.py | 96 ++++++++-- .../tests/bionemo/esm2/model/test_convert.py | 8 +- .../tests/bionemo/esm2/model/test_model.py | 10 +- .../llm/model/biobert/transformer_specs.py | 10 +- 7 files changed, 309 insertions(+), 32 deletions(-) create mode 100644 sub-packages/bionemo-amplify/src/bionemo/amplify/convert.py create mode 100644 sub-packages/bionemo-amplify/tests/bionemo/amplify/test_convert.py diff --git a/Dockerfile b/Dockerfile index 6642f3c5e7..dcae4b6a54 100644 --- a/Dockerfile +++ b/Dockerfile @@ -38,7 +38,7 @@ EOF # Reinstall TE to avoid debugpy bug in vscode: https://nvbugspro.nvidia.com/bug/5078830 # Pull the latest TE version from https://github.com/NVIDIA/TransformerEngine/releases # Use the version that matches the pytorch base container. -ARG TE_TAG=v1.13 +ARG TE_TAG=2215fa5c7557b66034068816020f9f611019e457 RUN NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi \ pip --disable-pip-version-check --no-cache-dir install \ git+https://github.com/NVIDIA/TransformerEngine.git@${TE_TAG} @@ -48,10 +48,13 @@ RUN NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi \ RUN CAUSAL_CONV1D_FORCE_BUILD=TRUE pip --disable-pip-version-check --no-cache-dir install \ git+https://github.com/Dao-AILab/causal-conv1d.git@v1.2.2.post1 -# Mamba dependancy installation +# Mamba dependency installation RUN pip --disable-pip-version-check --no-cache-dir install \ git+https://github.com/state-spaces/mamba.git@v2.2.2 +ARG XFORMER_ENGINE_TAG=v0.0.29.post1 +RUN pip install -v -U git+https://github.com/facebookresearch/xformers.git@${XFORMER_ENGINE_TAG}#egg=xformers + RUN pip install hatchling # needed to install nemo-run ARG NEMU_RUN_TAG=34259bd3e752fef94045a9a019e4aaf62bd11ce2 RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMU_RUN_TAG} @@ -100,7 +103,7 @@ COPY ./sub-packages /workspace/bionemo2/sub-packages RUN --mount=type=bind,source=./.git,target=./.git \ --mount=type=bind,source=./requirements-test.txt,target=/requirements-test.txt \ --mount=type=bind,source=./requirements-cve.txt,target=/requirements-cve.txt \ - --mount=type=cache,target=/root/.cache < BionemoLightningModule: + """Initialize the converted model.""" + return biobert_lightning_module(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + """Applies the transformation.""" + source = AutoModel.from_pretrained(str(self), trust_remote_code=True, torch_dtype="auto") + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + teardown(trainer, target) + return output_path + + def convert_state(self, source, target): + """Converting HF state dict to NeMo state dict.""" + mapping = { + "encoder.weight": "embedding.word_embeddings.weight", + "transformer_encoder.*.wo.weight": "encoder.layers.*.self_attention.linear_proj.weight", + "transformer_encoder.*.ffn.w12.weight": "encoder.layers.*.mlp.linear_fc1.weight", + "transformer_encoder.*.ffn.w3.weight": "encoder.layers.*.mlp.linear_fc2.weight", + "transformer_encoder.*.attention_norm.weight": "encoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "transformer_encoder.*.ffn_norm.weight": "encoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "layer_norm_2.weight": "encoder.final_layernorm.weight", + "decoder.weight": "output_layer.weight", + "decoder.bias": "output_layer.bias", + # "esm.encoder.layer.*.attention.output.dense.weight": "encoder.layers.*.self_attention.linear_proj.weight", + } + + # lm_head.bias + return io.apply_transforms( + source, + target, + mapping=mapping, + transforms=[_import_qkv_weight], + # transforms=[_pad_embeddings, _pad_bias, _import_qkv_weight], + ) + + @property + def tokenizer(self) -> BioNeMoAMPLIFYTokenizer: + """We just have the one tokenizer for ESM-2.""" + return BioNeMoAMPLIFYTokenizer() + + @property + def config(self) -> AMPLIFYConfig: + """Returns the transformed ESM-2 config given the model tag.""" + source = HFAutoConfig.from_pretrained(str(self), trust_remote_code=True) + output = AMPLIFYConfig( + num_layers=source.num_hidden_layers, + hidden_size=source.hidden_size, + ffn_hidden_size=source.intermediate_size, + position_embedding_type="rope", + num_attention_heads=source.num_attention_heads, + seq_length=source.max_length, + fp16=(dtype_from_hf(source) == torch.float16), + bf16=(dtype_from_hf(source) == torch.bfloat16), + params_dtype=dtype_from_hf(source), + ) + + return output + + +@io.state_transform( + source_key="esm.embeddings.word_embeddings.weight", + target_key="embedding.word_embeddings.weight", +) +def _pad_embeddings(ctx: io.TransformCTX, source_embed): + """Pad the embedding layer to the new input dimension.""" + nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by + hf_embedding_dimension = source_embed.size(0) + num_padding_rows = nemo_embedding_dimension - hf_embedding_dimension + padding_rows = torch.zeros(num_padding_rows, source_embed.size(1)) + return torch.cat((source_embed, padding_rows), dim=0) + + +@io.state_transform( + source_key="lm_head.bias", + target_key="output_layer.bias", +) +def _pad_bias(ctx: io.TransformCTX, source_bias): + """Pad the embedding layer to the new input dimension.""" + nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by + hf_embedding_dimension = source_bias.size(0) + output_bias = torch.zeros(nemo_embedding_dimension, dtype=source_bias.dtype, device=source_bias.device) + output_bias[:hf_embedding_dimension] = source_bias + return output_bias + + +@io.state_transform( + source_key=( + "transformer_encoder.*.q.weight", + "transformer_encoder.*.k.weight", + "transformer_encoder.*.v.weight", + ), + target_key="encoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv_weight(ctx: io.TransformCTX, query, key, value): + """Pad the embedding layer to the new input dimension.""" + concat_weights = torch.cat((query, key, value), dim=0) + input_shape = concat_weights.size() + np = ctx.target.config.num_attention_heads + # transpose weights + # [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads] + # --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads] + concat_weights = concat_weights.view(3, np, -1, query.size()[-1]) + concat_weights = concat_weights.transpose(0, 1).contiguous() + concat_weights = concat_weights.view(*input_shape) + return concat_weights + + +@io.state_transform( + source_key=( + "esm.encoder.layer.*.attention.self.query.bias", + "esm.encoder.layer.*.attention.self.key.bias", + "esm.encoder.layer.*.attention.self.value.bias", + ), + target_key="encoder.layers.*.self_attention.linear_qkv.bias", +) +def _import_qkv_bias(ctx: io.TransformCTX, query, key, value): + """Pad the embedding layer to the new input dimension.""" + concat_biases = torch.cat((query, key, value), dim=0) + input_shape = concat_biases.size() + np = ctx.target.config.num_attention_heads + # transpose biases + # [num_splits_model_parallel * attention head size * #attention heads] + # --> [attention head size * num_splits_model_parallel * #attention heads] + concat_biases = concat_biases.view(3, np, -1) + concat_biases = concat_biases.transpose(0, 1).contiguous() + concat_biases = concat_biases.view(*input_shape) + return concat_biases diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_convert.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_convert.py new file mode 100644 index 0000000000..97bda74441 --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_convert.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nemo.lightning import io + +from bionemo.amplify.convert import HFAMPLIFYImporter # noqa: F401 +from bionemo.amplify.model import AMPLIFYConfig +from bionemo.amplify.tokenizer import BioNeMoAMPLIFYTokenizer +from bionemo.core.utils.dtypes import PrecisionTypes +from bionemo.esm2.testing.compare import get_input_tensors, load_and_evaluate_hf_model +from bionemo.llm.model.biobert.lightning import biobert_lightning_module + + +def assert_amplify_equivalence(ckpt_path: str, model_tag: str, precision: PrecisionTypes) -> None: + tokenizer = BioNeMoAMPLIFYTokenizer() + + input_ids, attention_mask = get_input_tensors(tokenizer) + load_and_evaluate_hf_model(model_tag, precision, input_ids, attention_mask) + + +def test_convert_smoke_test_120M(tmp_path): + model_tag = "chandar-lab/AMPLIFY_120M" + module = biobert_lightning_module(config=AMPLIFYConfig()) + io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") + + +def test_convert_smoke_test_350M(tmp_path): + model_tag = "chandar-lab/AMPLIFY_350M" + module = biobert_lightning_module(config=AMPLIFYConfig()) + io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py index e8690c1d04..a08ada0ae6 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py @@ -26,7 +26,7 @@ from bionemo.esm2.model.model import ESM2Config -def assert_model_equivalence( +def assert_esm2_equivalence( ckpt_path: Path | str, model_tag: str, precision: PrecisionTypes = "fp32", @@ -49,13 +49,57 @@ def assert_model_equivalence( """ tokenizer = get_tokenizer() + input_ids, attention_mask = get_input_tensors(tokenizer) + + nemo_logits, nemo_hidden_state = load_and_evaluate_nemo_esm2(ckpt_path, precision, input_ids, attention_mask) + gc.collect() + torch.cuda.empty_cache() + hf_logits, hf_hidden_state = load_and_evaluate_hf_model(model_tag, precision, input_ids, attention_mask) + + # Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These + # should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences. + # We don't care about the padding tokens, so we only compare the non-padding tokens. + assert_cosine_similarity(nemo_logits, hf_logits, attention_mask, rtol, atol) + assert_cosine_similarity(nemo_hidden_state, hf_hidden_state, attention_mask, rtol, atol) + + +def get_input_tensors(tokenizer) -> tuple[torch.Tensor, torch.Tensor]: + """Get input tensors for testing. + + Args: + tokenizer: A huggingface-like tokenizer object. + + Returns: + A tuple of the input IDs and attention mask tensors. + """ test_proteins = [ "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA", "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLAGG", ] tokens = tokenizer(test_proteins, return_tensors="pt", padding=True, truncation=True).to("cuda") - input_ids = tokens["input_ids"] - attention_mask = tokens["attention_mask"] + input_ids: torch.Tensor = tokens["input_ids"] # type: ignore + attention_mask: torch.Tensor = tokens["attention_mask"] # type: ignore + return input_ids, attention_mask + + +def load_and_evaluate_nemo_esm2( + ckpt_path: Path | str, + precision: PrecisionTypes, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Load a NeMo2 ESM-2 model and evaluate it on the given inputs. + + Args: + ckpt_path: A path to a NeMo2 checkpoint for an ESM-2 model. + precision: The precision type to use for the comparison. + input_ids: The input IDs tensor to evaluate. + attention_mask: The attention mask tensor to evaluate. + + Returns: + A tuple of the logits and hidden states tensors calculated by the NeMo2 model, respectively. + """ + tokenizer = get_tokenizer() dtype = get_autocast_dtype(precision) nemo_config = ESM2Config( @@ -77,23 +121,45 @@ def assert_model_equivalence( nemo_output = nemo_model(input_ids, attention_mask) nemo_logits = nemo_output["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size] nemo_hidden_state = nemo_output["hidden_states"] + return nemo_logits, nemo_hidden_state - del nemo_model - gc.collect() - torch.cuda.empty_cache() +def load_and_evaluate_hf_model( + model_tag: str, precision: PrecisionTypes, input_ids: torch.Tensor, attention_mask: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """Load a HuggingFace model and evaluate it on the given inputs. + + Args: + model_tag: The HuggingFace model tag for the model to compare against. + precision: The precision type to use for the comparison. + input_ids: The input IDs tensor to evaluate. + attention_mask: The attention mask tensor to evaluate. + + Returns: + A tuple of the logits and hidden states tensors calculated by the HuggingFace model, respectively. + """ hf_model = AutoModelForMaskedLM.from_pretrained(model_tag, torch_dtype=get_autocast_dtype(precision)).cuda().eval() hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True) hf_hidden_state = hf_output_all.hidden_states[-1] + return hf_output_all.logits, hf_hidden_state - # Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These - # should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences. - # We don't care about the padding tokens, so we only compare the non-padding tokens. - logit_similarity = torch.nn.functional.cosine_similarity(nemo_logits, hf_output_all.logits, dim=2) - logit_similarity = logit_similarity[attention_mask == 1] - hidden_state_similarity = torch.nn.functional.cosine_similarity(nemo_hidden_state, hf_hidden_state, dim=2) - hidden_state_similarity = hidden_state_similarity[attention_mask == 1] +def assert_cosine_similarity( + tensor1: torch.Tensor, + tensor2: torch.Tensor, + mask: torch.Tensor, + rtol: float | None = None, + atol: float | None = None, +) -> None: + """Assert that the cosine similarity between two tensors is close to 1. - torch.testing.assert_close(logit_similarity, torch.ones_like(logit_similarity), rtol=rtol, atol=atol) - torch.testing.assert_close(hidden_state_similarity, torch.ones_like(hidden_state_similarity), rtol=rtol, atol=atol) + Args: + tensor1: The first tensor to compare. + tensor2: The second tensor to compare. + mask: A mask tensor to apply to the comparison. + rtol: The relative tolerance to use for the comparison. Defaults to 1e-4. + atol: The absolute tolerance to use for the comparison. Defaults to 1e-4. + """ + similarity = torch.nn.functional.cosine_similarity(tensor1, tensor2, dim=2) + similarity = similarity[mask == 1] + torch.testing.assert_close(similarity, torch.ones_like(similarity), rtol=rtol, atol=atol) diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py index de8a23a107..51f6fb7124 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py @@ -19,7 +19,7 @@ from bionemo.esm2.model.convert import HFESM2Importer # noqa: F401 from bionemo.esm2.model.model import ESM2Config -from bionemo.esm2.testing.compare import assert_model_equivalence +from bionemo.esm2.testing.compare import assert_esm2_equivalence from bionemo.llm.model.biobert.lightning import biobert_lightning_module from bionemo.testing import megatron_parallel_state_utils @@ -35,7 +35,7 @@ def test_nemo2_conversion_equivalent_8m(tmp_path): module = biobert_lightning_module(config=ESM2Config()) io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") with megatron_parallel_state_utils.distributed_model_parallel_state(): - assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag) + assert_esm2_equivalence(tmp_path / "nemo_checkpoint", model_tag) def test_nemo2_conversion_equivalent_8m_bf16(tmp_path): @@ -43,7 +43,7 @@ def test_nemo2_conversion_equivalent_8m_bf16(tmp_path): module = biobert_lightning_module(config=ESM2Config()) io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") with megatron_parallel_state_utils.distributed_model_parallel_state(precision="bf16"): - assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag, precision="bf16") + assert_esm2_equivalence(tmp_path / "nemo_checkpoint", model_tag, precision="bf16") @pytest.mark.slow @@ -52,4 +52,4 @@ def test_nemo2_conversion_equivalent_650m(tmp_path): module = biobert_lightning_module(config=ESM2Config()) io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") with megatron_parallel_state_utils.distributed_model_parallel_state(): - assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag, atol=1e-4, rtol=1e-4) + assert_esm2_equivalence(tmp_path / "nemo_checkpoint", model_tag, atol=1e-4, rtol=1e-4) diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py index 8895b3719a..12cf4800ea 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py @@ -29,7 +29,7 @@ from bionemo.esm2.data.datamodule import ESMDataModule from bionemo.esm2.data.tokenizer import get_tokenizer from bionemo.esm2.model.embedding import ESM2Embedding -from bionemo.esm2.testing.compare import assert_model_equivalence +from bionemo.esm2.testing.compare import assert_esm2_equivalence from bionemo.llm.model.biobert.model import MegatronBioBertModel from bionemo.llm.utils.weight_utils import nemo1_to_nemo2_biobert_key_mapping from bionemo.testing import megatron_parallel_state_utils @@ -180,7 +180,7 @@ def test_model_equivalence_with_huggingface_8m(precision): model_tag = "facebook/esm2_t6_8M_UR50D" ckpt_path = load("esm2/8m:2.0") with megatron_parallel_state_utils.distributed_model_parallel_state(precision=precision): - assert_model_equivalence(ckpt_path, model_tag, precision=precision) + assert_esm2_equivalence(ckpt_path, model_tag, precision=precision) @pytest.mark.slow @@ -188,7 +188,7 @@ def test_model_equivalence_with_huggingface_650m(): model_tag = "facebook/esm2_t33_650M_UR50D" ckpt_path = load("esm2/650m:2.0") with megatron_parallel_state_utils.distributed_model_parallel_state(): - assert_model_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) + assert_esm2_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) @pytest.mark.slow @@ -196,7 +196,7 @@ def test_model_equivalence_with_huggingface_650m_bf16(): model_tag = "facebook/esm2_t33_650M_UR50D" ckpt_path = load("esm2/650m:2.0") with megatron_parallel_state_utils.distributed_model_parallel_state(precision="bf16"): - assert_model_equivalence(ckpt_path, model_tag, precision="bf16") + assert_esm2_equivalence(ckpt_path, model_tag, precision="bf16") @pytest.mark.slow @@ -205,4 +205,4 @@ def test_model_equivalence_with_huggingface_3b(): model_tag = "facebook/esm2_t36_3B_UR50D" ckpt_path = load("esm2/3b:2.0") with megatron_parallel_state_utils.distributed_model_parallel_state(): - assert_model_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) + assert_esm2_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/transformer_specs.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/transformer_specs.py index d6baae9044..b3f6bceb5c 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/transformer_specs.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/transformer_specs.py @@ -17,17 +17,17 @@ from enum import Enum from typing import Optional, Sequence, Type +from megatron.core.extensions.transformer_engine import ( + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TERowParallelLinear, +) from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.models.bert import bert_layer_specs from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer import spec_utils from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules -from megatron.core.transformer.custom_layers.transformer_engine import ( - TEDotProductAttention, - TELayerNormColumnParallelLinear, - TERowParallelLinear, -) from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp