Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ryogrid committed Oct 14, 2024
2 parents e59d2a4 + c46d3d7 commit d1080ec
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 165 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ jobs:
- name: Mypy Check
uses: jpetrucciani/mypy-check@master
with:
path: /home/runner/work/anime-illust-image-searcher/anime-illust-image-searcher/*.py
path: .
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# Anime Style Illustration Specific Image Search App with Vit Tagger x LSI
# Anime Style Illustration Specific Image Search App with ViT Tagger x LSI
## What's This?
- Anime Style Illustration Specific Image Search App with ML Technique
- can be used for photos. but flexible photo search is offered by Google Photos or etc :)
- Search capabilities of cloud photo album services towards illustration image files are poor for some reason
- So, I wrote simple scripts

## Method
- Search Images matching with Query Texts on Latent Representation Vectors
- Search Images Matching with Query Texts on Latent Semantic Representation Vector Space
- Vectors are generated with embedding model: Visual Transformar (ViT) Tagger x Latent Semantic Indexing (LSI)
- LSI is Ssed for Covering Tagging Presision
- You can use tags to search which are difficult for tagging because search index is applyed LSI
- LSI is used for Covering Tagging Presision
- You can use tags to search which are difficult for tagging because the index data which is composed of vectors is applyed LSI
- implemented with Gensim lib
- ( Web UI is implemented with StreamLit )

Expand All @@ -18,7 +18,7 @@
- $ python make-tags-with-wd-tagger.py --dir "IMAGE FILES CONTAINED DIR PATH"
- The script searches directory structure recursively :)
- This takes quite a while...
- About 1 file/s at middle spec desktop PC (GPU is not used)
- About 0.5 sec/file at middle spec desktop PC (GPU is not used)
- AMD Ryzen 7 5700X 8-Core Processor 4.50 GHz
- You may speed up with editing the script to use CUDAExecutionProvider, CoreMLExecutionProvider and etc :)
- Plese see [here](https://onnxruntime.ai/docs/execution-providers/)
Expand All @@ -43,7 +43,7 @@
- Solution
- Search words you want to use from taggs-wd-tagger.txt with grep, editor or something for existance checking
- If exist, there is no problem. If not, you should think similar words and search it in same manner :)
- Charcter code of file pathes
- Character code of file pathes
- If file path contains charactors which can't be contered to Unicode or utf-8, scripts may ouput error message at processing the file
- But, it doesn't mean that your script usage is wrong. Though these files is ignored or not displayed at Web UI :|
- This is problem of current implentation. When you use scripts on Windows and charactor code of directory/file names isn't utf-8, the problem may occur
Expand All @@ -55,11 +55,11 @@

## TODO
- [ ] <del>Search on latent representation generated by CLIP model</del>
- This was tried but precition with current public available CLIP models which are not fit for anime style illust was bad :|
- This was alredy tried but precition was not good because current public available CLIP models are not fitting for anime style illust :|
- [ ] Weight specifying to keyword like prompt format of Stable Diffusion Web UI
- Current implemenataion uses all keywords faialy. But there is many cases that users want to emphasize specific keyword and can't get appropriate results without that!
- [ ] Incremental index updating at image files increasing
- [ ] Similar image search with specifying target image file
- [ ] Similar image search with specifying a image file
- [ ] Exporting found files list feature
- In text file. Once you get list, many other tools and viewer you like can be used :)
- [ ] Making binary package of this app which doesn't need python environment building
Expand Down
202 changes: 87 additions & 115 deletions make-tags-with-wd-tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import re
from typing import List, Tuple, Dict, Any, Optional, Callable, Protocol

# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
kaomojis = [
kaomojis: List[str] = [
"0_0",
"(o)_(o)",
"+_+",
Expand All @@ -32,122 +31,108 @@
"||_||",
]

VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
MODEL_FILE_NAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"

# extensions of image files to be processed
EXTENSIONS = ['.png', '.jpg', '.jpeg', ".PNG", ".JPG", ".JPEG"]

def mcut_threshold(probs):
"""
Maximum Cut Thresholding (MCut)
Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
for Multi-label Classification. In 11th International Symposium, IDA 2012
(pp. 172-183).
"""
sorted_probs = probs[probs.argsort()[::-1]]
difs = sorted_probs[:-1] - sorted_probs[1:]
t = difs.argmax()
thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
VIT_MODEL_DSV3_REPO: str = "SmilingWolf/wd-vit-tagger-v3"
MODEL_FILE_NAME: str = "model.onnx"
LABEL_FILENAME: str = "selected_tags.csv"

EXTENSIONS: List[str] = ['.png', '.jpg', '.jpeg', ".PNG", ".JPG", ".JPEG"]

def mcut_threshold(probs: np.ndarray) -> float:
sorted_probs: np.ndarray = probs[probs.argsort()[::-1]]
difs: np.ndarray = sorted_probs[:-1] - sorted_probs[1:]
t: int = difs.argmax()
thresh: float = (sorted_probs[t] + sorted_probs[t + 1]) / 2
return thresh

def load_labels(dataframe) -> Tuple[List[str], List[str], List[str], List[str]]:
name_series = dataframe["name"]
def load_labels(dataframe: pd.DataFrame) -> Tuple[List[str], List[int], List[int], List[int]]:
name_series: pd.Series = dataframe["name"]
name_series = name_series.map(
lambda x: x.replace("_", " ") if x not in kaomojis else x
)
tag_names: List[str] = name_series.tolist()

rating_indexes: List[str] = list(np.where(dataframe["category"] == 9)[0])
general_indexes: List[str] = list(np.where(dataframe["category"] == 0)[0])
character_indexes: List[str] = list(np.where(dataframe["category"] == 4)[0])
rating_indexes: List[int] = list(np.where(dataframe["category"] == 9)[0])
general_indexes: List[int] = list(np.where(dataframe["category"] == 0)[0])
character_indexes: List[int] = list(np.where(dataframe["category"] == 4)[0])
return tag_names, rating_indexes, general_indexes, character_indexes

def print_traceback():
tb = traceback.extract_tb(sys.exc_info()[2])
trace = traceback.format_list(tb)
def print_traceback() -> None:
tb: traceback.StackSummary = traceback.extract_tb(sys.exc_info()[2])
trace: List[str] = traceback.format_list(tb)
print('---- traceback ----')
for line in trace:
if '~^~' in line:
print(line.rstrip())
else:
text = re.sub(r'\n\s*', ' ', line.rstrip())
text: str = re.sub(r'\n\s*', ' ', line.rstrip())
print(text)
print('-------------------')

# list up files and filter by extension
def list_files_recursive(directory):
file_list = []
def list_files_recursive(directory: str) -> List[str]:
file_list: List[str] = []
for root, _, files in os.walk(directory):
for file in files:
file_path = os.path.join(root, file)
file_path: str = os.path.join(root, file)
if any(file_path.endswith(ext) for ext in EXTENSIONS):
file_list.append(file_path)
return file_list

class Predictor:
def __init__(self):
self.model_target_size = None
self.last_loaded_repo = None
self.tagger_model_path = None
self.tagger_model = None
self.tag_names = None
self.rating_indexes = None
self.general_indexes = None
self.character_indexes = None

def prepare_image(self, image):
target_size = self.model_target_size
def __init__(self) -> None:
self.model_target_size: Optional[int] = None
self.last_loaded_repo: Optional[str] = None
self.tagger_model_path: Optional[str] = None
self.tagger_model: Optional[rt.InferenceSession] = None
self.tag_names: Optional[List[str]] = None
self.rating_indexes: Optional[List[int]] = None
self.general_indexes: Optional[List[int]] = None
self.character_indexes: Optional[List[int]] = None

def prepare_image(self, image: Image.Image) -> np.ndarray:
target_size: int = self.model_target_size

if image.mode in ('RGBA', 'LA'):
background = Image.new("RGB", image.size, (255, 255, 255))
background: Image.Image = Image.new("RGB", image.size, (255, 255, 255))
background.paste(image, mask=image.split()[-1])
image = background
else:
image = image.convert("RGB")

# Pad image to square
image_shape = image.size
max_dim = max(image_shape)
pad_left = (max_dim - image_shape[0]) // 2
pad_top = (max_dim - image_shape[1]) // 2
image_shape: Tuple[int, int] = image.size
max_dim: int = max(image_shape)
pad_left: int = (max_dim - image_shape[0]) // 2
pad_top: int = (max_dim - image_shape[1]) // 2

padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
padded_image: Image.Image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
padded_image.paste(image, (pad_left, pad_top))

# Resize
if max_dim != target_size:
padded_image = padded_image.resize(
(target_size, target_size),
Image.BICUBIC,
)

# Convert to numpy array
image_array = np.asarray(padded_image, dtype=np.float32)

# Convert PIL-native RGB to BGR
image_array: np.ndarray = np.asarray(padded_image, dtype=np.float32)
image_array = image_array[:, :, ::-1]

return np.expand_dims(image_array, axis=0)

def load_model(self):
def load_model(self) -> None:
if self.tagger_model is not None:
return

self.tagger_model_path = hf_hub_download(repo_id=VIT_MODEL_DSV3_REPO, filename=MODEL_FILE_NAME)

self.tagger_model_path = hf_hub_download(repo_id=VIT_MODEL_DSV3_REPO, filename=MODEL_FILE_NAME)
self.tagger_model = rt.InferenceSession(self.tagger_model_path, providers=['CPUExecutionProvider'])
_, height, _, _ = self.tagger_model.get_inputs()[0].shape

self.model_target_size = height

csv_path = hf_hub_download(
csv_path: str = hf_hub_download(
VIT_MODEL_DSV3_REPO,
LABEL_FILENAME,
)
tags_df = pd.read_csv(csv_path)
sep_tags = load_labels(tags_df)
tags_df: pd.DataFrame = pd.read_csv(csv_path)
sep_tags: Tuple[List[str], List[int], List[int], List[int]] = load_labels(tags_df)

self.tag_names = sep_tags[0]
self.rating_indexes = sep_tags[1]
Expand All @@ -156,46 +141,38 @@ def load_model(self):

def predict(
self,
image,
general_thresh,
general_mcut_enabled,
character_thresh,
character_mcut_enabled,
):
image = self.prepare_image(image)
image: Image.Image,
general_thresh: float,
general_mcut_enabled: bool,
character_thresh: float,
character_mcut_enabled: bool,
) -> str:
image: np.ndarray = self.prepare_image(image)

input_name = self.tagger_model.get_inputs()[0].name
label_name = self.tagger_model.get_outputs()[0].name
preds = self.tagger_model.run([label_name], {input_name: image})[0]
input_name: str = self.tagger_model.get_inputs()[0].name
label_name: str = self.tagger_model.get_outputs()[0].name
preds: np.ndarray = self.tagger_model.run([label_name], {input_name: image})[0]

labels = list(zip(self.tag_names, preds[0].astype(float)))
labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[0].astype(float)))

# # First 4 labels are actually ratings: pick one with argmax
# ratings_names = [labels[i] for i in self.rating_indexes]
# rating = dict(ratings_names)

# Then we have general tags: pick any where prediction confidence > threshold
general_names = [labels[i] for i in self.general_indexes]
general_names: List[Tuple[str, float]] = [labels[i] for i in self.general_indexes]

if general_mcut_enabled:
general_probs = np.array([x[1] for x in general_names])
general_probs: np.ndarray = np.array([x[1] for x in general_names])
general_thresh = mcut_threshold(general_probs)

general_res = [x for x in general_names if x[1] > general_thresh]
general_res = dict(general_res)
general_res: Dict[str, float] = {x[0]: x[1] for x in general_names if x[1] > general_thresh}

# Everything else is characters: pick any where prediction confidence > threshold
character_names = [labels[i] for i in self.character_indexes]
character_names: List[Tuple[str, float]] = [labels[i] for i in self.character_indexes]

if character_mcut_enabled:
character_probs = np.array([x[1] for x in character_names])
character_probs: np.ndarray = np.array([x[1] for x in character_names])
character_thresh = mcut_threshold(character_probs)
character_thresh = max(0.15, character_thresh)

character_res = [x for x in character_names if x[1] > character_thresh]
character_res = dict(character_res)
character_res: Dict[str, float] = {x[0]: x[1] for x in character_names if x[1] > character_thresh}

sorted_general_strings = sorted(
sorted_general_strings: List[str] = sorted(
general_res.items(),
key=lambda x: x[1],
reverse=True,
Expand All @@ -206,10 +183,10 @@ def predict(
",".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
)

ret_string = sorted_general_strings
ret_string: str = sorted_general_strings

if len(character_res) > 0:
sorted_character_strings = sorted(
sorted_character_strings: List[str] = sorted(
character_res.items(),
key=lambda x: x[1],
reverse=True,
Expand All @@ -223,57 +200,52 @@ def predict(

return ret_string

# write a line to file
def write_to_file(self, csv_line):
def write_to_file(self, csv_line: str) -> None:
self.f.write(csv_line + '\n')
self.f.flush()

# root function
def process_directory(self, directory):
file_list = list_files_recursive(directory)
def process_directory(self, directory: str) -> None:
file_list: List[str] = list_files_recursive(directory)
print(f'{len(file_list)} files found')

# file for tagged results
self.f = open('tags-wd-tagger.txt', 'a', encoding='utf-8')

self.load_model()

start = time.perf_counter()
cnt = 0
# process each image file
start: float = time.perf_counter()
cnt: int = 0
for file_path in file_list:
try:
img = Image.open(file_path)
results_in_csv_format = self.predict(img, 0.3, True, 0.3, True)

# write result to file
img: Image.Image = Image.open(file_path)
results_in_csv_format: str = self.predict(img, 0.3, True, 0.3, True)

self.write_to_file(file_path + ',' + results_in_csv_format)

if cnt % 100 == 0:
now = time.perf_counter()
now: float = time.perf_counter()
print(f'{cnt} files processed')
diff = now - start
diff: float = now - start
print('{:.2f} seconds elapsed'.format(diff))
if cnt > 0:
time_per_file = diff / cnt
time_per_file: float = diff / cnt
print('{:.4f} seconds per file'.format(time_per_file))
print("", flush=True)

cnt += 1
except Exception as e:
error_class = type(e)
error_description = str(e)
err_msg = '%s: %s' % (error_class, error_description)
error_class: type = type(e)
error_description: str = str(e)
err_msg: str = '%s: %s' % (error_class, error_description)
print(err_msg)
print_traceback()
pass

def main():
parser = argparse.ArgumentParser()
def main() -> None:
parser: argparse.ArgumentParser = argparse.ArgumentParser()
parser.add_argument('--dir', nargs=1, required=True, help='tagging target directory path')
args = parser.parse_args()
args: argparse.Namespace = parser.parse_args()

predictor = Predictor()
predictor: Predictor = Predictor()
predictor.process_directory(args.dir[0])

if __name__ == "__main__":
Expand Down
Loading

0 comments on commit d1080ec

Please sign in to comment.