diff --git a/examples/async-text-generation.py b/examples/async-text-generation.py new file mode 100644 index 00000000..8263e68a --- /dev/null +++ b/examples/async-text-generation.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from argparse import ArgumentParser +from logging import getLogger +from pathlib import Path + +from transformers import AutoTokenizer + +from optimum.nvidia import AutoModelForCausalLM, ExportConfig, setup_logging + + +# Setup logging needs to happen before importing TRT ... +setup_logging(True) + +from optimum.nvidia.utils.cli import ( + postprocess_quantization_parameters, + register_common_model_topology_args, + register_optimization_profiles_args, + register_quantization_args, +) + + +LOGGER = getLogger(__name__) + + +async def infer(): + tokenizer = AutoTokenizer.from_pretrained(args.model) + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + + export = ExportConfig.from_pretrained(args.model) + export.max_input_len = 1024 + export.max_output_len = 256 + export.max_num_tokens = 256 + export.max_beam_width = 1 + + model = AutoModelForCausalLM.from_pretrained( + args.model, device_map="auto", export_config=export + ) + # model.save_pretrained(args.output) + + prompt = "What is the latest generation of Nvidia GPUs?" + tokens = tokenizer(prompt, return_tensors="pt") + generated = await model.agenerate( + tokens["input_ids"], + ) + + generated_text = tokenizer.batch_decode(generated, skip_special_tokens=True) + print(generated_text) + + +if __name__ == "__main__": + parser = ArgumentParser("🤗 Optimum-Nvidia Text-Generation Example") + parser.add_argument( + "--hub-token", + type=str, + help="Hugging Face Hub Token to retrieve private weights.", + ) + register_common_model_topology_args(parser) + register_optimization_profiles_args(parser) + register_quantization_args(parser) # Inject params.quantization_config + + parser.add_argument("model", type=str, help="The model's id or path to use.") + parser.add_argument( + "output", type=Path, help="Path to store generated TensorRT engine." + ) + args = parser.parse_args() + args = postprocess_quantization_parameters(args) + + if args.hub_token is not None: + from huggingface_hub import login + + login(args.hub_token) + + asyncio.run(infer()) diff --git a/examples/text-generation.py b/examples/text-generation.py index 1036dac9..a5b40b8e 100644 --- a/examples/text-generation.py +++ b/examples/text-generation.py @@ -19,7 +19,7 @@ from transformers import AutoTokenizer -from optimum.nvidia import AutoModelForCausalLM, setup_logging +from optimum.nvidia import AutoModelForCausalLM, ExportConfig, setup_logging # Setup logging needs to happen before importing TRT ... @@ -59,27 +59,26 @@ login(args.hub_token) - tokenizer = AutoTokenizer.from_pretrained(args.model, padding_side="left") + tokenizer = AutoTokenizer.from_pretrained(args.model) if not tokenizer.pad_token: tokenizer.pad_token = tokenizer.eos_token - # Create the model + export = ExportConfig.from_pretrained(args.model) + export.max_input_len = 1024 + export.max_output_len = 256 + export.max_num_tokens = 256 + export.max_beam_width = 1 + model = AutoModelForCausalLM.from_pretrained( - args.model, use_fp8=args.fp8, tp=args.tp, pp=args.pp + args.model, device_map="auto", export_config=export ) - model.save_pretrained(args.output) + # model.save_pretrained(args.output) prompt = "What is the latest generation of Nvidia GPUs?" - tokens = tokenizer(prompt, padding=True, return_tensors="pt") - generated, lengths = model.generate( - **tokens, - top_k=40, - top_p=0.95, - repetition_penalty=10, - pad_token_id=tokenizer.eos_token_id, - eos_token_id=tokenizer.eos_token_id, - max_new_tokens=args.max_new_tokens, + tokens = tokenizer(prompt, return_tensors="pt") + generated = model.generate( + tokens["input_ids"], ) - generated_text = tokenizer.batch_decode(generated, skip_special_tokens=True) + generated_text = tokenizer.decode(generated, skip_special_tokens=True) print(generated_text) diff --git a/pyproject.toml b/pyproject.toml index ca64d23d..fd65052d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,26 +26,24 @@ classifiers = [ dependencies = [ "accelerate == 0.25", "datasets >= 2.14.0", - "huggingface-hub >= 0.22.0", - "hf-transfer", + "huggingface-hub >= 0.23.0", + "hf-transfer==0.1.6", "mpmath == 1.3.0", "numpy >= 1.26.0, < 2.0.0", "onnx >= 1.12.0", "optimum >= 1.13.0", "setuptools", - "tensorrt-llm == 0.9.0", + "tensorrt-llm > 0.10.0", "torch>=2.2.0a,<=2.3.0a", "transformers >= 4.38.2", "pynvml" ] - [project.urls] Homepage = "https://huggingface.co/hardware/nvidia" Repository = "https://github.com/huggingface/optimum-nvidia" Issues = "https://github.com/huggingface/optimum-nvidia/issues" - # List additional dependencies [project.optional-dependencies] test = ["mock", "pytest", "pytest-xdist", "psutil", "parameterized", "datasets", "safetensors",] @@ -86,7 +84,6 @@ skip-magic-trailing-comma = false # Like Black, automatically detect the appropriate line ending. line-ending = "auto" - [tool.pytest.ini_options] pythonpath = [ "src" diff --git a/setup.py b/setup.py index 97916cd8..931e722d 100644 --- a/setup.py +++ b/setup.py @@ -29,14 +29,14 @@ INSTALL_REQUIRES = [ "accelerate == 0.25", "datasets >= 2.14", - "huggingface-hub >= 0.22.0", - "hf-transfer", + "huggingface-hub >= 0.23", + "hf-transfer==0.1.6", "mpmath == 1.3.0", "numpy >= 1.26.0", "onnx >= 1.12.0", "optimum >= 1.13.0", "setuptools", - "tensorrt-llm == 0.9.0", + "tensorrt-llm > 0.10.0", "torch>=2.2.0a,<=2.3.0a", "transformers >= 4.38.2", "pynvml" diff --git a/src/optimum/nvidia/__init__.py b/src/optimum/nvidia/__init__.py index 103a6652..ef751b18 100644 --- a/src/optimum/nvidia/__init__.py +++ b/src/optimum/nvidia/__init__.py @@ -13,8 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config import TensorRTConfig +LIBRARY_NAME = "trtllm" + + +from .export import ExportConfig from .logging import DEFAULT_LOGGING_FMT, setup_logging from .models import AutoModelForCausalLM -from .pipelines import pipeline +from .optimizations import IntoModelOptQuantizeConfig + +# from .pipelines import pipeline from .version import VERSION, __version__ diff --git a/src/optimum/nvidia/builder/__init__.py b/src/optimum/nvidia/builder/__init__.py deleted file mode 100644 index a363f015..00000000 --- a/src/optimum/nvidia/builder/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from .local import LocalEngineBuilder diff --git a/src/optimum/nvidia/builder/config.py b/src/optimum/nvidia/builder/config.py deleted file mode 100644 index fbd43cdb..00000000 --- a/src/optimum/nvidia/builder/config.py +++ /dev/null @@ -1,270 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from logging import getLogger -from typing import Optional, Union - -import torch -from tensorrt_llm.models import PretrainedConfig -from tensorrt_llm.plugin import PluginConfig - -from optimum.nvidia.lang import DataType - - -LOGGER = getLogger() -SUPPORTED_LOGITS_DTYPE = {"float32", "float16"} - - -@dataclass -class InferenceProfile: - max_batch_size: int - max_input_len: int - max_output_len: int - - -@dataclass -class GenerationProfile: - num_beams: int - max_draft_length: int - - -@dataclass -class ShardingProfile: - tensor_parallelism: int = 1 - pipeline_parallelism: int = 1 - world_size: int = 1 - gpus_per_node: int = 1 - - -@dataclass -class EngineConfig: - """ - Represent all the parameters required to tune and build the final TRTLLM engine(s) - """ - - optimisation_level: int - strongly_typed: bool - logits_dtype: str - workload_profile: InferenceProfile - generation_profile: GenerationProfile - sharding_profile: ShardingProfile - plugins_config: PluginConfig - - -class EngineConfigBuilder: - @staticmethod - def from_dict(config: PretrainedConfig, **additional_params): - builder = EngineConfigBuilder(config) - - # Define the data type to export the logits - builder.logits_as(additional_params.pop("logits_dtype", config.logits_dtype)) - - # Workload related - max_batch_size = additional_params.pop("max_batch_size", 1) - max_prompt_length = additional_params.pop("max_prompt_length", 128) - max_new_tokens = ( - additional_params.pop("max_output_length", config.max_position_embeddings) - - max_prompt_length - ) - - if max_new_tokens < 1: - raise ValueError( - "Unable to build the engine because the generation would lead to max_num_tokens < 1. (" - f"max_prompt_length = {max_prompt_length}, " - f"max_position_embeddings={config.max_position_embeddings}, " - f"max_new_tokens={max_new_tokens}" - ")" - ) - - builder.with_inference_profile( - max_batch_size, max_prompt_length, max_new_tokens - ) - - # Generation related - builder.with_generation_profile(additional_params.pop("num_beams", 1)) - - # Speculative decoding - if "max_speculated_draft_length" in additional_params: - builder.with_speculated_decoding( - additional_params.pop("max_speculated_draft_length") - ) - - return builder - - def __init__(self, config: PretrainedConfig): - self._config = config - - self._optimisation_level: int = 3 - self._logits_dtype = config.logits_dtype - self._strongly_typed: bool = False - self._sharding_profile: ShardingProfile = ShardingProfile() - self._workload_profile: Optional[InferenceProfile] = None - self._generation_profile: Optional[GenerationProfile] = None - self._plugin_config: Optional[PluginConfig] = None - - def strongly_typed(self) -> "EngineConfigBuilder": - self._strongly_typed = True - LOGGER.info("Defined engine as strongly typed") - return self - - def shard( - self, - tensor_parallelism: int = 1, - pipeline_parallelism: int = 1, - world_size: int = 1, - gpus_per_node: int = 1, - ) -> "EngineConfigBuilder": - self._sharding_profile = ShardingProfile( - tensor_parallelism, pipeline_parallelism, world_size, gpus_per_node - ) - LOGGER.debug(f"Defined sharding profile as: {self._sharding_profile}") - - return self - - def with_optimisation_level(self, level: int) -> "EngineConfigBuilder": - if level < 1: - raise ValueError(f"level should be >= 1 (got: {level})") - self._optimisation_level = level - LOGGER.info(f"Defined optimisation level to {self._optimisation_level}") - return self - - def logits_as( - self, dtype: Union[str, torch.dtype, DataType] - ) -> "EngineConfigBuilder": - if isinstance(dtype, torch.dtype): - dtype = DataType.from_torch(dtype) - - if isinstance(dtype, DataType): - dtype = dtype.value - - if dtype not in SUPPORTED_LOGITS_DTYPE: - dtype = "float32" - - self._logits_dtype = dtype - LOGGER.info(f"Defined logits dtype to: {self._logits_dtype}") - return self - - def with_inference_profile( - self, max_batch_size: int, max_prompt_length: int, max_new_tokens: int - ) -> "EngineConfigBuilder": - if max_batch_size < 1: - raise ValueError(f"max_batch_size should be >= 1 (got: {max_batch_size})") - - if max_prompt_length < 1: - raise ValueError( - f"max_prompt_length should be >= 1 (got: {max_batch_size})" - ) - - if max_prompt_length >= self._config.max_position_embeddings: - raise ValueError( - f"max_prompt_length should be shorter than the maximum length supported by the model." - f" (got: {max_prompt_length} and" - f" maximum sequence length supported by the model is {self._config.max_position_embeddings})" - ) - - if max_new_tokens < 1: - raise ValueError(f"max_new_tokens should be >= 1 (got: {max_new_tokens})") - - if max_new_tokens > self._config.max_position_embeddings: - raise ValueError( - f"max_new_tokens should be shorter than the maximum length supported by the model." - f" (got: {max_new_tokens} and" - f" maximum sequence length supported by the model is {self._config.max_position_embeddings})" - ) - - self._workload_profile = InferenceProfile( - max_batch_size, max_prompt_length, max_new_tokens - ) - LOGGER.info(f"Defined engine inference profile: {self._workload_profile}") - return self - - def with_generation_profile(self, num_beams: int) -> "EngineConfigBuilder": - if num_beams < 1: - raise ValueError(f"num_beams should be >= 1 (got: {num_beams})") - - self._generation_profile = GenerationProfile(num_beams, -1) - LOGGER.info(f"Defined engine generation profile: {self._generation_profile}") - return self - - def with_speculated_decoding(self, max_draft_length: int) -> "EngineConfigBuilder": - if max_draft_length < 1: - raise ValueError( - f"max_draft_length should be >= 1 (got: {max_draft_length})" - ) - - if self._generation_profile is None: - raise ValueError( - "You should specify generation profile first. " - "Please use EngineConfigBuilder.with_generation_profile()" - ) - - self._generation_profile = GenerationProfile( - self._generation_profile.num_beams, max_draft_length - ) - LOGGER.info( - f"Defined engine generation profile with speculation: {self._generation_profile}" - ) - return self - - def with_plugins_config(self, plugin_config: PluginConfig) -> "EngineConfigBuilder": - self._plugin_config = plugin_config - LOGGER.info(f"Defined plugins config: {plugin_config}") - return self - - def validate(self) -> bool: - if self._workload_profile is None: - raise ValueError( - "You need to set an inference profile. Use EngineConfigBuilder.with_inference_profile()." - ) - - if self._generation_profile is None: - raise ValueError( - "You need to set a generation profile. Use EngineConfigBuilder.with_generation_profile()." - ) - - if self._plugin_config is None: - raise ValueError( - "You need to set a plugin profile. Use EngineConfigBuilder.with_plugins_config()." - ) - - max_generated_length = ( - self._workload_profile.max_input_len - + self._workload_profile.max_output_len - - 1 - ) - if max_generated_length > self._config.max_position_embeddings: - raise ValueError( - "max_prompt_length + max_new_tokens should be lesser or equals " - "to the maximum length supported by the model (got " - f"max_prompt_length={self._workload_profile.max_input_len}, " - f"max_new_tokens={self._workload_profile.max_output_len}," - f"{self._workload_profile.max_input_len + self._workload_profile.max_output_len}" - f" > {self._config.max_position_embeddings}" - ")" - ) - - return True - - def build(self) -> EngineConfig: - self.validate() - return EngineConfig( - optimisation_level=self._optimisation_level, - sharding_profile=self._sharding_profile, - strongly_typed=self._strongly_typed, - logits_dtype=self._logits_dtype, - workload_profile=self._workload_profile, - generation_profile=self._generation_profile, - plugins_config=self._plugin_config, - ) diff --git a/src/optimum/nvidia/builder/local.py b/src/optimum/nvidia/builder/local.py deleted file mode 100644 index 204e818b..00000000 --- a/src/optimum/nvidia/builder/local.py +++ /dev/null @@ -1,152 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from itertools import chain -from logging import getLogger -from pathlib import Path -from subprocess import PIPE, STDOUT, run -from typing import Any, Dict - -from optimum.nvidia import TensorRTConfig -from optimum.nvidia.builder.config import EngineConfig -from optimum.nvidia.utils.patching import BuilderPatcher - - -LOGGER = getLogger() -CLI_PLUGIN_NAMES = { - # Plugins - "bert_attention_plugin", - "gpt_attention_plugin", - "gemm_plugin", - "lookup_plugin", - "lora_plugin", - "moe_plugin", - # Features - "context_fmha", - "context_fmha_fp32_acc", - "paged_kv_cache", - "remove_input_padding", - "use_custom_all_reduce", - "multi_block_mode", - "enable_xqa", - "attention_qk_half_accumulation", - "tokens_per_block", - "use_paged_context_fmha", - "use_context_fmha_for_generation", -} - - -def process_plugin_flag(_: str, value: Any) -> str: - if isinstance(value, bool): - return "enable" if value else "disable" - else: - return value - - -class LocalEngineBuilder: - TRTLLM_BUILD_EXEC = "trtllm-build" - - @staticmethod - def build_cli_command( - checkpoints: Path, - engines: Path, - model_config: TensorRTConfig, - build_config: EngineConfig, - ) -> Dict[str, Any]: - workload_params = { - "--max_batch_size": build_config.workload_profile.max_batch_size, - "--max_input_len": build_config.workload_profile.max_input_len, - "--max_output_len": build_config.workload_profile.max_output_len, - } - - generation_params = { - "--max_beam_width": build_config.generation_profile.num_beams, - } - - if build_config.generation_profile.max_draft_length >= 1: - generation_params["--max_draft_len"] = ( - build_config.generation_profile.max_draft_length - ) - - plugins_params = { - f"--{name}": process_plugin_flag(name, value) - for name in CLI_PLUGIN_NAMES - if (value := getattr(build_config.plugins_config, name)) is not None - } - - build_params = { - "--checkpoint_dir": checkpoints, - "--output_dir": engines, - "--model_config": checkpoints / "model.json", - "--builder_opt": build_config.optimisation_level, - "--logits_dtype": build_config.logits_dtype, - "--tp_size": model_config.mapping.tp_size, - "--pp_size": model_config.mapping.pp_size, - } - - if hasattr(model_config, "trt_model_class") and hasattr( - model_config, "trt_model_file" - ): - build_params["--model_cls_file"] = model_config.trt_model_file - build_params["--model_cls_name"] = model_config.trt_model_class - - if model_config.supports_strong_typing(): - build_params["--strongly_typed"] = None - - return build_params | generation_params | workload_params | plugins_params - - def __init__( - self, config: TensorRTConfig, checkpoint_folder: Path, output_folder: Path - ): - self._config = config - self._checkpoint_folder = checkpoint_folder - self._output_folder = output_folder - - def build(self, config: EngineConfig): - cli_params = LocalEngineBuilder.build_cli_command( - self._checkpoint_folder, self._output_folder, self._config, config - ) - cli_params_list = [str(t) for t in chain.from_iterable(cli_params.items())] - cli_params_list = [i for i in cli_params_list if i != "None"] - - LOGGER.info(f"trtllm-build parameters: {cli_params_list}") - - for rank in range(self._config.mapping.world_size): - ranked_checkpoint = f"rank{rank}.safetensors" - if not (self._checkpoint_folder / ranked_checkpoint).exists(): - raise ValueError( - f"Missing rank-{rank} checkpoints (rank{rank}.safetensors), cannot build." - ) - - # TODO: Remove BuilderPatcher once TensorRT-LLM updates its codebase to allow to disable `optimize(network)`. - with BuilderPatcher(): - # Run the build - result = run( - [LocalEngineBuilder.TRTLLM_BUILD_EXEC] + cli_params_list, - stdout=PIPE, - stderr=STDOUT, - ) - - if result.returncode != 0: - LOGGER.warning( - f"trtllm-build stdout: {result.stdout.decode('utf-8') if result.stdout is not None else None}" - ) - LOGGER.warning( - f"trtllm-build stderr: {result.stderr.decode('utf-8') if result.stderr is not None else None}" - ) - - raise ValueError( - f"Compilation failed ({result.returncode}), " - "please open up an issue at https://github.com/huggingface/optimum-nvidia" - ) # TODO: change with proper error diff --git a/src/optimum/nvidia/config.py b/src/optimum/nvidia/config.py deleted file mode 100644 index fd6f140b..00000000 --- a/src/optimum/nvidia/config.py +++ /dev/null @@ -1,231 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Optional, Union - -import torch -from tensorrt_llm import Mapping -from tensorrt_llm.layers import MoeConfig -from tensorrt_llm.models import PretrainedConfig as TensorRTPretrainedConfig -from tensorrt_llm.models.modeling_utils import ( - QuantConfig as TensorRTQuantizationConfig, -) -from tensorrt_llm.plugin import PluginConfig -from tensorrt_llm.quantization import QuantMode -from transformers import AutoConfig, PretrainedConfig - -from optimum.nvidia.quantization import AmmoQuantizationConfig -from optimum.nvidia.utils import get_user_agent - - -SUPPORTED_DTYPES = [ - "float32", - "float16", - "bfloat16", - "int32", - "int8", - "uint8", - "float8", -] - - -def dtype_to_str(dtype: Union[torch.dtype, str]) -> str: - if isinstance(dtype, str): - if dtype in SUPPORTED_DTYPES: - return dtype - else: - raise ValueError(f"Unsupported dtype ({dtype}) value") - elif dtype == torch.float32: - return "float32" - elif dtype == torch.float16: - return "float16" - elif dtype == torch.bfloat16: - return "bfloat16" - elif dtype == torch.int32: - return "int32" - elif dtype == torch.int8: - return "int8" - elif dtype == torch.uint8: - return "uint8" - elif dtype == torch.float8_e4m3fn: - return "float8" - else: - raise ValueError(f"Unsupported torch.dtype ({dtype}) value") - - -def convert_quant_method_to_trt( - method: str, weight_num_bits: int, activation_num_bits: Optional[int] = None -) -> (QuantMode, str): - if method == "awq": - if not activation_num_bits: - activation_num_bits = 16 - - if weight_num_bits not in {4, 8}: - raise ValueError( - f"Unsupported AWQ quantization schema with {weight_num_bits}-bits weights. " - "Only 4 and 8-bits weights' quantization schemas are supported." - ) - - if activation_num_bits not in {8, 16}: - raise ValueError( - f"Unsupported AWQ quantization schema with {activation_num_bits}-bits activations. " - "Only 8 and 16-bits activations' quantization schemas are supported." - ) - - mode = QuantMode.from_description( - quantize_weights=True, per_group=True, use_int4_weights=weight_num_bits == 4 - ) - return mode, f"W{weight_num_bits}A{activation_num_bits}_AWQ" - elif method == "gptq": - if not activation_num_bits: - activation_num_bits = 16 - - if weight_num_bits not in {4, 8}: - raise ValueError( - f"Unsupported GPTQ quantization schema with {weight_num_bits}-bits weights. " - "Only 4 and 8-bits weights' quantization schemas are supported." - ) - - if activation_num_bits != 16: - raise ValueError( - f"Unsupported GPTQ quantization schema with {activation_num_bits}-bits activations. " - "Only 16-bits activations' quantization schemas are supported." - ) - - mode = QuantMode.from_description( - quantize_weights=True, - per_group=False, - use_int4_weights=weight_num_bits == 4, - ) - return mode, f"W{weight_num_bits}A16_GPTQ" - else: - raise ValueError(f"Unsupported quantization method: {method}") - - -class TensorRTConfig(TensorRTPretrainedConfig, ABC): - @staticmethod - def get_quantization_config( - config: PretrainedConfig, - ) -> (QuantMode, AmmoQuantizationConfig): - if hasattr(config, "quantization_config"): - qconfig = config.quantization_config - num_bits = qconfig.num_bits - group_size = qconfig.group_size - mode, quant_method = convert_quant_method_to_trt( - qconfig.quant_method, num_bits - ) - has_zero_point = qconfig.get("zero_point", False) - exclude_modules = qconfig.get("module_to_not_convert", []) - - return mode, TensorRTQuantizationConfig( - quantization_algo=quant_method, - kv_cache_quant_algo=None, - group_size=group_size, - has_zero_point=has_zero_point, - exclude_modules=exclude_modules, - ) - else: - return QuantMode.from_description(), TensorRTQuantizationConfig( - None, None, None, False, None - ) - - @staticmethod - @abstractmethod - def from_config( - config: PretrainedConfig, mapping: Optional[Mapping] - ) -> "TensorRTConfig": - raise NotImplementedError() - - @staticmethod - def from_pretrained( - model_id_or_path: str, - *, - revision: Optional[str] = None, - token: Union[bool, str, None] = None, - mapping: Optional[Mapping] = None, - ): - config = AutoConfig.from_pretrained( - model_id_or_path, - revision=revision, - token=token, - user_agent=get_user_agent(), - ) - return TensorRTConfig.from_config(config, mapping) - - @staticmethod - @abstractmethod - def supports_strong_typing() -> bool: - raise NotImplementedError() - - @property - def has_moe(self) -> bool: - return getattr(self, "moe_num_experts", 1) - - @property - def moe_config(self) -> Optional[MoeConfig]: - if self.has_moe: - return MoeConfig() - else: - return None - - def shard( - self, - world_size: int, - gpus_per_node: int, - rank: int = 0, - tp_degree: int = 1, - pp_degree: int = 1, - ): - if tp_degree * pp_degree != world_size: - raise ValueError( - f"tensor parallelism ({tp_degree}) x pipeline parallelism ({pp_degree})" - f" != world size ({world_size})" - ) - - self.mapping = Mapping( - world_size=world_size, - rank=rank, - gpus_per_node=gpus_per_node, - tp_size=tp_degree, - pp_size=pp_degree, - ) - - def get_plugins_config(self) -> PluginConfig: - return PluginConfig( - rmsnorm_quantization_plugin="disable", - layernorm_quantization_plugin="disable", - nccl_plugin="disable", - paged_kv_cache="enable", - enable_xqa="enable", - use_paged_context_fmha=None, - use_context_fmha_for_generation=None, - tokens_per_block=None, - attention_qk_half_accumulation=None, - multi_block_mode=None, - use_custom_all_reduce=True, - remove_input_padding=True, - context_fmha=None, - context_fmha_fp32_acc=None, - ) - - def save_pretrained(self, save_folder: Union[str, Path]): - config_dict = self.to_dict() - config_dict.pop("trt_model_class", None) - config_dict.pop("trt_model_file", None) - - with open(Path(save_folder, "config.json"), "w") as config_f: - json.dump(config_dict, config_f) diff --git a/src/optimum/nvidia/errors.py b/src/optimum/nvidia/errors.py index ba1841ba..1f22c8bb 100644 --- a/src/optimum/nvidia/errors.py +++ b/src/optimum/nvidia/errors.py @@ -25,7 +25,6 @@ def __init__(self, msg: str, operation: Optional[str] = None): super().__init__(f"{msg}") -### Model support class UnsupportedModelException(OptimumNvidiaException): def __init__(self, model_type: str): super().__init__( @@ -34,7 +33,6 @@ def __init__(self, model_type: str): ) -### Unsupported features blocks class UnsupportedHardwareFeature(OptimumNvidiaException): """ Base exception class for all features not supported by underlying hardware diff --git a/src/optimum/nvidia/export/__init__.py b/src/optimum/nvidia/export/__init__.py new file mode 100644 index 00000000..90c09c64 --- /dev/null +++ b/src/optimum/nvidia/export/__init__.py @@ -0,0 +1,8 @@ +PATH_FOLDER_CHECKPOINTS = "checkpoints" +PATH_FOLDER_ENGINES = "engines" +PATH_FILE_CHECKPOINTS = "rank*.safetensors" +PATH_FILE_ENGINES = "rank*.engine" + +from .workspace import Workspace # noqa +from .config import ExportConfig, auto_parallel +from .converter import TensorRTArtifact, TensorRTArtifactKind, TensorRTModelConverter diff --git a/src/optimum/nvidia/export/config.py b/src/optimum/nvidia/export/config.py new file mode 100644 index 00000000..c2dd7222 --- /dev/null +++ b/src/optimum/nvidia/export/config.py @@ -0,0 +1,186 @@ +from dataclasses import dataclass +from logging import getLogger +from os import PathLike +from typing import TYPE_CHECKING, Optional, Union +from warnings import warn + +from tensorrt_llm import BuildConfig +from tensorrt_llm import Mapping as ShardingInfo +from tensorrt_llm.plugin import PluginConfig +from transformers import AutoConfig + +from optimum.nvidia.lang import DataType +from optimum.utils import NormalizedConfig + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + +INFER_NUM_LOCAL_GPUS = -1 +LOGGER = getLogger() + + +@dataclass +class ExportConfig: + dtype: str + max_input_len: int + max_output_len: int + max_batch_size: int + + # Optional parameters + max_beam_width: int = 1 + max_num_tokens: int = -1 + enabled_chunked_context: int = False + + sharding: Optional[ShardingInfo] = None + + optimization_level: int = 3 + + def __post_init__(self): + if self.max_batch_size < 1: + raise ValueError(f"max_batch_size should >= 1, got {self.max_batch_size}") + + @staticmethod + def from_pretrained( + model_id_or_path: Union[str, PathLike], max_batch_size: int = 1 + ) -> "ExportConfig": + return ExportConfig.from_config( + AutoConfig.from_pretrained(model_id_or_path), max_batch_size + ) + + @staticmethod + def from_config( + config: Union[NormalizedConfig, "PretrainedConfig"], max_batch_size: int = 1 + ) -> "ExportConfig": + if not isinstance(config, NormalizedConfig): + config = NormalizedConfig(config) + + dtype = DataType.from_torch(config.torch_dtype).value + max_input_len = config.max_position_embeddings + max_output_len = config.max_position_embeddings + + return ExportConfig( + dtype=dtype, + max_input_len=max_input_len, + max_output_len=max_output_len, + max_batch_size=max_batch_size, + ).validate() + + def validate(self) -> "ExportConfig": + if self.optimization_level < 0: + raise ValueError( + f"optimization_level should be >= 0, got {self.optimization_level}" + ) + + if self.max_num_tokens == -1: + if self.enabled_chunked_context: + # Should be N * tokens_per_block + self.max_num_tokens = 128 # hardcode for now + warn( + f"max_num_tokens set to {self.max_num_tokens} with chunked context enabled might not be optimal." + ) + else: + self.max_num_tokens = 2 * self.max_input_len + + LOGGER.debug(f"Inferred max_num_tokens={self.max_num_tokens}") + + return self + + @property + def plugin_config(self) -> "PluginConfig": + config = PluginConfig() + config.gemm_plugin = self.dtype + config.bert_attention_plugin = self.dtype + config.gpt_attention_plugin = self.dtype + config.nccl_plugin = self.dtype + config.mamba_conv1d_plugin = self.dtype + config.moe_plugin = self.dtype + return config + + def to_builder_config( + self, plugin_config: Optional[PluginConfig] = None + ) -> "BuildConfig": + self.validate() + + return BuildConfig( + max_input_len=self.max_input_len, + max_seq_len=self.max_output_len, + max_batch_size=self.max_batch_size, + max_beam_width=self.max_beam_width, + max_num_tokens=self.max_num_tokens, + builder_opt=self.optimization_level, + plugin_config=plugin_config or self.plugin_config, + ) + + def with_sharding( + self, + tp: int = 1, + pp: int = 1, + gpus_per_node: int = 8, + sharding: Optional[ShardingInfo] = None, + ) -> "ExportConfig": + self.sharding = sharding or ShardingInfo( + tp_size=tp, pp_size=pp, world_size=tp * pp, gpus_per_node=gpus_per_node + ) + return self + + +def auto_parallel( + config: "ExportConfig", world_size: int = INFER_NUM_LOCAL_GPUS +) -> "ExportConfig": + """ + Helper to infer the most suitable parallelization strategy to apply to the model with respect to the local hardware. + :param config: `ExportConfig` the quantization process should be added to + :param world_size: Number of GPUs to consider when discovering automatic parallelization strategies + :return: `ExportConfig` + """ + # Infer number of GPUs on the system + if world_size < 1: + from optimum.nvidia.utils.nvml import get_device_count + + world_size = get_device_count() + + LOGGER.info(f"Found {world_size} GPUs on the system") + + # Handle all the different cases (0, 1, N > 1) + if world_size == 0: + raise ValueError("No GPU found") + elif world_size == 1: + return config.with_sharding(tp=1, pp=1, gpus_per_node=world_size) + else: + LOGGER.info(f"Creating auto-parallelization strategy on {world_size}-GPUs") + LOGGER.warning( + "Auto-parallelization strategy is currently in beta and might not be optimal" + ) + + if world_size == 2: + return config.with_sharding(tp=2, pp=1, gpus_per_node=world_size) + elif world_size == 4: + return config.with_sharding(tp=2, pp=2, gpus_per_node=world_size) + elif world_size == 8: + return config.with_sharding(tp=4, pp=2, gpus_per_node=world_size) + else: + raise ValueError( + f"Unsupported number of GPUs: {world_size}. " + "Please open-up and issue on the optimum-nvidia repository: " + "https://github.com/huggingface/optimum-nvidia" + ) + + +def sharded(config: "ExportConfig", tp: int = 1, pp: int = 1) -> "ExportConfig": + """ + Helper to specific the parallelization strategy to apply to the model + :param config: `ExportConfig` the quantization process should be added to + :param tp: Tensor Parallelism degree to apply (`int` >= 1) + :param pp: Pipeline Parallelism degree to apply (`int` >= 1) + :return: `ExportConfig` + """ + if tp < 1: + raise ValueError(f"Tensor Parallelism (tp) should be >= 1 (got: tp={tp})") + + if pp < 1: + raise ValueError(f"Pipeline Parallelism (pp) should be >= 1 (got: pp={pp})") + + return config.with_sharding( + sharding=ShardingInfo(tp_size=tp, pp_size=pp, world_size=tp * pp) + ) diff --git a/src/optimum/nvidia/export/converter.py b/src/optimum/nvidia/export/converter.py new file mode 100644 index 00000000..102d4c06 --- /dev/null +++ b/src/optimum/nvidia/export/converter.py @@ -0,0 +1,135 @@ +from abc import ABC +from enum import Enum +from logging import getLogger +from os import PathLike +from pathlib import Path +from typing import TYPE_CHECKING, Optional, Sequence, Type, Union + +from tensorrt_llm.builder import build + +from optimum.nvidia.export import Workspace +from optimum.nvidia.utils.nvml import get_device_name, is_post_ampere + + +if TYPE_CHECKING: + from tensorrt_llm import BuildConfig, Mapping + from tensorrt_llm.models import PretrainedModel + +LOGGER = getLogger() + + +def infer_plugin_from_build_config(config: "BuildConfig") -> "BuildConfig": + if is_post_ampere(): + # Required for Chunk Context + LOGGER.debug("Enabling Paged Context FMHA plugin") + config.plugin_config.update_from_dict({"use_paged_context_fmha": True}) + + return config + + +class TensorRTArtifactKind(Enum): + CHECKPOINTS = "checkpoints" + ENGINES = "engines" + + +class TensorRTArtifact: + @staticmethod + def checkpoints(root: Union[str, PathLike]) -> "TensorRTArtifact": + return TensorRTArtifact(TensorRTArtifactKind.CHECKPOINTS, root) + + @staticmethod + def engines(root: Union[str, PathLike]) -> "TensorRTArtifact": + return TensorRTArtifact(TensorRTArtifactKind.ENGINES, root) + + def __init__(self, kind: TensorRTArtifactKind, root: Union[str, PathLike]): + self._kind = kind + self._root = root + + @property + def kind(self) -> TensorRTArtifactKind: + return self._kind + + @property + def root(self) -> Path: + return Path(self._root) + + def push_to_hub(self): + raise NotImplementedError() + + +class TensorRTModelConverter(ABC): + CONFIG_CLASS: Type + MODEL_CLASS: Type + + def __init__( + self, + model_id: str, + subpart: str = "", + workspace: Optional[Union["Workspace", str, bytes, Path]] = None, + ): + LOGGER.info(f"Creating a model converter for {subpart}") + if not workspace: + target_device = get_device_name(0)[-1] + workspace = Workspace.from_hub_cache( + model_id, target_device, subpart=subpart + ) + + if isinstance(workspace, (str, bytes, Path)): + workspace = Workspace(Path(workspace)) + + LOGGER.debug(f"Initializing model converter workspace at {workspace.root}") + + self._workspace = workspace + + @property + def workspace(self) -> Workspace: + return self._workspace + + def quantize(self): + raise NotImplementedError() + + def convert( + self, + models: Union["PretrainedModel", Sequence["PretrainedModel"]], + mapping: Optional["Mapping"] = None, + ) -> TensorRTArtifact: + """ + Take a local model and create the intermediate TRTLLM checkpoint + :param models + :param mapping + :return: + """ + if isinstance(models, PretrainedModel): + models = [models] + + for rank, model in enumerate(models): + LOGGER.info( + f"Converting {models[0].config.architecture} model for rank {rank} to TRTLLM" + ) + model.save_checkpoint(str(self._workspace.checkpoints_path)) + + return TensorRTArtifact.checkpoints(str(self._workspace.checkpoints_path)) + + def build( + self, + models: Union["PretrainedModel", Sequence["PretrainedModel"]], + config: "BuildConfig", + ) -> TensorRTArtifact: + """ + :param models + :param config + :return: + """ + + if not isinstance(models, Sequence): + models = [models] + + config = infer_plugin_from_build_config(config) + + for rank, model in enumerate(models): + LOGGER.info(f"Building TRTLLM engine for rank {rank}") + + engine = build(model, config) + engine.save(str(self._workspace.engines_path)) + + return TensorRTArtifact.engines(str(self._workspace.engines_path)) diff --git a/src/optimum/nvidia/export/workspace.py b/src/optimum/nvidia/export/workspace.py new file mode 100644 index 00000000..bf3bdb32 --- /dev/null +++ b/src/optimum/nvidia/export/workspace.py @@ -0,0 +1,76 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable, Optional + +from huggingface_hub import cached_assets_path +from tensorrt_llm import __version__ as TRTLLM_VERSION + +from optimum.nvidia import LIBRARY_NAME +from optimum.nvidia.export import ( + PATH_FILE_CHECKPOINTS, + PATH_FILE_ENGINES, + PATH_FOLDER_CHECKPOINTS, + PATH_FOLDER_ENGINES, +) + + +@dataclass +class Workspace: + root: Path + + @staticmethod + def from_hub_cache( + model_id: str, + device: str, + namespace: str = LIBRARY_NAME, + version: str = TRTLLM_VERSION, + subpart: Optional[str] = None, + ) -> "Workspace": + assets_path = cached_assets_path( + namespace, namespace=version, subfolder=model_id + ) + assets_path = assets_path.joinpath(device) + + if subpart: + assets_path = assets_path.joinpath(subpart) + + assets_path.mkdir(exist_ok=True, parents=True) + return Workspace(assets_path) + + def __post_init__(self): + if not self.checkpoints_path.exists(): + self.checkpoints_path.mkdir(parents=True) + + if not self.engines_path.exists(): + self.engines_path.mkdir(parents=True) + + @property + def checkpoints_path(self) -> Path: + """ + Folder path location holding all the engines + :return: `Path` + """ + return self.root / PATH_FOLDER_CHECKPOINTS + + @property + def engines_path(self) -> Path: + """ + Folder path location holding all the engines + :return: `Path` + """ + return self.root / PATH_FOLDER_ENGINES + + @property + def checkpoints(self) -> Iterable[Path]: + """ + Generator discovering all the checkpoint files present in this workspace + :return: + """ + return self.checkpoints_path.glob(PATH_FILE_CHECKPOINTS) + + def engines(self) -> Iterable[Path]: + """ + Generator discovering all the engine files present in this workspace + :return: + """ + return self.engines_path.glob(PATH_FILE_ENGINES) diff --git a/src/optimum/nvidia/hub.py b/src/optimum/nvidia/hub.py index b5f53ab9..5e47275c 100644 --- a/src/optimum/nvidia/hub.py +++ b/src/optimum/nvidia/hub.py @@ -12,345 +12,91 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import shutil -from glob import glob, iglob from logging import getLogger +from os import PathLike, symlink from pathlib import Path +from shutil import copytree from typing import ( - Any, Dict, List, + Mapping, Optional, - Protocol, - Tuple, Type, Union, - runtime_checkable, ) -from warnings import warn -import numpy as np -import torch from huggingface_hub import ModelHubMixin, snapshot_download from huggingface_hub.hub_mixin import T -from safetensors.torch import save_file as to_safetensors -from tensorrt_llm import Mapping -from tensorrt_llm._utils import numpy_to_torch -from tensorrt_llm.models.modeling_utils import PretrainedConfig, PretrainedModel -from transformers import AutoConfig, AutoTokenizer, GenerationConfig -from transformers import PreTrainedModel as TransformersPretrainedModel -from transformers.utils import SAFE_WEIGHTS_INDEX_NAME - -from optimum.nvidia import TensorRTConfig -from optimum.nvidia.builder import LocalEngineBuilder -from optimum.nvidia.builder.config import EngineConfigBuilder -from optimum.nvidia.quantization import AutoQuantizationConfig -from optimum.nvidia.quantization.ammo import AmmoQuantizer -from optimum.nvidia.utils import get_user_agent, maybe_offload_weights_to_cpu +from tensorrt_llm import __version__ as trtllm_version +from transformers import AutoConfig, GenerationConfig +from transformers.utils import ( + CONFIG_NAME, + GENERATION_CONFIG_NAME, + SAFE_WEIGHTS_INDEX_NAME, +) + +from optimum.nvidia import LIBRARY_NAME +from optimum.nvidia.export import ( + PATH_FOLDER_CHECKPOINTS, + PATH_FOLDER_ENGINES, + ExportConfig, + TensorRTModelConverter, + auto_parallel, +) +from optimum.nvidia.lang import DataType +from optimum.nvidia.models import ( + SupportsFromHuggingFace, + SupportsTransformersConversion, +) +from optimum.nvidia.utils import get_user_agent +from optimum.nvidia.utils.nvml import get_device_count, get_device_name +from optimum.utils import NormalizedConfig ATTR_TRTLLM_ENGINE_FOLDER = "__trtllm_engine_folder__" -ATTR_TRTLLM_CHECKPOINT_FOLDER = "__trtllm_checkpoint_folder__" -FOLDER_TRTLLM_CHECKPOINTS = "checkpoints" -FOLDER_TRTLLM_ENGINES = "engines" -FILE_TRTLLM_CHECKPOINT_PATTERN = "rank[0-9]*.safetensors" FILE_TRTLLM_ENGINE_PATTERN = "rank[0-9]*.engine" - -HUB_TRTLLM_ENGINE_PATTERNS = ["**/config.json", f"**/{FILE_TRTLLM_ENGINE_PATTERN}"] -HUB_SAFETENSORS_PATTERNS = ["config.json", "*.safetensors", SAFE_WEIGHTS_INDEX_NAME] - -SHARDING_KWARGS = {"tp_size", "pp_size", "world_size", "gpus_per_node", "rank"} -SHARDING_KWARGS_ALIASES = { - "tp": "tp_size", - "pp": "pp_size", -} +FILE_TRTLLM_CHECKPOINT_PATTERN = "rank[0-9]*.engine" +HUB_SNAPSHOT_ALLOW_PATTERNS = [ + CONFIG_NAME, + GENERATION_CONFIG_NAME, + SAFE_WEIGHTS_INDEX_NAME, + "*.safetensors", +] LOGGER = getLogger() -LOGGER.setLevel(level="INFO") - - -def extract_model_type(config: Dict[str, Any]) -> Tuple[Optional[str], bool]: - if "model_type" in config: - model_type = config["model_type"] - is_tensorrt_config = False - - # This path try to extract from the TensorRTLLM config - elif ( - "pretrained_config" in config and "architecture" in config["pretrained_config"] - ): - model_type = config["pretrained_config"]["architecture"] - prefix_pos = model_type.index("For") - model_type = model_type[:prefix_pos].lower() - is_tensorrt_config = True - else: - return None, False - - return model_type, is_tensorrt_config - - -def find_prebuilt_engines(root: Path) -> Tuple[List[Path], List[Path]]: - """ - Attempt to locate any prebuilt TRT engines at the provided root path or root' subfolders. - :param root: The directory we should look into for the engine files. - :return: The list of `Path` where engines were found in root (or root/FOLDER_TRTLLM_ENGINES, or root/**/FOLDER_TRTLLM_ENGINES), and the corresponding relative path to root. - """ - folders = [] - relative_folders = [] - - files = glob(Path(root, f"**/{FOLDER_TRTLLM_ENGINES}/rank*.engine").as_posix()) - for file in files: - file_directory = Path(Path(file).parents[0]) - folders.append(file_directory) - relative_folders.append(file_directory.relative_to(root)) - - files = glob(Path(root, "rank*.engine").as_posix()) - if len(files) > 0: - folders.append(root) - relative_folders.append(".") - - files = glob(Path(root, f"{FOLDER_TRTLLM_ENGINES}/rank*.engine").as_posix()) - if len(files) > 0: - folders.append(Path(root, FOLDER_TRTLLM_ENGINES)) - relative_folders.append(FOLDER_TRTLLM_ENGINES) - - folders = list(dict.fromkeys(folders)) - relative_folders = list(dict.fromkeys(relative_folders)) - - # TODO: Handle this properly - we should enforce the directory names and use dicts instead of lists here. - if len(relative_folders) == 2 and "decoder" in relative_folders[0].as_posix(): - folders.reverse() - relative_folders.reverse() - - return folders, relative_folders - - -def get_sharding_info(**kwargs): - mapping_kwargs = {} - - for name_, value in kwargs.items(): - # Handle aliased args - if name_ in SHARDING_KWARGS_ALIASES: - name = SHARDING_KWARGS_ALIASES[name_] - else: - name = name_ - - # Ensure we are targeting a sharding kwarg - if name in SHARDING_KWARGS: - # Check if not already present - if name in mapping_kwargs and name != name_: - LOGGER.warning(f"Parameter {name} already defined through {name_}") - - mapping_kwargs[name] = value - - if mapping_kwargs: - if "world_size" not in mapping_kwargs: - mapping_kwargs["world_size"] = ( - mapping_kwargs["tp_size"] * mapping_kwargs["pp_size"] - ) - LOGGER.debug( - f"Set sharding_info's world_size to {mapping_kwargs['world_size']}" - ) - return Mapping(**mapping_kwargs) - else: - return None - - -@runtime_checkable -class SupportsTensorrtConversion(Protocol): - MODEL_CONFIG: Type[TensorRTConfig] - HF_LIBRARY_TARGET_MODEL_CLASS: Type[ModelHubMixin] - TRT_LLM_TARGET_MODEL_CLASS: Type[PretrainedModel] - @staticmethod - def convert_weights( - target: PretrainedModel, - source: TransformersPretrainedModel, - config: PretrainedConfig, - ) -> Dict[str, np.ndarray]: ... - - -class HuggingFaceHubModel(ModelHubMixin, SupportsTensorrtConversion): - @classmethod - def convert_and_build( - cls, - local_path: Path, - hf_model_config: Dict, - engine_save_path: Optional[Path] = None, - hf_model: Optional[TransformersPretrainedModel] = None, - config_class: Optional[TensorRTConfig] = None, - **model_kwargs, - ) -> Tuple[List[Path], List[Path], List[Path]]: - """ - :param local_path: - :param hf_model_config: - :param model_kwargs: - :return: - """ - - # Path where will be stored the engines - root = engine_save_path if engine_save_path else local_path - checkpoint_folder = root / FOLDER_TRTLLM_CHECKPOINTS - engines_folder = root / FOLDER_TRTLLM_ENGINES - - # Ensure all the tree exists - checkpoint_folder.mkdir(exist_ok=True, parents=True) - engines_folder.mkdir(exist_ok=True, parents=True) - - # Retrieve configuration - config = AutoConfig.for_model(**hf_model_config) - if "torch_dtype" in model_kwargs: - config.torch_dtype = model_kwargs["torch_dtype"] - - # Get sharding information to forward to the config converted - model_kwargs["mapping"] = get_sharding_info(**model_kwargs) - - # Convert the original config to a model config TRTLLM understands - model_config = HuggingFaceHubModel.convert_config_to_trtllm( - cls, config, config_class=config_class, **model_kwargs +def get_trtllm_artifact(model_id: str, patterns: List[str]) -> Path: + return Path( + snapshot_download( + repo_id=model_id, + repo_type="model", + library_name=LIBRARY_NAME, + library_version=trtllm_version, + user_agent=get_user_agent(), + allow_patterns=patterns, ) + ) - if model_config.has_moe: - model_config.moe_config.validate() - - # We now have a TRTLLM compatible config, so let's feed it to the target TRTLLM model to create a checkpoint - LOGGER.debug("Allocating TRTLLM model to build the checkpoint") - model = cls.TRT_LLM_TARGET_MODEL_CLASS.from_config(model_config) - - # Retrieve the parameters for building the engine - if "engine_config" in model_kwargs: - engine_config = model_kwargs.pop("engine_config") - else: - builder = EngineConfigBuilder.from_dict(model_config, **model_kwargs) - builder.with_plugins_config(model_config.get_plugins_config()) - engine_config = builder.build() - - if engine_config.plugins_config is None: - engine_config.plugins_config = model_config.get_plugins_config() - - if hf_model is None: - LOGGER.debug( - f"Loading weights from {local_path} into the model ({cls.HF_LIBRARY_TARGET_MODEL_CLASS.__name__})" - ) - - hf_model = cls.HF_LIBRARY_TARGET_MODEL_CLASS.from_pretrained( - local_path, - torch_dtype="auto", - device_map="cpu", - local_files_only=True, - ) - else: - if not isinstance(hf_model, cls.HF_LIBRARY_TARGET_MODEL_CLASS): - raise ValueError( - f"Expected a {cls.HF_LIBRARY_TARGET_MODEL_CLASS.__name__} model to be provided, but the argument `hf_model` is a {hf_model.__class__.__name__}." - ) - - hf_model = hf_model.eval() - hf_model = maybe_offload_weights_to_cpu(hf_model) - - # Retrieve potential quantization config (If provided) - follow the transformers parameter's name - has_qconfig = "quantization_config" in model_kwargs - use_fp8 = model_kwargs.get("use_fp8", False) - - if has_qconfig or use_fp8: - LOGGER.debug("About to quantize Hugging Face model") - - if has_qconfig: - qconfig = model_kwargs.pop("quantization_config") - elif use_fp8: - if ( - candidate_tokenizer_path := engines_folder.parent.joinpath( - "tokenizer.json" - ) - ).exists(): - tokenizer_path = candidate_tokenizer_path.parent - elif "_model_id" in hf_model_config: - tokenizer_path = hf_model_config["_model_id"] - else: - raise ValueError( - "Unable to determine the tokenizer to use to quantize this model. " - "Please provide a complete QuantizationConfig using " - "from_pretrained(..., quantization_config=AutoQuantizationConfig.from_description())" - ) - - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - qconfig = AutoQuantizationConfig.from_description( - weight="float8", - activation="float8", - tokenizer=tokenizer, - dataset="c4-new", - ) - - warn( - "Converting model to support float8 inference.\n" - f"Calibrating model with dataset='c4', split='train', samples={len(qconfig.calibration_dataset)}.\n" - "Note: if text generation doesn't meet your expectations, " - "you can control the quantization process manually with this API: " - "qconfig = AutoQuantizationConfig.from_description(weight='float8', activation='float8', ...) " - "forwarding the configuration to .from_pretrained(..., quantization_config=qconfig)" - ) - - hf_quantizer = AmmoQuantizer( - quantization_config=qconfig, - artifact_path=checkpoint_folder, - tensor_parallel_degree=engine_config.sharding_profile.tensor_parallelism, - pipeline_parallel_degree=engine_config.sharding_profile.pipeline_parallelism, - export_tensorrt_llm_config=True, - ) - - hf_quantizer.preprocess_model(hf_model, batch_size=1) - hf_quantizer.postprocess_model(hf_model) - - else: - # Apply the conversion from Hugging Face weights to TRTLLM - for rank in range(model_config.mapping.world_size): - LOGGER.debug( - f"Converting weights from Hugging Face checkpoint for rank {rank}" - ) - model_config.set_rank(rank) - converted_weights = cls.convert_weights(model, hf_model, model_config) - - # Bind the converted weights against the TRTLLM model - model.load(converted_weights) - - # Write ranked-checkpoints - converted_weights = { - name: numpy_to_torch(tensor) - if isinstance(tensor, np.ndarray) - else tensor - for name, tensor in converted_weights.items() - } - to_safetensors( - converted_weights, - checkpoint_folder / f"rank{model_config.mapping.rank}.safetensors", - ) - # Write global config - model_config.save_pretrained(checkpoint_folder) - - # We are freeing memory used by the HF Model to let the engine build goes forward - del hf_model - torch.cuda.empty_cache() - - # Build - engine_builder = LocalEngineBuilder( - model_config, checkpoint_folder, engines_folder - ) - engine_builder.build(engine_config) - - return ( - [checkpoint_folder], - [engines_folder], - [engines_folder.relative_to(local_path)], - ) +class HuggingFaceHubModel( + ModelHubMixin, + library_name=LIBRARY_NAME, + languages=["python", "c++"], + tags=["optimum-nvidia", "trtllm"], + repo_url="https://github.com/huggingface/optimum-nvidia", + docs_url="https://huggingface.co/docs/optimum/nvidia_overview", +): + def __init__(self, engines_path: Union[str, PathLike, Path]): + self._engines_path = Path(engines_path) + super().__init__() @classmethod def _from_pretrained( cls: Type[T], *, model_id: str, - config: Dict[str, Any], + config: Dict, revision: Optional[str], cache_dir: Optional[Union[str, Path]], force_download: bool, @@ -358,80 +104,61 @@ def _from_pretrained( resume_download: bool, local_files_only: bool, token: Optional[Union[str, bool]], - **model_kwargs, + use_cuda_graph: bool = False, + device_map: Optional[str] = None, + export_config: Optional[ExportConfig] = None, + force_export: bool = False, + save_intermediate_checkpoints: bool = False, ) -> T: - # Config attributes are being injected also along "config"... - for key in config.keys(): - if key in model_kwargs: - model_kwargs.pop(key) - - if not isinstance(cls, SupportsTensorrtConversion): - raise ValueError( - f"{cls} doesn't support converting from Hugging Face Hub model." - " Please open up an issue at https://github.com/huggingface/optimum-nvidia/issues" + if get_device_count() < 1: + raise ValueError("No GPU detected on this platform") + + device_name = get_device_name(0)[-1] + common_hub_path = f"{device_name}/{config['torch_dtype']}" + + # Look for prebuild TRTLLM Engine + engine_files = checkpoint_files = [] + if not force_export: + LOGGER.debug(f"Retrieving prebuild engine(s) for device {device_name}") + cached_path = get_trtllm_artifact( + model_id, [f"{common_hub_path}/**/{PATH_FOLDER_ENGINES}/*.engine"] ) - # Check if we are using a local path - if not (local_path := Path(model_id)).exists(): - LOGGER.debug( - f"Loading potential prebuilt engines from the Hub ({model_id}@{revision})" - ) - - # Let's retrieve the weights for this model - # NOTE: We use `snapshot_download` to be able to provide a custom user-agent - # NOTE: maybe we can do the same with `from_pretrained` - local_path = HuggingFaceHubModel.retrieve_snapshot_from_hub( - model_id, - revision, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - token, - prebuilt_engines_only=True, - ) + if ( + engines_config_path := ( + cached_path / PATH_FOLDER_ENGINES / "config.json" + ) + ).exists(): + LOGGER.info(f"Found engines at {engines_config_path.parent}") + engine_files = engines_config_path.parent.glob( + FILE_TRTLLM_ENGINE_PATTERN + ) - # Look for prebuilt engine files, if none found, we convert and build - checkpoints_folders, (engines_folders, relative_paths_engines_folders) = ( - None, - find_prebuilt_engines(local_path), - ) - if len(engines_folders) == 0: - LOGGER.info( - f"No engine file found in {local_path}, converting and building engines" + # if no engine is found, then just try to locate a checkpoint + if not engine_files: + LOGGER.debug(f"Retrieving checkpoint(s) for {device_name}") + cached_path = get_trtllm_artifact( + model_id, [f"{common_hub_path}/**/*.safetensors"] ) - # If local_path exists and is not empty we have a local snapshot if ( - not local_path.exists() - or len(list(local_path.glob("*.safetensors"))) == 0 - ): - LOGGER.debug( - f"Loading original transformers weights from the Hub ({model_id}@{revision})" + checkpoints_config_path := ( + cached_path / PATH_FOLDER_CHECKPOINTS / "config.json" ) - - local_path = HuggingFaceHubModel.retrieve_snapshot_from_hub( - model_id, - revision, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - token, - prebuilt_engines_only=False, + ).exists(): + LOGGER.info(f"Found checkpoints at {checkpoints_config_path.parent}") + checkpoint_files = checkpoints_config_path.parent.glob( + FILE_TRTLLM_CHECKPOINT_PATTERN ) - checkpoint_folders, engines_folders, relative_paths_engines_folders = ( - cls.convert_and_build(local_path, config, **model_kwargs) - ) - else: - LOGGER.info(f"Found pre-built engines at: {engines_folders}") + # If no checkpoint available, we are good for a full export from the Hugging Face Hub + if not checkpoint_files: + LOGGER.info(f"No prebuild engines nor checkpoint were found for {model_id}") - try: - generation_config = GenerationConfig.from_pretrained( + # Retrieve the snapshot if needed + original_checkpoints_path_for_conversion = snapshot_download( model_id, + repo_type="model", revision=revision, cache_dir=cache_dir, force_download=force_download, @@ -439,133 +166,97 @@ def _from_pretrained( resume_download=resume_download, local_files_only=local_files_only, token=token, + allow_patterns=HUB_SNAPSHOT_ALLOW_PATTERNS, ) - except OSError: - generation_config = None - - transformers_config = AutoConfig.for_model(**config) - model = cls( - engines_folders, - gpus_per_node=model_kwargs.pop("gpus_per_node", 1), - use_cuda_graph=model_kwargs.pop("use_cuda_graph", False), - generation_config=generation_config, - transformers_config=transformers_config, - ) - - setattr(model, ATTR_TRTLLM_CHECKPOINT_FOLDER, checkpoints_folders) - setattr(model, ATTR_TRTLLM_ENGINE_FOLDER, engines_folders) - model._relative_path_engines_folders = relative_paths_engines_folders - return model - def _save_pretrained(self, save_directory: Path) -> None: - if not hasattr(self, ATTR_TRTLLM_ENGINE_FOLDER): - raise ValueError( - "Unable to determine the root folder containing TensorRT-LLM engines. " - "Please open-up an issue at https://github.com/huggingface/optimum-nvidia" + # Retrieve a proper transformers' config + config = NormalizedConfig(AutoConfig.for_model(**config)) + generation_config = GenerationConfig.from_pretrained( + original_checkpoints_path_for_conversion ) - self.config.save_pretrained(save_directory) - if self.generation_config is not None: - self.generation_config.save_pretrained(save_directory) + # If no export config, let's grab a default one + export_config = export_config or ExportConfig.from_config(config) - # Retrieve the folder - engines_folders = getattr(self, ATTR_TRTLLM_ENGINE_FOLDER) - for i in range(len(engines_folders)): - engine_folder = Path(engines_folders[i]) - save_subfolder = self._relative_path_engines_folders[i] + # Handle the device_map + if device_map and device_map == "auto": + LOGGER.info("Auto-parallel we will be used") + export_config = auto_parallel(export_config) - LOGGER.debug(f"Saving engines at {save_directory}") + # Forward everything to the exporter + if isinstance(cls, SupportsTransformersConversion): + targets = cls.TRT_LLM_TARGET_MODEL_CLASSES - save_engine_directory = save_directory / save_subfolder - save_engine_directory.mkdir(exist_ok=True, parents=True) + if not isinstance(targets, Mapping): + targets = {"": targets} - # Move the engine(s) - for engine in iglob(FILE_TRTLLM_ENGINE_PATTERN, root_dir=engine_folder): - LOGGER.debug( - f"Moving file {engine_folder / engine} to {save_engine_directory / engine}" - ) - shutil.copyfile(engine_folder / engine, save_engine_directory / engine) + for idx, (subpart, clazz) in enumerate(targets.items()): + LOGGER.info( + f"Building {model_id} {subpart} ({idx + 1} / {len(targets)})" + ) - # Move the configuration - if (config_path := engine_folder / "config.json").exists(): - shutil.copyfile(config_path, save_engine_directory / "config.json") + converter = TensorRTModelConverter(model_id, subpart) + + # Artifacts resulting from a build are not stored in the location `snapshot_download` + # would use. Instead, it uses `cached_assets_path` to create a specific location which + # doesn't mess up with the HF caching system. Use can use `save_pretrained` to store + # the build artifact into a snapshot friendly place + # If this specific location is found, we don't necessary need to rebuild + + if force_export or not len( + list(converter.workspace.engines_path.glob("*.engine")) + ): + if not isinstance(clazz, SupportsFromHuggingFace): + raise TypeError(f"{clazz} can't convert from HF checkpoint") + + for rank in range(export_config.sharding.world_size): + # Specify the current model's rank + export_config.sharding.rank = rank + + ranked_model = clazz.from_hugging_face( + original_checkpoints_path_for_conversion, + dtype=DataType.from_torch(config.torch_dtype).value, + mapping=export_config.sharding, + load_by_shard=True, + ) + + build_config = export_config.to_builder_config() + + if save_intermediate_checkpoints: + _ = converter.convert(ranked_model) + LOGGER.info( + f"Saved intermediate checkpoints at {converter.workspace.checkpoints_path}" + ) + + _ = converter.build(ranked_model, build_config) + + LOGGER.info( + f"Saved TensorRT-LLM engines at {converter.workspace.engines_path}" + ) + else: + LOGGER.info( + f"Found existing engines at {converter.workspace.engines_path}" + ) else: - LOGGER.warning( - f"No config.json found at {config_path}. It might not be possible to reload the engines." + raise ValueError( + "Model doesn't support Hugging Face transformers conversion, aborting." ) - if (original_config_path := engine_folder / ".." / "config.json").exists(): - shutil.copyfile(original_config_path, save_directory / "config.json") - - @staticmethod - def convert_config_to_trtllm( - cls: Type[SupportsTensorrtConversion], - config: Union[PretrainedConfig, Dict[str, Any]], - config_class: Optional[TensorRTConfig] = None, - **additional_params, - ) -> TensorRTConfig: - """ - Convert a configuration initially generated by various Hugging Face libraries like transformers, diffusers, etc. - to TensorRT-LLM `tensorrt_llm.modeling_utils.PretrainedConfig`. - :param cls: The target class to allocate the model - :param config: The original library configuration file - :param config_class: An optional custom TensorRTConfig subclass to load into. - :param additional_params: - :return: `tensorrt_llm.modeling_utils.PretrainedConfig` - """ - if config_class is None: - config_class = cls.MODEL_CONFIG - - mapping = additional_params.get("mapping", None) - - trt_config = config_class.from_config(config, mapping) - if hasattr(trt_config, "check_config"): - trt_config.check_config() - - return trt_config - - @staticmethod - def retrieve_snapshot_from_hub( - model_id: str, - revision: Optional[str], - cache_dir: Optional[Union[str, Path]], - force_download: bool, - proxies: Optional[Dict], - resume_download: bool, - local_files_only: bool, - token: Optional[Union[str, bool]], - prebuilt_engines_only: bool = False, - ) -> Path: - """ - - :param model_id: - :param revision: - :param cache_dir: - :param force_download: - :param proxies: - :param resume_download: - :param local_files_only: - :param token: - :param prebuilt_engines_only: - :return: - """ - patterns = ( - HUB_TRTLLM_ENGINE_PATTERNS - if prebuilt_engines_only - else HUB_SAFETENSORS_PATTERNS - ) - patterns += ["config.json"] - local_path = snapshot_download( - repo_id=model_id, - revision=revision, - cache_dir=cache_dir, - allow_patterns=patterns, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - token=token, - user_agent=get_user_agent(), - ) + return cls( + engines_path=converter.workspace.engines_path, + generation_config=generation_config, + ) - return Path(local_path) + def _save_pretrained(self, save_directory: Path) -> None: + try: + # Need target_is_directory on Windows + # Windows10 needs elevated privilege for symlink which will raise OSError if not the case + # Falling back to copytree in this case + symlink(self._engines_path.parent, save_directory, target_is_directory=True) + except OSError as ose: + LOGGER.error( + f"Failed to create symlink from current engine folder {self._engines_path.parent} to {save_directory}. " + "Will default to copy based _save_pretrained", + exc_info=ose, + ) + copytree(self._engines_path.parent, save_directory, symlinks=True) diff --git a/src/optimum/nvidia/logging.py b/src/optimum/nvidia/logging.py index c09cafa8..9035972a 100644 --- a/src/optimum/nvidia/logging.py +++ b/src/optimum/nvidia/logging.py @@ -14,12 +14,14 @@ # limitations under the License. from logging import DEBUG, INFO, basicConfig -from tensorrt_llm.logger import logger - DEFAULT_LOGGING_FMT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" -def setup_logging(verbose: bool = False): - logger.set_level("verbose" if verbose else "info") +def setup_logging(verbose: bool = False, for_trtllm_logger: bool = False): basicConfig(format=DEFAULT_LOGGING_FMT, level=DEBUG if verbose else INFO) + + if for_trtllm_logger: + from tensorrt_llm.logger import logger + + logger.set_level("verbose" if verbose else "info") diff --git a/src/optimum/nvidia/models/__init__.py b/src/optimum/nvidia/models/__init__.py index 5655e416..767e6926 100644 --- a/src/optimum/nvidia/models/__init__.py +++ b/src/optimum/nvidia/models/__init__.py @@ -13,5 +13,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .base import SupportsFromHuggingFace, SupportsTransformersConversion # noqa from .auto import AutoModelForCausalLM -from .base import SupportsFromHuggingFace diff --git a/src/optimum/nvidia/models/auto.py b/src/optimum/nvidia/models/auto.py index d71e7882..f8c98550 100644 --- a/src/optimum/nvidia/models/auto.py +++ b/src/optimum/nvidia/models/auto.py @@ -14,30 +14,34 @@ # limitations under the License. from pathlib import Path -from typing import Any, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union from huggingface_hub import ModelHubMixin from optimum.nvidia.errors import UnsupportedModelException +from optimum.nvidia.models.gemma import GemmaForCausalLM +from optimum.nvidia.models.llama import LlamaForCausalLM -from ..hub import extract_model_type -from .gemma import GemmaForCausalLM -from .llama import LlamaForCausalLM -from .mistral import MistralForCausalLM -from .mixtral import MixtralForCausalLM - -_SUPPORTED_MODEL_CLASS = { - "llama": LlamaForCausalLM, - "mistral": MistralForCausalLM, - "mixtral": MixtralForCausalLM, - "gemma": GemmaForCausalLM, -} +if TYPE_CHECKING: + from optimum.nvidia.export import ExportConfig + from optimum.nvidia.runtime import CausalLM class AutoModelForCausalLM(ModelHubMixin): """ """ + _SUPPORTED_MODEL_CLASS = { + "llama": LlamaForCausalLM, + "mistral": LlamaForCausalLM, + "mixtral": LlamaForCausalLM, + "gemma": GemmaForCausalLM, + # "phi": PhiForCausalLM + } + + def __init__(self): + super().__init__() + @classmethod def _from_pretrained( cls: Type, @@ -51,22 +55,19 @@ def _from_pretrained( local_files_only: bool, token: Optional[Union[str, bool]], config: Optional[Dict[str, Any]] = None, + export_config: Optional["ExportConfig"] = None, + force_export: bool = False, + use_cuda_graph: bool = False, **model_kwargs, - ): + ) -> "CausalLM": if config is None: raise ValueError("Unable to determine the model type with config = None") - model_type, _ = extract_model_type(config) - if model_type is None: - raise ValueError( - "Unable to determine the model type from the provided config. " - "Please open-up an issue at https://github.com/huggingface/optimum-nvidia/issues" - ) - - if model_type not in _SUPPORTED_MODEL_CLASS: + model_type = config["model_type"] + if model_type not in AutoModelForCausalLM._SUPPORTED_MODEL_CLASS: raise UnsupportedModelException(model_type) - model_clazz = _SUPPORTED_MODEL_CLASS[model_type] + model_clazz = AutoModelForCausalLM._SUPPORTED_MODEL_CLASS[model_type] model = model_clazz.from_pretrained( pretrained_model_name_or_path=model_id, config=config, @@ -77,6 +78,9 @@ def _from_pretrained( resume_download=resume_download, local_files_only=local_files_only, token=token, + export_config=export_config, + force_export=force_export, + use_cuda_graph=use_cuda_graph, **model_kwargs, ) diff --git a/src/optimum/nvidia/models/base.py b/src/optimum/nvidia/models/base.py index 372008d5..ea0f29c7 100644 --- a/src/optimum/nvidia/models/base.py +++ b/src/optimum/nvidia/models/base.py @@ -1,47 +1,35 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +from os import PathLike +from typing import ( + TYPE_CHECKING, + Mapping, + Optional, + Protocol, + Type, + Union, + runtime_checkable, +) -from typing import Optional, Protocol, runtime_checkable - -from tensorrt_llm import Mapping -from tensorrt_llm.quantization import QuantMode +if TYPE_CHECKING: + from tensorrt_llm.top_model_mixin import TopModelMixin + from transformers import PreTrainedModel as TransformersPreTrainedModel @runtime_checkable class SupportsFromHuggingFace(Protocol): - """ - Define the protocol implemented by TensorRT-LLM models to support loading from Hugging Face Hub - """ - @classmethod def from_hugging_face( cls, - hf_model_dir, - dtype="float16", + hf_model_dir: Union[str, bytes, PathLike], + dtype: str = "float16", mapping: Optional[Mapping] = None, - quant_mode: Optional[QuantMode] = None, **kwargs, - ): - """ + ): ... - :param hf_model_dir: - :param dtype: - :param mapping: - :param quant_mode: - :param kwargs: - :return: - """ - ... + +@runtime_checkable +class SupportsTransformersConversion(Protocol): + HF_LIBRARY_TARGET_MODEL_CLASS: Type["TransformersPreTrainedModel"] + TRT_LLM_TARGET_MODEL_CLASSES: Union[ + Type["TopModelMixin"], Mapping[str, Type["TopModelMixin"]] + ] diff --git a/src/optimum/nvidia/models/gemma.py b/src/optimum/nvidia/models/gemma.py index 14f14bd4..2412cc46 100644 --- a/src/optimum/nvidia/models/gemma.py +++ b/src/optimum/nvidia/models/gemma.py @@ -12,108 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from logging import getLogger -from typing import Dict, Optional -import torch -from tensorrt_llm import Mapping -from tensorrt_llm.models import PretrainedConfig, PretrainedModel from tensorrt_llm.models.gemma.model import GemmaForCausalLM as TrtGemmaForCausalLM -from tensorrt_llm.models.gemma.weight import load_from_hf_gemma -from tensorrt_llm.plugin import PluginConfig from transformers import GemmaForCausalLM as TransformersGemmaForCausalLM -from transformers import PretrainedConfig as TransformersPretrainedConfig -from transformers import PreTrainedModel as TransformersPretrainedModel -from optimum.nvidia import TensorRTConfig -from optimum.nvidia.config import dtype_to_str from optimum.nvidia.hub import HuggingFaceHubModel +from optimum.nvidia.models import SupportsTransformersConversion from optimum.nvidia.runtime import CausalLM LOGGER = getLogger(__name__) -class GemmaConfig(TensorRTConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaGemmaConfig`]. It is used to instantiate an Gemma - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Gemma-7B. - - Configuration objects inherit from [`TensorRTConfig`] and can be used to control the model outputs. Read the - documentation from [`TensorRTConfig`] for more information. - """ - - @staticmethod - def from_config( - config: TransformersPretrainedConfig, mapping: Optional[Mapping] = None - ) -> "TensorRTConfig": - mapping = mapping or Mapping() - - # Retrieve the quantization from the transformers config (if provided) - _, qconfig = TensorRTConfig.get_quantization_config(config) - - trt_config = GemmaConfig( - architecture=config.architectures[0], - dtype=dtype_to_str(config.torch_dtype), - logits_dtype="float32", - vocab_size=config.vocab_size, - max_position_embeddings=config.max_position_embeddings, - hidden_size=config.hidden_size, - num_hidden_layers=config.num_hidden_layers, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=getattr( - config, "num_key_value_heads", config.num_attention_heads - ), - head_size=config.head_dim, - hidden_act=config.hidden_act, - intermediate_size=config.intermediate_size, - norm_epsilon=config.rms_norm_eps, - position_embedding_type="rope_gpt_neox", - rotary_base=getattr(config, "rope_theta", 10000.0), - rotary_scaling=getattr(config, "rope_scaling", None), - world_size=mapping.world_size, - tp_size=mapping.tp_size, - pp_size=mapping.pp_size, - use_prompt_tuning=False, - use_parallel_embedding=mapping.tp_size > 1, - embedding_sharding_dim=0, - share_embedding_table=False, - max_lora_rank=64, - quantization=qconfig, - ) - - trt_config.mapping.gpus_per_node = min(trt_config.mapping.world_size, 8) - - return trt_config - - def get_plugins_config(self) -> PluginConfig: - config = super().get_plugins_config() - config.moe_plugin = "disable" - config.bert_attention_plugin = "disable" - config.gpt_attention_plugin = self.dtype - config.gemm_plugin = self.dtype - - return config - - @staticmethod - def supports_strong_typing() -> bool: - return False - - -class GemmaForCausalLM(CausalLM, HuggingFaceHubModel): - MODEL_CONFIG = GemmaConfig +class GemmaForCausalLM(CausalLM, HuggingFaceHubModel, SupportsTransformersConversion): HF_LIBRARY_TARGET_MODEL_CLASS = TransformersGemmaForCausalLM - TRT_LLM_TARGET_MODEL_CLASS = TrtGemmaForCausalLM - - @staticmethod - def convert_weights( - target: PretrainedModel, - source: TransformersPretrainedModel, - config: PretrainedConfig, - ) -> Dict[str, torch.Tensor]: - if config.quant_mode.has_any_quant(): - raise NotImplementedError("Quantization is not supported yet.") + TRT_LLM_TARGET_MODEL_CLASSES = TrtGemmaForCausalLM - return load_from_hf_gemma(target, source, config.mapping, config.dtype) + TRT_LLM_MANDATORY_CONVERSION_PARAMS = {"share_embedding_table": True} diff --git a/src/optimum/nvidia/models/llama.py b/src/optimum/nvidia/models/llama.py index 65f84c75..a7b02692 100644 --- a/src/optimum/nvidia/models/llama.py +++ b/src/optimum/nvidia/models/llama.py @@ -13,108 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. from logging import getLogger -from typing import Dict, Optional -import numpy as np -from tensorrt_llm import Mapping -from tensorrt_llm.models import PretrainedConfig, PretrainedModel -from tensorrt_llm.models.llama.convert import load_weights_from_hf from tensorrt_llm.models.llama.model import LLaMAForCausalLM -from tensorrt_llm.plugin import PluginConfig from transformers import LlamaForCausalLM as TransformersLlamaForCausalLM -from transformers import PretrainedConfig as TransformersPretrainedConfig -from transformers import PreTrainedModel as TransformersPretrainedModel -from optimum.nvidia import TensorRTConfig -from optimum.nvidia.config import dtype_to_str from optimum.nvidia.hub import HuggingFaceHubModel +from optimum.nvidia.models import SupportsTransformersConversion from optimum.nvidia.runtime import CausalLM LOGGER = getLogger(__name__) -class LlamaConfig(TensorRTConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`TensorRTConfig`] and can be used to control the model outputs. Read the - documentation from [`TensorRTConfig`] for more information. - """ - - @staticmethod - def from_config( - config: TransformersPretrainedConfig, mapping: Optional[Mapping] = None - ) -> "TensorRTConfig": - mapping = mapping or Mapping() - - # Retrieve the quantization from the transformers config (if provided) - _, qconfig = TensorRTConfig.get_quantization_config(config) - - trt_config = LlamaConfig( - architecture=config.architectures[0], - dtype=dtype_to_str(config.torch_dtype), - logits_dtype="float32", - vocab_size=config.vocab_size, - max_position_embeddings=config.max_position_embeddings, - hidden_size=config.hidden_size, - num_hidden_layers=config.num_hidden_layers, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=getattr( - config, "num_key_value_heads", config.num_attention_heads - ), - hidden_act=config.hidden_act, - intermediate_size=config.intermediate_size, - norm_epsilon=config.rms_norm_eps, - position_embedding_type="rope_gpt_neox", - rotary_base=getattr(config, "rope_theta", 10000.0), - rotary_scaling=getattr(config, "rope_scaling", None), - world_size=mapping.world_size, - tp_size=mapping.tp_size, - pp_size=mapping.pp_size, - use_prompt_tuning=False, - use_parallel_embedding=False, - embedding_sharding_dim=0, - share_embedding_table=False, - max_lora_rank=64, - head_size=config.hidden_size // config.num_attention_heads, - quantization=qconfig, - ) - - trt_config.mapping.gpus_per_node = min(trt_config.mapping.world_size, 8) - - return trt_config - - def get_plugins_config(self) -> PluginConfig: - config = super().get_plugins_config() - config.moe_plugin = "disable" # TODO : Mixtral? - config.bert_attention_plugin = "disable" - config.gpt_attention_plugin = self.dtype - config.gemm_plugin = self.dtype - - return config - - @staticmethod - def supports_strong_typing() -> bool: - return False - - -class LlamaForCausalLM(CausalLM, HuggingFaceHubModel): - MODEL_CONFIG = LlamaConfig +class LlamaForCausalLM(CausalLM, HuggingFaceHubModel, SupportsTransformersConversion): HF_LIBRARY_TARGET_MODEL_CLASS = TransformersLlamaForCausalLM - TRT_LLM_TARGET_MODEL_CLASS = LLaMAForCausalLM - - @staticmethod - def convert_weights( - target: PretrainedModel, - source: TransformersPretrainedModel, - config: PretrainedConfig, - ) -> Dict[str, np.ndarray]: - if config.quant_mode.has_any_quant(): - raise NotImplementedError("Quantization is not supported yet.") - - return load_weights_from_hf( - config=config.to_dict(), mapping=config.mapping, model=source - ) + TRT_LLM_TARGET_MODEL_CLASSES = LLaMAForCausalLM diff --git a/src/optimum/nvidia/models/mistral.py b/src/optimum/nvidia/models/mistral.py index c6918537..762b1519 100644 --- a/src/optimum/nvidia/models/mistral.py +++ b/src/optimum/nvidia/models/mistral.py @@ -13,106 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. from logging import getLogger -from typing import Dict, Optional -import numpy as np -from tensorrt_llm import Mapping -from tensorrt_llm.models import PretrainedConfig, PretrainedModel from tensorrt_llm.models.llama.model import LLaMAForCausalLM -from tensorrt_llm.models.llama.weight import load_from_hf_llama -from tensorrt_llm.plugin import PluginConfig from transformers import MistralForCausalLM as TransformersMistralForCausalLM -from transformers import PretrainedConfig as TransformersPretrainedConfig -from transformers import PreTrainedModel as TransformersPretrainedModel -from optimum.nvidia import TensorRTConfig -from optimum.nvidia.config import dtype_to_str from optimum.nvidia.hub import HuggingFaceHubModel +from optimum.nvidia.models import SupportsTransformersConversion from optimum.nvidia.runtime import CausalLM LOGGER = getLogger(__name__) -class MistralConfig(TensorRTConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`TensorRTConfig`] and can be used to control the model outputs. Read the - documentation from [`TensorRTConfig`] for more information. - """ - - @staticmethod - def from_config( - config: TransformersPretrainedConfig, mapping: Optional[Mapping] - ) -> "TensorRTConfig": - mapping = mapping or Mapping() - - # Retrieve the quantization from the transformers config (if provided) - _, qconfig = TensorRTConfig.get_quantization_config(config) - - trt_config = MistralConfig( - architecture=config.architectures[0], - dtype=dtype_to_str(config.torch_dtype), - logits_dtype="float32", - vocab_size=config.vocab_size, - max_position_embeddings=config.max_position_embeddings, - hidden_size=config.hidden_size, - num_hidden_layers=config.num_hidden_layers, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=getattr( - config, "num_key_value_heads", config.num_attention_heads - ), - hidden_act=config.hidden_act, - intermediate_size=config.intermediate_size, - norm_epsilon=config.rms_norm_eps, - position_embedding_type="rope_gpt_neox", - rotary_base=getattr(config, "rope_theta", 10000.0), - rotary_scaling=getattr(config, "rope_scaling", None), - world_size=mapping.world_size, - tp_size=mapping.tp_size, - pp_size=mapping.pp_size, - use_prompt_tuning=False, - use_parallel_embedding=mapping.tp_size > 1, - embedding_sharding_dim=0, - share_embedding_table=False, - max_lora_rank=64, - head_size=config.hidden_size // config.num_attention_heads, - quantization=qconfig, - ) - - trt_config.mapping.gpus_per_node = min(trt_config.mapping.world_size, 8) - - return trt_config - - def get_plugins_config(self) -> PluginConfig: - config = super().get_plugins_config() - config.moe_plugin = "disable" # TODO : Mixtral? - config.bert_attention_plugin = "disable" - config.gpt_attention_plugin = self.dtype - config.gemm_plugin = self.dtype - - return config - - @staticmethod - def supports_strong_typing() -> bool: - return False - - -class MistralForCausalLM(CausalLM, HuggingFaceHubModel): - MODEL_CONFIG = MistralConfig +class MistralForCausalLM(CausalLM, HuggingFaceHubModel, SupportsTransformersConversion): HF_LIBRARY_TARGET_MODEL_CLASS = TransformersMistralForCausalLM - TRT_LLM_TARGET_MODEL_CLASS = LLaMAForCausalLM - - @staticmethod - def convert_weights( - target: PretrainedModel, - source: TransformersPretrainedModel, - config: PretrainedConfig, - ) -> Dict[str, np.ndarray]: - if config.quant_mode.has_any_quant(): - raise NotImplementedError("Quantization is not supported yet.") - - return load_from_hf_llama(target, source, config.mapping, config.dtype) + TRT_LLM_TARGET_MODEL_CLASSES = LLaMAForCausalLM diff --git a/src/optimum/nvidia/models/mixtral.py b/src/optimum/nvidia/models/mixtral.py index 70a74267..f4e3063a 100644 --- a/src/optimum/nvidia/models/mixtral.py +++ b/src/optimum/nvidia/models/mixtral.py @@ -13,110 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. from logging import getLogger -from typing import Dict, Optional -import numpy as np -from tensorrt_llm import Mapping -from tensorrt_llm.models import PretrainedConfig, PretrainedModel -from tensorrt_llm.models.llama.convert import load_weights_from_hf from tensorrt_llm.models.llama.model import LLaMAForCausalLM -from tensorrt_llm.plugin import PluginConfig from transformers import MixtralForCausalLM as TransformersMixtralForCausalLM -from transformers import PretrainedConfig as TransformersPretrainedConfig -from transformers import PreTrainedModel as TransformersPretrainedModel -from optimum.nvidia import TensorRTConfig -from optimum.nvidia.config import dtype_to_str from optimum.nvidia.hub import HuggingFaceHubModel +from optimum.nvidia.models import SupportsTransformersConversion from optimum.nvidia.runtime import CausalLM LOGGER = getLogger(__name__) -class MixtralConfig(TensorRTConfig): - r""" - This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an Mixtral - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Mixtral-8x7b. - - Configuration objects inherit from [`TensorRTConfig`] and can be used to control the model outputs. Read the - documentation from [`TensorRTConfig`] for more information. - """ - - @staticmethod - def from_config( - config: TransformersPretrainedConfig, mapping: Optional[Mapping] - ) -> "TensorRTConfig": - mapping = mapping or Mapping() - - # Retrieve the quantization from the transformers config (if provided) - _, qconfig = TensorRTConfig.get_quantization_config(config) - - trt_config = MixtralConfig( - architecture=config.architectures[0], - dtype=dtype_to_str(config.torch_dtype), - logits_dtype="float32", - vocab_size=config.vocab_size, - max_position_embeddings=config.max_position_embeddings, - hidden_size=config.hidden_size, - num_hidden_layers=config.num_hidden_layers, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=getattr( - config, "num_key_value_heads", config.num_attention_heads - ), - hidden_act="swiglu", - intermediate_size=config.intermediate_size, - norm_epsilon=config.rms_norm_eps, - position_embedding_type="rope_gpt_neox", - rotary_base=getattr(config, "rope_theta", 10000.0), - rotary_scaling=getattr(config, "rope_scaling", None), - world_size=mapping.world_size, - tp_size=mapping.tp_size, - pp_size=mapping.pp_size, - use_prompt_tuning=False, - use_parallel_embedding=False, - embedding_sharding_dim=0, - share_embedding_table=False, - max_lora_rank=64, - head_size=config.hidden_size // config.num_attention_heads, - quantization=qconfig, - moe_num_experts=getattr(config, "num_local_experts", 0), - moe_top_k=getattr(config, "num_experts_per_tok", 0), - ) - - trt_config.mapping.gpus_per_node = min(trt_config.mapping.world_size, 8) - - return trt_config - - def get_plugins_config(self) -> PluginConfig: - config = super().get_plugins_config() - config.moe_plugin = self.dtype - config.bert_attention_plugin = "disable" - config.gpt_attention_plugin = self.dtype - config.gemm_plugin = self.dtype - - return config - - @staticmethod - def supports_strong_typing() -> bool: - return False - - -class MixtralForCausalLM(CausalLM, HuggingFaceHubModel): - MODEL_CONFIG = MixtralConfig +class MixtralForCausalLM(CausalLM, HuggingFaceHubModel, SupportsTransformersConversion): HF_LIBRARY_TARGET_MODEL_CLASS = TransformersMixtralForCausalLM - TRT_LLM_TARGET_MODEL_CLASS = LLaMAForCausalLM - - @staticmethod - def convert_weights( - target: PretrainedModel, - source: TransformersPretrainedModel, - config: PretrainedConfig, - ) -> Dict[str, np.ndarray]: - if config.quant_mode.has_any_quant(): - config.quantization.exclude_modules.append("router") - - return load_weights_from_hf( - config=config.to_dict(), mapping=config.mapping, model=source - ) + TRT_LLM_TARGET_MODEL_CLASSES = LLaMAForCausalLM diff --git a/src/optimum/nvidia/models/whisper.py b/src/optimum/nvidia/models/whisper.py index 6915dc8b..2ae8257b 100644 --- a/src/optimum/nvidia/models/whisper.py +++ b/src/optimum/nvidia/models/whisper.py @@ -12,1241 +12,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy -import json -import pathlib -from collections import OrderedDict from logging import getLogger -from pathlib import Path from typing import ( TYPE_CHECKING, - Any, - Callable, - Dict, - Iterator, - List, - Optional, - Tuple, - Union, ) -import numpy as np -import torch -from tensorrt_llm import Mapping, mpi_rank, str_dtype_to_torch -from tensorrt_llm._utils import str_dtype_to_trt, torch_to_numpy, trt_dtype_to_torch -from tensorrt_llm.builder import BuildConfig -from tensorrt_llm.functional import LayerNormPositionType, LayerNormType from tensorrt_llm.models import DecoderModel as TrtDecoderModel -from tensorrt_llm.models import PretrainedConfig, PretrainedModel from tensorrt_llm.models import WhisperEncoder as TrtWhisperEncoder -from tensorrt_llm.plugin import PluginConfig -from tensorrt_llm.runtime import GenerationSession, ModelConfig, SamplingConfig -from tensorrt_llm.runtime.generation import LogitsProcessorList -from tensorrt_llm.runtime.session import Session, TensorInfo -from transformers import GenerationConfig -from transformers import PreTrainedModel as TransformersPretrainedModel -from transformers.models.whisper.generation_whisper import WhisperGenerationMixin -from transformers.models.whisper.modeling_whisper import ( - WhisperDecoder as TransformersWhisperDecoder, -) -from transformers.models.whisper.modeling_whisper import ( - WhisperEncoder as TransformersWhisperEncoder, -) from transformers.models.whisper.modeling_whisper import ( WhisperForConditionalGeneration as TransformersWhisperForConditionalGeneration, ) -from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE -from optimum.nvidia import TensorRTConfig -from optimum.nvidia.config import dtype_to_str -from optimum.nvidia.generation.logits_process import ( - TrtForceTokensLogitsProcessor, - TrtSuppressTokensAtBeginLogitsProcessor, - TrtSuppressTokensLogitsProcessor, - TrtWhisperNoSpeechDetection, -) -from optimum.nvidia.hub import HuggingFaceHubModel -from optimum.nvidia.runtime import TensorRTForSpeechSeq2Seq -from optimum.nvidia.utils.nvml import get_max_memory +from optimum.nvidia.models import SupportsTransformersConversion if TYPE_CHECKING: - from transformers import PretrainedConfig as TransformersPretrainedConfig - from transformers.generation.logits_process import LogitsProcessorList - from transformers.generation.stopping_criteria import StoppingCriteriaList + pass LOGGER = getLogger(__name__) -def split(v, tp_size, idx, dim=0): - if tp_size == 1: - return v - return np.split(v, tp_size, axis=dim)[idx] - - -def get_qkv(model_params, torch_dtype, attn_prefix: str): - q_weight = torch_to_numpy( - model_params[f"{attn_prefix}.q_proj.weight"].to(torch_dtype) - ) - k_weight = torch_to_numpy( - model_params[f"{attn_prefix}.k_proj.weight"].to(torch_dtype) - ) - v_weight = torch_to_numpy( - model_params[f"{attn_prefix}.v_proj.weight"].to(torch_dtype) - ) - - qkv_weight = (q_weight, k_weight, v_weight) - - # At least one of the query, key, value projection has bias. - if any( - bias_name in model_params - for bias_name in [ - f"{attn_prefix}.q_proj.bias", - f"{attn_prefix}.k_proj.bias", - f"{attn_prefix}.v_proj.bias", - ] - ): - # For example whisper encoder k_proj does not have bias, so we will just fill with zeros the fused bias if needed. - numpy_precision = q_weight.dtype - if f"{attn_prefix}.q_proj.bias" in model_params: - q_bias = torch_to_numpy( - model_params[f"{attn_prefix}.q_proj.bias"].to(torch_dtype) - ) - else: - q_bias = np.zeros(q_weight.shape[0], dtype=numpy_precision) - - if f"{attn_prefix}.k_proj.bias" in model_params: - k_bias = torch_to_numpy( - model_params[f"{attn_prefix}.k_proj.bias"].to(torch_dtype) - ) - else: - k_bias = np.zeros(k_weight.shape[0], dtype=numpy_precision) - - if f"{attn_prefix}.v_proj.bias" in model_params: - v_bias = torch_to_numpy( - model_params[f"{attn_prefix}.v_proj.bias"].to(torch_dtype) - ) - else: - v_bias = np.zeros(v_weight.shape[0], dtype=numpy_precision) - qkv_bias = (q_bias, k_bias, v_bias) - else: - qkv_bias = None - - return qkv_weight, qkv_bias - - -def convert_from_hf_whisper_encoder( - hf_whisper_encoder, - mapping=Mapping(), - dtype="float16", -): - num_layers = hf_whisper_encoder.config.encoder_layers - torch_dtype = str_dtype_to_torch(dtype) - - model_params = dict(hf_whisper_encoder.named_parameters()) - weights = {} - - # Convert specific tensors - # conv1 - # TensorRT-LLM Conv1d weight is 4D while transformers checkpoint ones are 3D. - conv1_weight = torch_to_numpy(model_params["conv1.weight"].to(torch_dtype)) - conv1_weight = conv1_weight[..., None] - - weights["model.conv1.weight"] = conv1_weight - weights["model.conv1.bias"] = torch_to_numpy( - model_params["conv1.bias"].to(torch_dtype) - ) - - # conv2 - conv2_weight = torch_to_numpy(model_params["conv2.weight"].to(torch_dtype)) - conv2_weight = conv2_weight[..., None] - - weights["model.conv2.weight"] = conv2_weight - weights["model.conv2.bias"] = torch_to_numpy( - model_params["conv2.bias"].to(torch_dtype) - ) - - # embed_positions - # NOTE: this one is kept as fp32 in Whisper, is this important? - weights["model.positional_embedding"] = torch_to_numpy( - model_params["embed_positions.weight"] # .to(torch_dtype) - ) - - # Final layer norm - weights["model.ln_post.weight"] = torch_to_numpy( - model_params["layer_norm.weight"].to(torch_dtype) - ) - weights["model.ln_post.bias"] = torch_to_numpy( - model_params["layer_norm.bias"].to(torch_dtype) - ) - - # Map all the hidden layers - for layer_idx in range(num_layers): - prefix = f"layers.{layer_idx}" - - # attention_layernorm - weights[f"model.encoder_layers.{layer_idx}.attention_layernorm.weight"] = ( - torch_to_numpy( - model_params[f"{prefix}.self_attn_layer_norm.weight"].to(torch_dtype) - ) - ) - weights[f"model.encoder_layers.{layer_idx}.attention_layernorm.bias"] = ( - torch_to_numpy( - model_params[f"{prefix}.self_attn_layer_norm.bias"].to(torch_dtype) - ) - ) - - # mlp_layernorm - weights[f"model.encoder_layers.{layer_idx}.mlp_layernorm.weight"] = ( - torch_to_numpy( - model_params[f"{prefix}.final_layer_norm.weight"].to(torch_dtype) - ) - ) - weights[f"model.encoder_layers.{layer_idx}.mlp_layernorm.bias"] = ( - torch_to_numpy( - model_params[f"{prefix}.final_layer_norm.bias"].to(torch_dtype) - ) - ) - - # Self attention layer - # TensorRT-LLM model definition uses a single GEMM for query/key/value, while transformers does not. - qkv_weight, qkv_bias = get_qkv( - model_params, attn_prefix=f"{prefix}.self_attn", torch_dtype=torch_dtype - ) - q_weight, k_weight, v_weight = qkv_weight - - qkv_weight = np.concatenate((q_weight, k_weight, v_weight), axis=0) - weight = split(qkv_weight, mapping.tp_size, mapping.tp_rank, dim=1) - weights[f"model.encoder_layers.{layer_idx}.attention.qkv.weight"] = weight - - if qkv_bias is not None: - q_bias, k_bias, v_bias = qkv_bias - packed_qkv_bias = np.concatenate((q_bias, k_bias, v_bias), axis=0) - weights[f"model.encoder_layers.{layer_idx}.attention.qkv.bias"] = ( - np.ascontiguousarray(packed_qkv_bias) - ) - - # Common projection logic - # 0: column tensor parallel, 1: row tensor parallel. - for src, dst, shard_axis in [ - ("self_attn.out_proj.weight", "attention.dense.weight", 1), - ("fc1.weight", "mlp.fc.weight", 0), - ("fc2.weight", "mlp.proj.weight", 1), - ]: - weight = torch_to_numpy(model_params[f"{prefix}.{src}"].to(torch_dtype)) - weight = split(weight, mapping.tp_size, mapping.tp_rank, dim=shard_axis) - weights[f"model.encoder_layers.{layer_idx}.{dst}"] = weight - - # Bias is never sharded. - for src, dst in [ - ("self_attn.out_proj.bias", "attention.dense.bias"), - ("fc1.bias", "mlp.fc.bias"), - ("fc2.bias", "mlp.proj.bias"), - ]: - weights[f"model.encoder_layers.{layer_idx}.{dst}"] = torch_to_numpy( - model_params[f"{prefix}.{src}"].to(torch_dtype) - ) - - # weights["lm_head.weight"] = np.zeros((0)) # Just a hack for commands/build.py - - return weights - - -def convert_from_hf_whisper_decoder( - hf_whisper_decoder, - mapping=Mapping(), - dtype="float16", -): - weights = {} - - num_layers = hf_whisper_decoder.config.decoder_layers - torch_dtype = str_dtype_to_torch(dtype) - - model_params = dict(hf_whisper_decoder.named_parameters()) - - # Convert specific tensors - if mapping.is_first_pp_rank(): - # embed_tokens - weights["model.embedding.vocab_embedding.weight"] = torch_to_numpy( - model_params["embed_tokens.weight"].to(torch_dtype) - ) - - # embed_positions - weights["model.embedding.position_embedding.weight"] = torch_to_numpy( - model_params["embed_positions.weight"].to(torch_dtype) - ) - - if mapping.is_last_pp_rank(): - # Final layer norm - weights["model.final_layernorm.weight"] = torch_to_numpy( - model_params["layer_norm.weight"].to(torch_dtype) - ) - weights["model.final_layernorm.bias"] = torch_to_numpy( - model_params["layer_norm.bias"].to(torch_dtype) - ) - - # Final vocab projection - lm_head = torch_to_numpy(model_params["embed_tokens.weight"].to(torch_dtype)) - lm_head = split(lm_head, mapping.tp_size, mapping.tp_rank) - weights["model.lm_head.weight"] = lm_head - - # Map all the hidden layers - for layer_idx in range(num_layers): - prefix = f"layers.{layer_idx}" - trt_llm_prefix = f"model.decoder_layers.{layer_idx}" - - # self_attention_layernorm - weights[f"{trt_llm_prefix}.self_attention_layernorm.weight"] = torch_to_numpy( - model_params[f"{prefix}.self_attn_layer_norm.weight"].to(torch_dtype) - ) - weights[f"{trt_llm_prefix}.self_attention_layernorm.bias"] = torch_to_numpy( - model_params[f"{prefix}.self_attn_layer_norm.bias"].to(torch_dtype) - ) - - # cross_attention_layernorm - weights[f"{trt_llm_prefix}.cross_attention_layernorm.weight"] = torch_to_numpy( - model_params[f"{prefix}.encoder_attn_layer_norm.weight"].to(torch_dtype) - ) - weights[f"{trt_llm_prefix}.cross_attention_layernorm.bias"] = torch_to_numpy( - model_params[f"{prefix}.encoder_attn_layer_norm.bias"].to(torch_dtype) - ) - - # mlp_layernorm - weights[f"{trt_llm_prefix}.mlp_layernorm.weight"] = torch_to_numpy( - model_params[f"{prefix}.final_layer_norm.weight"].to(torch_dtype) - ) - weights[f"{trt_llm_prefix}.mlp_layernorm.bias"] = torch_to_numpy( - model_params[f"{prefix}.final_layer_norm.bias"].to(torch_dtype) - ) - - # Self attention layer - qkv_weight, qkv_bias = get_qkv( - model_params, attn_prefix=f"{prefix}.self_attn", torch_dtype=torch_dtype - ) - q_weight, k_weight, v_weight = qkv_weight - - qkv_weight = np.concatenate((q_weight, k_weight, v_weight), axis=0) - weight = split(qkv_weight, mapping.tp_size, mapping.tp_rank, dim=1) - weights[f"{trt_llm_prefix}.self_attention.qkv.weight"] = weight - - if qkv_bias is not None: - q_bias, k_bias, v_bias = qkv_bias - packed_qkv_bias = np.concatenate((q_bias, k_bias, v_bias), axis=0) - weights[f"{trt_llm_prefix}.self_attention.qkv.bias"] = np.ascontiguousarray( - packed_qkv_bias - ) - - # Cross attention layer - qkv_weight, qkv_bias = get_qkv( - model_params, attn_prefix=f"{prefix}.encoder_attn", torch_dtype=torch_dtype - ) - q_weight, k_weight, v_weight = qkv_weight - - qkv_weight = np.concatenate((q_weight, k_weight, v_weight), axis=0) - weight = split(qkv_weight, mapping.tp_size, mapping.tp_rank, dim=1) - weights[f"{trt_llm_prefix}.cross_attention.qkv.weight"] = weight - - if qkv_bias is not None: - q_bias, k_bias, v_bias = qkv_bias - packed_qkv_bias = np.concatenate((q_bias, k_bias, v_bias), axis=0) - weights[f"{trt_llm_prefix}.cross_attention.qkv.bias"] = ( - np.ascontiguousarray(packed_qkv_bias) - ) - - # Common projection logic. - # 0: column tensor parallel, 1: row tensor parallel. - for src, dst, shard_axis in [ - ("self_attn.out_proj.weight", "self_attention.dense.weight", 1), - ("encoder_attn.out_proj.weight", "cross_attention.dense.weight", 1), - ("fc1.weight", "mlp.fc.weight", 0), - ("fc2.weight", "mlp.proj.weight", 1), - ]: - weight = torch_to_numpy(model_params[f"{prefix}.{src}"].to(torch_dtype)) - weight = split(weight, mapping.tp_size, mapping.tp_rank, dim=shard_axis) - weights[f"{trt_llm_prefix}.{dst}"] = weight - - # Bias is never sharded. - for src, dst in [ - ("self_attn.out_proj.bias", "self_attention.dense.bias"), - ("encoder_attn.out_proj.bias", "cross_attention.dense.bias"), - ("fc1.bias", "mlp.fc.bias"), - ("fc2.bias", "mlp.proj.bias"), - ]: - weights[f"{trt_llm_prefix}.{dst}"] = torch_to_numpy( - model_params[f"{prefix}.{src}"].to(torch_dtype) - ) - - return weights - - -class WhisperEncoderConfig(TensorRTConfig): - @classmethod - def from_config( - cls, config: "TransformersPretrainedConfig", mapping: Optional[Mapping] = None - ) -> "TensorRTConfig": - mapping = mapping or Mapping() - - # Retrieve the quantization from the transformers config (if provided) - _, qconfig = TensorRTConfig.get_quantization_config(config) - - trt_config = cls( - architecture=config.architectures[0], - dtype=dtype_to_str(config.torch_dtype), # TODO: always float32? - logits_dtype="float32", - vocab_size=config.vocab_size, - max_position_embeddings=config.max_target_positions, - hidden_size=config.d_model, - num_hidden_layers=config.encoder_layers, - num_attention_heads=config.encoder_attention_heads, - num_key_value_heads=config.encoder_attention_heads, - hidden_act=config.activation_function, - intermediate_size=None, - norm_epsilon=None, - position_embedding_type="learned_absolute", - world_size=mapping.world_size, - tp_size=mapping.tp_size, - pp_size=mapping.pp_size, - quantization=qconfig, - use_parallel_embedding=mapping.tp_size > 1, - embedding_sharding_dim=0 if mapping.tp_size > 1 else None, - share_embedding_table=None, - head_size=-1, # We need to set it otherwise TRT-LLM tries to compute `hidden_size // num_attention_heads` - max_source_positions=config.max_source_positions, - num_mel_bins=config.num_mel_bins, - trt_model_class="TrtWhisperEncoderPretrainedModel", - trt_model_file=pathlib.Path(__file__), - ) - - trt_config.mapping.gpus_per_node = min(trt_config.mapping.world_size, 8) - - return trt_config - - def get_plugins_config(self) -> PluginConfig: - config = super().get_plugins_config() - config.bert_attention_plugin = self.dtype - config.gpt_attention_plugin = "disable" - config.remove_input_padding = False # This one is bugged with Whisper. - config.paged_kv_cache = "disable" # TODO: getting AssertionError: Paged kv cache is enabled, the kv_cache_block_pointers tensor shall not be None - - config.moe_plugin = "disable" - config.gemm_plugin = self.dtype - config.context_fmha = True - config.enable_xqa = False - config.remove_input_padding = False - config.use_custom_all_reduce = "disable" - - config.layernorm_quantization_plugin = None - config.rmsnorm_quantization_plugin = None - config.nccl_plugin = None - - return config - - @staticmethod - def supports_strong_typing() -> bool: - return False - - -class WhisperDecoderConfig(TensorRTConfig): - @classmethod - def from_config( - cls, config: "TransformersPretrainedConfig", mapping: Optional[Mapping] - ) -> "TensorRTConfig": - mapping = mapping or Mapping() - - # Retrieve the quantization from the transformers config (if provided) - _, qconfig = TensorRTConfig.get_quantization_config(config) - - trt_config = cls( - architecture=config.architectures[0], - dtype=dtype_to_str(config.torch_dtype), # TODO: always float32? - logits_dtype="float32", - vocab_size=config.vocab_size, - max_position_embeddings=config.max_target_positions, - hidden_size=config.d_model, - num_hidden_layers=config.decoder_layers, - num_attention_heads=config.decoder_attention_heads, - num_key_value_heads=config.decoder_attention_heads, - hidden_act=config.activation_function, - intermediate_size=None, - norm_epsilon=None, - position_embedding_type="learned_absolute", - world_size=mapping.world_size, - tp_size=mapping.tp_size, - pp_size=mapping.pp_size, - quantization=qconfig, - use_parallel_embedding=None, - embedding_sharding_dim=None, - share_embedding_table=None, - head_size=-1, # We need to set it otherwise TRT-LLM tries to compute `hidden_size // num_attention_heads` - max_source_positions=config.max_source_positions, - decoder_ffn_dim=config.decoder_ffn_dim, - trt_model_class="TrtWhisperDecoderPretrainedModel", - trt_model_file=pathlib.Path(__file__), - num_encoder_attention_heads=config.encoder_attention_heads, - ) - - trt_config.mapping.gpus_per_node = min(trt_config.mapping.world_size, 8) - - return trt_config - - def get_plugins_config(self) -> PluginConfig: - config = super().get_plugins_config() - config.bert_attention_plugin = "disable" - config.gpt_attention_plugin = self.dtype - config.paged_kv_cache = "disable" # TODO: getting AssertionError: Paged kv cache is enabled, the kv_cache_block_pointers tensor shall not be None - - config.context_fmha = True - config.moe_plugin = "disable" - config.gemm_plugin = self.dtype - config.remove_input_padding = False - config.enable_xqa = False - - config.layernorm_quantization_plugin = None - config.rmsnorm_quantization_plugin = None - config.nccl_plugin = None - config.use_custom_all_reduce = "disable" - - return config - - @staticmethod - def supports_strong_typing() -> bool: - return False - - -class TrtWhisperEncoderPretrainedModel(PretrainedModel): - def __init__(self, config): - super().__init__(config) - self.model = TrtWhisperEncoder( - n_mels=config.num_mel_bins, - n_ctx=config.max_source_positions, - n_state=config.hidden_size, - n_head=config.num_attention_heads, - n_layer=config.num_hidden_layers, - dtype=str_dtype_to_trt(config.dtype), - ) - - def forward(self, **kwargs): - return self.model(**kwargs) - - def prepare_inputs(self, max_batch_size=16, **kwargs): - (x, input_lengths) = self.model.prepare_inputs(max_batch_size=max_batch_size) - - return {"x": x, "input_lengths": input_lengths} - - -class TrtWhisperDecoderPretrainedModel(PretrainedModel): - def __init__(self, config): - super().__init__(config) - self.model = TrtDecoderModel( - num_layers=config.num_hidden_layers, - num_heads=config.num_attention_heads, - hidden_size=config.hidden_size, - ffn_hidden_size=config.decoder_ffn_dim, - encoder_num_heads=config.num_encoder_attention_heads, - encoder_hidden_size=config.hidden_size, - vocab_size=config.vocab_size, - dtype=str_dtype_to_trt(config.dtype), - logits_dtype=str_dtype_to_trt(config.logits_dtype), - max_position_embeddings=config.max_position_embeddings, - has_position_embedding=True, - relative_attention=False, - head_size=None, - encoder_head_size=None, - num_kv_heads=None, - encoder_num_kv_heads=None, - type_vocab_size=None, - max_distance=0, - num_buckets=0, - has_embedding_layernorm=False, - has_embedding_scale=False, - q_scaling=1.0, - has_attention_qkvo_bias=True, - has_mlp_bias=True, - has_model_final_layernorm=True, - layernorm_eps=1e-5, - layernorm_position=LayerNormPositionType.pre_layernorm, - layernorm_type=LayerNormType.LayerNorm, - hidden_act=config.hidden_act, - rescale_before_lm_head=False, - mapping=config.mapping, - ) - self.config.optimize_network = False # See utils/patching.py. TODO: remove this once native to TensorRT-LLM. - - def forward(self, **kwargs): - return self.model(**kwargs) - - def prepare_inputs( - self, - max_batch_size, - max_input_len, - max_seq_len, - use_cache, - max_beam_width: int = 1, - max_num_tokens: int = None, - prompt_embedding_table_size: int = 0, - position_encoding_2d: bool = False, - max_draft_len: int = 0, - gather_context_logits: bool = False, - gather_generation_logits: bool = False, - lora_target_modules: List[str] = None, - **kwargs, - ): - if not use_cache: - raise NotImplementedError("use_cache=False is not implemented for Whisper.") - - ( - input_ids, - encoder_output, - position_ids, - token_type_ids, - use_cache, - attention_mask, - cross_attention_mask, - last_token_ids, - kv_cache_params, - attention_params, - hidden_states, - lora_params, - cross_kv_cache_gen, - cross_qkv_reuse, - ) = self.model.prepare_inputs( - max_batch_size=max_batch_size, - max_beam_width=max_beam_width, - max_decoder_input_len=max_input_len, - max_new_tokens=max_seq_len, - max_encoder_input_len=self.config.max_source_positions, - ) - - return { - "decoder_input_ids": input_ids, - "encoder_output": encoder_output, - "position_ids": position_ids, - "token_type_ids": token_type_ids, - "use_cache": use_cache, - "attention_mask": attention_mask, - "cross_attention_mask": cross_attention_mask, - "last_token_ids": last_token_ids, - "kv_cache_params": kv_cache_params, - "attention_params": attention_params, - "hidden_states": hidden_states, - "lora_params": lora_params, - "cross_kv_cache_gen": cross_kv_cache_gen, - "cross_qkv_reuse": cross_qkv_reuse, - } - - -class OptimumWhisperEncoder(TensorRTForSpeechSeq2Seq, HuggingFaceHubModel): - MODEL_CONFIG = WhisperEncoderConfig - HF_LIBRARY_TARGET_MODEL_CLASS = TransformersWhisperEncoder - TRT_LLM_TARGET_MODEL_CLASS = TrtWhisperEncoderPretrainedModel - - @staticmethod - def convert_weights( - target: PretrainedModel, - source: TransformersPretrainedModel, - config: PretrainedConfig, - ) -> Dict[str, torch.Tensor]: - if config.quant_mode.has_any_quant(): - raise NotImplementedError("Quantization is not supported yet.") - - return convert_from_hf_whisper_encoder(source, config.mapping, config.dtype) - - -class OptimumWhisperDecoder(TensorRTForSpeechSeq2Seq, HuggingFaceHubModel): - MODEL_CONFIG = WhisperDecoderConfig - HF_LIBRARY_TARGET_MODEL_CLASS = TransformersWhisperDecoder - TRT_LLM_TARGET_MODEL_CLASS = TrtWhisperDecoderPretrainedModel - - @staticmethod - def convert_weights( - target: PretrainedModel, - source: TransformersPretrainedModel, - config: PretrainedConfig, - ) -> Dict[str, torch.Tensor]: - if config.quant_mode.has_any_quant(): - raise NotImplementedError("Quantization is not supported yet.") - - return convert_from_hf_whisper_decoder(source, config.mapping, config.dtype) - - -class WhisperForConditionalGeneration( - TensorRTForSpeechSeq2Seq, HuggingFaceHubModel, WhisperGenerationMixin -): - MODEL_CONFIG = None +class WhisperForConditionalGeneration(SupportsTransformersConversion): HF_LIBRARY_TARGET_MODEL_CLASS = TransformersWhisperForConditionalGeneration - TRT_LLM_TARGET_MODEL_CLASS = None # Whisper is split in two in TRT-LLM. - - def __init__( - self, - engines_folders: List[Path], - *, - gpus_per_node: int, - transformers_config: "TransformersPretrainedConfig", - use_cuda_graph: bool = False, - generation_config: Optional[GenerationConfig] = None, - ): - super().__init__( - engines_folders, - gpus_per_node=gpus_per_node, - transformers_config=transformers_config, - use_cuda_graph=use_cuda_graph, - generation_config=generation_config, - ) - - if generation_config is None: - generation_config = GenerationConfig() - self.generation_config = generation_config - - self.config = transformers_config - - # Encoder. - serialize_path = engines_folders[0] / "rank0.engine" - with open(serialize_path, "rb") as f: - encoder_session = Session.from_serialized_engine(f.read()) - - self.encoder_session = encoder_session - - # Decoder. - decoder_config_path = engines_folders[1] / "config.json" - with open(decoder_config_path, "r") as f: - decoder_config = json.load(f) - - serialize_path = engines_folders[1] / "rank0.engine" - with open(serialize_path, "rb") as f: - decoder_engine_buffer = f.read() - - build_config = BuildConfig.from_dict(decoder_config["build_config"]) - trt_config = WhisperDecoderConfig.from_dict(decoder_config["pretrained_config"]) - - self.dtype = trt_config.dtype - - decoder_model_config = ModelConfig( - max_batch_size=build_config.max_batch_size, - max_beam_width=build_config.max_beam_width, - num_heads=trt_config.num_attention_heads, - num_kv_heads=trt_config.num_key_value_heads, - hidden_size=trt_config.hidden_size, - vocab_size=trt_config.vocab_size, - num_layers=trt_config.num_hidden_layers, - gpt_attention_plugin=build_config.plugin_config.gpt_attention_plugin, - remove_input_padding=build_config.plugin_config.remove_input_padding, - cross_attention=True, - has_position_embedding=True, - has_token_type_embedding=False, - ) - - # world_size > 1 is not supported. - world_size = 1 - runtime_rank = mpi_rank() - - runtime_mapping = Mapping(world_size, runtime_rank) - - decoder_generation_session = GenerationSession( - decoder_model_config, - decoder_engine_buffer, - runtime_mapping, - debug_mode=False, - ) - - self.decoder_generation_session = decoder_generation_session - - @classmethod - def convert_and_build( - cls, - local_path: Path, - hf_model_config: Dict, - engine_save_path: Optional[Path] = None, - hf_model: Optional[TransformersPretrainedModel] = None, - **model_kwargs, - ) -> Path: - max_memory = get_max_memory() - - if hf_model is None: - # Allocate required components for quantization - hf_model = cls.HF_LIBRARY_TARGET_MODEL_CLASS.from_pretrained( - local_path, - device_map="auto", - max_memory=max_memory, - local_files_only=True, - ) - - if engine_save_path is None: - engine_save_path = local_path - - LOGGER.info("Building Whisper encoder...") - ( - encoder_checkpoints_folder, - encoder_engines_folder, - encoder_engines_relative_folder, - ) = OptimumWhisperEncoder.convert_and_build( - local_path, - hf_model_config, - engine_save_path=Path(engine_save_path, "encoder"), - hf_model=hf_model.model.encoder, - config_class=WhisperEncoderConfig, - **model_kwargs, - ) - - LOGGER.info("Building Whisper decoder...") - ( - decoder_checkpoints_folder, - decoder_engines_folder, - decoder_engines_relative_folder, - ) = OptimumWhisperDecoder.convert_and_build( - local_path, - hf_model_config, - engine_save_path=Path(engine_save_path, "decoder"), - hf_model=hf_model.model.decoder, - config_class=WhisperDecoderConfig, - **model_kwargs, - ) - - return ( - [encoder_checkpoints_folder[0], decoder_checkpoints_folder[0]], - [encoder_engines_folder[0], decoder_engines_folder[0]], - [ - encoder_engines_relative_folder[0], - decoder_engines_relative_folder[0], - ], - ) - - def encoder( - self, - input_features: torch.Tensor, - ): - if dtype_to_str(input_features.dtype) != self.dtype: - LOGGER.warning( - f"input_features should be of dtype {self.dtype}, got {dtype_to_str(input_features.dtype)}. Automatically casting to {self.dtype}." - ) - input_features = input_features.to(str_dtype_to_torch(self.dtype)) - - input_lengths = torch.tensor( - [input_features.shape[2] // 2 for _ in range(input_features.shape[0])], - dtype=torch.int32, - device=input_features.device, - ) - - inputs = OrderedDict() - inputs["x"] = input_features - inputs["input_lengths"] = input_lengths - - output_list = [ - TensorInfo("x", str_dtype_to_trt(self.dtype), input_features.shape), - TensorInfo("input_lengths", str_dtype_to_trt("int32"), input_lengths.shape), - ] - - output_info = (self.encoder_session).infer_shapes(output_list) - - LOGGER.debug(f"output info {output_info}") - outputs = { - t.name: torch.empty( - tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda" - ) - for t in output_info - } - - stream = torch.cuda.current_stream() - - ok = self.encoder_session.run( - inputs=inputs, outputs=outputs, stream=stream.cuda_stream - ) - - assert ok, "Engine execution failed" - stream.synchronize() - - return outputs["output"] - - def _retrieve_logit_processors( - self, generation_config, logits_processor, begin_index, is_shortform, num_beams - ): - # Adapted from WhisperGenerationMixin._retrieve_logit_processors with XxxLogitsProcessor -> TrtXxxLogitsProcessor - - if generation_config.return_timestamps is True: - # TODO: implement. - raise NotImplementedError( - "return_timestamps=True is not implemented with TensorRT-LLM. Please open an issue at https://github.com/huggingface/optimum-nvidia/issues. In the meanwhile, please set `model.generation_config.return_timestamps=False`." - ) - - if generation_config.suppress_tokens is not None: - suppress_tokens_processor = TrtSuppressTokensLogitsProcessor( - generation_config.suppress_tokens - ) - logits_processor = ( - [suppress_tokens_processor] - if logits_processor is None - else [suppress_tokens_processor] + logits_processor - ) - generation_config.suppress_tokens = None - - if generation_config.begin_suppress_tokens is not None: - begin_suppress_processor = TrtSuppressTokensAtBeginLogitsProcessor( - generation_config.begin_suppress_tokens, begin_index=begin_index - ) - logits_processor = ( - [begin_suppress_processor] - if logits_processor is None - else [begin_suppress_processor] + logits_processor - ) - generation_config.begin_suppress_tokens = None - - if generation_config.no_speech_threshold is not None and not is_shortform: - no_speech_detector = TrtWhisperNoSpeechDetection( - no_speech_token=generation_config.no_timestamps_token_id - 1, - begin_index=begin_index, - scores_is_logprobs=num_beams > 1, - ) - logits_processor = ( - [no_speech_detector] - if logits_processor is None - else [no_speech_detector] + logits_processor - ) - no_speech_detector.set_model(self) - - if is_shortform and generation_config.forced_decoder_ids is not None: - forced_tokens_proc = TrtForceTokensLogitsProcessor( - generation_config.forced_decoder_ids - ) - # It's important that the `forced_tokens_proc` processor is appended after - # the suppress_tokens processor or else it might happen that all token logits are suppressed to -inf - # which would lead to unexpected behavior - # The better approach here is to NOT make use of the `forced_tokens_proc` for Whisper and instead - # initialize all of them as `decoder_input_ids`. - # TODO(Sanchit): Make sure to deprecate this in v4.39 as there will be no `forced_decoder_ids` anymore. - logits_processor = ( - [forced_tokens_proc] - if logits_processor is None - else logits_processor + [forced_tokens_proc] - ) - generation_config.forced_decoder_ids = None - - return logits_processor - - def _retrieve_init_tokens( - self, input_features, generation_config, config, num_segment_frames, kwargs - ): - # Adapted from WhisperGenerationMixin._retrieve_init_tokens with automatic language detection disabled. - - def replace_or_add(lst: List[int], num: int, itr: Iterator[int]): - """short function to replace num with a itr in lst""" - found = any(i in lst for i in itr) - if found: - lst = [num if i in itr else i for i in lst] - else: - lst.append(num) - return lst - - task = getattr(generation_config, "task", None) - language = getattr(generation_config, "language", None) - - if kwargs.get("forced_decoder_ids", None) is not None: - forced_decoder_ids = kwargs["forced_decoder_ids"] - elif ( - hasattr(generation_config, "forced_decoder_ids") - and generation_config.forced_decoder_ids is not None - ): - forced_decoder_ids = generation_config.forced_decoder_ids - - if language is None and task is None and forced_decoder_ids[0][1] is None: - LOGGER.warning_once( - "Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English." - "This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`." - ) - elif ( - hasattr(config, "forced_decoder_ids") - and config.forced_decoder_ids is not None - ): - forced_decoder_ids = config.forced_decoder_ids - else: - forced_decoder_ids = None - - if forced_decoder_ids is not None and task is not None: - LOGGER.info( - f"You have passed task={task}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of task={task}." - ) - forced_decoder_ids = None - elif forced_decoder_ids is not None and language is not None: - LOGGER.info( - f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}." - ) - forced_decoder_ids = None - - init_tokens = [generation_config.decoder_start_token_id] - if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1: - i = 1 - while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i: - init_tokens += [forced_decoder_ids[0][1]] - forced_decoder_ids = forced_decoder_ids[1:] - i += 1 - - if len(forced_decoder_ids) > 0: - raise ValueError( - f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}.", - ) - - # from v4.39 the forced decoder ids are always None in favour of decoder input ids - generation_config.forced_decoder_ids = None - - is_lang_id_undefined = len(init_tokens) <= 1 or ( - len(init_tokens) > 1 and init_tokens[1] is None - ) - if language is not None: - if language in generation_config.lang_to_id.keys(): - language_token = language - elif language in TO_LANGUAGE_CODE.keys(): - language_token = f"<|{TO_LANGUAGE_CODE[language]}|>" - elif language in TO_LANGUAGE_CODE.values(): - language_token = f"<|{language}|>" - else: - is_language_code = len(language) == 2 - raise ValueError( - f"Unsupported language: {language}. Language should be one of:" - f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}." - ) - if language_token not in generation_config.lang_to_id: - raise ValueError( - f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`." - "(You should just add it to the generation config)" - ) - - lang_id = generation_config.lang_to_id[language_token] - - # if language is defined it'll overwrite language ids that might have already been defined via the generation_config - replace_or_add(init_tokens, lang_id, generation_config.lang_to_id.values()) - elif hasattr(generation_config, "lang_to_id") and is_lang_id_undefined: - raise ValueError( - f"The language is not specified in the model's generation_config, and automatic language detection is not supported with TensorRT-LLM. Please set e.g. model.generation_config.language = '<|en|>' for English language. Available languages: {generation_config.lang_to_id.keys()}. Please refer to https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/whisper/tokenization_whisper.py#L95 for the languages codes." - ) - - if task is not None: - if task in TASK_IDS: - init_tokens.append(generation_config.task_to_id[generation_config.task]) - task_id = generation_config.task_to_id[generation_config.task] - - # if task is defined it'll overwrite task ids that might have already been defined via the generation_config - replace_or_add( - init_tokens, task_id, generation_config.task_to_id.values() - ) - else: - raise ValueError( - f"The `{task}`task is not supported. The task should be one of `{TASK_IDS}`" - ) - elif language is not None and hasattr(generation_config, "task_to_id"): - # if language is defined, but no task id is in `init_tokens`, default to transcribe - if not any(i in init_tokens for i in generation_config.task_to_id.values()): - init_tokens.append(generation_config.task_to_id["transcribe"]) - - if ( - not generation_config.return_timestamps - and hasattr(generation_config, "no_timestamps_token_id") - and init_tokens[-1] != generation_config.no_timestamps_token_id - ): - init_tokens.append(generation_config.no_timestamps_token_id) - elif ( - generation_config.return_timestamps - and init_tokens[-1] == generation_config.no_timestamps_token_id - ): - LOGGER.info( - "<|notimestamps|> prompt token is removed from generation_config since `return_timestamps` is set to `'True'`." - ) - init_tokens = init_tokens[:-1] - - # let's make sure we don't pass `None` tokens as prompt tokens - init_tokens = [t for t in init_tokens if t is not None] - - return init_tokens - - def generate( - self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional["LogitsProcessorList"] = None, - stopping_criteria: Optional["StoppingCriteriaList"] = None, - prefix_allowed_tokens_fn: Optional[ - Callable[[int, torch.Tensor], List[int]] - ] = None, - synced_gpus: bool = False, - return_timestamps: Optional[bool] = None, - task: Optional[str] = None, - language: Optional[str] = None, - is_multilingual: Optional[bool] = None, - prompt_ids: Optional[torch.Tensor] = None, - prompt_condition_type: Optional[str] = None, # first-segment, all-segments - condition_on_prev_tokens: Optional[bool] = None, - temperature: Optional[Union[float, Tuple[float, ...]]] = None, - compression_ratio_threshold: Optional[float] = None, - logprob_threshold: Optional[float] = None, - no_speech_threshold: Optional[float] = None, - num_segment_frames: Optional[int] = None, - attention_mask: Optional[torch.Tensor] = None, - time_precision: float = 0.02, - return_token_timestamps: Optional[bool] = None, - return_segments: bool = False, - return_dict_in_generate: Optional[bool] = None, - **kwargs, - ): - if inputs.device.type != "cuda": - raise ValueError( - f"TensorRT-LLM only supports inputs on CUDA device. Got: inputs.device = {inputs.device}" - ) - - def raise_unsupported(value: Any, name: str, default: Any = None): - if value != default: - raise ValueError( - f"TensorRTForSpeechSeq2Seq.generate does not support the argument {name} (got {name}={value}). Please open an issue at https://github.com/huggingface/optimum-nvidia/issues." - ) - - raise_unsupported(stopping_criteria, name="stopping_criteria") - raise_unsupported(prefix_allowed_tokens_fn, name="prefix_allowed_tokens_fn") - raise_unsupported(synced_gpus, name="synced_gpus", default=False) - raise_unsupported(return_timestamps, name="return_timestamps") - raise_unsupported(task, name="task") - raise_unsupported(prompt_ids, name="prompt_ids") - raise_unsupported(prompt_condition_type, name="prompt_condition_type") - raise_unsupported(temperature, name="temperature") - raise_unsupported(attention_mask, name="attention_mask") - raise_unsupported(time_precision, name="time_precision", default=0.02) - raise_unsupported(return_token_timestamps, name="return_token_timestamps") - raise_unsupported(return_segments, name="return_segments", default=False) - raise_unsupported(return_dict_in_generate, name="return_dict_in_generate") - - # 1. copy generation config - if generation_config is None: - generation_config = copy.deepcopy(self.generation_config) - else: - generation_config = copy.deepcopy(generation_config) - - self._set_language_and_task( - language=language, - task=task, - is_multilingual=is_multilingual, - generation_config=generation_config, - ) - self._set_token_ids( - generation_config=generation_config, - config=self.config, - kwargs=kwargs, - ) - self._set_thresholds_and_condition( - generation_config=generation_config, - logprob_threshold=logprob_threshold, - compression_ratio_threshold=compression_ratio_threshold, - no_speech_threshold=no_speech_threshold, - condition_on_prev_tokens=condition_on_prev_tokens, - ) - - num_beams = kwargs.pop("num_beams", generation_config.num_beams) - input_stride = 1 * 2 # encoder's conv1 stride * encoder's conv2 stride - - batch_size, total_input_frames = self._retrieve_total_input_frames( - input_features=inputs, input_stride=input_stride, kwargs=kwargs - ) - num_segment_frames = input_stride * self.config.max_source_positions - is_shortform = total_input_frames <= num_segment_frames - if not is_shortform: - raise ValueError( - "Whisper TensorRT-LLM implementation only supports short form for now. Please open an issue at https://github.com/huggingface/optimum-nvidia/issues." - ) - - init_tokens = self._retrieve_init_tokens( - inputs, - generation_config=generation_config, - config=self.config, - num_segment_frames=num_segment_frames, - kwargs=kwargs, - ) - - begin_index = len(init_tokens) - logits_processor = self._retrieve_logit_processors( - generation_config=generation_config, - logits_processor=logits_processor, - begin_index=begin_index, # begin index is index of first generated decoder token - is_shortform=is_shortform, - num_beams=kwargs.get("num_beams", 1), - ) - logits_processor = LogitsProcessorList(logits_processor) - - encoder_outputs = self.encoder(inputs) - - batch_size = inputs.shape[0] - one_tensor = torch.ones((batch_size, 1), device="cuda", dtype=torch.long) - decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1) - - max_new_tokens = kwargs.pop("max_new_tokens", generation_config.max_new_tokens) - if max_new_tokens is None: - # Transformers' GenerationConfig.max_new_tokens defaults to None. - if generation_config.max_length is not None: - max_new_tokens = ( - generation_config.max_length - decoder_input_ids.shape[1] - ) - else: - raise ValueError("Please specifiy the argument `max_new_tokens`.") - - if ( - max_new_tokens + decoder_input_ids.shape[-1] - > self.config.max_target_positions - ): - max_new_tokens = kwargs.get("max_new_tokens", 0) - raise ValueError( - f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` " - f"is {max_new_tokens}. Thus, the combined length of " - f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the " - f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. " - "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, " - f"so that their combined length is less than {self.config.max_target_positions}." - ) - - encoder_input_lengths = torch.tensor( - [encoder_outputs.shape[1] for x in range(encoder_outputs.shape[0])], - dtype=torch.int32, - device=inputs.device, - ) - - decoder_input_lengths = torch.tensor( - [decoder_input_ids.shape[-1] for _ in range(decoder_input_ids.shape[0])], - dtype=torch.int32, - device="cuda", - ) - decoder_max_input_length = torch.max(decoder_input_lengths).item() - - cross_attention_mask = torch.ones( - [encoder_outputs.shape[0], 1, encoder_outputs.shape[1]], - device=inputs.device, - dtype=torch.int32, - ) - - sampling_config = SamplingConfig( - end_id=generation_config.eos_token_id, - pad_id=generation_config.pad_token_id, - num_beams=num_beams, - ) - - self.decoder_generation_session.setup( - decoder_input_lengths.size(0), - decoder_max_input_length, - max_new_tokens, - beam_width=num_beams, - encoder_max_input_length=encoder_outputs.shape[1], - ) - - torch.cuda.synchronize() - - decoder_input_ids = decoder_input_ids.type(torch.int32).cuda() - - # output_ids of shape [batch_size, beam_width, output_len] - output_ids = self.decoder_generation_session.decode( - decoder_input_ids, - decoder_input_lengths, - sampling_config, - encoder_output=encoder_outputs, - encoder_input_lengths=encoder_input_lengths, - cross_attention_mask=cross_attention_mask, - logits_processor=logits_processor, - ) - torch.cuda.synchronize() - - return output_ids[ - :, - 0, - : torch.max(self.decoder_generation_session.sequence_length_buffer) + 1, - ] + TRT_LLM_TARGET_MODEL_CLASSES = { + "encoder": TrtWhisperEncoder, + "decoder": TrtDecoderModel, + } diff --git a/src/optimum/nvidia/optimizations/__init__.py b/src/optimum/nvidia/optimizations/__init__.py new file mode 100644 index 00000000..7fb27ab9 --- /dev/null +++ b/src/optimum/nvidia/optimizations/__init__.py @@ -0,0 +1,13 @@ +from typing import TYPE_CHECKING, Protocol, runtime_checkable + + +if TYPE_CHECKING: + from modelopt.torch.quantization import QuantizeConfig + + +@runtime_checkable +class IntoModelOptQuantizeConfig(Protocol): + def into_model_opt_qconfig(self) -> "QuantizeConfig": ... + + +from .datasets import get_dataset, load_dataset, prepare_dataset diff --git a/src/optimum/nvidia/quantization/datasets.py b/src/optimum/nvidia/optimizations/datasets.py similarity index 100% rename from src/optimum/nvidia/quantization/datasets.py rename to src/optimum/nvidia/optimizations/datasets.py diff --git a/src/optimum/nvidia/optimizations/quantize.py b/src/optimum/nvidia/optimizations/quantize.py new file mode 100644 index 00000000..a8620a44 --- /dev/null +++ b/src/optimum/nvidia/optimizations/quantize.py @@ -0,0 +1,59 @@ +import random +from logging import getLogger +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +import numpy as np +import torch +from datasets import Dataset +from modelopt.torch.export import torch_to_tensorrt_llm_checkpoint +from tensorrt_llm.quantization.quantize_by_modelopt import ( + get_model_type, + quantize_model, +) +from torch.utils.data import DataLoader + + +if TYPE_CHECKING: + from modelopt.torch.quantization import QuantizeConfig + from tensorrt_llm import Mapping + from torch import Tensor + from transformers import PreTrainedModel as TransformersPreTrainedModel + + from optimum.nvidia import IntoModelOptQuantizeConfig + + +LOGGER = getLogger() + + +def quantize( + model: "TransformersPreTrainedModel", + qconfig: Union["QuantizeConfig", "IntoModelOptQuantizeConfig"], + dataset: Union[Dataset, DataLoader], + mapping: Optional["Mapping"] = None, + seed: int = 2014, +) -> [Dict[str, Any], Dict[str, "Tensor"]]: + if isinstance(dataset, Dataset): + dataset = DataLoader(dataset) + + if isinstance(qconfig, IntoModelOptQuantizeConfig): + LOGGER.info(f"Converting {qconfig} to TensorRT-LLM quantization config") + qconfig = qconfig.into_model_opt_qconfig() + + # Seed everything + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + # Retrieve additional information + model_type = get_model_type(model) + mapping = mapping or Mapping( + world_size=1, rank=0, gpus_per_node=1, tp_size=1, pp_size=1 + ) + + # Do the quantization + with torch.inference_mode(): + qmodel = quantize_model(model, qconfig, dataset) + + return torch_to_tensorrt_llm_checkpoint( + qmodel, model_type, model.dtype, mapping.tp_size, mapping.pp_size + ) diff --git a/src/optimum/nvidia/quantization/__init__.py b/src/optimum/nvidia/quantization/__init__.py deleted file mode 100644 index 51a5d23f..00000000 --- a/src/optimum/nvidia/quantization/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .ammo import ( - AmmoQuantizationConfig, - AutoQuantizationConfig, - Float8QuantizationConfig, - QuantizationMethod, -) diff --git a/src/optimum/nvidia/quantization/ammo/__init__.py b/src/optimum/nvidia/quantization/ammo/__init__.py deleted file mode 100644 index 5dd8c571..00000000 --- a/src/optimum/nvidia/quantization/ammo/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .config import ( - AmmoQuantizationConfig, - AutoQuantizationConfig, - Float8QuantizationConfig, - QuantizationMethod, -) -from .quantizer import AmmoQuantizer diff --git a/src/optimum/nvidia/quantization/ammo/config.py b/src/optimum/nvidia/quantization/ammo/config.py deleted file mode 100644 index dd173197..00000000 --- a/src/optimum/nvidia/quantization/ammo/config.py +++ /dev/null @@ -1,188 +0,0 @@ -import random -from abc import ABC, abstractmethod -from enum import Enum -from typing import Any, Dict, Optional, Union - -import numpy as np -import torch -from ammo.torch import quantization as atq -from datasets import Dataset -from transformers import PreTrainedTokenizer -from transformers.utils.quantization_config import QuantizationConfigMixin - -from optimum.nvidia.lang import DataType -from optimum.nvidia.quantization.datasets import get_dataset - - -dtype = Union[str, torch.dtype] -TORCH_FLOAT8 = {torch.float8_e4m3fn, torch.float8_e5m2} - - -KV_CACHE_CFG = { - "*.query_key_value.output_quantizer": {"num_bits": 8, "axis": None, "enable": True}, - "*.Wqkv.output_quantizer": {"num_bits": 8, "axis": None, "enable": True}, - "*.W_pack.output_quantizer": {"num_bits": 8, "axis": None, "enable": True}, - "*.c_attn.output_quantizer": {"num_bits": 8, "axis": None, "enable": True}, - "*.k_proj.output_quantizer": {"num_bits": 8, "axis": None, "enable": True}, - "*.v_proj.output_quantizer": {"num_bits": 8, "axis": None, "enable": True}, -} - - -class QuantizationMethod(str, Enum): - FLOAT8 = "fp8" - - -class AmmoQuantizationConfig(ABC, QuantizationConfigMixin): - def __init__( - self, - quant_method: QuantizationMethod, - with_quantized_kv_cache: bool = False, - with_quantized_lm_head: bool = False, - calibration_data: Optional[Dataset] = None, - ): - super().__init__(quant_method) - - self._with_quantized_kv_cache: bool = with_quantized_kv_cache - self._with_quantized_lm_head: bool = with_quantized_lm_head - self._calibration_dataset = calibration_data - - @property - @abstractmethod - def weight_dtype(self) -> torch.dtype: - raise NotImplementedError("AmmoQuantizationConfig::weight_dtype is abstract.") - - @property - @abstractmethod - def has_quantized_kv_cache(self) -> bool: - raise NotImplementedError( - "AmmoQuantizationConfig::has_quantized_kv_cache is abstract." - ) - - @property - def has_calibration_dataset(self) -> bool: - return self._calibration_dataset is not None - - @property - @abstractmethod - def requires_calibration(self) -> bool: - raise NotImplementedError( - "AmmoQuantizationConfig::requires_calibration is abstract." - ) - - @property - def calibration_dataset(self) -> Optional[Dataset]: - return self._calibration_dataset - - @abstractmethod - def as_ammo_config(self) -> Dict[str, Any]: - raise NotImplementedError("AmmoQuantizationConfig::as_ammo_config is abstract.") - - def to_dict(self) -> Dict[str, Any]: - return { - "quant_method": self.quant_method, - "with_kv_cache": self._with_quantized_kv_cache, - "with_lm_head": self._with_quantized_lm_head, - } - - def to_diff_dict(self) -> Dict[str, Any]: - return self.to_dict() - - -class Float8QuantizationConfig(AmmoQuantizationConfig): - def __init__( - self, - with_quantized_kv_cache: bool = False, - with_quantized_lm_head: bool = False, - calibration_data: Optional[Dataset] = None, - ): - super().__init__( - QuantizationMethod.FLOAT8, - with_quantized_kv_cache, - with_quantized_lm_head, - calibration_data, - ) - - @property - def weight_dtype(self) -> torch.dtype: - return torch.float8_e4m3fn - - @property - def has_quantized_kv_cache(self) -> bool: - return self._with_quantized_kv_cache - - @property - def has_quantized_lm_head(self) -> bool: - return self._with_quantized_lm_head - - @property - def requires_calibration(self) -> bool: - return True - - def as_ammo_config(self) -> Dict[str, Any]: - cfg = atq.FP8_DEFAULT_CFG.copy() - quant_config = cfg["quant_cfg"] - - if self.has_quantized_kv_cache: - quant_kv_cache_config = KV_CACHE_CFG.copy() - for value in quant_kv_cache_config.values(): - value.update(num_bits=(4, 3)) - - quant_config.update(quant_kv_cache_config) - - quant_config["*lm_head*"] = {"enable": self.has_quantized_lm_head} - - return cfg - - -class AutoQuantizationConfig: - @classmethod - def from_dict(cls, kwargs): - return cls.from_description(**kwargs) - - @classmethod - def from_description( - cls, - weight: dtype, - activation: dtype, - tokenizer: Optional[PreTrainedTokenizer] = None, - dataset: Optional[Union[str, Dataset]] = None, - split: Optional[str] = "train", - num_samples: int = 512, - max_sequence_length: int = 1024, - seed: int = 2016, - device: Union[str, torch.device] = "cpu", - ): - random.seed(seed) - np.random.seed(seed) - torch.random.manual_seed(seed) - - if isinstance(weight, str): - weight = DataType(weight).to_torch() - - if isinstance(activation, str): - activation = DataType(activation).to_torch() - - if isinstance(dataset, str): - dataset = get_dataset( - dataset, - tokenizer, - num_samples, - seqlen=max_sequence_length, - split=split, - seed=seed, - device=device, - ) - else: - raise ValueError("Providing custom dataset is not yet supported") - - # float8 case - if weight in TORCH_FLOAT8: - return Float8QuantizationConfig( - with_quantized_kv_cache=activation in TORCH_FLOAT8, - with_quantized_lm_head=False, - calibration_data=dataset, - ) - else: - raise NotImplementedError( - f"Quantization(weight= {weight}, activation={activation}) is not supported yet." - ) diff --git a/src/optimum/nvidia/quantization/ammo/quantizer.py b/src/optimum/nvidia/quantization/ammo/quantizer.py deleted file mode 100644 index 048bb30b..00000000 --- a/src/optimum/nvidia/quantization/ammo/quantizer.py +++ /dev/null @@ -1,141 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from logging import getLogger -from os import PathLike -from pathlib import Path -from typing import Union - -import torch -from ammo.torch import export as ate -from ammo.torch import quantization as atq -from tensorrt_llm.quantization import QuantMode -from torch.utils.data import DataLoader -from tqdm import tqdm -from transformers import PreTrainedModel -from transformers.quantizers import HfQuantizer -from transformers.utils.quantization_config import QuantizationConfigMixin - -from optimum.nvidia.quantization.ammo import AmmoQuantizationConfig - - -LOGGER = getLogger(__name__) - -_SUPPORTED_MODEL_ARCHITECTURES = { - "llama": "llama", - "mistral": "llama", - "gemma": "gemma", -} - - -def get_quantization_algorithm_name(qconfig: QuantMode) -> str: - if qconfig.has_fp8_qdq() or qconfig.has_fp8_kv_cache(): - return "fp8" - elif qconfig.is_int4_weight_only(): - return "int4_awq" - elif not qconfig.has_act_and_weight_quant(): - return "int8_sq" - else: - raise ValueError(f"Unable to determine quantization algorithm from: {qconfig}") - - -def infer_decoder_type(model: PreTrainedModel) -> str: - if model.config.model_type in _SUPPORTED_MODEL_ARCHITECTURES: - return _SUPPORTED_MODEL_ARCHITECTURES[model.config.model_type] - - else: - raise ValueError(f"{model.config.model_type} is not supported yet") - - -class AmmoQuantizer(HfQuantizer): - def __init__( - self, - quantization_config: QuantizationConfigMixin, - artifact_path: Union[str, PathLike, Path], - tensor_parallel_degree: int = 1, - pipeline_parallel_degree: int = 1, - export_tensorrt_llm_config: bool = True, - ): - if tensor_parallel_degree < 1: - raise ValueError( - f"tensor_parallel_degree should be >= 1 (got {tensor_parallel_degree})" - ) - - if pipeline_parallel_degree < 1: - raise ValueError( - f"pipeline_parallel_degree should be >= 1 (got {pipeline_parallel_degree})" - ) - - super().__init__(quantization_config=quantization_config) - - if not isinstance(artifact_path, Path): - artifact_path = Path(artifact_path) - - self._artifact_path = artifact_path - self._tp_degree = tensor_parallel_degree - self._pp_degree = pipeline_parallel_degree - self._export_tensorrt_llm_config = export_tensorrt_llm_config - - @property - def is_serializable(self): - return False - - @property - def is_trainable(self): - return False - - def _process_model_before_weight_loading( - self, model, batch_size: int = 1, **kwargs - ): - assert isinstance(self.quantization_config, AmmoQuantizationConfig) - qconfig = self.quantization_config - - if qconfig.requires_calibration: - if not qconfig.has_calibration_dataset: - raise ValueError("Float8 quantization requires a calibration dataset") - - with torch.inference_mode(): - - def _loop(): - data = DataLoader( - qconfig.calibration_dataset, - batch_size=batch_size, - pin_memory=True, - pin_memory_device="cuda:0", - ) - - for sample in tqdm(data): - inputs = { - name: tensor.to("cuda:0") for name, tensor in sample.items() - } - model(**inputs) - - atq.quantize(model, config=qconfig.as_ammo_config(), forward_loop=_loop) - - def _process_model_after_weight_loading(self, model, **kwargs): - assert isinstance(self.quantization_config, AmmoQuantizationConfig) - - with torch.inference_mode(): - decoder_type = infer_decoder_type(model) - ate.export_model_config( - model=model, - decoder_type=decoder_type, - dtype=model.config.torch_dtype, - export_dir=self._artifact_path, - inference_tensor_parallel=self._tp_degree, - inference_pipeline_parallel=self._pp_degree, - export_tensorrt_llm_config=self._export_tensorrt_llm_config, - export_npz=False, - ) diff --git a/src/optimum/nvidia/runtime.py b/src/optimum/nvidia/runtime.py index 6572a2d9..a09871ac 100644 --- a/src/optimum/nvidia/runtime.py +++ b/src/optimum/nvidia/runtime.py @@ -1,327 +1,160 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import warnings +import asyncio +import json +import math from logging import getLogger from os import PathLike from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union -import tensorrt_llm.bindings as ctrrt import torch -from transformers import GenerationConfig -from transformers.generation.utils import GenerationMixin +from tensorrt_llm.bindings.executor import ExecutorConfig, KvCacheConfig +from tensorrt_llm.executor import ( + GenerationExecutor, + GenerationRequest, + GenerationResult, +) +from tensorrt_llm.hlapi import SamplingParams + +from optimum.nvidia.utils.nvml import is_post_ampere if TYPE_CHECKING: - from transformers import PretrainedConfig, PreTrainedModel - from transformers.generation.logits_process import LogitsProcessorList - from transformers.generation.stopping_criteria import StoppingCriteriaList - from transformers.generation.streamers import BaseStreamer + from transformers import ( + GenerationConfig, + ) LOGGER = getLogger(__name__) -PackedTensor = List[torch.Tensor] +def read_engine_config_file(path: Path) -> Dict[str, Any]: + with open(path / "config.json", "r", encoding="utf-8") as config_f: + return json.load(config_f) + + +def convert_generation_config(config: "GenerationConfig") -> "SamplingParams": + return SamplingParams( + end_id=config.eos_token_id, + pad_id=config.pad_token_id, + top_k=config.top_k if config.do_sample else 1, + top_p=config.top_p, + temperature=config.temperature, + beam_width=config.num_beams if config.do_sample else 1, + bad_words=config.bad_words_ids, + length_penalty=config.length_penalty, + repetition_penalty=config.repetition_penalty, + no_repeat_ngram_size=config.no_repeat_ngram_size + if config.no_repeat_ngram_size > 0 + else 1, + min_length=config.min_length if config.min_length > 0 else 1, + max_new_tokens=config.max_new_tokens, + return_generation_logits=config.output_logits, + return_log_probs=not config.renormalize_logits, + ) -DEFAULT_BATCH_SIZE: int = 1 -DEFAULT_PROMPT_LENGTH: int = 128 -DEFAULT_BEAM_WIDTH: int = 1 +def default_executor_config(config: Dict[str, Any]) -> "ExecutorConfig": + build_config = config["build_config"] + plugin_config = config["build_config"]["plugin_config"] -class CompiledModel: - def __init__(self, engines_folders_path: List[Union[Path, PathLike]]): - # A compiled model may have several subfolders (e.g. encoder-decoder model). - self._engines_folders_path = [ - Path(engines_folder_path) for engines_folder_path in engines_folders_path - ] + max_blocks_per_sequence = math.floor( + build_config["max_seq_len"] / plugin_config["tokens_per_block"] + ) - @property - def engine_path(self) -> Path: - """ - Return the local path where the engine(s) is/are located - :return: Path to the folder holding the engine(s) definition(s) - """ - return self._engines_folders_path - - -class CausalLM(CompiledModel, GenerationMixin): - main_input_name = "input_ids" - - __slots__ = ( - "_device", - "_config", - "_mapping", - "_session", - "_session_config", - "_use_packed_inputs", - "max_beam_width", - "max_batch_size", - "max_prompt_length", - "max_output_length", + return ExecutorConfig( + enable_chunked_context=is_post_ampere(), + kv_cache_config=KvCacheConfig( + enable_block_reuse=True, + max_tokens=build_config["max_beam_width"] + * plugin_config["tokens_per_block"] + * max_blocks_per_sequence, + ), ) + +class InferenceRuntimeBase: + __slots__ = ("_config", "_executor", "_generation_config", "_sampling_config") + def __init__( self, - engines_folders: List[Path], - *, - gpus_per_node: int, - transformers_config: "PretrainedConfig", - use_cuda_graph: bool = False, - generation_config: Optional[GenerationConfig] = None, + engines_path: Union[str, PathLike], + generation_config: "GenerationConfig", + executor_config: Optional["ExecutorConfig"] = None, ): - if len(engines_folders) != 1: - raise ValueError( - f"For CausalLM, expecting a single engine folder, got: {engines_folders}" - ) - super().__init__(engines_folders) - engines_folder = engines_folders[0] - - self._device = torch.device("cuda") - self._config = ctrrt.GptJsonConfig.parse_file(engines_folder / "config.json") - self._mapping = ctrrt.WorldConfig.mpi( - gpus_per_node, - self._config.tensor_parallelism, - self._config.pipeline_parallelism, - ) - self._session_config = ctrrt.GptSessionConfig( - max_batch_size=self._config.model_config.max_batch_size, - max_beam_width=self._config.model_config.max_beam_width, - max_sequence_length=self._config.model_config.max_seq_len, + engines_path = Path(engines_path) + + if not engines_path.exists(): + raise OSError(f"engine folder {engines_path} doesn't exist") + + self._config = read_engine_config_file(engines_path) + self._generation_config = generation_config + self._sampling_config = convert_generation_config(generation_config) + + self._executor = GenerationExecutor.create( + engine_dir=engines_path, + executor_config=executor_config or default_executor_config(self._config), + tokenizer=None, ) - self._session_config.cuda_graph_mode = use_cuda_graph - - # Create the engine - engine_file = self._config.engine_filename(self._mapping) - - try: - self._session = ctrrt.GptSession( - config=self._session_config, - model_config=self._config.model_config, - world_config=self._mapping, - engine_file=str(engines_folder.joinpath(engine_file)), - ) - except RuntimeError as e: - if "maxTokensInPagedKvCache" in repr( - e - ) and "must be large enough to process at least 1 sequence" in repr(e): - raise RuntimeError( - f"Could not initialize TensorRT-LLM decoder session, likely due a large maximum output length set at compilation time (max_output_len={self._config.model_config.max_seq_len}). Please try and set a lower value for `max_output_length` when building the engine. Error: {e}" - ) - else: - raise e - - # Additional cached properties - self._use_packed_inputs = self._config.model_config.use_packed_input - self.max_batch_size = self._config.model_config.max_batch_size - self.max_prompt_length = self._config.model_config.max_input_len - - self.max_output_length = self._config.model_config.max_seq_len - self.max_beam_width = self._session_config.max_beam_width - - if generation_config is None: - generation_config = GenerationConfig() - self.generation_config = generation_config - - # Required for GenerationMixin compatibility. - self.config = transformers_config - - @torch.no_grad() def generate( self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional["LogitsProcessorList"] = None, - stopping_criteria: Optional["StoppingCriteriaList"] = None, - prefix_allowed_tokens_fn: Optional[ - Callable[[int, torch.Tensor], List[int]] - ] = None, - synced_gpus: Optional[bool] = None, - assistant_model: Optional["PreTrainedModel"] = None, - streamer: Optional["BaseStreamer"] = None, - negative_prompt_ids: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.LongTensor: - def raise_unsupported(value: Any, name: str, default: Any = None): - if value != default: - raise ValueError( - f"{self.__class__.__name__}.generate does not support the argument {name} (got {name}={value}). Please open an issue at https://github.com/huggingface/optimum-nvidia/issues." - ) - - raise_unsupported(stopping_criteria, name="stopping_criteria") - raise_unsupported(prefix_allowed_tokens_fn, name="prefix_allowed_tokens_fn") - raise_unsupported(synced_gpus, name="synced_gpus") - raise_unsupported(logits_processor, name="logits_processor") - raise_unsupported(assistant_model, name="assistant_model") - raise_unsupported(streamer, name="streamer") - raise_unsupported(negative_prompt_ids, name="negative_prompt_ids") - raise_unsupported( - negative_prompt_attention_mask, name="negative_prompt_attention_mask" + inputs: Union[List[int], "torch.IntTensor"], + generation_config: Optional["GenerationConfig"] = None, + ): + # Retrieve the sampling config + sampling = ( + convert_generation_config(generation_config) + if generation_config + else self._sampling_config ) - # priority: `generation_config` argument > `model.generation_config` (the default generation config) - if generation_config is None: - # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, - # three conditions must be met - # 1) the generation config must have been created from the model config (`_from_model_config` field); - # 2) the generation config must have seen no modification since its creation (the hash is the same); - # 3) the user must have set generation parameters in the model config. - if ( - self.generation_config._from_model_config - and self.generation_config._original_object_hash - == hash(self.generation_config) - and self.config._has_non_default_generation_parameters() - ): - new_generation_config = GenerationConfig.from_model_config(self.config) - if new_generation_config != self.generation_config: - warnings.warn( - "You have modified the pretrained model configuration to control generation. This is a" - " deprecated strategy to control generation and will be removed soon, in a future version." - " Please use and modify the model generation configuration (see" - " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" - ) - self.generation_config = new_generation_config - generation_config = self.generation_config - - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update( - **kwargs - ) # All unused kwargs must be model kwargs - - if ( - generation_config.pad_token_id is None - and generation_config.eos_token_id is not None - ): - if model_kwargs.get("attention_mask", None) is None: - LOGGER.warning( - "The attention mask and the pad token id were not set. As a consequence, you may observe " - "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." - ) - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - LOGGER.warning( - f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation." - ) - generation_config.pad_token_id = eos_token_id - - device = self._device - - seed = model_kwargs.pop("seed", 42) - # If no GenerationConfig is provided, let's allocate one with default settings - sampling_config = ctrrt.SamplingConfig( - min(generation_config.num_beams, self.max_beam_width) - ) - sampling_config.random_seed = [seed] - sampling_config.temperature = [generation_config.temperature] - sampling_config.top_k = [generation_config.top_k] - sampling_config.top_p = [generation_config.top_p] - sampling_config.repetition_penalty = [generation_config.repetition_penalty] - sampling_config.length_penalty = [generation_config.length_penalty] - - if generation_config.min_new_tokens is not None: - sampling_config.min_length = [generation_config.min_new_tokens] - - input_ids, _, model_kwargs = self._prepare_model_inputs( - inputs, generation_config.bos_token_id, model_kwargs + if isinstance(inputs, torch.Tensor): + inputs = inputs.tolist() + + result = self._executor.generate(inputs, sampling_params=sampling) + return result[0].token_ids + + async def agenerate( + self, + inputs: Union[List[int], "torch.IntTensor"], + generation_config: Optional["GenerationConfig"] = None, + ) -> List[int]: + # Retrieve the sampling config + sampling = ( + convert_generation_config(generation_config) + if generation_config + else self._sampling_config ) - input_length = input_ids.shape[1] - - with torch.no_grad(): - if not isinstance(input_ids, torch.Tensor): - raise TypeError("input_ids should be a PyTorch tensor (torch.Tensor)") - - attention_mask = model_kwargs["attention_mask"] - input_ids, lengths = self._prepare_inputs(input_ids, attention_mask) - if torch.any(torch.gt(lengths, self.max_prompt_length)): - raise ValueError( - f"Input length {lengths} is bigger than maximum prompt length ({self.max_prompt_length})." - ) - - trt_inputs = ctrrt.GenerationInput( - end_id=generation_config.eos_token_id, - pad_id=generation_config.pad_token_id, - ids=input_ids.to(device), - lengths=lengths.to(device), - packed=self._use_packed_inputs, - ) - - max_new_tokens = generation_config.max_new_tokens - if max_new_tokens is None or max_new_tokens < 1: - max_new_tokens = self.max_output_length - input_ids.shape[1] - - trt_inputs.max_new_tokens = max_new_tokens - - # Tensors are being allocated as in/out parameters and TRTLLM will resize - trt_outputs = ctrrt.GenerationOutput( - ids=torch.empty(0, device=device, dtype=torch.int32), - lengths=torch.empty(0, device=device, dtype=torch.int32), - ) - - self._session.generate(trt_outputs, trt_inputs, sampling_config) - - total_length = trt_outputs.lengths.max().item() - output_ids = trt_outputs.ids.flatten(0, 1) - - # For some reason not in line with Transformers in case we finish early with EOS token (missing last EOS token). - if total_length - input_length < max_new_tokens: - total_length += 1 - - return output_ids[:, :total_length], total_length - - def _prepare_inputs( - self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - shape = input_ids.size() - input_ids = input_ids.int() - - if input_ids.ndim == 1: - lengths = torch.tensor(shape, dtype=torch.int32, device=self._device) - elif input_ids.ndim == 2 and shape[0] == 1: - lengths = torch.tensor([shape[1]], dtype=torch.int32, device=self._device) - elif attention_mask is not None: - lengths = attention_mask.sum(dim=1, dtype=torch.int32) - else: - warnings.warn( - "Not enough information to compute the non-padded tensor length. " - "Please provide an attention_mask to avoid situations where padding" - " will be attended to in attention modules" - ) - attention_mask = torch.ones_like(input_ids) - lengths = torch.tensor(shape, dtype=torch.int32).flatten() + if isinstance(inputs, torch.Tensor): + inputs = inputs.tolist() - if self._use_packed_inputs and shape[0] > 1: - input_ids = torch.masked_select(input_ids, attention_mask.bool()).view( - 1, -1 - ) + futures = self._executor.generate_async( + inputs, streaming=False, sampling_params=sampling + ) + if isinstance(futures, GenerationRequest): + results = await futures.aresult() + return results.token_ids + else: + results = await asyncio.gather(*[f.aresult() for f in futures]) + return [r.token_ids for r in results] - return input_ids, lengths +class CausalLMOutput: + __slots__ = ("_results",) -class TensorRTForSpeechSeq2Seq(CompiledModel): def __init__( - self, - engines_folders: List[Path], - *, - gpus_per_node: int, - transformers_config: "PretrainedConfig", - use_cuda_graph: bool = False, - generation_config: Optional[GenerationConfig] = None, + self, results: Union["GenerationResult", Sequence["GenerationResult"]] ): - super().__init__(engines_folders) + self._results = results + + @property + def logits(self): + return self._results.token_ids + + @property + def loss(self) -> None: + return None + + +class CausalLM(InferenceRuntimeBase): + pass diff --git a/src/optimum/nvidia/utils/nvml.py b/src/optimum/nvidia/utils/nvml.py index 7049ea2f..6c8ea9a9 100644 --- a/src/optimum/nvidia/utils/nvml.py +++ b/src/optimum/nvidia/utils/nvml.py @@ -33,6 +33,9 @@ _NVML_INITIALIZED = False +CUDA_ARCH_VOLTA = 7 +CUDA_ARCH_AMPERE = 8 +CUDA_ARCH_HOPPER = 9 SM_FP8_SUPPORTED = {89, 90} @@ -108,6 +111,27 @@ def has_float8_support() -> bool: return False +@functools.cache +@nvml_guard +def is_post_volta() -> bool: + major, _ = get_device_compute_capabilities(0) + return major >= CUDA_ARCH_VOLTA + + +@functools.cache +@nvml_guard +def is_post_ampere() -> bool: + major, _ = get_device_compute_capabilities(0) + return major >= CUDA_ARCH_AMPERE + + +@functools.cache +@nvml_guard +def is_post_hopper() -> bool: + major, _ = get_device_compute_capabilities(0) + return major >= CUDA_ARCH_HOPPER + + def get_max_memory(): fraction_device_map = { device_id: get_device_memory(device_id) * 0.7 diff --git a/tests/integration/test_float8_quantization.py b/tests/integration/test_float8_quantization.py deleted file mode 100644 index 78e633e6..00000000 --- a/tests/integration/test_float8_quantization.py +++ /dev/null @@ -1,66 +0,0 @@ -import tempfile - -import pytest -import torch -from transformers import AutoConfig, AutoTokenizer -from transformers import AutoModelForCausalLM as HfAutoModelForCausalLM - -from optimum.nvidia import AutoModelForCausalLM -from optimum.nvidia.quantization import AutoQuantizationConfig -from optimum.nvidia.utils.tests.utils import requires_float8 - - -@pytest.mark.parametrize( - "model_id", - ["google/gemma-2b", "meta-llama/Llama-2-7b-hf", "mistralai/Mistral-7B-v0.1"], -) -@requires_float8() -def test_float8_causallm_use_fp8(model_id: str): - # Use a tinier model - config = AutoConfig.from_pretrained(model_id) - config.num_hidden_layers = 1 - - # Create the flow and convert - with tempfile.TemporaryDirectory() as tmp_f: - _ = AutoTokenizer.from_pretrained(model_id).save_pretrained(tmp_f) - _ = HfAutoModelForCausalLM.from_config(config).save_pretrained(tmp_f) - model = AutoModelForCausalLM.from_pretrained( - tmp_f, max_batch_size=1, beam_width=1, use_fp8=True - ) - - assert model is not None - - -@pytest.mark.parametrize( - "model_id", - # ["google/gemma-2b", "meta-llama/Llama-2-7b-hf", "mistralai/Mistral-7B-v0.1"], - ["meta-llama/Llama-2-7b-hf", "mistralai/Mistral-7B-v0.1"], -) -@pytest.mark.parametrize("weight", ["float8", torch.float8_e4m3fn]) -@pytest.mark.parametrize("activation", ["float8", torch.float8_e4m3fn]) -@pytest.mark.parametrize("dataset", ["c4-new"]) -@requires_float8() -def test_float8_causallm_custom_qconfig_predefined_dataset( - model_id: str, dataset: str, weight, activation -): - # Use a tiner model - config = AutoConfig.from_pretrained(model_id) - config.num_hidden_layers = 1 - - # Create the flow and convert - with tempfile.TemporaryDirectory() as tmp_f: - _ = HfAutoModelForCausalLM.from_config(config).save_pretrained(tmp_f) - tokenizer = AutoTokenizer.from_pretrained(model_id) - - qconfig = AutoQuantizationConfig.from_description( - weight=weight, - activation=activation, - tokenizer=tokenizer, - dataset=dataset, - num_samples=32, - max_sequence_length=128, - ) - model = AutoModelForCausalLM.from_pretrained( - tmp_f, max_batch_size=1, beam_width=1, quantization_config=qconfig - ) - assert model is not None diff --git a/tests/integration/test_whisper.py b/tests/integration/test_whisper.py deleted file mode 100644 index bbd016c1..00000000 --- a/tests/integration/test_whisper.py +++ /dev/null @@ -1,158 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tempfile -from glob import glob -from pathlib import Path -from typing import Optional - -import datasets -import pytest -import torch -from transformers import AutoProcessor -from transformers import ( - WhisperForConditionalGeneration as TransformersWhisperForConditionalGeneration, -) -from utils_testing import clean_cached_engines_for_model - -from optimum.nvidia.models.whisper import WhisperForConditionalGeneration - - -TEST_MODELS = [ - "openai/whisper-tiny.en", - "openai/whisper-large-v3", - "distil-whisper/distil-medium.en", -] - - -@pytest.mark.parametrize("model_id", TEST_MODELS) -def test_whisper(model_id: str): - # Make sure we remove the potentially already built engines. - clean_cached_engines_for_model(model_id) - - model = WhisperForConditionalGeneration.from_pretrained(model_id) - with tempfile.TemporaryDirectory() as tmp_f: - model.save_pretrained(tmp_f) - - encoder_engines_files = glob(Path(tmp_f, "encoder/engines/*.engine").as_posix()) - decoder_engines_files = glob(Path(tmp_f, "decoder/engines/*.engine").as_posix()) - - assert len(encoder_engines_files) > 0 - assert len(decoder_engines_files) > 0 - - model = WhisperForConditionalGeneration.from_pretrained(tmp_f) - - -@pytest.mark.parametrize("model_id", TEST_MODELS) -@pytest.mark.parametrize("max_new_tokens", [None, 10]) -def test_generation(model_id: str, max_new_tokens: Optional[int]): - # Make sure we remove the potentially already built engines. - clean_cached_engines_for_model(model_id) - - torch_dtype = torch.float16 # TODO: test fp8, int4, int8, fp32 - - trt_model = WhisperForConditionalGeneration.from_pretrained( - model_id, torch_dtype=torch_dtype - ) - with torch.device("cuda"): - torch_model = TransformersWhisperForConditionalGeneration.from_pretrained( - model_id, torch_dtype=torch_dtype - ) - - processor = AutoProcessor.from_pretrained(model_id) - data = datasets.load_dataset( - "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" - ) - - kwargs = {} - if max_new_tokens is not None: - kwargs["max_new_tokens"] = max_new_tokens - - if hasattr(torch_model.generation_config, "lang_to_id"): - torch_model.generation_config.language = "<|en|>" - trt_model.generation_config.language = "<|en|>" - - for i in range(20): - inputs = processor( - data[i]["audio"]["array"], - return_tensors="pt", - sampling_rate=data[i]["audio"]["sampling_rate"], - ).to("cuda") - - input_features = inputs.input_features - input_features = input_features.to(torch_dtype) - - torch_model = torch_model.eval() - - # Greedy search. - trt_generated_ids = trt_model.generate( - inputs=input_features, num_beams=1, do_sample=False, top_k=None, **kwargs - ) - torch_generated_ids = torch_model.generate( - inputs=input_features, num_beams=1, do_sample=False, top_k=None, **kwargs - ) - - assert torch.equal(trt_generated_ids, torch_generated_ids) - - -@pytest.mark.parametrize("model_id", TEST_MODELS) -@pytest.mark.parametrize("max_new_tokens", [None, 10]) -def test_batched_generation(model_id: str, max_new_tokens: Optional[int]): - # Make sure we remove the potentially already built engines. - clean_cached_engines_for_model(model_id) - - torch_dtype = torch.float16 # TODO: test fp8, int4, int8, fp32 - - trt_model = WhisperForConditionalGeneration.from_pretrained( - model_id, torch_dtype=torch_dtype, max_batch_size=5 - ) - with torch.device("cuda"): - torch_model = TransformersWhisperForConditionalGeneration.from_pretrained( - model_id, torch_dtype=torch_dtype - ) - - processor = AutoProcessor.from_pretrained(model_id) - data = datasets.load_dataset( - "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" - ) - - if hasattr(torch_model.generation_config, "lang_to_id"): - torch_model.generation_config.language = "<|en|>" - trt_model.generation_config.language = "<|en|>" - - kwargs = {} - if max_new_tokens is not None: - kwargs["max_new_tokens"] = max_new_tokens - - for batch_size in [2, 3, 4]: - subdata = data.select(range(batch_size)) - inputs = processor( - [dat["array"] for dat in subdata["audio"]], return_tensors="pt" - ).to("cuda") - - input_features = inputs.input_features - input_features = input_features.to(torch_dtype) - - assert input_features.shape[0] == batch_size - - # Greedy search. - trt_generated_ids = trt_model.generate( - inputs=input_features, num_beams=1, do_sample=False, top_k=None, **kwargs - ) - torch_generated_ids = torch_model.generate( - inputs=input_features, num_beams=1, do_sample=False, top_k=None, **kwargs - ) - - assert torch.equal(trt_generated_ids, torch_generated_ids) diff --git a/tests/test_hub.py b/tests/test_hub.py index 91ec0561..b89ab2cc 100644 --- a/tests/test_hub.py +++ b/tests/test_hub.py @@ -2,63 +2,44 @@ from tempfile import TemporaryDirectory import mock +import pytest from transformers import AutoConfig as HfAutoConfig from transformers import AutoModelForCausalLM as HfAutoModelForCausalLM -# import pytest import optimum.nvidia.hub from optimum.nvidia import AutoModelForCausalLM -from optimum.nvidia.hub import FOLDER_TRTLLM_ENGINES -# from optimum.nvidia.utils.nvml import get_device_name +@pytest.mark.parametrize( + "model_id", + ("meta-llama/Llama-2-7b-chat-hf", "google/gemma-2b", "mistralai/Mistral-7B-v0.3"), +) +def test_save_engine_locally_and_reload(model_id: str): + with TemporaryDirectory() as hf_out: + with TemporaryDirectory() as trtllm_out: + trtllm_out = Path(trtllm_out) + def _save(): + config = HfAutoConfig.from_pretrained(model_id) + config.num_hidden_layers = 1 -# def test_load_engine_from_huggingface_hub(): -# with mock.patch("optimum.nvidia.hub.HuggingFaceHubModel.convert_and_build"): -# device = get_device_name(0) -# -# try: -# model = AutoModelForCausalLM.from_pretrained( -# "optimum-nvidia/llama-ci", revision=device[-1].lower() -# ) -# assert model is not None -# assert ( -# optimum.nvidia.hub.HuggingFaceHubModel.convert_and_build.call_count == 0 -# ) -# except ValueError: -# pytest.skip( -# f"No revision found for optimum-nvidia/llama-ci on GPU: {device[-1].lower()}" -# ) + model = HfAutoModelForCausalLM.from_config(config) + model.save_pretrained(hf_out) + model = AutoModelForCausalLM.from_pretrained(hf_out) + model.save_pretrained(trtllm_out) -def test_save_engine_locally_and_reload(): - with TemporaryDirectory() as out: - out = Path(out) - hf_out = out.joinpath("_hf") - trtllm_out = out.joinpath("_trtllm") + assert trtllm_out.exists() + assert (trtllm_out / "rank0.engine").exists() - def _save(): - config = HfAutoConfig.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - config.num_hidden_layers = 1 + def _reload(): + with mock.patch("optimum.nvidia.export.TensorRTModelConverter.build"): + model = AutoModelForCausalLM.from_pretrained(trtllm_out) + assert model is not None + assert ( + optimum.nvidia.export.TensorRTModelConverter.build.call_count + == 0 + ) - model = HfAutoModelForCausalLM.from_config(config) - model.save_pretrained(hf_out) - - model = AutoModelForCausalLM.from_pretrained(hf_out) - model.save_pretrained(trtllm_out) - - assert trtllm_out.exists() - assert (trtllm_out / FOLDER_TRTLLM_ENGINES / "rank0.engine").exists() - - def _reload(): - with mock.patch("optimum.nvidia.hub.HuggingFaceHubModel.convert_and_build"): - model = AutoModelForCausalLM.from_pretrained(trtllm_out) - assert model is not None - assert ( - optimum.nvidia.hub.HuggingFaceHubModel.convert_and_build.call_count - == 0 - ) - - _save() - _reload() + _save() + _reload() diff --git a/third-party/tensorrt-llm b/third-party/tensorrt-llm index 250d9c29..9bd15f19 160000 --- a/third-party/tensorrt-llm +++ b/third-party/tensorrt-llm @@ -1 +1 @@ -Subproject commit 250d9c293d5edbc2a45c20775b3150b1eb68b364 +Subproject commit 9bd15f1937f52658fb116f30d58fea786ce5d03b