diff --git a/manifest/clients/openai.py b/manifest/clients/openai.py index 3855580..9cf28df 100644 --- a/manifest/clients/openai.py +++ b/manifest/clients/openai.py @@ -107,6 +107,27 @@ def get_model_params(self) -> Dict: """ return {"model_name": self.NAME, "engine": getattr(self, "engine")} + def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]: + """ + Validate response as dict. + + Args: + response: response + request: request + + Return: + response as dict + """ + validated_response = super().validate_response(response, request) + # Handle logprobs + for choice in validated_response["choices"]: + if "logprobs" in choice: + logprobs = choice.pop("logprobs") + if logprobs and "token_logprobs" in logprobs: + choice["token_logprobs"] = logprobs["token_logprobs"] + choice["tokens"] = logprobs["tokens"] + return validated_response + def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]: """Split usage into list of usages for each prompt.""" try: diff --git a/manifest/response.py b/manifest/response.py index 9a35898..f46aaf3 100644 --- a/manifest/response.py +++ b/manifest/response.py @@ -52,7 +52,7 @@ class LMModelChoice(BaseModel): text: str token_logprobs: Optional[List[float]] = None - tokens: Optional[List[int]] = None + tokens: Optional[List[str]] = None class ArrayModelChoice(BaseModel):