diff --git a/reasoning_gym/arithmetic/time_intervals.py b/reasoning_gym/arithmetic/time_intervals.py index abc177ca..1b296d02 100644 --- a/reasoning_gym/arithmetic/time_intervals.py +++ b/reasoning_gym/arithmetic/time_intervals.py @@ -1,5 +1,5 @@ import random -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import date, datetime, time, timedelta from typing import List, Optional @@ -19,14 +19,12 @@ class TimeIntervalsConfig: min_date: date = date(1900, 1, 1) max_date: date = date(3000, 1, 1) max_date_difference_days: int = 100 - task_types: List[str] = None + task_types: List[str] = field( + default_factory=lambda: ["time", "time_seconds", "time_ms", "date", "datetime", "datetime_tz"] + ) seed: Optional[int] = None size: int = 500 - def __post_init__(self): - if self.task_types is None: - self.task_types = ["time", "time_seconds", "time_ms", "date", "datetime", "datetime_tz"] - def validate(self) -> None: """Validate configuration parameters""" assert self.size > 0, "size must be positive" diff --git a/reasoning_gym/geometry/advanced_geometry.py b/reasoning_gym/geometry/advanced_geometry.py index 3d88a348..6a34d1db 100644 --- a/reasoning_gym/geometry/advanced_geometry.py +++ b/reasoning_gym/geometry/advanced_geometry.py @@ -1,5 +1,5 @@ import random -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import List, Optional import sympy @@ -21,16 +21,11 @@ class AdvancedGeometryConfig: # Probability or list of tasks we want to generate # For demonstration, we have three categories: - task_types: List[str] = None - - def __post_init__(self): - if self.task_types is None: - # Default set of advanced tasks - self.task_types = [ - "orthocenter", - "incircle_radius", - "angle_measure", - ] + task_types: List[str] = field(default_factory=lambda: [ + "orthocenter", + "incircle_radius", + "angle_measure", + ]) def validate(self): assert self.min_coord < self.max_coord, "min_coord must be < max_coord." diff --git a/reasoning_gym/graphs/family_relationships.py b/reasoning_gym/graphs/family_relationships.py index 6ba042a8..ba1809af 100644 --- a/reasoning_gym/graphs/family_relationships.py +++ b/reasoning_gym/graphs/family_relationships.py @@ -1,5 +1,5 @@ import random -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import StrEnum from itertools import count from typing import List, Optional, Set, Tuple @@ -37,12 +37,8 @@ class Person: gender: Gender id: int spouse: Optional["Person"] = None - parents: List["Person"] = None - children: List["Person"] = None - - def __post_init__(self): - self.parents = self.parents or [] - self.children = self.children or [] + parents: List["Person"] = field(default_factory=list) + children: List["Person"] = field(default_factory=list) def __hash__(self): return self.id @@ -69,109 +65,25 @@ class FamilyRelationshipsConfig: min_family_size: int = 4 max_family_size: int = 8 - male_names: List[str] = None - female_names: List[str] = None + male_names: List[str] = field(default_factory=lambda: [ + "James", "John", "Robert", "Michael", "William", "David", "Richard", "Joseph", + "Thomas", "Charles", "Peter", "Daniel", "Matthew", "Christopher", "Andrew", + "George", "Edward", "Benjamin", "Henry", "Samuel", "Alexander", "Oliver", + "Jack", "Harry", "Jacob", "Noah", "Ethan", "Lucas", "Mason", "Logan", + "Sebastian", "Theodore", "Owen", "Liam", "Aiden", "Kai", "Jayden", "Zion", + "Phoenix", "Atlas", "Axel", "Ryder", "Finn" + ]) + female_names: List[str] = field(default_factory=lambda: [ + "Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan", + "Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna", "Margaret", "Victoria", + "Charlotte", "Sophia", "Isabella", "Olivia", "Ava", "Mia", "Emily", + "Abigail", "Amelia", "Eleanor", "Grace", "Alice", "Lucy", "Chloe", + "Sophie", "Lily", "Hannah", "Zoe", "Luna", "Nova", "Aria", "Willow", + "Aurora", "Sage", "River", "Winter", "Sky", "Rain" + ]) seed: Optional[int] = None size: int = 500 - def __post_init__(self): - # Default name lists if none provided - default_male_names = [ - "James", - "John", - "Robert", - "Michael", - "William", - "David", - "Richard", - "Joseph", - "Thomas", - "Charles", - "Peter", - "Daniel", - "Matthew", - "Christopher", - "Andrew", - "George", - "Edward", - "Benjamin", - "Henry", - "Samuel", - "Alexander", - "Oliver", - "Jack", - "Harry", - "Jacob", - "Noah", - "Ethan", - "Lucas", - "Mason", - "Logan", - "Sebastian", - "Theodore", - "Owen", - "Liam", - "Aiden", - "Kai", - "Jayden", - "Zion", - "Phoenix", - "Atlas", - "Axel", - "Ryder", - "Finn", - ] - default_female_names = [ - "Mary", - "Patricia", - "Jennifer", - "Linda", - "Elizabeth", - "Barbara", - "Susan", - "Jessica", - "Sarah", - "Karen", - "Emma", - "Lisa", - "Anna", - "Margaret", - "Victoria", - "Charlotte", - "Sophia", - "Isabella", - "Olivia", - "Ava", - "Mia", - "Emily", - "Abigail", - "Amelia", - "Eleanor", - "Grace", - "Alice", - "Lucy", - "Chloe", - "Sophie", - "Lily", - "Hannah", - "Zoe", - "Luna", - "Nova", - "Aria", - "Willow", - "Aurora", - "Sage", - "River", - "Winter", - "Sky", - "Rain", - ] - - if self.male_names is None: - self.male_names = default_male_names - if self.female_names is None: - self.female_names = default_female_names - def validate(self) -> None: """Validate configuration parameters""" assert self.min_family_size >= 3, "min_family_size must be at least 3"