generated from alan-cooney/transformer-lens-starter-template
-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use the standard mock dataset for testing the abstract class (#125)
- Loading branch information
1 parent
0bd0c9d
commit d5c66f8
Showing
2 changed files
with
14 additions
and
81 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
94 changes: 14 additions & 80 deletions
94
sparse_autoencoder/source_data/tests/test_abstract_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,102 +1,36 @@ | ||
"""Test the abstract dataset.""" | ||
from pathlib import Path | ||
from typing import Any, TypedDict | ||
|
||
from datasets import IterableDataset, load_dataset | ||
import pytest | ||
import torch | ||
|
||
from sparse_autoencoder.source_data.abstract_dataset import ( | ||
SourceDataset, | ||
TokenizedPrompts, | ||
) | ||
|
||
|
||
TEST_CONTEXT_SIZE: int = 4 | ||
|
||
|
||
class MockHuggingFaceDatasetItem(TypedDict): | ||
"""Mock Hugging Face dataset item typed dict.""" | ||
|
||
text: str | ||
meta: dict | ||
|
||
|
||
class MockSourceDataset(SourceDataset[MockHuggingFaceDatasetItem]): | ||
"""Mock source dataset for testing the inherited abstract dataset.""" | ||
|
||
def preprocess( | ||
self, | ||
source_batch: MockHuggingFaceDatasetItem, # noqa: ARG002 | ||
*, | ||
context_size: int, # noqa: ARG002 | ||
) -> TokenizedPrompts: | ||
"""Preprocess a batch of prompts.""" | ||
preprocess_batch = 100 | ||
tokenized_texts = torch.randint( | ||
low=0, high=50000, size=(preprocess_batch, TEST_CONTEXT_SIZE) | ||
).tolist() | ||
return {"input_ids": tokenized_texts} | ||
|
||
def __init__( | ||
self, | ||
dataset_path: str = "mock_dataset_path", | ||
dataset_split: str = "test", | ||
context_size: int = TEST_CONTEXT_SIZE, | ||
buffer_size: int = 1000, | ||
preprocess_batch_size: int = 1000, | ||
): | ||
"""Initialise the dataset.""" | ||
super().__init__( | ||
dataset_path, | ||
dataset_split, | ||
context_size, | ||
buffer_size, | ||
preprocess_batch_size, | ||
) | ||
from sparse_autoencoder.source_data.abstract_dataset import SourceDataset | ||
from sparse_autoencoder.source_data.mock_dataset import MockDataset | ||
|
||
|
||
@pytest.fixture() | ||
def mock_hugging_face_load_dataset(monkeypatch: pytest.MonkeyPatch) -> None: | ||
"""Mock the `load_dataset` function from Hugging Face. | ||
def mock_dataset() -> MockDataset: | ||
"""Fixture to create a default ConsecutiveIntHuggingFaceDataset for testing. | ||
Instead load the text data from mocks/text_dataset.txt, using a restored `load_dataset` method. | ||
Returns: | ||
ConsecutiveIntHuggingFaceDataset: An instance of the dataset for testing. | ||
""" | ||
return MockDataset(context_size=10, buffer_size=100) | ||
|
||
def mock_load_dataset(*args: Any, **kwargs: Any) -> IterableDataset: # noqa: ANN401 | ||
"""Mock load dataset function.""" | ||
mock_path = Path(__file__).parent / "mocks" / "text_dataset.txt" | ||
return load_dataset( | ||
"text", data_files={"train": [str(mock_path)]}, streaming=True, split="train" | ||
) # type: ignore | ||
|
||
monkeypatch.setattr( | ||
"sparse_autoencoder.source_data.abstract_dataset.load_dataset", mock_load_dataset | ||
) | ||
|
||
|
||
def test_extended_dataset_initialization(mock_hugging_face_load_dataset: pytest.Function) -> None: | ||
def test_extended_dataset_initialization(mock_dataset: MockDataset) -> None: | ||
"""Test the initialization of the extended dataset.""" | ||
data = MockSourceDataset() | ||
assert data is not None | ||
assert isinstance(data, SourceDataset) | ||
assert mock_dataset is not None | ||
assert isinstance(mock_dataset, SourceDataset) | ||
|
||
|
||
def test_extended_dataset_iterator(mock_hugging_face_load_dataset: pytest.Function) -> None: | ||
def test_extended_dataset_iterator(mock_dataset: MockDataset) -> None: | ||
"""Test the iterator of the extended dataset.""" | ||
data = MockSourceDataset() | ||
iterator = iter(data) | ||
iterator = iter(mock_dataset) | ||
assert iterator is not None | ||
|
||
first_item = next(iterator) | ||
assert len(first_item["input_ids"]) == TEST_CONTEXT_SIZE | ||
|
||
|
||
def test_get_dataloader(mock_hugging_face_load_dataset: pytest.Function) -> None: | ||
def test_get_dataloader(mock_dataset: MockDataset) -> None: | ||
"""Test the get_dataloader method of the extended dataset.""" | ||
data = MockSourceDataset() | ||
batch_size = 3 | ||
dataloader = data.get_dataloader(batch_size=batch_size) | ||
dataloader = mock_dataset.get_dataloader(batch_size=batch_size) | ||
first_item = next(iter(dataloader))["input_ids"] | ||
assert first_item.shape[0] == batch_size | ||
assert first_item.shape[-1] == TEST_CONTEXT_SIZE |