Skip to content

Commit

Permalink
Merge pull request #87 from cubenlp/rex/dev
Browse files Browse the repository at this point in the history
update async method
  • Loading branch information
RexWzh authored Sep 29, 2024
2 parents 731934b + 797f7b8 commit a195149
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 18 deletions.
9 changes: 8 additions & 1 deletion chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '[email protected]'
__version__ = '3.3.3'
__version__ = '3.3.4'

import os, sys, requests, json
from .chattype import Chat, Resp
Expand Down Expand Up @@ -87,6 +87,13 @@ def save_envs(env_file:str):
elif platform.startswith("darwin"):
platform = "macos"

# is jupyter notebook
try:
get_ipython
is_jupyter = True
except:
is_jupyter = False

def default_prompt(msg:str):
"""Default prompt message for the API call
Expand Down
44 changes: 42 additions & 2 deletions chattool/chattype.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .functioncall import generate_json_schema, delete_dialogue_assist
from pprint import pformat
from loguru import logger
import asyncio

class Chat():
def __init__( self
Expand Down Expand Up @@ -235,7 +236,7 @@ def getresponse( self
max_tries = max(max_tries, max_requests)
if options.get('stream'):
options['stream'] = False
warnings.warn("Use `async_stream_responses()` instead.")
warnings.warn("Use `stream_responses` instead.")
options = self._init_options(**options)
# make requests
api_key, chat_log, chat_url = self.api_key, self.chat_log, self.chat_url
Expand All @@ -258,11 +259,49 @@ async def async_stream_responses( self
Returns:
str: response text
Examples:
>>> chat = Chat("Hello")
>>> # in Jupyter notebook
>>> async for resp in chat.async_stream_responses():
>>> print(resp)
"""
async for resp in _async_stream_responses(
self.api_key, self.chat_url, self.chat_log, self.model, timeout=timeout, **options):
yield resp.delta_content if textonly else resp

def stream_responses(self, timeout:int=0, textonly:bool=True, **options):
"""Post request synchronously and stream the responses
Args:
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
textonly (bool, optional): whether to only return the text. Defaults to True.
options (dict, optional): other options like `temperature`, `top_p`, etc.
Returns:
str: response text
Examples:
>>> chat = Chat("Hello")
>>> for resp in chat.stream_responses():
>>> print(resp)
"""
assert not chattool.is_jupyter, "use `await chat.async_stream_responses()` in Jupyter notebook"
async_gen = self.async_stream_responses(timeout=timeout, textonly=textonly, **options)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
while True:
try:
# Run the async generator to get each response
response = loop.run_until_complete(async_gen.__anext__())
yield response
except StopAsyncIteration:
# End the generator when the async generator is exhausted
break
finally:
loop.close()

# Part3: tool call
def iswaiting(self):
"""Whether the response is waiting"""
Expand Down Expand Up @@ -396,7 +435,8 @@ def get_valid_models(self, gpt_only:bool=True)->List[str]:
model_url = os.path.join(self.api_base, 'models')
elif self.base_url:
model_url = os.path.join(self.base_url, 'v1/models')
return valid_models(self.api_key, model_url, gpt_only=gpt_only)
model_list = valid_models(self.api_key, model_url, gpt_only=gpt_only)
return sorted(set(model_list))

def get_curl(self, use_env_key:bool=False, **options):
"""Get the curl command
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '3.3.3'
VERSION = '3.3.4'

requirements = [
'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8',
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest
from chattool import *

TEST_PATH = 'tests/testfiles/'

@pytest.fixture(scope="session")
def testpath():
return TEST_PATH
11 changes: 6 additions & 5 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
chatlogs = [
[{"role": "user", "content": f"Print hello using {lang}"}] for lang in langs
]
testpath = 'tests/testfiles/'

def test_simple():
# set api_key in the environment variable
Expand All @@ -32,6 +31,8 @@ async def show_resp(chat):
async for resp in chat.async_stream_responses():
print(resp.delta_content, end='')
asyncio.run(show_resp(chat))
for resp in chat.stream_responses():
print(resp, end='')

def test_async_typewriter():
def typewriter_effect(text, delay):
Expand All @@ -51,23 +52,23 @@ async def show_resp(chat):
chat = Chat("Print hello using Python")
asyncio.run(show_resp(chat))

def test_async_process():
def test_async_process(testpath):
chkpoint = testpath + "test_async.jsonl"
t = time.time()
async_chat_completion(chatlogs[:1], chkpoint, clearfile=True, nproc=3)
async_chat_completion(chatlogs, chkpoint, nproc=3)
print(f"Time elapsed: {time.time() - t:.2f}s")

# broken test
def test_failed_async():
def test_failed_async(testpath):
api_key = chattool.api_key
chattool.api_key = "sk-invalid"
chkpoint = testpath + "test_async_fail.jsonl"
words = ["hello", "Can you help me?", "Do not translate this word", "I need help with my homework"]
resp = async_chat_completion(words, chkpoint, clearfile=True, nproc=3)
chattool.api_key = api_key

def test_async_process_withfunc():
def test_async_process_withfunc(testpath):
chkpoint = testpath + "test_async_withfunc.jsonl"
words = ["hello", "Can you help me?", "Do not translate this word", "I need help with my homework"]
def msg2log(msg):
Expand All @@ -77,7 +78,7 @@ def msg2log(msg):
return chat.chat_log
async_chat_completion(words, chkpoint, clearfile=True, nproc=3, msg2log=msg2log)

def test_normal_process():
def test_normal_process(testpath):
chkpoint = testpath + "test_nomal.jsonl"
def data2chat(data):
chat = Chat(data)
Expand Down
3 changes: 1 addition & 2 deletions tests/test_chattool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from chattool import cli
from chattool import Chat, Resp, findcost
import pytest
testpath = 'tests/testfiles/'


def test_command_line_interface():
Expand All @@ -21,7 +20,7 @@ def test_command_line_interface():
assert '--help Show this message and exit.' in help_result.output

# test for the chat class
def test_chat():
def test_chat(testpath):
# initialize
chat = Chat()
assert chat.chat_log == []
Expand Down
5 changes: 2 additions & 3 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os, responses
from chattool import Chat, load_chats, process_chats, api_key
testpath = 'tests/testfiles/'

def test_with_checkpoint():
def test_with_checkpoint(testpath):
# save chats without chatid
chat = Chat()
checkpath = testpath + "tmp.jsonl"
Expand Down Expand Up @@ -38,7 +37,7 @@ def test_with_checkpoint():
]
assert chats == [Chat(log) if log is not None else None for log in chat_logs]

def test_process_chats():
def test_process_chats(testpath):
def msg2chat(msg):
chat = Chat()
chat.system("You are a helpful translator for numbers.")
Expand Down
3 changes: 1 addition & 2 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import chattool
from chattool import Chat, save_envs, load_envs
testpath = 'tests/testfiles/'

def test_model_api_key():
api_key, model = chattool.api_key, chattool.model
Expand Down Expand Up @@ -43,7 +42,7 @@ def test_apibase():

chattool.api_base, chattool.base_url = api_base, base_url

def test_env_file():
def test_env_file(testpath):
save_envs(testpath + "chattool.env")
with open(testpath + "test.env", "w") as f:
f.write("OPENAI_API_KEY=sk-132\n")
Expand Down
3 changes: 1 addition & 2 deletions tests/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
)
import pytest, chattool, os
api_key, base_url, api_base = chattool.api_key, chattool.base_url, chattool.api_base
testpath = 'tests/testfiles/'

def test_valid_models():
if chattool.api_base:
Expand Down Expand Up @@ -40,7 +39,7 @@ def test_normalize_url():
assert normalize_url("api.openai.com") == "https://api.openai.com"
assert normalize_url("example.com/foo/bar") == "https://example.com/foo/bar"

def test_broken_requests():
def test_broken_requests(testpath):
"""Test the broken requests"""
with open(testpath + "test.txt", "w") as f:
f.write("hello world")
Expand Down

0 comments on commit a195149

Please sign in to comment.