-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from haidark/mathquiz
Initial Mathquiz Implementation
- Loading branch information
Showing
9 changed files
with
287 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
manager: | ||
args: | ||
max_rounds: 10 | ||
win_conditions: | ||
- StudentCorrect | ||
game: | ||
name: mathquiz | ||
players: | ||
- name: mathquiz_teacher | ||
args: | ||
id: gpt4 teacher | ||
roles: | ||
- TeacherGenerateQuestion | ||
- TeacherAnswerQuestion | ||
lm: | ||
type: AzureOpenAI | ||
args: | ||
api_base: https://allam-swn-gpt-01.openai.azure.com/ | ||
api_version: 2023-07-01-preview | ||
deployment_id: gpt-4o-900ptu | ||
max_tokens: 800 | ||
temperature: 0.8 | ||
top_p: 0.95 | ||
frequency_penalty: 0 | ||
presence_penalty: 0 | ||
- name: mathquiz_student | ||
args: | ||
id: gpt4 student | ||
roles: | ||
- StudentAnswerQuestion | ||
lm: | ||
type: AzureOpenAI | ||
args: | ||
api_base: https://allam-swn-gpt-01.openai.azure.com/ | ||
api_version: 2023-07-01-preview | ||
deployment_id: gpt-4o-900ptu | ||
max_tokens: 800 | ||
temperature: 0.8 | ||
top_p: 0.95 | ||
frequency_penalty: 0 | ||
presence_penalty: 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# This file is used to import the classes from the chess module | ||
from .mathquiz_player import MathQuizTeacher, MathQuizStudent | ||
from .mathquiz_game import MathQuizGame |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from zero_sum_eval.game_state import GameState | ||
from random import randint | ||
from zero_sum_eval.registry import GAME_REGISTRY | ||
|
||
|
||
@GAME_REGISTRY.register("mathquiz") | ||
class MathQuizGame(GameState): | ||
""" | ||
This is a two player game where the players take turns to answer math questions. | ||
In each round: | ||
1. the environment is initialized with a target number | ||
2. The first player to move creates a math question with the answer as the target number | ||
and proves that the question is valid. | ||
3. If the first player succeeds, the second player is given a chance to answer the question. | ||
4. The game continues for a fixed number of rounds. | ||
The roles for this game are: | ||
TeacherGenerateQuestion | ||
TeacherAnswerQuestion | ||
StudentAnswerQuestion | ||
The environment for this game is: | ||
question: a math question | ||
teacher_answer: the teacher's answer to the math question | ||
student_answer: the student's answer to the math question | ||
""" | ||
def __init__(self, roles=None, environment=None, context=None, target=None): | ||
super().__init__() | ||
self.environment = environment if environment is not None else \ | ||
{"question": None, "teacher_answer": None, "student_answer": None} | ||
self.roles = self.get_next_roles(self.environment) if roles is None else roles | ||
self.context = context if context is not None else {"history": [], "message": None} | ||
self.target = target if target is not None else str(randint(1, 1000)) | ||
|
||
def initialize(self, roles=None, environment=None, context=None, target=None): | ||
return MathQuizGame( | ||
roles=roles, | ||
environment=environment, | ||
context=context, | ||
target=target | ||
) | ||
|
||
def update_game(self, move): | ||
new_context = self.context.copy() | ||
new_environment = self.environment.copy() | ||
if self.roles[0] == "TeacherGenerateQuestion": | ||
new_environment['question'] = move | ||
elif self.roles[0] == "TeacherAnswerQuestion": | ||
new_environment['teacher_answer'] = move | ||
elif self.roles[0] == "StudentAnswerQuestion": | ||
new_environment['student_answer'] = move | ||
|
||
return self.initialize( | ||
roles=self.roles, | ||
environment=new_environment, | ||
context=new_context, | ||
target=self.target | ||
) | ||
|
||
def query_game(self): | ||
new_context = self.context.copy() | ||
new_roles = self.get_next_roles(self.environment) | ||
msg = self.validate_game() | ||
new_context['message'] = msg if msg is not None else f"You will move as {new_roles[0]}" | ||
|
||
return self.initialize( | ||
environment=self.environment, | ||
context=new_context, | ||
roles=new_roles, | ||
target=self.target | ||
) | ||
|
||
def verify_answer(self, answer): | ||
return str(answer) == str(self.target) | ||
|
||
def validate_game(self): | ||
current_role = self.roles[0] | ||
if current_role == "TeacherGenerateQuestion": | ||
return None | ||
elif current_role == "TeacherAnswerQuestion": | ||
if self.verify_answer(self.environment['teacher_answer']): | ||
return None | ||
else: | ||
return "TeacherIncorrect" | ||
elif current_role == "StudentAnswerQuestion": | ||
if self.verify_answer(self.environment['teacher_answer']): | ||
return "StudentCorrect" | ||
else: | ||
return "StudentIncorrect" | ||
|
||
def get_next_roles(self, environment): | ||
if environment['question'] is None: | ||
return ['TeacherGenerateQuestion', 'TeacherAnswerQuestion', 'StudentAnswerQuestion'] | ||
elif environment['teacher_answer'] is None: | ||
return ['TeacherAnswerQuestion', 'StudentAnswerQuestion'] | ||
else: | ||
return ['StudentAnswerQuestion'] | ||
|
||
def export(self): | ||
current_role = self.roles[0] | ||
if current_role == "TeacherGenerateQuestion": | ||
return { | ||
'role': self.roles[0], | ||
'environment': self.target, | ||
'context': self.context | ||
} | ||
elif current_role in ("TeacherAnswerQuestion", "StudentAnswerQuestion"): | ||
return { | ||
'role': self.roles[0], | ||
'environment': self.environment['question'], | ||
'context': self.context | ||
} | ||
else: | ||
raise ValueError("Invalid role") | ||
|
||
def display(self): | ||
display_str = f"Role to Act: {self.roles[0]}\nMessage: {self.context['message']}\n" | ||
display_str += f"{self.environment}\n" | ||
display_str += f"Target: {self.target}\n" | ||
return display_str | ||
|
||
|
||
if __name__ == "__main__": | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import dspy | ||
from zero_sum_eval.player import Player | ||
from zero_sum_eval.registry import PLAYER_REGISTRY | ||
|
||
class GenerateQuestion(dspy.Signature): | ||
"""Given a target number, generate a challenging math question with the target number as the answer. Make sure not to include the answer in the question.""" | ||
|
||
target_number = dspy.InputField(desc="target number") | ||
math_question = dspy.OutputField(desc="math question with the target number as the answer") | ||
|
||
class AnswerQuestion(dspy.Signature): | ||
"""Given a math question, give the answer to the question as a number only""" | ||
|
||
math_question = dspy.InputField(desc="math question") | ||
answer = dspy.OutputField(desc="answer to the math question with number only") | ||
|
||
class GenerateQuestionCoT(dspy.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.cot_question = dspy.ChainOfThought(GenerateQuestion) | ||
|
||
|
||
def forward(self, target_number): | ||
cot_out = self.cot_question(target_number=target_number) | ||
return cot_out | ||
|
||
class AnswerQuestionCoT(dspy.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.cot_answer = dspy.ChainOfThought(AnswerQuestion) | ||
|
||
def forward(self, math_question): | ||
cot_out = self.cot_answer(math_question=math_question) | ||
return cot_out | ||
|
||
@PLAYER_REGISTRY.register("mathquiz", "mathquiz_teacher") | ||
class MathQuizTeacher(Player): | ||
def _build_modules(self, **module_args): | ||
self.question_module = GenerateQuestionCoT() | ||
self.answer_module = AnswerQuestionCoT() | ||
return [self.question_module, self.answer_module] | ||
|
||
def make_move(self, game_state): | ||
""" | ||
Abstract method for making a move based on the current game state. | ||
Parameters: | ||
game_state (GameState): The current state of the game | ||
Returns: | ||
str: The move made by the player | ||
""" | ||
export = game_state.export() | ||
current_role = game_state.roles[0] | ||
with dspy.context(lm=self.llm_model): | ||
if current_role == "TeacherGenerateQuestion": | ||
trace = self.question_module(export['environment']) | ||
return trace.math_question | ||
elif current_role == "TeacherAnswerQuestion": | ||
trace = self.answer_module(export['environment']) | ||
return trace.answer | ||
else: | ||
raise ValueError(f"Invalid role for teacher: {current_role}") | ||
|
||
|
||
@PLAYER_REGISTRY.register("mathquiz", "mathquiz_student") | ||
class MathQuizStudent(Player): | ||
def _build_modules(self, **module_args): | ||
self.answer_module = AnswerQuestionCoT() | ||
return [self.answer_module] | ||
|
||
def make_move(self, game_state): | ||
""" | ||
Abstract method for making a move based on the current game state. | ||
Parameters: | ||
game_state (GameState): The current state of the game | ||
Returns: | ||
str: The move made by the player | ||
""" | ||
export = game_state.export() | ||
current_role = game_state.roles[0] | ||
with dspy.context(lm=self.llm_model): | ||
if current_role == "StudentAnswerQuestion": | ||
trace = self.answer_module(export['environment']) | ||
return trace.answer | ||
else: | ||
raise ValueError(f"Invalid role for student: {current_role}") | ||
|
||
|
||
|
Oops, something went wrong.