Skip to content

Commit

Permalink
Explicit str type checking in model parsers
Browse files Browse the repository at this point in the history
See comment here for context: #611 (comment)

Separating this diff from #611 to make it easier and get that one unblocked
  • Loading branch information
Rossdan Craig [email protected] committed Dec 28, 2023
1 parent 7329eef commit 87e5390
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,10 @@ def construct_stream_output(
)
accumulated_message = ""
for new_text in streamer:
accumulated_message += new_text
options.stream_callback(new_text, accumulated_message, 0)
output.data = accumulated_message
if isinstance(new_text, str):
accumulated_message += new_text
options.stream_callback(new_text, accumulated_message, 0)
output.data = accumulated_message
return output


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,15 @@ async def run_inference(
output_text = self.tokenizer.decode(
response[0][prompt_len:], skip_special_tokens=True
)
output_data_content: str = ''
if isinstance(output_text, str):
output_data_content = output_text
else:
raise ValueError(f"Output {output_text} needs to be of type 'str' but is of type: {type(output_text)}")
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": output_text,
"data": output_data_content,
"execution_count": 0,
"metadata": {},
}
Expand Down
8 changes: 7 additions & 1 deletion extensions/llama/python/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,15 @@ async def _run_inference_helper(self, model_input, options) -> List[Output]:
if options:
options.stream_callback(data, acc, index)
print(flush=True)

output_data_value: str = ''
if isinstance(acc, str):
output_data_value = acc
else:
raise ValueError(f"Output {acc} needs to be of type 'str' but is of type: {type(acc)}")
return ExecuteResult(
output_type="execute_result",
data=acc,
data=output_data_value,
metadata={}
)
else:
Expand Down

0 comments on commit 87e5390

Please sign in to comment.