diff --git a/chattool/__init__.py b/chattool/__init__.py index a543156..17a27f9 100644 --- a/chattool/__init__.py +++ b/chattool/__init__.py @@ -2,7 +2,7 @@ __author__ = """Rex Wang""" __email__ = '1073853456@qq.com' -__version__ = '3.1.7' +__version__ = '3.2.0' import os, sys, requests from .chattype import Chat, Resp diff --git a/chattool/asynctool.py b/chattool/asynctool.py index 088d3ab..6154a0f 100644 --- a/chattool/asynctool.py +++ b/chattool/asynctool.py @@ -176,8 +176,7 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str] if clearfile and os.path.exists(chkpoint): os.remove(chkpoint) if api_key is None: - api_key = chattool.api_key - assert api_key is not None, "API key is not provided!" + api_key = chattool.api_key or "" if chat_url is None: if chattool.api_base: chat_url = os.path.join(chattool.api_base, "chat/completions") @@ -187,7 +186,7 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str] raise Exception("chat_url is not provided!") chat_url = chattool.request.normalize_url(chat_url) if 'model' not in options: - options['model'] = chattool.model if chattool.model else "gpt-3.5-turbo" + options['model'] = chattool.model or "" # run async process assert nproc > 0, "nproc must be greater than 0!" max_tries = max(max_tries, max_requests) diff --git a/chattool/chattype.py b/chattool/chattype.py index aa5aa83..3d5200c 100644 --- a/chattool/chattype.py +++ b/chattool/chattype.py @@ -8,6 +8,7 @@ import aiohttp import os from .functioncall import generate_json_schema, delete_dialogue_assist +from pprint import pformat class Chat(): def __init__( self @@ -17,6 +18,8 @@ def __init__( self , base_url:Union[None, str]=None , chat_url:Union[None, str]=None , model:Union[None, str]=None + , tools:Union[None, List[Dict]]=None + , tool_choice:Union[None, str]=None , functions:Union[None, List[Dict]]=None , function_call:Union[None, str]=None , name2func:Union[None, Dict]=None): @@ -29,9 +32,12 @@ def __init__( self base_url (Union[None, str], optional): base url without suffix "/v1". Defaults to None. Example: "https://api.openai.com" chat_url (Union[None, str], optional): chat completion url. Defaults to None. Example: "https://api.openai.com/v1/chat/completions" model (Union[None, str], optional): model to use. Defaults to None. - functions (Union[None, List[Dict]], optional): functions to use, each function is a JSON Schema. Defaults to None. - function_call (str, optional): method to call the function. Defaults to None. Choices: ['auto', '$NameOfTheFunction', 'none'] + tools (Union[None, List[Dict]], optional): tools to use, each tool is a JSON Schema. Defaults to None. + tool_choice (Union[None, str], optional): method to choose the tool. Defaults to None. Choices: ['auto', '$NameOfTheTool', 'none'] name2func (Union[None, Dict], optional): name to function mapping. Defaults to None. + functions (Union[None, List[Dict]], optional): Decrpcated. functions to use, each function is a JSON Schema. Defaults to None. + function_call (str, optional): Decrpcated. method to call the function. Defaults to None. Choices: ['auto', '$NameOfTheFunction', 'none'] + Raises: ValueError: msg should be a list of dict, a string or None @@ -65,14 +71,17 @@ def __init__( self self.chat_url = "https://api.openai.com/v1/chat/completions" if functions is not None: assert isinstance(functions, list), "functions should be a list of dict" - self._functions, self._function_call = functions, function_call + if tools is not None: + assert isinstance(tools, list), "tools should be a list of dict" + self.functions, self.tools = functions or [], tools or [] + self._function_call, self._tool_choice = function_call, tool_choice self._name2func, self._resp = name2func, None # Part1: basic operation of the chat object def add(self, role:str, **kwargs): """Add a message to the chat log""" - assert role in ['user', 'assistant', 'system', 'function'],\ - f"role should be one of ['user', 'assistant', 'system', 'function'], but got {role}" + assert role in ['user', 'assistant', 'system', 'tool', 'function'],\ + f"role should be one of ['user', 'assistant', 'system', 'tool'], but got {role}" self._chat_log.append({'role':role, **kwargs}) return self @@ -80,17 +89,19 @@ def user(self, content: Union[List, str]): """User message""" return self.add('user', content=content) - def assistant(self, content:Union[None, str], function_call:Union[None, Dict]=None): + def assistant(self, content:Union[None, str]): """Assistant message""" - if function_call is not None: - assert isinstance(function_call, dict), "function_call should be a dict" - return self.add('assistant', content=content, function_call=function_call) return self.add('assistant', content=content) def function(self, content, name:str, dump:bool=True): """Add a message to the chat log""" if dump: content = json.dumps(content) return self.add('function', content=content, name=name) + + def tool(self, content, name:str, tool_call_id:str, dump:bool=True): + """Add a message to the chat log""" + if dump: content = json.dumps(content) + return self.add('tool', content=content, name=name, tool_call_id=tool_call_id) def system(self, content:str): """System message""" @@ -112,6 +123,8 @@ def deepcopy(self): , model=self.model , functions=self.functions , function_call=self.function_call + , tools=self.tools + , tool_choice=self.tool_choice , name2func=self.name2func , api_base=self.api_base , base_url=self.base_url) @@ -167,13 +180,18 @@ def load(path:str): def display_role_content(dic:dict, sep:Union[str, None]=None): """Show the role and content of the message""" if sep is None: sep = '\n' + '-'*15 + '\n' - role, content = dic['role'], dic['content'] - if role == 'user' or role == 'system' or (role == 'assistant' and 'function_call' not in dic): - return f"{sep}{role}{sep}{dic['content']}" + role, content, name, tools = dic['role'], dic.get('content'), dic.get('name'), dic.get('tool_calls') + if role == 'user' or role == 'system': + return f"{sep}{role}{sep}{content}" + elif role == 'tool': + return f"{sep}{role}{sep}tool:\n\t{name}\nresult:\n\t{content}" elif role == 'function': - return f"{sep}{role}{sep}function:\n\t{dic['name']}\nresult:\n\t{content}" + return f"{sep}{role}{sep}function:\n\t{name}\nresult:\n\t{content}" elif role == 'assistant': - return f"{sep}{role}{sep}calling function:\n\t{dic['function_call']}\ncontent:\n\t{content}" + if 'tool_calls' in dic: + return f"{sep}{role}{sep}calling tool:\n{pformat(tools)}" + if 'function_call' in dic: + return f"{sep}{role}{sep}calling function:\n{pformat(dic['function_call'])}" else: raise Exception(f"Unknown role {role}") @@ -189,7 +207,11 @@ def getresponse( self , timeinterval:int = 0 , update:bool = True , stream:bool = False + , tools:Union[None, List[Dict]]=None + , tool_choice:Union[None, str]=None , max_requests:int=-1 + , functions:Union[None, List[Dict]]=None + , function_call:Union[None, str]=None , **options)->Resp: """Get the API response @@ -205,10 +227,16 @@ def getresponse( self Resp: API response """ # initialize data - api_key, model, chat_url = self.api_key, self.model, self.chat_url - funcs = options.get('functions', self.functions) - func_call = options.get('function_call', self.function_call) - if api_key is None: warnings.warn("API key is not set!") + api_key, chat_url = self.api_key, self.chat_url + if 'model' not in options: options['model'] = self.model + # function call & tool call + tool_choice, tools = tool_choice or self.tool_choice, tools or self.tools + function_call, functions = function_call or self.function_call, functions or self.functions + if tool_choice is not None: + options['tool_choice'], options['tools'] = tool_choice, tools + elif function_call is not None: + options['function_call'], options['functions'] = function_call, functions + # if api_key is None: warnings.warn("API key is not set!") msg, resp, numoftries = self.chat_log, None, 0 max_tries = max(max_tries, max_requests) if stream: # TODO: add the `usage` key to the response @@ -217,10 +245,8 @@ def getresponse( self while max_tries: try: # make API Call - if funcs is not None: options['functions'] = funcs - if func_call is not None: options['function_call'] = func_call response = chat_completion( - api_key=api_key, messages=msg, model=model, + api_key=api_key, messages=msg, chat_url=chat_url, timeout=timeout, **options) resp = Resp(response) assert resp.is_valid(), resp.error_message @@ -256,17 +282,19 @@ async def async_stream_responses( self self.api_key, self.chat_url, self.chat_log, self.model, timeout=timeout, **options): yield resp.delta_content if textonly else resp - # Part3: function call + # Part3: tool call def iswaiting(self): """Whether the response is waiting""" if len(self) == 0: return False - return self[-1]['role'] == 'assistant' and 'function_call' in self[-1] + return self[-1]['role'] == 'assistant' and ('tool_calls' in self[-1] or 'function_call' in self[-1]) @staticmethod def get_name_and_params(dic:dict): - """Get the name and parameters of the function call""" - if 'role' in dic and 'function_call' in dic: - dic = dic['function_call'] + """Get the name and parameters of the tool call""" + if 'role' in dic and 'tool_calls' in dic: + dic = dic['tool_calls'] + elif 'role' in dic and 'function_call' in dic: + dic = dic['function_call'] name, params = dic['name'], json.loads(dic['arguments']) return name, params @@ -276,58 +304,97 @@ def simplify(self): return self def setfuncs(self, funcs:List): - """Initialize function for function call""" + """Initialize function for tool calls""" self.functions = [generate_json_schema(func) for func in funcs] self.function_call = 'auto' self.name2func = {func.__name__:func for func in funcs} return True + + def settools(self, tools:List): + """Initialize tools for tool calls""" + self._functions =[generate_json_schema(func) for func in tools] + self.tool_choice = 'auto' + self.name2func = {tool.__name__:tool for tool in tools} + return True + + def calltools(self, display:bool=False): + """Call all the tools""" + if not self.iswaiting(): + return False, "Not waiting for tool call." + tool_calls = self[-1]['tool_calls'] + allright = True + for tool in tool_calls: + result, name, tool_call_id, status = self.calltool(tool) + self.tool(result, name, tool_call_id) + if display: print(self.display_role_content(self[-1])) + allright = allright and status + return allright + + def calltool(self, tool): + """Call the tool""" + tool_call_id = tool['id'] + tool_name, tool_para = tool['function']['name'], tool['function']['arguments'] + if tool_name not in self.name2func: + return f"Tool {tool_name} not found.", tool_name, tool_call_id, False + try: + tool_args = json.loads(tool_para) + except Exception as e: + return f"Argument parsing failed with error: {e}", tool_name, tool_call_id, False + try: + result = self.name2func[tool_name](**tool_args) + except Exception as e: + return f"Tool {tool_name} failed with error: {e}", tool_name, tool_call_id, False + # succeed finally! + return result, tool_name, tool_call_id, True def callfunction(self): """Calling function""" if not self.iswaiting(): - return False, "Not waiting for function call." + return "Not waiting for function call.", name name = self[-1]['function_call']['name'] if name not in self.name2func: - return False, f"Function {name} not found." + return f"Function {name} not found.", name, False try: args = json.loads(self[-1]['function_call']['arguments']) except Exception as e: - return False, f"Cannot parse the arguments, error: {e}" + return f"Cannot parse the arguments, error: {e}", name, False try: result = self.name2func[name](**args) except Exception as e: - return False, f"Function {name} failed with error: {e}" + return f"Function {name} failed with error: {e}", name, False # succeed finally! - self.function(result, name) - return True, "Function called successfully." + return result, name, True def autoresponse( self , display:bool=False , maxturns:int=3 - , capturerr:bool=True + , use_tool:bool=True , **options): """Get the response automatically Args: display (bool, optional): whether to display the response. Defaults to False. maxturns (int, optional): maximum number of turns. Defaults to 3. - capturerr (bool, optional): if True, use the error message as the response. Defaults to True. options (dict, optional): other options like `temperature`, `top_p`, etc. Returns: bool: whether the response is finished """ - options['functions'], options['function_call'] = self.functions, self.function_call + if use_tool: + options['tools'], options['tool_choice'] = self.tools, self.tool_choice or 'auto' + else: + options['functions'], options['function_call'] = self.functions, self.function_call or 'auto' show = lambda msg: print(self.display_role_content(msg)) resp = self.getresponse(**options) if display: show(resp.message) while self.iswaiting() and maxturns != 0: # call api and update the result - status, msg = self.callfunction() - if not status: # update the error msg - if not capturerr: return False - self.function(msg, 'error') - if display: show(self[-1]) + if use_tool: + self.calltools(display=display) + else: + result, name, _ = self.callfunction() + self.function(result, name) + if display: show(self[-1]) resp = self.getresponse(**options) if display: show(resp.message) maxturns -= 1 @@ -404,6 +471,18 @@ def functions(self): def function_call(self): """Get function call""" return self._function_call + + @property + def tools(self): + """Get tools""" + if self.functions is not None: + return [{'type':'function', 'function': func} for func in self.functions] + return None + + @property + def tool_choice(self): + """Get tool choice""" + return self._tool_choice @property def name2func(self): @@ -451,12 +530,30 @@ def function_call(self, function_call:str): assert 'name' in function_call self._function_call = function_call + @tools.setter + def tools(self, tools:List[Dict]): + """Set tools""" + assert isinstance(tools, list), "tools should be a list of dict" + self._functions = [tool['function'] for tool in tools] + + @tool_choice.setter + def tool_choice(self, tool_choice:str): + """Set tool choice""" + if tool_choice in ['auto', None, 'none']: + self._tool_choice = tool_choice + elif isinstance(tool_choice, str): + self.tool_choice = {'name':tool_choice} + elif isinstance(tool_choice, dict): + assert 'name' in tool_choice + self._tool_choice = tool_choice + @name2func.setter def name2func(self, name2func:Dict): """Set name to function mapping""" assert isinstance(name2func, dict), "name2func should be a dict" self._name2func = name2func + @property def chat_log(self): """Chat history""" diff --git a/chattool/functioncall.py b/chattool/functioncall.py index 6778444..309a1e1 100644 --- a/chattool/functioncall.py +++ b/chattool/functioncall.py @@ -65,10 +65,10 @@ def delete_dialogue_assist(chat_log:List[Dict]): ind = 0 while ind < len(chat_log) - 1: log = chat_log[ind] - if log['role'] == 'assistant' and 'function_call' in log: + if log['role'] == 'assistant' and ('tool_calls' in log or 'function_call' in log): nextind = ind + 1 nextlog = chat_log[nextind] - if nextlog['role'] == 'function': + if nextlog['role'] == 'tool' or nextlog['role'] == 'function': chat_log.pop(nextind) chat_log.pop(ind) else: diff --git a/chattool/response.py b/chattool/response.py index 0c39dad..493d8e0 100644 --- a/chattool/response.py +++ b/chattool/response.py @@ -69,6 +69,11 @@ def function_call(self): """Function call""" return self.message.get('function_call') + @property + def tool_calls(self): + """Tool calls""" + return self.message.get('tool_calls') + @property def delta(self): """Delta""" diff --git a/setup.py b/setup.py index 3bdfc3c..90e6b14 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ with open('README.md') as readme_file: readme = readme_file.read() -VERSION = '3.1.7' +VERSION = '3.2.0' requirements = [ 'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8', diff --git a/tests/test_function.py b/tests/test_function.py index cdf88da..34908fe 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -54,12 +54,12 @@ def test_auto_response(): chat = Chat("What's the weather like in Boston?") chat.functions, chat.function_call = functions, 'auto' chat.name2func = name2func - chat.autoresponse(max_tries=2) + chat.autoresponse(max_tries=2, use_tool=False) chat.print_log() chat.clear() # response with nonempty content chat.user("what is the result of 1+1, and What's the weather like in Boston?") - chat.autoresponse(max_tries=2) + chat.autoresponse(max_tries=2, use_tool=False) # generate docstring from functions def add(a: int, b: int) -> int: @@ -89,6 +89,17 @@ def mult(a:int, b:int) -> int: """ return a * b +def test_mix_function_tool(): + chat = Chat("find the sum of 784359345 and 345345345") + chat.setfuncs([add]) + chat.autoresponse(max_tries=3, display=True, timeinterval=2) + chat.clear() + chat.user("find the sum of 784359345 and 345345345") + chat.autoresponse(use_tool=False) + newchat = Chat("find the product of 123124 and 399090") + newchat.settools([mult]) + newchat.autoresponse() + def test_add_and_mult(): functions = [generate_json_schema(add)] chat = Chat("find the sum of 784359345 and 345345345") @@ -100,35 +111,26 @@ def test_add_and_mult(): chat.name2func = {'add': add} # dictionary of functions chat.function_call = 'auto' # auto decision # run until success: maxturns=-1 - chat.autoresponse(max_tries=3, display=True, timeinterval=2) + chat.autoresponse(max_tries=3, display=True, timeinterval=2, use_tool=False) # response should be finished chat.simplify() chat.print_log() # use the setfuncs method chat = Chat("find the value of 124842 * 3423424") chat.setfuncs([add, mult]) # multi choice - chat.autoresponse(max_tries=3, timeinterval=2) + chat.autoresponse(max_tries=3, timeinterval=2, use_tool=False) chat.simplify() # simplify the chat log chat.print_log() # test multichoice chat.clear() - chat.user("find the value of 23723 + 12312, and 23723 * 12312") - # chat.autoresponse(max_tries=3, timeinterval=2) - -def test_mock_resp(): - chat = Chat("find the sum of 1235 and 3423") - chat.setfuncs([add, mult]) - # mock result of the resp - para = {'name': 'add', 'arguments': '{\n "a": 1235,\n "b": 3423\n}'} - chat.assistant(content=None, function_call=para) - chat.callfunction() - chat.getresponse(max_tries=2) + chat.user("find the value of (23723 * 1322312 ) + 12312") + chat.autoresponse(max_tries=3, timeinterval=2, use_tool=False) def test_use_exec_function(): chat = Chat("find the result of sqrt(121314)") chat.setfuncs([exec_python_code]) - chat.autoresponse(max_tries=2, display=True) - + # chat.autoresponse(max_tries=2, display=True, use_tool=False) + def test_find_permutation_group(): pass diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 0000000..caf76fa --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,145 @@ +# tests for function call + +from chattool import Chat, generate_json_schema, exec_python_code +import json + +# schema of functions +tools = [ + { + "type": "function", + "function":{ + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location", "unit"], + }, + } + } +] +weatherinfo = { + "location": "Boston, MA", + "temperature": "72", + "forecast": ["sunny", "windy"], + "unit":"celsius" +} +name2func = { + 'get_current_weather': lambda *kargs, **kwargs: weatherinfo +} + +def test_call_weather(): + chat = Chat("What's the weather like in Boston?") + resp = chat.getresponse(tools=tools, tool_choice='auto', max_tries=3) + if resp.finish_reason == 'tool_calls': + last_tool = chat[-1]['tool_calls'][0] # get the last tool call + parainfo, tool_call_id = last_tool['function'], last_tool['id'] + tool_name, tool_args = parainfo['name'], json.loads(parainfo['arguments']) + assert tool_name == 'get_current_weather' + assert 'location' in tool_args and 'unit' in tool_args + # continue the chat + # tool call result: weatherinfo + chat.tool(weatherinfo, tool_name, tool_call_id) + chat.getresponse() + chat.print_log() + else: + print("No function call found.") + assert True + +def test_auto_response(): + chat = Chat("What's the weather like in Boston?") + chat.tools, chat.tool_choice = tools, 'auto' + chat.name2func = name2func + chat.autoresponse(max_tries=2, display=True) + chat.print_log() + newchat = chat.deepcopy() + newchat.clear() + # response with nonempty content + newchat.user("what is the result of 1+1, and What's the weather like in Boston?") + newchat.autoresponse(max_tries=2, display=True) + +# generate docstring from functions +def add(a: int, b: int) -> int: + """ + This function adds two numbers. + + Parameters: + a (int): The first number. + b (int): The second number. + + Returns: + int: The sum of the two numbers. + """ + return a + b + +# with optional parameters +def mult(a:int, b:int) -> int: + """This function multiplies two numbers. + It is a useful calculator! + + Args: + a (int): The first number. + b (int): The second number. + + Returns: + int: The product of the two numbers. + """ + return a * b + +def test_add_and_mult(): + tools = [{ + 'type':'function', + 'function': generate_json_schema(tool)} for tool in [add, mult]] + chat = Chat("find the sum of 784359345 and 345345345") + chat.tools = tools + chat.tool_choice = None # unset keyword equivalent to "auto" + chat.tool_choice = 'none' + chat.tool_choice = {'name':'add'} + chat.tool_choice = 'add' # specify the function(convert to dict) + chat.tools = tools + chat.name2func = {'add': add} # dictionary of functions + chat.tool_choice = 'auto' # auto decision + # run until success: maxturns=-1 + chat.autoresponse(max_tries=3, display=True, timeinterval=2) + # response should be finished + chat.simplify() + chat.print_log() + # use the setfuncs method + chat2 = Chat("find the value of 124842 * 3423424") + chat2.settools([add, mult]) # multi choice + chat2.autoresponse(max_tries=3, display=True, timeinterval=2) + chat2.simplify() # simplify the chat log + chat2.print_log() + # test multichoice + chat3 = chat2.deepcopy() + chat3.clear() + chat3.user("find the value of 23723 + 12312, and 23723 * 12312") + chat3.autoresponse(max_tries=3, display=True, timeinterval=2) + # test multichoice + chat4 = chat2.deepcopy() + chat4.clear() + chat4.user("find the value of (23723 * 1322312 ) + 12312") + chat4.autoresponse(max_tries=3, display=True, timeinterval=2) + + +def test_use_exec_function(): + chat = Chat("find the result of sqrt(121314)") + chat.settools([exec_python_code]) + chat.autoresponse(max_tries=2, display=True) + +def test_find_permutation_group(): + pass + +def test_interact_with_leandojo(): + pass + +# debug area +# test_generate_docstring() +# test_function_call() +# test_function_call2()