-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor the overall Hugging Face -> TRTLLM export workflow (#133)
* Initial mixtral support * feat(mixtral): map the correct num_local_experts config key for the MOE config * feat(mixtral): allow to specify TP/PP configurations when allocating the model * feat(mixtral): Expose tp/pp in examples/cli * feat(mixtral): Remove config attributes from the model_kwargs to avoid setting many dupplicates * feat(hub): always do weight layout conversion on CPU memory * feat(mixtral): enable MOE config conversion from transformers * feat(parallelism): Enable providing TP/PP/MOE parallelism args * feat(parallelism): Enable forwarding tp/pp args to trtllm-build * feat(converter): Introduce base for TRTModelConverter * Upgrade huggingface-hub dependency to 0.23.0 * feat(hub): Initial refactoring for clear separation of concerns * feat(hub): Rework the overall separation of concern for the hub and exporting * feat(hub): Working for all non-Whisper model * feat(hub): Disable whisper for now * feat(trtllm) : Update trtllm to 0.10.0 * feat(deps) : Ping hf-transfer to 0.1.6 * feat(quant): Rework overall quantization schema * feat(misc): Failed name refactoring leaving untouched imports ... * feat(hub): Expose device_map="auto" * feat(chore): quality * feat(hub): expose device_map to enable auto-parallel * feat(docker): Use repo variable for image namelocal * feat(build): Validate new workflow for building engines * feat(deps): Move to TRTLLM 0.11 preversion * feat(deps): Use the new executor api for running LLMs * feat(kvcache): Use floor when computing the number of tokens to store in the kvcache * feat(ifb): Enable async generation with in-flight batching support * feat(misc): Add better typing return info for AutoModelCausalLM * feat(chore): quality * feat(misc): Remove padding reference in examples * Ensure building all the ranks in disitributed settisgs * Update hub tests
- Loading branch information
1 parent
714734f
commit 011b5a9
Showing
37 changed files
with
1,027 additions
and
3,759 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# coding=utf-8 | ||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved. | ||
# # | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# # | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# # | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import asyncio | ||
from argparse import ArgumentParser | ||
from logging import getLogger | ||
from pathlib import Path | ||
|
||
from transformers import AutoTokenizer | ||
|
||
from optimum.nvidia import AutoModelForCausalLM, ExportConfig, setup_logging | ||
|
||
|
||
# Setup logging needs to happen before importing TRT ... | ||
setup_logging(True) | ||
|
||
from optimum.nvidia.utils.cli import ( | ||
postprocess_quantization_parameters, | ||
register_common_model_topology_args, | ||
register_optimization_profiles_args, | ||
register_quantization_args, | ||
) | ||
|
||
|
||
LOGGER = getLogger(__name__) | ||
|
||
|
||
async def infer(): | ||
tokenizer = AutoTokenizer.from_pretrained(args.model) | ||
if not tokenizer.pad_token: | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
export = ExportConfig.from_pretrained(args.model) | ||
export.max_input_len = 1024 | ||
export.max_output_len = 256 | ||
export.max_num_tokens = 256 | ||
export.max_beam_width = 1 | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
args.model, device_map="auto", export_config=export | ||
) | ||
# model.save_pretrained(args.output) | ||
|
||
prompt = "What is the latest generation of Nvidia GPUs?" | ||
tokens = tokenizer(prompt, return_tensors="pt") | ||
generated = await model.agenerate( | ||
tokens["input_ids"], | ||
) | ||
|
||
generated_text = tokenizer.batch_decode(generated, skip_special_tokens=True) | ||
print(generated_text) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser("🤗 Optimum-Nvidia Text-Generation Example") | ||
parser.add_argument( | ||
"--hub-token", | ||
type=str, | ||
help="Hugging Face Hub Token to retrieve private weights.", | ||
) | ||
register_common_model_topology_args(parser) | ||
register_optimization_profiles_args(parser) | ||
register_quantization_args(parser) # Inject params.quantization_config | ||
|
||
parser.add_argument("model", type=str, help="The model's id or path to use.") | ||
parser.add_argument( | ||
"output", type=Path, help="Path to store generated TensorRT engine." | ||
) | ||
args = parser.parse_args() | ||
args = postprocess_quantization_parameters(args) | ||
|
||
if args.hub_token is not None: | ||
from huggingface_hub import login | ||
|
||
login(args.hub_token) | ||
|
||
asyncio.run(infer()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.