Skip to content

Commit

Permalink
feat(env): Count Bits Curriculum (#267)
Browse files Browse the repository at this point in the history
* add min n

* count bits
  • Loading branch information
zafstojano authored Mar 5, 2025
1 parent 8ccc4d7 commit 9bb6d02
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 8 deletions.
3 changes: 2 additions & 1 deletion reasoning_gym/arithmetic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticDataset
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
from .chain_sum import ChainSumConfig, ChainSumDataset
from .count_bits import CountBitsConfig, CountBitsDataset
from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset
from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticDataset
from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumDataset
from .dice import DiceConfig, DiceDataset
Expand Down Expand Up @@ -48,6 +48,7 @@
"TimeIntervalsDataset",
"CountBitsConfig",
"CountBitsDataset",
"CountBitsCurriculum",
"DiceConfig",
"DiceDataset",
"NumberFormatConfig",
Expand Down
25 changes: 23 additions & 2 deletions reasoning_gym/arithmetic/count_bits.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from random import Random
from typing import Optional

from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset

QUESTION_TEMPLATE = """How many 1 bits are there in the binary representation of the number {number}?"""
Expand All @@ -13,14 +14,15 @@
class CountBitsConfig:
"""Configuration for Count Bits dataset generation"""

min_n: int = 1 # Minimum number to consider
max_n: int = 2**31 - 1 # Maximum number to consider

size: int = 500 # Virtual dataset size
seed: Optional[int] = None

def validate(self):
"""Validate configuration parameters"""
assert 1 <= self.max_n, "max_n must be at least 1"
assert 1 <= self.min_n <= self.max_n, "min_n must be between 1 and max_n"


class CountBitsDataset(ProceduralDataset):
Expand All @@ -33,7 +35,7 @@ def __getitem__(self, idx: int) -> dict:
"""Generate a single Count Bits question"""
rng = Random(self.seed + idx)

number = rng.randint(1, self.config.max_n)
number = rng.randint(self.config.min_n, self.config.max_n)
binary = bin(number)[2:]
answer = binary.count("1")

Expand All @@ -44,4 +46,23 @@ def __getitem__(self, idx: int) -> dict:
}


class CountBitsCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(CountBitsCurriculum.__name__, CountBitsConfig)

# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="n",
levels=[1_000, 1_000_000, 100_000_000, 2**31 - 1],
default_level=0,
description="Number to count bits in",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_n",
upper_field_name="max_n",
),
)


register_dataset("count_bits", CountBitsDataset, CountBitsConfig)
30 changes: 25 additions & 5 deletions tests/test_count_bits.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@

import pytest

from reasoning_gym.arithmetic.count_bits import CountBitsConfig, CountBitsDataset
from reasoning_gym.arithmetic.count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset


def test_count_bits_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = CountBitsConfig(max_n=-1) # Negative not allowed
config = CountBitsConfig(min_n=-1) # Negative not allowed
config.validate()

with pytest.raises(AssertionError):
config = CountBitsConfig(max_n=0) # Zero not allowed
config = CountBitsConfig(min_n=0) # Zero not allowed
config.validate()

with pytest.raises(AssertionError):
config = CountBitsConfig(min_n=10, max_n=5) # min_n > max_n
config.validate()


Expand All @@ -28,7 +32,7 @@ def test_count_bits_dataset_deterministic():

def test_count_bits_dataset_items():
"""Test basic properties of generated items"""
config = CountBitsConfig(max_n=10, size=10, seed=42)
config = CountBitsConfig(min_n=1, max_n=1_000_000, size=10, seed=42)
dataset = CountBitsDataset(config)

for i in range(len(dataset)):
Expand All @@ -49,7 +53,7 @@ def test_count_bits_dataset_items():
binary = item["metadata"]["binary"]

# Verify values
assert number <= config.max_n
assert config.min_n <= number <= config.max_n
assert solution >= 0
assert set(binary) <= {"0", "1"}

Expand Down Expand Up @@ -81,3 +85,19 @@ def test_count_bits_answer():
count += number & 1
number >>= 1
assert solution == count


def test_count_bits_curriculum():
curriculum = CountBitsCurriculum()

base_value = {"size": 150, "seed": 1}

base_cfg: CountBitsConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_n == 1_000 and base_cfg.max_n == 1_000

# test incrementing attribute levels
curriculum.increment_attr_level("n")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_n == 1_000 and increased_cfg.max_n == 1_000_000

0 comments on commit 9bb6d02

Please sign in to comment.