Skip to content

Commit

Permalink
Add AMPLIFY huggingface conversion utility
Browse files Browse the repository at this point in the history
Signed-off-by: Peter St. John <[email protected]>
  • Loading branch information
pstjohn committed Feb 3, 2025
1 parent bd8ae26 commit 5a849d6
Show file tree
Hide file tree
Showing 7 changed files with 309 additions and 32 deletions.
10 changes: 7 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ EOF
# Reinstall TE to avoid debugpy bug in vscode: https://nvbugspro.nvidia.com/bug/5078830
# Pull the latest TE version from https://github.com/NVIDIA/TransformerEngine/releases
# Use the version that matches the pytorch base container.
ARG TE_TAG=v1.13
ARG TE_TAG=2215fa5c7557b66034068816020f9f611019e457
RUN NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi \
pip --disable-pip-version-check --no-cache-dir install \
git+https://github.com/NVIDIA/TransformerEngine.git@${TE_TAG}
Expand All @@ -48,10 +48,13 @@ RUN NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi \
RUN CAUSAL_CONV1D_FORCE_BUILD=TRUE pip --disable-pip-version-check --no-cache-dir install \
git+https://github.com/Dao-AILab/[email protected]

# Mamba dependancy installation
# Mamba dependency installation
RUN pip --disable-pip-version-check --no-cache-dir install \
git+https://github.com/state-spaces/[email protected]

ARG XFORMER_ENGINE_TAG=v0.0.29.post1
RUN pip install -v -U git+https://github.com/facebookresearch/xformers.git@${XFORMER_ENGINE_TAG}#egg=xformers

RUN pip install hatchling # needed to install nemo-run
ARG NEMU_RUN_TAG=34259bd3e752fef94045a9a019e4aaf62bd11ce2
RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMU_RUN_TAG}
Expand Down Expand Up @@ -100,7 +103,7 @@ COPY ./sub-packages /workspace/bionemo2/sub-packages
RUN --mount=type=bind,source=./.git,target=./.git \
--mount=type=bind,source=./requirements-test.txt,target=/requirements-test.txt \
--mount=type=bind,source=./requirements-cve.txt,target=/requirements-cve.txt \
--mount=type=cache,target=/root/.cache <<EOF
<<EOF
set -eo pipefail

uv pip install maturin --no-build-isolation
Expand All @@ -114,6 +117,7 @@ uv pip install --no-build-isolation \
rm -rf ./3rdparty
rm -rf /tmp/*
rm -rf ./sub-packages/bionemo-noodles/target
rm -rf /root/.cache
EOF

# In the devcontainer image, we just copy over the finished `dist-packages` folder from the build image back into the
Expand Down
164 changes: 164 additions & 0 deletions sub-packages/bionemo-amplify/src/bionemo/amplify/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from pathlib import Path

import torch
from nemo.lightning import io, teardown
from nemo.lightning.pytorch.utils import dtype_from_hf
from transformers import AutoConfig as HFAutoConfig
from transformers import AutoModel

from bionemo.amplify.model import AMPLIFYConfig
from bionemo.amplify.tokenizer import BioNeMoAMPLIFYTokenizer
from bionemo.llm.lightning import BionemoLightningModule
from bionemo.llm.model.biobert.lightning import biobert_lightning_module


@io.model_importer(BionemoLightningModule, "hf")
class HFAMPLIFYImporter(io.ModelConnector[AutoModel, BionemoLightningModule]):
"""Converts a Hugging Face ESM-2 model to a NeMo ESM-2 model."""

def init(self) -> BionemoLightningModule:
"""Initialize the converted model."""
return biobert_lightning_module(self.config, tokenizer=self.tokenizer)

def apply(self, output_path: Path) -> Path:
"""Applies the transformation."""
source = AutoModel.from_pretrained(str(self), trust_remote_code=True, torch_dtype="auto")
target = self.init()
trainer = self.nemo_setup(target)
self.convert_state(source, target)
self.nemo_save(output_path, trainer)
teardown(trainer, target)
return output_path

def convert_state(self, source, target):
"""Converting HF state dict to NeMo state dict."""
mapping = {
"encoder.weight": "embedding.word_embeddings.weight",
"transformer_encoder.*.wo.weight": "encoder.layers.*.self_attention.linear_proj.weight",
"transformer_encoder.*.ffn.w12.weight": "encoder.layers.*.mlp.linear_fc1.weight",
"transformer_encoder.*.ffn.w3.weight": "encoder.layers.*.mlp.linear_fc2.weight",
"transformer_encoder.*.attention_norm.weight": "encoder.layers.*.self_attention.linear_qkv.layer_norm_weight",
"transformer_encoder.*.ffn_norm.weight": "encoder.layers.*.mlp.linear_fc1.layer_norm_weight",
"layer_norm_2.weight": "encoder.final_layernorm.weight",
"decoder.weight": "output_layer.weight",
"decoder.bias": "output_layer.bias",
# "esm.encoder.layer.*.attention.output.dense.weight": "encoder.layers.*.self_attention.linear_proj.weight",
}

# lm_head.bias
return io.apply_transforms(
source,
target,
mapping=mapping,
transforms=[_import_qkv_weight],
# transforms=[_pad_embeddings, _pad_bias, _import_qkv_weight],
)

@property
def tokenizer(self) -> BioNeMoAMPLIFYTokenizer:
"""We just have the one tokenizer for ESM-2."""
return BioNeMoAMPLIFYTokenizer()

@property
def config(self) -> AMPLIFYConfig:
"""Returns the transformed ESM-2 config given the model tag."""
source = HFAutoConfig.from_pretrained(str(self), trust_remote_code=True)
output = AMPLIFYConfig(
num_layers=source.num_hidden_layers,
hidden_size=source.hidden_size,
ffn_hidden_size=source.intermediate_size,
position_embedding_type="rope",
num_attention_heads=source.num_attention_heads,
seq_length=source.max_length,
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
params_dtype=dtype_from_hf(source),
)

return output


@io.state_transform(
source_key="esm.embeddings.word_embeddings.weight",
target_key="embedding.word_embeddings.weight",
)
def _pad_embeddings(ctx: io.TransformCTX, source_embed):
"""Pad the embedding layer to the new input dimension."""
nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by
hf_embedding_dimension = source_embed.size(0)
num_padding_rows = nemo_embedding_dimension - hf_embedding_dimension
padding_rows = torch.zeros(num_padding_rows, source_embed.size(1))
return torch.cat((source_embed, padding_rows), dim=0)


@io.state_transform(
source_key="lm_head.bias",
target_key="output_layer.bias",
)
def _pad_bias(ctx: io.TransformCTX, source_bias):
"""Pad the embedding layer to the new input dimension."""
nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by
hf_embedding_dimension = source_bias.size(0)
output_bias = torch.zeros(nemo_embedding_dimension, dtype=source_bias.dtype, device=source_bias.device)
output_bias[:hf_embedding_dimension] = source_bias
return output_bias


@io.state_transform(
source_key=(
"transformer_encoder.*.q.weight",
"transformer_encoder.*.k.weight",
"transformer_encoder.*.v.weight",
),
target_key="encoder.layers.*.self_attention.linear_qkv.weight",
)
def _import_qkv_weight(ctx: io.TransformCTX, query, key, value):
"""Pad the embedding layer to the new input dimension."""
concat_weights = torch.cat((query, key, value), dim=0)
input_shape = concat_weights.size()
np = ctx.target.config.num_attention_heads
# transpose weights
# [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads]
# --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads]
concat_weights = concat_weights.view(3, np, -1, query.size()[-1])
concat_weights = concat_weights.transpose(0, 1).contiguous()
concat_weights = concat_weights.view(*input_shape)
return concat_weights


@io.state_transform(
source_key=(
"esm.encoder.layer.*.attention.self.query.bias",
"esm.encoder.layer.*.attention.self.key.bias",
"esm.encoder.layer.*.attention.self.value.bias",
),
target_key="encoder.layers.*.self_attention.linear_qkv.bias",
)
def _import_qkv_bias(ctx: io.TransformCTX, query, key, value):
"""Pad the embedding layer to the new input dimension."""
concat_biases = torch.cat((query, key, value), dim=0)
input_shape = concat_biases.size()
np = ctx.target.config.num_attention_heads
# transpose biases
# [num_splits_model_parallel * attention head size * #attention heads]
# --> [attention head size * num_splits_model_parallel * #attention heads]
concat_biases = concat_biases.view(3, np, -1)
concat_biases = concat_biases.transpose(0, 1).contiguous()
concat_biases = concat_biases.view(*input_shape)
return concat_biases
43 changes: 43 additions & 0 deletions sub-packages/bionemo-amplify/tests/bionemo/amplify/test_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from nemo.lightning import io

from bionemo.amplify.convert import HFAMPLIFYImporter # noqa: F401
from bionemo.amplify.model import AMPLIFYConfig
from bionemo.amplify.tokenizer import BioNeMoAMPLIFYTokenizer
from bionemo.core.utils.dtypes import PrecisionTypes
from bionemo.esm2.testing.compare import get_input_tensors, load_and_evaluate_hf_model
from bionemo.llm.model.biobert.lightning import biobert_lightning_module


def assert_amplify_equivalence(ckpt_path: str, model_tag: str, precision: PrecisionTypes) -> None:
tokenizer = BioNeMoAMPLIFYTokenizer()

input_ids, attention_mask = get_input_tensors(tokenizer)
load_and_evaluate_hf_model(model_tag, precision, input_ids, attention_mask)


def test_convert_smoke_test_120M(tmp_path):
model_tag = "chandar-lab/AMPLIFY_120M"
module = biobert_lightning_module(config=AMPLIFYConfig())
io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint")


def test_convert_smoke_test_350M(tmp_path):
model_tag = "chandar-lab/AMPLIFY_350M"
module = biobert_lightning_module(config=AMPLIFYConfig())
io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint")
96 changes: 81 additions & 15 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from bionemo.esm2.model.model import ESM2Config


def assert_model_equivalence(
def assert_esm2_equivalence(
ckpt_path: Path | str,
model_tag: str,
precision: PrecisionTypes = "fp32",
Expand All @@ -49,13 +49,57 @@ def assert_model_equivalence(
"""
tokenizer = get_tokenizer()

input_ids, attention_mask = get_input_tensors(tokenizer)

nemo_logits, nemo_hidden_state = load_and_evaluate_nemo_esm2(ckpt_path, precision, input_ids, attention_mask)
gc.collect()
torch.cuda.empty_cache()
hf_logits, hf_hidden_state = load_and_evaluate_hf_model(model_tag, precision, input_ids, attention_mask)

# Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These
# should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences.
# We don't care about the padding tokens, so we only compare the non-padding tokens.
assert_cosine_similarity(nemo_logits, hf_logits, attention_mask, rtol, atol)
assert_cosine_similarity(nemo_hidden_state, hf_hidden_state, attention_mask, rtol, atol)


def get_input_tensors(tokenizer) -> tuple[torch.Tensor, torch.Tensor]:
"""Get input tensors for testing.
Args:
tokenizer: A huggingface-like tokenizer object.
Returns:
A tuple of the input IDs and attention mask tensors.
"""
test_proteins = [
"MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA",
"MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLAGG",
]
tokens = tokenizer(test_proteins, return_tensors="pt", padding=True, truncation=True).to("cuda")
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]
input_ids: torch.Tensor = tokens["input_ids"] # type: ignore
attention_mask: torch.Tensor = tokens["attention_mask"] # type: ignore
return input_ids, attention_mask


def load_and_evaluate_nemo_esm2(
ckpt_path: Path | str,
precision: PrecisionTypes,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Load a NeMo2 ESM-2 model and evaluate it on the given inputs.
Args:
ckpt_path: A path to a NeMo2 checkpoint for an ESM-2 model.
precision: The precision type to use for the comparison.
input_ids: The input IDs tensor to evaluate.
attention_mask: The attention mask tensor to evaluate.
Returns:
A tuple of the logits and hidden states tensors calculated by the NeMo2 model, respectively.
"""
tokenizer = get_tokenizer()

dtype = get_autocast_dtype(precision)
nemo_config = ESM2Config(
Expand All @@ -77,23 +121,45 @@ def assert_model_equivalence(
nemo_output = nemo_model(input_ids, attention_mask)
nemo_logits = nemo_output["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size]
nemo_hidden_state = nemo_output["hidden_states"]
return nemo_logits, nemo_hidden_state

del nemo_model
gc.collect()
torch.cuda.empty_cache()

def load_and_evaluate_hf_model(
model_tag: str, precision: PrecisionTypes, input_ids: torch.Tensor, attention_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Load a HuggingFace model and evaluate it on the given inputs.
Args:
model_tag: The HuggingFace model tag for the model to compare against.
precision: The precision type to use for the comparison.
input_ids: The input IDs tensor to evaluate.
attention_mask: The attention mask tensor to evaluate.
Returns:
A tuple of the logits and hidden states tensors calculated by the HuggingFace model, respectively.
"""
hf_model = AutoModelForMaskedLM.from_pretrained(model_tag, torch_dtype=get_autocast_dtype(precision)).cuda().eval()
hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True)
hf_hidden_state = hf_output_all.hidden_states[-1]
return hf_output_all.logits, hf_hidden_state

# Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These
# should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences.
# We don't care about the padding tokens, so we only compare the non-padding tokens.
logit_similarity = torch.nn.functional.cosine_similarity(nemo_logits, hf_output_all.logits, dim=2)
logit_similarity = logit_similarity[attention_mask == 1]

hidden_state_similarity = torch.nn.functional.cosine_similarity(nemo_hidden_state, hf_hidden_state, dim=2)
hidden_state_similarity = hidden_state_similarity[attention_mask == 1]
def assert_cosine_similarity(
tensor1: torch.Tensor,
tensor2: torch.Tensor,
mask: torch.Tensor,
rtol: float | None = None,
atol: float | None = None,
) -> None:
"""Assert that the cosine similarity between two tensors is close to 1.
torch.testing.assert_close(logit_similarity, torch.ones_like(logit_similarity), rtol=rtol, atol=atol)
torch.testing.assert_close(hidden_state_similarity, torch.ones_like(hidden_state_similarity), rtol=rtol, atol=atol)
Args:
tensor1: The first tensor to compare.
tensor2: The second tensor to compare.
mask: A mask tensor to apply to the comparison.
rtol: The relative tolerance to use for the comparison. Defaults to 1e-4.
atol: The absolute tolerance to use for the comparison. Defaults to 1e-4.
"""
similarity = torch.nn.functional.cosine_similarity(tensor1, tensor2, dim=2)
similarity = similarity[mask == 1]
torch.testing.assert_close(similarity, torch.ones_like(similarity), rtol=rtol, atol=atol)
Loading

0 comments on commit 5a849d6

Please sign in to comment.