From 2d6ad87ff85a681dfa9d27e7b8194e925d8a3d06 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 19 Aug 2024 08:03:24 -0700 Subject: [PATCH 01/57] Add partial image implementation Signed-off-by: Ryan Wolf --- nemo_curator/datasets/__init__.py | 3 +- .../datasets/image_text_pair_dataset.py | 57 +++++++++++++++ nemo_curator/image/classifiers/__init__.py | 13 ++++ nemo_curator/image/embedders/__init__.py | 17 +++++ nemo_curator/image/embedders/base.py | 61 ++++++++++++++++ nemo_curator/image/embedders/open_clip.py | 73 +++++++++++++++++++ setup.py | 4 + 7 files changed, 227 insertions(+), 1 deletion(-) create mode 100644 nemo_curator/datasets/image_text_pair_dataset.py create mode 100644 nemo_curator/image/classifiers/__init__.py create mode 100644 nemo_curator/image/embedders/__init__.py create mode 100644 nemo_curator/image/embedders/base.py create mode 100644 nemo_curator/image/embedders/open_clip.py diff --git a/nemo_curator/datasets/__init__.py b/nemo_curator/datasets/__init__.py index af9695b2..1a67b16a 100644 --- a/nemo_curator/datasets/__init__.py +++ b/nemo_curator/datasets/__init__.py @@ -13,5 +13,6 @@ # limitations under the License. from .doc_dataset import DocumentDataset +from .image_text_pair_dataset import ImageTextPairDataset -__all__ = ["DocumentDataset"] +__all__ = ["DocumentDataset", "ImageTextPairDataset"] diff --git a/nemo_curator/datasets/image_text_pair_dataset.py b/nemo_curator/datasets/image_text_pair_dataset.py new file mode 100644 index 00000000..6a4d1f9a --- /dev/null +++ b/nemo_curator/datasets/image_text_pair_dataset.py @@ -0,0 +1,57 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List, Optional, Union + +import dask.dataframe as dd +from fsspec.core import open_files + + +class ImageTextPairDataset: + def __init__(self, path: str, metadata, tar_files: List[str]) -> None: + self.path = path + self.metadata = metadata + self.tar_files = tar_files + + @classmethod + def from_webdataset(cls, path: str): + metadata = dd.read_parquet(path) + tar_files = cls._get_tar_files(path) + + return cls(path, metadata, tar_files) + + @staticmethod + def _get_tar_files(path: str) -> List[str]: + glob_str = os.path.join(path, "*.tar") + # open_files doesn't actually open a file descriptor + tar_files = [file.path for file in open_files(glob_str)] + + return tar_files + + def save_metadata( + self, path: Optional[str] = None, columns: Optional[List[str]] = None + ) -> None: + if path is None: + path = self.path + + if columns is None: + metadata = self.metadata + else: + metadata = self.metadata[columns] + + metadata.to_parquet(path) + + def reshard(self, path: str, filter_column: str) -> None: + pass diff --git a/nemo_curator/image/classifiers/__init__.py b/nemo_curator/image/classifiers/__init__.py new file mode 100644 index 00000000..d9155f92 --- /dev/null +++ b/nemo_curator/image/classifiers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_curator/image/embedders/__init__.py b/nemo_curator/image/embedders/__init__.py new file mode 100644 index 00000000..27bf6aca --- /dev/null +++ b/nemo_curator/image/embedders/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .base import ImageEmbedder +from .open_clip import OpenClipImageEmbedder + +__all__ = ["ImageEmbedder", "OpenClipImageEmbedder"] diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py new file mode 100644 index 00000000..d2852b00 --- /dev/null +++ b/nemo_curator/image/embedders/base.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod + +import dask.dataframe as dd + +from nemo_curator.datasets import ImageTextPairDataset +from nemo_curator.utils.distributed_utils import load_object_on_worker + + +class ImageEmbedder(ABC): + def __init__(self) -> None: + pass + + def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: + # First, convert df to delayed + delayed_dfs = dataset.df.to_delayed() + + # Set the metadata using dd.from_map(self.inference, delayed_dfs, tar_files) + metadata = dd.from_map(self.inference, delayed_dfs, dataset.tar_files) + + return ImageTextPairDataset( + dataset.path, metadata=metadata, tar_files=dataset.tar_files + ) + + def inference(self, partition, tar_path): + pipeline = self.load_data_pipline(tar_path) + pipeline.build() + model = load_object_on_worker( + "image_embedding_model", self.load_embedding_model + ) + + total_samples = pipeline.epoch_size() + samples_completed = 0 + while samples_completed < total_samples: + image, text, meta = pipeline.run() + + print(f"Image: {image}. Text: {text}. Meta: {meta}") + break + embeddings = model(image) + + return partition + + @abstractmethod + def load_data_pipline(self, tar_path: str): + pass + + @abstractmethod + def load_embedding_model(self): + pass diff --git a/nemo_curator/image/embedders/open_clip.py b/nemo_curator/image/embedders/open_clip.py new file mode 100644 index 00000000..c1557663 --- /dev/null +++ b/nemo_curator/image/embedders/open_clip.py @@ -0,0 +1,73 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import nvidia.dali.fn as fn +import nvidia.dali.types as types +import open_clip +from nvidia.dali import pipeline_def + +from nemo_curator.image.embedders.base import ImageEmbedder + + +class OpenClipImageEmbedder(ImageEmbedder): + def __init__( + self, + model_name: str, + pretrained: Optional[str] = None, + batch_size: int = 1, + num_threads_per_worker=4, + ) -> None: + self.model_name = model_name + self.pretrained = pretrained + self.batch_size = batch_size + self.num_threads_per_worker = num_threads_per_worker + + def load_data_pipline(self, tar_path: str): + # Create the DALI pipeline + @pipeline_def( + batch_size=self.batch_size, + num_threads=self.num_threads_per_worker, + device_id=0, + ) + def webdataset_pipeline(_tar_path: str): + img_raw, text, json = fn.readers.webdataset( + paths=_tar_path, + ext=["jpg", "txt", "json"], + missing_component_behavior="error", + ) + img = fn.decoders.image(img_raw, device="mixed", output_type=types.RGB) + img = fn.crop_mirror_normalize( + img, + dtype=types.FLOAT, + mean=[0, 0, 0], + std=[255, 255, 255], + ) + + resized = fn.resize(img, device="gpu", resize_shorter=224) + output = fn.crop_mirror_normalize( + resized, + dtype=types.FLOAT, + crop=(224, 224), + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ) + return output, text, json + + return webdataset_pipeline(tar_path) + + def load_embedding_model(self, device="cuda"): + return open_clip.create_model( + self.model_name, pretrained=self.pretrained, device=device + ) diff --git a/setup.py b/setup.py index dd0d667b..e8fc0c7c 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,10 @@ "dask-cuda>=24.2", "spacy[cuda12x]>=3.6.0, <4.0.0", ], + "image": [ + "nvidia-dali-cuda120", + "open_clip_torch", + ], }, entry_points={ "console_scripts": [ From 4b32c1e6f1b5b056e9ba3eb9195bbd9d5bd4defc Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 19 Aug 2024 08:17:59 -0700 Subject: [PATCH 02/57] Refactor requirements Signed-off-by: Ryan Wolf --- requirements/requirements.txt | 29 +++++++++++ requirements/requirements_cuda12x.txt | 6 +++ requirements/requirements_image.txt | 2 + setup.py | 74 +++++++++------------------ 4 files changed, 61 insertions(+), 50 deletions(-) create mode 100644 requirements/requirements.txt create mode 100644 requirements/requirements_cuda12x.txt create mode 100644 requirements/requirements_image.txt diff --git a/requirements/requirements.txt b/requirements/requirements.txt new file mode 100644 index 00000000..05db445b --- /dev/null +++ b/requirements/requirements.txt @@ -0,0 +1,29 @@ +awscli>=1.22.55 +beautifulsoup4 +charset_normalizer>=3.1.0 +comment_parser +crossfit @ git+https://github.com/rapidsai/crossfit.git@0cc2993 +Cython +dask-mpi>=2021.11.0 +dask[complete]>=2021.7.1 +distributed>=2021.7.1 +fasttext==0.9.2 +ftfy==6.1.1 +in-place==0.5.0 +jieba==0.42.1 +justext==3.0.1 +lxml_html_clean +mwparserfromhell==0.6.5 +nemo_toolkit[nlp]>=1.23.0 +numpy<2 +openai +peft +presidio-analyzer==2.2.351 +presidio-anonymizer==2.2.351 +pycld2 +resiliparse +spacy>=3.6.0, <4.0.0 +unidic-lite==1.0.8 +usaddress==0.5.10 +warcio==1.7.4 +zstandard==0.18.0 diff --git a/requirements/requirements_cuda12x.txt b/requirements/requirements_cuda12x.txt new file mode 100644 index 00000000..ede6af8c --- /dev/null +++ b/requirements/requirements_cuda12x.txt @@ -0,0 +1,6 @@ +cudf-cu12>=24.2 +cugraph-cu12>=24.2 +cuml-cu12>=24.2 +dask-cuda>=24.2 +dask-cudf-cu12>=24.2 +spacy[cuda12x]>=3.6.0, <4.0.0 diff --git a/requirements/requirements_image.txt b/requirements/requirements_image.txt new file mode 100644 index 00000000..2b6688d8 --- /dev/null +++ b/requirements/requirements_image.txt @@ -0,0 +1,2 @@ +nvidia-dali-cuda120 +open_clip_torch diff --git a/setup.py b/setup.py index e8fc0c7c..6c488dec 100644 --- a/setup.py +++ b/setup.py @@ -11,14 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os from setuptools import setup, find_packages import pathlib +from itertools import chain here = pathlib.Path(__file__).parent.resolve() long_description = (here / "README.md").read_text(encoding="utf-8") + +def req_file(filename, folder="requirements"): + with open(os.path.join(folder, filename), encoding="utf-8") as f: + content = f.readlines() + return [x.strip() for x in content] + + +install_requires = req_file("requirements.txt") + +extras_require = { + "cuda12x": req_file("requirements_cuda12x.txt"), + "image": req_file("requirements_image.txt"), +} + +extras_require["all"] = list(chain(extras_require.values())) + +extras_require["image"] = list( + chain([extras_require["image"], extras_require["cuda12x"]]) +) + setup( name="nemo_curator", version="0.4.0", @@ -38,55 +59,8 @@ ], packages=find_packages(), python_requires=">=3.10, <3.11", - install_requires=[ - "dask[complete]>=2021.7.1", - "distributed>=2021.7.1", - "dask-mpi>=2021.11.0", - "charset_normalizer>=3.1.0", - "awscli>=1.22.55", - "fasttext==0.9.2", - "pycld2", - "justext==3.0.1", - # lxml_html_clean has difficulty installing in jusText - # installing explicitly to prevent errors/lag with jusText - "lxml_html_clean", - "resiliparse", - "ftfy==6.1.1", - "warcio==1.7.4", - "zstandard==0.18.0", - "in-place==0.5.0", - "unidic-lite==1.0.8", - "jieba==0.42.1", - "comment_parser", - "beautifulsoup4", - "mwparserfromhell==0.6.5", - "spacy>=3.6.0, <4.0.0", - "presidio-analyzer==2.2.351", - "presidio-anonymizer==2.2.351", - "usaddress==0.5.10", - "nemo_toolkit[nlp]>=1.23.0", - "Cython", - "crossfit @ git+https://github.com/rapidsai/crossfit.git@0cc2993", - # Numpy 2.0 breaks with spacy https://github.com/explosion/spaCy/issues/13528 - # TODO: Remove when issue is fixed - "numpy<2", - "openai", - "peft", - ], - extras_require={ - "cuda12x": [ - "cudf-cu12>=24.2", - "dask-cudf-cu12>=24.2", - "cuml-cu12>=24.2", - "cugraph-cu12>=24.2", - "dask-cuda>=24.2", - "spacy[cuda12x]>=3.6.0, <4.0.0", - ], - "image": [ - "nvidia-dali-cuda120", - "open_clip_torch", - ], - }, + install_requires=install_requires, + extras_require=extras_require, entry_points={ "console_scripts": [ "get_common_crawl_urls=nemo_curator.scripts.get_common_crawl_urls:console_script", From 601bf5c56d8861430b73d1a4b781e09653b5fd33 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 19 Aug 2024 08:29:36 -0700 Subject: [PATCH 03/57] Fix bugs Signed-off-by: Ryan Wolf --- nemo_curator/image/__init__.py | 13 +++++++++++++ nemo_curator/image/embedders/base.py | 6 +++--- 2 files changed, 16 insertions(+), 3 deletions(-) create mode 100644 nemo_curator/image/__init__.py diff --git a/nemo_curator/image/__init__.py b/nemo_curator/image/__init__.py new file mode 100644 index 00000000..d9155f92 --- /dev/null +++ b/nemo_curator/image/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index d2852b00..ff8ea12d 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -25,10 +25,10 @@ def __init__(self) -> None: def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: # First, convert df to delayed - delayed_dfs = dataset.df.to_delayed() + delayed_metadata = dataset.metadata.to_delayed() - # Set the metadata using dd.from_map(self.inference, delayed_dfs, tar_files) - metadata = dd.from_map(self.inference, delayed_dfs, dataset.tar_files) + # Set the metadata + metadata = dd.from_map(self.inference, delayed_metadata, dataset.tar_files) return ImageTextPairDataset( dataset.path, metadata=metadata, tar_files=dataset.tar_files From 6c8ecd64bbc4d37271d2b0d3b0db2efd263cf6b1 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 19 Aug 2024 08:46:13 -0700 Subject: [PATCH 04/57] Change from_map to map_partitions Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/base.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index ff8ea12d..4af84703 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -20,21 +20,22 @@ class ImageEmbedder(ABC): - def __init__(self) -> None: - pass + def __init__(self, image_embedding_column) -> None: + self.image_embedding_column = image_embedding_column def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: - # First, convert df to delayed - delayed_metadata = dataset.metadata.to_delayed() - - # Set the metadata - metadata = dd.from_map(self.inference, delayed_metadata, dataset.tar_files) + meta_df = dataset.metadata._meta.copy() + meta_df[self.image_embedding_column] = [1.0, 2.0] + embedding_df = dataset.metadata.map_partitions( + self.inference, dataset.tar_files, meta=meta_df + ) return ImageTextPairDataset( - dataset.path, metadata=metadata, tar_files=dataset.tar_files + dataset.path, metadata=embedding_df, tar_files=dataset.tar_files ) - def inference(self, partition, tar_path): + def inference(self, partition, tar_paths, partition_info=None): + tar_path = tar_paths[partition_info["number"]] pipeline = self.load_data_pipline(tar_path) pipeline.build() model = load_object_on_worker( @@ -50,6 +51,7 @@ def inference(self, partition, tar_path): break embeddings = model(image) + partition[self.image_embedding_column] = [[1.234]] * len(partition) return partition @abstractmethod From 0856a65fb5f79dcf1c7f3c94b416e76e221862bd Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 19 Aug 2024 08:50:45 -0700 Subject: [PATCH 05/57] Add super constructor Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/open_clip.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo_curator/image/embedders/open_clip.py b/nemo_curator/image/embedders/open_clip.py index c1557663..a8f288ef 100644 --- a/nemo_curator/image/embedders/open_clip.py +++ b/nemo_curator/image/embedders/open_clip.py @@ -28,7 +28,10 @@ def __init__( pretrained: Optional[str] = None, batch_size: int = 1, num_threads_per_worker=4, + image_embedding_column="image_embedding", ) -> None: + super().__init__(image_embedding_column=image_embedding_column) + self.model_name = model_name self.pretrained = pretrained self.batch_size = batch_size From 4dbef4260f424975e6a158537a4052050082a2fc Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 19 Aug 2024 08:52:29 -0700 Subject: [PATCH 06/57] Add kwargs for load_object_on_worker Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index 4af84703..381552b8 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -39,7 +39,7 @@ def inference(self, partition, tar_paths, partition_info=None): pipeline = self.load_data_pipline(tar_path) pipeline.build() model = load_object_on_worker( - "image_embedding_model", self.load_embedding_model + "image_embedding_model", self.load_embedding_model, {} ) total_samples = pipeline.epoch_size() From abb6b13eabbc5a444fc379fc2718dec3bfb35ead Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 19 Aug 2024 08:58:12 -0700 Subject: [PATCH 07/57] Get proper epoch size Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index 381552b8..e6e4e5b3 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -43,6 +43,7 @@ def inference(self, partition, tar_paths, partition_info=None): ) total_samples = pipeline.epoch_size() + total_samples = total_samples[list(total_samples.keys())[0]] samples_completed = 0 while samples_completed < total_samples: image, text, meta = pipeline.run() From 61eb1e363b4fc6cf067b25e355a2fa81bca43755 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 19 Aug 2024 09:23:24 -0700 Subject: [PATCH 08/57] Complete embedding creation loop Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/base.py | 39 +++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index e6e4e5b3..35f07a42 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -11,9 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json from abc import ABC, abstractmethod import dask.dataframe as dd +import numpy as np +import torch +from nvidia.dali.plugin.pytorch import feed_ndarray from nemo_curator.datasets import ImageTextPairDataset from nemo_curator.utils.distributed_utils import load_object_on_worker @@ -25,7 +29,7 @@ def __init__(self, image_embedding_column) -> None: def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: meta_df = dataset.metadata._meta.copy() - meta_df[self.image_embedding_column] = [1.0, 2.0] + meta_df[self.image_embedding_column] = [[1.0]] embedding_df = dataset.metadata.map_partitions( self.inference, dataset.tar_files, meta=meta_df ) @@ -45,16 +49,43 @@ def inference(self, partition, tar_paths, partition_info=None): total_samples = pipeline.epoch_size() total_samples = total_samples[list(total_samples.keys())[0]] samples_completed = 0 + final_image_embeddings = [] while samples_completed < total_samples: image, text, meta = pipeline.run() print(f"Image: {image}. Text: {text}. Meta: {meta}") - break - embeddings = model(image) + image = image.as_tensor() - partition[self.image_embedding_column] = [[1.234]] * len(partition) + image_torch = torch.empty( + image.shape(), dtype=torch.float32, device=self.device + ) + feed_ndarray(image, image_torch) # COPY !!! + + image = image_torch + captions = [text.at(i).tostring().decode("utf-8") for i in range(len(text))] + metadata = [ + json.loads(meta.at(i).tostring().decode("utf-8")) + for i in range(len(meta)) + ] + + with torch.no_grad(): + image_features = model(image) + batch_image_embeddings = np.asarray( + self.normalized(image_features.detach().cpu()) + ) + + for embedding in batch_image_embeddings: + final_image_embeddings.append(embedding) + + partition[self.image_embedding_column] = final_image_embeddings return partition + @staticmethod + def normalized(a, axis=-1, order=2): + l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) + l2[l2 == 0] = 1 + return a / np.expand_dims(l2, axis) + @abstractmethod def load_data_pipline(self, tar_path: str): pass From f8d692f9cf23e945c4b2bbd1721ab0f104f77f86 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 19 Aug 2024 09:29:56 -0700 Subject: [PATCH 09/57] Change devices Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/base.py | 6 +----- nemo_curator/image/embedders/open_clip.py | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index 35f07a42..71561f8f 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -52,13 +52,9 @@ def inference(self, partition, tar_paths, partition_info=None): final_image_embeddings = [] while samples_completed < total_samples: image, text, meta = pipeline.run() - - print(f"Image: {image}. Text: {text}. Meta: {meta}") image = image.as_tensor() - image_torch = torch.empty( - image.shape(), dtype=torch.float32, device=self.device - ) + image_torch = torch.empty(image.shape(), dtype=torch.float32, device="cuda") feed_ndarray(image, image_torch) # COPY !!! image = image_torch diff --git a/nemo_curator/image/embedders/open_clip.py b/nemo_curator/image/embedders/open_clip.py index a8f288ef..75ec9591 100644 --- a/nemo_curator/image/embedders/open_clip.py +++ b/nemo_curator/image/embedders/open_clip.py @@ -42,7 +42,6 @@ def load_data_pipline(self, tar_path: str): @pipeline_def( batch_size=self.batch_size, num_threads=self.num_threads_per_worker, - device_id=0, ) def webdataset_pipeline(_tar_path: str): img_raw, text, json = fn.readers.webdataset( From 5752562e646be81b7d00aae1bdde8b7d843ec3e9 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 19 Aug 2024 09:40:06 -0700 Subject: [PATCH 10/57] Add device Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/open_clip.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo_curator/image/embedders/open_clip.py b/nemo_curator/image/embedders/open_clip.py index 75ec9591..004ca2b4 100644 --- a/nemo_curator/image/embedders/open_clip.py +++ b/nemo_curator/image/embedders/open_clip.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Optional import nvidia.dali.fn as fn @@ -42,6 +43,7 @@ def load_data_pipline(self, tar_path: str): @pipeline_def( batch_size=self.batch_size, num_threads=self.num_threads_per_worker, + device_id=0, ) def webdataset_pipeline(_tar_path: str): img_raw, text, json = fn.readers.webdataset( From 4a4b3566b34e63a88f4f4385c35b7083b8dd510b Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 26 Aug 2024 17:20:24 -0700 Subject: [PATCH 11/57] Refactor embedding creation and add classifier Signed-off-by: Ryan Wolf --- .../datasets/image_text_pair_dataset.py | 4 +- nemo_curator/image/classifiers/__init__.py | 4 + nemo_curator/image/classifiers/aesthetic.py | 91 +++++++++++++++++++ nemo_curator/image/classifiers/base.py | 71 +++++++++++++++ nemo_curator/image/embedders/base.py | 82 ++++++++--------- nemo_curator/image/embedders/open_clip.py | 63 +++++++++++-- nemo_curator/utils/cudf_utils.py | 46 ++++++++++ nemo_curator/utils/file_utils.py | 4 + requirements/requirements_image.txt | 1 + 9 files changed, 312 insertions(+), 54 deletions(-) create mode 100644 nemo_curator/image/classifiers/aesthetic.py create mode 100644 nemo_curator/image/classifiers/base.py create mode 100644 nemo_curator/utils/cudf_utils.py diff --git a/nemo_curator/datasets/image_text_pair_dataset.py b/nemo_curator/datasets/image_text_pair_dataset.py index 6a4d1f9a..e255bcfe 100644 --- a/nemo_curator/datasets/image_text_pair_dataset.py +++ b/nemo_curator/datasets/image_text_pair_dataset.py @@ -15,7 +15,7 @@ import os from typing import List, Optional, Union -import dask.dataframe as dd +import dask_cudf from fsspec.core import open_files @@ -27,7 +27,7 @@ def __init__(self, path: str, metadata, tar_files: List[str]) -> None: @classmethod def from_webdataset(cls, path: str): - metadata = dd.read_parquet(path) + metadata = dask_cudf.read_parquet(path) tar_files = cls._get_tar_files(path) return cls(path, metadata, tar_files) diff --git a/nemo_curator/image/classifiers/__init__.py b/nemo_curator/image/classifiers/__init__.py index d9155f92..3988438f 100644 --- a/nemo_curator/image/classifiers/__init__.py +++ b/nemo_curator/image/classifiers/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .aesthetic import AestheticClassifier +from .base import ImageClassifier + +__all__ = ["AestheticClassifier", "ImageClassifier"] diff --git a/nemo_curator/image/classifiers/aesthetic.py b/nemo_curator/image/classifiers/aesthetic.py new file mode 100644 index 00000000..d951c978 --- /dev/null +++ b/nemo_curator/image/classifiers/aesthetic.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Optional + +import requests +import torch +import torch.nn as nn + +from nemo_curator.image.classifiers import ImageClassifier +from nemo_curator.utils.file_utils import NEMO_CURATOR_HOME + + +# MLP code taken from LAION Aesthetic V2 +# https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py +class MLP(nn.Module): + def __init__(self, input_size, xcol="emb", ycol="avg_rating"): + super().__init__() + self.input_size = input_size + self.xcol = xcol + self.ycol = ycol + self.layers = nn.Sequential( + nn.Linear(self.input_size, 1024), + nn.Dropout(0.2), + nn.Linear(1024, 128), + nn.Dropout(0.2), + nn.Linear(128, 64), + nn.Dropout(0.1), + nn.Linear(64, 16), + nn.Linear(16, 1), + ) + + def forward(self, x): + return self.layers(x) + + +class AestheticClassifier(ImageClassifier): + def __init__( + self, + embedding_column: str = "image_embedding", + pred_column: str = "aesthetic_scores", + batch_size: int = -1, + model_path: Optional[str] = None, + ) -> None: + super().__init__( + embedding_column=embedding_column, + pred_column=pred_column, + class_type=float, + batch_size=batch_size, + ) + + if model_path is None: + model_path = self._get_default_model() + + self.model_path = model_path + + @staticmethod + def _get_default_model(): + weights_name = "sac+logos+ava1-l14-linearMSE.pth" + model_path = os.path.join(NEMO_CURATOR_HOME, weights_name) + + if not os.path.exists(model_path): + url = ( + "https://github.com/christophschuhmann/" + f"improved-aesthetic-predictor/blob/main/{weights_name}?raw=true" + ) + r = requests.get(url) + + with open(model_path, "wb") as f: + f.write(r.content) + + return model_path + + def load_model(self, device): + model = MLP(768).to(device) + weights = torch.load(self.model_path, map_location=torch.device("cpu")) + model.load_state_dict(weights) + model.eval() + + return model diff --git a/nemo_curator/image/classifiers/base.py b/nemo_curator/image/classifiers/base.py new file mode 100644 index 00000000..e0658106 --- /dev/null +++ b/nemo_curator/image/classifiers/base.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from abc import ABC, abstractmethod + +import cupy as cp +import torch + +from nemo_curator.datasets import ImageTextPairDataset +from nemo_curator.utils.cudf_utils import create_list_series_from_1d_or_2d_ar +from nemo_curator.utils.distributed_utils import load_object_on_worker + + +class ImageClassifier(ABC): + """ + An abstract base class that represents a classifier on top + of embeddings generated by a CLIP vision encoder + """ + + def __init__( + self, embedding_column: str, pred_column: str, class_type: str, batch_size: int + ) -> None: + self.embedding_column = embedding_column + self.pred_column = pred_column + self.class_type = class_type + self.batch_size = batch_size + + def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: + meta = dataset.metadata.dtypes.to_dict() + meta[self.pred_column] = self.class_type + embedding_df = dataset.metadata.map_partitions(self.inference, meta=meta) + + return ImageTextPairDataset( + dataset.path, metadata=embedding_df, tar_files=dataset.tar_files + ) + + def inference(self, partition, partition_info=None): + device_id = int(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]) + device = f"cuda:{device_id}" + + model = load_object_on_worker( + "image_embedding_model", + self.load_model, + {"device": device}, + ) + + embeddings = torch.as_tensor( + partition["embeddings"].list.leaves.values.reshape(len(partition), -1), + device=device, + ) + + scores = cp.asarray(model(embeddings)) + + partition[self.pred_column] = create_list_series_from_1d_or_2d_ar(scores) + + return partition + + @abstractmethod + def load_model(self, device): + pass diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index 71561f8f..3fff9725 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -11,15 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json +import os from abc import ABC, abstractmethod +from typing import Iterable -import dask.dataframe as dd -import numpy as np +import cupy as cp import torch -from nvidia.dali.plugin.pytorch import feed_ndarray from nemo_curator.datasets import ImageTextPairDataset +from nemo_curator.utils.cudf_utils import create_list_series_from_1d_or_2d_ar from nemo_curator.utils.distributed_utils import load_object_on_worker @@ -28,10 +28,10 @@ def __init__(self, image_embedding_column) -> None: self.image_embedding_column = image_embedding_column def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: - meta_df = dataset.metadata._meta.copy() - meta_df[self.image_embedding_column] = [[1.0]] + meta = dataset.metadata.dtypes.to_dict() + meta[self.image_embedding_column] = "object" embedding_df = dataset.metadata.map_partitions( - self.inference, dataset.tar_files, meta=meta_df + self.inference, dataset.tar_files, meta=meta ) return ImageTextPairDataset( @@ -40,52 +40,44 @@ def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: def inference(self, partition, tar_paths, partition_info=None): tar_path = tar_paths[partition_info["number"]] - pipeline = self.load_data_pipline(tar_path) - pipeline.build() + device_id = int(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]) + model = load_object_on_worker( - "image_embedding_model", self.load_embedding_model, {} + "image_embedding_model", + self.load_embedding_model, + {"device": f"cuda:{device_id}"}, ) - total_samples = pipeline.epoch_size() - total_samples = total_samples[list(total_samples.keys())[0]] - samples_completed = 0 + dataset = self.load_dataset_shard(tar_path, device_id=device_id) final_image_embeddings = [] - while samples_completed < total_samples: - image, text, meta = pipeline.run() - image = image.as_tensor() - - image_torch = torch.empty(image.shape(), dtype=torch.float32, device="cuda") - feed_ndarray(image, image_torch) # COPY !!! - - image = image_torch - captions = [text.at(i).tostring().decode("utf-8") for i in range(len(text))] - metadata = [ - json.loads(meta.at(i).tostring().decode("utf-8")) - for i in range(len(meta)) - ] - - with torch.no_grad(): - image_features = model(image) - batch_image_embeddings = np.asarray( - self.normalized(image_features.detach().cpu()) - ) - - for embedding in batch_image_embeddings: - final_image_embeddings.append(embedding) - - partition[self.image_embedding_column] = final_image_embeddings - return partition + samples_completed = 0 + with torch.no_grad(): + for batch in dataset: + image_embeddings = model(batch) + final_image_embeddings.append(image_embeddings) + + batch_size = len(image_embeddings) + samples_completed += batch_size - @staticmethod - def normalized(a, axis=-1, order=2): - l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) - l2[l2 == 0] = 1 - return a / np.expand_dims(l2, axis) + print(f"{tar_path} - Samples Completed: {samples_completed}.") + + if samples_completed != len(partition): + raise RuntimeError( + f"Mismatch in sample count for partition {partition_info['number']}. " + f"{len(partition)} samples found in the metadata, but {samples_completed} found in {tar_path}." + ) + + concat_output = cp.asarray(torch.cat(final_image_embeddings, dim=0)) + partition[self.image_embedding_column] = create_list_series_from_1d_or_2d_ar( + concat_output, index=partition.index + ) + + return partition @abstractmethod - def load_data_pipline(self, tar_path: str): + def load_dataset_shard(self, tar_path: str) -> Iterable: pass @abstractmethod - def load_embedding_model(self): + def load_embedding_model(self, device): pass diff --git a/nemo_curator/image/embedders/open_clip.py b/nemo_curator/image/embedders/open_clip.py index 004ca2b4..70e999f1 100644 --- a/nemo_curator/image/embedders/open_clip.py +++ b/nemo_curator/image/embedders/open_clip.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os +import json from typing import Optional import nvidia.dali.fn as fn import nvidia.dali.types as types import open_clip +import torch from nvidia.dali import pipeline_def +from nvidia.dali.plugin.pytorch import feed_ndarray from nemo_curator.image.embedders.base import ImageEmbedder @@ -28,8 +30,9 @@ def __init__( model_name: str, pretrained: Optional[str] = None, batch_size: int = 1, - num_threads_per_worker=4, - image_embedding_column="image_embedding", + num_threads_per_worker: int = 4, + image_embedding_column: str = "image_embedding", + normalize_embeddings: bool = True, ) -> None: super().__init__(image_embedding_column=image_embedding_column) @@ -37,13 +40,14 @@ def __init__( self.pretrained = pretrained self.batch_size = batch_size self.num_threads_per_worker = num_threads_per_worker + self.normalize_embeddings = normalize_embeddings - def load_data_pipline(self, tar_path: str): + def load_dataset_shard(self, tar_path: str, device_id=0): # Create the DALI pipeline @pipeline_def( batch_size=self.batch_size, num_threads=self.num_threads_per_worker, - device_id=0, + device_id=device_id, ) def webdataset_pipeline(_tar_path: str): img_raw, text, json = fn.readers.webdataset( @@ -69,9 +73,54 @@ def webdataset_pipeline(_tar_path: str): ) return output, text, json - return webdataset_pipeline(tar_path) + pipeline = webdataset_pipeline(tar_path) + pipeline.build() + + total_samples = pipeline.epoch_size() + total_samples = total_samples[list(total_samples.keys())[0]] + + samples_completed = 0 + while samples_completed < total_samples: + image, text, meta = pipeline.run() + image = image.as_tensor() + + image_torch = torch.empty( + image.shape(), dtype=torch.float32, device=f"cuda:{device_id}" + ) + feed_ndarray(image, image_torch) # COPY !!! + image = image_torch + + captions = [text.at(i).tostring().decode("utf-8") for i in range(len(text))] + metadata = [ + json.loads(meta.at(i).tostring().decode("utf-8")) + for i in range(len(meta)) + ] + + remaining_samples = total_samples - samples_completed + if image.shape[0] >= remaining_samples: + image = image[:remaining_samples] + captions = captions[:remaining_samples] + metadata = metadata[:remaining_samples] + + samples_completed += min(image.shape[0], remaining_samples) + + yield image def load_embedding_model(self, device="cuda"): - return open_clip.create_model( + model = open_clip.create_model( self.model_name, pretrained=self.pretrained, device=device ) + model.eval() + + def infer(batch): + image_features = model.encode_image(batch) + if self.normalize_embeddings: + image_features = self.torch_normalized(image_features) + + return image_features + + return infer + + @staticmethod + def torch_normalized(a, dim=-1): + return torch.nn.functional.normalize(a, dim=dim) diff --git a/nemo_curator/utils/cudf_utils.py b/nemo_curator/utils/cudf_utils.py new file mode 100644 index 00000000..2ec6c1e2 --- /dev/null +++ b/nemo_curator/utils/cudf_utils.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cudf +import cupy as cp +from cudf.core.column import as_column + + +@staticmethod +def create_list_series_from_1d_or_2d_ar(ar, index): + """ + Create a cudf list series from 2d arrays + """ + if len(ar.shape) == 1: + n_rows, *_ = ar.shape + n_cols = 1 + elif len(ar.shape) == 2: + n_rows, n_cols = ar.shape + else: + return RuntimeError(f"Unexpected input shape: {ar.shape}") + data = as_column(ar.flatten()) + offset_col = as_column( + cp.arange(start=0, stop=len(data) + 1, step=n_cols), dtype="int32" + ) + mask_col = cp.full(shape=n_rows, fill_value=cp.bool_(True)) + mask = cudf._lib.transform.bools_to_mask(as_column(mask_col)) + lc = cudf.core.column.ListColumn( + size=n_rows, + dtype=cudf.ListDtype(data.dtype), + mask=mask, + offset=0, + null_count=0, + children=(offset_col, data), + ) + + return cudf.Series(lc, index=index) diff --git a/nemo_curator/utils/file_utils.py b/nemo_curator/utils/file_utils.py index de5c78af..5793b235 100644 --- a/nemo_curator/utils/file_utils.py +++ b/nemo_curator/utils/file_utils.py @@ -25,6 +25,10 @@ from nemo_curator.utils.distributed_utils import single_partition_write_with_filename +NEMO_CURATOR_HOME = os.environ.get( + "NEMO_CURATOR_HOME", os.path.join(os.path.expanduser("~"), ".nemo_curator") +) + def mkdir(d): pathlib.Path(d).mkdir(parents=True, exist_ok=True) diff --git a/requirements/requirements_image.txt b/requirements/requirements_image.txt index 2b6688d8..6c1d5f49 100644 --- a/requirements/requirements_image.txt +++ b/requirements/requirements_image.txt @@ -1,2 +1,3 @@ nvidia-dali-cuda120 +nvidia-nvjpeg2k-cu12 open_clip_torch From bfde960304985174dfbf035db26b457b322f6429 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 26 Aug 2024 17:31:44 -0700 Subject: [PATCH 12/57] Fix bugs in classifiers Signed-off-by: Ryan Wolf --- nemo_curator/image/classifiers/aesthetic.py | 3 ++- nemo_curator/image/classifiers/base.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/nemo_curator/image/classifiers/aesthetic.py b/nemo_curator/image/classifiers/aesthetic.py index d951c978..45eb0642 100644 --- a/nemo_curator/image/classifiers/aesthetic.py +++ b/nemo_curator/image/classifiers/aesthetic.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn -from nemo_curator.image.classifiers import ImageClassifier +from nemo_curator.image.classifiers.base import ImageClassifier from nemo_curator.utils.file_utils import NEMO_CURATOR_HOME @@ -69,6 +69,7 @@ def __init__( def _get_default_model(): weights_name = "sac+logos+ava1-l14-linearMSE.pth" model_path = os.path.join(NEMO_CURATOR_HOME, weights_name) + os.makedirs(NEMO_CURATOR_HOME, exist_ok=True) if not os.path.exists(model_path): url = ( diff --git a/nemo_curator/image/classifiers/base.py b/nemo_curator/image/classifiers/base.py index e0658106..a57d5183 100644 --- a/nemo_curator/image/classifiers/base.py +++ b/nemo_curator/image/classifiers/base.py @@ -56,7 +56,9 @@ def inference(self, partition, partition_info=None): ) embeddings = torch.as_tensor( - partition["embeddings"].list.leaves.values.reshape(len(partition), -1), + partition[self.embedding_column].list.leaves.values.reshape( + len(partition), -1 + ), device=device, ) From e421a368a3321b0df20e5c1ddd4f806fb5efb66b Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 27 Aug 2024 09:10:40 -0700 Subject: [PATCH 13/57] Refactor model names Signed-off-by: Ryan Wolf --- nemo_curator/image/classifiers/aesthetic.py | 3 ++- nemo_curator/image/classifiers/base.py | 14 ++++++++++---- nemo_curator/image/embedders/base.py | 9 +++++---- nemo_curator/image/embedders/open_clip.py | 6 +++--- 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/nemo_curator/image/classifiers/aesthetic.py b/nemo_curator/image/classifiers/aesthetic.py index 45eb0642..5e85636e 100644 --- a/nemo_curator/image/classifiers/aesthetic.py +++ b/nemo_curator/image/classifiers/aesthetic.py @@ -64,6 +64,7 @@ def __init__( model_path = self._get_default_model() self.model_path = model_path + self.embedding_dim = 768 @staticmethod def _get_default_model(): @@ -84,7 +85,7 @@ def _get_default_model(): return model_path def load_model(self, device): - model = MLP(768).to(device) + model = MLP(self.embedding_dim).to(device) weights = torch.load(self.model_path, map_location=torch.device("cpu")) model.load_state_dict(weights) model.eval() diff --git a/nemo_curator/image/classifiers/base.py b/nemo_curator/image/classifiers/base.py index a57d5183..db19101f 100644 --- a/nemo_curator/image/classifiers/base.py +++ b/nemo_curator/image/classifiers/base.py @@ -29,8 +29,14 @@ class ImageClassifier(ABC): """ def __init__( - self, embedding_column: str, pred_column: str, class_type: str, batch_size: int + self, + model_name: str, + embedding_column: str, + pred_column: str, + class_type: str, + batch_size: int, ) -> None: + self.model_name = model_name self.embedding_column = embedding_column self.pred_column = pred_column self.class_type = class_type @@ -39,18 +45,18 @@ def __init__( def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: meta = dataset.metadata.dtypes.to_dict() meta[self.pred_column] = self.class_type - embedding_df = dataset.metadata.map_partitions(self.inference, meta=meta) + embedding_df = dataset.metadata.map_partitions(self._run_inference, meta=meta) return ImageTextPairDataset( dataset.path, metadata=embedding_df, tar_files=dataset.tar_files ) - def inference(self, partition, partition_info=None): + def _run_inference(self, partition, partition_info=None): device_id = int(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]) device = f"cuda:{device_id}" model = load_object_on_worker( - "image_embedding_model", + self.model_name, self.load_model, {"device": device}, ) diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index 3fff9725..1d28d925 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -24,26 +24,27 @@ class ImageEmbedder(ABC): - def __init__(self, image_embedding_column) -> None: + def __init__(self, model_name: str, image_embedding_column: str) -> None: + self.model_name = model_name self.image_embedding_column = image_embedding_column def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: meta = dataset.metadata.dtypes.to_dict() meta[self.image_embedding_column] = "object" embedding_df = dataset.metadata.map_partitions( - self.inference, dataset.tar_files, meta=meta + self._run_inference, dataset.tar_files, meta=meta ) return ImageTextPairDataset( dataset.path, metadata=embedding_df, tar_files=dataset.tar_files ) - def inference(self, partition, tar_paths, partition_info=None): + def _run_inference(self, partition, tar_paths, partition_info=None): tar_path = tar_paths[partition_info["number"]] device_id = int(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]) model = load_object_on_worker( - "image_embedding_model", + self.model_name, self.load_embedding_model, {"device": f"cuda:{device_id}"}, ) diff --git a/nemo_curator/image/embedders/open_clip.py b/nemo_curator/image/embedders/open_clip.py index 70e999f1..3f689e55 100644 --- a/nemo_curator/image/embedders/open_clip.py +++ b/nemo_curator/image/embedders/open_clip.py @@ -34,9 +34,9 @@ def __init__( image_embedding_column: str = "image_embedding", normalize_embeddings: bool = True, ) -> None: - super().__init__(image_embedding_column=image_embedding_column) - - self.model_name = model_name + super().__init__( + model_name=model_name, image_embedding_column=image_embedding_column + ) self.pretrained = pretrained self.batch_size = batch_size self.num_threads_per_worker = num_threads_per_worker From b09892ede66c29dd03c8d516e318335f52b28155 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 27 Aug 2024 09:13:13 -0700 Subject: [PATCH 14/57] Add model name Signed-off-by: Ryan Wolf --- nemo_curator/image/classifiers/aesthetic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo_curator/image/classifiers/aesthetic.py b/nemo_curator/image/classifiers/aesthetic.py index 5e85636e..84dc4a20 100644 --- a/nemo_curator/image/classifiers/aesthetic.py +++ b/nemo_curator/image/classifiers/aesthetic.py @@ -54,6 +54,7 @@ def __init__( model_path: Optional[str] = None, ) -> None: super().__init__( + model_name="aesthetic_model_v2", embedding_column=embedding_column, pred_column=pred_column, class_type=float, From 8d43f9a1bd88def76e1154cd39ed0ae9d31e262d Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 27 Aug 2024 09:49:16 -0700 Subject: [PATCH 15/57] Fix classifier bugs Signed-off-by: Ryan Wolf --- nemo_curator/image/classifiers/aesthetic.py | 5 ++++- nemo_curator/image/classifiers/base.py | 8 ++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/nemo_curator/image/classifiers/aesthetic.py b/nemo_curator/image/classifiers/aesthetic.py index 84dc4a20..b5763301 100644 --- a/nemo_curator/image/classifiers/aesthetic.py +++ b/nemo_curator/image/classifiers/aesthetic.py @@ -91,4 +91,7 @@ def load_model(self, device): model.load_state_dict(weights) model.eval() - return model + def infer(batch): + return model(batch).squeeze() + + return infer diff --git a/nemo_curator/image/classifiers/base.py b/nemo_curator/image/classifiers/base.py index db19101f..0155f071 100644 --- a/nemo_curator/image/classifiers/base.py +++ b/nemo_curator/image/classifiers/base.py @@ -68,9 +68,13 @@ def _run_inference(self, partition, partition_info=None): device=device, ) - scores = cp.asarray(model(embeddings)) + with torch.no_grad(): + scores = model(embeddings) + scores = cp.asarray(scores) - partition[self.pred_column] = create_list_series_from_1d_or_2d_ar(scores) + partition[self.pred_column] = create_list_series_from_1d_or_2d_ar( + scores, index=partition.index + ) return partition From 49a21ef6d931f0654dffd70d450ee31a1d2445d0 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 27 Aug 2024 13:35:02 -0700 Subject: [PATCH 16/57] Allow postprocessing for classifiers Signed-off-by: Ryan Wolf --- .../datasets/image_text_pair_dataset.py | 2 +- nemo_curator/image/classifiers/aesthetic.py | 3 +++ nemo_curator/image/classifiers/base.py | 25 ++++++++++++++----- nemo_curator/image/embedders/base.py | 4 ++- 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/nemo_curator/datasets/image_text_pair_dataset.py b/nemo_curator/datasets/image_text_pair_dataset.py index e255bcfe..d622f8b3 100644 --- a/nemo_curator/datasets/image_text_pair_dataset.py +++ b/nemo_curator/datasets/image_text_pair_dataset.py @@ -53,5 +53,5 @@ def save_metadata( metadata.to_parquet(path) - def reshard(self, path: str, filter_column: str) -> None: + def to_webdataset(self, path: str, filter_column: str) -> None: pass diff --git a/nemo_curator/image/classifiers/aesthetic.py b/nemo_curator/image/classifiers/aesthetic.py index b5763301..e837d226 100644 --- a/nemo_curator/image/classifiers/aesthetic.py +++ b/nemo_curator/image/classifiers/aesthetic.py @@ -95,3 +95,6 @@ def infer(batch): return model(batch).squeeze() return infer + + def postprocess(self, series): + return series.list.leaves diff --git a/nemo_curator/image/classifiers/base.py b/nemo_curator/image/classifiers/base.py index 0155f071..12cf8442 100644 --- a/nemo_curator/image/classifiers/base.py +++ b/nemo_curator/image/classifiers/base.py @@ -13,6 +13,7 @@ # limitations under the License. import os from abc import ABC, abstractmethod +from typing import Union import cupy as cp import torch @@ -33,13 +34,13 @@ def __init__( model_name: str, embedding_column: str, pred_column: str, - class_type: str, + pred_type: Union[str, type], batch_size: int, ) -> None: self.model_name = model_name self.embedding_column = embedding_column self.pred_column = pred_column - self.class_type = class_type + self.pred_type = pred_type self.batch_size = batch_size def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: @@ -69,15 +70,27 @@ def _run_inference(self, partition, partition_info=None): ) with torch.no_grad(): - scores = model(embeddings) + if self.batch_size > 0: + batches = torch.split(embeddings, self.batch_size) + model_results = [] + for batch in batches: + batch_results = model(batch) + model_results.append(batch_results) + scores = torch.cat(model_results, dim=0) + else: + scores = model(embeddings) + scores = cp.asarray(scores) - partition[self.pred_column] = create_list_series_from_1d_or_2d_ar( - scores, index=partition.index - ) + series = create_list_series_from_1d_or_2d_ar(scores, index=partition.index) + postprocessed_series = self.postprocess(series) + partition[self.pred_column] = postprocessed_series return partition @abstractmethod def load_model(self, device): pass + + def postprocess(self, series): + return series diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index 1d28d925..4fbeec60 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -60,7 +60,9 @@ def _run_inference(self, partition, tar_paths, partition_info=None): batch_size = len(image_embeddings) samples_completed += batch_size - print(f"{tar_path} - Samples Completed: {samples_completed}.") + print( + f"{tar_path} - Embedding Creation with {self.model_name} Samples Completed: {samples_completed}." + ) if samples_completed != len(partition): raise RuntimeError( From edf2905b78074ba80e914bde085b75630d6ee242 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 27 Aug 2024 13:42:30 -0700 Subject: [PATCH 17/57] Fix name and add print Signed-off-by: Ryan Wolf --- nemo_curator/image/classifiers/aesthetic.py | 2 +- nemo_curator/image/classifiers/base.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/nemo_curator/image/classifiers/aesthetic.py b/nemo_curator/image/classifiers/aesthetic.py index e837d226..b2708947 100644 --- a/nemo_curator/image/classifiers/aesthetic.py +++ b/nemo_curator/image/classifiers/aesthetic.py @@ -57,7 +57,7 @@ def __init__( model_name="aesthetic_model_v2", embedding_column=embedding_column, pred_column=pred_column, - class_type=float, + pred_type=float, batch_size=batch_size, ) diff --git a/nemo_curator/image/classifiers/base.py b/nemo_curator/image/classifiers/base.py index 12cf8442..851554e1 100644 --- a/nemo_curator/image/classifiers/base.py +++ b/nemo_curator/image/classifiers/base.py @@ -86,6 +86,10 @@ def _run_inference(self, partition, partition_info=None): postprocessed_series = self.postprocess(series) partition[self.pred_column] = postprocessed_series + print( + f"Partition {partition_info['number']} - Classification with {self.model_name} completed for {len(scores)} samples." + ) + return partition @abstractmethod From eaef49aefca639e561d90556bcad8e4f1b25eace Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 27 Aug 2024 13:44:26 -0700 Subject: [PATCH 18/57] Fix variable name Signed-off-by: Ryan Wolf --- nemo_curator/image/classifiers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_curator/image/classifiers/base.py b/nemo_curator/image/classifiers/base.py index 851554e1..9a85eca6 100644 --- a/nemo_curator/image/classifiers/base.py +++ b/nemo_curator/image/classifiers/base.py @@ -45,7 +45,7 @@ def __init__( def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: meta = dataset.metadata.dtypes.to_dict() - meta[self.pred_column] = self.class_type + meta[self.pred_column] = self.pred_type embedding_df = dataset.metadata.map_partitions(self._run_inference, meta=meta) return ImageTextPairDataset( From 7ba7c34d39240f3b981cb8081b1d1d75d9c8cea5 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 28 Aug 2024 15:15:58 -0700 Subject: [PATCH 19/57] Add NSFW Signed-off-by: Ryan Wolf --- nemo_curator/image/classifiers/aesthetic.py | 8 +- nemo_curator/image/classifiers/base.py | 9 ++ nemo_curator/image/classifiers/nsfw.py | 104 ++++++++++++++++++++ 3 files changed, 117 insertions(+), 4 deletions(-) create mode 100644 nemo_curator/image/classifiers/nsfw.py diff --git a/nemo_curator/image/classifiers/aesthetic.py b/nemo_curator/image/classifiers/aesthetic.py index b2708947..5057d15f 100644 --- a/nemo_curator/image/classifiers/aesthetic.py +++ b/nemo_curator/image/classifiers/aesthetic.py @@ -49,23 +49,23 @@ class AestheticClassifier(ImageClassifier): def __init__( self, embedding_column: str = "image_embedding", - pred_column: str = "aesthetic_scores", + pred_column: str = "aesthetic_score", batch_size: int = -1, model_path: Optional[str] = None, ) -> None: super().__init__( - model_name="aesthetic_model_v2", + model_name="aesthetic_classifier", embedding_column=embedding_column, pred_column=pred_column, pred_type=float, batch_size=batch_size, + embedding_size=768, ) if model_path is None: model_path = self._get_default_model() self.model_path = model_path - self.embedding_dim = 768 @staticmethod def _get_default_model(): @@ -86,7 +86,7 @@ def _get_default_model(): return model_path def load_model(self, device): - model = MLP(self.embedding_dim).to(device) + model = MLP(self.embedding_size).to(device) weights = torch.load(self.model_path, map_location=torch.device("cpu")) model.load_state_dict(weights) model.eval() diff --git a/nemo_curator/image/classifiers/base.py b/nemo_curator/image/classifiers/base.py index 9a85eca6..29c0f235 100644 --- a/nemo_curator/image/classifiers/base.py +++ b/nemo_curator/image/classifiers/base.py @@ -36,12 +36,14 @@ def __init__( pred_column: str, pred_type: Union[str, type], batch_size: int, + embedding_size: int, ) -> None: self.model_name = model_name self.embedding_column = embedding_column self.pred_column = pred_column self.pred_type = pred_type self.batch_size = batch_size + self.embedding_size = embedding_size def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: meta = dataset.metadata.dtypes.to_dict() @@ -69,6 +71,13 @@ def _run_inference(self, partition, partition_info=None): device=device, ) + if self.embedding_dim != embeddings.shape[-1]: + raise RuntimeError( + f"{self.model_name} expects embedding size {self.embedding_size} but column " + f"'{self.embedding_column}' has embedding size {embeddings.shape[-1]}. Ensure your " + "classifier is compatible with the CLIP model you used to generate the embeddings." + ) + with torch.no_grad(): if self.batch_size > 0: batches = torch.split(embeddings, self.batch_size) diff --git a/nemo_curator/image/classifiers/nsfw.py b/nemo_curator/image/classifiers/nsfw.py new file mode 100644 index 00000000..a4e14be1 --- /dev/null +++ b/nemo_curator/image/classifiers/nsfw.py @@ -0,0 +1,104 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Optional + +import requests +import torch +import torch.nn as nn + +from nemo_curator.image.classifiers.base import ImageClassifier +from nemo_curator.utils.file_utils import NEMO_CURATOR_HOME + + +# MLP code taken from LAION's CLIP-based-NSFW-Detector +# https://github.com/LAION-AI/CLIP-based-NSFW-Detector/blob/main/h14_nsfw_model.py +class H14_NSFW_Detector(nn.Module): + def __init__(self, input_size=1024): + super().__init__() + self.input_size = input_size + self.layers = nn.Sequential( + nn.Linear(self.input_size, 1024), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(1024, 2048), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(2048, 1024), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(1024, 256), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(256, 128), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(128, 16), + nn.Linear(16, 1), + ) + + def forward(self, x): + return self.layers(x) + + +class NsfwClassifier(ImageClassifier): + def __init__( + self, + embedding_column: str = "image_embedding", + pred_column: str = "nsfw_score", + batch_size: int = -1, + model_path: Optional[str] = None, + ) -> None: + super().__init__( + model_name="nsfw_classifier", + embedding_column=embedding_column, + pred_column=pred_column, + pred_type=float, + batch_size=batch_size, + embedding_size=1024, + ) + + if model_path is None: + model_path = self._get_default_model() + + self.model_path = model_path + + @staticmethod + def _get_default_model(): + weights_name = "h14_nsfw.pth" + model_path = os.path.join(NEMO_CURATOR_HOME, weights_name) + os.makedirs(NEMO_CURATOR_HOME, exist_ok=True) + + if not os.path.exists(model_path): + url = f"https://github.com/LAION-AI/CLIP-based-NSFW-Detector/blob/main/{weights_name}?raw=true" + r = requests.get(url) + + with open(model_path, "wb") as f: + f.write(r.content) + + return model_path + + def load_model(self, device): + model = H14_NSFW_Detector(input_size=self.embedding_size).to(device) + weights = torch.load(self.model_path, map_location=torch.device("cpu")) + model.load_state_dict(weights) + model.eval() + + def infer(batch): + return model(batch).squeeze() + + return infer + + def postprocess(self, series): + return series.list.leaves From c1c1b1a5c2e0eae5038fb3e31905e08958936965 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 28 Aug 2024 15:23:48 -0700 Subject: [PATCH 20/57] Update init for import Signed-off-by: Ryan Wolf --- nemo_curator/image/classifiers/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo_curator/image/classifiers/__init__.py b/nemo_curator/image/classifiers/__init__.py index 3988438f..94597ce4 100644 --- a/nemo_curator/image/classifiers/__init__.py +++ b/nemo_curator/image/classifiers/__init__.py @@ -13,5 +13,6 @@ # limitations under the License. from .aesthetic import AestheticClassifier from .base import ImageClassifier +from .nsfw import NsfwClassifier -__all__ = ["AestheticClassifier", "ImageClassifier"] +__all__ = ["AestheticClassifier", "ImageClassifier", "NsfwClassifier"] From 88032c8ae3ff093853c954bd0c42344d514b9920 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 28 Aug 2024 15:59:29 -0700 Subject: [PATCH 21/57] Fix embedding size Signed-off-by: Ryan Wolf --- nemo_curator/image/classifiers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_curator/image/classifiers/base.py b/nemo_curator/image/classifiers/base.py index 29c0f235..a32603c7 100644 --- a/nemo_curator/image/classifiers/base.py +++ b/nemo_curator/image/classifiers/base.py @@ -71,7 +71,7 @@ def _run_inference(self, partition, partition_info=None): device=device, ) - if self.embedding_dim != embeddings.shape[-1]: + if self.embedding_size != embeddings.shape[-1]: raise RuntimeError( f"{self.model_name} expects embedding size {self.embedding_size} but column " f"'{self.embedding_column}' has embedding size {embeddings.shape[-1]}. Ensure your " From b4c5cd5bc0f695d5b89a3abc0e9f601a215c3161 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 28 Aug 2024 16:42:06 -0700 Subject: [PATCH 22/57] Add fused classifiers Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/base.py | 34 ++++++++++++++++++++--- nemo_curator/image/embedders/open_clip.py | 7 +++-- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index 4fbeec60..9767f8bd 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -19,14 +19,21 @@ import torch from nemo_curator.datasets import ImageTextPairDataset +from nemo_curator.image.classifiers import ImageClassifier from nemo_curator.utils.cudf_utils import create_list_series_from_1d_or_2d_ar from nemo_curator.utils.distributed_utils import load_object_on_worker class ImageEmbedder(ABC): - def __init__(self, model_name: str, image_embedding_column: str) -> None: + def __init__( + self, + model_name: str, + image_embedding_column: str, + classifiers: Iterable[ImageClassifier], + ) -> None: self.model_name = model_name self.image_embedding_column = image_embedding_column + self.classifiers = classifiers def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: meta = dataset.metadata.dtypes.to_dict() @@ -42,21 +49,35 @@ def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: def _run_inference(self, partition, tar_paths, partition_info=None): tar_path = tar_paths[partition_info["number"]] device_id = int(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]) + device = f"cuda:{device_id}" model = load_object_on_worker( self.model_name, self.load_embedding_model, - {"device": f"cuda:{device_id}"}, + {"device": device}, ) + classifier_models = [] + for classifier in self.classifiers: + loaded_classifier = load_object_on_worker( + classifier.model_name, classifier.load_model, {"device": device} + ) + classifier_models.append(loaded_classifier) dataset = self.load_dataset_shard(tar_path, device_id=device_id) final_image_embeddings = [] + classifier_results = [[] for _ in self.classifiers] samples_completed = 0 with torch.no_grad(): for batch in dataset: image_embeddings = model(batch) final_image_embeddings.append(image_embeddings) + for classifier_model, results in zip( + classifier_models, classifier_results + ): + classifier_result = classifier_model(image_embeddings) + results.append(classifier_result) + batch_size = len(image_embeddings) samples_completed += batch_size @@ -70,11 +91,16 @@ def _run_inference(self, partition, tar_paths, partition_info=None): f"{len(partition)} samples found in the metadata, but {samples_completed} found in {tar_path}." ) - concat_output = cp.asarray(torch.cat(final_image_embeddings, dim=0)) + concat_embedding_output = cp.asarray(torch.cat(final_image_embeddings, dim=0)) partition[self.image_embedding_column] = create_list_series_from_1d_or_2d_ar( - concat_output, index=partition.index + concat_embedding_output, index=partition.index ) + for classifier, results in zip(self.classifiers, classifier_results): + concat_output = cp.asarray(torch.cat(results, dim=0)) + series = create_list_series_from_1d_or_2d_ar(concat_output) + partition[classifier.pred_column] = classifier.postprocess(series) + return partition @abstractmethod diff --git a/nemo_curator/image/embedders/open_clip.py b/nemo_curator/image/embedders/open_clip.py index 3f689e55..2d3cb3ee 100644 --- a/nemo_curator/image/embedders/open_clip.py +++ b/nemo_curator/image/embedders/open_clip.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from typing import Optional +from typing import Iterable, Optional import nvidia.dali.fn as fn import nvidia.dali.types as types @@ -33,9 +33,12 @@ def __init__( num_threads_per_worker: int = 4, image_embedding_column: str = "image_embedding", normalize_embeddings: bool = True, + classifiers: Iterable = [], ) -> None: super().__init__( - model_name=model_name, image_embedding_column=image_embedding_column + model_name=model_name, + image_embedding_column=image_embedding_column, + classifiers=classifiers, ) self.pretrained = pretrained self.batch_size = batch_size From 8d939135dec7050215500ca9049727e77efc0828 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 28 Aug 2024 16:47:02 -0700 Subject: [PATCH 23/57] Fix missing index Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index 9767f8bd..03da1e19 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -98,7 +98,9 @@ def _run_inference(self, partition, tar_paths, partition_info=None): for classifier, results in zip(self.classifiers, classifier_results): concat_output = cp.asarray(torch.cat(results, dim=0)) - series = create_list_series_from_1d_or_2d_ar(concat_output) + series = create_list_series_from_1d_or_2d_ar( + concat_output, index=partition.index + ) partition[classifier.pred_column] = classifier.postprocess(series) return partition From 873b4109fbcd6d45c5a2e3e4fcd2fc2baa5c9813 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 28 Aug 2024 16:51:03 -0700 Subject: [PATCH 24/57] Update metdata for fused classifiers Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index 03da1e19..7817ca11 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -38,6 +38,9 @@ def __init__( def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: meta = dataset.metadata.dtypes.to_dict() meta[self.image_embedding_column] = "object" + for classifier in self.classifiers: + meta[classifier.pred_column] = classifier.pred_type + embedding_df = dataset.metadata.map_partitions( self._run_inference, dataset.tar_files, meta=meta ) From c73e2921b840eae26529420ae2453f2aa2f29f50 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 3 Sep 2024 19:38:58 -0700 Subject: [PATCH 25/57] Add export to webdataset Signed-off-by: Ryan Wolf --- .../datasets/image_text_pair_dataset.py | 176 +++++++++++++++++- 1 file changed, 169 insertions(+), 7 deletions(-) diff --git a/nemo_curator/datasets/image_text_pair_dataset.py b/nemo_curator/datasets/image_text_pair_dataset.py index d622f8b3..cd8ef1c5 100644 --- a/nemo_curator/datasets/image_text_pair_dataset.py +++ b/nemo_curator/datasets/image_text_pair_dataset.py @@ -12,25 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. +import io +import math import os -from typing import List, Optional, Union +import tarfile +from functools import partial +from typing import List, Optional import dask_cudf +import fsspec +import numpy as np +import pandas as pd from fsspec.core import open_files class ImageTextPairDataset: - def __init__(self, path: str, metadata, tar_files: List[str]) -> None: + def __init__(self, path: str, metadata, tar_files: List[str], id_col: str) -> None: self.path = path self.metadata = metadata self.tar_files = tar_files + self.id_col = id_col @classmethod - def from_webdataset(cls, path: str): + def from_webdataset(cls, path: str, id_col: str): metadata = dask_cudf.read_parquet(path) + metadata = metadata.map_partitions(cls._sort_partition, id_col=id_col) + tar_files = cls._get_tar_files(path) - return cls(path, metadata, tar_files) + return cls(path, metadata, tar_files, id_col) + + @staticmethod + def _sort_partition(partition, id_col): + return partition.sort_values(id_col) @staticmethod def _get_tar_files(path: str) -> List[str]: @@ -40,6 +54,20 @@ def _get_tar_files(path: str) -> List[str]: return tar_files + @staticmethod + def _name_partition( + partition_index: int, + temp: bool = False, + max_shards: int = 5, + ext: str = "parquet", + ) -> str: + if temp: + prefix = "temp_" + else: + prefix = "" + + return f"{prefix}{partition_index:0{max_shards}d}.{ext}" + def save_metadata( self, path: Optional[str] = None, columns: Optional[List[str]] = None ) -> None: @@ -51,7 +79,141 @@ def save_metadata( else: metadata = self.metadata[columns] - metadata.to_parquet(path) + metadata.to_parquet(path, name_function=self._name_partition) + + @staticmethod + def _filter_valid_members(members, valid_ids): + def filter_members(member): + full_id = member.name.split(".")[0] + sample_id = int(full_id) + + return sample_id in valid_ids + + return list(filter(filter_members, members)) + + def _get_eligible_samples(self, output_path: str, samples_per_shard: int): + parquet_glob_str = os.path.join(output_path, "temp_*.parquet") + tar_glob_str = os.path.join(self.path, "*.tar") + parquet_files = open_files(parquet_glob_str) + tar_files = open_files(tar_glob_str) + + curr_df = None + total_tar_samples = [] + for parquet_file, tar_file in zip(parquet_files, tar_files): + with parquet_file as f: + shard_df = pd.read_parquet(f) + + # Get all the samples associated with this dataframe from the tar file + valid_member_ids = set(map(int, shard_df[self.id_col])) + with tar_file as f: + tar = tarfile.open(fileobj=f) + valid_members = self._filter_valid_members( + tar.getmembers(), valid_member_ids + ) + valid_members.sort(key=lambda x: x.name) + tar_samples = [ + (member, tar.extractfile(member).read()) for member in valid_members + ] + + if len(tar_samples) % len(shard_df) != 0: + raise RuntimeError( + f"Tarfile {tar_file.path} entries {len(tar_samples)} are not a multiple of the number of samples {len(shard_df)}" + ) + + entries_per_sample = int(len(tar_samples) / len(shard_df)) + + # Concat the dataframe and tar file samples + if curr_df is not None: + curr_df = pd.concat([curr_df, shard_df], ignore_index=True) + else: + curr_df = shard_df + total_tar_samples.extend(tar_samples) + + # Delete the temp shard + parquet_file.fs.delete(parquet_file.path) + + # While there are enough samples, yield a slice and discard it + while len(curr_df) >= samples_per_shard: + yield ( + curr_df.iloc[:samples_per_shard].copy(), + total_tar_samples[: samples_per_shard * entries_per_sample], + ) + curr_df = curr_df.iloc[samples_per_shard:] + total_tar_samples = total_tar_samples[ + samples_per_shard * entries_per_sample : + ] + + # Return the remaining df and samples + yield curr_df, total_tar_samples + + @staticmethod + def combine_id(shard_id, sample_id, max_shards=5, max_samples_per_shard=4) -> str: + int_id = sample_id + (10**max_samples_per_shard) * shard_id + n_digits = max_samples_per_shard + max_shards + combined_id = f"{int_id:0{n_digits}d}" + return combined_id + + def split_id(combined_id: str, max_shards=5): + return int(combined_id[:max_shards]), int(combined_id[max_shards:]) + + def to_webdataset( + self, + path: str, + filter_column: str, + samples_per_shard: int = 10000, + max_shards=5, + old_id_col=None, + ) -> None: + max_samples_per_shard = math.ceil(math.log10(samples_per_shard)) + filtered_metadata = self.metadata[self.metadata[filter_column]] + + temp_name_fn = partial(self._name_partition, temp=True, max_shards=max_shards) + filtered_metadata.to_parquet(path, name_function=temp_name_fn) + + for shard_id, (shard_df, shard_tar) in enumerate( + self._get_eligible_samples(path, samples_per_shard) + ): + output_parquet_base = self._name_partition(shard_id, max_shards=max_shards) + output_tar_base = self._name_partition( + shard_id, max_shards=max_shards, ext="tar" + ) + output_parquet_path = os.path.join(path, output_parquet_base) + output_tar_path = os.path.join(path, output_tar_base) + output_parquet_file = fsspec.open(output_parquet_path, mode="wb") + output_tar_file = fsspec.open(output_tar_path, mode="wb") + + # Change the id on the parquet files + if old_id_col: + shard_df[old_id_col] = shard_df[self.id_col] + + new_ids = np.arange(len(shard_df)) + convert_ids = partial( + self.combine_id, + shard_id, + max_shards=max_shards, + max_samples_per_shard=max_samples_per_shard, + ) + shard_df[self.id_col] = list(map(convert_ids, new_ids)) + with output_parquet_file as f: + shard_df.to_parquet(f, index=False) + + members_per_sample = len(shard_tar) / len(shard_df) + with output_tar_file as f: + tar = tarfile.open(fileobj=f, mode="w") + for i, (member, data) in enumerate(shard_tar): + # Rename the each member to match the new id + sample_id = int(i // members_per_sample) + member_id = self.combine_id( + shard_id, + sample_id, + max_shards=max_shards, + max_samples_per_shard=max_samples_per_shard, + ) + extension = member.name.split(".")[-1] + member.name = f"{member_id}.{extension}" - def to_webdataset(self, path: str, filter_column: str) -> None: - pass + tar.addfile(member, io.BytesIO(data)) + print( + f"Finished writing shard {self._name_partition(shard_id, ext='tar')} with " + f"parquet length {len(shard_df)} and tar length {len(shard_tar)}" + ) From 361e0d19249a9995f64cb42f1e39a6b958ef4b27 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 4 Sep 2024 13:37:55 -0700 Subject: [PATCH 26/57] Fix missing id col Signed-off-by: Ryan Wolf --- nemo_curator/image/classifiers/base.py | 5 ++++- nemo_curator/image/embedders/base.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/nemo_curator/image/classifiers/base.py b/nemo_curator/image/classifiers/base.py index a32603c7..2e9d915b 100644 --- a/nemo_curator/image/classifiers/base.py +++ b/nemo_curator/image/classifiers/base.py @@ -51,7 +51,10 @@ def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: embedding_df = dataset.metadata.map_partitions(self._run_inference, meta=meta) return ImageTextPairDataset( - dataset.path, metadata=embedding_df, tar_files=dataset.tar_files + dataset.path, + metadata=embedding_df, + tar_files=dataset.tar_files, + id_col=dataset.id_col, ) def _run_inference(self, partition, partition_info=None): diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index 7817ca11..e4f1cf8e 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -46,7 +46,10 @@ def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: ) return ImageTextPairDataset( - dataset.path, metadata=embedding_df, tar_files=dataset.tar_files + dataset.path, + metadata=embedding_df, + tar_files=dataset.tar_files, + id_col=dataset.id_col, ) def _run_inference(self, partition, tar_paths, partition_info=None): From ce91626bf34e8113fa40b84aa9f83bb5d01acca5 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 4 Sep 2024 14:42:21 -0700 Subject: [PATCH 27/57] Sort embeddings by id Signed-off-by: Ryan Wolf --- .../datasets/image_text_pair_dataset.py | 2 +- nemo_curator/image/embedders/base.py | 17 ++++++++++++----- nemo_curator/image/embedders/open_clip.py | 2 +- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/nemo_curator/datasets/image_text_pair_dataset.py b/nemo_curator/datasets/image_text_pair_dataset.py index cd8ef1c5..54f91506 100644 --- a/nemo_curator/datasets/image_text_pair_dataset.py +++ b/nemo_curator/datasets/image_text_pair_dataset.py @@ -44,7 +44,7 @@ def from_webdataset(cls, path: str, id_col: str): @staticmethod def _sort_partition(partition, id_col): - return partition.sort_values(id_col) + return partition.sort_values(id_col).reset_index(drop=True) @staticmethod def _get_tar_files(path: str) -> List[str]: diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index e4f1cf8e..8f597b53 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -42,7 +42,7 @@ def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: meta[classifier.pred_column] = classifier.pred_type embedding_df = dataset.metadata.map_partitions( - self._run_inference, dataset.tar_files, meta=meta + self._run_inference, dataset.tar_files, dataset.id_col, meta=meta ) return ImageTextPairDataset( @@ -52,7 +52,7 @@ def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: id_col=dataset.id_col, ) - def _run_inference(self, partition, tar_paths, partition_info=None): + def _run_inference(self, partition, tar_paths, id_col, partition_info=None): tar_path = tar_paths[partition_info["number"]] device_id = int(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]) device = f"cuda:{device_id}" @@ -71,12 +71,14 @@ def _run_inference(self, partition, tar_paths, partition_info=None): dataset = self.load_dataset_shard(tar_path, device_id=device_id) final_image_embeddings = [] + image_ids = [] classifier_results = [[] for _ in self.classifiers] samples_completed = 0 with torch.no_grad(): - for batch in dataset: + for batch, metadata in dataset: image_embeddings = model(batch) final_image_embeddings.append(image_embeddings) + image_ids.extend(m[id_col] for m in metadata) for classifier_model, results in zip( classifier_models, classifier_results @@ -97,13 +99,18 @@ def _run_inference(self, partition, tar_paths, partition_info=None): f"{len(partition)} samples found in the metadata, but {samples_completed} found in {tar_path}." ) - concat_embedding_output = cp.asarray(torch.cat(final_image_embeddings, dim=0)) + # Order the output of the shard + sorted_indices = sorted(range(len(image_ids)), key=lambda k: image_ids[k]) + sorted_embeddings = torch.cat(final_image_embeddings, dim=0)[sorted_indices] + + concat_embedding_output = cp.asarray(sorted_embeddings) partition[self.image_embedding_column] = create_list_series_from_1d_or_2d_ar( concat_embedding_output, index=partition.index ) for classifier, results in zip(self.classifiers, classifier_results): - concat_output = cp.asarray(torch.cat(results, dim=0)) + sorted_results = torch.cat(results, dim=0)[sorted_indices] + concat_output = cp.asarray(sorted_results) series = create_list_series_from_1d_or_2d_ar( concat_output, index=partition.index ) diff --git a/nemo_curator/image/embedders/open_clip.py b/nemo_curator/image/embedders/open_clip.py index 2d3cb3ee..71c7489d 100644 --- a/nemo_curator/image/embedders/open_clip.py +++ b/nemo_curator/image/embedders/open_clip.py @@ -107,7 +107,7 @@ def webdataset_pipeline(_tar_path: str): samples_completed += min(image.shape[0], remaining_samples) - yield image + yield image, metadata def load_embedding_model(self, device="cuda"): model = open_clip.create_model( From d338943a93d01b530e288ad4d3e10ef5eca95669 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 4 Sep 2024 17:28:17 -0700 Subject: [PATCH 28/57] Add timm Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/timm.py | 124 +++++++++++++++++++++++++ nemo_curator/utils/image/__init__.py | 13 +++ nemo_curator/utils/image/transforms.py | 85 +++++++++++++++++ requirements/requirements_image.txt | 1 + 4 files changed, 223 insertions(+) create mode 100644 nemo_curator/image/embedders/timm.py create mode 100644 nemo_curator/utils/image/__init__.py create mode 100644 nemo_curator/utils/image/transforms.py diff --git a/nemo_curator/image/embedders/timm.py b/nemo_curator/image/embedders/timm.py new file mode 100644 index 00000000..124cf654 --- /dev/null +++ b/nemo_curator/image/embedders/timm.py @@ -0,0 +1,124 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from typing import Iterable, Optional + +import nvidia.dali.fn as fn +import nvidia.dali.types as types +import timm +import torch +from nvidia.dali import pipeline_def +from nvidia.dali.plugin.pytorch import feed_ndarray + +from nemo_curator.image.embedders.base import ImageEmbedder +from nemo_curator.utils.image.transforms import convert_transforms_to_dali + + +class TimmImageEmbedder(ImageEmbedder): + def __init__( + self, + model_name: str, + pretrained: bool = False, + batch_size: int = 1, + num_threads_per_worker: int = 4, + image_embedding_column: str = "image_embedding", + normalize_embeddings: bool = True, + classifiers: Iterable = [], + ) -> None: + super().__init__( + model_name=model_name, + image_embedding_column=image_embedding_column, + classifiers=classifiers, + ) + self.pretrained = pretrained + self.batch_size = batch_size + self.num_threads_per_worker = num_threads_per_worker + self.normalize_embeddings = normalize_embeddings + + # Load the model to get the transforms + model = timm.create_model(self.model_name, pretrained=self.pretrained) + torch_transforms = timm.data.create_transform( + **timm.data.resolve_data_config(model.pretrained_cfg) + ) + self.dali_transforms = convert_transforms_to_dali(torch_transforms) + + def load_dataset_shard(self, tar_path: str, device_id=0): + # Create the DALI pipeline + @pipeline_def( + batch_size=self.batch_size, + num_threads=self.num_threads_per_worker, + device_id=device_id, + ) + def webdataset_pipeline(_tar_path: str): + img_raw, text, json = fn.readers.webdataset( + paths=_tar_path, + ext=["jpg", "txt", "json"], + missing_component_behavior="error", + ) + img = fn.decoders.image(img_raw, device="mixed", output_type=types.RGB) + + for transform in self.dali_transforms: + img = transform(img) + + return img, text, json + + pipeline = webdataset_pipeline(tar_path) + pipeline.build() + + total_samples = pipeline.epoch_size() + total_samples = total_samples[list(total_samples.keys())[0]] + + samples_completed = 0 + while samples_completed < total_samples: + image, text, meta = pipeline.run() + image = image.as_tensor() + + image_torch = torch.empty( + image.shape(), dtype=torch.float32, device=f"cuda:{device_id}" + ) + feed_ndarray(image, image_torch) # COPY !!! + image = image_torch + + captions = [text.at(i).tostring().decode("utf-8") for i in range(len(text))] + metadata = [ + json.loads(meta.at(i).tostring().decode("utf-8")) + for i in range(len(meta)) + ] + + remaining_samples = total_samples - samples_completed + if image.shape[0] >= remaining_samples: + image = image[:remaining_samples] + captions = captions[:remaining_samples] + metadata = metadata[:remaining_samples] + + samples_completed += min(image.shape[0], remaining_samples) + + yield image, metadata + + def load_embedding_model(self, device="cuda"): + model = timm.create_model(self.model_name, pretrained=self.pretrained).eval() + model = model.to(device) + + def infer(batch): + image_features = model(batch) + if self.normalize_embeddings: + image_features = self.torch_normalized(image_features) + + return image_features + + return infer + + @staticmethod + def torch_normalized(a, dim=-1): + return torch.nn.functional.normalize(a, dim=dim) diff --git a/nemo_curator/utils/image/__init__.py b/nemo_curator/utils/image/__init__.py new file mode 100644 index 00000000..d9155f92 --- /dev/null +++ b/nemo_curator/utils/image/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_curator/utils/image/transforms.py b/nemo_curator/utils/image/transforms.py new file mode 100644 index 00000000..e01f8f65 --- /dev/null +++ b/nemo_curator/utils/image/transforms.py @@ -0,0 +1,85 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import List + +import nvidia.dali.fn as fn +from timm.data.transforms import MaybeToTensor +from torchvision.transforms.transforms import CenterCrop, Compose, Normalize, Resize + +ERROR_MESSAGE = """Transforms do not conform to expected style and cannot be automatically converted. +Expected: + Compose( + Resize(interpolation=bicubic, max_size=None, antialias=True), + CenterCrop(), + MaybeToTensor(), + Normalize(), + ) + +Got: {} + +Please manually convert the image transformations to use DALI +""" + + +def convert_transforms_to_dali(torch_transform: Compose) -> List: + """ + Converts a list of PyTorch/Timm image transformations into DALI transformations + Only works with transformations that follow this pattern: + + Compose( + Resize(interpolation=bicubic, max_size=None, antialias=True), + CenterCrop(), + MaybeToTensor(), + Normalize(), + ) + + Anything that does not follow this pattern will cause a ValueError to be raised + """ + if not isinstance(torch_transform, Compose): + raise ValueError(ERROR_MESSAGE.format(torch_transform)) + + crop = None + mean = [0.0] + std = [1.0] + resize_shorter = 0.0 + interp_type = "bicubic" + + # Loop over all transforms and extract relevant parameters + for transform in torch_transform.transforms: + if isinstance(transform, Resize): + if transform.interpolation != "bicubic": + raise ValueError(ERROR_MESSAGE.format(torch_transform)) + resize_shorter = transform.size + elif isinstance(transform, CenterCrop): + crop = transform.size + elif isinstance(transform, Normalize): + mean = transform.mean + std = transform.std + elif isinstance(transform, MaybeToTensor): + continue + else: + raise ValueError(ERROR_MESSAGE.format(torch_transform)) + + dali_transforms = [ + partial( + fn.resize, + device="gpu", + interp_type=interp_type, + resize_shorter=resize_shorter, + ), + partial(fn.crop_mirror_normalize, device="gpu", crop=crop, mean=mean, std=std), + ] + return dali_transforms diff --git a/requirements/requirements_image.txt b/requirements/requirements_image.txt index 6c1d5f49..9fa5930c 100644 --- a/requirements/requirements_image.txt +++ b/requirements/requirements_image.txt @@ -1,3 +1,4 @@ nvidia-dali-cuda120 nvidia-nvjpeg2k-cu12 open_clip_torch +timm From fc5fefbfebb15986fbfa4a5f4a7355e9126b79bd Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 4 Sep 2024 17:31:19 -0700 Subject: [PATCH 29/57] Update init file Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo_curator/image/embedders/__init__.py b/nemo_curator/image/embedders/__init__.py index 27bf6aca..e0468c9f 100644 --- a/nemo_curator/image/embedders/__init__.py +++ b/nemo_curator/image/embedders/__init__.py @@ -13,5 +13,6 @@ # limitations under the License. from .base import ImageEmbedder from .open_clip import OpenClipImageEmbedder +from .timm import TimmImageEmbedder -__all__ = ["ImageEmbedder", "OpenClipImageEmbedder"] +__all__ = ["ImageEmbedder", "OpenClipImageEmbedder", "TimmImageEmbedder"] From 29eb2ba7587ff4f311c5793d931597f199ccd7a0 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 5 Sep 2024 09:07:01 -0700 Subject: [PATCH 30/57] Add autocast to timm Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/timm.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/nemo_curator/image/embedders/timm.py b/nemo_curator/image/embedders/timm.py index 124cf654..0b26d38f 100644 --- a/nemo_curator/image/embedders/timm.py +++ b/nemo_curator/image/embedders/timm.py @@ -35,6 +35,7 @@ def __init__( image_embedding_column: str = "image_embedding", normalize_embeddings: bool = True, classifiers: Iterable = [], + autocast: bool = True, ) -> None: super().__init__( model_name=model_name, @@ -45,6 +46,7 @@ def __init__( self.batch_size = batch_size self.num_threads_per_worker = num_threads_per_worker self.normalize_embeddings = normalize_embeddings + self.autocast = autocast # Load the model to get the transforms model = timm.create_model(self.model_name, pretrained=self.pretrained) @@ -111,11 +113,17 @@ def load_embedding_model(self, device="cuda"): model = model.to(device) def infer(batch): - image_features = model(batch) + if self.autocast: + with torch.autocast(device_type="cuda"): + image_features = model(batch) + else: + image_features = model(batch) + if self.normalize_embeddings: image_features = self.torch_normalized(image_features) - return image_features + # Inference can be done in lower precision, but cuDF can only handle fp32 + return image_features.to(torch.float32) return infer From 09ed9d6b26625abd49a1ff3e59ccd743265b974d Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 5 Sep 2024 09:27:15 -0700 Subject: [PATCH 31/57] Update requirements and transform Signed-off-by: Ryan Wolf --- nemo_curator/utils/image/transforms.py | 10 ++++++++-- requirements/requirements_image.txt | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/nemo_curator/utils/image/transforms.py b/nemo_curator/utils/image/transforms.py index e01f8f65..f56e5843 100644 --- a/nemo_curator/utils/image/transforms.py +++ b/nemo_curator/utils/image/transforms.py @@ -17,7 +17,13 @@ import nvidia.dali.fn as fn from timm.data.transforms import MaybeToTensor -from torchvision.transforms.transforms import CenterCrop, Compose, Normalize, Resize +from torchvision.transforms.transforms import ( + CenterCrop, + Compose, + InterpolationMode, + Normalize, + Resize, +) ERROR_MESSAGE = """Transforms do not conform to expected style and cannot be automatically converted. Expected: @@ -60,7 +66,7 @@ def convert_transforms_to_dali(torch_transform: Compose) -> List: # Loop over all transforms and extract relevant parameters for transform in torch_transform.transforms: if isinstance(transform, Resize): - if transform.interpolation != "bicubic": + if transform.interpolation != InterpolationMode.BICUBIC: raise ValueError(ERROR_MESSAGE.format(torch_transform)) resize_shorter = transform.size elif isinstance(transform, CenterCrop): diff --git a/requirements/requirements_image.txt b/requirements/requirements_image.txt index 9fa5930c..e53da7ee 100644 --- a/requirements/requirements_image.txt +++ b/requirements/requirements_image.txt @@ -1,4 +1,4 @@ nvidia-dali-cuda120 nvidia-nvjpeg2k-cu12 open_clip_torch -timm +timm>=1.0.8 From 2a6b510cb122244ef72d585a0bac699b3dbb7603 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 5 Sep 2024 09:58:01 -0700 Subject: [PATCH 32/57] Add additional interpolation support Signed-off-by: Ryan Wolf --- nemo_curator/utils/image/transforms.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/nemo_curator/utils/image/transforms.py b/nemo_curator/utils/image/transforms.py index f56e5843..201fe39b 100644 --- a/nemo_curator/utils/image/transforms.py +++ b/nemo_curator/utils/image/transforms.py @@ -16,6 +16,7 @@ from typing import List import nvidia.dali.fn as fn +from nvidia.dali.types import DALIInterpType from timm.data.transforms import MaybeToTensor from torchvision.transforms.transforms import ( CenterCrop, @@ -28,7 +29,7 @@ ERROR_MESSAGE = """Transforms do not conform to expected style and cannot be automatically converted. Expected: Compose( - Resize(interpolation=bicubic, max_size=None, antialias=True), + Resize(interpolation=bicubic or linear, max_size=None, antialias=True), CenterCrop(), MaybeToTensor(), Normalize(), @@ -39,6 +40,13 @@ Please manually convert the image transformations to use DALI """ +# Linear = Bilinear and Cubic = Bicubic +# https://docs.nvidia.com/deeplearning/dali/user-guide/docs/data_types.html#nvidia.dali.types.DALIInterpType +SUPPORTED_INTERPOLATIONS = { + InterpolationMode.BICUBIC: DALIInterpType.INTERP_CUBIC, + InterpolationMode.BILINEAR: DALIInterpType.INTERP_LINEAR, +} + def convert_transforms_to_dali(torch_transform: Compose) -> List: """ @@ -46,7 +54,7 @@ def convert_transforms_to_dali(torch_transform: Compose) -> List: Only works with transformations that follow this pattern: Compose( - Resize(interpolation=bicubic, max_size=None, antialias=True), + Resize(interpolation=bicubic or bilinear, max_size=None, antialias=True), CenterCrop(), MaybeToTensor(), Normalize(), @@ -61,13 +69,14 @@ def convert_transforms_to_dali(torch_transform: Compose) -> List: mean = [0.0] std = [1.0] resize_shorter = 0.0 - interp_type = "bicubic" + interp_type = DALIInterpType.INTERP_LINEAR # Loop over all transforms and extract relevant parameters for transform in torch_transform.transforms: if isinstance(transform, Resize): - if transform.interpolation != InterpolationMode.BICUBIC: + if transform.interpolation not in SUPPORTED_INTERPOLATIONS: raise ValueError(ERROR_MESSAGE.format(torch_transform)) + interp_type = SUPPORTED_INTERPOLATIONS[transform.interpolation] resize_shorter = transform.size elif isinstance(transform, CenterCrop): crop = transform.size From b6bda1908b0adea9099e597212b0abc3179fcc3e Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 5 Sep 2024 10:55:33 -0700 Subject: [PATCH 33/57] Fix transform normalization Signed-off-by: Ryan Wolf --- nemo_curator/utils/image/transforms.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/nemo_curator/utils/image/transforms.py b/nemo_curator/utils/image/transforms.py index 201fe39b..58ff5474 100644 --- a/nemo_curator/utils/image/transforms.py +++ b/nemo_curator/utils/image/transforms.py @@ -16,7 +16,7 @@ from typing import List import nvidia.dali.fn as fn -from nvidia.dali.types import DALIInterpType +from nvidia.dali.types import FLOAT, DALIInterpType from timm.data.transforms import MaybeToTensor from torchvision.transforms.transforms import ( CenterCrop, @@ -95,6 +95,14 @@ def convert_transforms_to_dali(torch_transform: Compose) -> List: interp_type=interp_type, resize_shorter=resize_shorter, ), - partial(fn.crop_mirror_normalize, device="gpu", crop=crop, mean=mean, std=std), + # We need to multiply by 255 because DALI deals entirely in pixel values + partial( + fn.crop_mirror_normalize, + device="gpu", + crop=crop, + dtype=FLOAT, + mean=mean * 255, + std=std * 255, + ), ] return dali_transforms From d57462e4b759d1aa2f04f06cbe41a9ec4cf64228 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 5 Sep 2024 11:55:44 -0700 Subject: [PATCH 34/57] Remove open_clip Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/__init__.py | 3 +- nemo_curator/image/embedders/open_clip.py | 129 ---------------------- nemo_curator/image/embedders/timm.py | 2 +- requirements/requirements_image.txt | 1 - 4 files changed, 2 insertions(+), 133 deletions(-) delete mode 100644 nemo_curator/image/embedders/open_clip.py diff --git a/nemo_curator/image/embedders/__init__.py b/nemo_curator/image/embedders/__init__.py index e0468c9f..56a1b3ed 100644 --- a/nemo_curator/image/embedders/__init__.py +++ b/nemo_curator/image/embedders/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from .base import ImageEmbedder -from .open_clip import OpenClipImageEmbedder from .timm import TimmImageEmbedder -__all__ = ["ImageEmbedder", "OpenClipImageEmbedder", "TimmImageEmbedder"] +__all__ = ["ImageEmbedder", "TimmImageEmbedder"] diff --git a/nemo_curator/image/embedders/open_clip.py b/nemo_curator/image/embedders/open_clip.py deleted file mode 100644 index 71c7489d..00000000 --- a/nemo_curator/image/embedders/open_clip.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -from typing import Iterable, Optional - -import nvidia.dali.fn as fn -import nvidia.dali.types as types -import open_clip -import torch -from nvidia.dali import pipeline_def -from nvidia.dali.plugin.pytorch import feed_ndarray - -from nemo_curator.image.embedders.base import ImageEmbedder - - -class OpenClipImageEmbedder(ImageEmbedder): - def __init__( - self, - model_name: str, - pretrained: Optional[str] = None, - batch_size: int = 1, - num_threads_per_worker: int = 4, - image_embedding_column: str = "image_embedding", - normalize_embeddings: bool = True, - classifiers: Iterable = [], - ) -> None: - super().__init__( - model_name=model_name, - image_embedding_column=image_embedding_column, - classifiers=classifiers, - ) - self.pretrained = pretrained - self.batch_size = batch_size - self.num_threads_per_worker = num_threads_per_worker - self.normalize_embeddings = normalize_embeddings - - def load_dataset_shard(self, tar_path: str, device_id=0): - # Create the DALI pipeline - @pipeline_def( - batch_size=self.batch_size, - num_threads=self.num_threads_per_worker, - device_id=device_id, - ) - def webdataset_pipeline(_tar_path: str): - img_raw, text, json = fn.readers.webdataset( - paths=_tar_path, - ext=["jpg", "txt", "json"], - missing_component_behavior="error", - ) - img = fn.decoders.image(img_raw, device="mixed", output_type=types.RGB) - img = fn.crop_mirror_normalize( - img, - dtype=types.FLOAT, - mean=[0, 0, 0], - std=[255, 255, 255], - ) - - resized = fn.resize(img, device="gpu", resize_shorter=224) - output = fn.crop_mirror_normalize( - resized, - dtype=types.FLOAT, - crop=(224, 224), - mean=[0.48145466, 0.4578275, 0.40821073], - std=[0.26862954, 0.26130258, 0.27577711], - ) - return output, text, json - - pipeline = webdataset_pipeline(tar_path) - pipeline.build() - - total_samples = pipeline.epoch_size() - total_samples = total_samples[list(total_samples.keys())[0]] - - samples_completed = 0 - while samples_completed < total_samples: - image, text, meta = pipeline.run() - image = image.as_tensor() - - image_torch = torch.empty( - image.shape(), dtype=torch.float32, device=f"cuda:{device_id}" - ) - feed_ndarray(image, image_torch) # COPY !!! - image = image_torch - - captions = [text.at(i).tostring().decode("utf-8") for i in range(len(text))] - metadata = [ - json.loads(meta.at(i).tostring().decode("utf-8")) - for i in range(len(meta)) - ] - - remaining_samples = total_samples - samples_completed - if image.shape[0] >= remaining_samples: - image = image[:remaining_samples] - captions = captions[:remaining_samples] - metadata = metadata[:remaining_samples] - - samples_completed += min(image.shape[0], remaining_samples) - - yield image, metadata - - def load_embedding_model(self, device="cuda"): - model = open_clip.create_model( - self.model_name, pretrained=self.pretrained, device=device - ) - model.eval() - - def infer(batch): - image_features = model.encode_image(batch) - if self.normalize_embeddings: - image_features = self.torch_normalized(image_features) - - return image_features - - return infer - - @staticmethod - def torch_normalized(a, dim=-1): - return torch.nn.functional.normalize(a, dim=dim) diff --git a/nemo_curator/image/embedders/timm.py b/nemo_curator/image/embedders/timm.py index 0b26d38f..79679b4f 100644 --- a/nemo_curator/image/embedders/timm.py +++ b/nemo_curator/image/embedders/timm.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from typing import Iterable, Optional +from typing import Iterable import nvidia.dali.fn as fn import nvidia.dali.types as types diff --git a/requirements/requirements_image.txt b/requirements/requirements_image.txt index e53da7ee..0195a6a7 100644 --- a/requirements/requirements_image.txt +++ b/requirements/requirements_image.txt @@ -1,4 +1,3 @@ nvidia-dali-cuda120 nvidia-nvjpeg2k-cu12 -open_clip_torch timm>=1.0.8 From bacd1c09caada4f22128e0c5b7eb05b4bb862dd2 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 5 Sep 2024 17:53:18 -0700 Subject: [PATCH 35/57] Add index path support to wds Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/timm.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/nemo_curator/image/embedders/timm.py b/nemo_curator/image/embedders/timm.py index 79679b4f..e84f451e 100644 --- a/nemo_curator/image/embedders/timm.py +++ b/nemo_curator/image/embedders/timm.py @@ -36,6 +36,7 @@ def __init__( normalize_embeddings: bool = True, classifiers: Iterable = [], autocast: bool = True, + use_index_files=False, ) -> None: super().__init__( model_name=model_name, @@ -47,6 +48,7 @@ def __init__( self.num_threads_per_worker = num_threads_per_worker self.normalize_embeddings = normalize_embeddings self.autocast = autocast + self.use_index_files = use_index_files # Load the model to get the transforms model = timm.create_model(self.model_name, pretrained=self.pretrained) @@ -63,8 +65,14 @@ def load_dataset_shard(self, tar_path: str, device_id=0): device_id=device_id, ) def webdataset_pipeline(_tar_path: str): + if self.use_index_files: + index_paths = [f"{_tar_path.rsplit('.', 1)[0]}.idx"] + else: + index_paths = [] + img_raw, text, json = fn.readers.webdataset( paths=_tar_path, + index_paths=index_paths, ext=["jpg", "txt", "json"], missing_component_behavior="error", ) From 8e66c8f2e4431935aa9f93e8e17a4a2c5c6687c5 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 6 Sep 2024 09:07:55 -0700 Subject: [PATCH 36/57] Address Vibhu's feedback Signed-off-by: Ryan Wolf --- nemo_curator/classifiers/fineweb_edu.py | 2 ++ .../datasets/image_text_pair_dataset.py | 5 ++++- nemo_curator/image/classifiers/aesthetic.py | 13 ++++++++++--- nemo_curator/image/classifiers/nsfw.py | 13 ++++++++++--- nemo_curator/image/embedders/base.py | 11 +++++++---- nemo_curator/image/embedders/timm.py | 17 ++++++++++++----- 6 files changed, 45 insertions(+), 16 deletions(-) diff --git a/nemo_curator/classifiers/fineweb_edu.py b/nemo_curator/classifiers/fineweb_edu.py index f51c5789..9722b282 100644 --- a/nemo_curator/classifiers/fineweb_edu.py +++ b/nemo_curator/classifiers/fineweb_edu.py @@ -52,6 +52,8 @@ def custom_forward(*args, **kwargs): if autocast: with torch.autocast(device_type="cuda"): output = original_forward(*args, **kwargs) + else: + output = original_forward(*args, **kwargs) return output.logits.squeeze(-1).float() model.forward = custom_forward diff --git a/nemo_curator/datasets/image_text_pair_dataset.py b/nemo_curator/datasets/image_text_pair_dataset.py index 54f91506..abdf244b 100644 --- a/nemo_curator/datasets/image_text_pair_dataset.py +++ b/nemo_curator/datasets/image_text_pair_dataset.py @@ -19,6 +19,7 @@ from functools import partial from typing import List, Optional +import dask.dataframe as dd import dask_cudf import fsspec import numpy as np @@ -27,7 +28,9 @@ class ImageTextPairDataset: - def __init__(self, path: str, metadata, tar_files: List[str], id_col: str) -> None: + def __init__( + self, path: str, metadata: dd.DataFrame, tar_files: List[str], id_col: str + ) -> None: self.path = path self.metadata = metadata self.tar_files = tar_files diff --git a/nemo_curator/image/classifiers/aesthetic.py b/nemo_curator/image/classifiers/aesthetic.py index 5057d15f..5f88f1d3 100644 --- a/nemo_curator/image/classifiers/aesthetic.py +++ b/nemo_curator/image/classifiers/aesthetic.py @@ -90,11 +90,18 @@ def load_model(self, device): weights = torch.load(self.model_path, map_location=torch.device("cpu")) model.load_state_dict(weights) model.eval() + model = self.configure_forward(model) - def infer(batch): - return model(batch).squeeze() + return model - return infer + def configure_forward(self, model): + original_forward = model.forward + + def custom_forward(*args, **kwargs): + return original_forward(*args, **kwargs).squeeze() + + model.forward = custom_forward + return model def postprocess(self, series): return series.list.leaves diff --git a/nemo_curator/image/classifiers/nsfw.py b/nemo_curator/image/classifiers/nsfw.py index a4e14be1..4fc681e6 100644 --- a/nemo_curator/image/classifiers/nsfw.py +++ b/nemo_curator/image/classifiers/nsfw.py @@ -94,11 +94,18 @@ def load_model(self, device): weights = torch.load(self.model_path, map_location=torch.device("cpu")) model.load_state_dict(weights) model.eval() + model = self.configure_forward(model) - def infer(batch): - return model(batch).squeeze() + return model - return infer + def configure_forward(self, model): + original_forward = model.forward + + def custom_forward(*args, **kwargs): + return original_forward(*args, **kwargs).squeeze() + + model.forward = custom_forward + return model def postprocess(self, series): return series.list.leaves diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index 8f597b53..2e864db3 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -17,6 +17,7 @@ import cupy as cp import torch +from tqdm import tqdm from nemo_curator.datasets import ImageTextPairDataset from nemo_curator.image.classifiers import ImageClassifier @@ -74,6 +75,10 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): image_ids = [] classifier_results = [[] for _ in self.classifiers] samples_completed = 0 + progress_bar = tqdm( + total=len(partition), + desc=f"{tar_path} - Embedding creation with {self.model_name}", + ) with torch.no_grad(): for batch, metadata in dataset: image_embeddings = model(batch) @@ -88,10 +93,8 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): batch_size = len(image_embeddings) samples_completed += batch_size - - print( - f"{tar_path} - Embedding Creation with {self.model_name} Samples Completed: {samples_completed}." - ) + progress_bar.update(batch_size) + progress_bar.close() if samples_completed != len(partition): raise RuntimeError( diff --git a/nemo_curator/image/embedders/timm.py b/nemo_curator/image/embedders/timm.py index e84f451e..e2761526 100644 --- a/nemo_curator/image/embedders/timm.py +++ b/nemo_curator/image/embedders/timm.py @@ -119,21 +119,28 @@ def webdataset_pipeline(_tar_path: str): def load_embedding_model(self, device="cuda"): model = timm.create_model(self.model_name, pretrained=self.pretrained).eval() model = model.to(device) + model = self.configure_forward(model) - def infer(batch): + return model + + def configure_forward(self, model): + original_forward = model.forward + + def custom_forward(*args, **kwargs): if self.autocast: with torch.autocast(device_type="cuda"): - image_features = model(batch) + image_features = original_forward(*args, **kwargs) else: - image_features = model(batch) + image_features = original_forward(*args, **kwargs) if self.normalize_embeddings: - image_features = self.torch_normalized(image_features) + image_features = torch.nn.functional.normalize(image_features, dim=-1) # Inference can be done in lower precision, but cuDF can only handle fp32 return image_features.to(torch.float32) - return infer + model.forward = custom_forward + return model @staticmethod def torch_normalized(a, dim=-1): From 946053e04a78a0e0fbd4543ee2f21ef5b6b89008 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 6 Sep 2024 10:59:38 -0700 Subject: [PATCH 37/57] Add import guard for image dataset Signed-off-by: Ryan Wolf --- nemo_curator/datasets/__init__.py | 7 +++++- nemo_curator/utils/import_utils.py | 37 ++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/nemo_curator/datasets/__init__.py b/nemo_curator/datasets/__init__.py index 1a67b16a..16f4343a 100644 --- a/nemo_curator/datasets/__init__.py +++ b/nemo_curator/datasets/__init__.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nemo_curator.utils.import_utils import image_only_import_from + from .doc_dataset import DocumentDataset -from .image_text_pair_dataset import ImageTextPairDataset + +ImageTextPairDataset = image_only_import_from( + "nemo_curator.datasets.image_text_pair_dataset", "ImageTextPairDataset" +) __all__ = ["DocumentDataset", "ImageTextPairDataset"] diff --git a/nemo_curator/utils/import_utils.py b/nemo_curator/utils/import_utils.py index d404158e..6ffee38e 100644 --- a/nemo_curator/utils/import_utils.py +++ b/nemo_curator/utils/import_utils.py @@ -382,3 +382,40 @@ def gpu_only_import_from(module, symbol, *, alt=None): msg=f"{module}.{symbol} is not enabled in non GPU-enabled installations or environments. {GPU_INSTALL_STRING}", alt=alt, ) + + +IMAGE_INSTALL_STRING = """Install image packages via `pip install --extra-index-url https://pypi.nvidia.com nemo-curator[image]` +or use `pip install --extra-index-url https://pypi.nvidia.com ".[image]"` if installing from source""" + + +def image_only_import_from(module, symbol, *, alt=None): + """A function used to import symbols required only in image installs + + This function will attempt to import a module with the given name. + This function will attempt to import a symbol with the given name from + the given module, but it will not throw an ImportError if the symbol is not + found. Instead, it will return a placeholder object which will raise an + exception only if used with instructions on installing an image build. + + Parameters + ---------- + module: str + The name of the module to import. + symbol: str + The name of the symbol to import. + alt: object + An optional object to be used in place of the given symbol if it fails + to import in a non-image install + + Returns + ------- + object + The imported symbol, the given alternate, or a class derived from + UnavailableMeta. + """ + return safe_import_from( + module, + symbol, + msg=f"{module}.{symbol} is not enabled in without the nemo-curator[image] installations or environments. {IMAGE_INSTALL_STRING}", + alt=alt, + ) From 015d40cca341ea7c9e11bf61de0b44c2a9044cdb Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 6 Sep 2024 17:13:46 -0700 Subject: [PATCH 38/57] Change default device Signed-off-by: Ryan Wolf --- nemo_curator/image/classifiers/base.py | 3 +-- nemo_curator/image/embedders/base.py | 7 ++++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo_curator/image/classifiers/base.py b/nemo_curator/image/classifiers/base.py index 2e9d915b..b9f5d52d 100644 --- a/nemo_curator/image/classifiers/base.py +++ b/nemo_curator/image/classifiers/base.py @@ -58,8 +58,7 @@ def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: ) def _run_inference(self, partition, partition_info=None): - device_id = int(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]) - device = f"cuda:{device_id}" + device = "cuda" model = load_object_on_worker( self.model_name, diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index 2e864db3..92438881 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -55,8 +55,9 @@ def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: def _run_inference(self, partition, tar_paths, id_col, partition_info=None): tar_path = tar_paths[partition_info["number"]] - device_id = int(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]) - device = f"cuda:{device_id}" + # device_id = int(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]) + # device = f"cuda:{device_id}" + device = "cuda" model = load_object_on_worker( self.model_name, @@ -70,7 +71,7 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): ) classifier_models.append(loaded_classifier) - dataset = self.load_dataset_shard(tar_path, device_id=device_id) + dataset = self.load_dataset_shard(tar_path, device_id=0) final_image_embeddings = [] image_ids = [] classifier_results = [[] for _ in self.classifiers] From 852863d702ece1addc4f19687805251b2fb4fb79 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 6 Sep 2024 17:22:09 -0700 Subject: [PATCH 39/57] Remove commented code Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index 92438881..35c60eea 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -55,8 +55,6 @@ def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: def _run_inference(self, partition, tar_paths, id_col, partition_info=None): tar_path = tar_paths[partition_info["number"]] - # device_id = int(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]) - # device = f"cuda:{device_id}" device = "cuda" model = load_object_on_worker( From e7e320f966f8a28031228dac1db33629caf3648d Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Sun, 8 Sep 2024 16:44:58 -0700 Subject: [PATCH 40/57] Remove device id Signed-off-by: Ryan Wolf --- nemo_curator/image/embedders/base.py | 2 +- nemo_curator/image/embedders/timm.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index 35c60eea..5acfc0f0 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -69,7 +69,7 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): ) classifier_models.append(loaded_classifier) - dataset = self.load_dataset_shard(tar_path, device_id=0) + dataset = self.load_dataset_shard(tar_path) final_image_embeddings = [] image_ids = [] classifier_results = [[] for _ in self.classifiers] diff --git a/nemo_curator/image/embedders/timm.py b/nemo_curator/image/embedders/timm.py index e2761526..3a7c258a 100644 --- a/nemo_curator/image/embedders/timm.py +++ b/nemo_curator/image/embedders/timm.py @@ -57,12 +57,12 @@ def __init__( ) self.dali_transforms = convert_transforms_to_dali(torch_transforms) - def load_dataset_shard(self, tar_path: str, device_id=0): + def load_dataset_shard(self, tar_path: str): # Create the DALI pipeline @pipeline_def( batch_size=self.batch_size, num_threads=self.num_threads_per_worker, - device_id=device_id, + device_id=0, ) def webdataset_pipeline(_tar_path: str): if self.use_index_files: @@ -94,9 +94,7 @@ def webdataset_pipeline(_tar_path: str): image, text, meta = pipeline.run() image = image.as_tensor() - image_torch = torch.empty( - image.shape(), dtype=torch.float32, device=f"cuda:{device_id}" - ) + image_torch = torch.empty(image.shape(), dtype=torch.float32, device="cuda") feed_ndarray(image, image_torch) # COPY !!! image = image_torch From 92f47a065bbc8822ef67c221b10eeccf54b5d868 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Sun, 8 Sep 2024 17:27:25 -0700 Subject: [PATCH 41/57] Fix index issue Signed-off-by: Ryan Wolf --- nemo_curator/image/classifiers/aesthetic.py | 4 +++- nemo_curator/image/classifiers/nsfw.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/nemo_curator/image/classifiers/aesthetic.py b/nemo_curator/image/classifiers/aesthetic.py index 5f88f1d3..99e8c68b 100644 --- a/nemo_curator/image/classifiers/aesthetic.py +++ b/nemo_curator/image/classifiers/aesthetic.py @@ -104,4 +104,6 @@ def custom_forward(*args, **kwargs): return model def postprocess(self, series): - return series.list.leaves + new_series = series.list.leaves + new_series.index = series.index + return new_series diff --git a/nemo_curator/image/classifiers/nsfw.py b/nemo_curator/image/classifiers/nsfw.py index 4fc681e6..e66abcae 100644 --- a/nemo_curator/image/classifiers/nsfw.py +++ b/nemo_curator/image/classifiers/nsfw.py @@ -108,4 +108,6 @@ def custom_forward(*args, **kwargs): return model def postprocess(self, series): - return series.list.leaves + new_series = series.list.leaves + new_series.index = series.index + return new_series From 0eca48ff21dc45be49dc2546431447e1a6a8d56b Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 9 Sep 2024 16:38:14 -0700 Subject: [PATCH 42/57] Add docstrings and standardize variable names Signed-off-by: Ryan Wolf --- .../datasets/image_text_pair_dataset.py | 82 +++++++++++++++++-- nemo_curator/image/classifiers/aesthetic.py | 31 ++++++- nemo_curator/image/classifiers/base.py | 70 ++++++++++++++-- nemo_curator/image/classifiers/nsfw.py | 32 +++++++- nemo_curator/image/embedders/base.py | 65 ++++++++++++++- nemo_curator/image/embedders/timm.py | 71 ++++++++++++++-- 6 files changed, 316 insertions(+), 35 deletions(-) diff --git a/nemo_curator/datasets/image_text_pair_dataset.py b/nemo_curator/datasets/image_text_pair_dataset.py index abdf244b..39427b7a 100644 --- a/nemo_curator/datasets/image_text_pair_dataset.py +++ b/nemo_curator/datasets/image_text_pair_dataset.py @@ -28,9 +28,43 @@ class ImageTextPairDataset: + """ + A collection of image text pairs stored in webdataset-like format on disk or in cloud storage. + + The exact format assumes a single directory with sharded .tar, .parquet, and (optionally) + .idx files. Each tar file should have a unique integer id as it's name (00000.tar, + 00001.tar, 00002.tar, etc.). The tar files should contain images in .jpg files, text captions + in .txt files, and metadata in .json files. Each record of the dataset is identified by + a unique id that is a mix of the shard id along with the offset of the record within a shard. + For example, the 32rd record of the 43rd shard would be in 00042.tar and have image 000420031.jpg, + caption 000420031.txt, and metadata 000420031.json (assuming zero indexing). + + In addition to the collection of tar files, ImageTextPairDataset expects there to be .parquet files + in the root directory that follow the same naming convention as the shards (00042.tar -> 00042.parquet). + Each parquet file should contain an aggregated tabular form of the metadata for each record, with + each row in the parquet file corresponding to a record in that shard. The metadata, both in the parquet + files and the json files, must contain a unique id column that is the same as its record id (000420031 + in our examples). + + Index files may also be in the directory to speed up dataloading with DALI. + The index files must be generated by DALI's wds2idx tool. + See https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/dataloading_webdataset.html#Creating-an-index + for more information. Each index file must follow the same naming convention as the tar files + (00042.tar -> 00042.idx). + """ + def __init__( self, path: str, metadata: dd.DataFrame, tar_files: List[str], id_col: str ) -> None: + """ + Constructs an image text pair dataset. + + Args: + path (str): The root directory of the files. + metadata (dd.DataFrame): A dask cudf dataframe of the metadata. + tar_files (List[str]): A list of paths to the tar files. + id_col (str): The column storing the unique identifier for each record. + """ self.path = path self.metadata = metadata self.tar_files = tar_files @@ -38,6 +72,13 @@ def __init__( @classmethod def from_webdataset(cls, path: str, id_col: str): + """ + Loads an ImageTextPairDataset from a webdataset + + Args: + path (str): The path to the webdataset-like format on disk or cloud storage. + id_col (str): The column storing the unique identifier for each record. + """ metadata = dask_cudf.read_parquet(path) metadata = metadata.map_partitions(cls._sort_partition, id_col=id_col) @@ -53,6 +94,7 @@ def _sort_partition(partition, id_col): def _get_tar_files(path: str) -> List[str]: glob_str = os.path.join(path, "*.tar") # open_files doesn't actually open a file descriptor + # tar_files is sorted by default tar_files = [file.path for file in open_files(glob_str)] return tar_files @@ -74,6 +116,16 @@ def _name_partition( def save_metadata( self, path: Optional[str] = None, columns: Optional[List[str]] = None ) -> None: + """ + Saves the metadata of the dataset to the specified path as a collection + of parquet files. + + Args: + path (Optional[str]): The path to save the metadata to. If None, + writes to the original path. + columns (Optional[List[str]]): If specified, only saves a subset + of columns. + """ if path is None: path = self.path @@ -150,23 +202,37 @@ def _get_eligible_samples(self, output_path: str, samples_per_shard: int): yield curr_df, total_tar_samples @staticmethod - def combine_id(shard_id, sample_id, max_shards=5, max_samples_per_shard=4) -> str: + def _combine_id(shard_id, sample_id, max_shards=5, max_samples_per_shard=4) -> str: int_id = sample_id + (10**max_samples_per_shard) * shard_id n_digits = max_samples_per_shard + max_shards combined_id = f"{int_id:0{n_digits}d}" return combined_id - def split_id(combined_id: str, max_shards=5): - return int(combined_id[:max_shards]), int(combined_id[max_shards:]) - def to_webdataset( self, path: str, filter_column: str, samples_per_shard: int = 10000, - max_shards=5, - old_id_col=None, + max_shards: int = 5, + old_id_col: Optional[str] = None, ) -> None: + """ + Saves the dataset to a webdataset format with parquet files. + Will reshard the tar files to the specified number of samples per shard. + The id value in ImageTextPairDataset.id_col will be overwritten with a new id. + + Args: + path (str): The output path where the dataset should be written. + filter_column (str): A column of booleans. All samples with a value of True + in this column will be included in the output. Otherwise, the sample + will be omitted. + samples_per_shard (int): The number of samples to include in each tar file. + max_shards (int): The order of magnitude of the maximum number of shards + that will be created from the dataset. Will be used to determine the + number of leading zeros in the shard/sample ids. + old_id_col (Optional[str]): If specified, will preserve the previous + id value in the given column. + """ max_samples_per_shard = math.ceil(math.log10(samples_per_shard)) filtered_metadata = self.metadata[self.metadata[filter_column]] @@ -191,7 +257,7 @@ def to_webdataset( new_ids = np.arange(len(shard_df)) convert_ids = partial( - self.combine_id, + self._combine_id, shard_id, max_shards=max_shards, max_samples_per_shard=max_samples_per_shard, @@ -206,7 +272,7 @@ def to_webdataset( for i, (member, data) in enumerate(shard_tar): # Rename the each member to match the new id sample_id = int(i // members_per_sample) - member_id = self.combine_id( + member_id = self._combine_id( shard_id, sample_id, max_shards=max_shards, diff --git a/nemo_curator/image/classifiers/aesthetic.py b/nemo_curator/image/classifiers/aesthetic.py index 99e8c68b..885515c6 100644 --- a/nemo_curator/image/classifiers/aesthetic.py +++ b/nemo_curator/image/classifiers/aesthetic.py @@ -46,16 +46,39 @@ def forward(self, x): class AestheticClassifier(ImageClassifier): + """ + LAION-Aesthetics_Predictor V2 is a linear classifier trained on top of + OpenAI CLIP ViT-L/14 image embeddings. It is used to assess the aesthetic + quality of images. More information on the model can be found here: + https://laion.ai/blog/laion-aesthetics/. + """ + def __init__( self, - embedding_column: str = "image_embedding", + image_embedding_column: str = "image_embedding", pred_column: str = "aesthetic_score", batch_size: int = -1, model_path: Optional[str] = None, ) -> None: + """ + Constructs the classifier. + + Args: + image_embedding_column (str): The column name that stores the image + embeddings. + pred_column (str): The column name to be added where the aesthetic + scores will be stored. + pred_type (Union[str, type]): The datatype of the pred_column + batch_size (int): If greater than 0, the image embeddings + will be processed in batches of at most this size. If less than 0, + all embeddings will be processed at once. + model_path (Optional[str]): If specified, will load the model from the + given path. If not specified, will default to being stored in + NEMO_CURATOR_HOME. + """ super().__init__( model_name="aesthetic_classifier", - embedding_column=embedding_column, + image_embedding_column=image_embedding_column, pred_column=pred_column, pred_type=float, batch_size=batch_size, @@ -90,11 +113,11 @@ def load_model(self, device): weights = torch.load(self.model_path, map_location=torch.device("cpu")) model.load_state_dict(weights) model.eval() - model = self.configure_forward(model) + model = self._configure_forward(model) return model - def configure_forward(self, model): + def _configure_forward(self, model): original_forward = model.forward def custom_forward(*args, **kwargs): diff --git a/nemo_curator/image/classifiers/base.py b/nemo_curator/image/classifiers/base.py index b9f5d52d..597d1b62 100644 --- a/nemo_curator/image/classifiers/base.py +++ b/nemo_curator/image/classifiers/base.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from abc import ABC, abstractmethod -from typing import Union +from typing import Callable, Union +import cudf import cupy as cp import torch @@ -26,26 +26,55 @@ class ImageClassifier(ABC): """ An abstract base class that represents a classifier on top - of embeddings generated by a CLIP vision encoder + of embeddings generated by a CLIP vision encoder. + + Subclasses only need to define how a model is loaded. + They may also override the postprocess method if they would like + to modify output series of predictions before it gets combined into + the dataset. The classifier must be able to fit on a single GPU. """ def __init__( self, model_name: str, - embedding_column: str, + image_embedding_column: str, pred_column: str, pred_type: Union[str, type], batch_size: int, embedding_size: int, ) -> None: + """ + Constructs an image classifier + + Args: + model_name (str): A unqiue name to identify the model on each worker + and in the logs. + image_embedding_column (str): The column name that stores the image + embeddings. + pred_column (str): The column name to be added where the classifier's + predictions will be stored. + pred_type (Union[str, type]): The datatype of the pred_column + batch_size (int): If greater than 0, the image embeddings + will be processed in batches of at most this size. If less than 0, + all embeddings will be processed at once. + """ self.model_name = model_name - self.embedding_column = embedding_column + self.image_embedding_column = image_embedding_column self.pred_column = pred_column self.pred_type = pred_type self.batch_size = batch_size self.embedding_size = embedding_size def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: + """ + Classifies all embeddings in the dataset + + Args: + dataset (ImageTextPairDataset): The dataset to classify. + + Returns: + ImageTextPairDataset: A dataset with classifier scores. + """ meta = dataset.metadata.dtypes.to_dict() meta[self.pred_column] = self.pred_type embedding_df = dataset.metadata.map_partitions(self._run_inference, meta=meta) @@ -67,7 +96,7 @@ def _run_inference(self, partition, partition_info=None): ) embeddings = torch.as_tensor( - partition[self.embedding_column].list.leaves.values.reshape( + partition[self.image_embedding_column].list.leaves.values.reshape( len(partition), -1 ), device=device, @@ -76,7 +105,7 @@ def _run_inference(self, partition, partition_info=None): if self.embedding_size != embeddings.shape[-1]: raise RuntimeError( f"{self.model_name} expects embedding size {self.embedding_size} but column " - f"'{self.embedding_column}' has embedding size {embeddings.shape[-1]}. Ensure your " + f"'{self.image_embedding_column}' has embedding size {embeddings.shape[-1]}. Ensure your " "classifier is compatible with the CLIP model you used to generate the embeddings." ) @@ -104,8 +133,31 @@ def _run_inference(self, partition, partition_info=None): return partition @abstractmethod - def load_model(self, device): + def load_model(self, device: str) -> Callable: + """ + Loads the classifier model + + Args: + device (str): A PyTorch device identifier that specifies what GPU + to load the model on. + + Returns: + Callable: A callable model, usually a torch.nn.Module. + The input to this model will be the batches of images output + by the ImageEmbedder.load_dataset_shard. + """ pass - def postprocess(self, series): + def postprocess(self, series: cudf.Series) -> cudf.Series: + """ + Postprocesses the predictions of the classifier before saving + them to the metadata. + + Args: + series (cudf.Series): The cudf series of raw model predictions. + + Returns: + cudf.Series: The same series unmodified. Override in your classifier + if needed. + """ return series diff --git a/nemo_curator/image/classifiers/nsfw.py b/nemo_curator/image/classifiers/nsfw.py index e66abcae..07588191 100644 --- a/nemo_curator/image/classifiers/nsfw.py +++ b/nemo_curator/image/classifiers/nsfw.py @@ -53,16 +53,40 @@ def forward(self, x): class NsfwClassifier(ImageClassifier): + """ + NSFW Classifier is a small MLP trained on top of + Laion's ViT-H image embeddings. It is used to assess the likelihood + of images containing sexually explicit materal. + More information on the model can be found here: + https://github.com/LAION-AI/CLIP-based-NSFW-Detector. + """ + def __init__( self, - embedding_column: str = "image_embedding", + image_embedding_column: str = "image_embedding", pred_column: str = "nsfw_score", batch_size: int = -1, model_path: Optional[str] = None, ) -> None: + """ + Constructs the classifier. + + Args: + image_embedding_column (str): The column name that stores the image + embeddings. + pred_column (str): The column name to be added where the nsfw + scores will be stored. + pred_type (Union[str, type]): The datatype of the pred_column + batch_size (int): If greater than 0, the image embeddings + will be processed in batches of at most this size. If less than 0, + all embeddings will be processed at once. + model_path (Optional[str]): If specified, will load the model from the + given path. If not specified, will default to being stored in + NEMO_CURATOR_HOME. + """ super().__init__( model_name="nsfw_classifier", - embedding_column=embedding_column, + image_embedding_column=image_embedding_column, pred_column=pred_column, pred_type=float, batch_size=batch_size, @@ -94,11 +118,11 @@ def load_model(self, device): weights = torch.load(self.model_path, map_location=torch.device("cpu")) model.load_state_dict(weights) model.eval() - model = self.configure_forward(model) + model = self._configure_forward(model) return model - def configure_forward(self, model): + def _configure_forward(self, model): original_forward = model.forward def custom_forward(*args, **kwargs): diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index 5acfc0f0..cebb4998 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from abc import ABC, abstractmethod -from typing import Iterable +from typing import Callable, Iterable import cupy as cp import torch @@ -26,17 +25,49 @@ class ImageEmbedder(ABC): + """ + An abstract base class for generating image embeddings. + + Subclasses only need to define how a model is loaded and a dataset + is read in from a tar file shard. This class handles distributing + the tasks across workers and saving the metadata to the dataset. + The embedding model must be able to fit onto a single GPU. + """ + def __init__( self, model_name: str, image_embedding_column: str, classifiers: Iterable[ImageClassifier], ) -> None: + """ + Constructs an image embedder + + Args: + model_name (str): A unqiue name to identify the model on each worker + and in the logs. + image_embedding_column (str): The column name to be added where the + image embeddings will be saved. + classifiers (Iterable[ImageClassifier]): A collection of classifiers. If + the iterable has a nonzero length, all classifiers will be loaded + on the GPU at the same time and be passed the image embeddings + immediately after they are created. + """ self.model_name = model_name self.image_embedding_column = image_embedding_column self.classifiers = classifiers def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: + """ + Generates image embeddings for all images in the dataset + + Args: + dataset (ImageTextPairDataset): The dataset to create image embeddings for + + Returns: + ImageTextPairDataset: A dataset with image embeddings and potentially + classifier scores. + """ meta = dataset.metadata.dtypes.to_dict() meta[self.image_embedding_column] = "object" for classifier in self.classifiers: @@ -122,8 +153,36 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): @abstractmethod def load_dataset_shard(self, tar_path: str) -> Iterable: + """ + Loads images and metadata from a tarfile in the dataset + + Args: + tar_path (str): The path to a tar file shard in the input webdataset. + + Returns: + Iterable: An iterator over the dataset. Each iteration should produce + A tuple of (image, metadata) pairs. The batch of images will be passed + directly to the model created by ImageEmbedder.load_embedding_model. + The metadata must be a list of dictionaries. Each element of the list + must correspond to the image in the batch at the same position. + Each dictionary must contain a field that is the same as + id_field in the dataset. This id field in the metadata will be used + to match the image to the its record in the metadata (parquet) files. + """ pass @abstractmethod - def load_embedding_model(self, device): + def load_embedding_model(self, device: str) -> Callable: + """ + Loads the model used to generate image embeddings + + Args: + device (str): A PyTorch device identifier that specifies what GPU + to load the model on. + + Returns: + Callable: A callable model, usually a torch.nn.Module. + The input to this model will be the batches of images output + by the ImageEmbedder.load_dataset_shard. + """ pass diff --git a/nemo_curator/image/embedders/timm.py b/nemo_curator/image/embedders/timm.py index 3a7c258a..4adea8d4 100644 --- a/nemo_curator/image/embedders/timm.py +++ b/nemo_curator/image/embedders/timm.py @@ -26,6 +26,14 @@ class TimmImageEmbedder(ImageEmbedder): + """ + PyTorch Image Models (timm) is a library containing SOTA computer vision + models. Many of these models are useful in generating image embeddings + for modules in NeMo Curator. This module can also automatically convert + the image transformations from PyTorch transformations to DALI transformations + in the supported models. + """ + def __init__( self, model_name: str, @@ -36,8 +44,34 @@ def __init__( normalize_embeddings: bool = True, classifiers: Iterable = [], autocast: bool = True, - use_index_files=False, + use_index_files: bool = False, ) -> None: + """ + Constructs the embedder + + Args: + model_name (str): The timm model to use. A list of available models + can be found by running timm.list_models() + pretrained (bool): If True, loads the pretrained weights of the model. + batch_size (int): The number of images to run inference on in a single batch. + If the batch_size is larger than the number of elements in a shard, only + the number of elements in a shard will be used. + num_threads_per_worker (int): The number of threads per worker (GPU) to use + for loading images with DALI. + image_embedding_column (str): The output column where the embeddings will be + stored in the dataset. + normalize_embeddings (bool): Whether to normalize the embeddings output by the + model. Defaults to True. + classifiers (Iterable): A collection of classifiers to immediately apply on top + of the image embeddings. + autocast (bool): If True, runs the timm model using torch.autocast(). + use_index_files (bool): If True, tries to find and use index files generated + by DALI at the same path as the tar file shards. The index files must be + generated by DALI's wds2idx tool. See https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/dataloading_webdataset.html#Creating-an-index + for more information. Each index file must be of the form "shard_id.idx" + where shard_id is the same integer as the corresponding tar file for the + data. The index files must be in the same folder as the tar files. + """ super().__init__( model_name=model_name, image_embedding_column=image_embedding_column, @@ -58,6 +92,21 @@ def __init__( self.dali_transforms = convert_transforms_to_dali(torch_transforms) def load_dataset_shard(self, tar_path: str): + """ + Loads a webdataset tar shard using DALI + + Args: + tar_path (str): The path of the tar shard to load + + Returns: + Iterable: An iterator over the dataset. Each tar file + must have 3 files per record. A jpg file, a txt file, + and a json file. The jpg file must contain the image, the + txt file must contain the associated caption, and the + json must contain the metadata for the record (including + its id). Images will be loaded using DALI. + """ + # Create the DALI pipeline @pipeline_def( batch_size=self.batch_size, @@ -115,13 +164,25 @@ def webdataset_pipeline(_tar_path: str): yield image, metadata def load_embedding_model(self, device="cuda"): + """ + Loads the model used to generate image embeddings + + Args: + device (str): A PyTorch device identifier that specifies what GPU + to load the model on. + + Returns: + Callable: A timm model loaded on the specified device. + The model's forward call may be augmented with torch.autocast() + or embedding normalization if specified in the constructor. + """ model = timm.create_model(self.model_name, pretrained=self.pretrained).eval() model = model.to(device) - model = self.configure_forward(model) + model = self._configure_forward(model) return model - def configure_forward(self, model): + def _configure_forward(self, model): original_forward = model.forward def custom_forward(*args, **kwargs): @@ -139,7 +200,3 @@ def custom_forward(*args, **kwargs): model.forward = custom_forward return model - - @staticmethod - def torch_normalized(a, dim=-1): - return torch.nn.functional.normalize(a, dim=dim) From 59763a17927b03caab5655d002c7d418fc3ad8ff Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 10 Sep 2024 09:41:20 -0700 Subject: [PATCH 43/57] Add image curation tutorial Signed-off-by: Ryan Wolf --- tutorials/image-curation/image-curation.ipynb | 417 ++++++++++++++++++ 1 file changed, 417 insertions(+) create mode 100644 tutorials/image-curation/image-curation.ipynb diff --git a/tutorials/image-curation/image-curation.ipynb b/tutorials/image-curation/image-curation.ipynb new file mode 100644 index 00000000..c52faf4e --- /dev/null +++ b/tutorials/image-curation/image-curation.ipynb @@ -0,0 +1,417 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Image Curation in NeMo Curator\n", + "\n", + "In the following notebook, we'll be exploring all of the functionality that NeMo Curator has for image dataset curation.\n", + "NeMo Curator has a few built-in modules for \n", + "\n", + "First, we'll need to install NeMo Curator!\n", + "\n", + "NOTE: Please ensure you meet the [requirements](https://github.com/NVIDIA/NeMo-Curator/tree/main?tab=readme-ov-file#install-nemo-curator) before proceeding!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install NeMo Curator" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "!pip install cython\n", + "!pip install --extra-index-url https://pypi.nvidia.com nemo-curator[image]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download a Sample Dataset\n", + "If you already have a dataset in webdataset format, great! You can skip to the [next section](#create-clip-embeddings).\n", + "In order to have a sample dataset to play with, we are going to download a subset of the [Microsoft Common Objects in Context (mscoco)](https://cocodataset.org/#home) dataset.\n", + "MSCOCO is a dataset of 600,000 image-text pairs (around 76GB) that takes around 20 minutes to download.\n", + "For the sake of this tutorial, we are only going to download a subset of the dataset.\n", + "We will download 20,000 image-text pairs (around 3GB).\n", + "\n", + "To download the dataset, we are going to use a tool called img2dataset. Let's install it and download the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "!pip install img2dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we need to get a list of URLs that identify where all the images are hosted" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "!wget https://huggingface.co/datasets/ChristophSchuhmann/MS_COCO_2017_URL_TEXT/resolve/main/mscoco.parquet" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We truncate this list of URLs so we don't download the whole dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "NUM_URLS = 20_000\n", + "urls = pd.read_parquet(\"mscoco.parquet\")\n", + "truncated_urls = urls[:NUM_URLS]\n", + "truncated_urls.to_parquet(\"truncated_mscoco.parquet\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, let's start the download." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "!img2dataset \\\n", + " --url_list truncated_mscoco.parquet \\\n", + " --input_format \"parquet\" \\\n", + " --output_folder mscoco \\\n", + " --output_format webdataset \\\n", + " --url_col \"URL\" \\\n", + " --caption_col \"TEXT\" \\\n", + " --processes_count 16 \\\n", + " --thread_count 64 \\\n", + " --resize_mode no" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create CLIP Embeddings\n", + "\n", + "### Load the Dataset\n", + "Instead of operating on the images directly, most features in NeMo Curator take embeddings as inputs. So, as the first stage in the pipeline, we are going to generate embeddings for all the images in the dataset. To begin, let's load the dataset using NeMo Curator." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Change the dataset path if you have your own dataset\n", + "dataset_path = \"./mscoco\"\n", + "# Change the unique identifier depending on your dataset\n", + "id_col = \"key\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from nemo_curator.datasets import ImageTextPairDataset\n", + "\n", + "dataset = ImageTextPairDataset.from_webdataset(dataset_path, id_col)\n", + "# Filter out any entries that failed to download\n", + "dataset.metadata = dataset.metadata[dataset.metadata[\"error_message\"].isna()]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Choose the Embedder\n", + "We can now define the embedding creation step in our pipeline. NeMo Curator has support for all [timm](https://pypi.org/project/timm/) models. NeMo Curator's aesthetic classifier is trained on embeddings from `vit_large_patch14_clip_quickgelu_224.openai`, so we will use that.\n", + "\n", + "The cell below will do the following:\n", + "1. Download the model `vit_large_patch14_clip_quickgelu_224.openai`.\n", + "1. Automatically convert the image preprocessing transformations of `vit_large_patch14_clip_quickgelu_224.openai` from their PyTorch form to DALI." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from nemo_curator.image.embedders import TimmImageEmbedder\n", + "\n", + "embedding_model = TimmImageEmbedder(\n", + " \"vit_large_patch14_clip_quickgelu_224.openai\",\n", + " pretrained=True,\n", + " batch_size=1024,\n", + " num_threads_per_worker=16,\n", + " normalize_embeddings=True,\n", + " autocast=False,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now create embeddings for the whole dataset. It's important to understand what is going on internally in NeMo Curator so you can modify parameters appropriately.\n", + "\n", + "Once the computation is triggered, the cell below will\n", + "1. Load a shard of metadata (a `.parquet` file) onto each GPU you have available using Dask-cuDF.\n", + "1. Load a copy of `vit_large_patch14_clip_quickgelu_224.openai` onto each GPU.\n", + "1. Repeatedly load images into batches of size `batch_size` onto each GPU with a given threads per worker (`num_threads_per_worker`) using DALI.\n", + "1. The model is run on the batch (without `torch.autocast()` since `autocast=False`).\n", + "1. The output embeddings of the model are normalized since `normalize_embeddings=True`.\n", + "\n", + "\n", + "Since NeMo Curator uses Dask, the cell below will not cause the embeddings to be created. The computation will only begin once we inspect the output in the `dataset.metadata.head()` call or when we write to disk using `dataset.save_metadata()` or `dataset.to_webdataset()`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Dask is lazy, so this will not compute embeddings\n", + "dataset = embedding_model(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# This triggers the computation for the first shard in the dataset\n", + "dataset.metadata.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Aesthetic Classifier\n", + "With the embeddings now created, we can use the aesthetic classifier. This classifier assigns a score from 0 to 10 that corresponds to how aesthetically pleasing the image is. A score of 0 means that the image is not pleasant to look at, while a score of 10 is pleasant to look at. The exact classifier used is the LAION-Aesthetics_Predictor V2. More information on the model can be found here: https://laion.ai/blog/laion-aesthetics/.\n", + "\n", + "The following cell will download the model to your local storage at `NEMO_CURATOR_HOME` (`/home/user/.nemo_curator`). The model is only 3.6MB." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from nemo_curator.image.classifiers import AestheticClassifier\n", + "\n", + "aesthetic_classifier = AestheticClassifier()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = aesthetic_classifier(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_dataset_path = \"./output_dataset\"\n", + "dataset.metadata[\"passes_aesthetic_check\"] = dataset.metadata[\"aesthetic_score\"] > 5\n", + "dataset.to_webdataset(output_dataset_path, \"passes_aesthetic_check\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize Results\n", + "Now that we have filtered our dataset based on aesthetic score, we can see what kinds of images are remaining." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import tarfile\n", + "import io\n", + "from PIL import Image\n", + "import matplotlib.pyplot as plt\n", + "import os\n", + "\n", + "def display_image_from_tar(tar_file_path, image_file_name):\n", + " # Open the tar file\n", + " with tarfile.open(tar_file_path, 'r') as tar:\n", + " # Extract the specified image file\n", + " image_file = tar.extractfile(image_file_name)\n", + " \n", + " if image_file is not None:\n", + " # Read the image data\n", + " image_data = image_file.read()\n", + " \n", + " # Create a PIL Image object from the image data\n", + " image = Image.open(io.BytesIO(image_data))\n", + " \n", + " # Display the image using matplotlib\n", + " plt.figure(figsize=(10, 8))\n", + " plt.imshow(image)\n", + " plt.axis('off') # Hide axes\n", + " plt.title(f\"Image: {image_file_name}\")\n", + " plt.show()\n", + " else:\n", + " print(f\"Image file '{image_file_name}' not found in the tar archive.\")\n", + " \n", + "\n", + "output_shard = os.path.join(output_dataset_path, \"00000.tar\")\n", + "image_file_name = '000000003.jpg'\n", + "display_image_from_tar(output_shard, image_file_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Semantic Deduplication\n", + "\n", + "NeMo Curator provides an easy module for semantically deduplicating images. Semantic duplicates are images that contain almost the same information content, but are perceptually different. Two imges of the same dog taken from slightly different angles would be considered semantic duplicates. NeMo Curator' semantic deduplication approach is based on the paper [SemDeDup: Data-efficient learning at web-scale through semantic deduplication](https://arxiv.org/pdf/2303.09540) by Abbas et al which has demonstrated that deduplicating your data can lead to the same downstream performance in *half* the number of training iterations. For more information on the algorithm, you can check out the [documentation page](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/semdedup.html#data-curator-semdedup)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from nemo_curator.datasets import DocumentDataset\n", + "from nemo_curator import ClusteringModel, SemanticClusterLevelDedup\n", + "\n", + "# Convert the dataset\n", + "embeddings_dataset = DocumentDataset(dataset.metadata)\n", + "embeddings_dataset.df = embeddings_dataset.df.rename(columns={\"image_embeddings\": \"embeddings\"})\n", + "\n", + "semantic_dedup_outputs = \"./semantic_deduplication\"\n", + "os.makedirs(semantic_dedup_outputs, exist_ok=True)\n", + "\n", + "# Run clustering\n", + "clustering_output = os.path.join(semantic_dedup_outputs, \"cluster_output\")\n", + "clustering_model = ClusteringModel(\n", + " id_col=id_col,\n", + " max_iter=10,\n", + " n_clusters=1,\n", + " clustering_output_dir=clustering_output,\n", + ")\n", + "clustered_dataset = clustering_model(embeddings_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run cluster-level dedup\n", + "emb_by_cluster_output = os.path.join(semantic_dedup_outputs, \"emb_by_cluster\")\n", + "sorted_cluster_output = os.path.join(semantic_dedup_outputs, \"sorted_cluster\")\n", + "duplicate_output = os.path.join(semantic_dedup_outputs, \"duplicates\")\n", + "\n", + "semantic_dedup = SemanticClusterLevelDedup(\n", + " n_clusters=50000,\n", + " emb_by_clust_dir=emb_by_cluster_output,\n", + " sorted_clusters_dir=sorted_cluster_output,\n", + " id_col=id_col,\n", + " id_col_type=\"str\",\n", + " which_to_keep=\"hard\",\n", + " output_dir=duplicate_output,\n", + ")\n", + "semantic_dedup.compute_semantic_match_dfs()\n", + "deduplicated_dataset_ids = semantic_dedup.extract_dedup_data(eps_to_extract=0.07)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize duplicates" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Remove duplicates" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 40e1549ff7087d2c000091a9430432628d4cacc6 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 16 Sep 2024 10:03:05 -0700 Subject: [PATCH 44/57] Add initial image docs Signed-off-by: Ryan Wolf --- README.md | 4 +- docs/user-guide/api/datasets.rst | 8 +++ docs/user-guide/api/image/classifiers.rst | 21 +++++++ docs/user-guide/api/image/embedders.rst | 18 ++++++ docs/user-guide/api/image/index.rst | 10 ++++ docs/user-guide/api/index.rst | 1 + .../user-guide/{images => assets}/diagram.png | Bin .../sorted_sequence_dataloader.png | Bin .../{images => assets}/zeroshot_ablations.png | Bin .../distributeddataclassification.rst | 2 +- docs/user-guide/image/index.rst | 7 +++ docs/user-guide/index.rst | 26 +++++++- setup.py | 2 +- tutorials/image-curation/image-curation.ipynb | 56 +++++++++++------- 14 files changed, 128 insertions(+), 27 deletions(-) create mode 100644 docs/user-guide/api/image/classifiers.rst create mode 100644 docs/user-guide/api/image/embedders.rst create mode 100644 docs/user-guide/api/image/index.rst rename docs/user-guide/{images => assets}/diagram.png (100%) rename docs/user-guide/{images => assets}/sorted_sequence_dataloader.png (100%) rename docs/user-guide/{images => assets}/zeroshot_ablations.png (100%) create mode 100644 docs/user-guide/image/index.rst diff --git a/README.md b/README.md index ed52a337..830adf1a 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ πŸš€ **The GPU-Accelerated Open Source Framework for Efficient Large Language Model Data Curation** πŸš€

- diagram + diagram

NeMo Curator is a Python library specifically designed for fast and scalable dataset preparation and curation for [large language model (LLM)](https://www.nvidia.com/en-us/glossary/large-language-models/) use-cases such as foundation model pretraining, domain-adaptive pretraining (DAPT), supervised fine-tuning (SFT) and paramter-efficient fine-tuning (PEFT). It greatly accelerates data curation by leveraging GPUs with [Dask](https://www.dask.org/) and [RAPIDS](https://developer.nvidia.com/rapids), resulting in significant time savings. The library provides a customizable and modular interface, simplifying pipeline expansion and accelerating model convergence through the preparation of high-quality tokens. @@ -191,7 +191,7 @@ The modules within NeMo Curator were primarily designed to curate high-quality d The following figure shows that the use of different data curation modules implemented in NeMo Curator led to improved model zero-shot downstream task performance.

- drawing + drawing

In terms of scalability and compute performance, using the combination of RAPIDS and Dask fuzzy deduplication enabled us to deduplicate the 1.1 Trillion token Red Pajama dataset in 1.8 hours with 64 NVIDIA A100 Tensor Core GPUs. diff --git a/docs/user-guide/api/datasets.rst b/docs/user-guide/api/datasets.rst index 43e532b1..c8dba791 100644 --- a/docs/user-guide/api/datasets.rst +++ b/docs/user-guide/api/datasets.rst @@ -7,4 +7,12 @@ DocumentDataset ------------------- .. autoclass:: nemo_curator.datasets.DocumentDataset + :members: + + +------------------------------- +ImageTextPairDataset +------------------------------- + +.. autoclass:: nemo_curator.datasets.ImageTextPairDataset :members: \ No newline at end of file diff --git a/docs/user-guide/api/image/classifiers.rst b/docs/user-guide/api/image/classifiers.rst new file mode 100644 index 00000000..a43560e5 --- /dev/null +++ b/docs/user-guide/api/image/classifiers.rst @@ -0,0 +1,21 @@ +====================================== +Classifiers +====================================== + +------------------------------ +Base Class +------------------------------ + +.. autoclass:: nemo_curator.image.classifiers.ImageClassifier + :members: + + +------------------------------ +Image Classifiers +------------------------------ + +.. autoclass:: nemo_curator.image.classifiers.AestheticClassifier + :members: + +.. autoclass:: nemo_curator.image.classifiers.NsfwClassifier + :members: \ No newline at end of file diff --git a/docs/user-guide/api/image/embedders.rst b/docs/user-guide/api/image/embedders.rst new file mode 100644 index 00000000..aa1de81e --- /dev/null +++ b/docs/user-guide/api/image/embedders.rst @@ -0,0 +1,18 @@ +====================================== +Embedders +====================================== + +------------------------------ +Base Class +------------------------------ + +.. autoclass:: nemo_curator.image.embedders.ImageEmbedder + :members: + + +------------------------------ +Timm +------------------------------ + +.. autoclass:: nemo_curator.image.embedders.TimmImageEmbedder + :members: \ No newline at end of file diff --git a/docs/user-guide/api/image/index.rst b/docs/user-guide/api/image/index.rst new file mode 100644 index 00000000..c58862f4 --- /dev/null +++ b/docs/user-guide/api/image/index.rst @@ -0,0 +1,10 @@ +====================================== +Image Curation +====================================== + +.. toctree:: + :maxdepth: 4 + :titlesonly: + + embedders.rst + classifiers.rst \ No newline at end of file diff --git a/docs/user-guide/api/index.rst b/docs/user-guide/api/index.rst index 866f06b9..b76dd75b 100644 --- a/docs/user-guide/api/index.rst +++ b/docs/user-guide/api/index.rst @@ -18,4 +18,5 @@ API Reference decontamination.rst services.rst synthetic.rst + image/index.rst misc.rst \ No newline at end of file diff --git a/docs/user-guide/images/diagram.png b/docs/user-guide/assets/diagram.png similarity index 100% rename from docs/user-guide/images/diagram.png rename to docs/user-guide/assets/diagram.png diff --git a/docs/user-guide/images/sorted_sequence_dataloader.png b/docs/user-guide/assets/sorted_sequence_dataloader.png similarity index 100% rename from docs/user-guide/images/sorted_sequence_dataloader.png rename to docs/user-guide/assets/sorted_sequence_dataloader.png diff --git a/docs/user-guide/images/zeroshot_ablations.png b/docs/user-guide/assets/zeroshot_ablations.png similarity index 100% rename from docs/user-guide/images/zeroshot_ablations.png rename to docs/user-guide/assets/zeroshot_ablations.png diff --git a/docs/user-guide/distributeddataclassification.rst b/docs/user-guide/distributeddataclassification.rst index bc67f127..22252bc7 100644 --- a/docs/user-guide/distributeddataclassification.rst +++ b/docs/user-guide/distributeddataclassification.rst @@ -126,7 +126,7 @@ The key feature of CrossFit used in NeMo Curator is the sorted sequence data loa - Groups sorted sequences into optimized batches. - Efficiently allocates batches to the the provided GPU memories by estimating the memory footprint for each sequence length and batch size. -.. image:: images/sorted_sequence_dataloader.png +.. image:: assets/sorted_sequence_dataloader.png :alt: Sorted Sequence Data Loader Check out the `rapidsai/crossfit`_ repository for more information. diff --git a/docs/user-guide/image/index.rst b/docs/user-guide/image/index.rst new file mode 100644 index 00000000..bbb68054 --- /dev/null +++ b/docs/user-guide/image/index.rst @@ -0,0 +1,7 @@ +.. toctree:: + :maxdepth: 4 + :titlesonly: + + + embedders.rst + classifiers.rst \ No newline at end of file diff --git a/docs/user-guide/index.rst b/docs/user-guide/index.rst index 070bba97..b72513d2 100644 --- a/docs/user-guide/index.rst +++ b/docs/user-guide/index.rst @@ -1,5 +1,9 @@ .. include:: datacuration.rsts +------------------- +Text Curation +------------------- + :ref:`Downloading and Extracting Text ` Downloading a massive public dataset is usually the first step in data curation, and it can be cumbersome due to the dataset’s massive size and hosting method. This section describes how to download and extract large corpora efficiently. @@ -19,7 +23,7 @@ Both exact and fuzzy deduplication functionalities are supported in NeMo Curator and accelerated using RAPIDS cuDF. :ref:`GPU Accelerated Semantic Deduplication ` - NeMo-Curator provides scalable and GPU accelerated semantic deduplication functionality using RAPIDS cuML, cuDF, crossfit and Pytorch. + NeMo Curator provides scalable and GPU accelerated semantic deduplication functionality using RAPIDS cuML, cuDF, crossfit and Pytorch. :ref:`Synthetic Data Generation ` Synthetic data generation tools and example piplines are available within NeMo Curator. @@ -30,6 +34,26 @@ :ref:`Personally Identifiable Information Identification and Removal ` The purpose of the personally identifiable information (PII) redaction tool is to help scrub sensitive data out of training datasets +------------------- +Image Curation +------------------- + +:ref:`Image-Text Pair Datasets ` + Image-text pair datasets are commonly used as the basis for training multimodal generative models. NeMo Curator interfaces with the standardized Webdataset format for curating such datasets. + +:ref:`Image Embedding Creation ` + Image embeddings are the backbone to many data curation operations in NeMo Curator. This section describes how to efficiently create embeddings for massive datasets. + +:ref:`Classifiers ` + NeMo Curator provides several ways to use common classifiers like aesthetic scoring, and not-safe-for-work (NSFW) scoring. + +:ref:`Semantic Deduplication ` + Semantic deduplication with image datasets has been shown to drastically improve model performance. NeMo Curator has a semnatic deduplication module that can work with any modality. + +------------------- +Reference +------------------- + :ref:`NeMo Curator on Kubernetes ` Demonstration of how to run the NeMo Curator on a Dask Cluster deployed on top of Kubernetes diff --git a/setup.py b/setup.py index 6c488dec..3245bdf8 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def req_file(filename, folder="requirements"): setup( name="nemo_curator", - version="0.4.0", + version="0.5.0", description="Scalable Data Preprocessing Tool for " "Training Large Language Models", long_description=long_description, diff --git a/tutorials/image-curation/image-curation.ipynb b/tutorials/image-curation/image-curation.ipynb index c52faf4e..6cf6ceaa 100644 --- a/tutorials/image-curation/image-curation.ipynb +++ b/tutorials/image-curation/image-curation.ipynb @@ -14,27 +14,6 @@ "NOTE: Please ensure you meet the [requirements](https://github.com/NVIDIA/NeMo-Curator/tree/main?tab=readme-ov-file#install-nemo-curator) before proceeding!" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Install NeMo Curator" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "vscode": { - "languageId": "plaintext" - } - }, - "outputs": [], - "source": [ - "!pip install cython\n", - "!pip install --extra-index-url https://pypi.nvidia.com nemo-curator[image]" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -132,6 +111,28 @@ " --resize_mode no" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install NeMo Curator" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "!pip install cython\n", + "!pip install --extra-index-url https://pypi.nvidia.com nemo-curator[image]\n", + "%env DASK_DATAFRAME__QUERY_PLANNING False" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -142,6 +143,17 @@ "Instead of operating on the images directly, most features in NeMo Curator take embeddings as inputs. So, as the first stage in the pipeline, we are going to generate embeddings for all the images in the dataset. To begin, let's load the dataset using NeMo Curator." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from nemo_curator import get_client\n", + "\n", + "client = get_client(cluster_type=\"gpu\")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -271,7 +283,7 @@ "outputs": [], "source": [ "output_dataset_path = \"./output_dataset\"\n", - "dataset.metadata[\"passes_aesthetic_check\"] = dataset.metadata[\"aesthetic_score\"] > 5\n", + "dataset.metadata[\"passes_aesthetic_check\"] = dataset.metadata[\"aesthetic_score\"] > 6\n", "dataset.to_webdataset(output_dataset_path, \"passes_aesthetic_check\")" ] }, From 0d857b4eada7a5ada10d353d47fca28d66b693b9 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 17 Sep 2024 14:37:43 -0700 Subject: [PATCH 45/57] Remove tutorial Signed-off-by: Ryan Wolf --- tutorials/image-curation/image-curation.ipynb | 429 ------------------ 1 file changed, 429 deletions(-) delete mode 100644 tutorials/image-curation/image-curation.ipynb diff --git a/tutorials/image-curation/image-curation.ipynb b/tutorials/image-curation/image-curation.ipynb deleted file mode 100644 index 6cf6ceaa..00000000 --- a/tutorials/image-curation/image-curation.ipynb +++ /dev/null @@ -1,429 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Image Curation in NeMo Curator\n", - "\n", - "In the following notebook, we'll be exploring all of the functionality that NeMo Curator has for image dataset curation.\n", - "NeMo Curator has a few built-in modules for \n", - "\n", - "First, we'll need to install NeMo Curator!\n", - "\n", - "NOTE: Please ensure you meet the [requirements](https://github.com/NVIDIA/NeMo-Curator/tree/main?tab=readme-ov-file#install-nemo-curator) before proceeding!" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Download a Sample Dataset\n", - "If you already have a dataset in webdataset format, great! You can skip to the [next section](#create-clip-embeddings).\n", - "In order to have a sample dataset to play with, we are going to download a subset of the [Microsoft Common Objects in Context (mscoco)](https://cocodataset.org/#home) dataset.\n", - "MSCOCO is a dataset of 600,000 image-text pairs (around 76GB) that takes around 20 minutes to download.\n", - "For the sake of this tutorial, we are only going to download a subset of the dataset.\n", - "We will download 20,000 image-text pairs (around 3GB).\n", - "\n", - "To download the dataset, we are going to use a tool called img2dataset. Let's install it and download the dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "vscode": { - "languageId": "plaintext" - } - }, - "outputs": [], - "source": [ - "!pip install img2dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First, we need to get a list of URLs that identify where all the images are hosted" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "vscode": { - "languageId": "plaintext" - } - }, - "outputs": [], - "source": [ - "!wget https://huggingface.co/datasets/ChristophSchuhmann/MS_COCO_2017_URL_TEXT/resolve/main/mscoco.parquet" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We truncate this list of URLs so we don't download the whole dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "\n", - "NUM_URLS = 20_000\n", - "urls = pd.read_parquet(\"mscoco.parquet\")\n", - "truncated_urls = urls[:NUM_URLS]\n", - "truncated_urls.to_parquet(\"truncated_mscoco.parquet\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, let's start the download." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "vscode": { - "languageId": "plaintext" - } - }, - "outputs": [], - "source": [ - "!img2dataset \\\n", - " --url_list truncated_mscoco.parquet \\\n", - " --input_format \"parquet\" \\\n", - " --output_folder mscoco \\\n", - " --output_format webdataset \\\n", - " --url_col \"URL\" \\\n", - " --caption_col \"TEXT\" \\\n", - " --processes_count 16 \\\n", - " --thread_count 64 \\\n", - " --resize_mode no" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Install NeMo Curator" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "vscode": { - "languageId": "plaintext" - } - }, - "outputs": [], - "source": [ - "!pip install cython\n", - "!pip install --extra-index-url https://pypi.nvidia.com nemo-curator[image]\n", - "%env DASK_DATAFRAME__QUERY_PLANNING False" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create CLIP Embeddings\n", - "\n", - "### Load the Dataset\n", - "Instead of operating on the images directly, most features in NeMo Curator take embeddings as inputs. So, as the first stage in the pipeline, we are going to generate embeddings for all the images in the dataset. To begin, let's load the dataset using NeMo Curator." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from nemo_curator import get_client\n", - "\n", - "client = get_client(cluster_type=\"gpu\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Change the dataset path if you have your own dataset\n", - "dataset_path = \"./mscoco\"\n", - "# Change the unique identifier depending on your dataset\n", - "id_col = \"key\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from nemo_curator.datasets import ImageTextPairDataset\n", - "\n", - "dataset = ImageTextPairDataset.from_webdataset(dataset_path, id_col)\n", - "# Filter out any entries that failed to download\n", - "dataset.metadata = dataset.metadata[dataset.metadata[\"error_message\"].isna()]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Choose the Embedder\n", - "We can now define the embedding creation step in our pipeline. NeMo Curator has support for all [timm](https://pypi.org/project/timm/) models. NeMo Curator's aesthetic classifier is trained on embeddings from `vit_large_patch14_clip_quickgelu_224.openai`, so we will use that.\n", - "\n", - "The cell below will do the following:\n", - "1. Download the model `vit_large_patch14_clip_quickgelu_224.openai`.\n", - "1. Automatically convert the image preprocessing transformations of `vit_large_patch14_clip_quickgelu_224.openai` from their PyTorch form to DALI." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from nemo_curator.image.embedders import TimmImageEmbedder\n", - "\n", - "embedding_model = TimmImageEmbedder(\n", - " \"vit_large_patch14_clip_quickgelu_224.openai\",\n", - " pretrained=True,\n", - " batch_size=1024,\n", - " num_threads_per_worker=16,\n", - " normalize_embeddings=True,\n", - " autocast=False,\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can now create embeddings for the whole dataset. It's important to understand what is going on internally in NeMo Curator so you can modify parameters appropriately.\n", - "\n", - "Once the computation is triggered, the cell below will\n", - "1. Load a shard of metadata (a `.parquet` file) onto each GPU you have available using Dask-cuDF.\n", - "1. Load a copy of `vit_large_patch14_clip_quickgelu_224.openai` onto each GPU.\n", - "1. Repeatedly load images into batches of size `batch_size` onto each GPU with a given threads per worker (`num_threads_per_worker`) using DALI.\n", - "1. The model is run on the batch (without `torch.autocast()` since `autocast=False`).\n", - "1. The output embeddings of the model are normalized since `normalize_embeddings=True`.\n", - "\n", - "\n", - "Since NeMo Curator uses Dask, the cell below will not cause the embeddings to be created. The computation will only begin once we inspect the output in the `dataset.metadata.head()` call or when we write to disk using `dataset.save_metadata()` or `dataset.to_webdataset()`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Dask is lazy, so this will not compute embeddings\n", - "dataset = embedding_model(dataset)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# This triggers the computation for the first shard in the dataset\n", - "dataset.metadata.head()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Aesthetic Classifier\n", - "With the embeddings now created, we can use the aesthetic classifier. This classifier assigns a score from 0 to 10 that corresponds to how aesthetically pleasing the image is. A score of 0 means that the image is not pleasant to look at, while a score of 10 is pleasant to look at. The exact classifier used is the LAION-Aesthetics_Predictor V2. More information on the model can be found here: https://laion.ai/blog/laion-aesthetics/.\n", - "\n", - "The following cell will download the model to your local storage at `NEMO_CURATOR_HOME` (`/home/user/.nemo_curator`). The model is only 3.6MB." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from nemo_curator.image.classifiers import AestheticClassifier\n", - "\n", - "aesthetic_classifier = AestheticClassifier()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = aesthetic_classifier(dataset)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "output_dataset_path = \"./output_dataset\"\n", - "dataset.metadata[\"passes_aesthetic_check\"] = dataset.metadata[\"aesthetic_score\"] > 6\n", - "dataset.to_webdataset(output_dataset_path, \"passes_aesthetic_check\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Visualize Results\n", - "Now that we have filtered our dataset based on aesthetic score, we can see what kinds of images are remaining." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import tarfile\n", - "import io\n", - "from PIL import Image\n", - "import matplotlib.pyplot as plt\n", - "import os\n", - "\n", - "def display_image_from_tar(tar_file_path, image_file_name):\n", - " # Open the tar file\n", - " with tarfile.open(tar_file_path, 'r') as tar:\n", - " # Extract the specified image file\n", - " image_file = tar.extractfile(image_file_name)\n", - " \n", - " if image_file is not None:\n", - " # Read the image data\n", - " image_data = image_file.read()\n", - " \n", - " # Create a PIL Image object from the image data\n", - " image = Image.open(io.BytesIO(image_data))\n", - " \n", - " # Display the image using matplotlib\n", - " plt.figure(figsize=(10, 8))\n", - " plt.imshow(image)\n", - " plt.axis('off') # Hide axes\n", - " plt.title(f\"Image: {image_file_name}\")\n", - " plt.show()\n", - " else:\n", - " print(f\"Image file '{image_file_name}' not found in the tar archive.\")\n", - " \n", - "\n", - "output_shard = os.path.join(output_dataset_path, \"00000.tar\")\n", - "image_file_name = '000000003.jpg'\n", - "display_image_from_tar(output_shard, image_file_name)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Semantic Deduplication\n", - "\n", - "NeMo Curator provides an easy module for semantically deduplicating images. Semantic duplicates are images that contain almost the same information content, but are perceptually different. Two imges of the same dog taken from slightly different angles would be considered semantic duplicates. NeMo Curator' semantic deduplication approach is based on the paper [SemDeDup: Data-efficient learning at web-scale through semantic deduplication](https://arxiv.org/pdf/2303.09540) by Abbas et al which has demonstrated that deduplicating your data can lead to the same downstream performance in *half* the number of training iterations. For more information on the algorithm, you can check out the [documentation page](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/semdedup.html#data-curator-semdedup)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from nemo_curator.datasets import DocumentDataset\n", - "from nemo_curator import ClusteringModel, SemanticClusterLevelDedup\n", - "\n", - "# Convert the dataset\n", - "embeddings_dataset = DocumentDataset(dataset.metadata)\n", - "embeddings_dataset.df = embeddings_dataset.df.rename(columns={\"image_embeddings\": \"embeddings\"})\n", - "\n", - "semantic_dedup_outputs = \"./semantic_deduplication\"\n", - "os.makedirs(semantic_dedup_outputs, exist_ok=True)\n", - "\n", - "# Run clustering\n", - "clustering_output = os.path.join(semantic_dedup_outputs, \"cluster_output\")\n", - "clustering_model = ClusteringModel(\n", - " id_col=id_col,\n", - " max_iter=10,\n", - " n_clusters=1,\n", - " clustering_output_dir=clustering_output,\n", - ")\n", - "clustered_dataset = clustering_model(embeddings_dataset)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Run cluster-level dedup\n", - "emb_by_cluster_output = os.path.join(semantic_dedup_outputs, \"emb_by_cluster\")\n", - "sorted_cluster_output = os.path.join(semantic_dedup_outputs, \"sorted_cluster\")\n", - "duplicate_output = os.path.join(semantic_dedup_outputs, \"duplicates\")\n", - "\n", - "semantic_dedup = SemanticClusterLevelDedup(\n", - " n_clusters=50000,\n", - " emb_by_clust_dir=emb_by_cluster_output,\n", - " sorted_clusters_dir=sorted_cluster_output,\n", - " id_col=id_col,\n", - " id_col_type=\"str\",\n", - " which_to_keep=\"hard\",\n", - " output_dir=duplicate_output,\n", - ")\n", - "semantic_dedup.compute_semantic_match_dfs()\n", - "deduplicated_dataset_ids = semantic_dedup.extract_dedup_data(eps_to_extract=0.07)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Visualize duplicates" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Remove duplicates" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - } - ], - "metadata": { - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From b4e474a0481df353620cc1bdf3f560e60c31f244 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 19 Sep 2024 08:51:22 -0700 Subject: [PATCH 46/57] Add dataset docs Signed-off-by: Ryan Wolf --- docs/user-guide/image/datasets.rst | 117 +++++++++++++++++++++++++++++ docs/user-guide/image/index.rst | 2 +- 2 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 docs/user-guide/image/datasets.rst diff --git a/docs/user-guide/image/datasets.rst b/docs/user-guide/image/datasets.rst new file mode 100644 index 00000000..f4e09ce3 --- /dev/null +++ b/docs/user-guide/image/datasets.rst @@ -0,0 +1,117 @@ +.. _data-curator-image-datasets: + +========================= +Image-Text Pair Dataset +========================= + +Image-text pair datasets are commonly used for training generative text to image models or CLIP models. +NeMo Curator supports reading and writing datasets based on the `WebDataset `_ file format. +This format allows NeMo Curator to annotate the dataset with metadata including embeddings and classifier scores. +Its sharded format also makes it easier to distribute work to different workers processing the dataset. + +------------ +File Format +------------ + +Here is an example of what a dataset directory that is in the Webdataset format should look like. + +:: + + dataset/ + β”œβ”€β”€ 00000.tar + β”‚ β”œβ”€β”€ 000000000.jpg + β”‚ β”œβ”€β”€ 000000000.json + β”‚ β”œβ”€β”€ 000000000.txt + β”‚ β”œβ”€β”€ 000000001.jpg + β”‚ β”œβ”€β”€ 000000001.json + β”‚ β”œβ”€β”€ 000000001.txt + β”‚ └── ... + β”œβ”€β”€ 00001.tar + β”‚ β”œβ”€β”€ 000010000.jpg + β”‚ β”œβ”€β”€ 000010000.json + β”‚ β”œβ”€β”€ 000010000.txt + β”‚ β”œβ”€β”€ 000010001.jpg + β”‚ β”œβ”€β”€ 000010001.json + β”‚ β”œβ”€β”€ 000010001.txt + β”‚ └── ... + β”œβ”€β”€ 00002.tar + β”‚ └── ... + β”œβ”€β”€ 00000.parquet + β”œβ”€β”€ 00001.parquet + └── 00002.parquet + + +The exact format assumes a single directory with sharded ``.tar``, ``.parquet``, and (optionally) +``.idx`` files. Each tar file should have a unique integer id as it's name (``00000.tar``, +``00001.tar``, ``00002.tar``, etc.). The tar files should contain images in ``.jpg`` files, text captions +in ``.txt`` files, and metadata in ``.json`` files. Each record of the dataset is identified by +a unique id that is a mix of the shard id along with the offset of the record within a shard. +For example, the 32rd record of the 43rd shard would be in ``00042.tar`` and have image ``000420031.jpg``, +caption ``000420031.txt``, and metadata ``000420031.json`` (assuming zero indexing). + +In addition to the collection of tar files, NeMo Curator's ``ImageTextPairDataset`` expects there to be .parquet files +in the root directory that follow the same naming convention as the shards (``00042.tar`` -> ``00042.parquet``). +Each parquet file should contain an aggregated tabular form of the metadata for each record, with +each row in the parquet file corresponding to a record in that shard. The metadata, both in the parquet +files and the json files, must contain a unique id column that is the same as its record id (000420031 +in our examples). + +------- +Reading +------- + +Datasets can be read in using ``ImageTextPairDataset.from_webdataset()`` + +.. code-block:: python + from nemo_curator.datasets import ImageTextPairDataset + + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") + +* ``path="/path/to/dataset"`` should point to the root directory of the WebDataset. +* ``id_col="key"`` lets us know that the unique id column in the dataset is named "key" + +------- +Writing +------- + +There are two ways to write an image dataset. The first way only saves the metadata, while the second way will reshard the tar files. +Both trigger the computation of all the tasks you have set to run beforehand. + +.. code-block:: python + from nemo_curator.datasets import ImageTextPairDataset + + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") + + # Perform your operations (embedding creation, classifiers, etc.) + + dataset.save_metadata() + +``save_metadata()`` will only save sharded parquet files to the target directory. It does not modify the tar files. +There are two optional parameters + +* ``path`` allows you to change the location of where the dataset is saved. By default, it will overwrite the original parquet files. +* ``columns`` allows you to only save a subset of metadata. By default, all metadata will be saved. + + +.. code-block:: python + from nemo_curator.datasets import ImageTextPairDataset + + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") + + # Perform your operations (embedding creation, classifiers, etc.) + + dataset.to_webdataset(path="/path/to/output", filter_column="passes_curation") + +``to_webdataset()`` will reshard the webdataset to only include elements that have a value of ``True`` in the ``filter_column``. +Resharding can take a while, so this should typically only be done at the end of your curation pipeline when you are ready to export the dataset for training. + + +------------- +Index Files +------------- + +NeMo Curator uses `DALI `_ for image data loading from the tar files. +In order to speed up the data loading, you can supply ``.idx`` files in your dataset. +The index files must be generated by DALI's wds2idx tool. +See the `DALI documentation `_ for more information. +Each index file must follow the same naming convention as the tar files (00042.tar -> 00042.idx). \ No newline at end of file diff --git a/docs/user-guide/image/index.rst b/docs/user-guide/image/index.rst index bbb68054..7cb88c2d 100644 --- a/docs/user-guide/image/index.rst +++ b/docs/user-guide/image/index.rst @@ -2,6 +2,6 @@ :maxdepth: 4 :titlesonly: - + datasets.rst embedders.rst classifiers.rst \ No newline at end of file From 4b1f0080ba6c74524fd3b8ebd88bf0f0a4e20306 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 19 Sep 2024 11:00:44 -0700 Subject: [PATCH 47/57] Add embedder documentation Signed-off-by: Ryan Wolf --- docs/user-guide/image/datasets.rst | 4 ++ docs/user-guide/image/embedders.rst | 76 +++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 docs/user-guide/image/embedders.rst diff --git a/docs/user-guide/image/datasets.rst b/docs/user-guide/image/datasets.rst index f4e09ce3..19c06b3a 100644 --- a/docs/user-guide/image/datasets.rst +++ b/docs/user-guide/image/datasets.rst @@ -70,6 +70,8 @@ Datasets can be read in using ``ImageTextPairDataset.from_webdataset()`` * ``path="/path/to/dataset"`` should point to the root directory of the WebDataset. * ``id_col="key"`` lets us know that the unique id column in the dataset is named "key" +A more thorough list of parameters can be found in the `API Reference `_. + ------- Writing ------- @@ -106,6 +108,8 @@ There are two optional parameters Resharding can take a while, so this should typically only be done at the end of your curation pipeline when you are ready to export the dataset for training. +A more thorough list of parameters can be found in the `API Reference `_. + ------------- Index Files ------------- diff --git a/docs/user-guide/image/embedders.rst b/docs/user-guide/image/embedders.rst new file mode 100644 index 00000000..f9684c5c --- /dev/null +++ b/docs/user-guide/image/embedders.rst @@ -0,0 +1,76 @@ +.. _data-curator-image-embedding: + +========================= +Image Embedders +========================= + +-------------------- +Timm Image Embedder +-------------------- + +PyTorch Image Models (timm) is a library containing SOTA computer vision models. +Many of these models are useful in generating image embeddings for modules in NeMo Curator. +NeMo Curator provides easy support for all these models through ``TimmImageEmbedder``. +This module can also automatically convert the image transformations from PyTorch transformations to DALI transformations in the supported models. + +.. code-block:: python + + from nemo_curator.datasets import ImageTextPairDataset + from nemo_curator.image.embedders import TimmImageEmbedder + + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") + + embedding_model = TimmImageEmbedder( + "vit_large_patch14_clip_quickgelu_224.openai", + pretrained=True, + batch_size=1024, + num_threads_per_worker=16, + normalize_embeddings=True, + autocast=False, + ) + + dataset_with_embeddings = embedding_model(dataset) + + dataset_with_embeddings.save_metadata() + +Here, we load a dataset in and compute the image embeddings using ``vit_large_patch14_clip_quickgelu_224.openai``. +This model is the base for NeMo Curator's aesthetic classifier, so we use it for this example. + +A more thorough list of parameters can be found in the `API Reference `_. + +Under the hood, the image embedding model performs the following operations: + +1. Load a shard of metadata (a `.parquet` file) onto each GPU you have available using Dask-cuDF. +1. Load a copy of `vit_large_patch14_clip_quickgelu_224.openai` onto each GPU. +1. Repeatedly load images into batches of size `batch_size` onto each GPU with a given threads per worker (`num_threads_per_worker`) using DALI. +1. The model is run on the batch (without `torch.autocast()` since `autocast=False`). +1. The output embeddings of the model are normalized since `normalize_embeddings=True`. + + +------------------------ +Custom Image Embedder +------------------------ + +To write your own custom embedder, you inherit from ``nemo_curator.image.embedders.ImageEmbedder`` and override two methods as shown below: + +.. code-block:: python + + from nemo_curator.image.embedders import ImageEmbedder + + class MyCustomEmbedder(ImageEmbedder): + + def load_dataset_shard(self, tar_path: str) -> Iterable: + # Implement me! + pass + + def load_embedding_model(self, device: str) -> Callable: + # Implement me! + pass + + +* ``load_dataset_shard()`` will take in a path to a tar file and return an iterable over the shard. The iterable should return a tuple of (a batch of data, metadata). + The batch of data can be of any form. It will be directly passed to the model returned by ``load_embedding_model()``. + The metadata should be a dictionary of metadata, with a field corresponding to the ``id_col`` of the dataset. + In our example, the metadata should include a value for ``"key"``. +* ``load_embedding_model()`` will take a device and return a callable object. + This callable will take as input a batch of data produced by ``load_dataset_shard()``. \ No newline at end of file From e350ab0662b5c72fc2e0c5dd5a0e64bc91a36878 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 15 Oct 2024 09:51:40 -0700 Subject: [PATCH 48/57] Revert embedding column name change Signed-off-by: Ryan Wolf --- nemo_curator/image/classifiers/aesthetic.py | 6 +++--- nemo_curator/image/classifiers/base.py | 10 +++++----- nemo_curator/image/classifiers/nsfw.py | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/nemo_curator/image/classifiers/aesthetic.py b/nemo_curator/image/classifiers/aesthetic.py index 885515c6..a0dcfd20 100644 --- a/nemo_curator/image/classifiers/aesthetic.py +++ b/nemo_curator/image/classifiers/aesthetic.py @@ -55,7 +55,7 @@ class AestheticClassifier(ImageClassifier): def __init__( self, - image_embedding_column: str = "image_embedding", + embedding_column: str = "image_embedding", pred_column: str = "aesthetic_score", batch_size: int = -1, model_path: Optional[str] = None, @@ -64,7 +64,7 @@ def __init__( Constructs the classifier. Args: - image_embedding_column (str): The column name that stores the image + embedding_column (str): The column name that stores the image embeddings. pred_column (str): The column name to be added where the aesthetic scores will be stored. @@ -78,7 +78,7 @@ def __init__( """ super().__init__( model_name="aesthetic_classifier", - image_embedding_column=image_embedding_column, + embedding_column=embedding_column, pred_column=pred_column, pred_type=float, batch_size=batch_size, diff --git a/nemo_curator/image/classifiers/base.py b/nemo_curator/image/classifiers/base.py index 597d1b62..857161a6 100644 --- a/nemo_curator/image/classifiers/base.py +++ b/nemo_curator/image/classifiers/base.py @@ -37,7 +37,7 @@ class ImageClassifier(ABC): def __init__( self, model_name: str, - image_embedding_column: str, + embedding_column: str, pred_column: str, pred_type: Union[str, type], batch_size: int, @@ -49,7 +49,7 @@ def __init__( Args: model_name (str): A unqiue name to identify the model on each worker and in the logs. - image_embedding_column (str): The column name that stores the image + embedding_column (str): The column name that stores the image embeddings. pred_column (str): The column name to be added where the classifier's predictions will be stored. @@ -59,7 +59,7 @@ def __init__( all embeddings will be processed at once. """ self.model_name = model_name - self.image_embedding_column = image_embedding_column + self.embedding_column = embedding_column self.pred_column = pred_column self.pred_type = pred_type self.batch_size = batch_size @@ -96,7 +96,7 @@ def _run_inference(self, partition, partition_info=None): ) embeddings = torch.as_tensor( - partition[self.image_embedding_column].list.leaves.values.reshape( + partition[self.embedding_column].list.leaves.values.reshape( len(partition), -1 ), device=device, @@ -105,7 +105,7 @@ def _run_inference(self, partition, partition_info=None): if self.embedding_size != embeddings.shape[-1]: raise RuntimeError( f"{self.model_name} expects embedding size {self.embedding_size} but column " - f"'{self.image_embedding_column}' has embedding size {embeddings.shape[-1]}. Ensure your " + f"'{self.embedding_column}' has embedding size {embeddings.shape[-1]}. Ensure your " "classifier is compatible with the CLIP model you used to generate the embeddings." ) diff --git a/nemo_curator/image/classifiers/nsfw.py b/nemo_curator/image/classifiers/nsfw.py index 07588191..81968648 100644 --- a/nemo_curator/image/classifiers/nsfw.py +++ b/nemo_curator/image/classifiers/nsfw.py @@ -63,7 +63,7 @@ class NsfwClassifier(ImageClassifier): def __init__( self, - image_embedding_column: str = "image_embedding", + embedding_column: str = "image_embedding", pred_column: str = "nsfw_score", batch_size: int = -1, model_path: Optional[str] = None, @@ -72,7 +72,7 @@ def __init__( Constructs the classifier. Args: - image_embedding_column (str): The column name that stores the image + embedding_column (str): The column name that stores the image embeddings. pred_column (str): The column name to be added where the nsfw scores will be stored. @@ -86,7 +86,7 @@ def __init__( """ super().__init__( model_name="nsfw_classifier", - image_embedding_column=image_embedding_column, + embedding_column=embedding_column, pred_column=pred_column, pred_type=float, batch_size=batch_size, From 31dbfde6832afd840b81888ded0a0e4370161cbc Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 18 Oct 2024 11:04:21 -0700 Subject: [PATCH 49/57] Update user guide for images Signed-off-by: Ryan Wolf --- .../image/classifiers/aesthetic.rst | 97 +++++++++++++++++++ docs/user-guide/image/classifiers/index.rst | 6 ++ docs/user-guide/image/classifiers/nsfw.rst | 97 +++++++++++++++++++ docs/user-guide/image/embedders.rst | 66 ++++++++++--- docs/user-guide/image/gettingstarted.rst | 56 +++++++++++ docs/user-guide/image/index.rst | 2 +- 6 files changed, 312 insertions(+), 12 deletions(-) create mode 100644 docs/user-guide/image/classifiers/aesthetic.rst create mode 100644 docs/user-guide/image/classifiers/index.rst create mode 100644 docs/user-guide/image/classifiers/nsfw.rst create mode 100644 docs/user-guide/image/gettingstarted.rst diff --git a/docs/user-guide/image/classifiers/aesthetic.rst b/docs/user-guide/image/classifiers/aesthetic.rst new file mode 100644 index 00000000..5817eae1 --- /dev/null +++ b/docs/user-guide/image/classifiers/aesthetic.rst @@ -0,0 +1,97 @@ +========================= +Aesthetic Classifier +========================= + +-------------------- +Overview +-------------------- +Aesthetic classifiers can be used to assess the subjective quality of an image. +NeMo Curator integrates the `improved aesthetic predictor `_ that outputs a score from 0-10 where 10 is aesthetically pleasing. + +-------------------- +Use Cases +-------------------- +Filtering by aesthetic quality is common in generative image pipelines. +For example, `Stable Diffusion `_ progressively filtered by aesthetic score during training. + + +-------------------- +Prerequisities +-------------------- +Make sure you check out the `image curation getting started page `_ to install everything you will need. + +-------------------- +Usage +-------------------- + +The aesthetic classifier is a linear classifier that takes as input OpenAI CLIP ViT-L/14 image embeddings as input. +This model is available through the ``vit_large_patch14_clip_quickgelu_224.openai`` identifier in ``TimmImageEmbedder``. +First, we can compute these embeddings, then we can perform the classification. + +.. code-block:: python + + from nemo_curator import get_client + from nemo_curator.datasets import ImageTextPairDataset + from nemo_curator.image.embedders import TimmImageEmbedder + from nemo_curator.image.classifiers import AestheticClassifier + + client = get_client(cluster_type="gpu") + + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") + + embedding_model = TimmImageEmbedder( + "vit_large_patch14_clip_quickgelu_224.openai", + pretrained=True, + batch_size=1024, + num_threads_per_worker=16, + normalize_embeddings=True, + ) + aesthetic_classifier = AestheticClassifier() + + dataset_with_embeddings = embedding_model(dataset) + dataset_with_aesthetic_scores = aesthetic_classifier(dataset_with_embeddings) + + # Metdata will have a new column named "aesthetic_score" + dataset_with_aesthetic_scores.save_metadata() + +-------------------- +Key Parameters +-------------------- +* ``batch_size=-1`` is the optional batch size parameter. By default, it will process all the embeddings in a shard at once. Since the aesthetic classifier is a linear model, this is usually fine. + +--------------------------- +Performance Considerations +--------------------------- +Since the aesthetic model is so small, you can load it onto the GPU at the same time as the embedding model and perform inference directly after computing the embeddings. +Check out this example: + +.. code-block:: python + + from nemo_curator import get_client + from nemo_curator.datasets import ImageTextPairDataset + from nemo_curator.image.embedders import TimmImageEmbedder + from nemo_curator.image.classifiers import AestheticClassifier + + client = get_client(cluster_type="gpu") + + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") + + embedding_model = TimmImageEmbedder( + "vit_large_patch14_clip_quickgelu_224.openai", + pretrained=True, + batch_size=1024, + num_threads_per_worker=16, + normalize_embeddings=True, + classifiers=[AestheticClassifier()], + ) + + dataset_with_aesthetic_scores = embedding_model(dataset) + + # Metdata will have a new column named "aesthetic_score" + dataset_with_aesthetic_scores.save_metadata() + +--------------------------- +Additional Resources +--------------------------- +* `Image Curation Tutorial `_ +* `API Reference `_ \ No newline at end of file diff --git a/docs/user-guide/image/classifiers/index.rst b/docs/user-guide/image/classifiers/index.rst new file mode 100644 index 00000000..ba4e4be1 --- /dev/null +++ b/docs/user-guide/image/classifiers/index.rst @@ -0,0 +1,6 @@ +.. toctree:: + :maxdepth: 4 + :titlesonly: + + aesthetic.rst + nsfw.rst \ No newline at end of file diff --git a/docs/user-guide/image/classifiers/nsfw.rst b/docs/user-guide/image/classifiers/nsfw.rst new file mode 100644 index 00000000..2fb9cb48 --- /dev/null +++ b/docs/user-guide/image/classifiers/nsfw.rst @@ -0,0 +1,97 @@ +========================= +NSFW Classifier +========================= + +-------------------- +Overview +-------------------- +Not-safe-for-work (NSFW) classifiers determine the likelihood of an image containing sexually explicity material. +NeMo Curator integrates with `CLIP-based-NSFW-Detector `_ that outputs a value between 0 and 1 where 1 means the content is NSFW. + +-------------------- +Use Cases +-------------------- +Removing unsafe content is common in most data processing pipelines to prevent your generative AI model from learning to produce unsafe material. +For example, `Data Comp `_ filter out NSFW content before conducting their experiments. + +-------------------- +Prerequisities +-------------------- +Make sure you check out the `image curation getting started page `_ to install everything you will need. + +-------------------- +Usage +-------------------- + +The NSFW classifier is a small MLP classifier that takes as input OpenAI CLIP ViT-L/14 image embeddings as input. +This model is available through the ``vit_large_patch14_clip_quickgelu_224.openai`` identifier in ``TimmImageEmbedder``. +First, we can compute these embeddings, then we can perform the classification. + +.. code-block:: python + + from nemo_curator import get_client + from nemo_curator.datasets import ImageTextPairDataset + from nemo_curator.image.embedders import TimmImageEmbedder + from nemo_curator.image.classifiers import NsfwClassifier + + client = get_client(cluster_type="gpu") + + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") + + embedding_model = TimmImageEmbedder( + "vit_large_patch14_clip_quickgelu_224.openai", + pretrained=True, + batch_size=1024, + num_threads_per_worker=16, + normalize_embeddings=True, + ) + safety_classifier = NsfwClassifier() + + dataset_with_embeddings = embedding_model(dataset) + dataset_with_aesthetic_scores = safety_classifier(dataset_with_embeddings) + + # Metdata will have a new column named "nsfw_score" + dataset_with_aesthetic_scores.save_metadata() + +-------------------- +Key Parameters +-------------------- +* ``batch_size=-1`` is the optional batch size parameter. By default, it will process all the embeddings in a shard at once. Since the aesthetic classifier is a samll model, this is usually fine. + +--------------------------- +Performance Considerations +--------------------------- +Since the NSFW model is so small, you can load it onto the GPU at the same time as the embedding model and perform inference directly after computing the embeddings. +Check out this example: + +.. code-block:: python + + from nemo_curator import get_client + from nemo_curator.datasets import ImageTextPairDataset + from nemo_curator.image.embedders import TimmImageEmbedder + from nemo_curator.image.classifiers import NsfwClassifier + + client = get_client(cluster_type="gpu") + + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") + + embedding_model = TimmImageEmbedder( + "vit_large_patch14_clip_quickgelu_224.openai", + pretrained=True, + batch_size=1024, + num_threads_per_worker=16, + normalize_embeddings=True, + classifiers=[NsfwClassifier()], + ) + + dataset_with_aesthetic_scores = embedding_model(dataset) + + # Metdata will have a new column named "aesthetic_score" + dataset_with_aesthetic_scores.save_metadata() + + +--------------------------- +Additional Resources +--------------------------- +* `Image Curation Tutorial `_ +* `API Reference `_ \ No newline at end of file diff --git a/docs/user-guide/image/embedders.rst b/docs/user-guide/image/embedders.rst index f9684c5c..0ad77b37 100644 --- a/docs/user-guide/image/embedders.rst +++ b/docs/user-guide/image/embedders.rst @@ -4,20 +4,38 @@ Image Embedders ========================= +-------------------- +Overview +-------------------- +Many image curation features in NeMo Curator operate on image embeddings instead of images directly. +Image embedders provide a scalable way of generating embeddings for each image in the dataset. + +-------------------- +Use Cases +-------------------- +* Aesthetic and NSFW Classification takes image embeddings generated from OpenAI's CLIP ViT-L variant +* Semantic deduplication computes the similarity of datapoints + +-------------------- +Prerequisities +-------------------- +Make sure you check out the `image curation getting started page `_ to install everything you will need. + -------------------- Timm Image Embedder -------------------- PyTorch Image Models (timm) is a library containing SOTA computer vision models. Many of these models are useful in generating image embeddings for modules in NeMo Curator. -NeMo Curator provides easy support for all these models through ``TimmImageEmbedder``. -This module can also automatically convert the image transformations from PyTorch transformations to DALI transformations in the supported models. .. code-block:: python + from nemo_curator import get_client from nemo_curator.datasets import ImageTextPairDataset from nemo_curator.image.embedders import TimmImageEmbedder + client = get_client(cluster_type="gpu") + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") embedding_model = TimmImageEmbedder( @@ -26,26 +44,42 @@ This module can also automatically convert the image transformations from PyTorc batch_size=1024, num_threads_per_worker=16, normalize_embeddings=True, - autocast=False, ) dataset_with_embeddings = embedding_model(dataset) + # Metdata will have a new column named "image_embedding" dataset_with_embeddings.save_metadata() Here, we load a dataset in and compute the image embeddings using ``vit_large_patch14_clip_quickgelu_224.openai``. -This model is the base for NeMo Curator's aesthetic classifier, so we use it for this example. +At the end of the process, our metadata files have a new column named "image_embedding" that contains the image embedddings for each datapoint. -A more thorough list of parameters can be found in the `API Reference `_. +-------------------- +Key Parameters +-------------------- +* ``pretrained=True`` ensures you download the pretrained weights of the model. +* ``batch_size=1024`` determines the number of images processed on each individual GPU at once. +* ``num_threads_per_worker=16`` determines the number of threads used by DALI for dataloading. +* ``normalize_embeddings=True`` will normalize each embedding. NeMo Curator's classifiers expect normalized embeddings as input. + +--------------------------- +Performance Considerations +--------------------------- Under the hood, the image embedding model performs the following operations: -1. Load a shard of metadata (a `.parquet` file) onto each GPU you have available using Dask-cuDF. -1. Load a copy of `vit_large_patch14_clip_quickgelu_224.openai` onto each GPU. -1. Repeatedly load images into batches of size `batch_size` onto each GPU with a given threads per worker (`num_threads_per_worker`) using DALI. -1. The model is run on the batch (without `torch.autocast()` since `autocast=False`). -1. The output embeddings of the model are normalized since `normalize_embeddings=True`. +1. Download the weights of the model. +2. Download the PyTorch image transformations (resize and center-crop for example). +3. Convert the PyTorch image transformations to DALI transformations. +4. Load a shard of metadata (a ``.parquet`` file) onto each GPU you have available using Dask-cuDF. +5. Load a copy of the model onto each GPU. +6. Repeatedly load images into batches of size ``batch_size`` onto each GPU with a given threads per worker (``num_threads_per_worker``) using DALI. +7. The model is run on the batch (without ``torch.autocast()`` since ``autocast=False``). +8. The output embeddings of the model are normalized since ``normalize_embeddings=True``. +There are a couple of key performance considerations from this flow. +* You must have an NVIDIA GPU that mets the `requirements `_. +* You can create ``.idx`` files in the same directory of the tar files to speed up dataloading times. See the `DALI documentation `_ for more information. ------------------------ Custom Image Embedder @@ -73,4 +107,14 @@ To write your own custom embedder, you inherit from ``nemo_curator.image.embedde The metadata should be a dictionary of metadata, with a field corresponding to the ``id_col`` of the dataset. In our example, the metadata should include a value for ``"key"``. * ``load_embedding_model()`` will take a device and return a callable object. - This callable will take as input a batch of data produced by ``load_dataset_shard()``. \ No newline at end of file + This callable will take as input a batch of data produced by ``load_dataset_shard()``. + +--------------------------- +Additional Resources +--------------------------- + +* `Aesthetic Classifier `_ +* `NSFW Classifier `_ +* `Semantic Deduplication `_ +* `Image Curation Tutorial `_ +* `API Reference `_ \ No newline at end of file diff --git a/docs/user-guide/image/gettingstarted.rst b/docs/user-guide/image/gettingstarted.rst new file mode 100644 index 00000000..92307e24 --- /dev/null +++ b/docs/user-guide/image/gettingstarted.rst @@ -0,0 +1,56 @@ +================ +Get Started +================ + +NeMo Curator provides many tools for curating large scale text-image pair datasets for training generative image models. + +--------------------- +Install NeMo Curator +--------------------- +To install the image curation modules of NeMo Curator, ensure you meet the following requirements: +* Python 3.10 +* Ubuntu 22.04/20.04 +* NVIDIA GPU + * Voltaβ„’ or higher (compute capability 7.0+) + * CUDA 12 (or above) + +Note: While some of the text-based NeMo Curator modules do not require a GPU, all image curation modules require a GPU. + +You can install NeMo Curator in 3 ways. +1. PyPi +2. Source +3. NeMo Framework Container + +##################### +PyPi +##################### +NeMo Curator's PyPi page can be found `here `_. + +.. code-block:: bash + pip install cython + pip install nemo-curator[image] + +##################### +Source +##################### +NeMo Curator's GitHub can be found `here `_. + +.. code-block:: bash + git clone https://github.com/NVIDIA/NeMo-Curator.git + pip install cython + pip install ./NeMo-Curator[image] + +############################ +NeMo Framework Container +############################ +NeMo Curator comes preinstalled in the NeMo Framework container. You can find a list of all the NeMo Framework container tags `here `_. + +--------------------- +Use NeMo Curator +--------------------- + +NeMo Curator can be run locally, or on a variety of compute platforms (Slurm, k8s, and more). + +To get started using the image modules in NeMo Curator, we recommend you check out the following resources: +* `Image Curation Tutorial `_ +* `API Reference `_ \ No newline at end of file diff --git a/docs/user-guide/image/index.rst b/docs/user-guide/image/index.rst index 7cb88c2d..c6a53a86 100644 --- a/docs/user-guide/image/index.rst +++ b/docs/user-guide/image/index.rst @@ -4,4 +4,4 @@ datasets.rst embedders.rst - classifiers.rst \ No newline at end of file + classifiers/index.rst \ No newline at end of file From 30e004a37c456ecb48c9989b097b97881ea2222b Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 18 Oct 2024 13:32:38 -0700 Subject: [PATCH 50/57] Update README Signed-off-by: Ryan Wolf --- README.md | 133 ++++++++++------------- docs/user-guide/image/gettingstarted.rst | 2 +- 2 files changed, 58 insertions(+), 77 deletions(-) diff --git a/README.md b/README.md index ecf64ba5..9617adf0 100644 --- a/README.md +++ b/README.md @@ -9,51 +9,49 @@ # NeMo Curator -πŸš€ **The GPU-Accelerated Open Source Framework for Efficient Large Language Model Data Curation** πŸš€ +πŸš€ **The GPU-Accelerated Open Source Framework for Efficient Generative AI Model Data Curation** πŸš€ -

- diagram -

- -NeMo Curator is a Python library specifically designed for fast and scalable dataset preparation and curation for [large language model (LLM)](https://www.nvidia.com/en-us/glossary/large-language-models/) use-cases such as foundation model pretraining, domain-adaptive pretraining (DAPT), supervised fine-tuning (SFT) and paramter-efficient fine-tuning (PEFT). It greatly accelerates data curation by leveraging GPUs with [Dask](https://www.dask.org/) and [RAPIDS](https://developer.nvidia.com/rapids), resulting in significant time savings. The library provides a customizable and modular interface, simplifying pipeline expansion and accelerating model convergence through the preparation of high-quality tokens. - -At the core of the NeMo Curator is the `DocumentDataset` which serves as the the main dataset class. It acts as a straightforward wrapper around a Dask `DataFrame`. The Python library offers easy-to-use methods for expanding the functionality of your curation pipeline while eliminating scalability concerns. +NeMo Curator is a Python library specifically designed for fast and scalable dataset preparation and curation for generative AI use-cases such as foundation language model pretraining, text to image model training, domain-adaptive pretraining (DAPT), supervised fine-tuning (SFT) and paramter-efficient fine-tuning (PEFT). It greatly accelerates data curation by leveraging GPUs with [Dask](https://www.dask.org/) and [RAPIDS](https://developer.nvidia.com/rapids), resulting in significant time savings. The library provides a customizable and modular interface, simplifying pipeline expansion and accelerating model convergence through the preparation of high-quality tokens. ## Key Features -NeMo Curator provides a collection of scalable data-mining modules. Some of the key features include: +NeMo Curator provides a collection of scalable data curation modules for text and image curation. -- [Data download and text extraction](docs/user-guide/download.rst) +### Text +All of our text pipelines have great multilingual support. - - Default implementations for downloading and extracting Common Crawl, Wikipedia, and ArXiv data - - Easily customize the download and extraction and extend to other datasets +- [Download and Extraction](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/download.html) -- [Language identification and separation](docs/user-guide/languageidentificationunicodeformatting.rst) with [fastText](https://fasttext.cc/docs/en/language-identification.html) and [pycld2](https://pypi.org/project/pycld2/) + - Common Crawl, Wikipedia, and ArXiv sources + - Easily customize and extend to other sources -- [Text reformatting and cleaning](docs/user-guide/languageidentificationunicodeformatting.rst) to fix unicode decoding errors via [ftfy](https://ftfy.readthedocs.io/en/latest/) +- [Language Identification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/languageidentificationunicodeformatting.html) -- [Quality filtering](docs/user-guide/qualityfiltering.rst) +- [Unicode Fixing](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/languageidentificationunicodeformatting.html) - - Multilingual heuristic-based filtering +- [Heuristic Filtering](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html) - Classifier-based filtering via [fastText](https://fasttext.cc/) -- [Document-level deduplication](docs/user-guide/gpudeduplication.rst) +- Classifier Filtering + - [fastText]((https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html)) + - GPU-based: [Domain, Quality, Safety](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/distributeddataclassification.html) - - exact and fuzzy (near-identical) deduplication are accelerated using cuDF and Dask - - For fuzzy deduplication, our implementation follows the method described in [Microsoft Turing NLG 530B](https://arxiv.org/abs/2201.11990) - - For semantic deduplication, our implementation follows the method described in [SemDeDup](https://arxiv.org/pdf/2303.09540) by Meta AI (FAIR) [facebookresearch/SemDeDup](https://github.com/facebookresearch/SemDeDup) +- **GPU Deduplication** + - [Exact](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html) + - [Fuzzy](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html) (Minhash LSH) + - [Semantic](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/semdedup.html) -- [Multilingual downstream-task decontamination](docs/user-guide/taskdecontamination.rst) following the approach of [OpenAI GPT3](https://arxiv.org/pdf/2005.14165.pdf) and [Microsoft Turing NLG 530B](https://arxiv.org/abs/2201.11990) +- [Downstream-task Decontamination](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/taskdecontamination.html) -- [Distributed data classification](docs/user-guide/distributeddataclassification.rst) +- [Personal Identifiable Information (PII) Redaction](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/personalidentifiableinformationidentificationandremoval.html) - - Multi-node, multi-GPU classifier inference - - Provides sophisticated domain and quality classification - - Flexible interface for extending to your own classifier network +### Image -- [Personal identifiable information (PII) redaction](docs/user-guide/personalidentifiableinformationidentificationandremoval.rst) for removing addresses, credit card numbers, social security numbers, and more - -These modules offer flexibility and permit reordering, with only a few exceptions. In addition, the [NeMo Framework Launcher](https://github.com/NVIDIA/NeMo-Megatron-Launcher) provides pre-built pipelines that can serve as a foundation for your customization use cases. +- [Embedding Creation](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/image/classifiers/embedders.html) +- Classifier Filtering + - [Aesthetic](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/image/classifiers/aesthetic.html), [NSFW](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/image/classifiers/nsfw.html) +- GPU Deduplication + - [Semantic](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/semdedup.html) ## Resources @@ -83,57 +81,51 @@ Before installing NeMo Curator, ensure that the following requirements are met: - Voltaβ„’ or higher ([compute capability 7.0+](https://developer.nvidia.com/cuda-gpus)) - CUDA 12 (or above) -You can install NeMo-Curator -1. from PyPi -2. from source -3. get it through the [NeMo Framework container](https://github.com/NVIDIA/NeMo?tab=readme-ov-file#docker-containers). - - - -#### From PyPi +You can get NeMo-Curator in 3 ways. +1. PyPi +2. Source +3. NeMo Framework Container -To install the CPU-only modules: +#### PyPi ```bash pip install cython -pip install nemo-curator +pip install --extra-index-url https://pypi.nvidia.com nemo-curator[all] ``` -To install the CPU and CUDA-accelerated modules: - +#### Source ```bash +git clone https://github.com/NVIDIA/NeMo-Curator.git pip install cython -pip install --extra-index-url https://pypi.nvidia.com nemo-curator[cuda12x] +pip install ./NeMo-Curator[all] ``` -#### From Source - -1. Clone the NeMo Curator repository in GitHub. - - ```bash - git clone https://github.com/NVIDIA/NeMo-Curator.git - cd NeMo-Curator - ``` +#### From the NeMo Framework Container -2. Install the modules that you need. +The latest release of NeMo Curator comes preinstalled in the [NeMo Framework Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo/tags). If you want the latest commit inside the container, you can reinstall NeMo Curator using: - To install the CPU-only modules: +```bash +pip uninstall nemo-curator +rm -r /opt/NeMo-Curator +git clone https://github.com/NVIDIA/NeMo-Curator.git /opt/NeMo-Curator +pip install --extra-index-url https://pypi.nvidia.com /opt/NeMo-Curator[all] +``` - ```bash - pip install cython - pip install . - ``` +#### Extras +NeMo Curator has a set of extras you can use to only install the necessary modules for your workload. +These extras are available for all installation methods provided. - To install the CPU and CUDA-accelerated modules: +```bash +pip install nemo-curator # Installs CPU-only text curation modules +pip install --extra-index-url https://pypi.nvidia.com nemo-curator[cuda12x] # Installs CPU + GPU text curation modules +pip install --extra-index-url https://pypi.nvidia.com nemo-curator[image] # Installs CPU + GPU text and image curation modules +pip install --extra-index-url https://pypi.nvidia.com nemo-curator[all] # Installs all of the above +``` - ```bash - pip install cython - pip install --extra-index-url https://pypi.nvidia.com ".[cuda12x]" - ``` -#### Using Nightly Dependencies for Rapids +#### Using Nightly Dependencies for RAPIDS -You can also install NeMo Curator using the Rapids nightly, to do so you can set the environment variable `RAPIDS_NIGHTLY=1`. +You can also install NeMo Curator using the RAPIDS nightly, to do so you can set the environment variable `RAPIDS_NIGHTLY=1`. ```bash @@ -146,18 +138,6 @@ RAPIDS_NIGHTLY=1 pip install --extra-index-url=https://pypi.anaconda.org/rapidsa When the environment variable set to 0 or not set (default behavior) it'll use the stable version of Rapids. -#### From the NeMo Framework Container - -The latest release of NeMo Curator comes preinstalled in the [NeMo Framework Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo/tags). If you want the latest commit inside the container, you can reinstall NeMo Curator using: - -```bash -pip uninstall nemo-curator -rm -r /opt/NeMo-Curator -git clone https://github.com/NVIDIA/NeMo-Curator.git /opt/NeMo-Curator -pip install --extra-index-url https://pypi.nvidia.com /opt/NeMo-Curator[cuda12x] -``` -And follow the instructions for installing from source from [above](#from-source). - ## Use NeMo Curator ### Python API Quick Example @@ -189,6 +169,7 @@ To get started with NeMo Curator, you can follow the tutorials [available here]( - [`peft-curation`](https://github.com/NVIDIA/NeMo-Curator/tree/main/tutorials/peft-curation) which focuses on data curation for LLM parameter-efficient fine-tuning (PEFT) use-cases. - [`distributed_data_classification`](https://github.com/NVIDIA/NeMo-Curator/tree/main/tutorials/distributed_data_classification) which focuses on using the quality and domain classifiers to help with data annotation. - [`single_node_tutorial`](https://github.com/NVIDIA/NeMo-Curator/tree/main/tutorials/single_node_tutorial) which demonstrates an end-to-end data curation pipeline for curating Wikipedia data in Thai. +- [`image-curation`](https://github.com/NVIDIA/NeMo-Curator/blob/main/tutorials/image-curation/image-curation.ipynb) which explores the scalable image curation modules. ### Access Python Modules @@ -201,9 +182,9 @@ NeMo Curator also offers CLI scripts for you to use. The scripts in `nemo_curato ### Use NeMo Framework Launcher -As an alternative method for interfacing with NeMo Curator, you can use the [NeMo Framework Launcher](https://github.com/NVIDIA/NeMo-Megatron-Launcher). The launcher enables you to easily configure the parameters and cluster. It can also automatically generate the SLURM batch scripts that wrap around the CLI scripts required to run your pipeline. +As an alternative method for interfacing with NeMo Curator, you can use the [NeMo Framework Launcher](https://github.com/NVIDIA/NeMo-Megatron-Launcher). The launcher enables you to easily configure the parameters and cluster. It can also automatically generate the Slurm batch scripts that wrap around the CLI scripts required to run your pipeline. -In addition, other methods are available to run NeMo Curator on SLURM. For example, refer to the example scripts in [`examples/slurm`](examples/slurm/) for information on how to run NeMo Curator on SLURM without the NeMo Framework Launcher. +In addition, other methods are available to run NeMo Curator on Slurm. For example, refer to the example scripts in [`examples/slurm`](examples/slurm/) for information on how to run NeMo Curator on Slurm without the NeMo Framework Launcher. ## Module Ablation and Compute Performance diff --git a/docs/user-guide/image/gettingstarted.rst b/docs/user-guide/image/gettingstarted.rst index 92307e24..c7955fe3 100644 --- a/docs/user-guide/image/gettingstarted.rst +++ b/docs/user-guide/image/gettingstarted.rst @@ -16,7 +16,7 @@ To install the image curation modules of NeMo Curator, ensure you meet the follo Note: While some of the text-based NeMo Curator modules do not require a GPU, all image curation modules require a GPU. -You can install NeMo Curator in 3 ways. +You can get NeMo Curator in 3 ways. 1. PyPi 2. Source 3. NeMo Framework Container From 9c81b6e628b9ad93106234f603bbe6a109f5d518 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 18 Oct 2024 13:34:33 -0700 Subject: [PATCH 51/57] Update README with RAPIDS nightly instructions Signed-off-by: Ryan Wolf --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 9617adf0..163113af 100644 --- a/README.md +++ b/README.md @@ -125,8 +125,7 @@ pip install --extra-index-url https://pypi.nvidia.com nemo-curator[all] # Instal #### Using Nightly Dependencies for RAPIDS -You can also install NeMo Curator using the RAPIDS nightly, to do so you can set the environment variable `RAPIDS_NIGHTLY=1`. - +You can also install NeMo Curator using the [RAPIDS Nightly Builds](https://docs.rapids.ai/install). To do so, you can set the environment variable `RAPIDS_NIGHTLY=1`. ```bash # installing from pypi From f9f47ed1880c090c70a362cb47ea12a232ac36e4 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 18 Oct 2024 13:53:44 -0700 Subject: [PATCH 52/57] Fix formatting issues in image documentation Signed-off-by: Ryan Wolf --- README.md | 7 ------- docs/user-guide/image/classifiers/index.rst | 2 ++ docs/user-guide/image/classifiers/nsfw.rst | 12 ++++++------ docs/user-guide/image/gettingstarted.rst | 8 ++++++++ docs/user-guide/index.rst | 5 ++++- nemo_curator/image/classifiers/nsfw.py | 4 ++-- 6 files changed, 22 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 163113af..03b6c326 100644 --- a/README.md +++ b/README.md @@ -21,28 +21,21 @@ NeMo Curator provides a collection of scalable data curation modules for text an All of our text pipelines have great multilingual support. - [Download and Extraction](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/download.html) - - Common Crawl, Wikipedia, and ArXiv sources - Easily customize and extend to other sources - [Language Identification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/languageidentificationunicodeformatting.html) - - [Unicode Fixing](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/languageidentificationunicodeformatting.html) - - [Heuristic Filtering](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html) - Classifier-based filtering via [fastText](https://fasttext.cc/) - - Classifier Filtering - [fastText]((https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html)) - GPU-based: [Domain, Quality, Safety](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/distributeddataclassification.html) - - **GPU Deduplication** - [Exact](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html) - [Fuzzy](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html) (Minhash LSH) - [Semantic](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/semdedup.html) - - [Downstream-task Decontamination](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/taskdecontamination.html) - - [Personal Identifiable Information (PII) Redaction](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/personalidentifiableinformationidentificationandremoval.html) ### Image diff --git a/docs/user-guide/image/classifiers/index.rst b/docs/user-guide/image/classifiers/index.rst index ba4e4be1..bbd67f0e 100644 --- a/docs/user-guide/image/classifiers/index.rst +++ b/docs/user-guide/image/classifiers/index.rst @@ -1,3 +1,5 @@ +.. _data-curator-image-classifiers: + .. toctree:: :maxdepth: 4 :titlesonly: diff --git a/docs/user-guide/image/classifiers/nsfw.rst b/docs/user-guide/image/classifiers/nsfw.rst index 2fb9cb48..4bf505e8 100644 --- a/docs/user-guide/image/classifiers/nsfw.rst +++ b/docs/user-guide/image/classifiers/nsfw.rst @@ -48,15 +48,15 @@ First, we can compute these embeddings, then we can perform the classification. safety_classifier = NsfwClassifier() dataset_with_embeddings = embedding_model(dataset) - dataset_with_aesthetic_scores = safety_classifier(dataset_with_embeddings) + dataset_with_nsfw_scores = safety_classifier(dataset_with_embeddings) # Metdata will have a new column named "nsfw_score" - dataset_with_aesthetic_scores.save_metadata() + dataset_with_nsfw_scores.save_metadata() -------------------- Key Parameters -------------------- -* ``batch_size=-1`` is the optional batch size parameter. By default, it will process all the embeddings in a shard at once. Since the aesthetic classifier is a samll model, this is usually fine. +* ``batch_size=-1`` is the optional batch size parameter. By default, it will process all the embeddings in a shard at once. Since the NSFW classifier is a samll model, this is usually fine. --------------------------- Performance Considerations @@ -84,10 +84,10 @@ Check out this example: classifiers=[NsfwClassifier()], ) - dataset_with_aesthetic_scores = embedding_model(dataset) + dataset_with_nsfw_scores = embedding_model(dataset) - # Metdata will have a new column named "aesthetic_score" - dataset_with_aesthetic_scores.save_metadata() + # Metdata will have a new column named "nsfw_score" + dataset_with_nsfw_scores.save_metadata() --------------------------- diff --git a/docs/user-guide/image/gettingstarted.rst b/docs/user-guide/image/gettingstarted.rst index c7955fe3..dae4240d 100644 --- a/docs/user-guide/image/gettingstarted.rst +++ b/docs/user-guide/image/gettingstarted.rst @@ -1,3 +1,6 @@ + +.. _data-curator-image-getting-started: + ================ Get Started ================ @@ -8,6 +11,7 @@ NeMo Curator provides many tools for curating large scale text-image pair datase Install NeMo Curator --------------------- To install the image curation modules of NeMo Curator, ensure you meet the following requirements: + * Python 3.10 * Ubuntu 22.04/20.04 * NVIDIA GPU @@ -17,6 +21,7 @@ To install the image curation modules of NeMo Curator, ensure you meet the follo Note: While some of the text-based NeMo Curator modules do not require a GPU, all image curation modules require a GPU. You can get NeMo Curator in 3 ways. + 1. PyPi 2. Source 3. NeMo Framework Container @@ -27,6 +32,7 @@ PyPi NeMo Curator's PyPi page can be found `here `_. .. code-block:: bash + pip install cython pip install nemo-curator[image] @@ -36,6 +42,7 @@ Source NeMo Curator's GitHub can be found `here `_. .. code-block:: bash + git clone https://github.com/NVIDIA/NeMo-Curator.git pip install cython pip install ./NeMo-Curator[image] @@ -52,5 +59,6 @@ Use NeMo Curator NeMo Curator can be run locally, or on a variety of compute platforms (Slurm, k8s, and more). To get started using the image modules in NeMo Curator, we recommend you check out the following resources: + * `Image Curation Tutorial `_ * `API Reference `_ \ No newline at end of file diff --git a/docs/user-guide/index.rst b/docs/user-guide/index.rst index b72513d2..de53263f 100644 --- a/docs/user-guide/index.rst +++ b/docs/user-guide/index.rst @@ -38,6 +38,9 @@ Text Curation Image Curation ------------------- +:ref:`Get Started ` + Install NeMo Curator's image curation modules. + :ref:`Image-Text Pair Datasets ` Image-text pair datasets are commonly used as the basis for training multimodal generative models. NeMo Curator interfaces with the standardized Webdataset format for curating such datasets. @@ -47,7 +50,7 @@ Image Curation :ref:`Classifiers ` NeMo Curator provides several ways to use common classifiers like aesthetic scoring, and not-safe-for-work (NSFW) scoring. -:ref:`Semantic Deduplication ` +:ref:`Semantic Deduplication ` Semantic deduplication with image datasets has been shown to drastically improve model performance. NeMo Curator has a semnatic deduplication module that can work with any modality. ------------------- diff --git a/nemo_curator/image/classifiers/nsfw.py b/nemo_curator/image/classifiers/nsfw.py index 81968648..f641b887 100644 --- a/nemo_curator/image/classifiers/nsfw.py +++ b/nemo_curator/image/classifiers/nsfw.py @@ -55,8 +55,8 @@ def forward(self, x): class NsfwClassifier(ImageClassifier): """ NSFW Classifier is a small MLP trained on top of - Laion's ViT-H image embeddings. It is used to assess the likelihood - of images containing sexually explicit materal. + OpenAI's ViT-L CLIP image embeddings. It is used to assess the likelihood + of images containing sexually explicit material. More information on the model can be found here: https://github.com/LAION-AI/CLIP-based-NSFW-Detector. """ From c090ab67b407468c53cb953ed245671180cd75bd Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 18 Oct 2024 13:58:10 -0700 Subject: [PATCH 53/57] Remove extra newline in README Signed-off-by: Ryan Wolf --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 03b6c326..b80122a7 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,6 @@ All of our text pipelines have great multilingual support. - [Download and Extraction](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/download.html) - Common Crawl, Wikipedia, and ArXiv sources - Easily customize and extend to other sources - - [Language Identification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/languageidentificationunicodeformatting.html) - [Unicode Fixing](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/languageidentificationunicodeformatting.html) - [Heuristic Filtering](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html) From e91e70d4892dca34e222a7d22bf93e237e7415cb Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 22 Oct 2024 13:13:24 -0700 Subject: [PATCH 54/57] Address most of Sarah's feedback Signed-off-by: Ryan Wolf --- README.md | 25 +++++++-------- .../image/classifiers/aesthetic.rst | 8 ++--- docs/user-guide/image/classifiers/nsfw.rst | 10 +++--- docs/user-guide/image/datasets.rst | 26 +++++++-------- docs/user-guide/image/embedders.rst | 13 ++++---- docs/user-guide/index.rst | 8 ++--- .../datasets/image_text_pair_dataset.py | 32 +++++++++---------- nemo_curator/image/classifiers/aesthetic.py | 2 +- nemo_curator/image/classifiers/base.py | 10 +++--- nemo_curator/image/classifiers/nsfw.py | 2 +- nemo_curator/image/embedders/base.py | 18 +++++------ nemo_curator/image/embedders/timm.py | 18 +++++------ 12 files changed, 86 insertions(+), 86 deletions(-) diff --git a/README.md b/README.md index b80122a7..34a19b12 100644 --- a/README.md +++ b/README.md @@ -11,37 +11,36 @@ # NeMo Curator πŸš€ **The GPU-Accelerated Open Source Framework for Efficient Generative AI Model Data Curation** πŸš€ -NeMo Curator is a Python library specifically designed for fast and scalable dataset preparation and curation for generative AI use-cases such as foundation language model pretraining, text to image model training, domain-adaptive pretraining (DAPT), supervised fine-tuning (SFT) and paramter-efficient fine-tuning (PEFT). It greatly accelerates data curation by leveraging GPUs with [Dask](https://www.dask.org/) and [RAPIDS](https://developer.nvidia.com/rapids), resulting in significant time savings. The library provides a customizable and modular interface, simplifying pipeline expansion and accelerating model convergence through the preparation of high-quality tokens. +NeMo Curator is a Python library specifically designed for fast and scalable dataset preparation and curation for generative AI use cases such as foundation language model pretraining, text-to-image model training, domain-adaptive pretraining (DAPT), supervised fine-tuning (SFT) and parameter-efficient fine-tuning (PEFT). It greatly accelerates data curation by leveraging GPUs with [Dask](https://www.dask.org/) and [RAPIDS](https://developer.nvidia.com/rapids), resulting in significant time savings. The library provides a customizable and modular interface, simplifying pipeline expansion and accelerating model convergence through the preparation of high-quality tokens. ## Key Features NeMo Curator provides a collection of scalable data curation modules for text and image curation. -### Text +### Text Curation All of our text pipelines have great multilingual support. - [Download and Extraction](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/download.html) - - Common Crawl, Wikipedia, and ArXiv sources + - Default implementations Common Crawl, Wikipedia, and ArXiv sources - Easily customize and extend to other sources - [Language Identification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/languageidentificationunicodeformatting.html) -- [Unicode Fixing](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/languageidentificationunicodeformatting.html) +- [Unicode Reformatting](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/languageidentificationunicodeformatting.html) - [Heuristic Filtering](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html) - - Classifier-based filtering via [fastText](https://fasttext.cc/) - Classifier Filtering - [fastText]((https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html)) - - GPU-based: [Domain, Quality, Safety](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/distributeddataclassification.html) + - GPU-based models: [Domain, Quality, and Safety Classification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/distributeddataclassification.html) - **GPU Deduplication** - - [Exact](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html) - - [Fuzzy](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html) (Minhash LSH) - - [Semantic](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/semdedup.html) + - [Exact Deduplication](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html) + - [Fuzzy Deduplication](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html) via MinHash Locality Sensitive Hashing + - [Semantic Deduplication](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/semdedup.html) - [Downstream-task Decontamination](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/taskdecontamination.html) - [Personal Identifiable Information (PII) Redaction](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/personalidentifiableinformationidentificationandremoval.html) -### Image +### Image Curation - [Embedding Creation](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/image/classifiers/embedders.html) - Classifier Filtering - - [Aesthetic](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/image/classifiers/aesthetic.html), [NSFW](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/image/classifiers/nsfw.html) + - [Aesthetic](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/image/classifiers/aesthetic.html) and [NSFW](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/image/classifiers/nsfw.html) Classification - GPU Deduplication - [Semantic](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/semdedup.html) @@ -92,7 +91,7 @@ pip install cython pip install ./NeMo-Curator[all] ``` -#### From the NeMo Framework Container +#### NeMo Framework Container The latest release of NeMo Curator comes preinstalled in the [NeMo Framework Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo/tags). If you want the latest commit inside the container, you can reinstall NeMo Curator using: @@ -127,7 +126,7 @@ RAPIDS_NIGHTLY=1 pip install --extra-index-url=https://pypi.anaconda.org/rapidsa RAPIDS_NIGHTLY=1 pip install --extra-index-url=https://pypi.anaconda.org/rapidsai-wheels-nightly/simple ".[cuda12x]" ``` -When the environment variable set to 0 or not set (default behavior) it'll use the stable version of Rapids. +When the `RAPIDS_NIGHTLY` variable is set to 0 (which is the default), it will use the stable version of RAPIDS. ## Use NeMo Curator ### Python API Quick Example diff --git a/docs/user-guide/image/classifiers/aesthetic.rst b/docs/user-guide/image/classifiers/aesthetic.rst index 5817eae1..3a43cebe 100644 --- a/docs/user-guide/image/classifiers/aesthetic.rst +++ b/docs/user-guide/image/classifiers/aesthetic.rst @@ -16,7 +16,7 @@ For example, `Stable Diffusion `_ to install everything you will need. @@ -24,7 +24,7 @@ Make sure you check out the `image curation getting started page `_ filter out NSFW content before conducting their experiments. -------------------- -Prerequisities +Prerequisites -------------------- Make sure you check out the `image curation getting started page `_ to install everything you will need. @@ -23,7 +23,7 @@ Make sure you check out the `image curation getting started page ``00042.parquet``). -Each parquet file should contain an aggregated tabular form of the metadata for each record, with -each row in the parquet file corresponding to a record in that shard. The metadata, both in the parquet -files and the json files, must contain a unique id column that is the same as its record id (000420031 +Each Parquet file should contain an aggregated tabular form of the metadata for each record, with +each row in the Parquet file corresponding to a record in that shard. The metadata, both in the Parquet +files and the JSON files, must contain a unique ID column that is the same as its record ID (000420031 in our examples). ------- @@ -68,7 +68,7 @@ Datasets can be read in using ``ImageTextPairDataset.from_webdataset()`` dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") * ``path="/path/to/dataset"`` should point to the root directory of the WebDataset. -* ``id_col="key"`` lets us know that the unique id column in the dataset is named "key" +* ``id_col="key"`` lets us know that the unique ID column in the dataset is named "key". A more thorough list of parameters can be found in the `API Reference `_. @@ -88,10 +88,10 @@ Both trigger the computation of all the tasks you have set to run beforehand. dataset.save_metadata() -``save_metadata()`` will only save sharded parquet files to the target directory. It does not modify the tar files. -There are two optional parameters +``save_metadata()`` will only save sharded Parquet files to the target directory. It does not modify the tar files. +There are two optional parameters: -* ``path`` allows you to change the location of where the dataset is saved. By default, it will overwrite the original parquet files. +* ``path`` allows you to change the location of where the dataset is saved. By default, it will overwrite the original Parquet files. * ``columns`` allows you to only save a subset of metadata. By default, all metadata will be saved. @@ -104,7 +104,7 @@ There are two optional parameters dataset.to_webdataset(path="/path/to/output", filter_column="passes_curation") -``to_webdataset()`` will reshard the webdataset to only include elements that have a value of ``True`` in the ``filter_column``. +``to_webdataset()`` will reshard the WebDataset to only include elements that have a value of ``True`` in the ``filter_column``. Resharding can take a while, so this should typically only be done at the end of your curation pipeline when you are ready to export the dataset for training. diff --git a/docs/user-guide/image/embedders.rst b/docs/user-guide/image/embedders.rst index 0ad77b37..83033f9c 100644 --- a/docs/user-guide/image/embedders.rst +++ b/docs/user-guide/image/embedders.rst @@ -13,11 +13,11 @@ Image embedders provide a scalable way of generating embeddings for each image i -------------------- Use Cases -------------------- -* Aesthetic and NSFW Classification takes image embeddings generated from OpenAI's CLIP ViT-L variant -* Semantic deduplication computes the similarity of datapoints +* Aesthetic and NSFW classification both use image embeddings generated from OpenAI's CLIP ViT-L variant. +* Semantic deduplication computes the similarity of datapoints. -------------------- -Prerequisities +Prerequisites -------------------- Make sure you check out the `image curation getting started page `_ to install everything you will need. @@ -25,7 +25,7 @@ Make sure you check out the `image curation getting started page `_ is a library containing SOTA computer vision models. Many of these models are useful in generating image embeddings for modules in NeMo Curator. .. code-block:: python @@ -48,7 +48,7 @@ Many of these models are useful in generating image embeddings for modules in Ne dataset_with_embeddings = embedding_model(dataset) - # Metdata will have a new column named "image_embedding" + # Metadata will have a new column named "image_embedding" dataset_with_embeddings.save_metadata() Here, we load a dataset in and compute the image embeddings using ``vit_large_patch14_clip_quickgelu_224.openai``. @@ -78,6 +78,7 @@ Under the hood, the image embedding model performs the following operations: 8. The output embeddings of the model are normalized since ``normalize_embeddings=True``. There are a couple of key performance considerations from this flow. + * You must have an NVIDIA GPU that mets the `requirements `_. * You can create ``.idx`` files in the same directory of the tar files to speed up dataloading times. See the `DALI documentation `_ for more information. @@ -102,7 +103,7 @@ To write your own custom embedder, you inherit from ``nemo_curator.image.embedde pass -* ``load_dataset_shard()`` will take in a path to a tar file and return an iterable over the shard. The iterable should return a tuple of (a batch of data, metadata). +* ``load_dataset_shard()`` will take in a path to a tar file and return an iterable over the shard. The iterable should return a tuple of ``(a batch of data, metadata)``. The batch of data can be of any form. It will be directly passed to the model returned by ``load_embedding_model()``. The metadata should be a dictionary of metadata, with a field corresponding to the ``id_col`` of the dataset. In our example, the metadata should include a value for ``"key"``. diff --git a/docs/user-guide/index.rst b/docs/user-guide/index.rst index de53263f..fa09f7fa 100644 --- a/docs/user-guide/index.rst +++ b/docs/user-guide/index.rst @@ -23,7 +23,7 @@ Text Curation Both exact and fuzzy deduplication functionalities are supported in NeMo Curator and accelerated using RAPIDS cuDF. :ref:`GPU Accelerated Semantic Deduplication ` - NeMo Curator provides scalable and GPU accelerated semantic deduplication functionality using RAPIDS cuML, cuDF, crossfit and Pytorch. + NeMo Curator provides scalable and GPU accelerated semantic deduplication functionality using RAPIDS cuML, cuDF, crossfit and PyTorch. :ref:`Synthetic Data Generation ` Synthetic data generation tools and example piplines are available within NeMo Curator. @@ -42,16 +42,16 @@ Image Curation Install NeMo Curator's image curation modules. :ref:`Image-Text Pair Datasets ` - Image-text pair datasets are commonly used as the basis for training multimodal generative models. NeMo Curator interfaces with the standardized Webdataset format for curating such datasets. + Image-text pair datasets are commonly used as the basis for training multimodal generative models. NeMo Curator interfaces with the standardized WebDataset format for curating such datasets. :ref:`Image Embedding Creation ` Image embeddings are the backbone to many data curation operations in NeMo Curator. This section describes how to efficiently create embeddings for massive datasets. :ref:`Classifiers ` - NeMo Curator provides several ways to use common classifiers like aesthetic scoring, and not-safe-for-work (NSFW) scoring. + NeMo Curator provides several ways to use common classifiers like aesthetic scoring and not-safe-for-work (NSFW) scoring. :ref:`Semantic Deduplication ` - Semantic deduplication with image datasets has been shown to drastically improve model performance. NeMo Curator has a semnatic deduplication module that can work with any modality. + Semantic deduplication with image datasets has been shown to drastically improve model performance. NeMo Curator has a semantic deduplication module that can work with any modality. ------------------- Reference diff --git a/nemo_curator/datasets/image_text_pair_dataset.py b/nemo_curator/datasets/image_text_pair_dataset.py index 39427b7a..b580015c 100644 --- a/nemo_curator/datasets/image_text_pair_dataset.py +++ b/nemo_curator/datasets/image_text_pair_dataset.py @@ -29,21 +29,21 @@ class ImageTextPairDataset: """ - A collection of image text pairs stored in webdataset-like format on disk or in cloud storage. + A collection of image text pairs stored in WebDataset-like format on disk or in cloud storage. The exact format assumes a single directory with sharded .tar, .parquet, and (optionally) - .idx files. Each tar file should have a unique integer id as it's name (00000.tar, + .idx files. Each tar file should have a unique integer ID as its name (00000.tar, 00001.tar, 00002.tar, etc.). The tar files should contain images in .jpg files, text captions in .txt files, and metadata in .json files. Each record of the dataset is identified by - a unique id that is a mix of the shard id along with the offset of the record within a shard. - For example, the 32rd record of the 43rd shard would be in 00042.tar and have image 000420031.jpg, + a unique ID that is a mix of the shard ID along with the offset of the record within a shard. + For example, the 32nd record of the 43rd shard would be in 00042.tar and have image 000420031.jpg, caption 000420031.txt, and metadata 000420031.json (assuming zero indexing). In addition to the collection of tar files, ImageTextPairDataset expects there to be .parquet files in the root directory that follow the same naming convention as the shards (00042.tar -> 00042.parquet). - Each parquet file should contain an aggregated tabular form of the metadata for each record, with - each row in the parquet file corresponding to a record in that shard. The metadata, both in the parquet - files and the json files, must contain a unique id column that is the same as its record id (000420031 + Each Parquet file should contain an aggregated tabular form of the metadata for each record, with + each row in the Parquet file corresponding to a record in that shard. The metadata, both in the Parquet + files and the JSON files, must contain a unique ID column that is the same as its record ID (000420031 in our examples). Index files may also be in the directory to speed up dataloading with DALI. @@ -57,11 +57,11 @@ def __init__( self, path: str, metadata: dd.DataFrame, tar_files: List[str], id_col: str ) -> None: """ - Constructs an image text pair dataset. + Constructs an image-text pair dataset. Args: path (str): The root directory of the files. - metadata (dd.DataFrame): A dask cudf dataframe of the metadata. + metadata (dd.DataFrame): A Dask-cuDF DataFrame of the metadata. tar_files (List[str]): A list of paths to the tar files. id_col (str): The column storing the unique identifier for each record. """ @@ -73,10 +73,10 @@ def __init__( @classmethod def from_webdataset(cls, path: str, id_col: str): """ - Loads an ImageTextPairDataset from a webdataset + Loads an ImageTextPairDataset from a WebDataset Args: - path (str): The path to the webdataset-like format on disk or cloud storage. + path (str): The path to the WebDataset-like format on disk or cloud storage. id_col (str): The column storing the unique identifier for each record. """ metadata = dask_cudf.read_parquet(path) @@ -118,7 +118,7 @@ def save_metadata( ) -> None: """ Saves the metadata of the dataset to the specified path as a collection - of parquet files. + of Parquet files. Args: path (Optional[str]): The path to save the metadata to. If None, @@ -217,9 +217,9 @@ def to_webdataset( old_id_col: Optional[str] = None, ) -> None: """ - Saves the dataset to a webdataset format with parquet files. + Saves the dataset to a WebDataset format with Parquet files. Will reshard the tar files to the specified number of samples per shard. - The id value in ImageTextPairDataset.id_col will be overwritten with a new id. + The ID value in ImageTextPairDataset.id_col will be overwritten with a new ID. Args: path (str): The output path where the dataset should be written. @@ -229,9 +229,9 @@ def to_webdataset( samples_per_shard (int): The number of samples to include in each tar file. max_shards (int): The order of magnitude of the maximum number of shards that will be created from the dataset. Will be used to determine the - number of leading zeros in the shard/sample ids. + number of leading zeros in the shard/sample IDs. old_id_col (Optional[str]): If specified, will preserve the previous - id value in the given column. + ID value in the given column. """ max_samples_per_shard = math.ceil(math.log10(samples_per_shard)) filtered_metadata = self.metadata[self.metadata[filter_column]] diff --git a/nemo_curator/image/classifiers/aesthetic.py b/nemo_curator/image/classifiers/aesthetic.py index a0dcfd20..c8881d12 100644 --- a/nemo_curator/image/classifiers/aesthetic.py +++ b/nemo_curator/image/classifiers/aesthetic.py @@ -68,7 +68,7 @@ def __init__( embeddings. pred_column (str): The column name to be added where the aesthetic scores will be stored. - pred_type (Union[str, type]): The datatype of the pred_column + pred_type (Union[str, type]): The datatype of the pred_column. batch_size (int): If greater than 0, the image embeddings will be processed in batches of at most this size. If less than 0, all embeddings will be processed at once. diff --git a/nemo_curator/image/classifiers/base.py b/nemo_curator/image/classifiers/base.py index 857161a6..7ad9de01 100644 --- a/nemo_curator/image/classifiers/base.py +++ b/nemo_curator/image/classifiers/base.py @@ -44,7 +44,7 @@ def __init__( embedding_size: int, ) -> None: """ - Constructs an image classifier + Constructs an image classifier. Args: model_name (str): A unqiue name to identify the model on each worker @@ -53,7 +53,7 @@ def __init__( embeddings. pred_column (str): The column name to be added where the classifier's predictions will be stored. - pred_type (Union[str, type]): The datatype of the pred_column + pred_type (Union[str, type]): The datatype of the pred_column. batch_size (int): If greater than 0, the image embeddings will be processed in batches of at most this size. If less than 0, all embeddings will be processed at once. @@ -67,7 +67,7 @@ def __init__( def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: """ - Classifies all embeddings in the dataset + Classifies all embeddings in the dataset. Args: dataset (ImageTextPairDataset): The dataset to classify. @@ -135,7 +135,7 @@ def _run_inference(self, partition, partition_info=None): @abstractmethod def load_model(self, device: str) -> Callable: """ - Loads the classifier model + Loads the classifier model. Args: device (str): A PyTorch device identifier that specifies what GPU @@ -154,7 +154,7 @@ def postprocess(self, series: cudf.Series) -> cudf.Series: them to the metadata. Args: - series (cudf.Series): The cudf series of raw model predictions. + series (cudf.Series): The cuDF series of raw model predictions. Returns: cudf.Series: The same series unmodified. Override in your classifier diff --git a/nemo_curator/image/classifiers/nsfw.py b/nemo_curator/image/classifiers/nsfw.py index f641b887..ef18fed7 100644 --- a/nemo_curator/image/classifiers/nsfw.py +++ b/nemo_curator/image/classifiers/nsfw.py @@ -76,7 +76,7 @@ def __init__( embeddings. pred_column (str): The column name to be added where the nsfw scores will be stored. - pred_type (Union[str, type]): The datatype of the pred_column + pred_type (Union[str, type]): The datatype of the pred_column. batch_size (int): If greater than 0, the image embeddings will be processed in batches of at most this size. If less than 0, all embeddings will be processed at once. diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index cebb4998..d910e170 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -41,7 +41,7 @@ def __init__( classifiers: Iterable[ImageClassifier], ) -> None: """ - Constructs an image embedder + Constructs an image embedder. Args: model_name (str): A unqiue name to identify the model on each worker @@ -59,10 +59,10 @@ def __init__( def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: """ - Generates image embeddings for all images in the dataset + Generates image embeddings for all images in the dataset. Args: - dataset (ImageTextPairDataset): The dataset to create image embeddings for + dataset (ImageTextPairDataset): The dataset to create image embeddings for. Returns: ImageTextPairDataset: A dataset with image embeddings and potentially @@ -154,27 +154,27 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): @abstractmethod def load_dataset_shard(self, tar_path: str) -> Iterable: """ - Loads images and metadata from a tarfile in the dataset + Loads images and metadata from a tarfile in the dataset. Args: - tar_path (str): The path to a tar file shard in the input webdataset. + tar_path (str): The path to a tar file shard in the input WebDataset. Returns: Iterable: An iterator over the dataset. Each iteration should produce - A tuple of (image, metadata) pairs. The batch of images will be passed + a tuple of (image, metadata) pairs. The batch of images will be passed directly to the model created by ImageEmbedder.load_embedding_model. The metadata must be a list of dictionaries. Each element of the list must correspond to the image in the batch at the same position. Each dictionary must contain a field that is the same as - id_field in the dataset. This id field in the metadata will be used - to match the image to the its record in the metadata (parquet) files. + id_field in the dataset. This ID field in the metadata will be used + to match the image to the its record in the metadata (Parquet) files. """ pass @abstractmethod def load_embedding_model(self, device: str) -> Callable: """ - Loads the model used to generate image embeddings + Loads the model used to generate image embeddings. Args: device (str): A PyTorch device identifier that specifies what GPU diff --git a/nemo_curator/image/embedders/timm.py b/nemo_curator/image/embedders/timm.py index 4adea8d4..fac2fba2 100644 --- a/nemo_curator/image/embedders/timm.py +++ b/nemo_curator/image/embedders/timm.py @@ -47,7 +47,7 @@ def __init__( use_index_files: bool = False, ) -> None: """ - Constructs the embedder + Constructs the embedder. Args: model_name (str): The timm model to use. A list of available models @@ -93,18 +93,18 @@ def __init__( def load_dataset_shard(self, tar_path: str): """ - Loads a webdataset tar shard using DALI + Loads a WebDataset tar shard using DALI. Args: - tar_path (str): The path of the tar shard to load + tar_path (str): The path of the tar shard to load. Returns: Iterable: An iterator over the dataset. Each tar file - must have 3 files per record. A jpg file, a txt file, - and a json file. The jpg file must contain the image, the - txt file must contain the associated caption, and the - json must contain the metadata for the record (including - its id). Images will be loaded using DALI. + must have 3 files per record: a .jpg file, a .txt file, + and a .json file. The .jpg file must contain the image, the + .txt file must contain the associated caption, and the + .json must contain the metadata for the record (including + its ID). Images will be loaded using DALI. """ # Create the DALI pipeline @@ -165,7 +165,7 @@ def webdataset_pipeline(_tar_path: str): def load_embedding_model(self, device="cuda"): """ - Loads the model used to generate image embeddings + Loads the model used to generate image embeddings. Args: device (str): A PyTorch device identifier that specifies what GPU From 9746a66c2a9b93228ee98dea0c32c41c721b52cc Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 22 Oct 2024 14:49:07 -0700 Subject: [PATCH 55/57] Add section summary Signed-off-by: Ryan Wolf --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 34a19b12..1630d464 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,9 @@ All of our text pipelines have great multilingual support. - GPU Deduplication - [Semantic](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/semdedup.html) +These modules offer flexibility and permit reordering, with only a few exceptions. +All the modules automatically scale to multiple nodes to increase throughput. + ## Resources - [Documentation](docs/) From 2a069b6c05231aba8f340f0525dfc5ba1b712b6c Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 24 Oct 2024 10:16:47 -0700 Subject: [PATCH 56/57] Fix errors and REWORD GPU bullets in README Signed-off-by: Ryan Wolf --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 1630d464..21127d34 100644 --- a/README.md +++ b/README.md @@ -21,15 +21,15 @@ NeMo Curator provides a collection of scalable data curation modules for text an All of our text pipelines have great multilingual support. - [Download and Extraction](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/download.html) - - Default implementations Common Crawl, Wikipedia, and ArXiv sources + - Default implementations for Common Crawl, Wikipedia, and ArXiv sources - Easily customize and extend to other sources - [Language Identification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/languageidentificationunicodeformatting.html) - [Unicode Reformatting](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/languageidentificationunicodeformatting.html) - [Heuristic Filtering](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html) - Classifier Filtering - - [fastText]((https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html)) - - GPU-based models: [Domain, Quality, and Safety Classification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/distributeddataclassification.html) -- **GPU Deduplication** + - [fastText](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html) + - GPU-Accelerated models: [Domain, Quality, and Safety Classification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/distributeddataclassification.html) +- **GPU-Accelerated Deduplication** - [Exact Deduplication](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html) - [Fuzzy Deduplication](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html) via MinHash Locality Sensitive Hashing - [Semantic Deduplication](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/semdedup.html) From 34318e7cf01f9c2243fd67e3ca05580decda343e Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 24 Oct 2024 11:24:17 -0700 Subject: [PATCH 57/57] Fix how table of contents displays with new sections Signed-off-by: Ryan Wolf --- docs/user-guide/index.rst | 40 +++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/docs/user-guide/index.rst b/docs/user-guide/index.rst index 7baef104..1db64716 100644 --- a/docs/user-guide/index.rst +++ b/docs/user-guide/index.rst @@ -37,6 +37,23 @@ Text Curation :ref:`Personally Identifiable Information Identification and Removal ` The purpose of the personally identifiable information (PII) redaction tool is to help scrub sensitive data out of training datasets +.. toctree:: + :maxdepth: 4 + :titlesonly: + + + download.rst + documentdataset.rst + cpuvsgpu.rst + qualityfiltering.rst + languageidentificationunicodeformatting.rst + gpudeduplication.rst + semdedup.rst + syntheticdata.rst + taskdecontamination.rst + personalidentifiableinformationidentificationandremoval.rst + distributeddataclassification.rst + ------------------- Image Curation ------------------- @@ -56,6 +73,16 @@ Image Curation :ref:`Semantic Deduplication ` Semantic deduplication with image datasets has been shown to drastically improve model performance. NeMo Curator has a semantic deduplication module that can work with any modality. +.. toctree:: + :maxdepth: 4 + :titlesonly: + + image/gettingstarted.rst + image/datasets.rst + image/classifiers/index.rst + semdedup.rst + + ------------------- Reference ------------------- @@ -83,19 +110,8 @@ Reference :titlesonly: - download.rst - documentdataset.rst - cpuvsgpu.rst - qualityfiltering.rst - languageidentificationunicodeformatting.rst - gpudeduplication.rst - semdedup.rst - syntheticdata.rst - taskdecontamination.rst - personalidentifiableinformationidentificationandremoval.rst - distributeddataclassification.rst kubernetescurator.rst sparkother.rst bestpractices.rst nextsteps.rst - api/index.rst + api/index.rst \ No newline at end of file