diff --git a/binaries/conda/build_packages.py b/binaries/conda/build_packages.py index 4fd8a5d82b..00b9e9c13b 100644 --- a/binaries/conda/build_packages.py +++ b/binaries/conda/build_packages.py @@ -22,7 +22,13 @@ PACKAGES = ["torchserve", "torch-model-archiver", "torch-workflow-archiver"] # conda convert supported platforms https://docs.conda.io/projects/conda-build/en/stable/resources/commands/conda-convert.html -PLATFORMS = ["linux-64", "osx-64", "win-64", "osx-arm64"] # Add a new platform here +PLATFORMS = [ + "linux-64", + "osx-64", + "win-64", + "osx-arm64", + "linux-aarch64", +] # Add a new platform here if os.name == "nt": # Assumes miniconda is installed in windows diff --git a/docs/linux_aarch64.md b/docs/linux_aarch64.md new file mode 100644 index 0000000000..5e13410c83 --- /dev/null +++ b/docs/linux_aarch64.md @@ -0,0 +1,29 @@ +# TorchServe on linux aarch64 - Experimental + +TorchServe has been tested to be working on linux aarch64 for some of the examples. +- Tested this on Amazon Graviton 3 instance(m7g.4x.large) + +## Installation + +Currently installation from PyPi or installing from source works + +``` +python ts_scripts/install_dependencies.py +pip install torchserve torch-model-archiver torch-workflow-archiver +``` + +## Optimizations + +You can also enable this optimizations for Graviton 3 to get an improved performance. More details can be found in this [blog](https://pytorch.org/blog/optimized-pytorch-w-graviton/) +``` +export DNNL_DEFAULT_FPMATH_MODE=BF16 +export LRU_CACHE_CAPACITY=1024 +``` + +## Example + +This [example](https://github.com/pytorch/serve/tree/master/examples/text_to_speech_synthesizer/SpeechT5) on Text to Speech synthesis was verified to be working on Graviton 3 + +## To Dos +- CI +- Regression tests diff --git a/examples/text_to_speech_synthesizer/SpeechT5/README.md b/examples/text_to_speech_synthesizer/SpeechT5/README.md new file mode 100644 index 0000000000..e2182faf7f --- /dev/null +++ b/examples/text_to_speech_synthesizer/SpeechT5/README.md @@ -0,0 +1,50 @@ +# Text to Speech synthesis with SpeechT5 + +This is an example showing text to speech synthesis using SpeechT5 model. This has been verified to work on (linux-aarch64) Graviton 3 instance + +While running this model on `linux-aarch64`, you can enable these optimizations + +``` +export DNNL_DEFAULT_FPMATH_MODE=BF16 +export LRU_CACHE_CAPACITY=1024 +``` +More details can be found in this [blog](https://pytorch.org/blog/optimized-pytorch-w-graviton/) + + +## Pre-requisites +``` +chmod +x setup.sh +./setup.sh +``` + +## Download model + +This saves the model artifacts to `model_artifacts` directory +``` +huggingface-cli login +python download_model.py +``` + +## Create model archiver + +``` +mkdir model_store + +torch-model-archiver --model-name SpeechT5-TTS --version 1.0 --handler text_to_speech_handler.py --config-file model-config.yaml --archive-format no-archive --export-path model_store -f + +mv model_artifacts/* model_store/SpeechT5-TTS/ +``` + +## Start TorchServe + +``` +torchserve --start --ncs --model-store model_store --models SpeechT5-TTS +``` + +## Send Inference request + +``` +curl http://127.0.0.1:8080/predictions/SpeechT5-TTS -T sample_input.txt -o speech.wav +``` + +This generates an audio file `speech.wav` corresponding to the text in `sample_input.txt` diff --git a/examples/text_to_speech_synthesizer/SpeechT5/download_model.py b/examples/text_to_speech_synthesizer/SpeechT5/download_model.py new file mode 100644 index 0000000000..66d1494e0c --- /dev/null +++ b/examples/text_to_speech_synthesizer/SpeechT5/download_model.py @@ -0,0 +1,14 @@ +from datasets import load_dataset +from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor + +processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") +model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") +vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") + +embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") + +model.save_pretrained(save_directory="model_artifacts/model") +processor.save_pretrained(save_directory="model_artifacts/processor") +vocoder.save_pretrained(save_directory="model_artifacts/vocoder") +embeddings_dataset.save_to_disk("model_artifacts/speaker_embeddings") +print("Save model artifacts to directory model_artifacts") diff --git a/examples/text_to_speech_synthesizer/SpeechT5/model-config.yaml b/examples/text_to_speech_synthesizer/SpeechT5/model-config.yaml new file mode 100644 index 0000000000..feaf7026b3 --- /dev/null +++ b/examples/text_to_speech_synthesizer/SpeechT5/model-config.yaml @@ -0,0 +1,8 @@ +minWorkers: 1 +maxWorkers: 1 +handler: + model: "model" + vocoder: "vocoder" + processor: "processor" + speaker_embeddings: "speaker_embeddings" + output_dir: "/tmp" diff --git a/examples/text_to_speech_synthesizer/SpeechT5/sample_input.txt b/examples/text_to_speech_synthesizer/SpeechT5/sample_input.txt new file mode 100644 index 0000000000..e60d898198 --- /dev/null +++ b/examples/text_to_speech_synthesizer/SpeechT5/sample_input.txt @@ -0,0 +1 @@ +"I love San Francisco" diff --git a/examples/text_to_speech_synthesizer/SpeechT5/setup.sh b/examples/text_to_speech_synthesizer/SpeechT5/setup.sh new file mode 100644 index 0000000000..895c08b49e --- /dev/null +++ b/examples/text_to_speech_synthesizer/SpeechT5/setup.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +# Needed for soundfile +sudo apt install libsndfile1 -y + +pip install --upgrade transformers sentencepiece datasets[audio] soundfile diff --git a/examples/text_to_speech_synthesizer/SpeechT5/text_to_speech_handler.py b/examples/text_to_speech_synthesizer/SpeechT5/text_to_speech_handler.py new file mode 100644 index 0000000000..65fbbf1509 --- /dev/null +++ b/examples/text_to_speech_synthesizer/SpeechT5/text_to_speech_handler.py @@ -0,0 +1,68 @@ +import logging +import os +import uuid + +import soundfile as sf +import torch +from datasets import load_from_disk +from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor + +from ts.torch_handler.base_handler import BaseHandler + +logger = logging.getLogger(__name__) + + +class SpeechT5_TTS(BaseHandler): + def __init__(self): + self.model = None + self.processor = None + self.vocoder = None + self.speaker_embeddings = None + self.output_dir = "/tmp" + + def initialize(self, ctx): + properties = ctx.system_properties + model_dir = properties.get("model_dir") + + processor = ctx.model_yaml_config["handler"]["processor"] + model = ctx.model_yaml_config["handler"]["model"] + vocoder = ctx.model_yaml_config["handler"]["vocoder"] + embeddings_dataset = ctx.model_yaml_config["handler"]["speaker_embeddings"] + self.output_dir = ctx.model_yaml_config["handler"]["output_dir"] + + self.processor = SpeechT5Processor.from_pretrained(processor) + self.model = SpeechT5ForTextToSpeech.from_pretrained(model) + self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder) + + # load xvector containing speaker's voice characteristics from a dataset + embeddings_dataset = load_from_disk(embeddings_dataset) + self.speaker_embeddings = torch.tensor( + embeddings_dataset[7306]["xvector"] + ).unsqueeze(0) + + def preprocess(self, requests): + assert len(requests) == 1, "This is currently supported with batch_size=1" + req_data = requests[0] + + input_data = req_data.get("data") or req_data.get("body") + + if isinstance(input_data, (bytes, bytearray)): + input_data = input_data.decode("utf-8") + + inputs = self.processor(text=input_data, return_tensors="pt") + + return inputs + + def inference(self, inputs): + output = self.model.generate_speech( + inputs["input_ids"], self.speaker_embeddings, vocoder=self.vocoder + ) + return output + + def postprocess(self, inference_output): + path = self.output_dir + "/{}.wav".format(uuid.uuid4().hex) + sf.write(path, inference_output.numpy(), samplerate=16000) + with open(path, "rb") as output: + data = output.read() + os.remove(path) + return [data] diff --git a/examples/text_to_speech_synthesizer/README.md b/examples/text_to_speech_synthesizer/WaveGlow/README.md similarity index 100% rename from examples/text_to_speech_synthesizer/README.md rename to examples/text_to_speech_synthesizer/WaveGlow/README.md diff --git a/examples/text_to_speech_synthesizer/create_mar.sh b/examples/text_to_speech_synthesizer/WaveGlow/create_mar.sh similarity index 100% rename from examples/text_to_speech_synthesizer/create_mar.sh rename to examples/text_to_speech_synthesizer/WaveGlow/create_mar.sh diff --git a/examples/text_to_speech_synthesizer/requirements.txt b/examples/text_to_speech_synthesizer/WaveGlow/requirements.txt similarity index 100% rename from examples/text_to_speech_synthesizer/requirements.txt rename to examples/text_to_speech_synthesizer/WaveGlow/requirements.txt diff --git a/examples/text_to_speech_synthesizer/sample_text.txt b/examples/text_to_speech_synthesizer/WaveGlow/sample_text.txt similarity index 100% rename from examples/text_to_speech_synthesizer/sample_text.txt rename to examples/text_to_speech_synthesizer/WaveGlow/sample_text.txt diff --git a/examples/text_to_speech_synthesizer/waveglow_handler.py b/examples/text_to_speech_synthesizer/WaveGlow/waveglow_handler.py similarity index 100% rename from examples/text_to_speech_synthesizer/waveglow_handler.py rename to examples/text_to_speech_synthesizer/WaveGlow/waveglow_handler.py diff --git a/examples/text_to_speech_synthesizer/waveglow_model.py b/examples/text_to_speech_synthesizer/WaveGlow/waveglow_model.py similarity index 100% rename from examples/text_to_speech_synthesizer/waveglow_model.py rename to examples/text_to_speech_synthesizer/WaveGlow/waveglow_model.py diff --git a/requirements/developer.txt b/requirements/developer.txt index 77b12693d1..5387bf1de0 100644 --- a/requirements/developer.txt +++ b/requirements/developer.txt @@ -15,7 +15,7 @@ pre-commit==3.3.2 twine==4.0.2 mypy==1.3.0 torchpippy==0.1.1 -intel_extension_for_pytorch==2.2.0; sys_platform != 'win32' and sys_platform != 'darwin' +intel_extension_for_pytorch==2.2.0; sys_platform != 'win32' and sys_platform != 'darwin' and platform_machine != 'aarch64' onnxruntime==1.17.1 googleapis-common-protos onnx==1.16.0 diff --git a/requirements/torch_linux_aarch64.txt b/requirements/torch_linux_aarch64.txt new file mode 100644 index 0000000000..5aff2cf43c --- /dev/null +++ b/requirements/torch_linux_aarch64.txt @@ -0,0 +1,6 @@ +#pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu +--extra-index-url https://download.pytorch.org/whl/cpu +-r torch_common.txt +torch==2.2.1; sys_platform == 'linux' and platform_machine == 'aarch64' +torchvision==0.17.1; sys_platform == 'linux' and platform_machine == 'aarch64' +torchaudio==2.2.1; sys_platform == 'linux' and platform_machine == 'aarch64' diff --git a/ts_scripts/install_dependencies.py b/ts_scripts/install_dependencies.py index f047de2a2b..f6c208bf5b 100644 --- a/ts_scripts/install_dependencies.py +++ b/ts_scripts/install_dependencies.py @@ -118,9 +118,14 @@ def install_torch_packages(self, cuda_version): f"{sys.executable} -m pip install -U -r {torch_neuronx_requirements_file}" ) else: - os.system( - f"{sys.executable} -m pip install -U -r requirements/torch_{platform.system().lower()}.txt" - ) + if platform.machine() == "aarch64": + os.system( + f"{sys.executable} -m pip install -U -r requirements/torch_{platform.system().lower()}_{platform.machine()}.txt" + ) + else: + os.system( + f"{sys.executable} -m pip install -U -r requirements/torch_{platform.system().lower()}.txt" + ) def install_python_packages(self, cuda_version, requirements_file_path, nightly): check = "where" if platform.system() == "Windows" else "which" diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 8f81c07e4a..b72ce4a3b0 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1216,6 +1216,10 @@ libomp rpath venv TorchInductor +Graviton +aarch +linux +SpeechT Pytests deviceType XGBoost