diff --git a/crossfit/backend/torch/hf/model.py b/crossfit/backend/torch/hf/model.py index b431602..85224cc 100644 --- a/crossfit/backend/torch/hf/model.py +++ b/crossfit/backend/torch/hf/model.py @@ -31,6 +31,7 @@ def __init__( self, path_or_name: str, max_mem_gb: int = 16, + model_output_type: str = "numeric", training: bool = False, start_batch_size: int = 1, end_batch_size: int = 2048, @@ -38,7 +39,7 @@ def __init__( start_seq_len: int = 1, seq_len_increment: int = 64, ): - super().__init__(path_or_name, max_mem_gb) + super().__init__(path_or_name, max_mem_gb, model_output_type) self.start_batch_size = start_batch_size self.end_batch_size = end_batch_size self.batch_size_increment = batch_size_increment diff --git a/crossfit/backend/torch/model.py b/crossfit/backend/torch/model.py index 73d0ed3..868f09b 100644 --- a/crossfit/backend/torch/model.py +++ b/crossfit/backend/torch/model.py @@ -11,10 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +import cudf +import cupy as cp + +from crossfit.backend.cudf.series import ( + create_list_series_from_1d_or_2d_ar, + create_nested_list_series_from_3d_ar, +) +from crossfit.utils.torch_utils import cleanup_torch_cache, concat_and_pad_tensors + + class Model: - def __init__(self, path_or_name: str, max_mem_gb: int = 16): + def __init__(self, path_or_name: str, max_mem_gb: int = 16, model_output_type: str = "numeric"): self.path_or_name = path_or_name self.max_mem_gb = max_mem_gb + if model_output_type in ["numeric", "string"]: + self.model_output_type = model_output_type + else: + raise ValueError( + "Invalid model output type provided. Allowed values are : 'string' or 'numeric'." + ) def load_model(self, device="cuda"): raise NotImplementedError() @@ -41,3 +59,37 @@ def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int: def max_seq_length(self) -> int: raise NotImplementedError() + + def get_model_output(self, all_outputs_ls, index, loader, pred_output_col) -> cudf.DataFrame: + # importing here to avoid cyclic import error + from crossfit.backend.torch.loader import SortedSeqLoader + + out = cudf.DataFrame(index=index) + _index = loader.sort_column(index.values) if type(loader) is SortedSeqLoader else index + + if self.model_output_type == "string": + all_outputs = [o for output in all_outputs_ls for o in output] + out[pred_output_col] = cudf.Series(data=all_outputs, index=_index) + del all_outputs_ls + del loader + else: + outputs = cp.asarray( + concat_and_pad_tensors( + all_outputs_ls, + pad_token_id=loader.pad_token_id, + padding_side=loader.padding_side, + ) + ) + del all_outputs_ls + del loader + cleanup_torch_cache() + if len(outputs.shape) <= 2: + out[pred_output_col] = create_list_series_from_1d_or_2d_ar(outputs, _index) + elif len(outputs.shape) == 3: + out[pred_output_col] = create_nested_list_series_from_3d_ar(outputs, _index) + else: + raise RuntimeError(f"Unexpected output shape: {outputs.shape}") + del outputs + del _index + cleanup_torch_cache() + return out diff --git a/crossfit/backend/torch/op/base.py b/crossfit/backend/torch/op/base.py index 8fb6e31..960de95 100644 --- a/crossfit/backend/torch/op/base.py +++ b/crossfit/backend/torch/op/base.py @@ -14,18 +14,11 @@ from typing import Optional -import cudf -import cupy as cp import torch -from crossfit.backend.cudf.series import ( - create_list_series_from_1d_or_2d_ar, - create_nested_list_series_from_3d_ar, -) from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE, InMemoryLoader, SortedSeqLoader from crossfit.backend.torch.model import Model from crossfit.op.base import Op -from crossfit.utils.torch_utils import cleanup_torch_cache, concat_and_pad_tensors class Predictor(Op): @@ -82,26 +75,11 @@ def call(self, data, partition_info=None): output = self.post(output) all_outputs_ls.append(output) - - out = cudf.DataFrame(index=index) - outputs = cp.asarray( - concat_and_pad_tensors( - all_outputs_ls, pad_token_id=loader.pad_token_id, padding_side=loader.padding_side - ) - ) - _index = loader.sort_column(index.values) if self.sorted_data_loader else index - del all_outputs_ls - del loader - cleanup_torch_cache() - if len(outputs.shape) <= 2: - out[self.pred_output_col] = create_list_series_from_1d_or_2d_ar(outputs, _index) - elif len(outputs.shape) == 3: - out[self.pred_output_col] = create_nested_list_series_from_3d_ar(outputs, _index) - else: - raise RuntimeError(f"Unexpected output shape: {output.shape}") - del outputs, _index - cleanup_torch_cache() + out = self.model.get_model_output(all_outputs_ls, index, loader, self.pred_output_col) return out def meta(self): - return {self.pred_output_col: "float32"} + if self.model.model_output_type == "string": + return {self.pred_output_col: "object"} + else: + return {self.pred_output_col: "float32"} diff --git a/examples/custom_ct2_model.py b/examples/custom_ct2_model.py new file mode 100644 index 0000000..4ae04a2 --- /dev/null +++ b/examples/custom_ct2_model.py @@ -0,0 +1,163 @@ +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +from dataclasses import dataclass +from functools import lru_cache + +import ctranslate2 +import dask_cudf +from transformers import AutoConfig, AutoTokenizer + +import crossfit as cf +from crossfit import op +from crossfit.backend.torch.hf.model import HFModel + + +@dataclass +class TranslationConfig: + pretrained_model_name_or_path: str + ct2_model_path: str + + +class CT2CustomModel: + def __init__(self, config: TranslationConfig, device="cuda"): + self.config = config + self.tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=config.pretrained_model_name_or_path, + trust_remote_code=True, + ) + self.model = ctranslate2.Translator(model_path=config.ct2_model_path, device=device) + + def clean_extra_tokens(self, token_2d): + results = [ + [ + t + for t in token_1d + if t + not in { + self.tokenizer.pad_token, + self.tokenizer.bos_token, + self.tokenizer.eos_token, + self.tokenizer.unk_token, + } + ] + for token_1d in token_2d + ] + return results + + def __call__(self, batch): + token_ids_2d = batch["input_ids"] + token_ids_1d = token_ids_2d.view(-1).tolist() + tokens_1d = self.tokenizer.convert_ids_to_tokens(token_ids_1d) + tokens_2d = [ + tokens_1d[i : i + token_ids_2d.size(1)] + for i in range(0, len(tokens_1d), token_ids_2d.size(1)) + ] + tokens = self.clean_extra_tokens(tokens_2d) + + tr_res = self.model.translate_batch( + tokens, + min_decoding_length=0, + max_decoding_length=256, + beam_size=5, + num_hypotheses=1, + ) + translations = ["".join(x.hypotheses[0]) for x in tr_res] + return translations + + +class ModelForSeq2SeqModel(HFModel): + def __init__(self, config): + self.trans_config = config + self.config = self.load_cfg() + super().__init__( + self.trans_config.pretrained_model_name_or_path, model_output_type="string" + ) + + def load_model(self, device="cuda"): + model = CT2CustomModel(self.trans_config) + return model + + def load_config(self): + return self.load_cfg() + + @lru_cache(maxsize=1) + def load_tokenizer(self): + return AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=self.trans_config.pretrained_model_name_or_path, + trust_remote_code=True, + ) + + def max_seq_length(self) -> int: + return self.config.max_source_positions + + @lru_cache(maxsize=1) + def load_cfg(self): + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path=self.trans_config.pretrained_model_name_or_path, + trust_remote_code=True, + ) + return config + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="PyTorch Model Predictions using Crossfit") + parser.add_argument("input_parquet_path", help="Input parquet file") + parser.add_argument("output_parquet_path", help="Output file") + parser.add_argument( + "--ct2-model-dir", + help="Directory where ctranslate2 optimized model is present", + ) + parser.add_argument( + "--input-column", type=str, default="text", help="Column name in input dataframe" + ) + parser.add_argument("--pool-size", type=str, default="12GB", help="RMM pool size") + parser.add_argument("--num-workers", type=int, default=1, help="Number of GPUs") + parser.add_argument( + "--model-name", + type=str, + default="ai4bharat/indictrans2-en-indic-1B", + help="Model name", + ) + parser.add_argument("--batch-size", type=int, default=64, help="Batch size") + parser.add_argument("--partitions", type=int, default=2, help="Number of partitions") + + args = parser.parse_args() + return args + + +def main(): + args = parse_arguments() + Config = TranslationConfig( + pretrained_model_name_or_path=args.model_name, + ct2_model_path=args.ct2_model_dir, + ) + ddf = dask_cudf.read_parquet(args.input_parquet_path) + + with cf.Distributed(rmm_pool_size=args.pool_size, n_workers=args.num_workers): + model = ModelForSeq2SeqModel(Config) + pipe = op.Sequential( + op.Tokenizer(model, cols=[args.input_column], tokenizer_type="default", max_length=255), + op.Predictor(model, sorted_data_loader=True, batch_size=args.batch_size), + repartition=args.partitions, + keep_cols=[args.input_column], + ) + outputs = pipe(ddf) + outputs.to_parquet(args.output_parquet_path) + + +if __name__ == "__main__": + main() diff --git a/tests/op/test_model_function.py b/tests/op/test_model_function.py new file mode 100644 index 0000000..7ffd1eb --- /dev/null +++ b/tests/op/test_model_function.py @@ -0,0 +1,75 @@ +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest + +cp = pytest.importorskip("cupy") +cudf = pytest.importorskip("cudf") +dask_cudf = pytest.importorskip("dask_cudf") +dd = pytest.importorskip("dask.dataframe") +pd = pytest.importorskip("pandas") +transformers = pytest.importorskip("transformers") +torch = pytest.importorskip("torch") + +import crossfit as cf # noqa: E402 + +cf_loader = pytest.importorskip("crossfit.backend.torch.loader") + + +@pytest.mark.parametrize("trust_remote_code", ["y"]) +def test_model_output_int(trust_remote_code, model_name="microsoft/deberta-v3-base"): + with patch("builtins.input", return_value=trust_remote_code): + tokens_data = cudf.DataFrame({"input_ids": [[11, 12, 13], [14, 15, 0], [17, 0, 0]]}) + index = tokens_data.index.copy() + model = cf.HFModel(model_name) + data = [[4], [7], [10]] + all_outputs_ls = torch.tensor(data) + loader = cf_loader.SortedSeqLoader( + tokens_data, + model, + ) + pred_output_col = "translation" + out = model.get_model_output(all_outputs_ls, index, loader, pred_output_col) + assert isinstance(out, cudf.DataFrame) + assert isinstance(out["translation"][0][0], int) + assert ( + out["translation"][0] == data[0] + and out["translation"][1] == data[1] + and out["translation"][2] == data[2] + ) + + +@pytest.mark.parametrize("trust_remote_code", ["y"]) +def test_model_output_str(trust_remote_code, model_name="microsoft/deberta-v3-base"): + with patch("builtins.input", return_value=trust_remote_code): + tokens_data = cudf.DataFrame({"input_ids": [[18264, 7728, 8], [123, 99, 0], [3115, 0, 0]]}) + index = tokens_data.index.copy() + model = cf.HFModel(model_name, model_output_type="string") + data = [["▁हमारे▁परीक्षण▁डेटा"], ["▁पर▁हमारे▁दो"], ["▁दूरी▁कार्यों▁की"]] + all_outputs_ls = data + loader = cf_loader.SortedSeqLoader( + tokens_data, + model, + ) + pred_output_col = "translation" + out = model.get_model_output(all_outputs_ls, index, loader, pred_output_col) + assert isinstance(out, cudf.DataFrame) + assert isinstance(out["translation"][0][0], str) + assert ( + out["translation"][0] == data[0][0] + and out["translation"][1] == data[1][0] + and out["translation"][2] == data[2][0] + )