Skip to content

Commit

Permalink
resolved several type checking error.
Browse files Browse the repository at this point in the history
  • Loading branch information
ryogrid committed Oct 14, 2024
1 parent 58ff6af commit 219121a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 19 deletions.
31 changes: 14 additions & 17 deletions make-tags-with-wd-tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import re
from typing import List, Tuple, Dict, Any, Optional, Callable, Protocol

from numpy import signedinteger

kaomojis: List[str] = [
"0_0",
"(o)_(o)",
Expand Down Expand Up @@ -40,7 +42,7 @@
def mcut_threshold(probs: np.ndarray) -> float:
sorted_probs: np.ndarray = probs[probs.argsort()[::-1]]
difs: np.ndarray = sorted_probs[:-1] - sorted_probs[1:]
t: int = difs.argmax()
t: signedinteger[Any] = difs.argmax()
thresh: float = (sorted_probs[t] + sorted_probs[t + 1]) / 2
return thresh

Expand Down Expand Up @@ -147,11 +149,11 @@ def predict(
character_thresh: float,
character_mcut_enabled: bool,
) -> str:
image: np.ndarray = self.prepare_image(image)
img: np.ndarray = self.prepare_image(image)

input_name: str = self.tagger_model.get_inputs()[0].name
label_name: str = self.tagger_model.get_outputs()[0].name
preds: np.ndarray = self.tagger_model.run([label_name], {input_name: image})[0]
preds: np.ndarray = self.tagger_model.run([label_name], {input_name: img})[0]

labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[0].astype(float)))

Expand All @@ -172,31 +174,26 @@ def predict(

character_res: Dict[str, float] = {x[0]: x[1] for x in character_names if x[1] > character_thresh}

sorted_general_strings: List[str] = sorted(
sorted_general_strings: List[Tuple[str, float]] = sorted(
general_res.items(),
key=lambda x: x[1],
reverse=True,
)
sorted_general_strings = [x[0] for x in sorted_general_strings]
sorted_general_strings = [x.replace(' ', '_') for x in sorted_general_strings]
sorted_general_strings = (
",".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
sorted_general_strings_str : List[str] = [x[0] for x in sorted_general_strings]
sorted_general_strings_str = [x.replace(' ', '_') for x in sorted_general_strings_str]
ret_string: str = (
",".join(sorted_general_strings_str).replace("(", "\(").replace(")", "\)")
)

ret_string: str = sorted_general_strings

if len(character_res) > 0:
sorted_character_strings: List[str] = sorted(
sorted_character_strings: List[Tuple[str, float]] = sorted(
character_res.items(),
key=lambda x: x[1],
reverse=True,
)
sorted_character_strings = [x[0] for x in sorted_character_strings]
sorted_character_strings = [x.replace(' ', '_') for x in sorted_character_strings]
sorted_character_strings = (
",".join(sorted_character_strings).replace("(", "\(").replace(")", "\)")
)
ret_string += "," + sorted_character_strings
sorted_character_strings_str: List[str] = [x[0] for x in sorted_character_strings]
sorted_character_strings_str = [x.replace(' ', '_') for x in sorted_character_strings_str]
ret_string += ",".join(sorted_character_strings_str).replace("(", "\(").replace(")", "\)")

return ret_string

Expand Down
2 changes: 2 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[mypy]
disable_error_code = import-untyped
3 changes: 1 addition & 2 deletions web-ui-image-search-lsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def slideshow() -> None:
print(f'Error: {e}')
ss['slideshow_index'] = (ss['slideshow_index'] + 1) % len(images)
st.rerun()
return

if st.button('Stop'):
ss['slideshow_active'] = False
ss['slideshow_index'] = 0
Expand All @@ -144,7 +144,6 @@ def display_images() -> None:
ss['slideshow_active'] = True
ss['slideshow_index'] = 0
st.rerun()
return

for data_per_page in ss['data'][ss['page_index']]:
cols = st.columns(5)
Expand Down

0 comments on commit 219121a

Please sign in to comment.