Skip to content

Commit

Permalink
修改 load_chats逻辑,增加 warning 提升未完成的数量 (#83)
Browse files Browse the repository at this point in the history
* 修改 load_chats逻辑,增加 warning 提升未完成的数量

* Update checkpoint.py

更少的行数

* update patch version

---------

Co-authored-by: rex <[email protected]>
  • Loading branch information
Qing25 and RexWzh authored Jun 18, 2024
1 parent 0f34798 commit 3b25b79
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 15 deletions.
2 changes: 1 addition & 1 deletion chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '[email protected]'
__version__ = '3.3.1'
__version__ = '3.3.2'

import os, sys, requests, json
from .chattype import Chat, Resp
Expand Down
25 changes: 12 additions & 13 deletions chattool/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json, warnings, os
import json, os
from typing import List, Dict, Union, Callable, Any
from .chattype import Chat
import tqdm
from loguru import logger

def load_chats( checkpoint:str):
"""Load chats from a checkpoint file
Expand All @@ -23,18 +24,16 @@ def load_chats( checkpoint:str):
if len(txts) == 1 and txts[0] == '': return []
# get the chatlogs
logs = [json.loads(txt) for txt in txts]
chat_size, chatlogs = 1, [None]
for log in logs:
idx = log['index']
if idx >= chat_size: # extend chatlogs
chatlogs.extend([None] * (idx - chat_size + 1))
chat_size = idx + 1
chatlogs[idx] = log['chat_log']
# mapping from index to chat object
idx2chatlog = { log['index']: Chat(log['chat_log']) for log in logs }
max_index = max(idx2chatlog.keys())
chat_objects = [ idx2chatlog.get(index, None) for index in range(max_index+1)]
num_unfinished = chat_objects.count(None)
# check if there are missing chatlogs
if None in chatlogs:
warnings.warn(f"checkpoint file {checkpoint} has unfinished chats")
if num_unfinished > 0:
logger.warning(f"checkpoint file {checkpoint} has {num_unfinished}/{max_index+1} unfinished chats")
# return Chat class
return [Chat(chat_log) if chat_log is not None else None for chat_log in chatlogs]
return chat_objects

def process_chats( data:List[Any]
, data2chat:Callable[[Any], Chat]
Expand All @@ -59,7 +58,7 @@ def process_chats( data:List[Any]
## load chats from the checkpoint file
chats = load_chats(checkpoint)
if len(chats) > len(data):
warnings.warn(f"checkpoint file {checkpoint} has more chats than the data to be processed")
logger.warning(f"checkpoint file {checkpoint} has more chats than the data to be processed")
return chats[:len(data)]
chats.extend([None] * (len(data) - len(chats)))
## process chats
Expand All @@ -69,4 +68,4 @@ def process_chats( data:List[Any]
chat = data2chat(data[i])
chat.save(checkpoint, mode='a', index=i)
chats[i] = chat
return chats
return chats
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '3.3.1'
VERSION = '3.3.2'

requirements = [
'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8',
Expand Down

0 comments on commit 3b25b79

Please sign in to comment.