Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 13, 2025
1 parent 9d3407a commit 13ab1bd
Show file tree
Hide file tree
Showing 18 changed files with 171 additions and 186 deletions.
4 changes: 2 additions & 2 deletions examples/04_Perceptual_distance.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@
" reference_filemap = {s.lower(): s for s in os.listdir(folder / \"reference_images\")}\n",
" distorted_filemap = {s.lower(): s for s in os.listdir(folder / \"distorted_images\")}\n",
" for i in range(25):\n",
" reference_filename = reference_filemap[f\"i{i+1:02d}.bmp\"]\n",
" reference_filename = reference_filemap[f\"i{i + 1:02d}.bmp\"]\n",
" reference_images[i] = (\n",
" torch.as_tensor(\n",
" np.asarray(\n",
Expand All @@ -436,7 +436,7 @@
" for j in range(24):\n",
" for k in range(5):\n",
" distorted_filename = distorted_filemap[\n",
" f\"i{i+1:02d}_{j+1:02d}_{k+1}.bmp\"\n",
" f\"i{i + 1:02d}_{j + 1:02d}_{k + 1}.bmp\"\n",
" ]\n",
" distorted_images[i, j, k] = (\n",
" torch.as_tensor(\n",
Expand Down
6 changes: 2 additions & 4 deletions src/plenoptic/metric/perceptual_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def _ssim_parts(img1, img2, pad=False):

if not img1.ndim == img2.ndim == 4:
raise Exception(
"Input images should have four dimensions: (batch, channel,"
" height, width)"
"Input images should have four dimensions: (batch, channel, height, width)"
)
if img1.shape[-2:] != img2.shape[-2:]:
raise Exception("img1 and img2 must have the same height and width!")
Expand Down Expand Up @@ -455,8 +454,7 @@ def nlpd(img1, img2):

if not img1.ndim == img2.ndim == 4:
raise Exception(
"Input images should have four dimensions: (batch, channel,"
" height, width)"
"Input images should have four dimensions: (batch, channel, height, width)"
)
if img1.shape[-2:] != img2.shape[-2:]:
raise Exception("img1 and img2 must have the same height and width!")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,9 @@ def forward(
scales = self.scales
scale_ints = [s for s in scales if isinstance(s, int)]
if len(scale_ints) != 0:
assert (max(scale_ints) < self.num_scales) and (
min(scale_ints) >= 0
), "Scales must be within 0 and num_scales-1"
assert (max(scale_ints) < self.num_scales) and (min(scale_ints) >= 0), (
"Scales must be within 0 and num_scales-1"
)
angle = self.angle.copy()
log_rad = self.log_rad.copy()
lo0mask = self.lo0mask.clone()
Expand Down Expand Up @@ -619,9 +619,9 @@ def _recon_levels_check(
)
levs_nums = np.array([int(i) for i in levels if isinstance(i, int)])
assert (levs_nums >= 0).all(), "Level numbers must be non-negative."
assert (
levs_nums < self.num_scales
).all(), f"Level numbers must be in the range [0, {self.num_scales - 1:d}]"
assert (levs_nums < self.num_scales).all(), (
f"Level numbers must be in the range [0, {self.num_scales - 1:d}]"
)
levs_tmp = list(np.sort(levs_nums)) # we want smallest first
if "residual_highpass" in levels:
levs_tmp = ["residual_highpass"] + levs_tmp
Expand Down Expand Up @@ -661,15 +661,13 @@ def _recon_bands_check(self, bands: Literal["all"] | list[int]) -> list[int]:
if isinstance(bands, str):
if bands != "all":
raise TypeError(
"bands must be a list of ints or the string 'all' but got"
f" {bands}"
f"bands must be a list of ints or the string 'all' but got {bands}"
)
bands = np.arange(self.num_orientations)
else:
if not hasattr(bands, "__iter__"):
raise TypeError(
"bands must be a list of ints or the string 'all' but got"
f" {bands}"
f"bands must be a list of ints or the string 'all' but got {bands}"
)
bands: NDArray = np.array(bands, ndmin=1)
assert (bands >= 0).all(), "Error: band numbers must be larger than 0."
Expand Down Expand Up @@ -949,9 +947,9 @@ def steer_coeffs(
will have the same keys as `resteered_coeffs`.
"""
assert (
pyr_coeffs[(0, 0)].dtype not in complex_types
), "steering only implemented for real coefficients"
assert pyr_coeffs[(0, 0)].dtype not in complex_types, (

Check warning on line 950 in src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py

View check run for this annotation

Codecov / codecov/patch

src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py#L950

Added line #L950 was not covered by tests
"steering only implemented for real coefficients"
)
resteered_coeffs = {}
resteering_weights = {}
num_scales = self.num_scales
Expand Down
6 changes: 3 additions & 3 deletions src/plenoptic/simulate/models/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ def __init__(
center_std = torch.ones(out_channels) * center_std
if isinstance(surround_std, float) or surround_std.shape == torch.Size([]):
surround_std = torch.ones(out_channels) * surround_std
assert (
len(center_std) == out_channels and len(surround_std) == out_channels
), "stds must correspond to each out_channel"
assert len(center_std) == out_channels and len(surround_std) == out_channels, (
"stds must correspond to each out_channel"
)
assert amplitude_ratio >= 1.0, "ratio of amplitudes must at least be 1."

self.on_center = on_center
Expand Down
6 changes: 3 additions & 3 deletions src/plenoptic/synthesize/eigendistortion.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,9 +489,9 @@ def _indexer(self, idx: int) -> int:

all_idx = self.eigenindex
assert i in all_idx, "eigenindex must be the index of one of the vectors"
assert (
all_idx is not None and len(all_idx) != 0
), "No eigendistortions synthesized"
assert all_idx is not None and len(all_idx) != 0, (
"No eigendistortions synthesized"
)
return int(np.where(all_idx == i)[0])

def save(self, file_path: str):
Expand Down
3 changes: 1 addition & 2 deletions src/plenoptic/synthesize/mad_competition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,8 +1195,7 @@ def animate(
)
if mad.mad_image.ndim not in [3, 4]:
raise ValueError(
"animate() expects 3 or 4d data; unexpected"
" behavior will result otherwise!"
"animate() expects 3 or 4d data; unexpected behavior will result otherwise!"
)
_check_included_plots(included_plots, "included_plots")
_check_included_plots(width_ratios, "width_ratios")
Expand Down
3 changes: 1 addition & 2 deletions src/plenoptic/synthesize/metamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,8 +1745,7 @@ def animate(
)
if metamer.metamer.ndim not in [3, 4]:
raise ValueError(
"animate() expects 3 or 4d data; unexpected"
" behavior will result otherwise!"
"animate() expects 3 or 4d data; unexpected behavior will result otherwise!"
)
_check_included_plots(included_plots, "included_plots")
_check_included_plots(width_ratios, "width_ratios")
Expand Down
2 changes: 1 addition & 1 deletion src/plenoptic/synthesize/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def load(
"different! On two random tensors: "
f"Initialized: {init_loss}, Saved: "
f"{saved_loss}, difference: "
f"{init_loss-saved_loss}"
f"{init_loss - saved_loss}"
)
for k, v in tmp_dict.items():
setattr(self, k, v)
Expand Down
7 changes: 3 additions & 4 deletions src/plenoptic/tools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ def load_images(paths: str | list[str], as_gray: bool = True) -> Tensor:
im = imageio.imread(p)
except ValueError:
warnings.warn(
f"Unable to load in file {p}, it's probably not "
"an image, skipping..."
f"Unable to load in file {p}, it's probably not an image, skipping..."
)
continue
# make it a float32 array with values between 0 and 1
Expand Down Expand Up @@ -143,7 +142,7 @@ def load_images(paths: str | list[str], as_gray: bool = True) -> Tensor:
if as_gray:
if images.ndimension() != 3:
raise ValueError(
"For loading in images as grayscale, this should be a 3d" " tensor!"
"For loading in images as grayscale, this should be a 3d tensor!"
)
images = images.unsqueeze(1)
else:
Expand All @@ -155,7 +154,7 @@ def load_images(paths: str | list[str], as_gray: bool = True) -> Tensor:
images = images.unsqueeze(0) if len(paths) > 1 else images.unsqueeze(1)
if images.ndimension() != 4:
raise ValueError(
"Somehow ended up with other than 4 dimensions! Not sure how we" " got here"
"Somehow ended up with other than 4 dimensions! Not sure how we got here"
)
return images

Expand Down
4 changes: 2 additions & 2 deletions src/plenoptic/tools/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def imshow(
if as_rgb:
if im.shape[1] not in [3, 4]:
raise Exception(
"If as_rgb is True, then channel must have 3 " "or 4 elements!"
"If as_rgb is True, then channel must have 3 or 4 elements!"
)
im = im.transpose(0, 2, 3, 1)
# want to insert a fake "channel" dimension here, so our putting it
Expand Down Expand Up @@ -343,7 +343,7 @@ def animshow(
if as_rgb:
if vid.shape[1] not in [3, 4]:
raise Exception(
"If as_rgb is True, then channel must have 3 " "or 4 elements!"
"If as_rgb is True, then channel must have 3 or 4 elements!"
)
vid = vid.transpose(0, 2, 3, 4, 1)
# want to insert a fake "channel" dimension here, so our putting it
Expand Down
5 changes: 2 additions & 3 deletions src/plenoptic/tools/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def validate_model(
allowed_dtypes = [torch.float64, torch.complex128]
else:
raise TypeError(
"Only float or complex dtypes are allowed but got type" f" {image_dtype}"
f"Only float or complex dtypes are allowed but got type {image_dtype}"
)
if model(test_img).dtype not in allowed_dtypes:
raise TypeError("model changes precision of input, don't do that!")
Expand Down Expand Up @@ -250,8 +250,7 @@ def validate_coarse_to_fine(
)
except TypeError:
raise TypeError(
"model forward method does not accept scales argument"
f" {sc} {msg}"
f"model forward method does not accept scales argument {sc} {msg}"
)


Expand Down
104 changes: 51 additions & 53 deletions tests/test_geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@ def test_deviation_from_line_and_brownian_bridge(self):
stop = torch.randn(1, d).reshape(1, 1, sqrt_d, sqrt_d).to(DEVICE)
b = po.tools.sample_brownian_bridge(start, stop, t, d**0.5)
a, f = po.tools.deviation_from_line(b, normalize=True)
assert torch.abs(a[t // 2] - 0.5) < 1e-2, f"{a[t//2]}"
assert torch.abs(f[t // 2] - 2**0.5 / 2) < 1e-2, f"{f[t//2]}"
assert torch.abs(a[t // 2] - 0.5) < 1e-2, f"{a[t // 2]}"
assert torch.abs(f[t // 2] - 2**0.5 / 2) < 1e-2, f"{f[t // 2]}"

@pytest.mark.parametrize("normalize", [True, False])
def test_deviation_from_line_multichannel(self, normalize, einstein_img):
einstein_img = einstein_img.repeat(1, 3, 1, 1)
seq = po.tools.translation_sequence(einstein_img)
dist_along, dist_from = po.tools.deviation_from_line(seq, normalize)
assert (
dist_along.shape[0] == seq.shape[0]
), "Distance along line has wrong number of transitions!"
assert (
dist_from.shape[0] == seq.shape[0]
), "Distance from line has wrong number of transitions!"
assert dist_along.shape[0] == seq.shape[0], (
"Distance along line has wrong number of transitions!"
)
assert dist_from.shape[0] == seq.shape[0], (
"Distance from line has wrong number of transitions!"
)

@pytest.mark.parametrize("n_steps", [1, 10])
@pytest.mark.parametrize("max_norm", [0, 1, 10])
Expand Down Expand Up @@ -125,12 +125,12 @@ def test_translation_sequence(self, einstein_img, n_steps, multichannel):
einstein_img = einstein_img.repeat(1, 3, 1, 1)
with expectation:
shifted = po.tools.translation_sequence(einstein_img, n_steps)
assert torch.equal(
shifted[0], einstein_img[0]
), "somehow first frame changed!"
assert torch.equal(
shifted[1, 0, :, 1], shifted[0, 0, :, 0]
), "wrong dimension was translated!"
assert torch.equal(shifted[0], einstein_img[0]), (
"somehow first frame changed!"
)
assert torch.equal(shifted[1, 0, :, 1], shifted[0, 0, :, 0]), (
"wrong dimension was translated!"
)

@pytest.mark.parametrize(
"func",
Expand Down Expand Up @@ -178,15 +178,15 @@ def test_endpoints_dont_change(self, einstein_small_seq, model):
"straight",
)
moog.synthesize(max_iter=5)
assert torch.equal(
moog.geodesic[0], einstein_small_seq[0]
), "Somehow first endpoint changed!"
assert torch.equal(
moog.geodesic[-1], einstein_small_seq[-1]
), "Somehow last endpoint changed!"
assert not torch.equal(
moog.pixelfade[1:-1], moog.geodesic[1:-1]
), "Somehow middle of geodesic didn't changed!"
assert torch.equal(moog.geodesic[0], einstein_small_seq[0]), (
"Somehow first endpoint changed!"
)
assert torch.equal(moog.geodesic[-1], einstein_small_seq[-1]), (
"Somehow last endpoint changed!"
)
assert not torch.equal(moog.pixelfade[1:-1], moog.geodesic[1:-1]), (
"Somehow middle of geodesic didn't changed!"
)

@pytest.mark.parametrize("model", ["frontend.OnOff.nograd"], indirect=True)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -245,9 +245,7 @@ def test_save_load(self, einstein_small_seq, model, fail, tmp_path):
range_penalty = 0.5
expectation = pytest.raises(
ValueError,
match=(
"Saved and initialized range_penalty_lambda are" " different"
),
match=("Saved and initialized range_penalty_lambda are different"),
)
moog_copy = po.synth.Geodesic(
img_a,
Expand Down Expand Up @@ -358,9 +356,9 @@ def test_funcs_external_tensor(self, einstein_small_seq, model, func):
if func == "calculate_jerkiness":
arg_tensor.requires_grad_()
with_arg = getattr(moog, func)(arg_tensor)
assert not torch.equal(
no_arg, with_arg
), f"{func} is not using the input tensor!"
assert not torch.equal(no_arg, with_arg), (
f"{func} is not using the input tensor!"
)

@pytest.mark.parametrize("model", ["frontend.OnOff.nograd"], indirect=True)
def test_continue(self, einstein_small_seq, model):
Expand Down Expand Up @@ -390,25 +388,25 @@ def test_store_progress(self, einstein_small_seq, model, store_progress):
if store_progress == 3:
max_iter = 6
moog.synthesize(max_iter=max_iter, store_progress=store_progress)
assert len(moog.step_energy) == np.ceil(
max_iter / store_progress
), "Didn't end up with enough step_energy after first synth!"
assert len(moog.dev_from_line) == np.ceil(
max_iter / store_progress
), "Didn't end up with enough dev_from_line after first synth!"
assert (
len(moog.losses) == max_iter
), "Didn't end up with enough losses after first synth!"
assert len(moog.step_energy) == np.ceil(max_iter / store_progress), (
"Didn't end up with enough step_energy after first synth!"
)
assert len(moog.dev_from_line) == np.ceil(max_iter / store_progress), (
"Didn't end up with enough dev_from_line after first synth!"
)
assert len(moog.losses) == max_iter, (
"Didn't end up with enough losses after first synth!"
)
moog.synthesize(max_iter=max_iter, store_progress=store_progress)
assert len(moog.step_energy) == np.ceil(
2 * max_iter / store_progress
), "Didn't end up with enough step_energy after second synth!"
assert len(moog.dev_from_line) == np.ceil(
2 * max_iter / store_progress
), "Didn't end up with enough dev_from_line after second synth!"
assert (
len(moog.losses) == 2 * max_iter
), "Didn't end up with enough losses after second synth!"
assert len(moog.step_energy) == np.ceil(2 * max_iter / store_progress), (
"Didn't end up with enough step_energy after second synth!"
)
assert len(moog.dev_from_line) == np.ceil(2 * max_iter / store_progress), (
"Didn't end up with enough dev_from_line after second synth!"
)
assert len(moog.losses) == 2 * max_iter, (
"Didn't end up with enough losses after second synth!"
)

@pytest.mark.parametrize("model", ["frontend.OnOff.nograd"], indirect=True)
def test_stop_criterion(self, einstein_small_seq, model):
Expand All @@ -419,9 +417,9 @@ def test_stop_criterion(self, einstein_small_seq, model):
einstein_small_seq[:1], einstein_small_seq[-1:], model, 5
)
moog.synthesize(max_iter=10, stop_criterion=0.06, stop_iters_to_check=1)
assert (
abs(moog.pixel_change_norm[-1:]) < 0.06
).all(), "Didn't stop when hit criterion!"
assert (
abs(moog.pixel_change_norm[:-1]) > 0.06
).all(), "Stopped after hit criterion!"
assert (abs(moog.pixel_change_norm[-1:]) < 0.06).all(), (
"Didn't stop when hit criterion!"
)
assert (abs(moog.pixel_change_norm[:-1]) > 0.06).all(), (
"Stopped after hit criterion!"
)
Loading

0 comments on commit 13ab1bd

Please sign in to comment.