Skip to content

Commit

Permalink
Resolve pyi024 warnings (#2146)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Oct 20, 2023
1 parent 3189266 commit fdaab96
Show file tree
Hide file tree
Showing 69 changed files with 538 additions and 394 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ ignore = [
"S301", # todo: `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue # todo
"S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. # todo
"B905", # todo: `zip()` without an explicit `strict=` parameter
"PYI024", # todo: Use `typing.NamedTuple` instead of `collections.namedtuple`
]
# Exclude a variety of commonly ignored directories.
exclude = [
Expand Down
35 changes: 28 additions & 7 deletions src/torchmetrics/functional/image/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
# License under BSD 2-clause
import inspect
import os
from collections import namedtuple
from typing import List, NamedTuple, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -86,13 +85,21 @@ def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None

def forward(self, x: Tensor) -> NamedTuple:
"""Process input."""
squeeze_output = namedtuple("squeeze_output", ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"])

class _SqueezeOutput(NamedTuple):
relu1: Tensor
relu2: Tensor
relu3: Tensor
relu4: Tensor
relu5: Tensor
relu6: Tensor
relu7: Tensor

relus = []
for slice_ in self.slices:
x = slice_(x)
relus.append(x)
return squeeze_output(*relus)
return _SqueezeOutput(*relus)


class Alexnet(torch.nn.Module):
Expand Down Expand Up @@ -134,8 +141,15 @@ def forward(self, x: Tensor) -> NamedTuple:
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
alexnet_outputs = namedtuple("alexnet_outputs", ["relu1", "relu2", "relu3", "relu4", "relu5"])
return alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)

class _AlexnetOutputs(NamedTuple):
relu1: Tensor
relu2: Tensor
relu3: Tensor
relu4: Tensor
relu5: Tensor

return _AlexnetOutputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)


class Vgg16(torch.nn.Module):
Expand Down Expand Up @@ -177,8 +191,15 @@ def forward(self, x: Tensor) -> NamedTuple:
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple("vgg_outputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"])
return vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)

class _VGGOutputs(NamedTuple):
relu1_2: Tensor
relu2_2: Tensor
relu3_3: Tensor
relu4_3: Tensor
relu5_3: Tensor

return _VGGOutputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)


def _spatial_average(in_tens: Tensor, keep_dim: bool = True) -> Tensor:
Expand Down
16 changes: 16 additions & 0 deletions tests/unittests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os.path
from typing import NamedTuple

import numpy
import torch
from torch import Tensor

from unittests.conftest import (
BATCH_SIZE,
Expand All @@ -26,9 +28,23 @@
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False


class _Input(NamedTuple):
preds: Tensor
target: Tensor


class _GroupInput(NamedTuple):
preds: Tensor
target: Tensor
groups: Tensor


__all__ = [
"BATCH_SIZE",
"EXTRA_DIM",
"_Input",
"_GroupInput",
"NUM_BATCHES",
"NUM_CLASSES",
"NUM_PROCESSES",
Expand Down
6 changes: 2 additions & 4 deletions tests/unittests/audio/test_c_si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple

import pytest
import torch
from scipy.io import wavfile
from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
from torchmetrics.functional.audio import complex_scale_invariant_signal_noise_ratio

from unittests import BATCH_SIZE, NUM_BATCHES
from unittests import BATCH_SIZE, NUM_BATCHES, _Input
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

Input = namedtuple("Input", ["preds", "target"])

inputs = Input(
inputs = _Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 129, 20, 2),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 129, 20, 2),
)
Expand Down
7 changes: 3 additions & 4 deletions tests/unittests/audio/test_pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
Expand All @@ -22,21 +21,21 @@
from torchmetrics.audio import PerceptualEvaluationSpeechQuality
from torchmetrics.functional.audio import perceptual_evaluation_speech_quality

from unittests import _Input
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

Input = namedtuple("Input", ["preds", "target"])

# for 8k sample rate, need at least 8k/4=2000 samples
inputs_8k = Input(
inputs_8k = _Input(
preds=torch.rand(2, 3, 2100),
target=torch.rand(2, 3, 2100),
)
# for 16k sample rate, need at least 16k/4=4000 samples
inputs_16k = Input(
inputs_16k = _Input(
preds=torch.rand(2, 3, 4100),
target=torch.rand(2, 3, 4100),
)
Expand Down
8 changes: 3 additions & 5 deletions tests/unittests/audio/test_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial
from typing import Callable, Tuple

Expand All @@ -31,23 +30,22 @@
_find_best_perm_by_linear_sum_assignment,
)

from unittests import BATCH_SIZE, NUM_BATCHES
from unittests import BATCH_SIZE, NUM_BATCHES, _Input
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

TIME = 10

Input = namedtuple("Input", ["preds", "target"])

# three speaker examples to test _find_best_perm_by_linear_sum_assignment
inputs1 = Input(
inputs1 = _Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME),
)
# two speaker examples to test _find_best_perm_by_exhuastive_method
inputs2 = Input(
inputs2 = _Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME),
)
Expand Down
6 changes: 2 additions & 4 deletions tests/unittests/audio/test_sa_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
Expand All @@ -24,17 +23,16 @@
source_aggregated_signal_distortion_ratio,
)

from unittests import BATCH_SIZE, NUM_BATCHES
from unittests import BATCH_SIZE, NUM_BATCHES, _Input
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

NUM_SAMPLES = 100 # the number of samples

Input = namedtuple("Input", ["preds", "target"])

inputs = Input(
inputs = _Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, NUM_SAMPLES),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, NUM_SAMPLES),
)
Expand Down
8 changes: 4 additions & 4 deletions tests/unittests/audio/test_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial
from typing import Callable

Expand All @@ -25,19 +24,20 @@
from torchmetrics.functional import signal_distortion_ratio
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_11

from unittests import _Input
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _SAMPLE_NUMPY_ISSUE_895
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

Input = namedtuple("Input", ["preds", "target"])

inputs_1spk = Input(
inputs_1spk = _Input(
preds=torch.rand(2, 1, 1, 500),
target=torch.rand(2, 1, 1, 500),
)
inputs_2spk = Input(

inputs_2spk = _Input(
preds=torch.rand(2, 1, 2, 500),
target=torch.rand(2, 1, 2, 500),
)
Expand Down
6 changes: 2 additions & 4 deletions tests/unittests/audio/test_si_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
Expand All @@ -21,17 +20,16 @@
from torchmetrics.audio import ScaleInvariantSignalDistortionRatio
from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio

from unittests import BATCH_SIZE, NUM_BATCHES
from unittests import BATCH_SIZE, NUM_BATCHES, _Input
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

NUM_SAMPLES = 100

Input = namedtuple("Input", ["preds", "target"])

inputs = Input(
inputs = _Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, NUM_SAMPLES),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, NUM_SAMPLES),
)
Expand Down
6 changes: 2 additions & 4 deletions tests/unittests/audio/test_si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
Expand All @@ -21,17 +20,16 @@
from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio

from unittests import BATCH_SIZE, NUM_BATCHES
from unittests import BATCH_SIZE, NUM_BATCHES, _Input
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

NUM_SAMPLES = 100

Input = namedtuple("Input", ["preds", "target"])

inputs = Input(
inputs = _Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, NUM_SAMPLES),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, NUM_SAMPLES),
)
Expand Down
5 changes: 2 additions & 3 deletions tests/unittests/audio/test_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial
from typing import Callable

Expand All @@ -22,14 +21,14 @@
from torchmetrics.audio import SignalNoiseRatio
from torchmetrics.functional.audio import signal_noise_ratio

from unittests import _Input
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

Input = namedtuple("Input", ["preds", "target"])

inputs = Input(
inputs = _Input(
preds=torch.rand(2, 1, 1, 25),
target=torch.rand(2, 1, 1, 25),
)
Expand Down
7 changes: 3 additions & 4 deletions tests/unittests/audio/test_stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
Expand All @@ -22,19 +21,19 @@
from torchmetrics.audio import ShortTimeObjectiveIntelligibility
from torchmetrics.functional.audio import short_time_objective_intelligibility

from unittests import _Input
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

Input = namedtuple("Input", ["preds", "target"])

inputs_8k = Input(
inputs_8k = _Input(
preds=torch.rand(2, 3, 8000),
target=torch.rand(2, 3, 8000),
)
inputs_16k = Input(
inputs_16k = _Input(
preds=torch.rand(2, 3, 16000),
target=torch.rand(2, 3, 16000),
)
Expand Down
Loading

0 comments on commit fdaab96

Please sign in to comment.