Skip to content

Commit

Permalink
inhibit copying base model dependencies for summ, and translation (#3285
Browse files Browse the repository at this point in the history
)

* inhibit copying base model dependencies for summ, and translation

* comments fix

---------

Co-authored-by: Anubha Jain <[email protected]>
  • Loading branch information
Anubha98 and Anubha Jain authored Aug 22, 2024
1 parent 73db714 commit ae510a0
Showing 1 changed file with 10 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from azureml.acft.contrib.hf.nlp.utils.common_utils import deep_update
from azureml.acft.contrib.hf.nlp.constants.constants import (
MLFlowHFFlavourConstants, MLFlowHFFlavourTasks, SaveFileConstants, HfModelTypes
MLFlowHFFlavourConstants, MLFlowHFFlavourTasks, SaveFileConstants, HfModelTypes, Tasks
)

import mlflow
Expand Down Expand Up @@ -269,10 +269,10 @@ def remove_unwanted_packages(model_save_path: str):
yaml.safe_dump(conda_dict, f)
logger.info("Updated conda.yaml file")

def is_t5_text_classification_finetune(self, model_type) -> bool:
"""Check for t5 text-classification."""
return self.mlflow_task_type == MLFlowHFFlavourTasks.SINGLE_LABEL_CLASSIFICATION and \
model_type == HfModelTypes.T5
def is_t5_finetune(self, model_type) -> bool:
"""Check for t5 text-classification, translation, summarization."""
return self.component_args.task_name in [Tasks.SINGLE_LABEL_CLASSIFICATION, Tasks.TRANSLATION,
Tasks.SUMMARIZATION] and model_type == HfModelTypes.T5

def convert_model(self) -> None:
"""Convert pytorch model to oss mlflow model."""
Expand All @@ -282,9 +282,9 @@ def convert_model(self) -> None:
self.set_mlflow_model_parameters(model)

# Temp Fix:
# specific check for t5 text-classification so that base model dependencies doesn't get pass
# and use transformers version 4.40.0 from infer dependencies
if not self.is_t5_text_classification_finetune(model.config.model_type):
# specific check for t5 text-classification, translation, summarization so that base model
# dependencies doesn't get pass and use transformers version 4.44.0 from infer dependencies
if not self.is_t5_finetune(model.config.model_type):
conda_file_path = Path(self.ft_pytorch_model_path, MLFlowHFFlavourConstants.CONDA_YAML_FILE)
if conda_file_path.is_file():
self.mlflow_save_model_kwargs.update({"conda_env": str(conda_file_path)})
Expand Down Expand Up @@ -316,8 +316,8 @@ def convert_model(self) -> None:
self.add_model_signature()
self.copy_finetune_config(self.ft_pytorch_model_path, self.mlflow_model_save_path)

# Temp fix for t5 text-classification
if self.is_t5_text_classification_finetune(model.config.model_type):
# Temp fix for t5 text-classification, translation, summarization
if self.is_t5_finetune(model.config.model_type):
self.remove_unwanted_packages(self.mlflow_model_save_path)

logger.info("Saved MLFlow model using OSS flavour.")

0 comments on commit ae510a0

Please sign in to comment.