Skip to content

Commit

Permalink
mathquiz is running now
Browse files Browse the repository at this point in the history
  • Loading branch information
Haidar Khan committed Jul 27, 2024
1 parent c4cc6f3 commit 3603745
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 19 deletions.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main():

game_manager = GameManager(config)

logger.info("Starting a new game of chess.")
logger.info("Starting a new game!")

final_state = game_manager.start()

Expand Down
2 changes: 1 addition & 1 deletion zero_sum_eval/games/mathquiz/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# This file is used to import the classes from the chess module
from .mathquiz_player import MathQuizPlayer
from .mathquiz_player import MathQuizTeacher, MathQuizStudent
from .mathquiz_game import MathQuizGame
30 changes: 21 additions & 9 deletions zero_sum_eval/games/mathquiz/mathquiz_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, roles=None, environment=None, context=None, target=None):
{"question": None, "teacher_answer": None, "student_answer": None}
self.roles = self.get_next_roles(self.environment) if roles is None else roles
self.context = context if context is not None else {"history": [], "message": None}
self.target = target if target is not None else randint(1, 1000)
self.target = target if target is not None else str(randint(1, 1000))

def initialize(self, roles=None, environment=None, context=None, target=None):
return MathQuizGame(
Expand Down Expand Up @@ -70,10 +70,12 @@ def query_game(self):
)

def verify_answer(self, answer):
return answer == self.target
return str(answer) == str(self.target.strip())

def validate_game(self):
current_role = self.roles[0]
if current_role == "TeacherGenerateQuestion":
return None
if self.environment['teacher_answer'] is not None:
if current_role == "TeacherAnswerQuestion":
if self.verify_answer(self.environment['teacher_answer']):
Expand All @@ -97,16 +99,26 @@ def get_next_roles(self, environment):
return ['StudentAnswerQuestion']

def export(self):
return {
'roles': self.roles,
'environment': self.environment,
'context': self.context
}
current_role = self.roles[0]
if current_role == "TeacherGenerateQuestion":
return {
'role': self.roles[0],
'environment': self.target,
'context': self.context
}
elif current_role in ("TeacherAnswerQuestion", "StudentAnswerQuestion"):
return {
'role': self.roles[0],
'environment': self.environment['question'],
'context': self.context
}
else:
raise ValueError("Invalid role")

def display(self):
display_str = f"Role to Act: {self.roles[0]}\nMessage: {self.context['message']}\n"
display_str += f"{self.formatted_move_history()}\n"
display_str += f"{self.board}\n"
display_str += f"{self.environment}\n"
display_str += f"Target: {self.target}\n"
return display_str


Expand Down
24 changes: 16 additions & 8 deletions zero_sum_eval/games/mathquiz/mathquiz_player.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ def forward(self, math_question):

@PLAYER_REGISTRY.register("mathquiz", "mathquiz_teacher")
class MathQuizTeacher(Player):
def __init__(self, llm_model, max_tries=4, **kwargs):
def __init__(self, lm, max_tries=4, **kwargs):
super().__init__(**kwargs)
self.llm_model = llm_model
lm_args = lm["args"] if "args" in lm else {}
self.llm_model = LM_REGISTRY.build(lm["type"], **lm_args)
self.max_tries = max_tries
self.question_module = GenerateQuestionCoT()
self.answer_module = AnswerQuestionCoT()
Expand All @@ -65,17 +66,20 @@ def make_move(self, game_state):
current_role = game_state.roles[0]
with dspy.context(lm=self.llm_model):
if current_role == "TeacherGenerateQuestion":
trace = self.question_module(export['target'])
trace = self.question_module(export['environment'])
return trace.math_question
elif current_role == "TeacherAnswerQuestion":
trace = self.answer_module(export['environment'])
return trace.answer
else:
raise ValueError(f"Invalid role for teacher: {current_role}")

@PLAYER_REGISTRY.register("mathquiz", "mathquiz_student")
class MathQuizAnswer(Player):
def __init__(self, llm_model, max_tries=4, **kwargs):
class MathQuizStudent(Player):
def __init__(self, lm, max_tries=4, **kwargs):
super().__init__(**kwargs)
self.llm_model = llm_model
lm_args = lm["args"] if "args" in lm else {}
self.llm_model = LM_REGISTRY.build(lm["type"], **lm_args)
self.max_tries = max_tries
self.answer_module = AnswerQuestionCoT()
# self.optimized_module = self.optimize_prompts() if self.optimize else None
Expand All @@ -92,9 +96,13 @@ def make_move(self, game_state):
str: The move made by the player
"""
export = game_state.export()
current_role = game_state.roles[0]
with dspy.context(lm=self.llm_model):
trace = self.answer_module(export['environment'])
return trace.answer
if current_role == "StudentAnswerQuestion":
trace = self.answer_module(export['environment'])
return trace.answer
else:
raise ValueError(f"Invalid role for student: {current_role}")



0 comments on commit 3603745

Please sign in to comment.