From e879759dacb1ab6f1625b9ad2017e5fd59fa630e Mon Sep 17 00:00:00 2001 From: Ryo Kanbayashi Date: Mon, 21 Oct 2024 20:12:19 +0900 Subject: [PATCH] trying gpu use and tagging implentation with batch: fin (includes mod for ryo_grid env). --- tagging.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tagging.py b/tagging.py index 22f2ebe..fef8b13 100644 --- a/tagging.py +++ b/tagging.py @@ -126,8 +126,7 @@ def load_model(self) -> None: return self.tagger_model_path = hf_hub_download(repo_id=TAGGER_VIT_MODEL_REPO, filename=MODEL_FILE_NAME) - # self.tagger_model = rt.InferenceSession(self.tagger_model_path, providers=['CPUExecutionProvider']) - self.tagger_model = rt.InferenceSession(self.tagger_model_path, providers=['CUDAExecutionProvider']) + self.tagger_model = rt.InferenceSession(self.tagger_model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) _, height, _, _ = self.tagger_model.get_inputs()[0].shape self.model_target_size = height @@ -163,10 +162,15 @@ def predict( inputs = np.vstack([inputs, self.prepare_image(img)]) print(inputs.shape) - preds: List[np.ndarray] = self.tagger_model.run([label_name], {input_name: np.array(inputs)}) + start = time.perf_counter() + preds: np.ndarray = self.tagger_model.run([label_name], {input_name: inputs}) + end = time.perf_counter() + # print time per file + print('{:.4f} seconds per file'.format((end - start) / len(images)), flush=True) + ret_strings: List[str] = [] - for idx in range(len(preds)): - labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[idx].astype(float))) + for idx in range(0, len(images)): + labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[0][idx].astype(float))) general_names: List[Tuple[str, float]] = [labels[i] for i in self.general_indexes] @@ -222,9 +226,13 @@ def process_directory(self, directory: str) -> None: self.load_model() + # reverse the list to process from last (for ryo_grid env) + file_list.reverse() + imgs: List[Image.Image] = [] fpathes: List[str] = [] start: float = time.perf_counter() + last_cnt: int = 0 cnt: int = 0 for file_path in file_list: try: @@ -239,7 +247,9 @@ def process_directory(self, directory: str) -> None: imgs = [] fpathes = [] - if cnt % 100 == 0: + cnt += 1 + + if last_cnt - cnt >= 100: now: float = time.perf_counter() print(f'{cnt} files processed') diff: float = now - start @@ -248,8 +258,8 @@ def process_directory(self, directory: str) -> None: time_per_file: float = diff / cnt print('{:.4f} seconds per file'.format(time_per_file)) print("", flush=True) + last_cnt = cnt - cnt += 1 except Exception as e: error_class: type = type(e) error_description: str = str(e)