From 3603745ec7e8ad28827423b98ed7735d62f07bf4 Mon Sep 17 00:00:00 2001 From: Haidar Khan Date: Mon, 15 Jul 2024 00:00:28 -0400 Subject: [PATCH] mathquiz is running now --- main.py | 2 +- zero_sum_eval/games/mathquiz/__init__.py | 2 +- zero_sum_eval/games/mathquiz/mathquiz_game.py | 30 +++++++++++++------ .../games/mathquiz/mathquiz_player.py | 24 ++++++++++----- 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index 31e6685..ce2fc70 100644 --- a/main.py +++ b/main.py @@ -27,7 +27,7 @@ def main(): game_manager = GameManager(config) - logger.info("Starting a new game of chess.") + logger.info("Starting a new game!") final_state = game_manager.start() diff --git a/zero_sum_eval/games/mathquiz/__init__.py b/zero_sum_eval/games/mathquiz/__init__.py index a333d83..17a3a89 100644 --- a/zero_sum_eval/games/mathquiz/__init__.py +++ b/zero_sum_eval/games/mathquiz/__init__.py @@ -1,3 +1,3 @@ # This file is used to import the classes from the chess module -from .mathquiz_player import MathQuizPlayer +from .mathquiz_player import MathQuizTeacher, MathQuizStudent from .mathquiz_game import MathQuizGame \ No newline at end of file diff --git a/zero_sum_eval/games/mathquiz/mathquiz_game.py b/zero_sum_eval/games/mathquiz/mathquiz_game.py index a1a606f..48d4b5f 100644 --- a/zero_sum_eval/games/mathquiz/mathquiz_game.py +++ b/zero_sum_eval/games/mathquiz/mathquiz_game.py @@ -30,7 +30,7 @@ def __init__(self, roles=None, environment=None, context=None, target=None): {"question": None, "teacher_answer": None, "student_answer": None} self.roles = self.get_next_roles(self.environment) if roles is None else roles self.context = context if context is not None else {"history": [], "message": None} - self.target = target if target is not None else randint(1, 1000) + self.target = target if target is not None else str(randint(1, 1000)) def initialize(self, roles=None, environment=None, context=None, target=None): return MathQuizGame( @@ -70,10 +70,12 @@ def query_game(self): ) def verify_answer(self, answer): - return answer == self.target + return str(answer) == str(self.target.strip()) def validate_game(self): current_role = self.roles[0] + if current_role == "TeacherGenerateQuestion": + return None if self.environment['teacher_answer'] is not None: if current_role == "TeacherAnswerQuestion": if self.verify_answer(self.environment['teacher_answer']): @@ -97,16 +99,26 @@ def get_next_roles(self, environment): return ['StudentAnswerQuestion'] def export(self): - return { - 'roles': self.roles, - 'environment': self.environment, - 'context': self.context - } + current_role = self.roles[0] + if current_role == "TeacherGenerateQuestion": + return { + 'role': self.roles[0], + 'environment': self.target, + 'context': self.context + } + elif current_role in ("TeacherAnswerQuestion", "StudentAnswerQuestion"): + return { + 'role': self.roles[0], + 'environment': self.environment['question'], + 'context': self.context + } + else: + raise ValueError("Invalid role") def display(self): display_str = f"Role to Act: {self.roles[0]}\nMessage: {self.context['message']}\n" - display_str += f"{self.formatted_move_history()}\n" - display_str += f"{self.board}\n" + display_str += f"{self.environment}\n" + display_str += f"Target: {self.target}\n" return display_str diff --git a/zero_sum_eval/games/mathquiz/mathquiz_player.py b/zero_sum_eval/games/mathquiz/mathquiz_player.py index 0a80557..54c1566 100644 --- a/zero_sum_eval/games/mathquiz/mathquiz_player.py +++ b/zero_sum_eval/games/mathquiz/mathquiz_player.py @@ -42,9 +42,10 @@ def forward(self, math_question): @PLAYER_REGISTRY.register("mathquiz", "mathquiz_teacher") class MathQuizTeacher(Player): - def __init__(self, llm_model, max_tries=4, **kwargs): + def __init__(self, lm, max_tries=4, **kwargs): super().__init__(**kwargs) - self.llm_model = llm_model + lm_args = lm["args"] if "args" in lm else {} + self.llm_model = LM_REGISTRY.build(lm["type"], **lm_args) self.max_tries = max_tries self.question_module = GenerateQuestionCoT() self.answer_module = AnswerQuestionCoT() @@ -65,17 +66,20 @@ def make_move(self, game_state): current_role = game_state.roles[0] with dspy.context(lm=self.llm_model): if current_role == "TeacherGenerateQuestion": - trace = self.question_module(export['target']) + trace = self.question_module(export['environment']) return trace.math_question elif current_role == "TeacherAnswerQuestion": trace = self.answer_module(export['environment']) return trace.answer + else: + raise ValueError(f"Invalid role for teacher: {current_role}") @PLAYER_REGISTRY.register("mathquiz", "mathquiz_student") -class MathQuizAnswer(Player): - def __init__(self, llm_model, max_tries=4, **kwargs): +class MathQuizStudent(Player): + def __init__(self, lm, max_tries=4, **kwargs): super().__init__(**kwargs) - self.llm_model = llm_model + lm_args = lm["args"] if "args" in lm else {} + self.llm_model = LM_REGISTRY.build(lm["type"], **lm_args) self.max_tries = max_tries self.answer_module = AnswerQuestionCoT() # self.optimized_module = self.optimize_prompts() if self.optimize else None @@ -92,9 +96,13 @@ def make_move(self, game_state): str: The move made by the player """ export = game_state.export() + current_role = game_state.roles[0] with dspy.context(lm=self.llm_model): - trace = self.answer_module(export['environment']) - return trace.answer + if current_role == "StudentAnswerQuestion": + trace = self.answer_module(export['environment']) + return trace.answer + else: + raise ValueError(f"Invalid role for student: {current_role}")