Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ctranslate2 translation example script #83

Merged
merged 10 commits into from
Sep 27, 2024
3 changes: 2 additions & 1 deletion crossfit/backend/torch/hf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ def __init__(
self,
path_or_name: str,
max_mem_gb: int = 16,
model_output_type: str = "numeric",
training: bool = False,
start_batch_size: int = 1,
end_batch_size: int = 2048,
batch_size_increment: int = 256,
start_seq_len: int = 1,
seq_len_increment: int = 64,
):
super().__init__(path_or_name, max_mem_gb)
super().__init__(path_or_name, max_mem_gb, model_output_type)
self.start_batch_size = start_batch_size
self.end_batch_size = end_batch_size
self.batch_size_increment = batch_size_increment
Expand Down
54 changes: 53 additions & 1 deletion crossfit/backend/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import cudf
import cupy as cp

from crossfit.backend.cudf.series import (
create_list_series_from_1d_or_2d_ar,
create_nested_list_series_from_3d_ar,
)
from crossfit.utils.torch_utils import cleanup_torch_cache, concat_and_pad_tensors


class Model:
def __init__(self, path_or_name: str, max_mem_gb: int = 16):
def __init__(self, path_or_name: str, max_mem_gb: int = 16, model_output_type: str = "numeric"):
self.path_or_name = path_or_name
self.max_mem_gb = max_mem_gb
if model_output_type in ["numeric", "string"]:
self.model_output_type = model_output_type
else:
raise ValueError(
"Invalid model output type provided. Allowed values are : 'string' or 'numeric'."
)

def load_model(self, device="cuda"):
raise NotImplementedError()
Expand All @@ -41,3 +59,37 @@ def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int:

def max_seq_length(self) -> int:
raise NotImplementedError()

def get_model_output(self, all_outputs_ls, index, loader, pred_output_col) -> cudf.DataFrame:
# importing here to avoid cyclic import error
from crossfit.backend.torch.loader import SortedSeqLoader

out = cudf.DataFrame(index=index)
_index = loader.sort_column(index.values) if type(loader) is SortedSeqLoader else index

if self.model_output_type == "string":
all_outputs = [o for output in all_outputs_ls for o in output]
out[pred_output_col] = cudf.Series(data=all_outputs, index=_index)
del all_outputs_ls
del loader
else:
outputs = cp.asarray(
concat_and_pad_tensors(
all_outputs_ls,
pad_token_id=loader.pad_token_id,
padding_side=loader.padding_side,
)
)
del all_outputs_ls
del loader
cleanup_torch_cache()
if len(outputs.shape) <= 2:
out[pred_output_col] = create_list_series_from_1d_or_2d_ar(outputs, _index)
elif len(outputs.shape) == 3:
out[pred_output_col] = create_nested_list_series_from_3d_ar(outputs, _index)
else:
raise RuntimeError(f"Unexpected output shape: {outputs.shape}")
del outputs
del _index
cleanup_torch_cache()
return out
32 changes: 5 additions & 27 deletions crossfit/backend/torch/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,11 @@

from typing import Optional

import cudf
import cupy as cp
import torch

from crossfit.backend.cudf.series import (
create_list_series_from_1d_or_2d_ar,
create_nested_list_series_from_3d_ar,
)
from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE, InMemoryLoader, SortedSeqLoader
from crossfit.backend.torch.model import Model
from crossfit.op.base import Op
from crossfit.utils.torch_utils import cleanup_torch_cache, concat_and_pad_tensors


class Predictor(Op):
Expand Down Expand Up @@ -82,26 +75,11 @@ def call(self, data, partition_info=None):
output = self.post(output)

all_outputs_ls.append(output)

out = cudf.DataFrame(index=index)
outputs = cp.asarray(
concat_and_pad_tensors(
all_outputs_ls, pad_token_id=loader.pad_token_id, padding_side=loader.padding_side
)
)
_index = loader.sort_column(index.values) if self.sorted_data_loader else index
del all_outputs_ls
del loader
cleanup_torch_cache()
if len(outputs.shape) <= 2:
out[self.pred_output_col] = create_list_series_from_1d_or_2d_ar(outputs, _index)
elif len(outputs.shape) == 3:
out[self.pred_output_col] = create_nested_list_series_from_3d_ar(outputs, _index)
else:
raise RuntimeError(f"Unexpected output shape: {output.shape}")
del outputs, _index
cleanup_torch_cache()
out = self.model.get_model_output(all_outputs_ls, index, loader, self.pred_output_col)
return out

def meta(self):
return {self.pred_output_col: "float32"}
if self.model.model_output_type == "string":
return {self.pred_output_col: "object"}
else:
return {self.pred_output_col: "float32"}
163 changes: 163 additions & 0 deletions examples/custom_ct2_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
from dataclasses import dataclass
from functools import lru_cache

import ctranslate2
import dask_cudf
from transformers import AutoConfig, AutoTokenizer

import crossfit as cf
from crossfit import op
from crossfit.backend.torch.hf.model import HFModel


@dataclass
class TranslationConfig:
pretrained_model_name_or_path: str
ct2_model_path: str


class CT2CustomModel:
def __init__(self, config: TranslationConfig, device="cuda"):
self.config = config
self.tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=config.pretrained_model_name_or_path,
trust_remote_code=True,
)
self.model = ctranslate2.Translator(model_path=config.ct2_model_path, device=device)

def clean_extra_tokens(self, token_2d):
results = [
[
t
for t in token_1d
if t
not in {
self.tokenizer.pad_token,
self.tokenizer.bos_token,
self.tokenizer.eos_token,
self.tokenizer.unk_token,
}
]
for token_1d in token_2d
]
return results

def __call__(self, batch):
token_ids_2d = batch["input_ids"]
token_ids_1d = token_ids_2d.view(-1).tolist()
tokens_1d = self.tokenizer.convert_ids_to_tokens(token_ids_1d)
tokens_2d = [
tokens_1d[i : i + token_ids_2d.size(1)]
for i in range(0, len(tokens_1d), token_ids_2d.size(1))
]
tokens = self.clean_extra_tokens(tokens_2d)

tr_res = self.model.translate_batch(
tokens,
min_decoding_length=0,
max_decoding_length=256,
beam_size=5,
num_hypotheses=1,
)
translations = ["".join(x.hypotheses[0]) for x in tr_res]
return translations


class ModelForSeq2SeqModel(HFModel):
def __init__(self, config):
self.trans_config = config
self.config = self.load_cfg()
super().__init__(
self.trans_config.pretrained_model_name_or_path, model_output_type="string"
)

def load_model(self, device="cuda"):
model = CT2CustomModel(self.trans_config)
return model

def load_config(self):
return self.load_cfg()

@lru_cache(maxsize=1)
def load_tokenizer(self):
return AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=self.trans_config.pretrained_model_name_or_path,
trust_remote_code=True,
)

def max_seq_length(self) -> int:
return self.config.max_source_positions

@lru_cache(maxsize=1)
def load_cfg(self):
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.trans_config.pretrained_model_name_or_path,
trust_remote_code=True,
)
return config


def parse_arguments():
parser = argparse.ArgumentParser(description="PyTorch Model Predictions using Crossfit")
parser.add_argument("input_parquet_path", help="Input parquet file")
parser.add_argument("output_parquet_path", help="Output file")
parser.add_argument(
"--ct2-model-dir",
help="Directory where ctranslate2 optimized model is present",
)
parser.add_argument(
"--input-column", type=str, default="text", help="Column name in input dataframe"
)
parser.add_argument("--pool-size", type=str, default="12GB", help="RMM pool size")
parser.add_argument("--num-workers", type=int, default=1, help="Number of GPUs")
parser.add_argument(
"--model-name",
type=str,
default="ai4bharat/indictrans2-en-indic-1B",
help="Model name",
)
parser.add_argument("--batch-size", type=int, default=64, help="Batch size")
parser.add_argument("--partitions", type=int, default=2, help="Number of partitions")

args = parser.parse_args()
return args


def main():
args = parse_arguments()
Config = TranslationConfig(
pretrained_model_name_or_path=args.model_name,
ct2_model_path=args.ct2_model_dir,
)
ddf = dask_cudf.read_parquet(args.input_parquet_path)

with cf.Distributed(rmm_pool_size=args.pool_size, n_workers=args.num_workers):
model = ModelForSeq2SeqModel(Config)
pipe = op.Sequential(
op.Tokenizer(model, cols=[args.input_column], tokenizer_type="default", max_length=255),
op.Predictor(model, sorted_data_loader=True, batch_size=args.batch_size),
repartition=args.partitions,
keep_cols=[args.input_column],
)
outputs = pipe(ddf)
outputs.to_parquet(args.output_parquet_path)


if __name__ == "__main__":
main()
75 changes: 75 additions & 0 deletions tests/op/test_model_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import patch

import pytest

cp = pytest.importorskip("cupy")
cudf = pytest.importorskip("cudf")
dask_cudf = pytest.importorskip("dask_cudf")
dd = pytest.importorskip("dask.dataframe")
pd = pytest.importorskip("pandas")
transformers = pytest.importorskip("transformers")
torch = pytest.importorskip("torch")

import crossfit as cf # noqa: E402

cf_loader = pytest.importorskip("crossfit.backend.torch.loader")


@pytest.mark.parametrize("trust_remote_code", ["y"])
def test_model_output_int(trust_remote_code, model_name="microsoft/deberta-v3-base"):
with patch("builtins.input", return_value=trust_remote_code):
tokens_data = cudf.DataFrame({"input_ids": [[11, 12, 13], [14, 15, 0], [17, 0, 0]]})
index = tokens_data.index.copy()
model = cf.HFModel(model_name)
data = [[4], [7], [10]]
all_outputs_ls = torch.tensor(data)
loader = cf_loader.SortedSeqLoader(
tokens_data,
model,
)
pred_output_col = "translation"
out = model.get_model_output(all_outputs_ls, index, loader, pred_output_col)
assert isinstance(out, cudf.DataFrame)
assert isinstance(out["translation"][0][0], int)
assert (
out["translation"][0] == data[0]
and out["translation"][1] == data[1]
and out["translation"][2] == data[2]
)


@pytest.mark.parametrize("trust_remote_code", ["y"])
def test_model_output_str(trust_remote_code, model_name="microsoft/deberta-v3-base"):
with patch("builtins.input", return_value=trust_remote_code):
tokens_data = cudf.DataFrame({"input_ids": [[18264, 7728, 8], [123, 99, 0], [3115, 0, 0]]})
index = tokens_data.index.copy()
model = cf.HFModel(model_name, model_output_type="string")
data = [["▁हमारे▁परीक्षण▁डेटा"], ["▁पर▁हमारे▁दो"], ["▁दूरी▁कार्यों▁की"]]
all_outputs_ls = data
loader = cf_loader.SortedSeqLoader(
tokens_data,
model,
)
pred_output_col = "translation"
out = model.get_model_output(all_outputs_ls, index, loader, pred_output_col)
assert isinstance(out, cudf.DataFrame)
assert isinstance(out["translation"][0][0], str)
assert (
out["translation"][0] == data[0][0]
and out["translation"][1] == data[1][0]
and out["translation"][2] == data[2][0]
)
Loading