Skip to content

Commit

Permalink
added score_answer implementation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
joesharratt1229 committed Feb 2, 2025
1 parent f5838da commit b0d21cf
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 26 deletions.
31 changes: 29 additions & 2 deletions reasoning_gym/algebra/intermediate_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from dataclasses import dataclass
from typing import Optional
from typing import Any, Dict, Optional

import sympy

Expand Down Expand Up @@ -221,16 +221,43 @@ def __getitem__(self, index: int):
integrand = self._generate_repeated_parts(rng, x)

answer = sympy.integrate(integrand, x)
answer_str = str(answer) + " + C"

return {
"question": rng.choice(self.prompt_template).format(integrand=integrand),
"answer": str(answer) + " + C",
"answer": answer_str,
"metadata": {
"integrand": str(integrand),
"problem_type": problem_type,
"variable": str(x),
"type": substitution_type if problem_type == "substitution" else parts_type,
"expected_answer_expression": answer,
},
}

def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
"""Determine if the solution provided solves the problem"""
reward = 0.0
if answer is not None:
try:
var = metadata["variable"]
x = sympy.Symbol(var)
# Parse answer while allowing integration constant 'C'
user_expr = sympy.parse_expr(answer, local_dict={var: x, "C": sympy.Symbol("C")})
# Compute derivative of student's answer
derivative = sympy.diff(user_expr, x)
integrand = sympy.parse_expr(metadata["integrand"], local_dict={var: x})

# Check mathematical equivalence through simplification
if sympy.simplify(derivative - integrand) == 0:
reward = 1.0
elif answer.strip():
reward = 0.05
else:
reward = 0.01
except:
reward = 0.01
return reward


register_dataset("intermediate_integration", IntermediateIntegrationDataset, IntermediateIntegrationConfig)
32 changes: 30 additions & 2 deletions reasoning_gym/algebra/simple_integration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
from dataclasses import dataclass
from fractions import Fraction
from typing import Optional
from typing import Any, Dict, Optional

import sympy

Expand Down Expand Up @@ -73,8 +73,36 @@ def __getitem__(self, idx: int) -> dict:
return {
"question": rng.choice(self._prompt_templates).format(integrand=derivative),
"answer": str(polynomial) + " + C",
"metadata": {"integrand": str(derivative), "variable": str(symbol), "antiderivative": str(polynomial)},
"metadata": {
"integrand": str(derivative),
"variable": str(symbol),
"expected_answer_expression": polynomial,
},
}

def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
"""Determine if the solution provided solves the problem"""
reward = 0.0
if answer is not None:
try:
var = metadata["variable"]
x = sympy.Symbol(var)
# Parse answer while allowing integration constant 'C'
user_expr = sympy.parse_expr(answer, local_dict={var: x, "C": sympy.Symbol("C")})
# Compute derivative of student's answer
derivative = sympy.diff(user_expr, x)
integrand = sympy.parse_expr(metadata["integrand"], local_dict={var: x})

# Check mathematical equivalence through simplification
if sympy.simplify(derivative - integrand) == 0:
reward = 1.0
elif answer.strip():
reward = 0.05
else:
reward = 0.01
except:
reward = 0.01
return reward


register_dataset("simple_integration", SimpleIntegrationDataset, SimpleIntegrationConfig)
53 changes: 44 additions & 9 deletions tests/test_intermediate_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,50 @@ def test_intermediate_integration_dataset_items():
assert isinstance(parse_expr(answer), sympy.Expr)


def test_solution_verification():
"""Test for solution verification of each answer"""
config = IntermediateIntegrationConfig(seed=42, size=10)
def test_verify_answer():
config = IntermediateIntegrationConfig(seed=42)
dataset = IntermediateIntegrationDataset(config)
for i in range(len(dataset)):
item = dataset[i]
score = dataset.score_answer(item["answer"], item["metadata"])
assert score == 1.0

for item in dataset:
integrand = parse_expr(item["metadata"]["integrand"])
variable = sympy.Symbol(item["metadata"]["variable"])
answer = parse_expr(item["answer"].replace(" + C", ""))

# Verify that the derivative of the answer equals the integrand
assert sympy.simplify(sympy.diff(answer, variable) - integrand) == 0
def test_score_answer_cases():
"""Test various answer scoring scenarios"""
config = IntermediateIntegrationConfig(seed=42)
dataset = IntermediateIntegrationDataset(config)
x = sympy.Symbol("x")
X = sympy.Symbol("X")

# Test cases: (answer, metadata, expected_score)
test_cases = [
# Correct answers
("x**2 + C", {"variable": "x", "integrand": "2*x"}, 1.0),
("X**3 - 5*X + C", {"variable": "X", "integrand": "3*X**2 - 5"}, 1.0),
("sin(x) + C", {"variable": "x", "integrand": "cos(x)"}, 1.0),
# Correct without explicit constant
("x**2", {"variable": "x", "integrand": "2*x"}, 1.0),
("log(x)", {"variable": "x", "integrand": "1/x"}, 1.0),
# Incorrect but properly formatted
("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.05),
("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.05),
# Malformed expressions
("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.01),
("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.01),
# Empty answer
("", {"variable": "x", "integrand": "2*x"}, 0.01),
# Case sensitivity
("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.05),
("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.05),
# Alternative constant notation
("x**2 + K", {"variable": "x", "integrand": "2*x"}, 1.0),
("sin(x) + D", {"variable": "x", "integrand": "cos(x)"}, 1.0),
# Simplification required
("x**2 + C + 5 - 5", {"variable": "x", "integrand": "2*x"}, 1.0),
("(x**3)/3 - 2*x + C", {"variable": "x", "integrand": "x**2 - 2"}, 1.0),
]

for answer, metadata, expected in test_cases:
score = dataset.score_answer(answer, metadata)
assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}"
58 changes: 45 additions & 13 deletions tests/test_simple_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import random
from fractions import Fraction

import pytest
import sympy
from sympy.parsing.sympy_parser import parse_expr
Expand Down Expand Up @@ -63,23 +60,58 @@ def test_simple_integration_dataset_items():

assert "integrand" in item["metadata"]
assert "variable" in item["metadata"]
assert "antiderivative" in item["metadata"]
assert "expected_answer_expression" in item["metadata"]

# Verify answer is a mathematical expression
answer = item["answer"]
answer = answer.replace(" + C", "")
assert isinstance(parse_expr(answer), sympy.Expr)


def test_simple_integration_solution_verification():
"""Test for solution verification of each answer"""
config = SimpleIntegrationConfig(seed=42, size=10)
def test_verify_answer():
config = SimpleIntegrationConfig(seed=42)
dataset = SimpleIntegrationDataset(config)
for i in range(len(dataset)):
item = dataset[i]
score = dataset.score_answer(item["answer"], item["metadata"])
assert score == 1.0

for item in dataset:
integrand = parse_expr(item["metadata"]["integrand"])
variable = sympy.Symbol(item["metadata"]["variable"])
answer = parse_expr(item["answer"].replace(" + C", ""))

# Verify that the derivative of the answer equals the integrand
assert sympy.simplify(sympy.diff(answer, variable) - integrand) == 0
def test_score_answer_cases():
"""Test various answer scoring scenarios"""
config = SimpleIntegrationConfig(seed=42)
dataset = SimpleIntegrationDataset(config)
x = sympy.Symbol("x")
X = sympy.Symbol("X")

# Test cases: (answer, metadata, expected_score)
test_cases = [
# Correct answers
("x**2 + C", {"variable": "x", "integrand": "2*x"}, 1.0),
("X**3 - 5*X + C", {"variable": "X", "integrand": "3*X**2 - 5"}, 1.0),
("sin(x) + C", {"variable": "x", "integrand": "cos(x)"}, 1.0),
# Correct without explicit constant
("x**2", {"variable": "x", "integrand": "2*x"}, 1.0),
("log(x)", {"variable": "x", "integrand": "1/x"}, 1.0),
# Incorrect but properly formatted
("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.05),
("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.05),
# Malformed expressions
("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.01),
("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.01),
# Empty answer
("", {"variable": "x", "integrand": "2*x"}, 0.01),
# Case sensitivity
("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.05),
("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.05),
# Alternative constant notation
("x**2 + K", {"variable": "x", "integrand": "2*x"}, 1.0),
("sin(x) + D", {"variable": "x", "integrand": "cos(x)"}, 1.0),
# Simplification required
("x**2 + C + 5 - 5", {"variable": "x", "integrand": "2*x"}, 1.0),
("(x**3)/3 - 2*x + C", {"variable": "x", "integrand": "x**2 - 2"}, 1.0),
]

for answer, metadata, expected in test_cases:
score = dataset.score_answer(answer, metadata)
assert score == expected, f"Failed case: {answer} | Expected {expected}, got {score}"

0 comments on commit b0d21cf

Please sign in to comment.