Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V0.5.0 forward merge #340

Merged
merged 12 commits into from
Nov 4, 2024
64 changes: 35 additions & 29 deletions nemo_curator/image/classifiers/nsfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import zipfile
from typing import Optional

import requests
Expand All @@ -23,33 +24,35 @@


# MLP code taken from LAION's CLIP-based-NSFW-Detector
# https://github.com/LAION-AI/CLIP-based-NSFW-Detector/blob/main/h14_nsfw_model.py
class H14_NSFW_Detector(nn.Module):
def __init__(self, input_size=1024):
# https://github.com/LAION-AI/CLIP-based-NSFW-Detector/issues/7
class Normalization(nn.Module):
def __init__(self, shape):
super().__init__()
self.input_size = input_size
self.layers = nn.Sequential(
nn.Linear(self.input_size, 1024),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 2048),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 16),
nn.Linear(16, 1),
)
self.register_buffer("mean", torch.zeros(shape))
self.register_buffer("variance", torch.ones(shape))

def forward(self, x):
return (x - self.mean) / self.variance.sqrt()


class NSFWModel(nn.Module):
def __init__(self):
super().__init__()
self.norm = Normalization([768])
self.linear_1 = nn.Linear(768, 64)
self.linear_2 = nn.Linear(64, 512)
self.linear_3 = nn.Linear(512, 256)
self.linear_4 = nn.Linear(256, 1)
self.act = nn.ReLU()
self.act_out = nn.Sigmoid()

def forward(self, x):
return self.layers(x)
x = self.norm(x)
x = self.act(self.linear_1(x))
x = self.act(self.linear_2(x))
x = self.act(self.linear_3(x))
x = self.act_out(self.linear_4(x))
return x


class NsfwClassifier(ImageClassifier):
Expand Down Expand Up @@ -90,7 +93,7 @@ def __init__(
pred_column=pred_column,
pred_type=float,
batch_size=batch_size,
embedding_size=1024,
embedding_size=768,
)

if model_path is None:
Expand All @@ -100,21 +103,24 @@ def __init__(

@staticmethod
def _get_default_model():
weights_name = "h14_nsfw.pth"
weights_name = "clip_autokeras_binary_nsfw.pth"
model_path = os.path.join(NEMO_CURATOR_HOME, weights_name)
os.makedirs(NEMO_CURATOR_HOME, exist_ok=True)

if not os.path.exists(model_path):
url = f"https://github.com/LAION-AI/CLIP-based-NSFW-Detector/blob/main/{weights_name}?raw=true"
url = "https://github.com/LAION-AI/CLIP-based-NSFW-Detector/files/10250461/clip_autokeras_binary_nsfw.zip"
r = requests.get(url)

with open(model_path, "wb") as f:
raw_zip_path = os.path.join(NEMO_CURATOR_HOME, "nsfw.zip")
with open(raw_zip_path, "wb") as f:
f.write(r.content)
with zipfile.ZipFile(raw_zip_path, "r") as f:
f.extractall(NEMO_CURATOR_HOME)

return model_path

def load_model(self, device):
model = H14_NSFW_Detector(input_size=self.embedding_size).to(device)
model = NSFWModel().to(device)
weights = torch.load(self.model_path, map_location=torch.device("cpu"))
model.load_state_dict(weights)
model.eval()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def req_file(filename, folder="requirements"):

setup(
name="nemo_curator",
version="0.5.0",
version="0.6.0.dev0",
description="Scalable Data Preprocessing Tool for "
"Training Large Language Models",
long_description=long_description,
Expand Down
123 changes: 123 additions & 0 deletions tutorials/image-curation/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import os
import tarfile
from functools import partial
from multiprocessing import Pool

import aiofiles
import aiohttp
import pandas as pd


async def download_image(session, url, filename):
async with session.get(url) as response:
if response.status == 200:
async with aiofiles.open(filename, mode="wb") as f:
await f.write(await response.read())
return True
return False


async def process_batch(batch, output_dir, batch_num):
tar_filename = os.path.join(output_dir, f"{batch_num:05d}.tar")
tmp_dir = os.path.join(output_dir, "tmp")
os.makedirs(tmp_dir, exist_ok=True)

metadatas = []
async with aiohttp.ClientSession() as session:
tasks = []
for i, (_, row) in enumerate(batch.iterrows()):
caption = row["TEXT"]
url = row["URL"]

key = f"{batch_num:05d}{i:04d}"
jpg_filename = os.path.join(tmp_dir, f"{key}.jpg")
txt_filename = os.path.join(tmp_dir, f"{key}.txt")
json_filename = os.path.join(tmp_dir, f"{key}.json")

meta = {"url": url, "caption": caption, "key": key}
metadatas.append(meta)

tasks.append(download_image(session, url, jpg_filename))

async with aiofiles.open(txt_filename, mode="w") as f:
await f.write(caption)

async with aiofiles.open(json_filename, mode="w") as f:
await f.write(json.dumps(meta))

results = await asyncio.gather(*tasks)

with tarfile.open(tar_filename, "w") as tar:
for i, success in enumerate(results):
if success:
key = f"{batch_num:05d}{i:04d}"
jpg_base = f"{key}.jpg"
txt_base = f"{key}.txt"
json_base = f"{key}.json"
jpg_tmp = os.path.join(tmp_dir, jpg_base)
txt_tmp = os.path.join(tmp_dir, txt_base)
json_tmp = os.path.join(tmp_dir, json_base)

tar.add(jpg_tmp, arcname=jpg_base)
tar.add(txt_tmp, arcname=txt_base)
tar.add(json_tmp, arcname=json_base)

# Clean up temporary files
for i in range(len(batch)):
key = f"{batch_num:05d}{i:04d}"
jpg_tmp = os.path.join(tmp_dir, f"{key}.jpg")
txt_tmp = os.path.join(tmp_dir, f"{key}.txt")
json_tmp = os.path.join(tmp_dir, f"{key}.json")

os.remove(jpg_tmp)
os.remove(txt_tmp)
os.remove(json_tmp)

# Write parquet
meta_df = pd.DataFrame(metadatas)
parquet_path = os.path.join(output_dir, f"{batch_num:05d}.parquet")
meta_df.to_parquet(parquet_path)


def process_parquet_chunk(chunk, output_dir):
batch_num, batch = chunk

asyncio.run(process_batch(batch, output_dir, batch_num))


def download_webdataset(
parquet_path, output_dir, entries_per_tar=10000, num_processes=2
):
os.makedirs(output_dir, exist_ok=True)

# Read the parquet file
df = pd.read_parquet(parquet_path)

# Split the dataframe into chunks for multiprocessing
chunks = [
(batch_num, df[i : i + entries_per_tar])
for batch_num, i in enumerate(range(0, len(df), entries_per_tar))
]

# Use multiprocessing to process chunks in parallel
with Pool(processes=num_processes) as pool:
func = partial(process_parquet_chunk, output_dir=output_dir)
pool.map(func, chunks)

tmp_dir = os.path.join(output_dir, "tmp")
os.rmdir(tmp_dir)
Loading