Skip to content

Commit

Permalink
implemented cfeature based reranking version. testing is yet.
Browse files Browse the repository at this point in the history
  • Loading branch information
ryogrid committed Nov 7, 2024
1 parent 92f57c0 commit d596258
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 110 deletions.
93 changes: 22 additions & 71 deletions gen_cfeatures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# https://huggingface.co/spaces/deepghs/ccip/blob/f7d50a4f5dd3d4681984187308d70839ff0d3f5b/ccip.py

import datetime
import os, time

Expand All @@ -8,23 +10,28 @@

import json
import os.path
from functools import lru_cache
from io import TextIOWrapper
from typing import Union, List, Optional
from typing import List, Optional

import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download, HfFileSystem

from imgutils.data import MultiImagesTyping, load_images, ImageTyping
from imgutils.utils import open_onnx_model
from onnxruntime import InferenceSession
from gensim.similarities import Similarity

try:
from imgutils.data import load_images, ImageTyping
from imgutils.utils import open_onnx_model
from onnxruntime import InferenceSession
except (ModuleNotFoundError, ImportError):
print('Please install the imgutils and onnxruntime package to use charactor feature extraction.')

try:
from typing import Literal
except (ModuleNotFoundError, ImportError):
from typing_extensions import Literal
try:
from typing_extensions import Literal
except (ModuleNotFoundError, ImportError):
pass

hf_fs = HfFileSystem()

Expand Down Expand Up @@ -73,28 +80,6 @@ def list_files_recursive(self, dir_path: str) -> List[str]:
file_list.append(file_path)
return file_list

# def prepare_image(self, image: Image.Image) -> Image.Image:
# #target_size: int = self.model_target_size
#
# if image.mode in ('RGBA', 'LA'):
# background: Image.Image = Image.new("RGB", image.size, (255, 255, 255))
# background.paste(image, mask=image.split()[-1])
# image = background
# else:
# # copy image to avoid error at convert method call
# image = image.copy()
# image = image.convert("RGB")
#
# image_shape: Tuple[int, int] = image.size
# max_dim: int = max(image_shape)
# pad_left: int = (max_dim - image_shape[0]) // 2
# pad_top: int = (max_dim - image_shape[1]) // 2
#
# padded_image: Image.Image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
# padded_image.paste(image, (pad_left, pad_top))
#
# return padded_image

def write_to_file(self, csv_line: str) -> None:
self.f.write(csv_line + '\n')

Expand All @@ -120,7 +105,6 @@ def _preprocess_image(self, image: Image.Image, size: int = 384):

return data

#@lru_cache()
def _open_feat_model(self, model) -> InferenceSession:
return open_onnx_model(hf_hub_download(
f'deepghs/ccip_onnx',
Expand All @@ -129,22 +113,9 @@ def _open_feat_model(self, model) -> InferenceSession:
mode = 'CUDAExecutionProvider',
)

# @lru_cache()
# def _open_metric_model(self, model):
# return open_onnx_model(hf_hub_download(
# f'deepghs/ccip_onnx',
# f'{model}/model_metrics.onnx',
# ))
#
# @lru_cache()
def _open_metrics(self, model):
with open(hf_hub_download(f'deepghs/ccip_onnx', f'{model}/metrics.json'), 'r') as f:
return json.load(f)
#
# @lru_cache()
# def _open_cluster_metrics(self, model):
# with open(hf_hub_download(f'deepghs/ccip_onnx', f'{model}/cluster.json'), 'r') as f:
# return json.load(f)

#def ccip_batch_extract_features(self, images: MultiImagesTyping, size: int = 384, model: str = _DEFAULT_MODEL_NAMES):
def ccip_batch_extract_features(self, images: List[np.ndarray], size: int = 384,
Expand Down Expand Up @@ -218,14 +189,6 @@ def ccip_default_threshold(self, model: str = _DEFAULT_MODEL_NAMES) -> float:
"""
return self._open_metrics(model)['threshold']

# _FeatureOrImage = Union[ImageTyping, np.ndarray]

# def _p_feature(self, x: _FeatureOrImage, size: int = 384, model: str = _DEFAULT_MODEL_NAMES):
# if isinstance(x, np.ndarray): # if feature
# return x
# else: # is image or path
# return self.ccip_extract_feature(x, size, model)

def predict(
self,
images: List[np.ndarray],
Expand All @@ -234,25 +197,6 @@ def predict(
ret = self.ccip_batch_extract_features(images)
print("Processing results...")
return ret
# batched_tensor = torch.stack(tensors, dim=0)
#
# print("Running inference...")
# with torch.inference_mode():
# # move model to GPU, if available
# model = self.tagger_model
# if torch_device.type != "cpu":
# model = self.tagger_model.to(torch_device)
# batched_tensor = batched_tensor.to(torch_device)
# # run the model
# outputs = model.forward(batched_tensor)
# # apply the final activation function (timm doesn't support doing this internally)
# outputs = F.sigmoid(outputs)
# # move inputs, outputs, and model back to to cpu if we were on GPU
# if torch_device.type != "cpu":
# outputs = outputs.to("cpu")
#
# print("Processing results...")
# preds = outputs.numpy()

def gen_image_ndarray(self, file_path) -> np.ndarray | None:
try:
Expand All @@ -266,6 +210,14 @@ def gen_image_ndarray(self, file_path) -> np.ndarray | None:
print(err_msg)
return None

def get_image_feature(self, file_path: str) -> np.ndarray:
if self.cindex is None:
self.cindex = Similarity.load('charactor-featues-idx')
self.threshold = self.ccip_default_threshold(_DEFAULT_MODEL_NAMES)

img: np.ndarray = self.gen_image_ndarray(file_path)
return self.predict([img])[0]

def write_vecs_to_index(self, vecs: np.ndarray) -> bool:
for vec in vecs:
if self.cindex is None:
Expand All @@ -277,7 +229,6 @@ def process_directory(self, dir_path: str, added_date: datetime.date | None = No
file_list: List[str] = self.list_files_recursive(dir_path)
print(f'{len(file_list)} files found')

# self.load_model()
self.embed_model = self._open_feat_model(_DEFAULT_MODEL_NAMES)
self.threshold = self.ccip_default_threshold(_DEFAULT_MODEL_NAMES)
self.f = open('charactor-featues-idx.csv', 'a', encoding='utf-8')
Expand Down
154 changes: 115 additions & 39 deletions webui.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import os
import sys

from gensim import corpora
Expand All @@ -14,6 +15,9 @@
import time
from typing import List, Tuple, Dict, Any, Optional, Protocol

# for use character features vector
from gen_cfeatures import Predictor

# $ streamlit run webui.py

ss: SessionStateProxy = st.session_state
Expand All @@ -23,6 +27,11 @@
index: Optional[MatrixSimilarity] = None
dictionary: Optional[corpora.Dictionary] = None

cfeatures_idx: Optional[MatrixSimilarity] = None
cfeature_filepath_idx: Optional[List[str]] = None
predictor: Optional[Predictor] = None
cfeature_reranking_mode = False

NG_WORDS: List[str] = ['language', 'english_text', 'pixcel_art']

class Arguments(Protocol):
Expand Down Expand Up @@ -174,51 +183,21 @@ def get_embedded_vector_by_doc_id(doc_id: int) -> List[Tuple[int, float]]:
doc_doc2vec: List[Tuple[int, float]] = [(ii, val) for ii, val in enumerate(embed_vec)]
return doc_doc2vec

def find_similar_documents(new_doc: str, topn: int = 50) -> List[Tuple[int, float]]:
# get embed vector using Doc2Vec model
vec_doc2vec: List[Tuple[int, float]] = normalize_and_apply_weight_doc2vec(new_doc)

# Existing similarity scores using Dod2Vec model
sims_doc2vec: ndarray = index[vec_doc2vec]

splited_term = [x for x in new_doc.split(' ')]
query_term_and_weight: Dict[int, float] = {}
for term in splited_term:
term_splited: List[str] = term.split(':')
if len(term_splited) >= 2 and ((term_splited[-1].startswith('+') or term_splited[-1].startswith('-') or term_splited[-1].isdigit())):
if term_splited[-1].startswith('+'):
# + indicates that the term is required and for making the term required, the weight is set to REQUIRE_TAG_MAGIC_NUMBER + weight
query_term_and_weight[dictionary.token2id[':'.join(term_splited[0:len(term_splited) - 1])]] = REQUIRE_TAG_MAGIC_NUMBER + int(term_splited[-1])
else:
query_term_and_weight[dictionary.token2id[':'.join(term_splited[0:len(term_splited) - 1])]] = int(term_splited[-1])
else:
query_term_and_weight[dictionary.token2id[':'.join(term_splited[0:len(term_splited)])]] = 1

# BM25 scores
bm25_scores = compute_bm25_scores(query_weights=query_term_and_weight)

# Normalize scores
if sims_doc2vec.max() > 0:
sims_doc2vec = sims_doc2vec / sims_doc2vec.max()
if bm25_scores.max() > 0:
bm25_scores = bm25_scores / bm25_scores.max()

# Combine scores
final_scores = BM25_WEIGHT * bm25_scores + DOC2VEC_WEIGHT * sims_doc2vec

def get_doc2vec_based_reranked_scores(final_scores, topn) -> List[Tuple[int, float]]:
# Get top documents
sims: List[Tuple[int, float]] = list(enumerate(final_scores))
sims = sorted(sims, key=lambda item: -item[1])

if len(sims) > 10:
# Perform rescoring
top10_sims = sims[:10] # Top 10 documents
top10_doc_ids: List[int] = [doc_id for doc_id, _ in top10_sims]
top10_doc_ids_set = set(top10_doc_ids)
top10_doc_vectors: List[List[Tuple[int, float]]] = [get_embedded_vector_by_doc_id(doc_id + 1) for doc_id in top10_doc_ids]
top10_doc_vectors: List[List[Tuple[int, float]]] = [get_embedded_vector_by_doc_id(doc_id + 1) for doc_id in
top10_doc_ids]
weighted_mean_vec: ndarray = np.average(top10_doc_vectors, axis=0, weights=[score for _, score in top10_sims])
weighted_mean_vec = weighted_mean_vec / np.linalg.norm(weighted_mean_vec)
weighted_mean_vec_with_docid: List[Tuple[int, float]] = [(round(docid), val) for docid, val in weighted_mean_vec.tolist()]
weighted_mean_vec_with_docid: List[Tuple[int, float]] = [(round(docid), val) for docid, val in
weighted_mean_vec.tolist()]

reranked_scores: ndarray = index[weighted_mean_vec_with_docid]

Expand Down Expand Up @@ -262,7 +241,57 @@ def sorting_key(item):
if ret_len > len(final_sims):
ret_len = len(final_sims)
return final_sims[:ret_len]
else:
# Apply threshold filtering
sims = filter_searched_result(sims)
ret_len: int = topn
if ret_len > len(sims):
ret_len = len(sims)
return sims[:ret_len]

def get_cfeatures_based_reranked_scores(final_scores, topn) -> List[Tuple[int, float]]:
global cfeature_filepath_idx
global cfeatures_idx
global predictor

if cfeature_filepath_idx is None:
cfeature_filepath_idx = []
with open('charactor-featues-idx.csv', 'r', encoding='utf-8') as f:
for line in f:
cfeature_filepath_idx.append(line.strip())

if cfeatures_idx is None:
cfeatures_idx = MatrixSimilarity.load('charactor-featues-idx')

if predictor is None:
predictor = Predictor()

# when length of final_scores is larger than 10, calculate mean vector of cfeatures from top10 images
# and calculate similarity between the mean vector and all images
# then, sort the similarity and return images whose similarity is higher than threshold

# Get top documents
sims: List[Tuple[int, float]] = list(enumerate(final_scores))
sims = sorted(sims, key=lambda item: -item[1])
if len(sims) > 10:
# Perform rescoring
top10_sims = sims[:10] # Top 10 documents
top10_doc_ids: List[int] = [doc_id for doc_id, _ in top10_sims]

# aggregete filepathes of top10 images
top10_files = [image_files_name_tags_arr[doc_id - 1].split(',')[0] for doc_id in top10_doc_ids]

# get charactor features
top10_cfeatures: List[np.ndarray] = [predictor.get_image_feature(file) for file in top10_files]
weighted_mean_cfeatures: np.ndarray = np.average(top10_cfeatures, axis=0, weights=[score for _, score in top10_sims])
weighted_mean_cfeatures = weighted_mean_cfeatures / np.linalg.norm(weighted_mean_cfeatures)
conved_mean_cfeatures: List[Tuple[int, float]] = [(ii, val) for ii, val in enumerate(weighted_mean_cfeatures)]
sims_by_cfeature: np.ndarray = cfeatures_idx[conved_mean_cfeatures]
sorted_sims: List[Tuple[int, float]] = list(enumerate(sims_by_cfeature))
sorted_sims = sorted(sorted_sims, key=lambda item: -item[1])
# filter by threshold
ret_sims = [(doc_id, score) for doc_id, score in sorted_sims if score > predictor.threshold]
return ret_sims
else:
# Apply threshold filtering
sims = filter_searched_result(sims)
Expand All @@ -272,6 +301,48 @@ def sorting_key(item):
return sims[:ret_len]


def find_similar_documents(new_doc: str, topn: int = 50) -> List[Tuple[int, float]]:
global cfeature_reranking_mode

# get embed vector using Doc2Vec model
vec_doc2vec: List[Tuple[int, float]] = normalize_and_apply_weight_doc2vec(new_doc)

# Existing similarity scores using Dod2Vec model
sims_doc2vec: ndarray = index[vec_doc2vec]

splited_term = [x for x in new_doc.split(' ')]
query_term_and_weight: Dict[int, float] = {}
for term in splited_term:
term_splited: List[str] = term.split(':')
if len(term_splited) >= 2 and ((term_splited[-1].startswith('+') or term_splited[-1].startswith('-') or term_splited[-1].isdigit())):
if term_splited[-1].startswith('+'):
# + indicates that the term is required and for making the term required, the weight is set to REQUIRE_TAG_MAGIC_NUMBER + weight
query_term_and_weight[dictionary.token2id[':'.join(term_splited[0:len(term_splited) - 1])]] = REQUIRE_TAG_MAGIC_NUMBER + int(term_splited[-1])
else:
query_term_and_weight[dictionary.token2id[':'.join(term_splited[0:len(term_splited) - 1])]] = int(term_splited[-1])
else:
query_term_and_weight[dictionary.token2id[':'.join(term_splited[0:len(term_splited)])]] = 1

# BM25 scores
bm25_scores = compute_bm25_scores(query_weights=query_term_and_weight)

# Normalize scores
if sims_doc2vec.max() > 0:
sims_doc2vec = sims_doc2vec / sims_doc2vec.max()
if bm25_scores.max() > 0:
bm25_scores = bm25_scores / bm25_scores.max()

# Combine scores
final_scores = BM25_WEIGHT * bm25_scores + DOC2VEC_WEIGHT * sims_doc2vec

# Rerank scores
if os.path.exists('charactor-featues-idx') and os.path.exists('charactor-featues-idx.csv'):
# special mode
cfeature_reranking_mode = True
return get_cfeatures_based_reranked_scores(final_scores, topn)
else:
return get_doc2vec_based_reranked_scores(final_scores, topn)

def init_session_state(data: List[Any] = []) -> None:
global ss
if 'data' not in ss:
Expand Down Expand Up @@ -464,10 +535,15 @@ def show_search_result() -> None:
found_docs_info: List[Dict[str, Any]] = []
for doc_id, similarity in similar_docs:
try:
found_img_info_splited: List[str] = image_files_name_tags_arr[doc_id].split(',')
if is_include_ng_word(found_img_info_splited):
continue
found_fpath: str = found_img_info_splited[0]
if cfeature_reranking_mode:
# special mode
found_fpath: str = cfeature_filepath_idx[doc_id]
else:
found_img_info_splited: List[str] = image_files_name_tags_arr[doc_id].split(',')
if is_include_ng_word(found_img_info_splited):
continue
found_fpath: str = found_img_info_splited[0]

if args is not None and args.rep:
found_fpath = found_fpath.replace(args.rep[0], args.rep[1])
found_docs_info.append({
Expand Down

0 comments on commit d596258

Please sign in to comment.