Skip to content

Commit

Permalink
trying gpu use and tagging implentation with batch: fin (includes mod…
Browse files Browse the repository at this point in the history
… for ryo_grid env).
  • Loading branch information
ryogrid committed Oct 21, 2024
1 parent abad9b7 commit e879759
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions tagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ def load_model(self) -> None:
return

self.tagger_model_path = hf_hub_download(repo_id=TAGGER_VIT_MODEL_REPO, filename=MODEL_FILE_NAME)
# self.tagger_model = rt.InferenceSession(self.tagger_model_path, providers=['CPUExecutionProvider'])
self.tagger_model = rt.InferenceSession(self.tagger_model_path, providers=['CUDAExecutionProvider'])
self.tagger_model = rt.InferenceSession(self.tagger_model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
_, height, _, _ = self.tagger_model.get_inputs()[0].shape

self.model_target_size = height
Expand Down Expand Up @@ -163,10 +162,15 @@ def predict(
inputs = np.vstack([inputs, self.prepare_image(img)])
print(inputs.shape)

preds: List[np.ndarray] = self.tagger_model.run([label_name], {input_name: np.array(inputs)})
start = time.perf_counter()
preds: np.ndarray = self.tagger_model.run([label_name], {input_name: inputs})
end = time.perf_counter()
# print time per file
print('{:.4f} seconds per file'.format((end - start) / len(images)), flush=True)

ret_strings: List[str] = []
for idx in range(len(preds)):
labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[idx].astype(float)))
for idx in range(0, len(images)):
labels: List[Tuple[str, float]] = list(zip(self.tag_names, preds[0][idx].astype(float)))

general_names: List[Tuple[str, float]] = [labels[i] for i in self.general_indexes]

Expand Down Expand Up @@ -222,9 +226,13 @@ def process_directory(self, directory: str) -> None:

self.load_model()

# reverse the list to process from last (for ryo_grid env)
file_list.reverse()

imgs: List[Image.Image] = []
fpathes: List[str] = []
start: float = time.perf_counter()
last_cnt: int = 0
cnt: int = 0
for file_path in file_list:
try:
Expand All @@ -239,7 +247,9 @@ def process_directory(self, directory: str) -> None:
imgs = []
fpathes = []

if cnt % 100 == 0:
cnt += 1

if last_cnt - cnt >= 100:
now: float = time.perf_counter()
print(f'{cnt} files processed')
diff: float = now - start
Expand All @@ -248,8 +258,8 @@ def process_directory(self, directory: str) -> None:
time_per_file: float = diff / cnt
print('{:.4f} seconds per file'.format(time_per_file))
print("", flush=True)
last_cnt = cnt

cnt += 1
except Exception as e:
error_class: type = type(e)
error_description: str = str(e)
Expand Down

0 comments on commit e879759

Please sign in to comment.