diff --git a/askchat/askchat.py b/askchat/askchat.py index ac1fbb8..2369d54 100644 --- a/askchat/askchat.py +++ b/askchat/askchat.py @@ -52,6 +52,7 @@ def main(): parser.add_argument("--api-key", default=None, help="API key") ## Chat with history parser.add_argument('-c', action='store_true', help='Continue the last conversation') + parser.add_argument('-r', action='store_true', help='Regenerate the last conversation') parser.add_argument('-s', "--save", default=None, help="Save the conversation to a file") parser.add_argument("-l", "--load", default=None, help="Load the conversation from a file") parser.add_argument("-p", "--print", default=None, nargs='*', help="Print the conversation from " +\ @@ -166,11 +167,15 @@ def main(): if isinstance(msg, list): msg = ' '.join(msg) assert len(msg.strip()), 'Please specify message' - if args.c and os.path.exists(LAST_CHAT_FILE): + if os.path.exists(LAST_CHAT_FILE): with open(LAST_CHAT_FILE, "r") as f: chatlog = json.load(f) - chatlog.append({"role":"user", "content":msg}) - msg = chatlog + if args.c: + msg = chatlog + [{"role":"user", "content":msg}] + elif args.r: + if len(chatlog) > 0:chatlog.pop() + if len(chatlog) > 0:chatlog.pop() + msg = chatlog + [{"role":"user", "content":msg}] # call the function chat = Chat(msg)