Skip to content

Commit

Permalink
Fix type checkling
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney committed Jan 26, 2024
1 parent cfcb8a8 commit 9c5f715
Showing 1 changed file with 7 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""Test the model reconstruction score metric."""

from jaxtyping import Float
import pytest
from syrupy.session import SnapshotSession
import torch
from torch import Tensor
from torch import Tensor, tensor

from sparse_autoencoder.metrics.utils.find_metric_result import find_metric_result
from sparse_autoencoder.metrics.validate.abstract_validate_metric import ValidationMetricData
from sparse_autoencoder.metrics.validate.model_reconstruction_score import ModelReconstructionScore
from sparse_autoencoder.tensor_types import Axis


def test_model_reconstruction_score_empty_data() -> None:
Expand All @@ -19,9 +17,9 @@ def test_model_reconstruction_score_empty_data() -> None:
is provided (i.e., at the end of training or in similar scenarios).
"""
data = ValidationMetricData(
source_model_loss=Float[Tensor, Axis.ITEMS]([]),
source_model_loss_with_reconstruction=Float[Tensor, Axis.ITEMS]([]),
source_model_loss_with_zero_ablation=Float[Tensor, Axis.ITEMS]([]),
source_model_loss=tensor([]),
source_model_loss_with_reconstruction=tensor([]),
source_model_loss_with_zero_ablation=tensor([]),
)
metric = ModelReconstructionScore()
result = metric.calculate(data)
Expand All @@ -41,13 +39,9 @@ def test_model_reconstruction_score_empty_data() -> None:
),
(
ValidationMetricData(
source_model_loss=Float[Tensor, Axis.ITEMS]([[0.5], [1.5], [2.5]]),
source_model_loss_with_reconstruction=Float[Tensor, Axis.ITEMS](
[[1.5], [2.5], [3.5]]
),
source_model_loss_with_zero_ablation=Float[Tensor, Axis.ITEMS](
[[8.0], [7.0], [4.0]]
),
source_model_loss=tensor([[0.5], [1.5], [2.5]]),
source_model_loss_with_reconstruction=tensor([[1.5], [2.5], [3.5]]),
source_model_loss_with_zero_ablation=tensor([[8.0], [7.0], [4.0]]),
),
0.79,
),
Expand Down

0 comments on commit 9c5f715

Please sign in to comment.