From 48999261dd8742a0d4f793ffbb63960ed6d366b2 Mon Sep 17 00:00:00 2001 From: Andreas Koepf Date: Tue, 4 Feb 2025 19:17:34 +0100 Subject: [PATCH] register composite dataset --- reasoning_gym/algorithmic/letter_jumble.py | 2 +- reasoning_gym/composite.py | 55 ++++++++++++---------- reasoning_gym/factory.py | 17 +------ tests/test_composite.py | 47 ++++++++---------- 4 files changed, 54 insertions(+), 67 deletions(-) diff --git a/reasoning_gym/algorithmic/letter_jumble.py b/reasoning_gym/algorithmic/letter_jumble.py index 9919ccf0..f92d8f2b 100644 --- a/reasoning_gym/algorithmic/letter_jumble.py +++ b/reasoning_gym/algorithmic/letter_jumble.py @@ -3,7 +3,7 @@ import re from dataclasses import dataclass from random import Random -from typing import List, Optional +from typing import Optional from reasoning_gym.data import read_data_file diff --git a/reasoning_gym/composite.py b/reasoning_gym/composite.py index f9ce7b7f..7ea581b2 100644 --- a/reasoning_gym/composite.py +++ b/reasoning_gym/composite.py @@ -1,15 +1,17 @@ from dataclasses import dataclass from random import Random -from typing import List, Dict, Any, Optional +from typing import Any, Dict, List, Optional + import yaml from .dataset import ProceduralDataset -from .factory import create_dataset +from .factory import create_dataset, register_dataset @dataclass class DatasetSpec: """Specification for a single dataset within the composite""" + name: str weight: float config: dict @@ -24,6 +26,7 @@ def validate(self): @dataclass class CompositeConfig: """Configuration for CompositeDataset""" + size: int = 500 seed: Optional[int] = None datasets: List[DatasetSpec] = None @@ -33,21 +36,21 @@ def validate(self): assert self.size > 0, "size must be positive" assert self.datasets, "Must specify at least one dataset" assert len(self.datasets) > 0, "Must specify at least one dataset" - + # Validate each dataset spec for ds in self.datasets: ds.validate() @classmethod - def from_yaml(cls, yaml_path: str) -> 'CompositeConfig': + def from_yaml(cls, yaml_path: str) -> "CompositeConfig": """Load configuration from YAML file""" - with open(yaml_path, 'r') as f: + with open(yaml_path, "r") as f: data = yaml.safe_load(f) - + # Convert dataset specs to DatasetSpec objects - if 'datasets' in data: - data['datasets'] = [DatasetSpec(**ds) for ds in data['datasets']] - + if "datasets" in data: + data["datasets"] = [DatasetSpec(**ds) for ds in data["datasets"]] + return cls(**data) @@ -56,24 +59,24 @@ class CompositeDataset(ProceduralDataset): def __init__(self, config: CompositeConfig): super().__init__(config=config, seed=config.seed, size=config.size) - + # Initialize sub-datasets with incremented seeds self.datasets = {} self.weights = [] total_weight = 0.0 - + for i, ds_spec in enumerate(config.datasets): # Create dataset with derived seed ds_config = ds_spec.config.copy() - if 'seed' not in ds_config: - ds_config['seed'] = self.seed + i + 1 - if 'size' not in ds_config: - ds_config['size'] = self.size - + if "seed" not in ds_config: + ds_config["seed"] = self.seed + i + 1 + if "size" not in ds_config: + ds_config["size"] = self.size + self.datasets[ds_spec.name] = create_dataset(ds_spec.name, **ds_config) total_weight += ds_spec.weight self.weights.append(ds_spec.weight) - + # Normalize weights self.weights = [w / total_weight for w in self.weights] self.dataset_names = [ds.name for ds in config.datasets] @@ -82,22 +85,26 @@ def __getitem__(self, idx: int) -> dict: """Generate a single dataset item by sampling from sub-datasets""" # Create deterministic RNG for this index rng = Random(self.seed + idx) - + # Sample dataset according to weights dataset_idx = rng.choices(range(len(self.dataset_names)), weights=self.weights, k=1)[0] dataset_name = self.dataset_names[dataset_idx] dataset = self.datasets[dataset_name] - + # Get item from selected dataset item = dataset[idx] - + # Add source dataset info to metadata - item['metadata']['source_dataset'] = dataset_name - item['metadata']['source_index'] = idx - + item["metadata"]["source_dataset"] = dataset_name + item["metadata"]["source_index"] = idx + return item def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: """Forward scoring to appropriate dataset""" - dataset_name = entry['metadata']['source_dataset'] + dataset_name = entry["metadata"]["source_dataset"] return self.datasets[dataset_name].score_answer(answer, entry) + + +# Register the dataset +register_dataset("composite", CompositeDataset, CompositeConfig) diff --git a/reasoning_gym/factory.py b/reasoning_gym/factory.py index 3d0c91ac..8a9b8f5b 100644 --- a/reasoning_gym/factory.py +++ b/reasoning_gym/factory.py @@ -2,7 +2,6 @@ from typing import Dict, Type, TypeVar from .dataset import ProceduralDataset -from .composite import CompositeDataset, CompositeConfig # Type variables for generic type hints ConfigT = TypeVar("ConfigT") @@ -54,18 +53,6 @@ def create_dataset(name: str, **kwargs) -> ProceduralDataset: dataset_cls, config_cls = DATASETS[name] - conifg = config_cls(**kwargs) + config = config_cls(**kwargs) - return dataset_cls(config=conifg) - - -# Register composite dataset -register_dataset("composite", CompositeDataset, CompositeConfig) - if name not in DATASETS: - raise ValueError(f"Dataset '{name}' not registered") - - dataset_cls, config_cls = DATASETS[name] - - conifg = config_cls(**kwargs) - - return dataset_cls(config=conifg) + return dataset_cls(config=config) diff --git a/tests/test_composite.py b/tests/test_composite.py index e96b056b..cbfec38a 100644 --- a/tests/test_composite.py +++ b/tests/test_composite.py @@ -1,9 +1,9 @@ import os + import pytest -import tempfile import yaml -from reasoning_gym.composite import CompositeDataset, CompositeConfig, DatasetSpec +from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec def create_test_config(tmp_path): @@ -18,7 +18,7 @@ def create_test_config(tmp_path): "config": { "min_terms": 2, "max_terms": 4, - } + }, }, { "name": "leg_counting", @@ -26,15 +26,16 @@ def create_test_config(tmp_path): "config": { "min_animals": 1, "max_animals": 3, - } - } - ] + }, + }, + ], } - + config_path = os.path.join(tmp_path, "test_config.yaml") - with open(config_path, 'w') as f: + print(config_path) + with open(config_path, "w") as f: yaml.dump(config, f) - + return config_path @@ -43,7 +44,7 @@ def test_composite_config_validation(): with pytest.raises(AssertionError): config = CompositeConfig(size=-1) config.validate() - + with pytest.raises(AssertionError): config = CompositeConfig(datasets=[]) config.validate() @@ -52,16 +53,12 @@ def test_composite_config_validation(): def test_composite_dataset_deterministic(): """Test that dataset generates same items with same seed""" config = CompositeConfig( - size=10, - seed=42, - datasets=[ - DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4}) - ] + size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})] ) - + dataset1 = CompositeDataset(config) dataset2 = CompositeDataset(config) - + for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] @@ -69,16 +66,12 @@ def test_composite_dataset_deterministic(): def test_composite_dataset_metadata(): """Test that metadata includes source dataset information""" config = CompositeConfig( - size=10, - seed=42, - datasets=[ - DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4}) - ] + size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})] ) - + dataset = CompositeDataset(config) item = dataset[0] - + assert "source_dataset" in item["metadata"] assert "source_index" in item["metadata"] assert item["metadata"]["source_dataset"] == "chain_sum" @@ -93,9 +86,9 @@ def test_composite_dataset_weights(): datasets=[ DatasetSpec("chain_sum", 2.0, {"min_terms": 2}), DatasetSpec("chain_sum", 3.0, {"min_terms": 3}), - ] + ], ) - + dataset = CompositeDataset(config) assert abs(dataset.weights[0] - 0.4) < 1e-6 assert abs(dataset.weights[1] - 0.6) < 1e-6 @@ -105,7 +98,7 @@ def test_yaml_loading(tmp_path): """Test loading configuration from YAML""" config_path = create_test_config(tmp_path) config = CompositeConfig.from_yaml(config_path) - + assert config.size == 100 assert config.seed == 42 assert len(config.datasets) == 2