From c9621439b2b82348a84833420602be95e29ec4a9 Mon Sep 17 00:00:00 2001 From: Alan <41682961+alan-cooney@users.noreply.github.com> Date: Wed, 22 Nov 2023 23:02:31 -0500 Subject: [PATCH] Fix all docstrings for docs (#97) --- .github/workflows/checks.yml | 2 + .github/workflows/gh-pages.yml | 4 +- .vscode/cspell.json | 1 + docs/gen_ref_pages.py | 17 +++-- poetry.lock | 62 ++++++++++++++----- pyproject.toml | 20 +++++- .../abstract_activation_resampler.py | 4 +- .../activation_resampler.py | 6 ++ .../activation_store/base_store.py | 5 -- .../activation_store/disk_store.py | 9 ++- .../activation_store/list_store.py | 41 ++++++++---- .../activation_store/tensor_store.py | 28 +++++---- .../autoencoder/abstract_autoencoder.py | 6 -- .../components/abstract_decoder.py | 3 - .../components/abstract_encoder.py | 3 - .../components/abstract_outer_bias.py | 2 - .../components/unit_norm_decoder.py | 4 ++ sparse_autoencoder/loss/abstract_loss.py | 1 - sparse_autoencoder/loss/reducer.py | 2 +- .../metrics/train/abstract_train_metric.py | 10 ++- .../validate/abstract_validate_metric.py | 1 - .../optimizer/abstract_optimizer.py | 2 - .../optimizer/adam_with_reset.py | 35 ++++++++--- .../source_data/abstract_dataset.py | 4 +- .../source_data/pretokenized_dataset.py | 3 + .../source_data/text_dataset.py | 3 + .../source_model/store_activations_hook.py | 3 + sparse_autoencoder/train/abstract_pipeline.py | 44 ++++++++++--- sparse_autoencoder/train/pipeline.py | 19 +++++- .../train/utils/get_model_device.py | 4 +- 30 files changed, 241 insertions(+), 107 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 6aad8828..593859ed 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -52,6 +52,8 @@ jobs: run: poetry run pyright - name: Ruff lint run: poetry run ruff check . --output-format=github + - name: Docstrings lint + run: poetry run pydoclint . - name: Ruff format run: poetry run ruff format . --check - name: Pytest diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml index 79728bdc..cbb409ef 100644 --- a/.github/workflows/gh-pages.yml +++ b/.github/workflows/gh-pages.yml @@ -27,10 +27,10 @@ jobs: cache: "poetry" - name: Install poe run: pip install poethepoet - - name: Install mkdocs - run: pip install mkdocs - name: Install dependencies run: poetry install --with docs + - name: Generate docs + run: poe gen-docs - name: Build Docs run: poe make-docs - name: Upload Docs Artifact diff --git a/.vscode/cspell.json b/.vscode/cspell.json index c9c28e01..22ffcbe5 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -87,6 +87,7 @@ "typecheck", "ultralow", "uncopyrighted", + "ungraphed", "unsqueeze", "venv", "virtualenv", diff --git a/docs/gen_ref_pages.py b/docs/gen_ref_pages.py index 996ac748..54d2435f 100644 --- a/docs/gen_ref_pages.py +++ b/docs/gen_ref_pages.py @@ -22,7 +22,7 @@ def is_source_file(file: Path) -> bool: """Check if the provided file is a source file for Sparse Encoder. Args: - file (Path): The file path to check. + file: The file path to check. Returns: bool: True if the file is a source file, False otherwise. @@ -34,11 +34,10 @@ def process_path(path: Path) -> tuple[Path, Path, Path]: """Process the given path for documentation generation. Args: - path (Path): The file path to process. + path: The file path to process. Returns: - tuple[Path, Path, Path]: A tuple containing module path, documentation path, - and full documentation path. + A tuple containing module path, documentation path, and full documentation path. """ module_path = path.relative_to(PROJECT_ROOT).with_suffix("") doc_path = path.relative_to(PROJECT_ROOT).with_suffix(".md") @@ -56,9 +55,9 @@ def generate_documentation(path: Path, module_path: Path, full_doc_path: Path) - """Generate documentation for the given source file. Args: - path (Path): The source file path. - module_path (Path): The module path. - full_doc_path (Path): The full documentation file path. + path: The source file path. + module_path: The module path. + full_doc_path: The full documentation file path. """ if module_path.name == "__main__": return @@ -77,8 +76,8 @@ def generate_nav_file(nav: mkdocs_gen_files.nav.Nav, reference_dir: Path) -> Non """Generate the navigation file for the documentation. Args: - nav (mkdocs_gen_files.Nav): The navigation object. - reference_dir (Path): The directory to write the navigation file. + nav: The navigation object. + reference_dir: The directory to write the navigation file. """ with mkdocs_gen_files.open(reference_dir / "SUMMARY.md", "w") as nav_file: nav_file.writelines(nav.build_literate_nav()) diff --git a/poetry.lock b/poetry.lock index e7116742..2cff6ea7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -141,13 +141,13 @@ frozenlist = ">=1.1.0" [[package]] name = "anyio" -version = "4.0.0" +version = "4.1.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false python-versions = ">=3.8" files = [ - {file = "anyio-4.0.0-py3-none-any.whl", hash = "sha256:cfdb2b588b9fc25ede96d8db56ed50848b0b649dca3dd1df0b11f683bb9e0b5f"}, - {file = "anyio-4.0.0.tar.gz", hash = "sha256:f7ed51751b2c2add651e5747c891b47e26d2a21be5d32d9311dfe9692f3e5d7a"}, + {file = "anyio-4.1.0-py3-none-any.whl", hash = "sha256:56a415fbc462291813a94528a779597226619c8e78af7de0507333f700011e5f"}, + {file = "anyio-4.1.0.tar.gz", hash = "sha256:5a0bec7085176715be77df87fc66d6c9d70626bd752fcc85f57cdbee5b3760da"}, ] [package.dependencies] @@ -156,9 +156,9 @@ idna = ">=2.8" sniffio = ">=1.1" [package.extras] -doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] -trio = ["trio (>=0.22)"] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (>=0.23)"] [[package]] name = "appdirs" @@ -790,6 +790,17 @@ files = [ [package.dependencies] six = ">=1.4.0" +[[package]] +name = "docstring-parser-fork" +version = "0.0.5" +description = "Parse Python docstrings in reST, Google and Numpydoc format" +optional = false +python-versions = ">=3.6,<4.0" +files = [ + {file = "docstring_parser_fork-0.0.5-py3-none-any.whl", hash = "sha256:d521dea9b9cc6c60ab5569fa0c1115e3b84a83e6413266fb111a7c81cb935997"}, + {file = "docstring_parser_fork-0.0.5.tar.gz", hash = "sha256:395ae8ee6a359e268670ebc4fe9a40dab917a94f6decd7cda8e86f9bea5c9456"}, +] + [[package]] name = "einops" version = "0.7.0" @@ -1459,13 +1470,13 @@ jupyter-server = ">=1.1.2" [[package]] name = "jupyter-server" -version = "2.11.0" +version = "2.10.1" description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_server-2.11.0-py3-none-any.whl", hash = "sha256:c9bd6e6d71dc5a2a25df167dc323422997f14682b008bfecb5d7920a55020ea7"}, - {file = "jupyter_server-2.11.0.tar.gz", hash = "sha256:78c97ec8049f9062f0151725bc8a1364dfed716646a66819095e0e8a24793eba"}, + {file = "jupyter_server-2.10.1-py3-none-any.whl", hash = "sha256:20519e355d951fc5e1b6ac5952854fe7620d0cfb56588fa4efe362a758977ed3"}, + {file = "jupyter_server-2.10.1.tar.gz", hash = "sha256:e6da2657a954a7879eed28cc08e0817b01ffd81d7eab8634660397b55f926472"}, ] [package.dependencies] @@ -1849,13 +1860,13 @@ recommended = ["mkdocs-minify-plugin (>=0.7,<1.0)", "mkdocs-redirects (>=1.2,<2. [[package]] name = "mkdocs-material-extensions" -version = "1.3" +version = "1.3.1" description = "Extension pack for Python Markdown and MkDocs Material." optional = false python-versions = ">=3.8" files = [ - {file = "mkdocs_material_extensions-1.3-py3-none-any.whl", hash = "sha256:0297cc48ba68a9fdd1ef3780a3b41b534b0d0df1d1181a44676fda5f464eeadc"}, - {file = "mkdocs_material_extensions-1.3.tar.gz", hash = "sha256:f0446091503acb110a7cab9349cbc90eeac51b58d1caa92a704a81ca1e24ddbd"}, + {file = "mkdocs_material_extensions-1.3.1-py3-none-any.whl", hash = "sha256:adff8b62700b25cb77b53358dad940f3ef973dd6db797907c49e3c2ef3ab4e31"}, + {file = "mkdocs_material_extensions-1.3.1.tar.gz", hash = "sha256:10c9511cea88f568257f960358a467d12b970e1f7b2c0e5fb2bb48cab1928443"}, ] [[package]] @@ -2803,6 +2814,25 @@ files = [ {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, ] +[[package]] +name = "pydoclint" +version = "0.3.8" +description = "A Python docstring linter that checks arguments, returns, yields, and raises sections" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydoclint-0.3.8-py2.py3-none-any.whl", hash = "sha256:8e5e020071bb64056fd3f1d68f3b1162ffeb8a3fd6424f73fef7272dac62c166"}, + {file = "pydoclint-0.3.8.tar.gz", hash = "sha256:5a9686a5fb410343e998402686b87cc07df647ea3ab92528c0b0cf8505584e44"}, +] + +[package.dependencies] +click = ">=8.0.0" +docstring-parser-fork = ">=0.0.5" +tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +flake8 = ["flake8 (>=4)"] + [[package]] name = "pygments" version = "2.17.2" @@ -2838,13 +2868,13 @@ extra = ["pygments (>=2.12)"] [[package]] name = "pyright" -version = "1.1.336" +version = "1.1.337" description = "Command line wrapper for pyright" optional = false python-versions = ">=3.7" files = [ - {file = "pyright-1.1.336-py3-none-any.whl", hash = "sha256:8f6a8f365730c8d6c1af840d937371fd5cf0137b6e1827b8b066bc0bb7327aa6"}, - {file = "pyright-1.1.336.tar.gz", hash = "sha256:f92d6d6845e4175833ea60dee5b1ef4d5d66663438fdaedccc1c3ba0f8efa3e3"}, + {file = "pyright-1.1.337-py3-none-any.whl", hash = "sha256:8cbd4ef71797258f816a8393a758c9c91213479f472082d0e3a735ef7ab5f65a"}, + {file = "pyright-1.1.337.tar.gz", hash = "sha256:81d81f839d1750385390c4c4a7b84b062ece2f9a078f87055d4d2a5914ef2a08"}, ] [package.dependencies] @@ -4804,4 +4834,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10, <3.12" -content-hash = "e802589e54cdda73b75d2853f5d1ce3ad9a4a5b07cfb4debbd278c99b47de016" +content-hash = "aa4b0f481f5d94eab4e6708bf9e21b9c6f006b5e231a21f9335b7a69666a4d85" diff --git a/pyproject.toml b/pyproject.toml index 9a4582e2..c2a9a964 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,8 @@ jupyter=">=1" plotly=">=5" poethepoet=">=0.24.2" - pyright=">=1.1.334" + pydoclint="^0.3.8" + pyright=">=1.1.337" pytest=">=7" pytest-cov=">=4" pytest-timeout=">=2.2.0" @@ -42,14 +43,14 @@ mkdocs-section-index=">=0.3.8" mkdocstrings={extras=["python"], version=">=0.24.0"} mkdocstrings-python=">=1.7.3" + mknotebooks="^0.8.0" pytkdocs-tweaks=">=0.0.7" - mknotebooks = "^0.8.0" [tool.poe.tasks] [tool.poe.tasks.check] help="All checks" ignore_fail=false - sequence=["check-lock", "format", "lint", "test", "typecheck"] + sequence=["check-lock", "docstring-lint", "format", "lint", "test", "typecheck"] [tool.poe.tasks.format] cmd="ruff format ." @@ -59,6 +60,10 @@ cmd="ruff check . --fix" help="Lint (with autofix)" + [tool.poe.tasks.docstring-lint] + cmd="pydoclint ." + help="Lint docstrings" + [tool.poe.tasks.ruff] help=" [alias for lint && format]" ignore_fail=false @@ -95,6 +100,7 @@ [tool.poe.tasks.gen-docs] help="Cleans out the automatically generated docs." script="docs.gen_ref_pages:run" + [tool.poe.tasks.make-docs] cmd="mkdocs build" help="Generates our docs" @@ -225,3 +231,11 @@ strictListInference=true strictParameterNoneValue=true strictSetInference=true + +[tool.pydoclint] + allow-init-docstring=true + arg-type-hints-in-docstring=false + check-return-types=false + check-yield-types=false + exclude='\.venv' + style="google" diff --git a/sparse_autoencoder/activation_resampler/abstract_activation_resampler.py b/sparse_autoencoder/activation_resampler/abstract_activation_resampler.py index 2f9758e5..5219395d 100644 --- a/sparse_autoencoder/activation_resampler/abstract_activation_resampler.py +++ b/sparse_autoencoder/activation_resampler/abstract_activation_resampler.py @@ -68,5 +68,7 @@ def resample_dead_neurons( autoencoder: Sparse autoencoder model. loss_fn: Loss function. train_batch_size: Train batch size (also used for resampling). + + Returns: + Indices of dead neurons, and the updates for the encoder and decoder weights and biases. """ - raise NotImplementedError diff --git a/sparse_autoencoder/activation_resampler/activation_resampler.py b/sparse_autoencoder/activation_resampler/activation_resampler.py index 1c76ae6c..e2c082da 100644 --- a/sparse_autoencoder/activation_resampler/activation_resampler.py +++ b/sparse_autoencoder/activation_resampler/activation_resampler.py @@ -96,6 +96,9 @@ def compute_loss_and_get_activations( Returns: A tuple containing the loss per item, and all input activations. + + Raises: + ValueError: If the number of items in the store is less than the number of inputs """ with torch.no_grad(): loss_batches: list[TrainBatchStatistic] = [] @@ -274,6 +277,9 @@ def resample_dead_neurons( autoencoder: Sparse autoencoder model. loss_fn: Loss function. train_batch_size: Train batch size (also used for resampling). + + Returns: + Indices of dead neurons, and the updates for the encoder and decoder weights and biases. """ with torch.no_grad(): dead_neuron_indices = self.get_dead_neuron_indices(neuron_activity) diff --git a/sparse_autoencoder/activation_store/base_store.py b/sparse_autoencoder/activation_store/base_store.py index 652d52f0..ba5d466e 100644 --- a/sparse_autoencoder/activation_store/base_store.py +++ b/sparse_autoencoder/activation_store/base_store.py @@ -51,27 +51,22 @@ class ActivationStore(Dataset[InputOutputActivationVector], ABC): @abstractmethod def append(self, item: InputOutputActivationVector) -> Future | None: """Add a Single Item to the Store.""" - raise NotImplementedError @abstractmethod def extend(self, batch: InputOutputActivationBatch) -> Future | None: """Add a Batch to the Store.""" - raise NotImplementedError @abstractmethod def empty(self) -> None: """Empty the Store.""" - raise NotImplementedError @abstractmethod def __len__(self) -> int: """Get the Length of the Store.""" - raise NotImplementedError @abstractmethod def __getitem__(self, index: int) -> InputOutputActivationVector: """Get an Item from the Store.""" - raise NotImplementedError def shuffle(self) -> None: """Optional shuffle method.""" diff --git a/sparse_autoencoder/activation_store/disk_store.py b/sparse_autoencoder/activation_store/disk_store.py index f052bf13..d831b54e 100644 --- a/sparse_autoencoder/activation_store/disk_store.py +++ b/sparse_autoencoder/activation_store/disk_store.py @@ -251,9 +251,12 @@ def __len__(self) -> int: """Length Dunder Method. Example: - >>> store = DiskActivationStore(max_cache_size=1, empty_dir=True) - >>> print(len(store)) - 0 + >>> store = DiskActivationStore(max_cache_size=1, empty_dir=True) + >>> print(len(store)) + 0 + + Returns: + The number of activation vectors in the dataset. """ # Calculate the length if not cached if self._disk_n_activation_vectors.value == -1: diff --git a/sparse_autoencoder/activation_store/list_store.py b/sparse_autoencoder/activation_store/list_store.py index 52562940..6ec660c8 100644 --- a/sparse_autoencoder/activation_store/list_store.py +++ b/sparse_autoencoder/activation_store/list_store.py @@ -138,19 +138,23 @@ def __len__(self) -> int: Returns the number of activation vectors in the dataset. Example: - >>> import torch - >>> store = ListActivationStore() - >>> store.append(torch.randn(100)) - >>> store.append(torch.randn(100)) - >>> len(store) - 2 + >>> import torch + >>> store = ListActivationStore() + >>> store.append(torch.randn(100)) + >>> store.append(torch.randn(100)) + >>> len(store) + 2 + + Returns: + The number of activation vectors in the dataset. """ return len(self._data) def __sizeof__(self) -> int: """Sizeof Dunder Method. - Returns the size of the dataset in bytes. + Returns: + The size of the dataset in bytes. """ # The list of tensors is really a list of pointers to tensors, so we need to account for # this as well as the size of the tensors themselves. @@ -222,6 +226,10 @@ def append(self, item: InputOutputActivationVector) -> Future | None: Args: item: The item to append to the dataset. + + Returns: + Future that completes when the activation vector has queued to be written to disk, and + if needed, written to disk. """ self._data.append(item.to(self._device)) @@ -254,6 +262,10 @@ def extend(self, batch: SourceModelActivations) -> Future | None: Args: batch: A batch of items to add to the dataset. + + Returns: + Future that completes when the activation vectors have queued to be written to disk, and + if needed, written to disk. """ # Schedule _extend to run in a separate process if self._pool: @@ -269,12 +281,15 @@ def wait_for_writes_to_complete(self) -> None: Wait for any non-blocking writes (e.g. calls to :meth:`append`) to complete. Example: - >>> import torch - >>> store = ListActivationStore(multiprocessing_enabled=True) - >>> store.extend(torch.randn(3, 100)) - >>> store.wait_for_writes_to_complete() - >>> len(store) - 3 + >>> import torch + >>> store = ListActivationStore(multiprocessing_enabled=True) + >>> store.extend(torch.randn(3, 100)) + >>> store.wait_for_writes_to_complete() + >>> len(store) + 3 + + Raises: + RuntimeError: If any exceptions occurred in the background workers. """ # Restart the pool if self._pool: diff --git a/sparse_autoencoder/activation_store/tensor_store.py b/sparse_autoencoder/activation_store/tensor_store.py index 3b91c79f..2d6140d6 100644 --- a/sparse_autoencoder/activation_store/tensor_store.py +++ b/sparse_autoencoder/activation_store/tensor_store.py @@ -88,25 +88,29 @@ def __len__(self) -> int: Returns the number of activation vectors in the dataset. Example: - >>> import torch - >>> store = TensorActivationStore(max_items=10_000_000, num_neurons=100) - >>> store.append(torch.randn(100)) - >>> store.append(torch.randn(100)) - >>> len(store) - 2 + >>> import torch + >>> store = TensorActivationStore(max_items=10_000_000, num_neurons=100) + >>> store.append(torch.randn(100)) + >>> store.append(torch.randn(100)) + >>> len(store) + 2 + + Returns: + The number of activation vectors in the dataset. """ return self.items_stored def __sizeof__(self) -> int: """Sizeof Dunder Method. - Returns the size of the underlying tensor in bytes. - Example: - >>> import torch - >>> store = TensorActivationStore(max_items=2, num_neurons=100) - >>> store.__sizeof__() # Pre-allocated tensor of 2x100 - 800 + >>> import torch + >>> store = TensorActivationStore(max_items=2, num_neurons=100) + >>> store.__sizeof__() # Pre-allocated tensor of 2x100 + 800 + + Returns: + The size of the underlying tensor in bytes. """ return self._data.element_size() * self._data.nelement() diff --git a/sparse_autoencoder/autoencoder/abstract_autoencoder.py b/sparse_autoencoder/autoencoder/abstract_autoencoder.py index a2411a66..0478a2c1 100644 --- a/sparse_autoencoder/autoencoder/abstract_autoencoder.py +++ b/sparse_autoencoder/autoencoder/abstract_autoencoder.py @@ -19,25 +19,21 @@ class AbstractAutoencoder(Module, ABC): @abstractmethod def encoder(self) -> AbstractEncoder: """Encoder.""" - raise NotImplementedError @property @abstractmethod def decoder(self) -> AbstractDecoder: """Decoder.""" - raise NotImplementedError @property @abstractmethod def pre_encoder_bias(self) -> AbstractOuterBias: """Pre-encoder bias.""" - raise NotImplementedError @property @abstractmethod def post_decoder_bias(self) -> AbstractOuterBias: """Post-decoder bias.""" - raise NotImplementedError @abstractmethod def forward( @@ -55,9 +51,7 @@ def forward( Returns: Tuple of learned activations and decoded activations. """ - raise NotImplementedError @abstractmethod def reset_parameters(self) -> None: """Reset the parameters.""" - raise NotImplementedError diff --git a/sparse_autoencoder/autoencoder/components/abstract_decoder.py b/sparse_autoencoder/autoencoder/components/abstract_decoder.py index 49104e23..befd0e00 100644 --- a/sparse_autoencoder/autoencoder/components/abstract_decoder.py +++ b/sparse_autoencoder/autoencoder/components/abstract_decoder.py @@ -24,7 +24,6 @@ class AbstractDecoder(Module, ABC): @abstractmethod def weight(self) -> DecoderWeights: """Weight.""" - raise NotImplementedError @abstractmethod def forward( @@ -39,12 +38,10 @@ def forward( Returns: Decoded activations. """ - raise NotImplementedError @abstractmethod def reset_parameters(self) -> None: """Reset the parameters.""" - raise NotImplementedError @final def update_dictionary_vectors( diff --git a/sparse_autoencoder/autoencoder/components/abstract_encoder.py b/sparse_autoencoder/autoencoder/components/abstract_encoder.py index f681dbdb..e3b41b54 100644 --- a/sparse_autoencoder/autoencoder/components/abstract_encoder.py +++ b/sparse_autoencoder/autoencoder/components/abstract_encoder.py @@ -27,13 +27,11 @@ class AbstractEncoder(Module, ABC): @abstractmethod def weight(self) -> EncoderWeights: """Weight.""" - raise NotImplementedError @property @abstractmethod def bias(self) -> LearntActivationVector: """Bias.""" - raise NotImplementedError @abstractmethod def forward(self, x: InputOutputActivationBatch) -> LearnedActivationBatch: @@ -45,7 +43,6 @@ def forward(self, x: InputOutputActivationBatch) -> LearnedActivationBatch: Returns: Resulting activations. """ - raise NotImplementedError @final def update_dictionary_vectors( diff --git a/sparse_autoencoder/autoencoder/components/abstract_outer_bias.py b/sparse_autoencoder/autoencoder/components/abstract_outer_bias.py index 687f227a..bed752e2 100644 --- a/sparse_autoencoder/autoencoder/components/abstract_outer_bias.py +++ b/sparse_autoencoder/autoencoder/components/abstract_outer_bias.py @@ -22,7 +22,6 @@ def bias(self) -> InputOutputActivationVector: May be a reference to a bias parameter in the parent module, if using e.g. a tied bias. """ - raise NotImplementedError @abstractmethod def forward( @@ -37,4 +36,3 @@ def forward( Returns: Resulting activations. """ - raise NotImplementedError diff --git a/sparse_autoencoder/autoencoder/components/unit_norm_decoder.py b/sparse_autoencoder/autoencoder/components/unit_norm_decoder.py index 78f9af87..1e1e1b87 100644 --- a/sparse_autoencoder/autoencoder/components/unit_norm_decoder.py +++ b/sparse_autoencoder/autoencoder/components/unit_norm_decoder.py @@ -125,6 +125,10 @@ def _weight_backward_hook( Args: grad: Gradient with respect to the weights. + + Returns: + Gradient with respect to the weights, with the component parallel to the dictionary + vectors removed. """ # Project the gradients onto the dictionary vectors. Intuitively the dictionary vectors can # be thought of as vectors that end on the circumference of a hypersphere. The projection of diff --git a/sparse_autoencoder/loss/abstract_loss.py b/sparse_autoencoder/loss/abstract_loss.py index d0934a2e..ee8978e2 100644 --- a/sparse_autoencoder/loss/abstract_loss.py +++ b/sparse_autoencoder/loss/abstract_loss.py @@ -55,7 +55,6 @@ def forward( Returns: Loss per batch item. """ - raise NotImplementedError @final def batch_scalar_loss( diff --git a/sparse_autoencoder/loss/reducer.py b/sparse_autoencoder/loss/reducer.py index 8b3aec8d..3fa68334 100644 --- a/sparse_autoencoder/loss/reducer.py +++ b/sparse_autoencoder/loss/reducer.py @@ -45,7 +45,7 @@ def __init__( """Initialize the loss reducer. Args: - loss_modules: Loss modules to reduce. + *loss_modules: Loss modules to reduce. Raises: ValueError: If the loss reducer has no loss modules. diff --git a/sparse_autoencoder/metrics/train/abstract_train_metric.py b/sparse_autoencoder/metrics/train/abstract_train_metric.py index 6fda3dcf..bb8e014f 100644 --- a/sparse_autoencoder/metrics/train/abstract_train_metric.py +++ b/sparse_autoencoder/metrics/train/abstract_train_metric.py @@ -25,5 +25,11 @@ class AbstractTrainMetric(ABC): @abstractmethod def calculate(self, data: TrainMetricData) -> dict[str, Any]: - """Calculate any metrics.""" - raise NotImplementedError + """Calculate any metrics. + + Args: + data: Train metric data. + + Returns: + Dictionary of metrics. + """ diff --git a/sparse_autoencoder/metrics/validate/abstract_validate_metric.py b/sparse_autoencoder/metrics/validate/abstract_validate_metric.py index 239d51bb..6f6d9177 100644 --- a/sparse_autoencoder/metrics/validate/abstract_validate_metric.py +++ b/sparse_autoencoder/metrics/validate/abstract_validate_metric.py @@ -19,4 +19,3 @@ class AbstractValidationMetric(ABC): @abstractmethod def calculate(self, data: ValidationMetricData) -> dict[str, Any]: """Calculate any metrics.""" - raise NotImplementedError diff --git a/sparse_autoencoder/optimizer/abstract_optimizer.py b/sparse_autoencoder/optimizer/abstract_optimizer.py index dd7413cc..aaecaa05 100644 --- a/sparse_autoencoder/optimizer/abstract_optimizer.py +++ b/sparse_autoencoder/optimizer/abstract_optimizer.py @@ -21,7 +21,6 @@ def reset_state_all_parameters(self) -> None: Resets any optimizer state (e.g. momentum). This is for use after manually editing model parameters (e.g. with activation resampling). """ - raise NotImplementedError @abstractmethod def reset_neurons_state( @@ -46,4 +45,3 @@ def reset_neurons_state( Raises: ValueError: If the parameter name is not found. """ - raise NotImplementedError diff --git a/sparse_autoencoder/optimizer/adam_with_reset.py b/sparse_autoencoder/optimizer/adam_with_reset.py index 45c7af39..73331267 100644 --- a/sparse_autoencoder/optimizer/adam_with_reset.py +++ b/sparse_autoencoder/optimizer/adam_with_reset.py @@ -32,7 +32,7 @@ class AdamWithReset(Adam, AbstractOptimizerWithReset): The names of the parameters, so that we can find them later when resetting the state. """ - def __init__( # noqa: PLR0913 , D417 (extending existing implementation) + def __init__( # noqa: PLR0913 (extending existing implementation) self, params: params_t, lr: float | Tensor = 1e-3, @@ -64,9 +64,31 @@ def __init__( # noqa: PLR0913 , D417 (extending existing implementation) >>> optimizer.reset_state_all_parameters() Args: - named_parameters (Iterator[tuple[str, Parameter]]): An iterator over the named - parameters of the model. This is used to find the parameters when resetting their - state. You should set this as `model.named_parameters()`. + params: Iterable of parameters to optimize or dicts defining parameter groups. + lr: Learning rate. A Tensor LR is not yet fully supported for all implementations. Use a + float LR unless specifying fused=True or capturable=True. + betas: Coefficients used for computing running averages of gradient and its square. + eps: Term added to the denominator to improve numerical stability. + weight_decay: Weight decay (L2 penalty). + amsgrad: Whether to use the AMSGrad variant of this algorithm from the paper "On the + Convergence of Adam and Beyond". + foreach: Whether foreach implementation of optimizer is used. If None, foreach is used + over the for-loop implementation on CUDA if more performant. Note that foreach uses + more peak memory. + maximize: If True, maximizes the parameters based on the objective, instead of + minimizing. + capturable: Whether this instance is safe to capture in a CUDA graph. True can impair + ungraphed performance. + differentiable: Whether autograd should occur through the optimizer step in training. + Setting to True can impair performance. + fused: Whether the fused implementation (CUDA only) is used. Supports torch.float64, + torch.float32, torch.float16, and torch.bfloat16. + named_parameters: An iterator over the named parameters of the model. This is used to + find the parameters when resetting their state. You should set this as + `model.named_parameters()`. + + Raises: + ValueError: If the number of parameter names does not match the number of parameters. """ # Initialise the parent class (note we repeat the parameter names so that type hints work). super().__init__( @@ -126,7 +148,7 @@ def _get_parameter_name_idx(self, parameter_name: str) -> int: """Get the index of a parameter name. Args: - parameter_name (str): The name of the parameter. + parameter_name: The name of the parameter. Returns: int: The index of the parameter name. @@ -177,9 +199,6 @@ def reset_neurons_state( parameter_group: The index of the parameter group to reset (typically this is just zero, unless you have setup multiple parameter groups for e.g. different learning rates for different parameters). - - Raises: - ValueError: If the parameter name is not found. """ # Get the state of the parameter group = self.param_groups[parameter_group] diff --git a/sparse_autoencoder/source_data/abstract_dataset.py b/sparse_autoencoder/source_data/abstract_dataset.py index 86825f19..f38fafaa 100644 --- a/sparse_autoencoder/source_data/abstract_dataset.py +++ b/sparse_autoencoder/source_data/abstract_dataset.py @@ -94,8 +94,10 @@ def preprocess( context_size: The context size to use when returning a list of tokenized prompts. *Towards Monosemanticity: Decomposing Language Models With Dictionary Learning* used a context size of 250. + + Returns: + Tokenized prompts. """ - raise NotImplementedError @abstractmethod def __init__( diff --git a/sparse_autoencoder/source_data/pretokenized_dataset.py b/sparse_autoencoder/source_data/pretokenized_dataset.py index ed5e3ef5..f88757b5 100644 --- a/sparse_autoencoder/source_data/pretokenized_dataset.py +++ b/sparse_autoencoder/source_data/pretokenized_dataset.py @@ -46,6 +46,9 @@ def preprocess( Args: source_batch: A batch of source data. context_size: The context size to use for tokenized prompts. + + Returns: + Tokenized prompts. """ tokenized_prompts: list[list[int]] = source_batch["tokens"] diff --git a/sparse_autoencoder/source_data/text_dataset.py b/sparse_autoencoder/source_data/text_dataset.py index f5f6b8f3..d2c97d0e 100644 --- a/sparse_autoencoder/source_data/text_dataset.py +++ b/sparse_autoencoder/source_data/text_dataset.py @@ -54,6 +54,9 @@ def preprocess( Args: source_batch: A batch of source data, including 'text' with a list of strings. context_size: Context size for tokenized prompts. + + Returns: + Tokenized prompts. """ prompts: list[str] = source_batch["text"] diff --git a/sparse_autoencoder/source_model/store_activations_hook.py b/sparse_autoencoder/source_model/store_activations_hook.py index 119ddc57..b8487262 100644 --- a/sparse_autoencoder/source_model/store_activations_hook.py +++ b/sparse_autoencoder/source_model/store_activations_hook.py @@ -46,6 +46,9 @@ def store_activations_hook( value: The activations to store. hook: The hook point. store: The activation store. This should be pre-initialised with `functools.partial`. + + Returns: + Unmodified activations. """ store.extend(value) diff --git a/sparse_autoencoder/train/abstract_pipeline.py b/sparse_autoencoder/train/abstract_pipeline.py index fae05bf6..2eb38d70 100644 --- a/sparse_autoencoder/train/abstract_pipeline.py +++ b/sparse_autoencoder/train/abstract_pipeline.py @@ -99,15 +99,28 @@ def __init__( # noqa: PLR0913 @abstractmethod def generate_activations(self, store_size: int) -> TensorActivationStore: - """Generate activations.""" - raise NotImplementedError + """Generate activations. + + Args: + store_size: Number of activations to generate. + + Returns: + Activation store for the train section. + """ @abstractmethod def train_autoencoder( self, activation_store: TensorActivationStore, train_batch_size: int ) -> NeuronActivity: - """Train the sparse autoencoder.""" - raise NotImplementedError + """Train the sparse autoencoder. + + Args: + activation_store: Activation store from the generate section. + train_batch_size: Train batch size. + + Returns: + Number of times each neuron fired. + """ @final def resample_neurons( @@ -153,7 +166,6 @@ def resample_neurons( @abstractmethod def validate_sae(self) -> None: """Get validation metrics.""" - raise NotImplementedError @final def save_checkpoint(self) -> None: @@ -174,7 +186,18 @@ def run_pipeline( validate_frequency: int | None = None, checkpoint_frequency: int | None = None, ) -> None: - """Run the full training pipeline.""" + """Run the full training pipeline. + + Args: + train_batch_size: Train batch size. + max_store_size: Maximum size of the activation store. + max_activations: Maximum total number of activations to train on (the original paper + used 8bn, although others have had success with 100m+). + resample_frequency: Frequency at which to resample dead neurons (the original paper used + every 200m). + validate_frequency: Frequency at which to get validation metrics. + checkpoint_frequency: Frequency at which to save a checkpoint. + """ last_resampled: int = 0 last_validated: int = 0 last_checkpoint: int = 0 @@ -207,9 +230,9 @@ def run_pipeline( neuron_activity = detached_neuron_activity # Update the counters - last_resampled += store_size - last_validated += store_size - last_checkpoint += store_size + last_resampled += len(activation_store) + last_validated += len(activation_store) + last_checkpoint += len(activation_store) # Resample dead neurons (if needed) progress_bar.set_postfix({"stage": "resample"}) @@ -273,5 +296,8 @@ def stateful_dataloader_iterable( Returns: Stateful iterable over the data in the dataloader. + + Yields: + Data from the dataloader. """ yield from dataloader diff --git a/sparse_autoencoder/train/pipeline.py b/sparse_autoencoder/train/pipeline.py index ef279b70..6a4032a0 100644 --- a/sparse_autoencoder/train/pipeline.py +++ b/sparse_autoencoder/train/pipeline.py @@ -25,7 +25,14 @@ class Pipeline(AbstractPipeline): total_training_steps: int = 1 def generate_activations(self, store_size: int) -> TensorActivationStore: - """Generate activations.""" + """Generate activations. + + Args: + store_size: Number of activations to generate. + + Returns: + Activation store for the train section. + """ num_neurons: int = 256 source_model_device: torch.device = get_model_device(self.source_model) @@ -57,7 +64,15 @@ def generate_activations(self, store_size: int) -> TensorActivationStore: def train_autoencoder( self, activation_store: TensorActivationStore, train_batch_size: int ) -> NeuronActivity: - """Train the sparse autoencoder.""" + """Train the sparse autoencoder. + + Args: + activation_store: Activation store from the generate section. + train_batch_size: Train batch size. + + Returns: + Number of times each neuron fired. + """ autoencoder_device: torch.device = get_model_device(self.autoencoder) activations_dataloader = DataLoader( diff --git a/sparse_autoencoder/train/utils/get_model_device.py b/sparse_autoencoder/train/utils/get_model_device.py index 8099e2f3..0d49bbd5 100644 --- a/sparse_autoencoder/train/utils/get_model_device.py +++ b/sparse_autoencoder/train/utils/get_model_device.py @@ -7,10 +7,10 @@ def get_model_device(model: Module) -> torch.device: """Get the device on which a PyTorch model is on. Args: - model (nn.Module): The PyTorch model. + model: The PyTorch model. Returns: - torch.device: The device ('cuda' or 'cpu') where the model is located. + The device ('cuda' or 'cpu') where the model is located. Raises: ValueError: If the model has no parameters.