Skip to content

Commit

Permalink
fix save_pretrained
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Jul 24, 2024
1 parent c729068 commit 877336a
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 80 deletions.
56 changes: 40 additions & 16 deletions src/optimum/nvidia/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from abc import ABCMeta, abstractmethod
from logging import getLogger
from os import PathLike, scandir, symlink
from pathlib import Path
Expand Down Expand Up @@ -51,7 +52,7 @@
SupportsFromHuggingFace,
SupportsTransformersConversion,
)
from optimum.nvidia.utils import get_user_agent, model_type_from_known_config
from optimum.nvidia.utils import get_user_agent
from optimum.nvidia.utils.nvml import get_device_count, get_device_name
from optimum.utils import NormalizedConfig

Expand Down Expand Up @@ -94,7 +95,9 @@ def folder_list_checkpoints(folder: Path) -> Iterable[Path]:
return checkpoint_candidates


def get_trtllm_artifact(model_id: str, patterns: List[str], add_default_allow_patterns: bool = True) -> Path:
def get_trtllm_artifact(
model_id: str, patterns: List[str], add_default_allow_patterns: bool = True
) -> Path:
if (local_path := Path(model_id)).exists():
return local_path

Expand All @@ -105,7 +108,9 @@ def get_trtllm_artifact(model_id: str, patterns: List[str], add_default_allow_pa
library_name=LIBRARY_NAME,
library_version=trtllm_version,
user_agent=get_user_agent(),
allow_patterns=patterns + HUB_SNAPSHOT_ALLOW_PATTERNS if add_default_allow_patterns else patterns,
allow_patterns=patterns + HUB_SNAPSHOT_ALLOW_PATTERNS
if add_default_allow_patterns
else patterns,
)
)

Expand All @@ -117,10 +122,10 @@ class HuggingFaceHubModel(
tags=["optimum-nvidia", "trtllm"],
repo_url="https://github.com/huggingface/optimum-nvidia",
docs_url="https://huggingface.co/docs/optimum/nvidia_overview",
metaclass=ABCMeta,
):
def __init__(self, engines_path: Union[str, PathLike, Path]):
self._engines_path = Path(engines_path)
super().__init__()

@classmethod
def _from_pretrained(
Expand Down Expand Up @@ -158,22 +163,32 @@ def _from_pretrained(
# Check if the model_id is not a local path
local_model_id = Path(model_id)

engines_folder = checkpoints_folder = None
engine_files = checkpoint_files = []

# Check if we have a local path to a model OR a model_id on the hub
if local_model_id.exists() and local_model_id.is_dir():
if any(engine_files := folder_list_engines(local_model_id)):
checkpoint_files = []
if any(engine_files := list(folder_list_engines(local_model_id))):
engines_folder = engine_files[
0
].parent # Looking for parent folder not actual specific engine file
checkpoints_folder = None
else:
checkpoint_files = folder_list_checkpoints(local_model_id)
checkpoint_files = list(folder_list_checkpoints(local_model_id))

if checkpoint_files:
checkpoints_folder = checkpoint_files[0].parent

else:
# Look for prebuild TRTLLM Engine
engine_files = checkpoint_files = []
if not force_export:
LOGGER.debug(f"Retrieving prebuild engine(s) for device {device_name}")
cached_path = get_trtllm_artifact(
model_id, [f"{common_hub_path}/**/{PATH_FOLDER_ENGINES}/*.engine"]
)

engine_files = folder_list_engines(cached_path / PATH_FOLDER_ENGINES)
engines_folder = cached_path / PATH_FOLDER_ENGINES
engine_files = folder_list_engines(engines_folder)

# if no engine is found, then just try to locate a checkpoint
if not engine_files:
Expand All @@ -182,10 +197,11 @@ def _from_pretrained(
model_id, [f"{common_hub_path}/**/*.safetensors"]
)

checkpoint_files = folder_list_checkpoints(cached_path / PATH_FOLDER_CHECKPOINTS)
checkpoints_folder = cached_path / PATH_FOLDER_CHECKPOINTS
checkpoint_files = folder_list_checkpoints(checkpoints_folder)

# If no checkpoint available, we are good for a full export from the Hugging Face Hub
if not checkpoint_files:
if not engine_files:
LOGGER.info(f"No prebuild engines nor checkpoint were found for {model_id}")

# Retrieve the snapshot if needed
Expand Down Expand Up @@ -272,7 +288,7 @@ def _from_pretrained(
)

_ = converter.build(ranked_model, build_config)
engine_files = converter.workspace.engines_path
engines_folder = converter.workspace.engines_path

LOGGER.info(
f"Saved TensorRT-LLM engines at {converter.workspace.engines_path}"
Expand All @@ -285,11 +301,17 @@ def _from_pretrained(
raise ValueError(
"Model doesn't support Hugging Face transformers conversion, aborting."
)
else:
generation_config = GenerationConfig.from_pretrained(engines_folder)

return cls(
engines_path=engine_files,
generation_config=generation_config,
)
return cls(
engines_path=engines_folder,
generation_config=generation_config,
)

@abstractmethod
def _save_additional_parcels(self, save_directory: Path):
raise NotImplementedError()

def _save_pretrained(self, save_directory: Path) -> None:
try:
Expand All @@ -312,3 +334,5 @@ def _save_pretrained(self, save_directory: Path) -> None:
save_directory.joinpath(file.relative_to(self._engines_path)),
symlinks=True,
)
finally:
self._save_additional_parcels(save_directory)
15 changes: 0 additions & 15 deletions src/optimum/nvidia/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger
from os import PathLike
from typing import TYPE_CHECKING, Optional, Union

from tensorrt_llm.models.gemma.model import GemmaForCausalLM as TrtGemmaForCausalLM
from transformers import GemmaForCausalLM as TransformersGemmaForCausalLM
Expand All @@ -24,10 +22,6 @@
from optimum.nvidia.runtime import CausalLM


if TYPE_CHECKING:
from optimum.nvidia.runtime import ExecutorConfig, GenerationConfig


LOGGER = getLogger(__name__)


Expand All @@ -36,12 +30,3 @@ class GemmaForCausalLM(CausalLM, HuggingFaceHubModel, SupportsTransformersConver
TRT_LLM_TARGET_MODEL_CLASSES = TrtGemmaForCausalLM

TRT_LLM_MANDATORY_CONVERSION_PARAMS = {"share_embedding_table": True}

def __init__(
self,
engines_path: Union[str, PathLike],
generation_config: "GenerationConfig",
executor_config: Optional["ExecutorConfig"] = None,
):
CausalLM.__init__(self, engines_path, generation_config, executor_config)
HuggingFaceHubModel.__init__(self, engines_path)
18 changes: 1 addition & 17 deletions src/optimum/nvidia/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger
from os import PathLike
from typing import TYPE_CHECKING, Optional, Union

from tensorrt_llm.models.llama.model import LLaMAForCausalLM
from transformers import LlamaForCausalLM as TransformersLlamaForCausalLM

from optimum.nvidia.hub import HuggingFaceHubModel
from optimum.nvidia.models import SupportsTransformersConversion
from optimum.nvidia.runtime import CausalLM


if TYPE_CHECKING:
from optimum.nvidia.runtime import ExecutorConfig, GenerationConfig


LOGGER = getLogger(__name__)


class LlamaForCausalLM(CausalLM, HuggingFaceHubModel, SupportsTransformersConversion):
class LlamaForCausalLM(CausalLM, SupportsTransformersConversion):
HF_LIBRARY_TARGET_MODEL_CLASS = TransformersLlamaForCausalLM
TRT_LLM_TARGET_MODEL_CLASSES = LLaMAForCausalLM

def __init__(
self,
engines_path: Union[str, PathLike],
generation_config: "GenerationConfig",
executor_config: Optional["ExecutorConfig"] = None,
):
CausalLM.__init__(self, engines_path, generation_config, executor_config)
HuggingFaceHubModel.__init__(self, engines_path)
15 changes: 0 additions & 15 deletions src/optimum/nvidia/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger
from os import PathLike
from typing import TYPE_CHECKING, Optional, Union

from tensorrt_llm.models.llama.model import LLaMAForCausalLM
from transformers import MistralForCausalLM as TransformersMistralForCausalLM
Expand All @@ -24,22 +22,9 @@
from optimum.nvidia.runtime import CausalLM


if TYPE_CHECKING:
from optimum.nvidia.runtime import ExecutorConfig, GenerationConfig


LOGGER = getLogger(__name__)


class MistralForCausalLM(CausalLM, HuggingFaceHubModel, SupportsTransformersConversion):
HF_LIBRARY_TARGET_MODEL_CLASS = TransformersMistralForCausalLM
TRT_LLM_TARGET_MODEL_CLASSES = LLaMAForCausalLM

def __init__(
self,
engines_path: Union[str, PathLike],
generation_config: "GenerationConfig",
executor_config: Optional["ExecutorConfig"] = None,
):
CausalLM.__init__(self, engines_path, generation_config, executor_config)
HuggingFaceHubModel.__init__(self, engines_path)
15 changes: 0 additions & 15 deletions src/optimum/nvidia/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger
from os import PathLike
from typing import TYPE_CHECKING, Optional, Union

from tensorrt_llm.models.llama.model import LLaMAForCausalLM
from transformers import MixtralForCausalLM as TransformersMixtralForCausalLM
Expand All @@ -24,22 +22,9 @@
from optimum.nvidia.runtime import CausalLM


if TYPE_CHECKING:
from optimum.nvidia.runtime import ExecutorConfig, GenerationConfig


LOGGER = getLogger(__name__)


class MixtralForCausalLM(CausalLM, HuggingFaceHubModel, SupportsTransformersConversion):
HF_LIBRARY_TARGET_MODEL_CLASS = TransformersMixtralForCausalLM
TRT_LLM_TARGET_MODEL_CLASSES = LLaMAForCausalLM

def __init__(
self,
engines_path: Union[str, PathLike],
generation_config: "GenerationConfig",
executor_config: Optional["ExecutorConfig"] = None,
):
CausalLM.__init__(self, engines_path, generation_config, executor_config)
HuggingFaceHubModel.__init__(self, engines_path)
19 changes: 17 additions & 2 deletions src/optimum/nvidia/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from tensorrt_llm.hlapi import SamplingParams

from optimum.nvidia.hub import HuggingFaceHubModel
from optimum.nvidia.utils.nvml import is_post_ampere


Expand Down Expand Up @@ -155,5 +156,19 @@ def loss(self) -> None:
return None


class CausalLM(InferenceRuntimeBase):
pass
class CausalLM(HuggingFaceHubModel, InferenceRuntimeBase):
def __init__(
self,
engines_path: Union[str, PathLike, Path],
generation_config: "GenerationConfig",
executor_config: Optional["ExecutorConfig"] = None,
):
InferenceRuntimeBase.__init__(
self, engines_path, generation_config, executor_config
)
HuggingFaceHubModel.__init__(self, engines_path)

def _save_additional_parcels(self, save_directory: Path):
self._generation_config.save_pretrained(
save_directory, "generation_config.json"
)
10 changes: 10 additions & 0 deletions tests/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import mock
import pytest
import torch.cuda
from huggingface_hub import login
from transformers import AutoConfig as HfAutoConfig
from transformers import AutoModelForCausalLM as HfAutoModelForCausalLM

Expand Down Expand Up @@ -60,6 +62,7 @@ def test_folder_list_engines(rank: int):
("meta-llama/Llama-2-7b-chat-hf", "google/gemma-2b", "mistralai/Mistral-7B-v0.3"),
)
def test_save_engine_locally_and_reload(model_id: str):
login("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")

Check warning on line 65 in tests/test_hub.py

View workflow job for this annotation

GitHub Actions / trufflehog

Found verified HuggingFace result 🐷🔑
with TemporaryDirectory() as hf_out:
with TemporaryDirectory() as trtllm_out:
trtllm_out = Path(trtllm_out)
Expand All @@ -70,9 +73,13 @@ def _save():

model = HfAutoModelForCausalLM.from_config(config)
model.save_pretrained(hf_out)
del model
torch.cuda.empty_cache()

model = AutoModelForCausalLM.from_pretrained(hf_out)
model.save_pretrained(trtllm_out)
del model
torch.cuda.empty_cache()

assert trtllm_out.exists()
assert (trtllm_out / "rank0.engine").exists()
Expand All @@ -86,6 +93,9 @@ def _reload():
== 0
)

del model
torch.cuda.empty_cache()

_save()
_reload()

Expand Down

0 comments on commit 877336a

Please sign in to comment.