Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf committed Feb 1, 2025
1 parent 4b7fdd7 commit 91cf5ab
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 65 deletions.
2 changes: 1 addition & 1 deletion reasoning_gym/geometry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .simple_geometry import SimpleGeometryConfig, SimpleGeometryDataset
from .advanced_geometry import AdvancedGeometryConfig, AdvancedGeometryDataset
from .simple_geometry import SimpleGeometryConfig, SimpleGeometryDataset

__all__ = [
"SimpleGeometryConfig",
Expand Down
44 changes: 17 additions & 27 deletions reasoning_gym/geometry/advanced_geometry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import random
from dataclasses import dataclass
from typing import Optional, List
from typing import List, Optional

import sympy
from sympy.geometry import Point, Triangle, Segment
from sympy.geometry import Point, Segment, Triangle

from ..factory import ProceduralDataset, register_dataset

Expand All @@ -13,9 +13,10 @@ class AdvancedGeometryConfig:
"""
Configuration for generating advanced geometry tasks.
"""

min_coord: int = -10 # Minimum x/y coordinate
max_coord: int = 10 # Maximum x/y coordinate
size: int = 50 # Number of problems to generate
max_coord: int = 10 # Maximum x/y coordinate
size: int = 50 # Number of problems to generate
seed: Optional[int] = None

# Probability or list of tasks we want to generate
Expand Down Expand Up @@ -98,16 +99,12 @@ def _generate_non_degenerate_triangle(self, rng: random.Random):
x = rng.randint(self.config.min_coord, self.config.max_coord)
y = rng.randint(self.config.min_coord, self.config.max_coord)
points.append(Point(x, y))

A, B, C = points

# Calculate signed area to check for non-degeneracy
# Using the formula: 1/2 * |x1(y2 - y3) + x2(y3 - y1) + x3(y1 - y2)|
area = abs(
A.x * (B.y - C.y) +
B.x * (C.y - A.y) +
C.x * (A.y - B.y)
) / 2
area = abs(A.x * (B.y - C.y) + B.x * (C.y - A.y) + C.x * (A.y - B.y)) / 2

if area > 0:
return A, B, C
Expand All @@ -121,21 +118,19 @@ def _build_orthocenter_task(self, rng: random.Random, A: Point, B: Point, C: Poi
# Convert segments to lines
BC_line = sympy.Line(B, C)
CA_line = sympy.Line(C, A)

# Calculate altitudes by creating lines perpendicular from each vertex
alt_A = BC_line.perpendicular_line(A)
alt_B = CA_line.perpendicular_line(B)

# Find orthocenter (intersection of any two altitudes, e.g. alt_A and alt_B)
ortho = alt_A.intersection(alt_B)[0]

x_ortho_approx = float(ortho.x.evalf())
y_ortho_approx = float(ortho.y.evalf())

question_template = rng.choice(self._prompt_templates["orthocenter"])
question = question_template.format(
A=(A.x, A.y), B=(B.x, B.y), C=(C.x, C.y)
)
question = question_template.format(A=(A.x, A.y), B=(B.x, B.y), C=(C.x, C.y))
answer_str = f"({x_ortho_approx:.3f}, {y_ortho_approx:.3f})"

metadata = {
Expand All @@ -147,7 +142,6 @@ def _build_orthocenter_task(self, rng: random.Random, A: Point, B: Point, C: Poi
}
return question, answer_str, metadata


def _build_incircle_radius_task(self, rng: random.Random, A: Point, B: Point, C: Point):
"""
Build a question about finding the incircle radius of triangle ABC.
Expand All @@ -156,23 +150,21 @@ def _build_incircle_radius_task(self, rng: random.Random, A: Point, B: Point, C:
a = B.distance(C)
b = C.distance(A)
c = A.distance(B)

# Semi-perimeter
s = (a + b + c) / 2

# Area using Heron's formula
area = sympy.sqrt(s * (s - a) * (s - b) * (s - c))

# Radius of incircle = Area / Semi-perimeter
radius = area / s

# Convert to float for final answer
radius_approx = float(radius.evalf())

question_template = rng.choice(self._prompt_templates["incircle_radius"])
question = question_template.format(
A=(A.x, A.y), B=(B.x, B.y), C=(C.x, C.y)
)
question = question_template.format(A=(A.x, A.y), B=(B.x, B.y), C=(C.x, C.y))
answer_str = f"{radius_approx:.3f}"

metadata = {
Expand Down Expand Up @@ -211,9 +203,7 @@ def _build_angle_measure_task(self, rng: random.Random, A: Point, B: Point, C: P
angle_deg = float(angle_rad.evalf() * 180 / sympy.pi)

question_template = rng.choice(self._prompt_templates["angle_measure"])
question = question_template.format(
A=(A.x, A.y), B=(B.x, B.y), C=(C.x, C.y)
)
question = question_template.format(A=(A.x, A.y), B=(B.x, B.y), C=(C.x, C.y))

answer_str = f"{angle_deg:.2f}°"
metadata = {
Expand Down
16 changes: 8 additions & 8 deletions reasoning_gym/geometry/simple_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from ..factory import ProceduralDataset, register_dataset


@dataclass
class SimpleGeometryConfig:
"""
Expand All @@ -12,12 +13,12 @@ class SimpleGeometryConfig:
for the first (N-1) sides, and asks the solver to find the last angle.
"""

min_sides: int = 3 # Minimum number of sides (e.g. triangle)
max_sides: int = 6 # Maximum number of sides (e.g. hexagon)
min_angle: int = 10 # Minimum angle (in degrees) for each of the first (N-1) angles
max_angle: int = 170 # Maximum angle (in degrees) for each of the first (N-1) angles
min_sides: int = 3 # Minimum number of sides (e.g. triangle)
max_sides: int = 6 # Maximum number of sides (e.g. hexagon)
min_angle: int = 10 # Minimum angle (in degrees) for each of the first (N-1) angles
max_angle: int = 170 # Maximum angle (in degrees) for each of the first (N-1) angles
seed: Optional[int] = None # Random seed
size: int = 100 # Number of geometry tasks to generate
size: int = 100 # Number of geometry tasks to generate

def validate(self) -> None:
"""
Expand Down Expand Up @@ -85,9 +86,7 @@ def __getitem__(self, idx: int) -> dict:
# Build the question string
angle_list_str = ", ".join(f"{a:.1f}°" for a in known_angles)
prompt = rng.choice(self._prompt_templates).format(
n_sides=n_sides,
n_minus_1=n_sides - 1,
angle_list=angle_list_str
n_sides=n_sides, n_minus_1=n_sides - 1, angle_list=angle_list_str
)

# Round the missing angle to one decimal place or integer if it is very close to an integer
Expand Down Expand Up @@ -136,5 +135,6 @@ def _generate_valid_angles(self, rng: random.Random, n_sides: int, total_sum: in
f"with total sum {total_sum} within {max_attempts} attempts."
)


# Register the dataset so it can be accessed similarly to the others
register_dataset("simple_geometry", SimpleGeometryDataset, SimpleGeometryConfig)
19 changes: 7 additions & 12 deletions tests/test_advanced_geometry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import pytest

from reasoning_gym.geometry.advanced_geometry import (
AdvancedGeometryDataset,
AdvancedGeometryConfig,
)
from reasoning_gym.geometry.advanced_geometry import AdvancedGeometryConfig, AdvancedGeometryDataset


def test_advanced_geometry_config_validation():
"""Test that invalid configs raise appropriate errors."""
Expand Down Expand Up @@ -35,8 +33,7 @@ def test_advanced_geometry_dataset_deterministic():

for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i], (
f"Item mismatch at index {i} for same seed. "
f"Dataset1: {dataset1[i]} vs Dataset2: {dataset2[i]}"
f"Item mismatch at index {i} for same seed. " f"Dataset1: {dataset1[i]} vs Dataset2: {dataset2[i]}"
)


Expand All @@ -55,16 +52,14 @@ def test_advanced_geometry_dataset_items():

# Basic metadata checks
metadata = item["metadata"]
assert "A" in metadata and "B" in metadata and "C" in metadata, (
"Metadata should contain coordinates for points A, B, and C."
)
assert (
"A" in metadata and "B" in metadata and "C" in metadata
), "Metadata should contain coordinates for points A, B, and C."

# Check answer format depending on task type
# For angle measure tasks, answer should end with '°'
if "angle_measure" in item["question"].lower() or "angle at" in item["question"].lower():
assert item["answer"].endswith("°"), (
f"Expected angle measure in degrees, got {item['answer']}"
)
assert item["answer"].endswith("°"), f"Expected angle measure in degrees, got {item['answer']}"


def test_advanced_geometry_dataset_iteration():
Expand Down
22 changes: 5 additions & 17 deletions tests/test_simple_geometry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import pytest

from reasoning_gym.geometry.simple_geometry import (
SimpleGeometryDataset,
SimpleGeometryConfig,
)
from reasoning_gym.geometry.simple_geometry import SimpleGeometryConfig, SimpleGeometryDataset


def test_simple_geometry_config_validation():
"""Test invalid configs raise appropriate errors."""
Expand Down Expand Up @@ -35,21 +33,13 @@ def test_simple_geometry_dataset_deterministic():

for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i], (
f"Item mismatch at index {i} for same seed. "
f"Dataset1: {dataset1[i]} vs Dataset2: {dataset2[i]}"
f"Item mismatch at index {i} for same seed. " f"Dataset1: {dataset1[i]} vs Dataset2: {dataset2[i]}"
)


def test_simple_geometry_dataset_items():
"""Test basic properties of generated items."""
config = SimpleGeometryConfig(
min_sides=3,
max_sides=5,
min_angle=10,
max_angle=120,
size=10,
seed=123
)
config = SimpleGeometryConfig(min_sides=3, max_sides=5, min_angle=10, max_angle=120, size=10, seed=123)
dataset = SimpleGeometryDataset(config)

for i in range(len(dataset)):
Expand All @@ -62,9 +52,7 @@ def test_simple_geometry_dataset_items():

metadata = item["metadata"]
assert "n_sides" in metadata, "Metadata should contain 'n_sides'."
assert "missing_angle_rounded" in metadata, (
"Metadata should contain the computed 'missing_angle_rounded'."
)
assert "missing_angle_rounded" in metadata, "Metadata should contain the computed 'missing_angle_rounded'."

# Check that the missing angle is a valid float or integer
missing_angle = float(item["answer"])
Expand Down

0 comments on commit 91cf5ab

Please sign in to comment.