diff --git a/backend/src/chess/engine.ts b/backend/src/chess/engine.ts new file mode 100644 index 0000000..1687eda --- /dev/null +++ b/backend/src/chess/engine.ts @@ -0,0 +1,87 @@ +import { Chess, validateFen } from "chess.js"; +import { llmInterpretPrompt } from "./llm"; +import { assertUnreachable } from "../utils/assertions"; + + +class NormalMove { + square1: string; + square2: string; + constructor(square1: string, square2: string) { + this.square1 = square1; + this.square2 = square2; + } +} + +class PromotionMove extends NormalMove { + piece: "q" | "r" | "b" | "n"; + constructor(square1: string, square2: string, piece: "q" | "r" | "b" | "n") { + super(square1, square2) + this.piece = piece; + } +} + +class InvalidMove { + prompt: string; + constructor(prompt: string) { + this.prompt = prompt; + } +} + +class FailedMove { + error: string; + constructor(error: string) { + this.error = error; + } + +} + +type Move = InvalidMove | NormalMove | PromotionMove; + +const FEN = "rnbqkb1r/1p2ppp1/3p4/p2n3p/3P4/3B1N2/PPP2PPP/RNBQK2R w KQkq - 0 7"; +const chess = new Chess(FEN); + +function movePiece(square1, square2): string | FailedMove { + const piece = chess.remove(square1); + if (!piece) { + console.log(piece) + return null; + } + chess.put(piece, square2); + const validate = validateFen(chess.fen()); + if (validate.ok) { + return chess.fen(); + }else { + console.log(validateFen(chess.fen())); + return new FailedMove(validate.error); + } +} + +function promotePiece(square1, square2, piece): string | FailedMove { + const pawn = chess.remove(square1); + if (!pawn) { + console.log(pawn) + return null; + } + chess.put({ type: piece, color: pawn.color }, square2); + const validate = validateFen(chess.fen()); + if (validate.ok) { + return chess.fen(); + } else { + console.log(validateFen(chess.fen())); + return new FailedMove(validate.error); + } +} + +export async function interpretMove(prompt: string, fen: string): Promise { + const move = await llmInterpretPrompt(prompt, fen); + if (move instanceof NormalMove) { + return movePiece(move.square1, move.square2); + } else if (move instanceof PromotionMove) { + return promotePiece(move.square1, move.square2, move.piece); + } else if (move instanceof InvalidMove) { + assertUnreachable(move); + } +} + + +export {Move, NormalMove, PromotionMove, InvalidMove} \ No newline at end of file diff --git a/backend/src/chess/llm.ts b/backend/src/chess/llm.ts new file mode 100644 index 0000000..4005d40 --- /dev/null +++ b/backend/src/chess/llm.ts @@ -0,0 +1,132 @@ +import { GoogleGenerativeAI } from "@google/generative-ai"; +import { Piece } from "chess.js"; +import { assertUnreachable, assertNever } from "../utils/assertions.js"; +import dotenv from "dotenv"; +import { InvalidMove, Move, NormalMove, PromotionMove } from "./engine.js"; +dotenv.config(); + +async function displayTokenCount(model, request) { + const { totalTokens } = await model.countTokens(request); + console.log("Token count: ", totalTokens); +} + +async function displayChatTokenCount(model, chat, msg) { + const history = await chat.getHistory(); + const msgContent = { role: "user", parts: [{ text: msg }] }; + await displayTokenCount(model, { contents: [...history, msgContent] }); +} + +// Access your API key as an environment variable (see "Set up your API key" above) +const genAI = new GoogleGenerativeAI(process.env.GEMINI_KEY); +const FEN = "rnbqkb1r/1p2ppp1/3p4/p2n3p/3P4/3B1N2/PPP2PPP/RNBQK2R w KQkq - 0 7"; + +export async function llmInterpretPrompt( + prompt: string, + fen: string +): Promise { + // For text-only input, use the gemini-pro model + const model = genAI.getGenerativeModel({ model: "gemini-pro" }); + const chat = model.startChat({ + history: [ + { + role: "user", + parts: `Assume the role of an Next Generation Chess Interpreter. Players will describe their moves within 15 words and you are to parse them into valid chess moves. Your response can be one of the following: + + 1. (square, square), to move a piece from the first square to the second square. For example, ('e2', 'e4') + 2. (square, square, piece), to promote a pawn to a piece. For example, ('e7', 'e8', 'q'). This promotes the pawn at e7 to a queen. The piece can be a 'q' (queen), 'r' (rook), 'b' (bishop), or 'n' (knight). + 3. 'Invalid move', if the move does not make sense or is illegal. + If you understand, respond with 'Yes, I understand'. The current game state is provided by the following FEN: ${fen}`, + }, + { + role: "model", + parts: "Yes, I understand.", + }, + ], + generationConfig: { + maxOutputTokens: 1000, + }, + }); + + try { + const result = await chat.sendMessage(prompt); + const response = await result.response; + const text = response.text(); + const safe = await llmCheckMoveValidity(text, fen); + console.log(parseResponseMove(text), safe); + if (safe) { + return parseResponseMove(text); + } else { + return new InvalidMove("Illegal Move: " + text); + } + } catch (e) { + console.log(e); + } +} + +async function llmCheckMoveValidity( + prompt: string, + fen: string +): Promise { + // For text-only input, use the gemini-pro model + const model = genAI.getGenerativeModel({ model: "gemini-pro" }); + const chat = model.startChat({ + history: [ + { + role: "user", + parts: `Assume the role of an Next Generation Chess Interpreter. Given the FEN of the current game, you are to determine whether a move is legal. The input can be one of the following formats: + + 1. (square, square), to move a piece from the first square to the second square. For example, ('e2', 'e4') moves the piece at e2 to e4. + 2. (square, square, piece), to promote a pawn to a piece. For example, ('e7', 'e8', 'q'), promotes the pawn at e7 to a queen. The piece can be a 'q' (queen), 'r' (rook), 'b' (bishop), or 'n' (knight). + + If the move is legal, respond with 'True'. If the move is illegal, respond with 'False'. You should only have either 'True' or 'False' in your response. + If you understand, respond with 'Yes, I understand'. The current game state is provided by the following FEN: ${fen}`, + }, + { + role: "model", + parts: "Yes, I understand.", + }, + ], + generationConfig: { + maxOutputTokens: 100, + }, + }); + + try { + const result = await chat.sendMessage(prompt); + const response = await result.response; + const text = response.text(); + return text === "True"; + } catch (e) { + console.log(e); + } +} + +function parseResponseMove(response: string): Move { + // if response contains 'Invalid move', return 'Invalid move' + if (response.includes("Invalid move")) { + return new InvalidMove(response); + } + + // check if response is in the format (square, square) + const moveRegex = /\(\'([abcdefgh]\d)\', \'([abcdefgh]\d)\'\)/; + const moveMatch = response.match(moveRegex); + if (moveMatch) { + const [_, square1, square2] = moveMatch; + return new NormalMove(square1, square2); + } + + // check if response is in the format (square, square) + const promotionRegex = + /\(\'([abcdefgh]\d)\', \'([abcdefgh])\d\', '([qrbn])'\)/; + const promotionMatch = response.match(promotionRegex); + if (promotionMatch) { + const [_, square1, square2, piece] = promotionMatch; + if (piece === "q" || piece === "r" || piece === "b" || piece === "n") { + return new PromotionMove(square1, square2, piece); + } else { + assertNever(); + } + } + + return new InvalidMove(`Illegal Move: \n ${response}`); +} diff --git a/backend/src/llm.ts b/backend/src/llm.ts deleted file mode 100644 index 148030e..0000000 --- a/backend/src/llm.ts +++ /dev/null @@ -1,18 +0,0 @@ -const { GoogleGenerativeAI } = require("@google/generative-ai"); - -// Access your API key as an environment variable (see "Set up your API key" above) -const genAI = new GoogleGenerativeAI(process.env.GEMINI_KEY); - -async function run() { - // For text-only input, use the gemini-pro model - const model = genAI.getGenerativeModel({ model: "gemini-pro"}); - - const prompt = "Write a story about a magic backpack." - - const result = await model.generateContent(prompt); - const response = await result.response; - const text = response.text(); - console.log(text); -} - -run(); \ No newline at end of file diff --git a/backend/src/utils/assertions.ts b/backend/src/utils/assertions.ts new file mode 100644 index 0000000..a1519e3 --- /dev/null +++ b/backend/src/utils/assertions.ts @@ -0,0 +1,13 @@ +/** + * Throws an exception indicating that a code path is unreachable. + * This is mainly used for TypeScript compiler checks. + * @param {never} value - The value that should never be reached. + * @returns {never} + */ +export function assertUnreachable(value) { + throw new Error(`Unreachable code reached with value: ${value}`); +} + +export function assertNever() { + throw new Error(`Unreachable code reached`); +} \ No newline at end of file diff --git a/backend/test/chessEngine.js b/backend/test/chessEngine.js new file mode 100644 index 0000000..a47735c --- /dev/null +++ b/backend/test/chessEngine.js @@ -0,0 +1,19 @@ +import { Chess } from 'chess.js' + +const FEN = "rnbqkb1r/1p2ppp1/3p4/p2n3p/3P4/3B1N2/PPP2PPP/RNBQK2R w KQkq - 0 7"; +const chess = new Chess(FEN) + +function move(square1, square2) { + const piece = chess.remove(square1); + chess.put(piece, square2); + return chess.fen(); +} + +function promote(square1, square2, piece) { + const pawn = chess.remove(square1); + chess.put({ type: piece, color: pawn.color}, square2); + return chess.fen(); +} + +console.log(move('e2', 'e4')); +console.log(promote('e7', 'e8', 'q')); \ No newline at end of file diff --git a/backend/test/illegalChess.js b/backend/test/illegalChess.js index 8f7d387..b7a7f34 100644 --- a/backend/test/illegalChess.js +++ b/backend/test/illegalChess.js @@ -1,7 +1,35 @@ -import { Chess } from 'chess.js' +import { Chess, validateFen } from "chess.js"; const FEN = "rnbqkb1r/1p2ppp1/3p4/p2n3p/3P4/3B1N2/PPP2PPP/RNBQK2R w KQkq - 0 7"; -const chess = new Chess(FEN) -console.log(chess.ascii()) +const chess = new Chess(FEN); -chess.put({}) \ No newline at end of file +function move(square1, square2) { + const piece = chess.remove(square1); + if (!piece) { + console.log(piece) + return null; + } + chess.put(piece, square2); + if (!validateFen(chess.fen()).ok) { + console.log(validateFen(chess.fen())); + // return null; + } + return chess.fen(); +} + +function promote(square1, square2, piece) { + const pawn = chess.remove(square1); + if (!pawn) { + console.log(pawn) + return null; + } + chess.put({ type: piece, color: pawn.color }, square2); + if (!validateFen(chess.fen()).ok) { + console.log(validateFen(chess.fen())); + // return null; + } + return chess.fen(); +} + +console.log(move("e2", "e4")); +console.log(promote("g7", "g8", "q")); diff --git a/backend/test/llm.js b/backend/test/llm.js index af0f692..86010c4 100644 --- a/backend/test/llm.js +++ b/backend/test/llm.js @@ -42,12 +42,14 @@ async function run(prompt, fen) { }); try { - // const result = await model.generateContent(prompt); const result = await chat.sendMessage(prompt) const response = await result.response; const text = response.text(); const safe = await check(text); console.log(parseResponse(text), safe); + if (safe) { + return parseResponse(text); + } } catch (e) { console.log(e); } @@ -74,12 +76,11 @@ async function check(prompt, fen) { }, ], generationConfig: { - maxOutputTokens: 1000, + maxOutputTokens: 100, }, }); try { - // const result = await model.generateContent(prompt); const result = await chat.sendMessage(prompt) const response = await result.response; const text = response.text();