Skip to content

Commit

Permalink
Revert-deepspeed (#189)
Browse files Browse the repository at this point in the history
This didn't sufficiently improve performance to warrant the addition
  • Loading branch information
alan-cooney authored Jan 26, 2024
1 parent f55652e commit 41697a0
Show file tree
Hide file tree
Showing 17 changed files with 342 additions and 683 deletions.
5 changes: 1 addition & 4 deletions .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
"jaxtyping",
"kaiming",
"keepdim",
"logit",
"lognormal",
"loguniform",
"loguniformvalues",
Expand All @@ -74,7 +73,6 @@
"neox",
"nonlinerity",
"numel",
"onebit",
"openwebtext",
"optim",
"penality",
Expand Down Expand Up @@ -118,7 +116,6 @@
"venv",
"virtualenv",
"virtualenvs",
"wandb",
"zoadam"
"wandb"
]
}
542 changes: 207 additions & 335 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 1 addition & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,17 @@
readme="README.md"
version="0.0.0"

# Note: Zstandard is required for downloading datasets such as The Pile
[tool.poetry.dependencies]
datasets=">=2.15.0"
deepspeed={version=">=0.12.6", extras=["deepspeed"], optional=false}
einops=">=0.6"
mpi4py={version=">=3.1.5", extras=["deepspeed"], optional=true}
pydantic=">=2.5.2"
python=">=3.10, <3.12"
strenum=">=0.4.15"
tokenizers=">=0.15.0"
torch=">=2.1.1"
transformers=">=4.35.2"
wandb=">=0.16.1"
zstandard=">=0.22.0"
zstandard=">=0.22.0" # Required for downloading datasets such as The Pile

[tool.poetry.group]
[tool.poetry.group.dev.dependencies]
Expand Down Expand Up @@ -57,9 +54,6 @@
pymdown-extensions=">=10.5"
pytkdocs-tweaks=">=0.0.7"

[tool.poetry.extras]
deepspeed=["deepspeed", "mpi4py"]

[tool.poetry.scripts]
join-sae-sweep='sparse_autoencoder.train.join_sweep:run'

Expand Down
2 changes: 0 additions & 2 deletions sparse_autoencoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from sparse_autoencoder.metrics.train.capacity import CapacityMetric
from sparse_autoencoder.metrics.train.feature_density import TrainBatchFeatureDensityMetric
from sparse_autoencoder.optimizer.adam_with_reset import AdamWithReset
from sparse_autoencoder.optimizer.deepspeed_adam_with_reset import ZeroOneAdamWithReset
from sparse_autoencoder.source_data.pretokenized_dataset import PreTokenizedDataset
from sparse_autoencoder.source_data.text_dataset import TextDataset
from sparse_autoencoder.train.pipeline import Pipeline
Expand Down Expand Up @@ -84,5 +83,4 @@
"TensorActivationStore",
"TextDataset",
"TrainBatchFeatureDensityMetric",
"ZeroOneAdamWithReset",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
from dataclasses import dataclass
from typing import Annotated, NamedTuple

from deepspeed import DeepSpeedEngine
from einops import rearrange
from jaxtyping import Bool, Float, Int64
from pydantic import Field, NonNegativeInt, PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn.parallel import DataParallel
from torch.utils.data import DataLoader

from sparse_autoencoder.activation_resampler.utils.component_slice_tensor import (
Expand All @@ -19,6 +17,7 @@
from sparse_autoencoder.loss.abstract_loss import AbstractLoss
from sparse_autoencoder.tensor_types import Axis
from sparse_autoencoder.train.utils.get_model_device import get_model_device
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


@dataclass
Expand Down Expand Up @@ -208,7 +207,7 @@ def _get_dead_neuron_indices(
def compute_loss_and_get_activations(
self,
store: ActivationStore,
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder] | DeepSpeedEngine,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
loss_fn: AbstractLoss,
train_batch_size: int,
) -> LossInputActivationsTuple:
Expand Down Expand Up @@ -441,7 +440,7 @@ def renormalize_and_scale(
def resample_dead_neurons(
self,
activation_store: ActivationStore,
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder] | DeepSpeedEngine,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
loss_fn: AbstractLoss,
train_batch_size: int,
) -> list[ParameterUpdateResults]:
Expand Down Expand Up @@ -531,7 +530,7 @@ def step_resampler(
self,
batch_neuron_activity: Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)],
activation_store: ActivationStore,
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder] | DeepSpeedEngine,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
loss_fn: AbstractLoss,
train_batch_size: int,
) -> list[ParameterUpdateResults] | None:
Expand Down
2 changes: 1 addition & 1 deletion sparse_autoencoder/autoencoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def load(
The loaded model.
"""
# Load the file
serialized_state = torch.load(file_path, map_location=torch.device("cpu"))
serialized_state = torch.load(file_path)
state = SparseAutoencoderState.model_validate(serialized_state)

# Initialise the model
Expand Down
2 changes: 1 addition & 1 deletion sparse_autoencoder/optimizer/adam_with_reset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__( # (extending existing implementation)
lr: float | Float[Tensor, Axis.names(Axis.SINGLE_ITEM)] = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
weight_decay: float = 0,
*,
amsgrad: bool = False,
foreach: bool | None = None,
Expand Down
194 changes: 0 additions & 194 deletions sparse_autoencoder/optimizer/deepspeed_adam_with_reset.py

This file was deleted.

23 changes: 11 additions & 12 deletions sparse_autoencoder/source_model/replace_activations_hook.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,33 @@
"""Replace activations hook."""
from typing import TYPE_CHECKING

from deepspeed import DeepSpeedEngine
from jaxtyping import Float
from torch import Tensor
from torch.nn.parallel import DataParallel
from transformer_lens.hook_points import HookPoint

from sparse_autoencoder.autoencoder.model import SparseAutoencoder
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


if TYPE_CHECKING:
from sparse_autoencoder.tensor_types import Axis
from jaxtyping import Float


def replace_activations_hook(
value: Tensor,
hook: HookPoint, # noqa: ARG001
sparse_autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder] | DeepSpeedEngine,
sparse_autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
component_idx: int | None = None,
n_components: int | None = None,
) -> Tensor:
"""Replace activations hook.
This should be pre-initialised with `functools.partial`.
Args:
value: The activations to replace.
hook: The hook point.
sparse_autoencoder: The sparse autoencoder.
sparse_autoencoder: The sparse autoencoder. This should be pre-initialised with
`functools.partial`.
component_idx: The component index to replace the activations with, if just replacing
activations for a single component. Requires the model to have a component axis.
n_components: The number of components that the SAE is trained on.
Returns:
Replaced activations.
Expand All @@ -47,8 +43,11 @@ def replace_activations_hook(
)

if component_idx is not None:
if n_components is None:
error_message = "The number of model components must be set if component_idx is set."
if sparse_autoencoder.config.n_components is None:
error_message = (
"Cannot replace for a specific component, if the model does not have a "
"component axis."
)
raise RuntimeError(error_message)

# The approach here is to run a forward pass with dummy values for all components other than
Expand All @@ -57,7 +56,7 @@ def replace_activations_hook(
# components.
expanded_shape = [
squashed_value.shape[0],
n_components,
sparse_autoencoder.config.n_components,
squashed_value.shape[-1],
]
expanded = squashed_value.unsqueeze(1).expand(*expanded_shape)
Expand Down
Loading

0 comments on commit 41697a0

Please sign in to comment.