Skip to content

Commit

Permalink
refactor: Apply updated pre-commits
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzomammana committed May 2, 2024
1 parent 80017e6 commit e4905ce
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 59 deletions.
46 changes: 24 additions & 22 deletions quadra/callbacks/anomalib.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,10 @@ def on_test_batch_end(
threshold = pl_module.pixel_metrics.F1Score.threshold
else:
raise AttributeError("Metric has no threshold attribute")
elif hasattr(pl_module.image_metrics.F1Score, "threshold"):
threshold = pl_module.image_metrics.F1Score.threshold
else:
if hasattr(pl_module.image_metrics.F1Score, "threshold"):
threshold = pl_module.image_metrics.F1Score.threshold
else:
raise AttributeError("Metric has no threshold attribute")
raise AttributeError("Metric has no threshold attribute")

for (
filename,
Expand All @@ -202,34 +201,36 @@ def on_test_batch_end(
outputs["pred_scores"],
)
):
image = Denormalize()(image.cpu())
true_mask = true_mask.cpu().numpy()
anomaly_map = anomaly_map.cpu().numpy()
denormalized_image = Denormalize()(image.cpu())
current_true_mask = true_mask.cpu().numpy()
current_anomaly_map = anomaly_map.cpu().numpy()

output_label_folder = "ok" if pred_label == gt_label else "wrong"

if self.plot_only_wrong and output_label_folder == "ok":
continue

heatmap = superimpose_anomaly_map(anomaly_map, image, normalize=not self.inputs_are_normalized)
heatmap = superimpose_anomaly_map(
current_anomaly_map, denormalized_image, normalize=not self.inputs_are_normalized
)

if isinstance(threshold, float):
pred_mask = compute_mask(anomaly_map, threshold)
pred_mask = compute_mask(current_anomaly_map, threshold)
else:
raise TypeError("Threshold should be float")
vis_img = mark_boundaries(image, pred_mask, color=(1, 0, 0), mode="thick")
vis_img = mark_boundaries(denormalized_image, pred_mask, color=(1, 0, 0), mode="thick")
visualizer = Visualizer()

if self.task == "segmentation":
visualizer.add_image(image=image, title="Image")
visualizer.add_image(image=denormalized_image, title="Image")
if "mask" in outputs:
true_mask = true_mask * 255
visualizer.add_image(image=true_mask, color_map="gray", title="Ground Truth")
current_true_mask = current_true_mask * 255
visualizer.add_image(image=current_true_mask, color_map="gray", title="Ground Truth")
visualizer.add_image(image=heatmap, title="Predicted Heat Map")
visualizer.add_image(image=pred_mask, color_map="gray", title="Predicted Mask")
visualizer.add_image(image=vis_img, title="Segmentation Result")
elif self.task == "classification":
gt_im = add_anomalous_label(image) if gt_label else add_normal_label(image)
gt_im = add_anomalous_label(denormalized_image) if gt_label else add_normal_label(denormalized_image)
visualizer.add_image(gt_im, title="Image/True label")
if anomaly_score >= threshold:
image_classified = add_anomalous_label(heatmap, anomaly_score)
Expand All @@ -239,27 +240,28 @@ def on_test_batch_end(

visualizer.generate()
visualizer.figure.suptitle(
f"F1 threshold: {threshold}, Mask_max: {anomaly_map.max():.3f}, Anomaly_score: {anomaly_score:.3f}"
f"F1 threshold: {threshold}, Mask_max: {current_anomaly_map.max():.3f}, "
f"Anomaly_score: {anomaly_score:.3f}"
)
filename = Path(filename)
self._add_images(visualizer, filename, output_label_folder)
path_filename = Path(filename)
self._add_images(visualizer, path_filename, output_label_folder)
visualizer.close()

if self.plot_raw_outputs:
for raw_output, raw_name in zip([heatmap, vis_img], ["heatmap", "segmentation"]):
if raw_name == "segmentation":
raw_output = (raw_output * 255).astype(np.uint8)
raw_output = cv2.cvtColor(raw_output, cv2.COLOR_RGB2BGR)
current_raw_output = (raw_output * 255).astype(np.uint8)
current_raw_output = cv2.cvtColor(raw_output, cv2.COLOR_RGB2BGR)
raw_filename = (
Path(self.output_path)
/ "images"
/ output_label_folder
/ filename.parent.name
/ path_filename.parent.name
/ "raw_outputs"
/ Path(filename.stem + f"_{raw_name}.png")
/ Path(path_filename.stem + f"_{raw_name}.png")
)
raw_filename.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(str(raw_filename), raw_output)
cv2.imwrite(str(raw_filename), current_raw_output)

def on_test_end(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Sync logs.
Expand Down
5 changes: 3 additions & 2 deletions quadra/datamodules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,10 @@ def load_augmented_samples(
for sample, label in zip(samples, targets):
aug_samples.append(sample)
aug_labels.append(label)
final_sample = sample
if replace_str_from is not None and replace_str_to is not None:
sample = sample.replace(replace_str_from, replace_str_to)
base, ext = os.path.splitext(sample)
final_sample = final_sample.replace(replace_str_from, replace_str_to)
base, ext = os.path.splitext(final_sample)
for k in range(self.n_aug_to_take):
aug_samples.append(base + "_" + str(k + 1) + ext)
aug_labels.append(label)
Expand Down
2 changes: 1 addition & 1 deletion quadra/datamodules/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def _filter_images_and_targets(
samples: list[str] = []
targets: list[str] = []
idx_to_class = {v: k for k, v in class_to_idx.items()}
images_and_targets = [(str(image_path), target) for image_path, target in images_and_targets]
for image_path, target in images_and_targets:
image_path = str(image_path)
target_class = idx_to_class[target]
if self.exclude_filter is not None and any(
exclude_filter in image_path for exclude_filter in self.exclude_filter
Expand Down
9 changes: 4 additions & 5 deletions quadra/datasets/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,11 @@ def __getitem__(self, index: int) -> dict[str, str | Tensor]:
# If good images have no associated mask create an empty one
if label_index == 0:
mask = np.zeros(shape=original_image_shape[:2])
elif os.path.isfile(mask_path):
mask = cv2.imread(mask_path, flags=0) / 255.0 # type: ignore[operator]
else:
if os.path.isfile(mask_path):
mask = cv2.imread(mask_path, flags=0) / 255.0 # type: ignore[operator]
else:
# We need ones in the mask to compute correctly at least image level f1 score
mask = np.ones(shape=original_image_shape[:2])
# We need ones in the mask to compute correctly at least image level f1 score
mask = np.ones(shape=original_image_shape[:2])

if self.valid_area_mask is not None:
mask = mask * self.valid_area_mask
Expand Down
7 changes: 4 additions & 3 deletions quadra/models/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,11 @@ def generate_session_options(self) -> ort.SessionOptions:
dict[str, Any], OmegaConf.to_container(self.config.session_options, resolve=True)
)
for key, value in session_options_dict.items():
final_value = value
if isinstance(value, dict) and "_target_" in value:
value = instantiate(value)
final_value = instantiate(final_value)

setattr(session_options, key, value)
setattr(session_options, key, final_value)

return session_options

Expand Down Expand Up @@ -240,7 +241,7 @@ def _forward_from_pytorch(self, input_dict: dict[str, torch.Tensor]):
for k, v in input_dict.items():
if not v.is_contiguous():
# If not contiguous onnx give wrong results
v = v.contiguous()
v = v.contiguous() # noqa: PLW2901

io_binding.bind_input(
name=k,
Expand Down
7 changes: 3 additions & 4 deletions quadra/schedulers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@ def set_lr(self, lr: tuple[float, ...]):
lr_to_set = self.init_lr[0]
else:
lr_to_set = self.init_lr[i]
elif len(lr) == 1:
lr_to_set = lr[0]
else:
if len(lr) == 1:
lr_to_set = lr[0]
else:
lr_to_set = lr[i]
lr_to_set = lr[i]
g["lr"] = lr_to_set

def get_lr(self):
Expand Down
13 changes: 6 additions & 7 deletions quadra/utils/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,13 @@ def export_onnx_model(

if hasattr(onnx_config, "fixed_batch_size") and onnx_config.fixed_batch_size is not None:
dynamic_axes = None
else:
if dynamic_axes is None:
dynamic_axes = {}
for i, _ in enumerate(input_names):
dynamic_axes[input_names[i]] = {0: "batch_size"}
elif dynamic_axes is None:
dynamic_axes = {}
for i, _ in enumerate(input_names):
dynamic_axes[input_names[i]] = {0: "batch_size"}

for i, _ in enumerate(output_names):
dynamic_axes[output_names[i]] = {0: "batch_size"}
for i, _ in enumerate(output_names):
dynamic_axes[output_names[i]] = {0: "batch_size"}

onnx_config = cast(dict[str, Any], OmegaConf.to_container(onnx_config, resolve=True))

Expand Down
5 changes: 2 additions & 3 deletions quadra/utils/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,8 @@ def register_best_model(
if mode == "max":
if run.data.metrics[metric] > best_run.data.metrics[metric]:
best_run = run
else:
if run.data.metrics[metric] < best_run.data.metrics[metric]:
best_run = run
elif run.data.metrics[metric] < best_run.data.metrics[metric]:
best_run = run

if best_run is None:
log.error("No runs found for experiment %s with the given metric", experiment_name)
Expand Down
13 changes: 7 additions & 6 deletions quadra/utils/patch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ def _map_files(files: list[Any]):
"""Convert a list of dict to a list of PatchDatasetFileFormat."""
mapped_files = []
for file in files:
current_file = file
if isinstance(file, dict):
file = PatchDatasetFileFormat(**file)
mapped_files.append(file)
current_file = PatchDatasetFileFormat(**current_file)
mapped_files.append(current_file)

return mapped_files

def __post_init__(self):
Expand Down Expand Up @@ -308,11 +310,10 @@ def __save_patch_dataset(
missing_classes = set(classes_in_mask).difference(class_to_idx.values())

assert len(missing_classes) == 0, f"Found index in mask that has no corresponding class {missing_classes}"
elif mask_patches is not None:
reference_classes = {k: str(v) for k, v in enumerate(list(np.unique(mask_patches)))}
else:
if mask_patches is not None:
reference_classes = {k: str(v) for k, v in enumerate(list(np.unique(mask_patches)))}
else:
raise ValueError("If no `class_to_idx` is provided, `mask_patches` must be provided")
raise ValueError("If no `class_to_idx` is provided, `mask_patches` must be provided")

log.debug("Classes from mask: %s", reference_classes)
class_to_idx = {v: k for k, v in reference_classes.items()}
Expand Down
5 changes: 2 additions & 3 deletions quadra/utils/patch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,8 @@ def save_classification_result(
if is_polygon:
if len(reconstruction["prediction"]) == 0:
continue
else:
if reconstruction["prediction"].sum() == 0:
continue
elif reconstruction["prediction"].sum() == 0:
continue

if counter > 5:
break
Expand Down
5 changes: 2 additions & 3 deletions quadra/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ def create_grid_figure(
_, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=fig_size, squeeze=False)
for i, row in enumerate(images):
for j, image in enumerate(row):
if len(image.shape) == 3 and image.shape[0] == 1:
image = image[0]
ax[i][j].imshow(image, vmin=bounds[i][0], vmax=bounds[i][1])
image_to_plot = image[0] if len(image.shape) == 3 and image.shape[0] == 1 else image
ax[i][j].imshow(image_to_plot, vmin=bounds[i][0], vmax=bounds[i][1])
ax[i][j].get_xaxis().set_ticks([])
ax[i][j].get_yaxis().set_ticks([])
if row_names is not None:
Expand Down

0 comments on commit e4905ce

Please sign in to comment.