Skip to content

Commit

Permalink
implementation nearly complete
Browse files Browse the repository at this point in the history
  • Loading branch information
Haidar Khan committed Jul 11, 2024
1 parent e34287c commit 7638fa6
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 173 deletions.
139 changes: 63 additions & 76 deletions zero_sum_eval/games/mathquiz/mathquiz_game.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,98 @@
from zero_sum_eval.game_state import GameState
from random import randint


class MathQuizGame(GameState):
def __init__(self, roles=None, environment=None, context=None):
"""
This is a two player game where the players take turns to answer math questions.
In each round:
1. the environment is initialized with a target number
2. The first player to move creates a math question with the answer as the target number
and proves that the question is valid.
3. If the first player succeeds, the second player is given a chance to answer the question.
4. The game continues for a fixed number of rounds.
The roles for this game are:
TeacherGenerateQuestion
TeacherAnswerQuestion
StudentAnswerQuestion
The environment for this game is:
question: a math question
teacher_answer: the teacher's answer to the math question
student_answer: the student's answer to the math question
"""
def __init__(self, roles=None, environment=None, context=None, target=None):
super().__init__()
self.environment = environment
self.environment = environment if environment is not None else \
{"question": None, "teacher_answer": None, "student_answer": None}
self.roles = self.get_next_roles(self.environment) if roles is None else roles
self.context = context if context is not None else {"history": [], "message": None}
self.target = target if target is not None else randint(1, 1000)

def initialize(self, environment, context=None ,roles=None):
def initialize(self, roles=None, environment=None, context=None, target=None):
return MathQuizGame(
roles=roles,
environment=environment,
context=context
context=context,
target=target
)

def update_game(self, move):
new_context = self.context.copy()
new_environment = self.environment.copy()
if self.roles[0] == "TeacherGenerateQuestion":
new_environment['question'] = move
elif self.roles[0] == "TeacherAnswerQuestion":
new_environment['teacher_answer'] = move
elif self.roles[0] == "StudentAnswerQuestion":
new_environment['student_answer'] = move

try:
chess_move = self.board.parse_san(move)
if self.board.is_legal(chess_move):
self.board.push(chess_move)
new_context['history'].append(f"{move}")
new_context['message'] = None
#maybe fix message here
else:
new_context['message'] = f"Your move {move} is an illegal move"
except ValueError as e:
new_context['message'] = f"Your move {move} caused an error: {str(e)} "

new_environment = self.board.fen()
return self.initialize(
roles=self.roles,
environment=new_environment,
context=new_context
context=new_context,
target=self.target
)

def query_game(self):
new_context = self.context.copy()
new_roles = [self.roles[0]]
new_roles = self.get_next_roles(self.environment)
msg = self.validate_game()
new_context['message'] = msg if msg is not None else f"You will move as {self.get_next_roles(self.environment)[0]}"
new_context['message'] = msg if msg is not None else f"You will move as {new_roles[0]}"

return self.initialize(
environment=self.environment,
context=new_context,
roles=new_roles
)

def verify_answer(self, answer):
return answer == self.target

def validate_game(self):
if self.board.is_checkmate():
return "Checkmate"
elif self.board.is_stalemate():
return "Stalemate"
elif self.board.is_insufficient_material():
return "Insufficient material"
elif self.board.is_seventyfive_moves():
return "75-move rule"
elif self.board.is_fivefold_repetition():
return "Fivefold repetition"
elif not self.board.is_valid():
return "Invalid"
current_role = self.roles[0]
if self.environment['teacher_answer'] is not None:
if current_role == "TeacherAnswerQuestion":
if self.verify_answer(self.environment['teacher_answer']):
return "TeacherCorrect"
else:
return "TeacherIncorrect"
elif current_role == "StudentAnswerQuestion":
if self.verify_answer(self.environment['teacher_answer']):
return "StudentCorrect"
else:
return "StudentIncorrect"
else:
return self.context['message']

def get_next_roles(self, fen):
turn = fen.split()[1] # 'w' or 'b'
return ['White', 'Black'] if turn == 'w' else ['Black', 'White']

def formatted_move_history(self):
history = self.context['history']
formatted_history = ""
moves = len(history)//2+1
for i in range(1, moves+1):
j = (i-1)*2
formatted_history+=f"{i}."
if j < len(history):
formatted_history+=f"{history[j]} "
if j+1 < len(history):
formatted_history+=f"{history[j+1]} "
return formatted_history.strip()
def get_next_roles(self, environment):
if environment['question'] is None:
return ['TeacherGenerateQuestion', 'TeacherAnswerQuestion', 'StudentAnswerQuestion']
elif environment['teacher_answer'] is None:
return ['TeacherAnswerQuestion', 'StudentAnswerQuestion']
else:
return ['StudentAnswerQuestion']

def export(self):
return {
Expand All @@ -95,31 +109,4 @@ def display(self):


if __name__ == "__main__":
chess_game = ChessGame().initialize(chess.Board().fen())
print(chess_game.export())

# 1. e4 e5
chess_game = chess_game.update_game("e4")
print(chess_game.export())

print(chess_game.query_game().export())
chess_game = chess_game.update_game("e5")
print(chess_game.export())

# 2. Nf3 Nc6
chess_game = chess_game.update_game("Nf3")
print(chess_game.query_game().export())
chess_game = chess_game.update_game("Nc6")
print(chess_game.export())

# 3. Nxe5
chess_game = chess_game.update_game("Nxe5")
print(chess_game.export())

validation_result = chess_game.validate_game()
if validation_result:
print(f"Game validation result: {validation_result}")
else:
print("Game is valid.")


pass
152 changes: 55 additions & 97 deletions zero_sum_eval/games/mathquiz/mathquiz_player.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,73 +5,49 @@
import dspy
from dspy.primitives.assertions import assert_transform_module, backtrack_handler
from dspy.teleprompt import LabeledFewShot, BootstrapFewShot, MIPROv2, BootstrapFewShotWithRandomSearch
import copy
import chess
from chess import IllegalMoveError, InvalidMoveError, AmbiguousMoveError
import functools, json
from random import shuffle

# TODO: add support for resigning

def validate_move(example, prediction, trace=None):
pred_move = prediction.move
true_move = example.move
board_state = example.board_state
board = chess.Board(board_state)
if true_move is not None and pred_move == true_move:
return 1
elif board.is_legal(board.parse_san(pred_move)):
return 0
else:
return -1

class NextMove(dspy.Signature):
"""Given a board state, role, and move history, produce the next best valid move"""
class GenerateQuestion(dspy.Signature):
"""Given a target answer, generate a challenging math question that has the target answer"""

board_state = dspy.InputField(desc="FEN formatted current board state")
role = dspy.InputField(desc="role of the player making the next move")
history = dspy.InputField(desc="move history")
move = dspy.OutputField(desc="a valid SAN formatted move without move number or elipses")
target_answer = dspy.InputField(desc="target answer for the question")
math_question = dspy.OutputField(desc="math question with the target answer")

class AnswerQuestion(dspy.Signature):
"""Given a math question, answer the question"""

math_question = dspy.InputField(desc="math question to answer")
answer = dspy.OutputField(desc="answer to the math question")

class ChessCoT(dspy.Module):
class GenerateQuestionCoT(dspy.Module):
def __init__(self):
super().__init__()
self.cot_move = dspy.ChainOfThought(NextMove)
self.cot_question = dspy.ChainOfThought(GenerateQuestion)


def forward(self, board_state, role, history):
cot_out = self.cot_move(board_state=board_state,
role=role,
history=history)
cot_out.move = cot_out.move.replace(".", "")
try:
board = chess.Board(board_state)
move = board.parse_san(cot_out.move)
except IllegalMoveError:
dspy.Suggest(
False,
f"{cot_out.move} is an illegal move, choose a different move.",
)
except InvalidMoveError:
dspy.Suggest(
False,
f"{cot_out.move} is an invalid move, choose a different move."
)
except AmbiguousMoveError:
dspy.Suggest(
False,
f"{cot_out.move} is an ambiguous move, choose a different move."
)
def forward(self, target_answer):
cot_out = self.cot_question(target_answer=target_answer)
return cot_out

class AnswerQuestionCoT(dspy.Module):
def __init__(self):
super().__init__()
self.cot_answer = dspy.ChainOfThought(AnswerQuestion)

def forward(self, math_question):
cot_out = self.cot_answer(math_question=math_question)
return cot_out

class ChessPlayer(Player):
class MathQuizTeacher(Player):
def __init__(self, llm_model, max_tries=4, **kwargs):
super().__init__(**kwargs)
self.llm_model = llm_model
self.max_tries = max_tries
self.module = assert_transform_module(ChessCoT(), functools.partial(backtrack_handler, max_backtracks=max_tries))
self.optimized_module = self.optimize_prompts() if self.optimize else None
self.question_module = GenerateQuestionCoT()
self.answer_module = AnswerQuestionCoT()
# self.optimized_module = self.optimize_prompts() if self.optimize else None
dspy.configure(trace=[])

def make_move(self, game_state):
Expand All @@ -85,56 +61,38 @@ def make_move(self, game_state):
str: The move made by the player
"""
export = game_state.export()
current_role = game_state.roles[0]
with dspy.context(lm=self.llm_model):
if self.optimize:
trace = self.optimized_module(board_state=export['environment'],
role=export['roles'][0],
history=game_state.formatted_move_history())
else:
trace = self.module(board_state=export['environment'],
role=export['roles'][0],
history=game_state.formatted_move_history())
return trace.move
if current_role == "TeacherGenerateQuestion":
trace = self.question_module(export['target'])
return trace.math_question
elif current_role == "TeacherAnswerQuestion":
trace = self.answer_module(export['environment'])
return trace.answer

def optimize_prompts(self):
filename = 'data/chess/stockfish_examples.jsonl'
dataset = self.create_dataset(filename)
shuffle(dataset)
# config = dict(max_bootstrapped_demos=4, max_labeled_demos=16)
# teleprompter = BootstrapFewShot(metric=validate_move, **config)
with dspy.context(lm=self.llm_model):

teleprompter = MIPROv2(metric=validate_move, prompt_model=self.llm_model, task_model=self.llm_model,
num_candidates=5, minibatch_size=20, minibatch_full_eval_steps=10,
verbose=True)
return teleprompter.compile(self.module, max_bootstrapped_demos=1, max_labeled_demos=1,
trainset=dataset, eval_kwargs={})

def load_examples(self, filename):
examples = []
with open(filename, 'r') as f:
for line in f:
examples.append(json.loads(line))
return examples
class MathQuizQuestion(Player):
def __init__(self, llm_model, max_tries=4, **kwargs):
super().__init__(**kwargs)
self.llm_model = llm_model
self.max_tries = max_tries
self.answer_module = AnswerQuestionCoT()
# self.optimized_module = self.optimize_prompts() if self.optimize else None
dspy.configure(trace=[])

def create_dataset(self, filename):
examples = self.load_examples(filename)
dataset = []
for example in [examples[i] for i in range(0, len(examples), len(examples)//10)]:
if self.role =="White": # white to move
if not example['turn']:
continue
else: # black to move
if example['turn']:
continue
example = dspy.Example(board_state=example['board_state'],
role=f"{self.role}",
history=example['history'],
move=example['move']
).with_inputs("board_state", "role", "history")
dataset.append(example)
return dataset

def make_move(self, game_state):
"""
Abstract method for making a move based on the current game state.
Parameters:
game_state (GameState): The current state of the game
Returns:
str: The move made by the player
"""
export = game_state.export()
with dspy.context(lm=self.llm_model):
trace = self.answer_module(export['environment'])
return trace.answer



0 comments on commit 7638fa6

Please sign in to comment.