-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #61 from open-thought/composite_dataset
Add composite dataset
- Loading branch information
Showing
5 changed files
with
220 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from dataclasses import dataclass | ||
from random import Random | ||
from typing import Any, Dict, List, Optional | ||
|
||
import yaml | ||
|
||
from .dataset import ProceduralDataset | ||
from .factory import create_dataset, register_dataset | ||
|
||
|
||
@dataclass | ||
class DatasetSpec: | ||
"""Specification for a single dataset within the composite""" | ||
|
||
name: str | ||
weight: float | ||
config: dict | ||
|
||
def validate(self): | ||
"""Validate dataset specification""" | ||
assert self.name, "Dataset name cannot be empty" | ||
assert self.weight > 0, "Weight must be positive" | ||
assert isinstance(self.config, dict), "Config must be a dictionary" | ||
|
||
|
||
@dataclass | ||
class CompositeConfig: | ||
"""Configuration for CompositeDataset""" | ||
|
||
size: int = 500 | ||
seed: Optional[int] = None | ||
datasets: List[DatasetSpec] = None | ||
|
||
def validate(self): | ||
"""Validate configuration parameters""" | ||
assert self.size > 0, "size must be positive" | ||
assert self.datasets, "Must specify at least one dataset" | ||
assert len(self.datasets) > 0, "Must specify at least one dataset" | ||
|
||
# Validate each dataset spec | ||
for ds in self.datasets: | ||
ds.validate() | ||
|
||
@classmethod | ||
def from_yaml(cls, yaml_path: str) -> "CompositeConfig": | ||
"""Load configuration from YAML file""" | ||
with open(yaml_path, "r") as f: | ||
data = yaml.safe_load(f) | ||
|
||
# Convert dataset specs to DatasetSpec objects | ||
if "datasets" in data: | ||
data["datasets"] = [DatasetSpec(**ds) for ds in data["datasets"]] | ||
|
||
return cls(**data) | ||
|
||
|
||
class CompositeDataset(ProceduralDataset): | ||
"""A dataset that combines multiple datasets with weighted sampling""" | ||
|
||
def __init__(self, config: CompositeConfig): | ||
super().__init__(config=config, seed=config.seed, size=config.size) | ||
|
||
# Initialize sub-datasets with incremented seeds | ||
self.datasets = {} | ||
self.weights = [] | ||
total_weight = 0.0 | ||
|
||
for i, ds_spec in enumerate(config.datasets): | ||
# Create dataset with derived seed | ||
ds_config = ds_spec.config.copy() | ||
if "seed" not in ds_config: | ||
ds_config["seed"] = self.seed + i + 1 | ||
if "size" not in ds_config: | ||
ds_config["size"] = self.size | ||
|
||
self.datasets[ds_spec.name] = create_dataset(ds_spec.name, **ds_config) | ||
total_weight += ds_spec.weight | ||
self.weights.append(ds_spec.weight) | ||
|
||
# Normalize weights | ||
self.weights = [w / total_weight for w in self.weights] | ||
self.dataset_names = [ds.name for ds in config.datasets] | ||
|
||
def __getitem__(self, idx: int) -> dict: | ||
"""Generate a single dataset item by sampling from sub-datasets""" | ||
# Create deterministic RNG for this index | ||
rng = Random(self.seed + idx) | ||
|
||
# Sample dataset according to weights | ||
dataset_idx = rng.choices(range(len(self.dataset_names)), weights=self.weights, k=1)[0] | ||
dataset_name = self.dataset_names[dataset_idx] | ||
dataset = self.datasets[dataset_name] | ||
|
||
# Get item from selected dataset | ||
item = dataset[idx] | ||
|
||
# Add source dataset info to metadata | ||
item["metadata"]["source_dataset"] = dataset_name | ||
item["metadata"]["source_index"] = idx | ||
|
||
return item | ||
|
||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: | ||
"""Forward scoring to appropriate dataset""" | ||
dataset_name = entry["metadata"]["source_dataset"] | ||
return self.datasets[dataset_name].score_answer(answer, entry) | ||
|
||
|
||
# Register the dataset | ||
register_dataset("composite", CompositeDataset, CompositeConfig) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import os | ||
|
||
import pytest | ||
import yaml | ||
|
||
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec | ||
|
||
|
||
def create_test_config(tmp_path): | ||
"""Create a test YAML config file""" | ||
config = { | ||
"size": 100, | ||
"seed": 42, | ||
"datasets": [ | ||
{ | ||
"name": "chain_sum", | ||
"weight": 0.3, | ||
"config": { | ||
"min_terms": 2, | ||
"max_terms": 4, | ||
}, | ||
}, | ||
{ | ||
"name": "leg_counting", | ||
"weight": 0.7, | ||
"config": { | ||
"min_animals": 1, | ||
"max_animals": 3, | ||
}, | ||
}, | ||
], | ||
} | ||
|
||
config_path = os.path.join(tmp_path, "test_config.yaml") | ||
print(config_path) | ||
with open(config_path, "w") as f: | ||
yaml.dump(config, f) | ||
|
||
return config_path | ||
|
||
|
||
def test_composite_config_validation(): | ||
"""Test configuration validation""" | ||
with pytest.raises(AssertionError): | ||
config = CompositeConfig(size=-1) | ||
config.validate() | ||
|
||
with pytest.raises(AssertionError): | ||
config = CompositeConfig(datasets=[]) | ||
config.validate() | ||
|
||
|
||
def test_composite_dataset_deterministic(): | ||
"""Test that dataset generates same items with same seed""" | ||
config = CompositeConfig( | ||
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})] | ||
) | ||
|
||
dataset1 = CompositeDataset(config) | ||
dataset2 = CompositeDataset(config) | ||
|
||
for i in range(len(dataset1)): | ||
assert dataset1[i] == dataset2[i] | ||
|
||
|
||
def test_composite_dataset_metadata(): | ||
"""Test that metadata includes source dataset information""" | ||
config = CompositeConfig( | ||
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})] | ||
) | ||
|
||
dataset = CompositeDataset(config) | ||
item = dataset[0] | ||
|
||
assert "source_dataset" in item["metadata"] | ||
assert "source_index" in item["metadata"] | ||
assert item["metadata"]["source_dataset"] == "chain_sum" | ||
assert isinstance(item["metadata"]["source_index"], int) | ||
|
||
|
||
def test_composite_dataset_weights(): | ||
"""Test that dataset weights are properly normalized""" | ||
config = CompositeConfig( | ||
size=1000, | ||
seed=42, | ||
datasets=[ | ||
DatasetSpec("chain_sum", 2.0, {"min_terms": 2}), | ||
DatasetSpec("chain_sum", 3.0, {"min_terms": 3}), | ||
], | ||
) | ||
|
||
dataset = CompositeDataset(config) | ||
assert abs(dataset.weights[0] - 0.4) < 1e-6 | ||
assert abs(dataset.weights[1] - 0.6) < 1e-6 | ||
|
||
|
||
def test_yaml_loading(tmp_path): | ||
"""Test loading configuration from YAML""" | ||
config_path = create_test_config(tmp_path) | ||
config = CompositeConfig.from_yaml(config_path) | ||
|
||
assert config.size == 100 | ||
assert config.seed == 42 | ||
assert len(config.datasets) == 2 | ||
assert config.datasets[0].name == "chain_sum" | ||
assert config.datasets[1].name == "leg_counting" |