-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add ops for running a custom pytorch classifier (#38)
* add ops for running a custom pytorch classifier * update * lint * incrase num partitions * catch out of memory error * decrease pool size * try moving model to inside context * comment out test for now * minor updates
- Loading branch information
Showing
8 changed files
with
423 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright 2023 NVIDIA CORPORATION | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import gc | ||
from typing import Optional | ||
|
||
import cudf | ||
import cupy as cp | ||
import torch | ||
|
||
from crossfit.backend.cudf.series import ( | ||
create_list_series_from_2d_ar, | ||
create_nested_list_series_from_3d_ar, | ||
) | ||
from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE, InMemoryLoader, SortedSeqLoader | ||
from crossfit.backend.torch.model import Model | ||
from crossfit.op.base import Op | ||
|
||
|
||
class Predictor(Op): | ||
def __init__( | ||
self, | ||
model: Model, | ||
pre=None, | ||
cols=False, | ||
keep_cols=None, | ||
batch_size: int = DEFAULT_BATCH_SIZE, | ||
max_mem: str = "16GB", | ||
sorted_data_loader: bool = True, | ||
model_output_col: Optional[str] = None, | ||
pred_output_col: str = "preds", | ||
): | ||
super().__init__(pre=pre, cols=cols, keep_cols=keep_cols) | ||
self.model = model | ||
self.batch_size = batch_size | ||
self.max_mem = max_mem | ||
self.max_mem_gb = int(self.max_mem.split("GB")[0]) / 2.5 | ||
self.sorted_data_loader = sorted_data_loader | ||
self.model_output_col = model_output_col | ||
self.pred_output_col = pred_output_col | ||
|
||
def setup(self): | ||
self.model.load_on_worker(self) | ||
|
||
@torch.no_grad() | ||
def call(self, data, partition_info=None): | ||
index = data.index | ||
if self.sorted_data_loader: | ||
loader = SortedSeqLoader( | ||
data[["input_ids", "attention_mask"]], | ||
self.model, | ||
progress_bar=self.create_progress_bar(len(data), partition_info), | ||
initial_batch_size=self.batch_size, | ||
) | ||
else: | ||
loader = InMemoryLoader( | ||
data[["input_ids", "attention_mask"]], | ||
batch_size=self.batch_size, | ||
progress_bar=self.create_progress_bar(len(data), partition_info), | ||
max_seq_len=self.model.max_seq_length(), | ||
) | ||
|
||
all_outputs_ls = [] | ||
for output in loader.map(self.model.get_model(self)): | ||
if isinstance(output, dict): | ||
if self.model_output_col not in output: | ||
raise ValueError(f"Column '{self.model_outupt_col}' not found in model output.") | ||
all_outputs_ls.append(output[self.model_output_col]) | ||
else: | ||
all_outputs_ls.append(output) | ||
|
||
out = cudf.DataFrame(index=index) | ||
outputs = cp.asarray(torch.vstack(all_outputs_ls)) | ||
_index = loader.sort_column(index.values) if self.sorted_data_loader else index | ||
if len(outputs.shape) == 2: | ||
out[self.pred_output_col] = create_list_series_from_2d_ar(outputs, _index) | ||
elif len(outputs.shape) == 3: | ||
out[self.pred_output_col] = create_nested_list_series_from_3d_ar(outputs, _index) | ||
else: | ||
raise RuntimeError(f"Unexpected output shape: {output.shape}") | ||
|
||
gc.collect() | ||
torch.cuda.empty_cache() | ||
|
||
return out | ||
|
||
def meta(self): | ||
return {self.pred_output_col: "float32"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from typing import List, Union | ||
|
||
import cudf | ||
|
||
from crossfit.op.base import Op | ||
|
||
|
||
class Labeler(Op): | ||
def __init__( | ||
self, | ||
labels: List[str], | ||
cols=None, | ||
keep_cols=None, | ||
pre=None, | ||
keep_prob: bool = False, | ||
suffix: str = "labels", | ||
): | ||
super().__init__(pre=pre, cols=cols, keep_cols=keep_cols) | ||
self.labels = labels | ||
self.keep_prob = keep_prob | ||
self.suffix = suffix | ||
|
||
def call_column(self, data: cudf.Series) -> cudf.Series: | ||
if isinstance(data, cudf.DataFrame): | ||
raise ValueError( | ||
"data must be a Series, got DataFrame. Add a pre step to convert to Series" | ||
) | ||
|
||
num_labels = len(data.iloc[0]) | ||
if len(self.labels) != num_labels: | ||
raise ValueError( | ||
f"The number of provided labels is {len(self.labels)} " | ||
f"but there are {num_labels} in data." | ||
) | ||
|
||
scores = data.list.leaves.values.reshape(-1, num_labels) | ||
classes = scores.argmax(1) | ||
labels_map = {i: self.labels[i] for i in range(len(self.labels))} | ||
|
||
return cudf.Series(classes).map(labels_map) | ||
|
||
def call(self, data: Union[cudf.Series, cudf.DataFrame]) -> Union[cudf.Series, cudf.DataFrame]: | ||
output = cudf.DataFrame() | ||
|
||
if self.cols is None: | ||
if not isinstance(data, cudf.Series): | ||
raise ValueError("data must be a cudf Series") | ||
|
||
return self.call_column(data) | ||
|
||
for col in self.cols: | ||
if col not in data.columns: | ||
raise ValueError(f"Column {col} not found in data") | ||
|
||
labels = self.call_column(data[col]) | ||
output[self._construct_name(col, self.suffix)] = labels | ||
|
||
return output | ||
|
||
def meta(self): | ||
labeled = {"labels": "string"} | ||
|
||
if len(self.cols) > 1: | ||
labeled = { | ||
self._construct_name(col, suffix): dtype | ||
for col in self.cols | ||
for suffix, dtype in labeled.items() | ||
} | ||
|
||
return labeled | ||
|
||
def _construct_name(self, col_name, suffix): | ||
if len(self.cols) == 1: | ||
return suffix | ||
|
||
return f"{col_name}_{suffix}" |
Oops, something went wrong.