Skip to content

Commit

Permalink
register composite dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf committed Feb 4, 2025
1 parent 0c08ee5 commit 4899926
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 67 deletions.
2 changes: 1 addition & 1 deletion reasoning_gym/algorithmic/letter_jumble.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from dataclasses import dataclass
from random import Random
from typing import List, Optional
from typing import Optional

from reasoning_gym.data import read_data_file

Expand Down
55 changes: 31 additions & 24 deletions reasoning_gym/composite.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from dataclasses import dataclass
from random import Random
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional

import yaml

from .dataset import ProceduralDataset
from .factory import create_dataset
from .factory import create_dataset, register_dataset


@dataclass
class DatasetSpec:
"""Specification for a single dataset within the composite"""

name: str
weight: float
config: dict
Expand All @@ -24,6 +26,7 @@ def validate(self):
@dataclass
class CompositeConfig:
"""Configuration for CompositeDataset"""

size: int = 500
seed: Optional[int] = None
datasets: List[DatasetSpec] = None
Expand All @@ -33,21 +36,21 @@ def validate(self):
assert self.size > 0, "size must be positive"
assert self.datasets, "Must specify at least one dataset"
assert len(self.datasets) > 0, "Must specify at least one dataset"

# Validate each dataset spec
for ds in self.datasets:
ds.validate()

@classmethod
def from_yaml(cls, yaml_path: str) -> 'CompositeConfig':
def from_yaml(cls, yaml_path: str) -> "CompositeConfig":
"""Load configuration from YAML file"""
with open(yaml_path, 'r') as f:
with open(yaml_path, "r") as f:
data = yaml.safe_load(f)

# Convert dataset specs to DatasetSpec objects
if 'datasets' in data:
data['datasets'] = [DatasetSpec(**ds) for ds in data['datasets']]
if "datasets" in data:
data["datasets"] = [DatasetSpec(**ds) for ds in data["datasets"]]

return cls(**data)


Expand All @@ -56,24 +59,24 @@ class CompositeDataset(ProceduralDataset):

def __init__(self, config: CompositeConfig):
super().__init__(config=config, seed=config.seed, size=config.size)

# Initialize sub-datasets with incremented seeds
self.datasets = {}
self.weights = []
total_weight = 0.0

for i, ds_spec in enumerate(config.datasets):
# Create dataset with derived seed
ds_config = ds_spec.config.copy()
if 'seed' not in ds_config:
ds_config['seed'] = self.seed + i + 1
if 'size' not in ds_config:
ds_config['size'] = self.size
if "seed" not in ds_config:
ds_config["seed"] = self.seed + i + 1
if "size" not in ds_config:
ds_config["size"] = self.size

self.datasets[ds_spec.name] = create_dataset(ds_spec.name, **ds_config)
total_weight += ds_spec.weight
self.weights.append(ds_spec.weight)

# Normalize weights
self.weights = [w / total_weight for w in self.weights]
self.dataset_names = [ds.name for ds in config.datasets]
Expand All @@ -82,22 +85,26 @@ def __getitem__(self, idx: int) -> dict:
"""Generate a single dataset item by sampling from sub-datasets"""
# Create deterministic RNG for this index
rng = Random(self.seed + idx)

# Sample dataset according to weights
dataset_idx = rng.choices(range(len(self.dataset_names)), weights=self.weights, k=1)[0]
dataset_name = self.dataset_names[dataset_idx]
dataset = self.datasets[dataset_name]

# Get item from selected dataset
item = dataset[idx]

# Add source dataset info to metadata
item['metadata']['source_dataset'] = dataset_name
item['metadata']['source_index'] = idx
item["metadata"]["source_dataset"] = dataset_name
item["metadata"]["source_index"] = idx

return item

def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
"""Forward scoring to appropriate dataset"""
dataset_name = entry['metadata']['source_dataset']
dataset_name = entry["metadata"]["source_dataset"]
return self.datasets[dataset_name].score_answer(answer, entry)


# Register the dataset
register_dataset("composite", CompositeDataset, CompositeConfig)
17 changes: 2 additions & 15 deletions reasoning_gym/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Dict, Type, TypeVar

from .dataset import ProceduralDataset
from .composite import CompositeDataset, CompositeConfig

# Type variables for generic type hints
ConfigT = TypeVar("ConfigT")
Expand Down Expand Up @@ -54,18 +53,6 @@ def create_dataset(name: str, **kwargs) -> ProceduralDataset:

dataset_cls, config_cls = DATASETS[name]

conifg = config_cls(**kwargs)
config = config_cls(**kwargs)

return dataset_cls(config=conifg)


# Register composite dataset
register_dataset("composite", CompositeDataset, CompositeConfig)
if name not in DATASETS:
raise ValueError(f"Dataset '{name}' not registered")

dataset_cls, config_cls = DATASETS[name]

conifg = config_cls(**kwargs)

return dataset_cls(config=conifg)
return dataset_cls(config=config)
47 changes: 20 additions & 27 deletions tests/test_composite.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os

import pytest
import tempfile
import yaml

from reasoning_gym.composite import CompositeDataset, CompositeConfig, DatasetSpec
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec


def create_test_config(tmp_path):
Expand All @@ -18,23 +18,24 @@ def create_test_config(tmp_path):
"config": {
"min_terms": 2,
"max_terms": 4,
}
},
},
{
"name": "leg_counting",
"weight": 0.7,
"config": {
"min_animals": 1,
"max_animals": 3,
}
}
]
},
},
],
}

config_path = os.path.join(tmp_path, "test_config.yaml")
with open(config_path, 'w') as f:
print(config_path)
with open(config_path, "w") as f:
yaml.dump(config, f)

return config_path


Expand All @@ -43,7 +44,7 @@ def test_composite_config_validation():
with pytest.raises(AssertionError):
config = CompositeConfig(size=-1)
config.validate()

with pytest.raises(AssertionError):
config = CompositeConfig(datasets=[])
config.validate()
Expand All @@ -52,33 +53,25 @@ def test_composite_config_validation():
def test_composite_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = CompositeConfig(
size=10,
seed=42,
datasets=[
DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})
]
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
)

dataset1 = CompositeDataset(config)
dataset2 = CompositeDataset(config)

for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]


def test_composite_dataset_metadata():
"""Test that metadata includes source dataset information"""
config = CompositeConfig(
size=10,
seed=42,
datasets=[
DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})
]
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
)

dataset = CompositeDataset(config)
item = dataset[0]

assert "source_dataset" in item["metadata"]
assert "source_index" in item["metadata"]
assert item["metadata"]["source_dataset"] == "chain_sum"
Expand All @@ -93,9 +86,9 @@ def test_composite_dataset_weights():
datasets=[
DatasetSpec("chain_sum", 2.0, {"min_terms": 2}),
DatasetSpec("chain_sum", 3.0, {"min_terms": 3}),
]
],
)

dataset = CompositeDataset(config)
assert abs(dataset.weights[0] - 0.4) < 1e-6
assert abs(dataset.weights[1] - 0.6) < 1e-6
Expand All @@ -105,7 +98,7 @@ def test_yaml_loading(tmp_path):
"""Test loading configuration from YAML"""
config_path = create_test_config(tmp_path)
config = CompositeConfig.from_yaml(config_path)

assert config.size == 100
assert config.seed == 42
assert len(config.datasets) == 2
Expand Down

0 comments on commit 4899926

Please sign in to comment.