Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify dataparallel approach #191

Merged
merged 7 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"jaxtyping",
"kaiming",
"keepdim",
"logit",
"lognormal",
"loguniform",
"loguniformvalues",
Expand All @@ -73,6 +74,7 @@
"neox",
"nonlinerity",
"numel",
"onebit",
"openwebtext",
"optim",
"penality",
Expand Down Expand Up @@ -116,6 +118,7 @@
"venv",
"virtualenv",
"virtualenvs",
"wandb"
"wandb",
"zoadam"
]
}
1,253 changes: 648 additions & 605 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
pytest-integration=">=0.2.3"
pytest-timeout=">=2.2.0"
pytest-xdist="^3.5.0"
ruff=">=0.1.4"
ruff=">=0.1.14"
syrupy=">=4.6.0"

[tool.poetry.group.demos.dependencies]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import Field, NonNegativeInt, PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn.parallel import DataParallel
from torch.utils.data import DataLoader

from sparse_autoencoder.activation_resampler.utils.component_slice_tensor import (
Expand All @@ -17,7 +18,6 @@
from sparse_autoencoder.loss.abstract_loss import AbstractLoss
from sparse_autoencoder.tensor_types import Axis
from sparse_autoencoder.train.utils.get_model_device import get_model_device
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


@dataclass
Expand Down Expand Up @@ -207,7 +207,7 @@ def _get_dead_neuron_indices(
def compute_loss_and_get_activations(
self,
store: ActivationStore,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder],
loss_fn: AbstractLoss,
train_batch_size: int,
) -> LossInputActivationsTuple:
Expand Down Expand Up @@ -440,7 +440,7 @@ def renormalize_and_scale(
def resample_dead_neurons(
self,
activation_store: ActivationStore,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder],
loss_fn: AbstractLoss,
train_batch_size: int,
) -> list[ParameterUpdateResults]:
Expand Down Expand Up @@ -530,7 +530,7 @@ def step_resampler(
self,
batch_neuron_activity: Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)],
activation_store: ActivationStore,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder],
loss_fn: AbstractLoss,
train_batch_size: int,
) -> list[ParameterUpdateResults] | None:
Expand Down
2 changes: 1 addition & 1 deletion sparse_autoencoder/autoencoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def load(
The loaded model.
"""
# Load the file
serialized_state = torch.load(file_path)
serialized_state = torch.load(file_path, map_location=torch.device("cpu"))
state = SparseAutoencoderState.model_validate(serialized_state)

# Initialise the model
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""Test the model reconstruction score metric."""

from jaxtyping import Float
import pytest
from syrupy.session import SnapshotSession
import torch
from torch import Tensor
from torch import Tensor, tensor

from sparse_autoencoder.metrics.utils.find_metric_result import find_metric_result
from sparse_autoencoder.metrics.validate.abstract_validate_metric import ValidationMetricData
from sparse_autoencoder.metrics.validate.model_reconstruction_score import ModelReconstructionScore
from sparse_autoencoder.tensor_types import Axis


def test_model_reconstruction_score_empty_data() -> None:
Expand All @@ -19,9 +17,9 @@ def test_model_reconstruction_score_empty_data() -> None:
is provided (i.e., at the end of training or in similar scenarios).
"""
data = ValidationMetricData(
source_model_loss=Float[Tensor, Axis.ITEMS]([]),
source_model_loss_with_reconstruction=Float[Tensor, Axis.ITEMS]([]),
source_model_loss_with_zero_ablation=Float[Tensor, Axis.ITEMS]([]),
source_model_loss=tensor([]),
source_model_loss_with_reconstruction=tensor([]),
source_model_loss_with_zero_ablation=tensor([]),
)
metric = ModelReconstructionScore()
result = metric.calculate(data)
Expand All @@ -41,13 +39,9 @@ def test_model_reconstruction_score_empty_data() -> None:
),
(
ValidationMetricData(
source_model_loss=Float[Tensor, Axis.ITEMS]([[0.5], [1.5], [2.5]]),
source_model_loss_with_reconstruction=Float[Tensor, Axis.ITEMS](
[[1.5], [2.5], [3.5]]
),
source_model_loss_with_zero_ablation=Float[Tensor, Axis.ITEMS](
[[8.0], [7.0], [4.0]]
),
source_model_loss=tensor([[0.5], [1.5], [2.5]]),
source_model_loss_with_reconstruction=tensor([[1.5], [2.5], [3.5]]),
source_model_loss_with_zero_ablation=tensor([[8.0], [7.0], [4.0]]),
),
0.79,
),
Expand Down
2 changes: 1 addition & 1 deletion sparse_autoencoder/optimizer/adam_with_reset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__( # (extending existing implementation)
lr: float | Float[Tensor, Axis.names(Axis.SINGLE_ITEM)] = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
weight_decay: float = 0.0,
*,
amsgrad: bool = False,
foreach: bool | None = None,
Expand Down
22 changes: 11 additions & 11 deletions sparse_autoencoder/source_model/replace_activations_hook.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,36 @@
"""Replace activations hook."""
from typing import TYPE_CHECKING

from jaxtyping import Float
from torch import Tensor
from torch.nn.parallel import DataParallel
from transformer_lens.hook_points import HookPoint

from sparse_autoencoder.autoencoder.model import SparseAutoencoder
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


if TYPE_CHECKING:
from sparse_autoencoder.tensor_types import Axis
from jaxtyping import Float


def replace_activations_hook(
value: Tensor,
hook: HookPoint, # noqa: ARG001
sparse_autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
sparse_autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder],
component_idx: int | None = None,
n_components: int | None = None,
) -> Tensor:
"""Replace activations hook.

This should be pre-initialised with `functools.partial`.

Args:
value: The activations to replace.
hook: The hook point.
sparse_autoencoder: The sparse autoencoder. This should be pre-initialised with
`functools.partial`.
sparse_autoencoder: The sparse autoencoder.
component_idx: The component index to replace the activations with, if just replacing
activations for a single component. Requires the model to have a component axis.
n_components: The number of components that the SAE is trained on.

Returns:
Replaced activations.
Expand All @@ -43,11 +46,8 @@ def replace_activations_hook(
)

if component_idx is not None:
if sparse_autoencoder.config.n_components is None:
error_message = (
"Cannot replace for a specific component, if the model does not have a "
"component axis."
)
if n_components is None:
error_message = "The number of model components must be set if component_idx is set."
raise RuntimeError(error_message)

# The approach here is to run a forward pass with dummy values for all components other than
Expand All @@ -56,7 +56,7 @@ def replace_activations_hook(
# components.
expanded_shape = [
squashed_value.shape[0],
sparse_autoencoder.config.n_components,
n_components,
squashed_value.shape[-1],
]
expanded = squashed_value.unsqueeze(1).expand(*expanded_shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def test_hook_replaces_activations_2_components() -> None:
fwd_hooks=[
(
"blocks.0.hook_mlp_out",
partial(replace_activations_hook, sparse_autoencoder=autoencoder, component_idx=1),
partial(
replace_activations_hook,
sparse_autoencoder=autoencoder,
component_idx=1,
n_components=2,
),
)
],
)
Expand Down
34 changes: 24 additions & 10 deletions sparse_autoencoder/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import NonNegativeInt, PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn.parallel import DataParallel
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
Expand All @@ -32,7 +33,6 @@
from sparse_autoencoder.source_model.zero_ablate_hook import zero_ablate_hook
from sparse_autoencoder.tensor_types import Axis
from sparse_autoencoder.train.utils.get_model_device import get_model_device
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


if TYPE_CHECKING:
Expand All @@ -51,9 +51,15 @@ class Pipeline:
activation_resampler: ActivationResampler | None
"""Activation resampler to use."""

autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder]
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder]
"""Sparse autoencoder to train."""

n_input_features: int
"""Number of input features in the sparse autoencoder."""

n_learned_features: int
"""Number of learned features in the sparse autoencoder."""

cache_names: list[str]
"""Names of the cache hook points to use in the source model."""

Expand Down Expand Up @@ -81,7 +87,7 @@ class Pipeline:
source_dataset: SourceDataset
"""Source dataset to generate activation data from (tokenized prompts)."""

source_model: HookedTransformer | DataParallelWithModelAttributes[HookedTransformer]
source_model: HookedTransformer | DataParallel[HookedTransformer]
"""Source model to get activations from."""

total_activations_trained_on: int = 0
Expand All @@ -97,13 +103,15 @@ def n_components(self) -> int:
def __init__(
self,
activation_resampler: ActivationResampler | None,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder],
cache_names: list[str],
layer: NonNegativeInt,
loss: AbstractLoss,
optimizer: AbstractOptimizerWithReset,
source_dataset: SourceDataset,
source_model: HookedTransformer | DataParallelWithModelAttributes[HookedTransformer],
source_model: HookedTransformer | DataParallel[HookedTransformer],
n_input_features: int,
n_learned_features: int,
run_name: str = "sparse_autoencoder",
checkpoint_directory: Path = DEFAULT_CHECKPOINT_DIRECTORY,
lr_scheduler: LRScheduler | None = None,
Expand All @@ -124,6 +132,8 @@ def __init__(
optimizer: Optimizer to use.
source_dataset: Source dataset to get data from.
source_model: Source model to get activations from.
n_input_features: Number of input features in the sparse autoencoder.
n_learned_features: Number of learned features in the sparse autoencoder.
run_name: Name of the run for saving checkpoints.
checkpoint_directory: Directory to save checkpoints to.
lr_scheduler: Learning rate scheduler to use.
Expand All @@ -146,6 +156,8 @@ def __init__(
self.source_data_batch_size = source_data_batch_size
self.source_dataset = source_dataset
self.source_model = source_model
self.n_input_features = n_input_features
self.n_learned_features = n_learned_features

# Create a stateful iterator
source_dataloader = source_dataset.get_dataloader(
Expand Down Expand Up @@ -175,9 +187,10 @@ def generate_activations(self, store_size: PositiveInt) -> TensorActivationStore
raise ValueError(error_message)

# Setup the store
n_neurons: int = self.autoencoder.config.n_input_features
source_model_device: torch.device = get_model_device(self.source_model)
store = TensorActivationStore(store_size, n_neurons, n_components=self.n_components)
store = TensorActivationStore(
store_size, self.n_input_features, n_components=self.n_components
)

# Add the hook to the model (will automatically store the activations every time the model
# runs)
Expand Down Expand Up @@ -225,9 +238,9 @@ def train_autoencoder(
learned_activations_fired_count: Int64[
Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)
] = torch.zeros(
(self.n_components, self.autoencoder.config.n_learned_features),
(self.n_components, self.n_learned_features),
dtype=torch.int64,
device=autoencoder_device,
device=torch.device("cpu"),
)

for store_batch in activations_dataloader:
Expand Down Expand Up @@ -260,7 +273,7 @@ def train_autoencoder(
# Store count of how many neurons have fired
with torch.no_grad():
fired = learned_activations > 0
learned_activations_fired_count.add_(fired.sum(dim=0))
learned_activations_fired_count.add_(fired.sum(dim=0).cpu())

# Backwards pass
total_loss.backward()
Expand Down Expand Up @@ -358,6 +371,7 @@ def validate_sae(self, validation_n_activations: PositiveInt) -> None:
replace_activations_hook,
sparse_autoencoder=self.autoencoder,
component_idx=component_idx,
n_components=self.n_components,
)

with torch.no_grad():
Expand Down
Loading
Loading