Skip to content

Commit

Permalink
Merge pull request #61 from open-thought/composite_dataset
Browse files Browse the repository at this point in the history
Add composite dataset
  • Loading branch information
andreaskoepf authored Feb 5, 2025
2 parents b84e29a + 4899926 commit c4c0b3b
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 3 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"pyfiglet==1.0.2",
"pytz>=2024.1",
"tabulate==0.9.0",
"pyyaml>=6.0.2",
]
classifiers = [
"Programming Language :: Python :: 3",
Expand Down
2 changes: 1 addition & 1 deletion reasoning_gym/algorithmic/letter_jumble.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from dataclasses import dataclass
from random import Random
from typing import List, Optional
from typing import Optional

from reasoning_gym.data import read_data_file

Expand Down
110 changes: 110 additions & 0 deletions reasoning_gym/composite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from dataclasses import dataclass
from random import Random
from typing import Any, Dict, List, Optional

import yaml

from .dataset import ProceduralDataset
from .factory import create_dataset, register_dataset


@dataclass
class DatasetSpec:
"""Specification for a single dataset within the composite"""

name: str
weight: float
config: dict

def validate(self):
"""Validate dataset specification"""
assert self.name, "Dataset name cannot be empty"
assert self.weight > 0, "Weight must be positive"
assert isinstance(self.config, dict), "Config must be a dictionary"


@dataclass
class CompositeConfig:
"""Configuration for CompositeDataset"""

size: int = 500
seed: Optional[int] = None
datasets: List[DatasetSpec] = None

def validate(self):
"""Validate configuration parameters"""
assert self.size > 0, "size must be positive"
assert self.datasets, "Must specify at least one dataset"
assert len(self.datasets) > 0, "Must specify at least one dataset"

# Validate each dataset spec
for ds in self.datasets:
ds.validate()

@classmethod
def from_yaml(cls, yaml_path: str) -> "CompositeConfig":
"""Load configuration from YAML file"""
with open(yaml_path, "r") as f:
data = yaml.safe_load(f)

# Convert dataset specs to DatasetSpec objects
if "datasets" in data:
data["datasets"] = [DatasetSpec(**ds) for ds in data["datasets"]]

return cls(**data)


class CompositeDataset(ProceduralDataset):
"""A dataset that combines multiple datasets with weighted sampling"""

def __init__(self, config: CompositeConfig):
super().__init__(config=config, seed=config.seed, size=config.size)

# Initialize sub-datasets with incremented seeds
self.datasets = {}
self.weights = []
total_weight = 0.0

for i, ds_spec in enumerate(config.datasets):
# Create dataset with derived seed
ds_config = ds_spec.config.copy()
if "seed" not in ds_config:
ds_config["seed"] = self.seed + i + 1
if "size" not in ds_config:
ds_config["size"] = self.size

self.datasets[ds_spec.name] = create_dataset(ds_spec.name, **ds_config)
total_weight += ds_spec.weight
self.weights.append(ds_spec.weight)

# Normalize weights
self.weights = [w / total_weight for w in self.weights]
self.dataset_names = [ds.name for ds in config.datasets]

def __getitem__(self, idx: int) -> dict:
"""Generate a single dataset item by sampling from sub-datasets"""
# Create deterministic RNG for this index
rng = Random(self.seed + idx)

# Sample dataset according to weights
dataset_idx = rng.choices(range(len(self.dataset_names)), weights=self.weights, k=1)[0]
dataset_name = self.dataset_names[dataset_idx]
dataset = self.datasets[dataset_name]

# Get item from selected dataset
item = dataset[idx]

# Add source dataset info to metadata
item["metadata"]["source_dataset"] = dataset_name
item["metadata"]["source_index"] = idx

return item

def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
"""Forward scoring to appropriate dataset"""
dataset_name = entry["metadata"]["source_dataset"]
return self.datasets[dataset_name].score_answer(answer, entry)


# Register the dataset
register_dataset("composite", CompositeDataset, CompositeConfig)
4 changes: 2 additions & 2 deletions reasoning_gym/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ def create_dataset(name: str, **kwargs) -> ProceduralDataset:

dataset_cls, config_cls = DATASETS[name]

conifg = config_cls(**kwargs)
config = config_cls(**kwargs)

return dataset_cls(config=conifg)
return dataset_cls(config=config)
106 changes: 106 additions & 0 deletions tests/test_composite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os

import pytest
import yaml

from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec


def create_test_config(tmp_path):
"""Create a test YAML config file"""
config = {
"size": 100,
"seed": 42,
"datasets": [
{
"name": "chain_sum",
"weight": 0.3,
"config": {
"min_terms": 2,
"max_terms": 4,
},
},
{
"name": "leg_counting",
"weight": 0.7,
"config": {
"min_animals": 1,
"max_animals": 3,
},
},
],
}

config_path = os.path.join(tmp_path, "test_config.yaml")
print(config_path)
with open(config_path, "w") as f:
yaml.dump(config, f)

return config_path


def test_composite_config_validation():
"""Test configuration validation"""
with pytest.raises(AssertionError):
config = CompositeConfig(size=-1)
config.validate()

with pytest.raises(AssertionError):
config = CompositeConfig(datasets=[])
config.validate()


def test_composite_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = CompositeConfig(
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
)

dataset1 = CompositeDataset(config)
dataset2 = CompositeDataset(config)

for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]


def test_composite_dataset_metadata():
"""Test that metadata includes source dataset information"""
config = CompositeConfig(
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
)

dataset = CompositeDataset(config)
item = dataset[0]

assert "source_dataset" in item["metadata"]
assert "source_index" in item["metadata"]
assert item["metadata"]["source_dataset"] == "chain_sum"
assert isinstance(item["metadata"]["source_index"], int)


def test_composite_dataset_weights():
"""Test that dataset weights are properly normalized"""
config = CompositeConfig(
size=1000,
seed=42,
datasets=[
DatasetSpec("chain_sum", 2.0, {"min_terms": 2}),
DatasetSpec("chain_sum", 3.0, {"min_terms": 3}),
],
)

dataset = CompositeDataset(config)
assert abs(dataset.weights[0] - 0.4) < 1e-6
assert abs(dataset.weights[1] - 0.6) < 1e-6


def test_yaml_loading(tmp_path):
"""Test loading configuration from YAML"""
config_path = create_test_config(tmp_path)
config = CompositeConfig.from_yaml(config_path)

assert config.size == 100
assert config.seed == 42
assert len(config.datasets) == 2
assert config.datasets[0].name == "chain_sum"
assert config.datasets[1].name == "leg_counting"

0 comments on commit c4c0b3b

Please sign in to comment.