From 429bc5caffef65bee0da1a36f3cf5d8998fcf748 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Thu, 7 Mar 2024 17:25:49 -0800 Subject: [PATCH] Add quantization config option (#433) Co-authored-by: ZHENG, Zhen Co-authored-by: Logan Adams --- mii/api.py | 1 + mii/config.py | 13 +++++++++++++ requirements/requirements.txt | 2 +- 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/mii/api.py b/mii/api.py index d46ded1d..d909c837 100644 --- a/mii/api.py +++ b/mii/api.py @@ -50,6 +50,7 @@ def _parse_kwargs_to_model_config( # Create the ModelConfig object and return it with remaining kwargs model_config = ModelConfig(**model_config) + return model_config, remaining_kwargs diff --git a/mii/config.py b/mii/config.py index 490947f1..8e9c5cd7 100644 --- a/mii/config.py +++ b/mii/config.py @@ -131,6 +131,12 @@ class ModelConfig(MIIConfigModel): `inference_engine_config`. """ + quantization_mode: Optional[str] = None + """ + The quantization mode in string format. The supported modes are as follows: + - 'wf6af16', weight-only quantization with FP6 weight and FP16 activation. + """ + inference_engine_config: RaggedInferenceEngineConfig = {} """ DeepSpeed inference engine config. This is automatically generated, but you @@ -210,6 +216,13 @@ def propagate_tp_size(cls, values: Dict[str, Any]) -> Dict[str, Any]: values.get("inference_engine_config").tensor_parallel.tp_size = tensor_parallel return values + @root_validator + def propagate_quantization_mode(cls, values: Dict[str, Any]) -> Dict[str, Any]: + quantization_mode = values.get("quantization_mode") + values.get( + "inference_engine_config").quantization.quantization_mode = quantization_mode + return values + @root_validator def check_replica_config(cls, values: Dict[str, Any]) -> Dict[str, Any]: num_replica_config = len(values.get("replica_configs")) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index f067ee09..019fc261 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,5 +1,5 @@ asyncio -deepspeed>=0.13.0 +deepspeed>=0.14.0 deepspeed-kernels Flask-RESTful grpcio