-
Notifications
You must be signed in to change notification settings - Fork 483
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feat] support math/gms8k agent config (#494)
* support math agent * support gsm8k agent * support gsm8k agent * minor fix * minor fix * minor fix * Update configs/eval_codeagent.py
- Loading branch information
Showing
6 changed files
with
369 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from opencompass.openicl.icl_prompt_template import PromptTemplate | ||
from opencompass.openicl.icl_retriever import ZeroRetriever | ||
from opencompass.openicl.icl_inferencer import AgentInferencer | ||
from opencompass.openicl.icl_evaluator import AccEvaluator | ||
from opencompass.datasets import HFDataset, gsm8k_postprocess, gsm8k_dataset_postprocess | ||
|
||
# This config is for code interpreter | ||
gsm8k_example = """ | ||
<HUMAN>A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there? | ||
<ASSISTANT>{thought} We need to calculate the total number of fruits. The total number of fruits in the first three baskets is given, while for the fourth basket, we need to subtract 2 from each fruit category. We can solve this problem using simple arithmetic. | ||
{action} PythonInterpreter | ||
{action_input} | ||
```python | ||
def solution(): | ||
# Fruits in the first three baskets | ||
apples_first_three = 9 | ||
oranges_first_three = 15 | ||
bananas_first_three = 14 | ||
# Fruits in the fourth basket | ||
apples_fourth = apples_first_three - 2 | ||
oranges_fourth = oranges_first_three - 2 | ||
bananas_fourth = bananas_first_three - 2 | ||
# Total fruits | ||
total_fruits = ((apples_first_three + oranges_first_three + bananas_first_three) * 3 + | ||
apples_fourth + oranges_fourth + bananas_fourth) | ||
return {{"total_fruits": total_fruits}} | ||
``` | ||
<SYSTEM>{response}{{'total_fruits': 146}} | ||
<ASSISTANT> {thought} By adding the given numbers of apples, oranges, and bananas in the first three baskets, then subtracting 2 from each category for the fourth basket, we have found the total number of fruits. | ||
{finish} 146 | ||
<HUMAN>Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles? | ||
<ASSISTANT>{thought} This is a problem that requires solving equations. We know the relationship between the number of marbles, frisbees, and deck cards. Bella has twice as many marbles as frisbees, and 20 more frisbees than deck cards. Finally, we are told Bella buys 2/5 times more of each item. This purchasing will increase the number of each type of item. | ||
{action} PythonInterpreter | ||
{action_input} | ||
```python | ||
def solution(): | ||
# Given number of marbles | ||
marbles_now = 60 | ||
# Calculate number of frisbees and deck cards now | ||
frisbees_now = marbles_now / 2 | ||
cards_now = frisbees_now - 20 | ||
# Calculate number of each item after buying more | ||
marbles_then = marbles_now + (2/5) * marbles_now | ||
frisbees_then = frisbees_now + (2/5) * frisbees_now | ||
cards_then = cards_now + (2/5)*cards_now | ||
# Total number of items then | ||
total_items = marbles_then + frisbees_then + cards_then | ||
return {{"total_items": total_items}} | ||
``` | ||
<SYSTEM>{response}{{'total_items': 140.0}} | ||
<ASSISTANT>{thought} By establishing the relationships between the numbers of marbles, frisbees, and deck cards that Bella currently has, we can calculate how many of each item she will have after buying 2/5 more of each. Adding these quantities together gives us the total number of items. | ||
{finish} 140 | ||
""" | ||
|
||
gsm8k_reader_cfg = dict(input_columns=['question'], output_column='answer') | ||
|
||
gsm8k_infer_cfg = dict( | ||
prompt_template=dict(type=PromptTemplate, template='{question}'), | ||
retriever=dict(type=ZeroRetriever), | ||
inferencer=dict(type=AgentInferencer, example=gsm8k_example)) | ||
|
||
gsm8k_eval_cfg = dict( | ||
evaluator=dict(type=AccEvaluator), | ||
pred_postprocessor=dict(type=gsm8k_postprocess), | ||
dataset_postprocessor=dict(type=gsm8k_dataset_postprocess)) | ||
|
||
gsm8k_datasets = [ | ||
dict( | ||
abbr='gsm8k', | ||
type=HFDataset, | ||
path='gsm8k', | ||
name='main', | ||
reader_cfg=gsm8k_reader_cfg, | ||
infer_cfg=gsm8k_infer_cfg, | ||
eval_cfg=gsm8k_eval_cfg) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from opencompass.openicl.icl_prompt_template import PromptTemplate | ||
from opencompass.openicl.icl_retriever import ZeroRetriever | ||
from opencompass.openicl.icl_inferencer import AgentInferencer | ||
from opencompass.datasets import MATHDataset, MATHEvaluator, math_postprocess | ||
|
||
# This config is for code interpreter | ||
math_example = """ | ||
<HUMAN>Find the domain of the expression $\\frac{{\sqrt{{x-2}}}}{{\sqrt{{5-x}}}}$. | ||
<ASSISTANT>{thought} The domain restrictions are determined by: | ||
The square root in the numerator must be non-negative. | ||
The square root in the denominator must be positive (because we can't have 0 in the denominator). | ||
The value inside a square root (the radicand) must be non-negative. We can use `sympy` to determine the domain of the expression. | ||
{action} PythonInterpreter | ||
{action_input} | ||
```python | ||
from sympy import symbols, solveset, S, And | ||
def solution(): | ||
# Define the variable | ||
x = symbols('x') | ||
# Define the inequalities for the domain based on the expression | ||
inequality1 = x-2 >= 0 # because of the square root in the numerator | ||
inequality2 = 5-x > 0 # because of the square root in the denominator | ||
# Solve the inequalities | ||
domain1 = solveset(inequality1, x, domain=S.Reals) | ||
domain2 = solveset(inequality2, x, domain=S.Reals) | ||
# Find the intersection of the two domains | ||
final_domain = domain1.intersect(domain2) | ||
return final_domain | ||
``` | ||
<SYSTEM>{response}'Interval.Ropen(2, 5)' | ||
<ASSISTANT> {thought} Therefore, the domain of the expression is $\\boxed{{[2,5)}}$ | ||
{finish} [2,5) | ||
""" | ||
|
||
math_infer_cfg = dict( | ||
prompt_template=dict(type=PromptTemplate, template='{problem}'), | ||
retriever=dict(type=ZeroRetriever), | ||
inferencer=dict(type=AgentInferencer, example=math_example)) | ||
|
||
math_eval_cfg = dict( | ||
evaluator=dict(type=MATHEvaluator), | ||
pred_postprocessor=dict(type=math_postprocess)) | ||
|
||
math_datasets = [ | ||
dict( | ||
type=MATHDataset, | ||
abbr='math', | ||
path='./data/math/math.json', | ||
reader_cfg=dict( | ||
input_columns=['problem'], | ||
output_column='solution', | ||
), | ||
infer_cfg=math_infer_cfg, | ||
eval_cfg=math_eval_cfg) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from mmengine.config import read_base | ||
from opencompass.partitioners import SizePartitioner | ||
from opencompass.runners import LocalRunner | ||
from opencompass.tasks import OpenICLInferTask | ||
from opencompass.models import OpenAI, HuggingFaceCausalLM | ||
from opencompass.models.lagent import CodeAgent | ||
|
||
with read_base(): | ||
from .datasets.math.math_gen_6cca30 import math_datasets | ||
from .datasets.gsm8k.gsm8k_gen_e7ef64 import gsm8k_datasets | ||
|
||
datasets = [] | ||
datasets += gsm8k_datasets | ||
datasets += math_datasets | ||
|
||
models = [ | ||
dict( | ||
abbr='gpt-3.5-react', | ||
type=CodeAgent, | ||
llm=dict( | ||
type=OpenAI, | ||
path='gpt-3.5-turbo', | ||
key='ENV', | ||
query_per_second=1, | ||
max_seq_len=4096, | ||
), | ||
batch_size=8), | ||
dict( | ||
abbr='WizardCoder-Python-13B-V1.0-react', | ||
type=CodeAgent, | ||
llm=dict( | ||
type=HuggingFaceCausalLM, | ||
path="WizardLM/WizardCoder-Python-13B-V1.0", | ||
tokenizer_path='WizardLM/WizardCoder-Python-13B-V1.0', | ||
tokenizer_kwargs=dict( | ||
padding_side='left', | ||
truncation_side='left', | ||
trust_remote_code=True, | ||
), | ||
max_seq_len=2048, | ||
model_kwargs=dict(trust_remote_code=True, device_map='auto'), | ||
), | ||
batch_size=8, | ||
run_cfg=dict(num_gpus=2, num_procs=1)), | ||
] | ||
|
||
infer = dict( | ||
partitioner=dict(type=SizePartitioner, max_task_size=40000), | ||
runner=dict( | ||
type=LocalRunner, max_num_workers=16, | ||
task=dict(type=OpenICLInferTask)), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from typing import Union | ||
|
||
from lagent.actions import ActionExecutor | ||
from lagent.agents.base_agent import BaseAgent | ||
from lagent.agents.react import ReActProtocol | ||
from lagent.llms.base_api import BaseAPIModel | ||
from lagent.llms.base_llm import BaseModel | ||
from lagent.schema import ActionReturn, AgentReturn | ||
|
||
|
||
class ReAct(BaseAgent): | ||
"""An implementation of ReAct (https://arxiv.org/abs/2210.03629) | ||
Args: | ||
llm (BaseModel or BaseAPIModel): a LLM service which can chat | ||
and act as backend. | ||
action_executor (ActionExecutor): an action executor to manage | ||
all actions and their response. | ||
protocol (ReActProtocol): a wrapper to generate prompt and | ||
parse the response from LLM / actions. | ||
max_turn (int): the maximum number of trails for LLM to generate | ||
plans that can be successfully parsed by ReWOO protocol. | ||
""" | ||
|
||
def __init__(self, | ||
llm: Union[BaseModel, BaseAPIModel], | ||
action_executor: ActionExecutor, | ||
protocol: ReActProtocol = ReActProtocol(), | ||
max_turn: int = 2) -> None: | ||
self.max_turn = max_turn | ||
super().__init__(llm=llm, | ||
action_executor=action_executor, | ||
protocol=protocol) | ||
|
||
def opencompass_adapter(self, prompt): | ||
# adapter for prompt parsing | ||
from opencompass.utils.prompt import PromptList | ||
if isinstance(prompt, list): | ||
for p in prompt: | ||
if 'content' in p: | ||
p['prompt'] = p.pop('content') | ||
prompt = PromptList(prompt) | ||
return prompt | ||
|
||
def chat(self, message: str) -> AgentReturn: | ||
self._inner_history = [] | ||
self._inner_history.append(dict(role='user', content=message)) | ||
agent_return = AgentReturn() | ||
force_stop = False | ||
default_response = '对不起,我无法回答你的问题' | ||
for turn in range(self.max_turn): | ||
prompt = self._protocol.format( | ||
chat_history=self.session_history, | ||
inner_step=self._inner_history, | ||
action_executor=self._action_executor, | ||
force_stop=force_stop) | ||
prompt = self.opencompass_adapter(prompt) | ||
# allow single generation | ||
response = self._llm.generate_from_template([prompt], 512)[0] | ||
self._inner_history.append(dict(role='assistant', | ||
content=response)) | ||
thought, action, action_input = self._protocol.parse( | ||
response, self._action_executor) | ||
action_return: ActionReturn = self._action_executor( | ||
action, action_input) | ||
action_return.thought = thought | ||
agent_return.actions.append(action_return) | ||
if action_return.type == self._action_executor.finish_action.name: | ||
agent_return.response = action_return.result['text'] | ||
return agent_return | ||
self._inner_history.append( | ||
dict(role='system', | ||
content=self._protocol.format_response(action_return))) | ||
if turn == self.max_turn - 1: | ||
force_stop = True | ||
agent_return.response = default_response | ||
# only append the user and final response | ||
self._session_history.append(dict(role='user', content=message)) | ||
self._session_history.append( | ||
dict(role='assistant', content=agent_return.response)) | ||
return agent_return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.