Skip to content

Commit

Permalink
pytorch tagging script worked at M1 MBA.
Browse files Browse the repository at this point in the history
  • Loading branch information
ryogrid committed Oct 23, 2024
1 parent aac4b33 commit c55ad99
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions colab_env/tagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@

EXTENSIONS: List[str] = ['.png', '.jpg', '.jpeg', ".PNG", ".JPG", ".JPEG"]

BATCH_SIZE: int = 100
BATCH_SIZE: int = 10
PROGRESS_INTERVAL: int = 100

torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# for apple silicon
Expand Down Expand Up @@ -138,9 +139,9 @@ def __init__(self) -> None:
self.last_loaded_repo: Optional[str] = None
self.tagger_model: Optional[nn.Module] = None
self.tag_names: Optional[List[str]] = None
self.rating_indexes: Optional[List[int]] = None
self.general_indexes: Optional[List[int]] = None
self.character_indexes: Optional[List[int]] = None
self.rating_index: Optional[List[int]] = None
self.general_index: Optional[List[int]] = None
self.character_index: Optional[List[int]] = None
self.transform: Optional[Callable] = None

def list_files_recursive(self, tarfile_path: str) -> List[str]:
Expand All @@ -152,7 +153,7 @@ def list_files_recursive(self, tarfile_path: str) -> List[str]:
file_list.append(file_path)
return file_list

def prepare_image(self, image: Image.Image) -> np.ndarray:
def prepare_image(self, image: Image.Image) -> Image.Image:
# target_size: int = self.model_target_size

if image.mode in ('RGBA', 'LA'):
Expand All @@ -176,10 +177,12 @@ def prepare_image(self, image: Image.Image) -> np.ndarray:
# Image.BICUBIC,
# )

image_array: np.ndarray = np.asarray(padded_image, dtype=np.float32)
image_array = image_array[:, :, ::-1]
# image_array: np.ndarray = np.asarray(padded_image, dtype=np.float32)
# image_array = image_array[:, :, ::-1]
#
# return np.expand_dims(image_array, axis=0)

return np.expand_dims(image_array, axis=0)
return padded_image

def load_labels_hf(
self,
Expand All @@ -197,8 +200,8 @@ def load_labels_hf(

df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
self.rating_index = list(np.where(df["category"] == 9)[0])
self.general_index = list(np.where(df["category"] == 0)[0]),
self.character_index = list(np.where(df["category"] == 4)[0]),
self.general_index = list(np.where(df["category"] == 0)[0])
self.character_index = list(np.where(df["category"] == 4)[0])
self.tag_names = df["name"].tolist()

def load_model(self) -> None:
Expand Down Expand Up @@ -278,19 +281,21 @@ def predict(
# char_threshold=opts.char_threshold,
# )

# print(preds)
# exit(1)
ret_strings: List[str] = []
for idx in range(0, len(images)):
labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[0][idx].astype(float)))
labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[idx].astype(float)))

general_names: List[Tuple[str, float]] = [labels[i] for i in self.general_indexes]
general_names: List[Tuple[str, float]] = [labels[i] for i in self.general_index]

if general_mcut_enabled:
general_probs: np.ndarray = np.array([x[1] for x in general_names])
general_thresh = mcut_threshold(general_probs)

general_res: Dict[str, float] = {x[0]: x[1] for x in general_names if x[1] > general_thresh}

character_names: List[Tuple[str, float]] = [labels[i] for i in self.character_indexes]
character_names: List[Tuple[str, float]] = [labels[i] for i in self.character_index]

if character_mcut_enabled:
character_probs: np.ndarray = np.array([x[1] for x in character_names])
Expand Down Expand Up @@ -357,7 +362,7 @@ def process_directory(self, tarfile_path: str) -> None:

cnt += 1

if cnt - last_cnt >= BATCH_SIZE:
if cnt - last_cnt >= PROGRESS_INTERVAL:
now: float = time.perf_counter()
print(f'{cnt} files processed')
diff: float = now - start
Expand Down

0 comments on commit c55ad99

Please sign in to comment.