Skip to content

Commit

Permalink
pytorch tagging script worked at M1 MBA.
Browse files Browse the repository at this point in the history
  • Loading branch information
ryogrid committed Oct 23, 2024
1 parent aac4b33 commit f24cccf
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions colab_env/tagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

EXTENSIONS: List[str] = ['.png', '.jpg', '.jpeg', ".PNG", ".JPG", ".JPEG"]

BATCH_SIZE: int = 100
BATCH_SIZE: int = 10

torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# for apple silicon
Expand Down Expand Up @@ -138,9 +138,9 @@ def __init__(self) -> None:
self.last_loaded_repo: Optional[str] = None
self.tagger_model: Optional[nn.Module] = None
self.tag_names: Optional[List[str]] = None
self.rating_indexes: Optional[List[int]] = None
self.general_indexes: Optional[List[int]] = None
self.character_indexes: Optional[List[int]] = None
self.rating_index: Optional[List[int]] = None
self.general_index: Optional[List[int]] = None
self.character_index: Optional[List[int]] = None
self.transform: Optional[Callable] = None

def list_files_recursive(self, tarfile_path: str) -> List[str]:
Expand All @@ -152,7 +152,7 @@ def list_files_recursive(self, tarfile_path: str) -> List[str]:
file_list.append(file_path)
return file_list

def prepare_image(self, image: Image.Image) -> np.ndarray:
def prepare_image(self, image: Image.Image) -> Image.Image:
# target_size: int = self.model_target_size

if image.mode in ('RGBA', 'LA'):
Expand All @@ -176,10 +176,12 @@ def prepare_image(self, image: Image.Image) -> np.ndarray:
# Image.BICUBIC,
# )

image_array: np.ndarray = np.asarray(padded_image, dtype=np.float32)
image_array = image_array[:, :, ::-1]
# image_array: np.ndarray = np.asarray(padded_image, dtype=np.float32)
# image_array = image_array[:, :, ::-1]
#
# return np.expand_dims(image_array, axis=0)

return np.expand_dims(image_array, axis=0)
return padded_image

def load_labels_hf(
self,
Expand All @@ -197,8 +199,8 @@ def load_labels_hf(

df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
self.rating_index = list(np.where(df["category"] == 9)[0])
self.general_index = list(np.where(df["category"] == 0)[0]),
self.character_index = list(np.where(df["category"] == 4)[0]),
self.general_index = list(np.where(df["category"] == 0)[0])
self.character_index = list(np.where(df["category"] == 4)[0])
self.tag_names = df["name"].tolist()

def load_model(self) -> None:
Expand Down Expand Up @@ -278,19 +280,21 @@ def predict(
# char_threshold=opts.char_threshold,
# )

# print(preds)
# exit(1)
ret_strings: List[str] = []
for idx in range(0, len(images)):
labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[0][idx].astype(float)))
labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[idx].astype(float)))

general_names: List[Tuple[str, float]] = [labels[i] for i in self.general_indexes]
general_names: List[Tuple[str, float]] = [labels[i] for i in self.general_index]

if general_mcut_enabled:
general_probs: np.ndarray = np.array([x[1] for x in general_names])
general_thresh = mcut_threshold(general_probs)

general_res: Dict[str, float] = {x[0]: x[1] for x in general_names if x[1] > general_thresh}

character_names: List[Tuple[str, float]] = [labels[i] for i in self.character_indexes]
character_names: List[Tuple[str, float]] = [labels[i] for i in self.character_index]

if character_mcut_enabled:
character_probs: np.ndarray = np.array([x[1] for x in character_names])
Expand Down

0 comments on commit f24cccf

Please sign in to comment.