diff --git a/colab_env/tagging.py b/colab_env/tagging.py index 99e722c..2c90517 100644 --- a/colab_env/tagging.py +++ b/colab_env/tagging.py @@ -47,7 +47,7 @@ EXTENSIONS: List[str] = ['.png', '.jpg', '.jpeg', ".PNG", ".JPG", ".JPEG"] -BATCH_SIZE: int = 100 +BATCH_SIZE: int = 10 torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # for apple silicon @@ -138,9 +138,9 @@ def __init__(self) -> None: self.last_loaded_repo: Optional[str] = None self.tagger_model: Optional[nn.Module] = None self.tag_names: Optional[List[str]] = None - self.rating_indexes: Optional[List[int]] = None - self.general_indexes: Optional[List[int]] = None - self.character_indexes: Optional[List[int]] = None + self.rating_index: Optional[List[int]] = None + self.general_index: Optional[List[int]] = None + self.character_index: Optional[List[int]] = None self.transform: Optional[Callable] = None def list_files_recursive(self, tarfile_path: str) -> List[str]: @@ -152,7 +152,7 @@ def list_files_recursive(self, tarfile_path: str) -> List[str]: file_list.append(file_path) return file_list - def prepare_image(self, image: Image.Image) -> np.ndarray: + def prepare_image(self, image: Image.Image) -> Image.Image: # target_size: int = self.model_target_size if image.mode in ('RGBA', 'LA'): @@ -176,10 +176,12 @@ def prepare_image(self, image: Image.Image) -> np.ndarray: # Image.BICUBIC, # ) - image_array: np.ndarray = np.asarray(padded_image, dtype=np.float32) - image_array = image_array[:, :, ::-1] + # image_array: np.ndarray = np.asarray(padded_image, dtype=np.float32) + # image_array = image_array[:, :, ::-1] + # + # return np.expand_dims(image_array, axis=0) - return np.expand_dims(image_array, axis=0) + return padded_image def load_labels_hf( self, @@ -197,8 +199,8 @@ def load_labels_hf( df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) self.rating_index = list(np.where(df["category"] == 9)[0]) - self.general_index = list(np.where(df["category"] == 0)[0]), - self.character_index = list(np.where(df["category"] == 4)[0]), + self.general_index = list(np.where(df["category"] == 0)[0]) + self.character_index = list(np.where(df["category"] == 4)[0]) self.tag_names = df["name"].tolist() def load_model(self) -> None: @@ -278,11 +280,13 @@ def predict( # char_threshold=opts.char_threshold, # ) + # print(preds) + # exit(1) ret_strings: List[str] = [] for idx in range(0, len(images)): - labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[0][idx].astype(float))) + labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[idx].astype(float))) - general_names: List[Tuple[str, float]] = [labels[i] for i in self.general_indexes] + general_names: List[Tuple[str, float]] = [labels[i] for i in self.general_index] if general_mcut_enabled: general_probs: np.ndarray = np.array([x[1] for x in general_names]) @@ -290,7 +294,7 @@ def predict( general_res: Dict[str, float] = {x[0]: x[1] for x in general_names if x[1] > general_thresh} - character_names: List[Tuple[str, float]] = [labels[i] for i in self.character_indexes] + character_names: List[Tuple[str, float]] = [labels[i] for i in self.character_index] if character_mcut_enabled: character_probs: np.ndarray = np.array([x[1] for x in character_names])