Skip to content

Commit

Permalink
refactor: Solve test issues for python 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzomammana committed May 1, 2024
1 parent 6bf2ff6 commit 6896266
Show file tree
Hide file tree
Showing 24 changed files with 45 additions and 48 deletions.
5 changes: 2 additions & 3 deletions quadra/callbacks/anomalib.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def generate(self):
self.figure.subplots_adjust(right=0.9)

axes = self.axis if len(self.images) > 1 else [self.axis]
for axis, image_dict in zip(axes, self.images, strict=False):
for axis, image_dict in zip(axes, self.images):
axis.axes.xaxis.set_visible(False)
axis.axes.yaxis.set_visible(False)
axis.imshow(image_dict["image"], image_dict["color_map"], vmin=0, vmax=255)
Expand Down Expand Up @@ -200,7 +200,6 @@ def on_test_batch_end(
outputs["label"],
outputs["pred_labels"],
outputs["pred_scores"],
strict=False,
)
):
image = Denormalize()(image.cpu())
Expand Down Expand Up @@ -247,7 +246,7 @@ def on_test_batch_end(
visualizer.close()

if self.plot_raw_outputs:
for raw_output, raw_name in zip([heatmap, vis_img], ["heatmap", "segmentation"], strict=False):
for raw_output, raw_name in zip([heatmap, vis_img], ["heatmap", "segmentation"]):
if raw_name == "segmentation":
raw_output = (raw_output * 255).astype(np.uint8)
raw_output = cv2.cvtColor(raw_output, cv2.COLOR_RGB2BGR)
Expand Down
4 changes: 2 additions & 2 deletions quadra/datamodules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def hash_data(self) -> None:
return

# TODO: We need to find a way to annotate the columns of data.
paths_and_hash_length = zip(self.data["samples"], [self.hash_size] * len(self.data), strict=False)
paths_and_hash_length = zip(self.data["samples"], [self.hash_size] * len(self.data))

with mp.Pool(min(8, mp.cpu_count() - 1)) as pool:
self.data["hash"] = list(
Expand Down Expand Up @@ -355,7 +355,7 @@ def load_augmented_samples(
raise ValueError("`n_aug_to_take` is not set. Cannot load augmented samples.")
aug_samples = []
aug_labels = []
for sample, label in zip(samples, targets, strict=False):
for sample, label in zip(samples, targets):
aug_samples.append(sample)
aug_labels.append(label)
if replace_str_from is not None and replace_str_to is not None:
Expand Down
4 changes: 2 additions & 2 deletions quadra/datamodules/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,15 +243,15 @@ def _prepare_data(self) -> None:
samples_test, targets_test = self._read_split(self.test_split_file)
if not self.train_split_file:
samples_train, targets_train = [], []
for sample, target in zip(all_samples, all_targets, strict=False):
for sample, target in zip(all_samples, all_targets):
if sample not in samples_test:
samples_train.append(sample)
targets_train.append(target)
if self.train_split_file:
samples_train, targets_train = self._read_split(self.train_split_file)
if not self.test_split_file:
samples_test, targets_test = [], []
for sample, target in zip(all_samples, all_targets, strict=False):
for sample, target in zip(all_samples, all_targets):
if sample not in samples_train:
samples_test.append(sample)
targets_test.append(target)
Expand Down
12 changes: 6 additions & 6 deletions quadra/datamodules/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _prepare_data(self) -> None:
samples_test, targets_test, masks_test = self._read_split(self.test_split_file)
if not self.train_split_file:
samples_train, targets_train, masks_train = [], [], []
for sample, target, mask in zip(all_samples, all_targets, all_masks, strict=False):
for sample, target, mask in zip(all_samples, all_targets, all_masks):
if sample not in samples_test:
samples_train.append(sample)
targets_train.append(target)
Expand All @@ -197,7 +197,7 @@ def _prepare_data(self) -> None:
samples_train, targets_train, masks_train = self._read_split(self.train_split_file)
if not self.test_split_file:
samples_test, targets_test, masks_test = [], [], []
for sample, target, mask in zip(all_samples, all_targets, all_masks, strict=False):
for sample, target, mask in zip(all_samples, all_targets, all_masks):
if sample not in samples_train:
samples_test.append(sample)
targets_test.append(target)
Expand Down Expand Up @@ -549,7 +549,7 @@ def _prepare_data(self) -> None:
samples_and_masks_test,
targets_test,
) = iterative_train_test_split(
np.expand_dims(np.array(list(zip(all_samples, all_masks, strict=False))), 1),
np.expand_dims(np.array(list(zip(all_samples, all_masks))), 1),
np.array(all_targets),
test_size=self.test_size,
)
Expand All @@ -561,7 +561,7 @@ def _prepare_data(self) -> None:
samples_test, targets_test, masks_test = self._read_split(self.test_split_file)
if not self.train_split_file:
samples_train, targets_train, masks_train = [], [], []
for sample, target, mask in zip(all_samples, all_targets, all_masks, strict=False):
for sample, target, mask in zip(all_samples, all_targets, all_masks):
if sample not in samples_test:
samples_train.append(sample)
targets_train.append(target)
Expand All @@ -571,7 +571,7 @@ def _prepare_data(self) -> None:
samples_train, targets_train, masks_train = self._read_split(self.train_split_file)
if not self.test_split_file:
samples_test, targets_test, masks_test = [], [], []
for sample, target, mask in zip(all_samples, all_targets, all_masks, strict=False):
for sample, target, mask in zip(all_samples, all_targets, all_masks):
if sample not in samples_train:
samples_test.append(sample)
targets_test.append(target)
Expand All @@ -583,7 +583,7 @@ def _prepare_data(self) -> None:
raise ValueError("Validation split file is specified but no train or test split file is specified.")
else:
samples_and_masks_train, targets_train, samples_and_masks_val, targets_val = iterative_train_test_split(
np.expand_dims(np.array(list(zip(samples_train, masks_train, strict=False))), 1),
np.expand_dims(np.array(list(zip(samples_train, masks_train))), 1),
np.array(targets_train),
test_size=self.val_size,
)
Expand Down
2 changes: 1 addition & 1 deletion quadra/datasets/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def __init__(
class_to_idx = {c: i for i, c in enumerate(range(unique_targets))}
self.class_to_idx = class_to_idx
self.idx_to_class = {v: k for k, v in class_to_idx.items()}
self.samples = list(zip(self.x, self.y, strict=False))
self.samples = list(zip(self.x, self.y))
self.rgb = rgb
self.transform = transform

Expand Down
2 changes: 1 addition & 1 deletion quadra/datasets/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(

cls, counts = np.unique(targets_array, return_counts=True)
max_count = np.max(counts)
for cl, count in zip(cls, counts, strict=False):
for cl, count in zip(cls, counts):
idx_to_pick = list(np.where(targets_array == cl)[0])

if count < max_count:
Expand Down
2 changes: 1 addition & 1 deletion quadra/metrics/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def segmentation_props(
# Add dummy Dices so LSA is unique and i can compute FP and FN
dice_mat = _pad_to_shape(dice_mat, (max_dim, max_dim), 1)
lsa = linear_sum_assignment(dice_mat, maximize=False)
for row, col in zip(lsa[0], lsa[1], strict=False):
for row, col in zip(lsa[0], lsa[1]):
# More preds than GTs --> False Positive
if row < n_labels_pred and col >= n_labels_mask:
min_row = pred_bbox[row][0]
Expand Down
2 changes: 1 addition & 1 deletion quadra/models/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def __call__(self, *inputs: np.ndarray | torch.Tensor) -> Any:

onnx_inputs: dict[str, np.ndarray | torch.Tensor] = {}

for onnx_input, current_input in zip(self.model.get_inputs(), inputs, strict=False):
for onnx_input, current_input in zip(self.model.get_inputs(), inputs):
if isinstance(current_input, torch.Tensor):
onnx_inputs[onnx_input.name] = current_input
use_pytorch = True
Expand Down
1 change: 0 additions & 1 deletion quadra/modules/ssl/byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def update_teacher(self):
for student_ps, teacher_ps in zip(
list(self.model.parameters()) + list(self.student_projection_mlp.parameters()),
list(self.teacher.parameters()) + list(self.teacher_projection_mlp.parameters()),
strict=False,
):
teacher_ps.data = teacher_ps.data * teacher_momentum + (1 - teacher_momentum) * student_ps.data

Expand Down
3 changes: 1 addition & 2 deletions quadra/tasks/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _generate_report(self) -> None:
exportable_anomaly_scores = anomaly_scores

# Zip the lists together to create rows for the CSV file
rows = zip(image_paths, pred_labels, gt_labels, exportable_anomaly_scores, strict=False)
rows = zip(image_paths, pred_labels, gt_labels, exportable_anomaly_scores)
# Specify the CSV file name
csv_file = "test_predictions.csv"
# Write the data to the CSV file
Expand Down Expand Up @@ -498,7 +498,6 @@ def generate_report(self) -> None:
self.metadata["image_labels"],
anomaly_scores,
anomaly_maps,
strict=False,
),
total=len(self.metadata["image_paths"]),
):
Expand Down
4 changes: 1 addition & 3 deletions quadra/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,9 +629,7 @@ def train(self) -> None:
all_labels = all_labels[sorted_indices]

# cycle over all train/test split
for train_dataloader, test_dataloader in zip(
self.train_dataloader_list, self.test_dataloader_list, strict=False
):
for train_dataloader, test_dataloader in zip(self.train_dataloader_list, self.test_dataloader_list):
# Reinit classifier
self.model = self.config.model
self.trainer.change_classifier(self.model)
Expand Down
2 changes: 1 addition & 1 deletion quadra/tasks/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def test(self) -> None:
if self.datamodule.test_dataset_available:
stages.append("test")
dataloaders.append(self.datamodule.test_dataloader())
for stage, dataloader in zip(stages, dataloaders, strict=False):
for stage, dataloader in zip(stages, dataloaders):
log.info("Running inference on %s set with batch size: %d", stage, dataloader.batch_size)
image_list, mask_list, mask_pred_list, label_list = [], [], [], []
for batch in dataloader:
Expand Down
2 changes: 1 addition & 1 deletion quadra/tasks/ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def test(self) -> None:
im = interpolate(im, self.embedding_image_size)

images.append(im.cpu())
metadata.extend(zip(targets, class_names, file_paths, strict=False))
metadata.extend(zip(targets, class_names, file_paths))
counter += len(im)
images = torch.cat(images, dim=0)
embeddings = torch.cat(embeddings, dim=0)
Expand Down
4 changes: 3 additions & 1 deletion quadra/utils/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,11 @@ def _normalize_batch(self, outputs, pl_module):
"""Normalize a batch of predictions."""
image_threshold = pl_module.image_threshold.value.cpu()
pixel_threshold = pl_module.pixel_threshold.value.cpu()
outputs["pred_scores"] = normalize_anomaly_score(outputs["pred_scores"], image_threshold)
outputs["pred_scores"] = normalize_anomaly_score(outputs["pred_scores"], image_threshold.item())

threshold = pixel_threshold if self.threshold_type == "pixel" else image_threshold
threshold = threshold.item()

if "anomaly_maps" in outputs:
outputs["anomaly_maps"] = normalize_anomaly_score(outputs["anomaly_maps"], threshold)

Expand Down
8 changes: 4 additions & 4 deletions quadra/utils/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def find_images_and_targets(
sorted_labels = sorted(unique_labels, key=natural_key)
class_to_idx = {str(c): idx for idx, c in enumerate(sorted_labels)}

images_and_targets = [(f, l) for f, l in zip(filenames, labels, strict=False) if l in class_to_idx]
images_and_targets = [(f, l) for f, l in zip(filenames, labels) if l in class_to_idx]

if sort:
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
Expand Down Expand Up @@ -210,7 +210,7 @@ def find_test_image(
file_samples.append(sample_path)

test_split = [os.path.join(folder, sample.strip()) for sample in file_samples]
labels = [t for s, t in zip(filenames, labels, strict=False) if s in file_samples]
labels = [t for s, t in zip(filenames, labels) if s in file_samples]
filenames = [s for s in filenames if s in file_samples]
log.info("Selected %d images using test_split_file for the test", len(filenames))
if len(filenames) != len(file_samples):
Expand Down Expand Up @@ -353,7 +353,7 @@ def get_split(

cl, counts = np.unique(targets, return_counts=True)

for num, _cl in zip(counts, cl, strict=False):
for num, _cl in zip(counts, cl):
if num == 1:
to_remove = np.where(np.array(targets) == _cl)[0][0]
samples = np.delete(np.array(samples), to_remove)
Expand All @@ -378,7 +378,7 @@ def get_split(
file_samples.append(sample_path)

train_split = [os.path.join(image_dir, sample.strip()) for sample in file_samples]
targets = np.array([t for s, t in zip(samples, targets, strict=False) if s in file_samples])
targets = np.array([t for s, t in zip(samples, targets) if s in file_samples])
samples = np.array([s for s in samples if s in file_samples])

if limit_training_data is not None:
Expand Down
4 changes: 2 additions & 2 deletions quadra/utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def calculate_mask_based_metrics(
"Accuracy": [],
}
for idx, (image, pred, mask, thresh_pred, dice_score) in enumerate(
zip(images, preds, masks, thresh_preds, dice_scores, strict=False)
zip(images, preds, masks, thresh_preds, dice_scores)
):
if np.sum(mask) == 0:
good_dice.append(dice_score)
Expand Down Expand Up @@ -303,7 +303,7 @@ def create_mask_report(
non_zero_score_idx = sorted_idx[~binary_labels]
zero_score_idx = sorted_idx[binary_labels]
file_paths = []
for name, current_score_idx in zip(["good", "bad"], [zero_score_idx, non_zero_score_idx], strict=False):
for name, current_score_idx in zip(["good", "bad"], [zero_score_idx, non_zero_score_idx]):
if len(current_score_idx) == 0:
continue

Expand Down
8 changes: 5 additions & 3 deletions quadra/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
import warnings
from collections.abc import Callable
from typing import cast
from typing import Union, cast

import numpy as np
import timm
Expand Down Expand Up @@ -161,7 +161,9 @@ def get_feature(
x1 = x1.to(feature_extractor.device).to(feature_extractor.model_dtype)

if gradcam:
y_hat = cast(list[torch.Tensor] | tuple[torch.Tensor] | torch.Tensor, feature_extractor(x1).detach())
y_hat = cast(
Union[list[torch.Tensor], tuple[torch.Tensor], torch.Tensor], feature_extractor(x1).detach()
)
# mypy can't detect that gradcam is true only if we have a features_extractor
if is_vision_transformer(feature_extractor.features_extractor): # type: ignore[union-attr]
grayscale_cam_low_res = grad_rollout(
Expand All @@ -176,7 +178,7 @@ def get_feature(
feature_extractor.zero_grad(set_to_none=True) # type: ignore[union-attr]
else:
with torch.no_grad():
y_hat = cast(list[torch.Tensor] | tuple[torch.Tensor] | torch.Tensor, feature_extractor(x1))
y_hat = cast(Union[list[torch.Tensor], tuple[torch.Tensor], torch.Tensor], feature_extractor(x1))
grayscale_cams = None

if isinstance(y_hat, (list, tuple)):
Expand Down
2 changes: 1 addition & 1 deletion quadra/utils/patch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def generate_patch_dataset(
num_workers=num_workers,
)

for phase, split_dict in zip(["val", "test"], [val_data_dictionary, test_data_dictionary], strict=False):
for phase, split_dict in zip(["val", "test"], [val_data_dictionary, test_data_dictionary]):
if len(split_dict) > 0:
log.info("Generating %s set", phase)
generate_patch_sliding_window_dataset(
Expand Down
6 changes: 3 additions & 3 deletions quadra/utils/tests/fixtures/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _build_classification_dataset(

classes = dataset_arguments.classes if dataset_arguments.classes else range(len(dataset_arguments.samples))

for class_name, samples in zip(classes, dataset_arguments.samples, strict=False):
for class_name, samples in zip(classes, dataset_arguments.samples):
class_path = classification_dataset_path / str(class_name)
class_path.mkdir()
for i in range(samples):
Expand Down Expand Up @@ -197,7 +197,7 @@ def _build_multilabel_classification_dataset(

generated_samples = []
counter = 0
for class_name, samples in zip(classes, dataset_arguments.samples, strict=False):
for class_name, samples in zip(classes, dataset_arguments.samples):
for _ in range(samples):
image = _random_image()
image_path = images_path / f"{counter}.png"
Expand Down Expand Up @@ -319,7 +319,7 @@ def _build_classification_patch_dataset(

class_to_idx = {class_name: i for i, class_name in enumerate(classes)}

for class_name, samples in zip(classes, dataset_arguments.samples, strict=False):
for class_name, samples in zip(classes, dataset_arguments.samples):
for i in range(samples):
image = _random_image(size=(224, 224))
mask = np.zeros((224, 224), dtype=np.uint8)
Expand Down
6 changes: 2 additions & 4 deletions quadra/utils/tests/fixtures/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,12 @@ def _build_segmentation_dataset(
classes = [0] + classes

counter = 0
for split_name, split_samples in zip(
["train", "val", "test"], [train_samples, val_samples, test_samples], strict=False
):
for split_name, split_samples in zip(["train", "val", "test"], [train_samples, val_samples, test_samples]):
if split_samples is None:
continue

with open(segmentation_dataset_path / f"{split_name}.txt", "w") as split_file:
for class_name, samples in zip(classes, split_samples, strict=False):
for class_name, samples in zip(classes, split_samples):
for _ in range(samples):
image = _random_image(size=(224, 224))
mask = np.zeros((224, 224), dtype=np.uint8)
Expand Down
4 changes: 2 additions & 2 deletions quadra/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __call__(self, tensor: torch.Tensor, make_copy=True) -> torch.Tensor:
new_t = tensor.detach().clone()
else:
new_t = tensor
for t, m, s in zip(new_t, self.mean, self.std, strict=False):
for t, m, s in zip(new_t, self.mean, self.std):
t.mul_(s).add_(m)
# The normalize code -> t.sub_(m).div_(s)
return new_t
Expand Down Expand Up @@ -83,7 +83,7 @@ def create_grid_figure(
ax[i][j].get_xaxis().set_ticks([])
ax[i][j].get_yaxis().set_ticks([])
if row_names is not None:
for ax, name in zip(ax[:, 0], row_names, strict=False): # noqa: B020
for ax, name in zip(ax[:, 0], row_names): # noqa: B020
ax.set_ylabel(name, rotation=90)

plt.tight_layout()
Expand Down
2 changes: 1 addition & 1 deletion quadra/utils/vit_explainability.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def grad_rollout(
"""
result = torch.eye(attentions[0].size(-1))
with torch.no_grad():
for attention, grad in zip(attentions, gradients, strict=False):
for attention, grad in zip(attentions, gradients):
weights = grad
attention_heads_fused = torch.mean((attention * weights), dim=1)
attention_heads_fused[attention_heads_fused < 0] = 0
Expand Down
Loading

0 comments on commit 6896266

Please sign in to comment.