diff --git a/make-tags-with-wd-tagger.py b/make-tags-with-wd-tagger.py index 0b8bf51..e5f165f 100644 --- a/make-tags-with-wd-tagger.py +++ b/make-tags-with-wd-tagger.py @@ -9,6 +9,8 @@ import re from typing import List, Tuple, Dict, Any, Optional, Callable, Protocol +from numpy import signedinteger + kaomojis: List[str] = [ "0_0", "(o)_(o)", @@ -40,7 +42,7 @@ def mcut_threshold(probs: np.ndarray) -> float: sorted_probs: np.ndarray = probs[probs.argsort()[::-1]] difs: np.ndarray = sorted_probs[:-1] - sorted_probs[1:] - t: int = difs.argmax() + t: signedinteger[Any] = difs.argmax() thresh: float = (sorted_probs[t] + sorted_probs[t + 1]) / 2 return thresh @@ -147,11 +149,11 @@ def predict( character_thresh: float, character_mcut_enabled: bool, ) -> str: - image: np.ndarray = self.prepare_image(image) + img: np.ndarray = self.prepare_image(image) input_name: str = self.tagger_model.get_inputs()[0].name label_name: str = self.tagger_model.get_outputs()[0].name - preds: np.ndarray = self.tagger_model.run([label_name], {input_name: image})[0] + preds: np.ndarray = self.tagger_model.run([label_name], {input_name: img})[0] labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[0].astype(float))) @@ -172,31 +174,26 @@ def predict( character_res: Dict[str, float] = {x[0]: x[1] for x in character_names if x[1] > character_thresh} - sorted_general_strings: List[str] = sorted( + sorted_general_strings: List[Tuple[str, float]] = sorted( general_res.items(), key=lambda x: x[1], reverse=True, ) - sorted_general_strings = [x[0] for x in sorted_general_strings] - sorted_general_strings = [x.replace(' ', '_') for x in sorted_general_strings] - sorted_general_strings = ( - ",".join(sorted_general_strings).replace("(", "\(").replace(")", "\)") + sorted_general_strings_str : List[str] = [x[0] for x in sorted_general_strings] + sorted_general_strings_str = [x.replace(' ', '_') for x in sorted_general_strings_str] + ret_string: str = ( + ",".join(sorted_general_strings_str).replace("(", "\(").replace(")", "\)") ) - ret_string: str = sorted_general_strings - if len(character_res) > 0: - sorted_character_strings: List[str] = sorted( + sorted_character_strings: List[Tuple[str, float]] = sorted( character_res.items(), key=lambda x: x[1], reverse=True, ) - sorted_character_strings = [x[0] for x in sorted_character_strings] - sorted_character_strings = [x.replace(' ', '_') for x in sorted_character_strings] - sorted_character_strings = ( - ",".join(sorted_character_strings).replace("(", "\(").replace(")", "\)") - ) - ret_string += "," + sorted_character_strings + sorted_character_strings_str: List[str] = [x[0] for x in sorted_character_strings] + sorted_character_strings_str = [x.replace(' ', '_') for x in sorted_character_strings_str] + ret_string += ",".join(sorted_character_strings_str).replace("(", "\(").replace(")", "\)") return ret_string diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..d787271 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +disable_error_code = import-untyped diff --git a/web-ui-image-search-lsi.py b/web-ui-image-search-lsi.py index d829fad..58e122a 100644 --- a/web-ui-image-search-lsi.py +++ b/web-ui-image-search-lsi.py @@ -121,7 +121,7 @@ def slideshow() -> None: print(f'Error: {e}') ss['slideshow_index'] = (ss['slideshow_index'] + 1) % len(images) st.rerun() - return + if st.button('Stop'): ss['slideshow_active'] = False ss['slideshow_index'] = 0 @@ -144,7 +144,6 @@ def display_images() -> None: ss['slideshow_active'] = True ss['slideshow_index'] = 0 st.rerun() - return for data_per_page in ss['data'][ss['page_index']]: cols = st.columns(5)