Skip to content

Commit

Permalink
refactor: Use field default_factory TimeIntervalsConfig, AdvancedGeom…
Browse files Browse the repository at this point in the history
…etryConfig
  • Loading branch information
andreaskoepf committed Feb 2, 2025
1 parent 8202f23 commit 4e9fc4b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 124 deletions.
10 changes: 4 additions & 6 deletions reasoning_gym/arithmetic/time_intervals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import random
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import date, datetime, time, timedelta
from typing import List, Optional

Expand All @@ -19,14 +19,12 @@ class TimeIntervalsConfig:
min_date: date = date(1900, 1, 1)
max_date: date = date(3000, 1, 1)
max_date_difference_days: int = 100
task_types: List[str] = None
task_types: List[str] = field(
default_factory=lambda: ["time", "time_seconds", "time_ms", "date", "datetime", "datetime_tz"]
)
seed: Optional[int] = None
size: int = 500

def __post_init__(self):
if self.task_types is None:
self.task_types = ["time", "time_seconds", "time_ms", "date", "datetime", "datetime_tz"]

def validate(self) -> None:
"""Validate configuration parameters"""
assert self.size > 0, "size must be positive"
Expand Down
17 changes: 6 additions & 11 deletions reasoning_gym/geometry/advanced_geometry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import random
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import List, Optional

import sympy
Expand All @@ -21,16 +21,11 @@ class AdvancedGeometryConfig:

# Probability or list of tasks we want to generate
# For demonstration, we have three categories:
task_types: List[str] = None

def __post_init__(self):
if self.task_types is None:
# Default set of advanced tasks
self.task_types = [
"orthocenter",
"incircle_radius",
"angle_measure",
]
task_types: List[str] = field(default_factory=lambda: [
"orthocenter",
"incircle_radius",
"angle_measure",
])

def validate(self):
assert self.min_coord < self.max_coord, "min_coord must be < max_coord."
Expand Down
126 changes: 19 additions & 107 deletions reasoning_gym/graphs/family_relationships.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import random
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import StrEnum
from itertools import count
from typing import List, Optional, Set, Tuple
Expand Down Expand Up @@ -37,12 +37,8 @@ class Person:
gender: Gender
id: int
spouse: Optional["Person"] = None
parents: List["Person"] = None
children: List["Person"] = None

def __post_init__(self):
self.parents = self.parents or []
self.children = self.children or []
parents: List["Person"] = field(default_factory=list)
children: List["Person"] = field(default_factory=list)

def __hash__(self):
return self.id
Expand All @@ -69,109 +65,25 @@ class FamilyRelationshipsConfig:

min_family_size: int = 4
max_family_size: int = 8
male_names: List[str] = None
female_names: List[str] = None
male_names: List[str] = field(default_factory=lambda: [
"James", "John", "Robert", "Michael", "William", "David", "Richard", "Joseph",
"Thomas", "Charles", "Peter", "Daniel", "Matthew", "Christopher", "Andrew",
"George", "Edward", "Benjamin", "Henry", "Samuel", "Alexander", "Oliver",
"Jack", "Harry", "Jacob", "Noah", "Ethan", "Lucas", "Mason", "Logan",
"Sebastian", "Theodore", "Owen", "Liam", "Aiden", "Kai", "Jayden", "Zion",
"Phoenix", "Atlas", "Axel", "Ryder", "Finn"
])
female_names: List[str] = field(default_factory=lambda: [
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan",
"Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna", "Margaret", "Victoria",
"Charlotte", "Sophia", "Isabella", "Olivia", "Ava", "Mia", "Emily",
"Abigail", "Amelia", "Eleanor", "Grace", "Alice", "Lucy", "Chloe",
"Sophie", "Lily", "Hannah", "Zoe", "Luna", "Nova", "Aria", "Willow",
"Aurora", "Sage", "River", "Winter", "Sky", "Rain"
])
seed: Optional[int] = None
size: int = 500

def __post_init__(self):
# Default name lists if none provided
default_male_names = [
"James",
"John",
"Robert",
"Michael",
"William",
"David",
"Richard",
"Joseph",
"Thomas",
"Charles",
"Peter",
"Daniel",
"Matthew",
"Christopher",
"Andrew",
"George",
"Edward",
"Benjamin",
"Henry",
"Samuel",
"Alexander",
"Oliver",
"Jack",
"Harry",
"Jacob",
"Noah",
"Ethan",
"Lucas",
"Mason",
"Logan",
"Sebastian",
"Theodore",
"Owen",
"Liam",
"Aiden",
"Kai",
"Jayden",
"Zion",
"Phoenix",
"Atlas",
"Axel",
"Ryder",
"Finn",
]
default_female_names = [
"Mary",
"Patricia",
"Jennifer",
"Linda",
"Elizabeth",
"Barbara",
"Susan",
"Jessica",
"Sarah",
"Karen",
"Emma",
"Lisa",
"Anna",
"Margaret",
"Victoria",
"Charlotte",
"Sophia",
"Isabella",
"Olivia",
"Ava",
"Mia",
"Emily",
"Abigail",
"Amelia",
"Eleanor",
"Grace",
"Alice",
"Lucy",
"Chloe",
"Sophie",
"Lily",
"Hannah",
"Zoe",
"Luna",
"Nova",
"Aria",
"Willow",
"Aurora",
"Sage",
"River",
"Winter",
"Sky",
"Rain",
]

if self.male_names is None:
self.male_names = default_male_names
if self.female_names is None:
self.female_names = default_female_names

def validate(self) -> None:
"""Validate configuration parameters"""
assert self.min_family_size >= 3, "min_family_size must be at least 3"
Expand Down

0 comments on commit 4e9fc4b

Please sign in to comment.