From e62c470c9bc7908db3e1a73173d31cfbed71794c Mon Sep 17 00:00:00 2001 From: Alan <41682961+alan-cooney@users.noreply.github.com> Date: Tue, 9 Jan 2024 11:43:15 -0300 Subject: [PATCH] Default to average metric aggregation across components (#180) --- sparse_autoencoder/metrics/abstract_metric.py | 7 ++++--- .../train/tests/__snapshots__/test_l0_norm_metric.ambr | 2 +- .../tests/__snapshots__/test_neuron_activity_metric.ambr | 8 ++++---- .../__snapshots__/test_model_reconstruction_score.ambr | 9 ++++----- sparse_autoencoder/train/tests/test_pipeline.py | 6 ++++++ 5 files changed, 19 insertions(+), 13 deletions(-) diff --git a/sparse_autoencoder/metrics/abstract_metric.py b/sparse_autoencoder/metrics/abstract_metric.py index 00b4bc86..6dba237d 100644 --- a/sparse_autoencoder/metrics/abstract_metric.py +++ b/sparse_autoencoder/metrics/abstract_metric.py @@ -13,6 +13,7 @@ from jaxtyping import Float, Int import numpy as np from strenum import LowercaseStrEnum, SnakeCaseStrEnum +import torch from torch import Tensor from wandb import data_types @@ -99,7 +100,7 @@ def __init__( | Int[Tensor, Axis.names(Axis.COMPONENT)], name: str, location: MetricLocation, - aggregate_approach: ComponentAggregationApproach | None = ComponentAggregationApproach.ALL, + aggregate_approach: ComponentAggregationApproach | None = ComponentAggregationApproach.MEAN, aggregate_value: Any | None = None, # noqa: ANN401 postfix: str | None = None, ) -> None: @@ -195,9 +196,9 @@ def aggregate_value( # noqa: PLR0911 ): match self.aggregate_approach: case ComponentAggregationApproach.MEAN: - return self.component_wise_values.mean(dim=0) + return self.component_wise_values.mean(dim=0, dtype=torch.float32) case ComponentAggregationApproach.SUM: - return self.component_wise_values.sum(dim=0) + return self.component_wise_values.sum(dim=0, dtype=torch.float32) case ComponentAggregationApproach.ALL: return self.component_wise_values case _: diff --git a/sparse_autoencoder/metrics/train/tests/__snapshots__/test_l0_norm_metric.ambr b/sparse_autoencoder/metrics/train/tests/__snapshots__/test_l0_norm_metric.ambr index 8f2f6149..b3ffd14f 100644 --- a/sparse_autoencoder/metrics/train/tests/__snapshots__/test_l0_norm_metric.ambr +++ b/sparse_autoencoder/metrics/train/tests/__snapshots__/test_l0_norm_metric.ambr @@ -8,7 +8,7 @@ 'component_3/train/learned_activations_l0_norm': tensor(8.), 'component_4/train/learned_activations_l0_norm': tensor(8.), 'component_5/train/learned_activations_l0_norm': tensor(8.), - 'train/learned_activations_l0_norm': tensor([8., 8., 8., 8., 8., 8.]), + 'train/learned_activations_l0_norm/component_mean': tensor(8.), }), ]) # --- diff --git a/sparse_autoencoder/metrics/train/tests/__snapshots__/test_neuron_activity_metric.ambr b/sparse_autoencoder/metrics/train/tests/__snapshots__/test_neuron_activity_metric.ambr index 55bc5ffb..57897c29 100644 --- a/sparse_autoencoder/metrics/train/tests/__snapshots__/test_neuron_activity_metric.ambr +++ b/sparse_autoencoder/metrics/train/tests/__snapshots__/test_neuron_activity_metric.ambr @@ -8,7 +8,7 @@ 'component_3/train/learned_neuron_activity/dead_over_10_activations': tensor(0), 'component_4/train/learned_neuron_activity/dead_over_10_activations': tensor(0), 'component_5/train/learned_neuron_activity/dead_over_10_activations': tensor(0), - 'train/learned_neuron_activity/dead_over_10_activations': tensor([0, 0, 0, 0, 0, 0]), + 'train/learned_neuron_activity/dead_over_10_activations/component_mean': tensor(0.), }), dict({ 'component_0/train/learned_neuron_activity/alive_over_10_activations': tensor(8), @@ -17,7 +17,7 @@ 'component_3/train/learned_neuron_activity/alive_over_10_activations': tensor(8), 'component_4/train/learned_neuron_activity/alive_over_10_activations': tensor(8), 'component_5/train/learned_neuron_activity/alive_over_10_activations': tensor(8), - 'train/learned_neuron_activity/alive_over_10_activations': tensor([8, 8, 8, 8, 8, 8]), + 'train/learned_neuron_activity/alive_over_10_activations/component_mean': tensor(8.), }), dict({ 'component_0/train/learned_neuron_activity/activity_histogram_over_10_activations': Histogram( @@ -366,7 +366,7 @@ 'component_3/train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations': tensor(0), 'component_4/train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations': tensor(0), 'component_5/train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations': tensor(0), - 'train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations': tensor([0, 0, 0, 0, 0, 0]), + 'train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations/component_mean': tensor(0.), }), dict({ 'component_0/train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor(0), @@ -375,7 +375,7 @@ 'component_3/train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor(0), 'component_4/train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor(0), 'component_5/train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor(0), - 'train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor([0, 0, 0, 0, 0, 0]), + 'train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations/component_mean': tensor(0.), }), ]) # --- diff --git a/sparse_autoencoder/metrics/validate/tests/__snapshots__/test_model_reconstruction_score.ambr b/sparse_autoencoder/metrics/validate/tests/__snapshots__/test_model_reconstruction_score.ambr index 6284e1e2..e7a54b42 100644 --- a/sparse_autoencoder/metrics/validate/tests/__snapshots__/test_model_reconstruction_score.ambr +++ b/sparse_autoencoder/metrics/validate/tests/__snapshots__/test_model_reconstruction_score.ambr @@ -8,7 +8,7 @@ 'component_3/validate/reconstruction_score/baseline_loss': tensor(0.4598), 'component_4/validate/reconstruction_score/baseline_loss': tensor(0.4281), 'component_5/validate/reconstruction_score/baseline_loss': tensor(0.4961), - 'validate/reconstruction_score/baseline_loss': tensor([0.3800, 0.5251, 0.4923, 0.4598, 0.4281, 0.4961]), + 'validate/reconstruction_score/baseline_loss/component_mean': tensor(0.4636), }), dict({ 'component_0/validate/reconstruction_score/loss_with_reconstruction': tensor(0.6111), @@ -17,7 +17,7 @@ 'component_3/validate/reconstruction_score/loss_with_reconstruction': tensor(0.6497), 'component_4/validate/reconstruction_score/loss_with_reconstruction': tensor(0.4929), 'component_5/validate/reconstruction_score/loss_with_reconstruction': tensor(0.3723), - 'validate/reconstruction_score/loss_with_reconstruction': tensor([0.6111, 0.5219, 0.4063, 0.6497, 0.4929, 0.3723]), + 'validate/reconstruction_score/loss_with_reconstruction/component_mean': tensor(0.5090), }), dict({ 'component_0/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.2891), @@ -26,7 +26,7 @@ 'component_3/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.4740), 'component_4/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.5452), 'component_5/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.3733), - 'validate/reconstruction_score/loss_with_zero_ablation': tensor([0.2891, 0.3879, 0.5850, 0.4740, 0.5452, 0.3733]), + 'validate/reconstruction_score/loss_with_zero_ablation/component_mean': tensor(0.4424), }), dict({ 'component_0/validate/reconstruction_score': tensor(3.5422), @@ -35,8 +35,7 @@ 'component_3/validate/reconstruction_score': tensor(-12.3338), 'component_4/validate/reconstruction_score': tensor(0.4468), 'component_5/validate/reconstruction_score': tensor(-0.0081), - 'validate/reconstruction_score': tensor([ 3.5422e+00, 9.7672e-01, 1.9278e+00, -1.2334e+01, 4.4681e-01, - -8.1113e-03]), + 'validate/reconstruction_score/component_mean': tensor(-0.9081), }), ]) # --- diff --git a/sparse_autoencoder/train/tests/test_pipeline.py b/sparse_autoencoder/train/tests/test_pipeline.py index c1be1b66..da7d9128 100644 --- a/sparse_autoencoder/train/tests/test_pipeline.py +++ b/sparse_autoencoder/train/tests/test_pipeline.py @@ -334,6 +334,12 @@ def calculate(self, data: ValidationMetricData) -> list[MetricResult]: dummy_metric.data.source_model_loss_with_zero_ablation is not None ), "Source model loss with zero ablation should be calculated." + # Check the dimensions are correct + ndim_with_component = 2 + assert dummy_metric.data.source_model_loss.ndim == ndim_with_component + assert dummy_metric.data.source_model_loss_with_reconstruction.ndim == ndim_with_component + assert dummy_metric.data.source_model_loss_with_zero_ablation.ndim == ndim_with_component + class TestSaveCheckpoint: """Test the save_checkpoint method."""