diff --git a/colab_env/tagging.py b/colab_env/tagging.py index 7049fda..7aa99f0 100644 --- a/colab_env/tagging.py +++ b/colab_env/tagging.py @@ -1,27 +1,23 @@ # https://github.com/neggles/wdv3-timm/blob/main/wdv3_timm.py -import os, time, io -import numpy as np -import onnxruntime as rt -from huggingface_hub import hf_hub_download -from PIL import Image +import os, time import pandas as pd import argparse import traceback, sys import re +from pathlib import Path from typing import List, Tuple, Dict, Any, Optional, Callable, Protocol -from simple_parsing import field, parse_known_args -from huggingface_hub.utils import HfHubHTTPError +import numpy as np from numpy import signedinteger +from PIL import Image import timm +from timm.data import create_transform, resolve_data_config import torch from torch import Tensor, nn from torch.nn import functional as F -from timm.data import create_transform, resolve_data_config - -from dataclasses import dataclass -from pathlib import Path +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import HfHubHTTPError kaomojis: List[str] = [ "0_0", @@ -46,66 +42,64 @@ ] TAGGER_VIT_MODEL_REPO: str = "SmilingWolf/wd-eva02-large-tagger-v3" -MODEL_FILE_NAME: str = "model.onnx" -LABEL_FILENAME: str = "selected_tags.csv" +# MODEL_FILE_NAME: str = "model.onnx" +# LABEL_FILENAME: str = "selected_tags.csv" EXTENSIONS: List[str] = ['.png', '.jpg', '.jpeg', ".PNG", ".JPG", ".JPEG"] -BATCH_SIZE: int = 100 +BATCH_SIZE: int = 10 +PROGRESS_INTERVAL: int = 100 torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -transform = None - - -@dataclass -class LabelData: - names: list[str] - rating: list[np.int64] - general: list[np.int64] - character: list[np.int64] - - -@dataclass -class ScriptOptions: - image_file: Path = field(positional=True) - model: str = field(default="vit") - gen_threshold: float = field(default=0.35) - char_threshold: float = field(default=0.75) - - -def get_tags( - probs: Tensor, - labels: LabelData, - gen_threshold: float, - char_threshold: float, -): - # Convert indices+probs to labels - probs = list(zip(labels.names, probs.numpy())) - - # First 4 labels are actually ratings - rating_labels = dict([probs[i] for i in labels.rating]) - - # General labels, pick any where prediction confidence > threshold - gen_labels = [probs[i] for i in labels.general] - gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) - gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) - - # Character labels, pick any where prediction confidence > threshold - char_labels = [probs[i] for i in labels.character] - char_labels = dict([x for x in char_labels if x[1] > char_threshold]) - char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) - - # Combine general and character labels, sort by confidence - combined_names = [x for x in gen_labels] - combined_names.extend([x for x in char_labels]) - - # Convert to a string suitable for use as a training caption - caption = ", ".join(combined_names) - taglist = caption.replace("_", " ").replace("(", "\(").replace(")", "\)") - - return caption, taglist, rating_labels, char_labels, gen_labels - +# for apple silicon +if torch.backends.mps.is_available(): + torch_device = torch.device("mps") + +# @dataclass +# class LabelData: +# names: list[str] +# rating: list[np.int64] +# general: list[np.int64] +# character: list[np.int64] + +# @dataclass +# class ScriptOptions: +# image_file: Path = field(positional=True) +# model: str = field(default="vit") +# gen_threshold: float = field(default=0.35) +# char_threshold: float = field(default=0.75) + +# def get_tags( +# probs: Tensor, +# labels: LabelData, +# gen_threshold: float, +# char_threshold: float, +# ): +# # Convert indices+probs to labels +# probs = list(zip(labels.names, probs.numpy())) +# +# # First 4 labels are actually ratings +# rating_labels = dict([probs[i] for i in labels.rating]) +# +# # General labels, pick any where prediction confidence > threshold +# gen_labels = [probs[i] for i in labels.general] +# gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) +# gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) +# +# # Character labels, pick any where prediction confidence > threshold +# char_labels = [probs[i] for i in labels.character] +# char_labels = dict([x for x in char_labels if x[1] > char_threshold]) +# char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) +# +# # Combine general and character labels, sort by confidence +# combined_names = [x for x in gen_labels] +# combined_names.extend([x for x in char_labels]) +# +# # Convert to a string suitable for use as a training caption +# caption = ", ".join(combined_names) +# taglist = caption.replace("_", " ").replace("(", "\(").replace(")", "\)") +# +# return caption, taglist, rating_labels, char_labels, gen_labels def mcut_threshold(probs: np.ndarray) -> float: sorted_probs: np.ndarray = probs[probs.argsort()[::-1]] @@ -114,19 +108,17 @@ def mcut_threshold(probs: np.ndarray) -> float: thresh: float = (sorted_probs[t] + sorted_probs[t + 1]) / 2 return thresh - -def load_labels(dataframe: pd.DataFrame) -> Tuple[List[str], List[int], List[int], List[int]]: - name_series: pd.Series = dataframe["name"] - name_series = name_series.map( - lambda x: x.replace("_", " ") if x not in kaomojis else x - ) - tag_names: List[str] = name_series.tolist() - - rating_indexes: List[int] = list(np.where(dataframe["category"] == 9)[0]) - general_indexes: List[int] = list(np.where(dataframe["category"] == 0)[0]) - character_indexes: List[int] = list(np.where(dataframe["category"] == 4)[0]) - return tag_names, rating_indexes, general_indexes, character_indexes - +# def load_labels(dataframe: pd.DataFrame) -> Tuple[List[str], List[int], List[int], List[int]]: +# name_series: pd.Series = dataframe["name"] +# name_series = name_series.map( +# lambda x: x.replace("_", " ") if x not in kaomojis else x +# ) +# tag_names: List[str] = name_series.tolist() +# +# rating_indexes: List[int] = list(np.where(dataframe["category"] == 9)[0]) +# general_indexes: List[int] = list(np.where(dataframe["category"] == 0)[0]) +# character_indexes: List[int] = list(np.where(dataframe["category"] == 4)[0]) +# return tag_names, rating_indexes, general_indexes, character_indexes def print_traceback() -> None: tb: traceback.StackSummary = traceback.extract_tb(sys.exc_info()[2]) @@ -143,14 +135,14 @@ def print_traceback() -> None: class Predictor: def __init__(self) -> None: - self.model_target_size: Optional[int] = None + # self.model_target_size: Optional[int] = None self.last_loaded_repo: Optional[str] = None - self.tagger_model_path: Optional[str] = None - self.tagger_model: Optional[rt.InferenceSession] = None + self.tagger_model: Optional[nn.Module] = None self.tag_names: Optional[List[str]] = None - self.rating_indexes: Optional[List[int]] = None - self.general_indexes: Optional[List[int]] = None - self.character_indexes: Optional[List[int]] = None + self.rating_index: Optional[List[int]] = None + self.general_index: Optional[List[int]] = None + self.character_index: Optional[List[int]] = None + self.transform: Optional[Callable] = None def list_files_recursive(self, tarfile_path: str) -> List[str]: file_list: List[str] = [] @@ -161,8 +153,8 @@ def list_files_recursive(self, tarfile_path: str) -> List[str]: file_list.append(file_path) return file_list - def prepare_image(self, image: Image.Image) -> np.ndarray: - target_size: int = self.model_target_size + def prepare_image(self, image: Image.Image) -> Image.Image: + # target_size: int = self.model_target_size if image.mode in ('RGBA', 'LA'): background: Image.Image = Image.new("RGB", image.size, (255, 255, 255)) @@ -179,23 +171,25 @@ def prepare_image(self, image: Image.Image) -> np.ndarray: padded_image: Image.Image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) padded_image.paste(image, (pad_left, pad_top)) - if max_dim != target_size: - padded_image = padded_image.resize( - (target_size, target_size), - Image.BICUBIC, - ) + # if max_dim != target_size: + # padded_image = padded_image.resize( + # (target_size, target_size), + # Image.BICUBIC, + # ) - image_array: np.ndarray = np.asarray(padded_image, dtype=np.float32) - image_array = image_array[:, :, ::-1] + # image_array: np.ndarray = np.asarray(padded_image, dtype=np.float32) + # image_array = image_array[:, :, ::-1] + # + # return np.expand_dims(image_array, axis=0) - return np.expand_dims(image_array, axis=0) + return padded_image def load_labels_hf( self, repo_id: str, revision: Optional[str] = None, token: Optional[str] = None, - ) -> LabelData: + ) -> None: try: csv_path = hf_hub_download( repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token @@ -206,25 +200,23 @@ def load_labels_hf( df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) self.rating_index = list(np.where(df["category"] == 9)[0]) - self.general_index = list(np.where(df["category"] == 0)[0]), - self.character_index = list(np.where(df["category"] == 4)[0]), + self.general_index = list(np.where(df["category"] == 0)[0]) + self.character_index = list(np.where(df["category"] == 4)[0]) self.tag_names = df["name"].tolist() def load_model(self) -> None: - global transform - if self.tagger_model is not None: return - model: nn.Module = timm.create_model("hf-hub:" + TAGGER_VIT_MODEL_REPO).eval() + self.tagger_model = timm.create_model("hf-hub:" + TAGGER_VIT_MODEL_REPO).eval() state_dict = timm.models.load_state_dict_from_hf(TAGGER_VIT_MODEL_REPO) - model.load_state_dict(state_dict) + self.tagger_model.load_state_dict(state_dict) print("Loading tag list...") self.load_labels_hf(repo_id=TAGGER_VIT_MODEL_REPO) print("Creating data transform...") - transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model)) + self.transform = create_transform(**resolve_data_config(self.tagger_model.pretrained_cfg, model=self.tagger_model)) # self.tagger_model_path = hf_hub_download(repo_id=TAGGER_VIT_MODEL_REPO, filename=MODEL_FILE_NAME) # self.tagger_model = rt.InferenceSession(self.tagger_model_path, providers=['CUDAExecutionProvider']) @@ -252,16 +244,23 @@ def predict( character_thresh: float, character_mcut_enabled: bool, ) -> List[str]: - # run the model's input transform to convert to tensor and rescale - inputs: Tensor = transform(images).unsqueeze(0) - # NCHW image RGB to BGR - inputs = inputs[:, [2, 1, 0]] + inputs: Optional[Tensor] = None + for img in images: + img_tmp = self.prepare_image(img) + # run the model's input transform to convert to tensor and rescale + input: Tensor = self.transform(img_tmp).unsqueeze(0) + # NCHW image RGB to BGR + input = input[:, [2, 1, 0]] + if inputs is None: + inputs = input + else: + inputs = torch.cat((inputs, input), 0) print("Running inference...") with torch.inference_mode(): # move model to GPU, if available if torch_device.type != "cpu": - model = model.to(torch_device) + model = self.tagger_model.to(torch_device) inputs = inputs.to(torch_device) # run the model outputs = model.forward(inputs) @@ -269,23 +268,26 @@ def predict( outputs = F.sigmoid(outputs) # move inputs, outputs, and model back to to cpu if we were on GPU if torch_device.type != "cpu": - inputs = inputs.to("cpu") + # inputs = inputs.to("cpu") + # model = model.to("cpu") outputs = outputs.to("cpu") - model = model.to("cpu") print("Processing results...") - caption, taglist, ratings, character, general = get_tags( - probs=outputs.squeeze(0), - labels=labels, - gen_threshold=opts.gen_threshold, - char_threshold=opts.char_threshold, - ) + preds = outputs.numpy() + # caption, taglist, ratings, character, general = get_tags( + # probs=outputs.squeeze(0), + # labels=labels, + # gen_threshold=opts.gen_threshold, + # char_threshold=opts.char_threshold, + # ) + # print(preds) + # exit(1) ret_strings: List[str] = [] for idx in range(0, len(images)): - labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[0][idx].astype(float))) + labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[idx].astype(float))) - general_names: List[Tuple[str, float]] = [labels[i] for i in self.general_indexes] + general_names: List[Tuple[str, float]] = [labels[i] for i in self.general_index] if general_mcut_enabled: general_probs: np.ndarray = np.array([x[1] for x in general_names]) @@ -293,7 +295,7 @@ def predict( general_res: Dict[str, float] = {x[0]: x[1] for x in general_names if x[1] > general_thresh} - character_names: List[Tuple[str, float]] = [labels[i] for i in self.character_indexes] + character_names: List[Tuple[str, float]] = [labels[i] for i in self.character_index] if character_mcut_enabled: character_probs: np.ndarray = np.array([x[1] for x in character_names]) @@ -361,7 +363,7 @@ def process_directory(self, tarfile_path: str) -> None: cnt += 1 - if cnt - last_cnt >= BATCH_SIZE: + if cnt - last_cnt >= PROGRESS_INTERVAL: now: float = time.perf_counter() print(f'{cnt} files processed') diff: float = now - start @@ -383,10 +385,14 @@ def process_directory(self, tarfile_path: str) -> None: pass -def main(arg_str: str) -> None: +#def main(arg_str: str) -> None: +def main(arg_str: List[str]) -> None: + parser: argparse.ArgumentParser = argparse.ArgumentParser() + parser.add_argument('--dir', nargs=1, required=True, help='tagging target directory path') + args: argparse.Namespace = parser.parse_args(arg_str) predictor: Predictor = Predictor() - predictor.process_directory(arg_str) - + # predictor.process_directory(arg_str) + predictor.process_directory(args.dir[0]) -opts, _ = parse_known_args(ScriptOptions) -main('/content/freepik') \ No newline at end of file +#main('/content/freepik') +main(sys.argv[1:]) \ No newline at end of file