-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMCTS-Llama.py
130 lines (106 loc) · 5.34 KB
/
MCTS-Llama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# -*- coding: utf-8 -*-
"""MCTS-Llama.py"""
import random
import math
import ollama
import numpy as np
from sentence_transformers import SentenceTransformer
class TreeNode:
def __init__(self, query, response=None):
self.query = query
self.response = response
self.score = 0
self.visits = 0
self.children = []
self.parent = None
def add_child(self, child_node):
child_node.parent = self
self.children.append(child_node)
class MCT:
def __init__(self, model_name, query, ground_truth):
self.model_name = model_name
self.root = TreeNode(query)
self.ground_truth = ground_truth
self.best_response = [None, -1]
self.token_model = SentenceTransformer('all-mpnet-base-v2')
def ollama_response(self, prompt, model):
response = ollama.chat(
model=model,
messages=[{'role': 'user', 'content': prompt}],
stream=False,
format='json',
)
return response['message']
def expand(self, node: TreeNode, num_responses: int):
if node.visits == 0:
response = self.generate_response(node.query)
response_text = self.extract_text_from_response(response)
node.response = response_text
score = self.simulate(node)
self.backpropagate(node, score)
else:
refined_query = self.refine_query(node.query, node.response, self.ground_truth)
responses = [self.generate_response(refined_query) for _ in range(num_responses)]
for response in responses:
response_text = self.extract_text_from_response(response)
child_node = TreeNode(query=refined_query, response=response_text)
node.add_child(child_node)
def simulate(self, node: TreeNode):
score = self.evaluate_response(node.response)
node.score = score
return score
def backpropagate(self, node: TreeNode, score: float):
while node:
node.visits += 1
node.score = max(node.score, score)
if score > self.best_response[1]:
self.best_response = [node.response, score]
node = node.parent
print(self.best_response)
def evaluate_response(self, response: str) -> float:
a = self.token_model.encode(response)
b = self.token_model.encode(self.ground_truth)
similarity = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
return similarity
def select(self, node: TreeNode) -> TreeNode:
while node.children:
node = max(node.children, key=lambda child: self.ucb1(child))
return node
def ucb1(self, node: TreeNode) -> float:
if node.visits == 0:
return float('inf')
ucb_score = node.score + math.sqrt(2 * math.log(node.parent.visits) / node.visits)
return ucb_score
def generate_response(self, query: str) -> str:
response = self.ollama_response(prompt=query, model=self.model_name)
return response
def extract_text_from_response(self, response) -> str:
return response
def refine_query(self, query: str, response: str, ground_truth: str) -> str:
refined_query = f"""I have a query, an initial result generated by an LLM, and the correct answer to the query. Your task is to generate a refined query that would lead the LLM towards providing a response closer to the ground truth answer.
Query: ```{query}```
Initial Result Generated by the LLM: ```{response}```
Ground Truth Answer: ```{ground_truth}```
Instructions:
Understand the Context:
Carefully analyze the original query to understand its intent and the specific information it seeks.
Evaluate the initial result generated by the LLM, identifying any inaccuracies or deviations from the ground truth answer.
Compare the initial result with the ground truth answer to pinpoint the key differences and understand the correct information needed.
Generate a Refined Query:
Formulate a new query that is more specific, clear, and targeted to address the gaps or inaccuracies in the initial result.
Ensure the refined query is logical, precise, and concise, directly aiming to elicit a response that aligns with the ground truth answer.
Avoid introducing new information or context that deviates from the original query’s intent.
IMPORTANT :
the new query you generate should not lose its meaning which it had began with"""
return refined_query
def search(self, iterations: int):
for _ in range(iterations):
node = self.select(self.root)
self.expand(node, num_responses=2)
if self.best_response[1] >= 0.90:
print("Best response found with score 0.90 or higher. Ending search.")
break
def mcts(model_name, query, ground_truth, iterations=2):
mct = MCT(model_name=model_name, query=query, ground_truth=ground_truth)
mct.search(iterations=iterations)
return mct.best_response