diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index 344c79b3..90bd6d8f 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -99,7 +99,7 @@ def chat( self, inputs: Union[List[dict], List[List[dict]]], **gen_params, - ) -> List[str]: + ) -> Union[str, List[str]]: """Generate responses given the contexts. Args: @@ -108,7 +108,7 @@ def chat( gen_params: additional generation configuration Returns: - List[str]: A list of generated strings. + Union[str, List[str]]: generated string(s) """ assert isinstance(inputs, list) if isinstance(inputs[0], dict): @@ -120,7 +120,8 @@ def chat( for messages in inputs ] wait(tasks) - return [task.result() for task in tasks] + ret = [task.result() for task in tasks] + return ret[0] if isinstance(inputs[0], dict) else ret def _chat(self, messages: List[dict], **gen_params) -> str: """Generate completion from a list of templates.