Skip to content

Commit

Permalink
more prompt engineering
Browse files Browse the repository at this point in the history
  • Loading branch information
Oceankoh committed Jan 20, 2024
1 parent 2c1de39 commit c4f72d9
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 80 deletions.
43 changes: 22 additions & 21 deletions backend/src/chess/llm.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { BlockReason, GoogleGenerativeAI } from "@google/generative-ai";
import { Piece } from "chess.js";
import { assertUnreachable, assertNever } from "../utils/assertions.js";
import { BlockReason, FinishReason, GoogleGenerativeAI } from "@google/generative-ai";
import { assertNever, assertUnreachable } from "../utils/assertions.js";
import dotenv from "dotenv";
import { InvalidMove, Move, NormalMove, PromotionMove } from "./engine.js";
dotenv.config();
Expand All @@ -18,7 +17,6 @@ async function displayChatTokenCount(model, chat, msg) {

// Access your API key as an environment variable (see "Set up your API key" above)
const genAI = new GoogleGenerativeAI(process.env.GEMINI_KEY);
const FEN = "rnbqkb1r/1p2ppp1/3p4/p2n3p/3P4/3B1N2/PPP2PPP/RNBQK2R w KQkq - 0 7";

export async function llmInterpretPrompt(
prompt: string,
Expand Down Expand Up @@ -56,27 +54,32 @@ export async function llmInterpretPrompt(
const result = await chat.sendMessage(prompt);
const response = await result.response;

if (
response.promptFeedback &&
response.promptFeedback.blockReason !==
BlockReason.BLOCKED_REASON_UNSPECIFIED
) {
if (response.candidates[0].finishReason === FinishReason.MAX_TOKENS) {
return new InvalidMove(
"Blocked Prompt: The response returned was too long. Please try again."
);
} else if (response.candidates[0].finishReason === FinishReason.SAFETY) {
return new InvalidMove(
"Blocked Prompt: " + response.promptFeedback.blockReasonMessage
"Blocked Prompt: The prompt was flagged as harmful. Please try again."
);
}

const text = response.text();
const parsed = parseResponseMove(text);
if (parsed instanceof InvalidMove) {
return parsed;
}
const safe = await llmCheckMoveValidity(parsed, fen);
console.log(parseResponseMove(text), safe);
if (safe) {
return parseResponseMove(text);
} else if (
parsed instanceof PromotionMove ||
parsed instanceof NormalMove
) {
const safe = await llmCheckMoveValidity(parsed, fen);
if (safe) {
return parsed;
} else {
return new InvalidMove(`Illegal Move: ${text}`);
}
} else {
return new InvalidMove(`Illegal Move: ${text}`);
assertUnreachable(parsed);
}
}

Expand Down Expand Up @@ -113,11 +116,9 @@ async function llmCheckMoveValidity(
);

const response = await result.response;
if (
response.promptFeedback &&
response.promptFeedback.blockReason !==
BlockReason.BLOCKED_REASON_UNSPECIFIED
) {
if (response.candidates[0].finishReason === FinishReason.MAX_TOKENS) {
return false;
} else if (response.candidates[0].finishReason === FinishReason.SAFETY) {
return false;
}

Expand Down
172 changes: 113 additions & 59 deletions backend/test/llm.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,52 @@
import { GoogleGenerativeAI } from "@google/generative-ai";
import {
BlockReason,
FinishReason,
GoogleGenerativeAI,
} from "@google/generative-ai";
import dotenv from "dotenv";

dotenv.config();
class NormalMove {
square1;
square2;
constructor(square1, square2) {
this.square1 = square1;
this.square2 = square2;
}

toString() {
return `(${this.square1}, ${this.square2})`;
}
}
class PromotionMove extends NormalMove {
piece;
constructor(square1, square2, piece) {
super(square1, square2);
this.piece = piece;
}

toString() {
return `(${this.square1}, ${this.square2}, '${this.piece}')`;
}
}

class InvalidMove {
prompt;
constructor(prompt) {
this.prompt = prompt;
}

toString() {
return "Invalid Move";
}
}

class FailedMove {
error;
constructor(error) {
this.error = error;
}
}

async function displayTokenCount(model, request) {
const { totalTokens } = await model.countTokens(request);
Expand All @@ -17,8 +63,8 @@ async function displayChatTokenCount(model, chat, msg) {
const genAI = new GoogleGenerativeAI(process.env.GEMINI_KEY);
const FEN = "rnbqkb1r/1p2ppp1/3p4/p2n3p/3P4/3B1N2/PPP2PPP/RNBQK2R w KQkq - 0 7";

async function run(prompt, fen) {
console.log("prompt: ", prompt);
export async function llmInterpretPrompt(prompt, fen) {
console.log(prompt);
// For text-only input, use the gemini-pro model
const model = genAI.getGenerativeModel({ model: "gemini-pro" });
const chat = model.startChat({
Expand All @@ -27,45 +73,55 @@ async function run(prompt, fen) {
role: "user",
parts: `Assume the role of an Next Generation Chess Interpreter. Players will describe their moves and you are to parse them into valid chess moves.
The current game state is provided by the following FEN: ${fen}
Your response must be one of the following:
The current game state is provided by the following FEN: ${fen}
Your response must be one of the following:
1. (<square>, <square>), to move a piece from the first square to the second square. For example, ('e2', 'e4')
2. (<square>, <square>, <piece>), to promote a pawn to a piece. For example, ('e7', 'e8', 'q'). This promotes the pawn at e7 to a queen. The piece can be a 'q' (queen), 'r' (rook), 'b' (bishop), or 'n' (knight).
1. (<square>, <square>), to move a piece from the first square to the second square. For example, ('e2', 'e4')
2. (<square>, <square>, <piece>), to promote a pawn to a piece. For example, ('e7', 'e8', 'q'). This promotes the pawn at e7 to a queen. The piece can be a 'q' (queen), 'r' (rook), 'b' (bishop), or 'n' (knight).
This is very important: You should only have either a move formatted as (<square>, <square>) or (<square>, <square>, <piece>) in your response.
This is very important: You should only have either a move formatted as (<square>, <square>) or (<square>, <square>, <piece>) in your response.
If you understand, respond with 'Yes, I understand'.`,
If you understand, respond with 'Yes, I understand'.`,
},
{
role: "model",
parts: "Yes, I understand.",
},
],
generationConfig: {
maxOutputTokens: 1000,
maxOutputTokens: 500,
},
});

try {
const result = await chat.sendMessage(prompt);
const response = await result.response;
const text = response.text();
console.log("before parsing: ", text);
const parsed = parseResponse(text);
console.log("after parsing: ", parsed);
const safe = await check(parsed, fen);
if (safe) {
return parsed;
}
} catch (e) {
console.log(e);
const result = await chat.sendMessage(prompt);
const response = await result.response;

if (response.candidates[0].finishReason === FinishReason.MAX_TOKENS) {
return new InvalidMove(
"Blocked Prompt: The response returned was too long. Please try again."
);
} else if (response.candidates[0].finishReason === FinishReason.SAFETY) {
return new InvalidMove(
"Blocked Prompt: The prompt was flagged as harmful. Please try again."
);
}

const text = response.text();
const parsed = parseResponseMove(text);
console.log("parsed", parsed);
if (parsed instanceof InvalidMove) {
return parsed;
}
const safe = await llmCheckMoveValidity(parsed, fen);
if (safe) {
return parsed;
} else {
return new InvalidMove(`Illegal Move: ${text}`);
}
console.log(`\n\n\n\n`);
}

async function check(prompt, fen) {
async function llmCheckMoveValidity(prompt, fen) {
// For text-only input, use the gemini-pro model
const model = genAI.getGenerativeModel({ model: "gemini-pro" });
const chat = model.startChat({
Expand All @@ -90,38 +146,33 @@ async function check(prompt, fen) {
},
});

try {
console.log(fen);
const result = await chat.sendMessage(
`The current game state is provided by the following FEN: ${fen}. The move to be made is ${prompt}`
const result = await chat.sendMessage(
`The current game state is provided by the following FEN: ${fen}. The move to be made is ${prompt.toString()}`
);
const response = await result.response;

if (response.candidates[0].finishReason === FinishReason.MAX_TOKENS) {
return new InvalidMove(
"Blocked Prompt: The response returned was too long. Please try again."
);
} else if (response.candidates[0].finishReason === FinishReason.SAFETY) {
return new InvalidMove(
"Blocked Prompt: The prompt was flagged as harmful. Please try again."
);
const response = await result.response;
const text = response.text();
console.log("safety check", text);
return text === "True";
} catch (e) {
console.log(e);
}
}

function parseResponse(response) {
console.log(response);
response = response.trim();
// accept only (square, square) or (square, square, piece) or 'Invalid move'

// if response contains 'Invalid move', return 'Invalid move'
if (response.includes("Invalid move")) {
// return response;
return "Invalid Move";
}
const text = response.text();
console.log("safety check", text);
return text === "True";
}

function parseResponseMove(response) {
// check if response is in the format (square, square)
const moveRegex = /\(\'?([abcdefgh]\d)\'?,\s?\'?([abcdefgh]\d)\'?\)/;
const moveMatch = response.match(moveRegex);
if (moveMatch) {
const [_, square1, square2] = moveMatch;
// return { square1, square2 };
return `(${square1}, ${square2})`;
return new NormalMove(square1, square2);
}

// check if response is in the format (square, square)
Expand All @@ -130,22 +181,25 @@ function parseResponse(response) {
const promotionMatch = response.match(promotionRegex);
if (promotionMatch) {
const [_, square1, square2, piece] = promotionMatch;
// return { square1, square2, piece };
return `(${square1}, ${square2}, ${piece})`;
if (piece === "q" || piece === "r" || piece === "b" || piece === "n") {
return new PromotionMove(square1, square2, piece);
} else {
// assertNever();
}
}

// return `Illegal Response: \n ${response}`;
return `Illegal Response`;
console.log("Invalid Response: ", response);
return new InvalidMove(`Invalid Response: ${response}`);
}

// user prompt
const prompt1 = "capture the opponent's rook";
const prompt2 = "advance and promote all my pawns";
const prompt3 = "deliver a checkmate";
const prompt4 = "('e2', 'e8', 'q')";
const prompt5 = "move the piece at d1 to d2";

await run(prompt1, FEN);
await run(prompt2, FEN);
await run(prompt3, FEN);
await run(prompt4, FEN);
await run(prompt5, FEN);
await llmInterpretPrompt(prompt1, FEN);
await llmInterpretPrompt(prompt2, FEN);
await llmInterpretPrompt(prompt3, FEN);
await llmInterpretPrompt(prompt4, FEN);
await llmInterpretPrompt(prompt5, FEN);

0 comments on commit c4f72d9

Please sign in to comment.