diff --git a/pyproject.toml b/pyproject.toml index af5b48fb..8cf67152 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,14 +31,18 @@ dependencies = [ "mpmath == 1.3.0", "numpy >= 1.26.0, < 2.0.0", "onnx >= 1.12.0", - "optimum >= 1.13.0", + "optimum >= 1.21.0", "setuptools", - "tensorrt-llm == 0.12.0.dev2024072300", - "torch>=2.3.0a,<=2.4.0a", - "transformers >= 4.38.2", + "tensorrt-llm == 0.13.0.dev2024082700", + "torch>=2.4.0a,<=2.5.0a", +# "transformers >= 4.43.2", "pynvml" ] +[project.scripts] +optimum-cli="optimum.commands.optimum_cli:main" + + [project.urls] Homepage = "https://huggingface.co/hardware/nvidia" Repository = "https://github.com/huggingface/optimum-nvidia" diff --git a/setup.py b/setup.py index f8191027..eeb41b27 100644 --- a/setup.py +++ b/setup.py @@ -34,11 +34,11 @@ "mpmath == 1.3.0", "numpy >= 1.26.0", "onnx >= 1.12.0", - "optimum >= 1.13.0", + "optimum >= 1.21.0", "setuptools", - "tensorrt-llm == 0.12.0.dev2024072300", - "torch>=2.3.0a,<=2.4.0a", - "transformers >= 4.38.2", + "tensorrt-llm == 0.13.0.dev2024082700", + "torch>=2.3.0a,<=2.5.0a", + "transformers >= 4.43.2", "pynvml" ] @@ -98,4 +98,9 @@ dependency_links=["https://pypi.nvidia.com"], include_package_data=True, zip_safe=False, + entry_points={ + "console_scripts": [ + "optimum-cli=optimum.commands.optimum_cli:main", + ] + }, ) diff --git a/src/optimum/commands/env.py b/src/optimum/commands/env.py new file mode 100644 index 00000000..ecf1e0ad --- /dev/null +++ b/src/optimum/commands/env.py @@ -0,0 +1,57 @@ +import platform +import subprocess + +import huggingface_hub +from tensorrt import __version__ as trt_version +from tensorrt_llm import __version__ as trtllm_version +from transformers import __version__ as transformers_version +from transformers.utils import is_torch_available + +from ..nvidia.version import __version__ as optimum_nvidia_version +from ..version import __version__ as optimum_version +from . import BaseOptimumCLICommand, CommandInfo + + +class EnvironmentCommand(BaseOptimumCLICommand): + COMMAND = CommandInfo( + name="env", help="Get information about the environment used." + ) + + @staticmethod + def print_apt_pkgs(): + apt = subprocess.Popen(["apt", "list", "--installed"], stdout=subprocess.PIPE) + grep = subprocess.Popen( + ["grep", "cuda"], stdin=apt.stdout, stdout=subprocess.PIPE + ) + pkgs_list = list(grep.stdout) + for pkg in pkgs_list: + print(pkg.decode("utf-8").split("\n")[0]) + + def run(self): + pt_version = "not installed" + if is_torch_available(): + import torch + + pt_version = torch.__version__ + + platform_info = { + "Platform": platform.platform(), + "Python version": platform.python_version(), + } + info = { + "`optimum-neuron` version": optimum_nvidia_version, + "`tensorrt` version": trt_version, + "`tensorrt-llm` version": trtllm_version, + "`optimum` version": optimum_version, + "`transformers` version": transformers_version, + "`huggingface_hub` version": huggingface_hub.__version__, + "`torch` version": f"{pt_version}", + } + + print("\nCopy-and-paste the text below in your GitHub issue:\n") + print("\nPlatform:\n") + print(self.format_dict(platform_info)) + print("\nPython packages:\n") + print(self.format_dict(info)) + print("\nCUDA system packages:\n") + self.print_apt_pkgs() diff --git a/src/optimum/commands/export/trtllm.py b/src/optimum/commands/export/trtllm.py new file mode 100644 index 00000000..daf5c76e --- /dev/null +++ b/src/optimum/commands/export/trtllm.py @@ -0,0 +1,41 @@ +import subprocess +import sys +from typing import TYPE_CHECKING, Optional + +from ..base import BaseOptimumCLICommand, CommandInfo +from ...nvidia.export.cli import common_trtllm_export_args + +if TYPE_CHECKING: + from argparse import ArgumentParser, Namespace, _SubParsersAction + + + +class TrtLlmExportCommand(BaseOptimumCLICommand): + COMMAND = CommandInfo( + name="trtllm", help="Export PyTorch models to TensorRT-LLM compiled engines" + ) + + def __init__( + self, + subparsers: "_SubParsersAction", + args: Optional["Namespace"] = None, + command: Optional["CommandInfo"] = None, + from_defaults_factory: bool = False, + parser: Optional["ArgumentParser"] = None, + ): + super().__init__( + subparsers, + args=args, + command=command, + from_defaults_factory=from_defaults_factory, + parser=parser, + ) + self.args_string = " ".join(sys.argv[3:]) + + @staticmethod + def parse_args(parser: "ArgumentParser"): + return common_trtllm_export_args(parser) + + def run(self): + full_command = f"python3 -m optimum.exporters.trtllm {self.args_string}" + subprocess.run(full_command, shell=True, check=True) diff --git a/src/optimum/commands/register/register_export.py b/src/optimum/commands/register/register_export.py new file mode 100644 index 00000000..858f8987 --- /dev/null +++ b/src/optimum/commands/register/register_export.py @@ -0,0 +1,13 @@ +"""Registers the export command for TRTLLM to the Optimum CLI.""" + +from ...nvidia.utils.import_utils import is_tensorrt_llm_available +from ..export import ExportCommand + + +if _tensorrt_llm_export_command_was_imported := is_tensorrt_llm_available(): + from ..export.trtllm import TrtLlmExportCommand # noqa: F811 + +if _tensorrt_llm_export_command_was_imported: + REGISTER_COMMANDS = [(TrtLlmExportCommand, ExportCommand)] +else: + REGISTER_COMMANDS = [] diff --git a/src/optimum/exporters/trtllm/__init__.py b/src/optimum/exporters/trtllm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/optimum/exporters/trtllm/main.py b/src/optimum/exporters/trtllm/main.py new file mode 100644 index 00000000..b6b5a0a5 --- /dev/null +++ b/src/optimum/exporters/trtllm/main.py @@ -0,0 +1,26 @@ +from argparse import ArgumentParser + +from huggingface_hub import login + +from optimum.nvidia import AutoModelForCausalLM +from optimum.nvidia.export import ExportConfig +from optimum.nvidia.export.cli import common_trtllm_export_args + +from transformers import AutoConfig + +if __name__ == '__main__': + parser = ArgumentParser("Hugging Face Optimum TensorRT-LLM exporter") + common_trtllm_export_args(parser) + parser.add_argument("--push-to-hub", type=str, required=False, help="Repository id where to push engines") + parser.add_argument("destination", help="Local path where the generated engines will be saved") + args = parser.parse_args() + + config = AutoConfig.from_pretrained(args.model) + export = ExportConfig.from_config(config, args.max_batch_size) + model = AutoModelForCausalLM.from_pretrained(args.model, export_config=export, export_only=True) + model.save_pretrained(args.destination) + + if args.push_to_hub: + print(f"Exporting model to the Hugging Face Hub: {args.push_to_hub}") + model.push_to_hub(args.push_to_hub, commit_message=f"Optimum-CLI TensorRT-LLM {args.model} export") + diff --git a/src/optimum/nvidia/export/cli.py b/src/optimum/nvidia/export/cli.py new file mode 100644 index 00000000..54360786 --- /dev/null +++ b/src/optimum/nvidia/export/cli.py @@ -0,0 +1,43 @@ +def common_trtllm_export_args(parser: "ArgumentParser"): + required_group = parser.add_argument_group("Required arguments") + required_group.add_argument( + "-m", + "--model", + type=str, + required=True, + help="Model ID on huggingface.co or path on disk to load model from.", + ) + required_group.add_argument( + "--max-input-length", + type=int, + default=1, + help="Maximum sequence length, in number of tokens, the prompt can be. The maximum number of potential tokens " + "generated will be - .", + ) + required_group.add_argument( + "--max-output-length", + type=int, + default=1, + help="Maximum sequence length, in number of tokens, the model supports.", + ) + + optional_group = parser.add_argument_group("Optional arguments") + optional_group.add_argument( + "-d", + "--dtype", + type=str, + default="auto", + help="Computational data type used for the model.", + ) + optional_group.add_argument( + "--max-batch-size", + type=int, + default=1, + help="Maximum number of concurrent requests the model can process.", + ) + optional_group.add_argument( + "--max-beams-width", + type=int, + default=1, + help='Maximum number of sampling paths ("beam") to evaluate when decoding new a token.', + ) \ No newline at end of file diff --git a/src/optimum/nvidia/export/config.py b/src/optimum/nvidia/export/config.py index 6f5479b9..7a479236 100644 --- a/src/optimum/nvidia/export/config.py +++ b/src/optimum/nvidia/export/config.py @@ -7,6 +7,7 @@ from tensorrt_llm import BuildConfig from tensorrt_llm import Mapping as ShardingInfo from tensorrt_llm.plugin import PluginConfig +from tensorrt_llm.plugin.plugin import ContextFMHAType from transformers import AutoConfig from optimum.nvidia.lang import DataType @@ -94,12 +95,18 @@ def validate(self) -> "ExportConfig": @property def plugin_config(self) -> "PluginConfig": config = PluginConfig() - config.gemm_plugin = self.dtype - config.bert_attention_plugin = self.dtype - config.gpt_attention_plugin = self.dtype - config.nccl_plugin = self.dtype - config.mamba_conv1d_plugin = self.dtype - config.moe_plugin = self.dtype + + config.gemm_plugin = "auto" + config.gpt_attention_plugin = "auto" + config.set_context_fmha(ContextFMHAType.enabled) + + if self.sharding.world_size > 1: + config.lookup_plugin = "auto" + config.set_nccl_plugin() + + if DataType(self.dtype) == DataType.FLOAT8: + config.gemm_swiglu_plugin = True + return config def to_builder_config( @@ -115,6 +122,7 @@ def to_builder_config( max_num_tokens=self.max_num_tokens, builder_opt=self.optimization_level, plugin_config=plugin_config or self.plugin_config, + use_fused_mlp=True, ) def with_sharding( diff --git a/src/optimum/nvidia/export/converter.py b/src/optimum/nvidia/export/converter.py index 102d4c06..8f14d956 100644 --- a/src/optimum/nvidia/export/converter.py +++ b/src/optimum/nvidia/export/converter.py @@ -1,3 +1,4 @@ +import shutil from abc import ABC from enum import Enum from logging import getLogger @@ -66,6 +67,7 @@ def __init__( model_id: str, subpart: str = "", workspace: Optional[Union["Workspace", str, bytes, Path]] = None, + license_path: Optional[Union[str, bytes, Path]] = None, ): LOGGER.info(f"Creating a model converter for {subpart}") if not workspace: @@ -80,11 +82,26 @@ def __init__( LOGGER.debug(f"Initializing model converter workspace at {workspace.root}") self._workspace = workspace + self._license_path = license_path @property def workspace(self) -> Workspace: return self._workspace + def save_license(self, licence_filename: str = "LICENSE"): + """ + Save the license if provided and if the license is not already present. + This method doesn't check the content of the license + :param licence_filename: Name of the file containing the license content + """ + if ( + not ( + dst_licence_file_path := self.workspace.root / licence_filename + ).exists() + and self._license_path + ): + shutil.copyfile(self._license_path, dst_licence_file_path) + def quantize(self): raise NotImplementedError() @@ -108,6 +125,7 @@ def convert( ) model.save_checkpoint(str(self._workspace.checkpoints_path)) + self.save_license() return TensorRTArtifact.checkpoints(str(self._workspace.checkpoints_path)) def build( @@ -126,10 +144,13 @@ def build( config = infer_plugin_from_build_config(config) - for rank, model in enumerate(models): - LOGGER.info(f"Building TRTLLM engine for rank {rank}") + for model in models: + LOGGER.info( + f"Building TRTLLM engine for rank {model.config.mapping.rank} ->> {config.to_dict()}" + ) engine = build(model, config) engine.save(str(self._workspace.engines_path)) + self.save_license() return TensorRTArtifact.engines(str(self._workspace.engines_path)) diff --git a/src/optimum/nvidia/hub.py b/src/optimum/nvidia/hub.py index c0a27b6d..8ce68d56 100644 --- a/src/optimum/nvidia/hub.py +++ b/src/optimum/nvidia/hub.py @@ -60,13 +60,16 @@ ATTR_TRTLLM_ENGINE_FOLDER = "__trtllm_engine_folder__" FILE_TRTLLM_ENGINE_PATTERN = "rank[0-9]*.engine" FILE_TRTLLM_CHECKPOINT_PATTERN = "rank[0-9]*.engine" +FILE_LICENSE_NAME = "LICENSE" HUB_SNAPSHOT_ALLOW_PATTERNS = [ CONFIG_NAME, GENERATION_CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, "*.safetensors", + FILE_LICENSE_NAME, ] + LOGGER = getLogger() @@ -144,6 +147,7 @@ def _from_pretrained( device_map: Optional[str] = None, export_config: Optional[ExportConfig] = None, force_export: bool = False, + export_only: bool = False, save_intermediate_checkpoints: bool = False, ) -> T: if get_device_count() < 1: @@ -233,6 +237,14 @@ def _from_pretrained( original_checkpoints_path_for_conversion ) + # This is required to complain with binding license for derivative work + if FILE_LICENSE_NAME in original_checkpoints_path_for_conversion: + licence_path = original_checkpoints_path_for_conversion.joinpath( + FILE_LICENSE_NAME + ) + else: + licence_path = None + # If no export config, let's grab a default one export_config = export_config or ExportConfig.from_config(config) @@ -253,14 +265,15 @@ def _from_pretrained( f"Building {model_id} {subpart} ({idx + 1} / {len(targets)})" ) - converter = TensorRTModelConverter(model_id, subpart, workspace) + converter = TensorRTModelConverter( + model_id, subpart, workspace, licence_path + ) # Artifacts resulting from a build are not stored in the location `snapshot_download` # would use. Instead, it uses `cached_assets_path` to create a specific location which # doesn't mess up with the HF caching system. Use can use `save_pretrained` to store # the build artifact into a snapshot friendly place # If this specific location is found, we don't necessary need to rebuild - if force_export or not len( list(converter.workspace.engines_path.glob("*.engine")) ): @@ -276,6 +289,9 @@ def _from_pretrained( dtype=DataType.from_torch(config.torch_dtype).value, mapping=export_config.sharding, load_by_shard=True, + use_parallel_embedding=export_config.sharding.world_size + > 1, + share_embedding_table=config.tie_word_embeddings, ) ranked_model.config.mapping.rank = rank @@ -307,6 +323,7 @@ def _from_pretrained( return cls( engines_path=engines_folder, generation_config=generation_config, + load_engines=not export_only ) @abstractmethod @@ -314,25 +331,32 @@ def _save_additional_parcels(self, save_directory: Path): raise NotImplementedError() def _save_pretrained(self, save_directory: Path) -> None: - try: - # Need target_is_directory on Windows - # Windows10 needs elevated privilege for symlink which will raise OSError if not the case - # Falling back to copytree in this case - for file in self._engines_path.glob("*"): + device_name = get_device_name(0)[-1] + save_directory = save_directory.joinpath(device_name) + save_directory.mkdir(parents=True, exist_ok=True) + + src_license_file_path = self._engines_path.parent / FILE_LICENSE_NAME + dst_files = [src_license_file_path] if src_license_file_path.exists() else [] + dst_files += list(self._engines_path.glob("*")) + + for file in dst_files: + try: + # Need target_is_directory on Windows + # Windows10 needs elevated privilege for symlink which will raise OSError if not the case + # Falling back to copytree in this case symlink( file, save_directory.joinpath(file.relative_to(self._engines_path)) ) - except OSError as ose: - LOGGER.error( - f"Failed to create symlink from current engine folder {self._engines_path.parent} to {save_directory}. " - "Will default to copy based _save_pretrained", - exc_info=ose, - ) - for file in self._engines_path.glob("*"): + except OSError as ose: + LOGGER.error( + f"Failed to create symlink from current engine folder {self._engines_path.parent} to {save_directory}. " + "Will default to copy based _save_pretrained", + exc_info=ose, + ) copytree( file, save_directory.joinpath(file.relative_to(self._engines_path)), symlinks=True, ) - finally: - self._save_additional_parcels(save_directory) + + self._save_additional_parcels(save_directory) diff --git a/src/optimum/nvidia/lang/__init__.py b/src/optimum/nvidia/lang/__init__.py index 8352cec5..4b17ff37 100644 --- a/src/optimum/nvidia/lang/__init__.py +++ b/src/optimum/nvidia/lang/__init__.py @@ -20,7 +20,7 @@ import torch -class DataType(Enum): +class DataType(str, Enum): FLOAT32 = "float32" FLOAT16 = "float16" BFLOAT16 = "bfloat16" diff --git a/src/optimum/nvidia/runtime.py b/src/optimum/nvidia/runtime.py index b70af98f..f3c765eb 100644 --- a/src/optimum/nvidia/runtime.py +++ b/src/optimum/nvidia/runtime.py @@ -80,6 +80,7 @@ def __init__( engines_path: Union[str, PathLike], generation_config: "GenerationConfig", executor_config: Optional["ExecutorConfig"] = None, + load_engines: bool = True ): engines_path = Path(engines_path) @@ -90,10 +91,11 @@ def __init__( self._generation_config = generation_config self._sampling_config = convert_generation_config(generation_config) - self._executor = GenerationExecutor.create( - engine_dir=engines_path, - executor_config=executor_config or default_executor_config(self._config), - ) + if load_engines: + self._executor = GenerationExecutor.create( + engine_dir=engines_path, + executor_config=executor_config or default_executor_config(self._config), + ) def generate( self, @@ -162,9 +164,10 @@ def __init__( engines_path: Union[str, PathLike, Path], generation_config: "GenerationConfig", executor_config: Optional["ExecutorConfig"] = None, + load_engines: bool = True ): InferenceRuntimeBase.__init__( - self, engines_path, generation_config, executor_config + self, engines_path, generation_config, executor_config, load_engines ) HuggingFaceHubModel.__init__(self, engines_path) diff --git a/src/optimum/nvidia/utils/import_utils.py b/src/optimum/nvidia/utils/import_utils.py new file mode 100644 index 00000000..af4943fe --- /dev/null +++ b/src/optimum/nvidia/utils/import_utils.py @@ -0,0 +1,5 @@ +import importlib.util + + +def is_tensorrt_llm_available() -> bool: + return importlib.util.find_spec("tensorrt_llm") is not None diff --git a/src/optimum/nvidia/version.py b/src/optimum/nvidia/version.py index 03b6f5e9..f99d0a41 100644 --- a/src/optimum/nvidia/version.py +++ b/src/optimum/nvidia/version.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from distutils.version import StrictVersion +from packaging.version import Version __version__ = "0.1.0b7" -VERSION = StrictVersion(__version__) +VERSION = Version(__version__)