-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathagents.py
97 lines (77 loc) · 3.14 KB
/
agents.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from typing import Dict, List, Optional, Tuple
import openai
class Prover:
"""An agent that uses an interactive theorem prover to generate proofs."""
def __init__(
self,
goal: str,
system_prompt_path: str = "prompts/prover.txt",
model: str = "gpt-3.5-turbo-0613",
):
self._model = model
self._context: List[Dict[str, str]] = []
# Load the system prompt from the path
with open(system_prompt_path, "r") as f:
system_prompt = f.read()
self._context.append({"role": "system", "content": system_prompt})
# Add the goal to the context
self._context.append({"role": "user", "content": goal})
def _call_model(self, temperature: float = 0.5) -> str:
"""Call the model with the current context."""
response = openai.ChatCompletion.create(
model=self._model,
messages=self._context,
temperature=temperature,
)
model_output = response["choices"][0]["message"]["content"]
return model_output
def step(self, input: Optional[str]) -> Tuple[str, str]:
if input is not None:
# Add the user input to the context
self._context.append({"role": "user", "content": input})
while True:
model_output = self._call_model()
print(f"Raw model output: {model_output}")
# Add the model output to the context
self._context.append({"role": "assistant", "content": model_output})
# Split on code:
split = model_output.split("```")
if len(split) == 3:
break
# Step again, but remind the model to use the format.
self._context.append(
{
"role": "user",
"content": "Please use the specified format. Do not leave out any of the sections. Do not add any additional output.",
}
)
natural = model_output.split("```")[0]
coq = model_output.split("```")[1]
# Strip and leading or trailing whitespace
coq = coq.strip()
# Strip any lines containing ```
coq = "\n".join([line for line in coq.split("\n") if "```" not in line])
return natural, coq
class Checker:
"""A stateless agent that evaluates the proofs generated by the prover."""
def __init__(
self,
system_prompt_path: str = "prompts/checker.txt",
model: str = "gpt-3.5-turbo-0613",
):
self._model = model
with open(system_prompt_path, "r") as f:
self._system_prompt = f.read()
def check(self, input: str) -> Tuple[str, bool]:
"""Return feedback on the proof."""
response = openai.ChatCompletion.create(
model=self._model,
messages=[
{"role": "system", "content": self._system_prompt},
{"role": "user", "content": input},
],
)
model_output = response["choices"][0]["message"]["content"]
# Check if the proof is ACCEPTED
accepted = "ACCEPTED" in model_output
return model_output, accepted