Skip to content

Commit

Permalink
Merge branch 'some-enhance-and-script-for-colab-env' of https://githu…
Browse files Browse the repository at this point in the history
…b.com/ryogrid/anime-illust-image-searcher into some-enhance-and-script-for-colab-env
  • Loading branch information
ryogrid committed Oct 24, 2024
2 parents c92bf00 + c55ad99 commit dcc4d27
Showing 1 changed file with 131 additions and 125 deletions.
256 changes: 131 additions & 125 deletions colab_env/tagging.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,23 @@
# https://github.com/neggles/wdv3-timm/blob/main/wdv3_timm.py

import os, time, io
import numpy as np
import onnxruntime as rt
from huggingface_hub import hf_hub_download
from PIL import Image
import os, time
import pandas as pd
import argparse
import traceback, sys
import re
from pathlib import Path
from typing import List, Tuple, Dict, Any, Optional, Callable, Protocol
from simple_parsing import field, parse_known_args
from huggingface_hub.utils import HfHubHTTPError

import numpy as np
from numpy import signedinteger
from PIL import Image
import timm
from timm.data import create_transform, resolve_data_config
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from timm.data import create_transform, resolve_data_config

from dataclasses import dataclass
from pathlib import Path
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError

kaomojis: List[str] = [
"0_0",
Expand All @@ -46,66 +42,64 @@
]

TAGGER_VIT_MODEL_REPO: str = "SmilingWolf/wd-eva02-large-tagger-v3"
MODEL_FILE_NAME: str = "model.onnx"
LABEL_FILENAME: str = "selected_tags.csv"
# MODEL_FILE_NAME: str = "model.onnx"
# LABEL_FILENAME: str = "selected_tags.csv"

EXTENSIONS: List[str] = ['.png', '.jpg', '.jpeg', ".PNG", ".JPG", ".JPEG"]

BATCH_SIZE: int = 100
BATCH_SIZE: int = 10
PROGRESS_INTERVAL: int = 100

torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = None


@dataclass
class LabelData:
names: list[str]
rating: list[np.int64]
general: list[np.int64]
character: list[np.int64]


@dataclass
class ScriptOptions:
image_file: Path = field(positional=True)
model: str = field(default="vit")
gen_threshold: float = field(default=0.35)
char_threshold: float = field(default=0.75)


def get_tags(
probs: Tensor,
labels: LabelData,
gen_threshold: float,
char_threshold: float,
):
# Convert indices+probs to labels
probs = list(zip(labels.names, probs.numpy()))

# First 4 labels are actually ratings
rating_labels = dict([probs[i] for i in labels.rating])

# General labels, pick any where prediction confidence > threshold
gen_labels = [probs[i] for i in labels.general]
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))

# Character labels, pick any where prediction confidence > threshold
char_labels = [probs[i] for i in labels.character]
char_labels = dict([x for x in char_labels if x[1] > char_threshold])
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))

# Combine general and character labels, sort by confidence
combined_names = [x for x in gen_labels]
combined_names.extend([x for x in char_labels])

# Convert to a string suitable for use as a training caption
caption = ", ".join(combined_names)
taglist = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")

return caption, taglist, rating_labels, char_labels, gen_labels

# for apple silicon
if torch.backends.mps.is_available():
torch_device = torch.device("mps")

# @dataclass
# class LabelData:
# names: list[str]
# rating: list[np.int64]
# general: list[np.int64]
# character: list[np.int64]

# @dataclass
# class ScriptOptions:
# image_file: Path = field(positional=True)
# model: str = field(default="vit")
# gen_threshold: float = field(default=0.35)
# char_threshold: float = field(default=0.75)

# def get_tags(
# probs: Tensor,
# labels: LabelData,
# gen_threshold: float,
# char_threshold: float,
# ):
# # Convert indices+probs to labels
# probs = list(zip(labels.names, probs.numpy()))
#
# # First 4 labels are actually ratings
# rating_labels = dict([probs[i] for i in labels.rating])
#
# # General labels, pick any where prediction confidence > threshold
# gen_labels = [probs[i] for i in labels.general]
# gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
# gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
#
# # Character labels, pick any where prediction confidence > threshold
# char_labels = [probs[i] for i in labels.character]
# char_labels = dict([x for x in char_labels if x[1] > char_threshold])
# char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
#
# # Combine general and character labels, sort by confidence
# combined_names = [x for x in gen_labels]
# combined_names.extend([x for x in char_labels])
#
# # Convert to a string suitable for use as a training caption
# caption = ", ".join(combined_names)
# taglist = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")
#
# return caption, taglist, rating_labels, char_labels, gen_labels

def mcut_threshold(probs: np.ndarray) -> float:
sorted_probs: np.ndarray = probs[probs.argsort()[::-1]]
Expand All @@ -114,19 +108,17 @@ def mcut_threshold(probs: np.ndarray) -> float:
thresh: float = (sorted_probs[t] + sorted_probs[t + 1]) / 2
return thresh


def load_labels(dataframe: pd.DataFrame) -> Tuple[List[str], List[int], List[int], List[int]]:
name_series: pd.Series = dataframe["name"]
name_series = name_series.map(
lambda x: x.replace("_", " ") if x not in kaomojis else x
)
tag_names: List[str] = name_series.tolist()

rating_indexes: List[int] = list(np.where(dataframe["category"] == 9)[0])
general_indexes: List[int] = list(np.where(dataframe["category"] == 0)[0])
character_indexes: List[int] = list(np.where(dataframe["category"] == 4)[0])
return tag_names, rating_indexes, general_indexes, character_indexes

# def load_labels(dataframe: pd.DataFrame) -> Tuple[List[str], List[int], List[int], List[int]]:
# name_series: pd.Series = dataframe["name"]
# name_series = name_series.map(
# lambda x: x.replace("_", " ") if x not in kaomojis else x
# )
# tag_names: List[str] = name_series.tolist()
#
# rating_indexes: List[int] = list(np.where(dataframe["category"] == 9)[0])
# general_indexes: List[int] = list(np.where(dataframe["category"] == 0)[0])
# character_indexes: List[int] = list(np.where(dataframe["category"] == 4)[0])
# return tag_names, rating_indexes, general_indexes, character_indexes

def print_traceback() -> None:
tb: traceback.StackSummary = traceback.extract_tb(sys.exc_info()[2])
Expand All @@ -143,14 +135,14 @@ def print_traceback() -> None:

class Predictor:
def __init__(self) -> None:
self.model_target_size: Optional[int] = None
# self.model_target_size: Optional[int] = None
self.last_loaded_repo: Optional[str] = None
self.tagger_model_path: Optional[str] = None
self.tagger_model: Optional[rt.InferenceSession] = None
self.tagger_model: Optional[nn.Module] = None
self.tag_names: Optional[List[str]] = None
self.rating_indexes: Optional[List[int]] = None
self.general_indexes: Optional[List[int]] = None
self.character_indexes: Optional[List[int]] = None
self.rating_index: Optional[List[int]] = None
self.general_index: Optional[List[int]] = None
self.character_index: Optional[List[int]] = None
self.transform: Optional[Callable] = None

def list_files_recursive(self, tarfile_path: str) -> List[str]:
file_list: List[str] = []
Expand All @@ -161,8 +153,8 @@ def list_files_recursive(self, tarfile_path: str) -> List[str]:
file_list.append(file_path)
return file_list

def prepare_image(self, image: Image.Image) -> np.ndarray:
target_size: int = self.model_target_size
def prepare_image(self, image: Image.Image) -> Image.Image:
# target_size: int = self.model_target_size

if image.mode in ('RGBA', 'LA'):
background: Image.Image = Image.new("RGB", image.size, (255, 255, 255))
Expand All @@ -179,23 +171,25 @@ def prepare_image(self, image: Image.Image) -> np.ndarray:
padded_image: Image.Image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
padded_image.paste(image, (pad_left, pad_top))

if max_dim != target_size:
padded_image = padded_image.resize(
(target_size, target_size),
Image.BICUBIC,
)
# if max_dim != target_size:
# padded_image = padded_image.resize(
# (target_size, target_size),
# Image.BICUBIC,
# )

image_array: np.ndarray = np.asarray(padded_image, dtype=np.float32)
image_array = image_array[:, :, ::-1]
# image_array: np.ndarray = np.asarray(padded_image, dtype=np.float32)
# image_array = image_array[:, :, ::-1]
#
# return np.expand_dims(image_array, axis=0)

return np.expand_dims(image_array, axis=0)
return padded_image

def load_labels_hf(
self,
repo_id: str,
revision: Optional[str] = None,
token: Optional[str] = None,
) -> LabelData:
) -> None:
try:
csv_path = hf_hub_download(
repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token
Expand All @@ -206,25 +200,23 @@ def load_labels_hf(

df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
self.rating_index = list(np.where(df["category"] == 9)[0])
self.general_index = list(np.where(df["category"] == 0)[0]),
self.character_index = list(np.where(df["category"] == 4)[0]),
self.general_index = list(np.where(df["category"] == 0)[0])
self.character_index = list(np.where(df["category"] == 4)[0])
self.tag_names = df["name"].tolist()

def load_model(self) -> None:
global transform

if self.tagger_model is not None:
return

model: nn.Module = timm.create_model("hf-hub:" + TAGGER_VIT_MODEL_REPO).eval()
self.tagger_model = timm.create_model("hf-hub:" + TAGGER_VIT_MODEL_REPO).eval()
state_dict = timm.models.load_state_dict_from_hf(TAGGER_VIT_MODEL_REPO)
model.load_state_dict(state_dict)
self.tagger_model.load_state_dict(state_dict)

print("Loading tag list...")
self.load_labels_hf(repo_id=TAGGER_VIT_MODEL_REPO)

print("Creating data transform...")
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
self.transform = create_transform(**resolve_data_config(self.tagger_model.pretrained_cfg, model=self.tagger_model))

# self.tagger_model_path = hf_hub_download(repo_id=TAGGER_VIT_MODEL_REPO, filename=MODEL_FILE_NAME)
# self.tagger_model = rt.InferenceSession(self.tagger_model_path, providers=['CUDAExecutionProvider'])
Expand Down Expand Up @@ -252,48 +244,58 @@ def predict(
character_thresh: float,
character_mcut_enabled: bool,
) -> List[str]:
# run the model's input transform to convert to tensor and rescale
inputs: Tensor = transform(images).unsqueeze(0)
# NCHW image RGB to BGR
inputs = inputs[:, [2, 1, 0]]
inputs: Optional[Tensor] = None
for img in images:
img_tmp = self.prepare_image(img)
# run the model's input transform to convert to tensor and rescale
input: Tensor = self.transform(img_tmp).unsqueeze(0)
# NCHW image RGB to BGR
input = input[:, [2, 1, 0]]
if inputs is None:
inputs = input
else:
inputs = torch.cat((inputs, input), 0)

print("Running inference...")
with torch.inference_mode():
# move model to GPU, if available
if torch_device.type != "cpu":
model = model.to(torch_device)
model = self.tagger_model.to(torch_device)
inputs = inputs.to(torch_device)
# run the model
outputs = model.forward(inputs)
# apply the final activation function (timm doesn't support doing this internally)
outputs = F.sigmoid(outputs)
# move inputs, outputs, and model back to to cpu if we were on GPU
if torch_device.type != "cpu":
inputs = inputs.to("cpu")
# inputs = inputs.to("cpu")
# model = model.to("cpu")
outputs = outputs.to("cpu")
model = model.to("cpu")

print("Processing results...")
caption, taglist, ratings, character, general = get_tags(
probs=outputs.squeeze(0),
labels=labels,
gen_threshold=opts.gen_threshold,
char_threshold=opts.char_threshold,
)
preds = outputs.numpy()
# caption, taglist, ratings, character, general = get_tags(
# probs=outputs.squeeze(0),
# labels=labels,
# gen_threshold=opts.gen_threshold,
# char_threshold=opts.char_threshold,
# )

# print(preds)
# exit(1)
ret_strings: List[str] = []
for idx in range(0, len(images)):
labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[0][idx].astype(float)))
labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[idx].astype(float)))

general_names: List[Tuple[str, float]] = [labels[i] for i in self.general_indexes]
general_names: List[Tuple[str, float]] = [labels[i] for i in self.general_index]

if general_mcut_enabled:
general_probs: np.ndarray = np.array([x[1] for x in general_names])
general_thresh = mcut_threshold(general_probs)

general_res: Dict[str, float] = {x[0]: x[1] for x in general_names if x[1] > general_thresh}

character_names: List[Tuple[str, float]] = [labels[i] for i in self.character_indexes]
character_names: List[Tuple[str, float]] = [labels[i] for i in self.character_index]

if character_mcut_enabled:
character_probs: np.ndarray = np.array([x[1] for x in character_names])
Expand Down Expand Up @@ -361,7 +363,7 @@ def process_directory(self, tarfile_path: str) -> None:

cnt += 1

if cnt - last_cnt >= BATCH_SIZE:
if cnt - last_cnt >= PROGRESS_INTERVAL:
now: float = time.perf_counter()
print(f'{cnt} files processed')
diff: float = now - start
Expand All @@ -383,10 +385,14 @@ def process_directory(self, tarfile_path: str) -> None:
pass


def main(arg_str: str) -> None:
#def main(arg_str: str) -> None:
def main(arg_str: List[str]) -> None:
parser: argparse.ArgumentParser = argparse.ArgumentParser()
parser.add_argument('--dir', nargs=1, required=True, help='tagging target directory path')
args: argparse.Namespace = parser.parse_args(arg_str)
predictor: Predictor = Predictor()
predictor.process_directory(arg_str)

# predictor.process_directory(arg_str)
predictor.process_directory(args.dir[0])

opts, _ = parse_known_args(ScriptOptions)
main('/content/freepik')
#main('/content/freepik')
main(sys.argv[1:])

0 comments on commit dcc4d27

Please sign in to comment.