Skip to content

Commit

Permalink
fix: logprobs from openai (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
lorr1 authored Apr 24, 2023
1 parent d7401c6 commit e559c8f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
21 changes: 21 additions & 0 deletions manifest/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,27 @@ def get_model_params(self) -> Dict:
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}

def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Validate response as dict.
Args:
response: response
request: request
Return:
response as dict
"""
validated_response = super().validate_response(response, request)
# Handle logprobs
for choice in validated_response["choices"]:
if "logprobs" in choice:
logprobs = choice.pop("logprobs")
if logprobs and "token_logprobs" in logprobs:
choice["token_logprobs"] = logprobs["token_logprobs"]
choice["tokens"] = logprobs["tokens"]
return validated_response

def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
"""Split usage into list of usages for each prompt."""
try:
Expand Down
2 changes: 1 addition & 1 deletion manifest/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class LMModelChoice(BaseModel):

text: str
token_logprobs: Optional[List[float]] = None
tokens: Optional[List[int]] = None
tokens: Optional[List[str]] = None


class ArrayModelChoice(BaseModel):
Expand Down

0 comments on commit e559c8f

Please sign in to comment.