Skip to content

Commit

Permalink
[typescript] Save output.data with text content instead of response data
Browse files Browse the repository at this point in the history
This comes after Sarmad's schema updates in #589. To keep diffs small and easier to review, this simply converts from model-specific outputs --> pure text. I have a diff in #610 which converts from pure text --> `OutputData` format.


We only needed to update the `hf.py` and `openai.py`, because `palm.py` aalready returns output in the form of `string | null` type.

Ran yarn automated tests, but there aren't any specifically for openai. Not sure how to run the `demo.ts` file which would also be a reasonable test to ensure everything there works too


For the extensions, we only have typescript for `hf.ts` (trivial: just changed `response` to `response.generated_text`), while `llama.ts` already outputs it in text format so no changes needed
  • Loading branch information
Rossdan Craig [email protected] committed Dec 25, 2023
1 parent 9b6768f commit 8e3a0be
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 31 deletions.
13 changes: 6 additions & 7 deletions extensions/HuggingFace/typescript/hf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
ExecuteResult,
AIConfigRuntime,
InferenceOptions,
CallbackEvent
CallbackEvent,
} from "aiconfig";
import _ from "lodash";
import * as aiconfig from "aiconfig";
Expand Down Expand Up @@ -211,7 +211,7 @@ export class HuggingFaceTextGenerationModelParserExtension extends Parameterized
const response = await this.hfClient.textGenerationStream(
textGenerationArgs
);
output = await ConstructStreamOutput(
output = await constructStreamOutput(
response,
options as InferenceOptions
);
Expand Down Expand Up @@ -248,8 +248,7 @@ export class HuggingFaceTextGenerationModelParserExtension extends Parameterized
}

if (output.output_type === "execute_result") {
return (output.data as TextGenerationOutput | TextGenerationStreamOutput)
.generated_text as string;
return output.data ?? "";
} else {
return "";
}
Expand All @@ -262,7 +261,7 @@ export class HuggingFaceTextGenerationModelParserExtension extends Parameterized
* @param options
* @returns
*/
async function ConstructStreamOutput(
async function constructStreamOutput(
response: AsyncGenerator<TextGenerationStreamOutput>,
options: InferenceOptions
): Promise<Output> {
Expand All @@ -280,7 +279,7 @@ async function ConstructStreamOutput(

output = {
output_type: "execute_result",
data: delta,
data: accumulatedMessage,
execution_count: index,
metadata: metadata,
} as ExecuteResult;
Expand All @@ -294,7 +293,7 @@ function constructOutput(response: TextGenerationOutput): Output {

const output = {
output_type: "execute_result",
data: data,
data: data.generated_text,
execution_count: 0,
metadata: metadata,
} as ExecuteResult;
Expand Down
2 changes: 1 addition & 1 deletion typescript/__tests__/parsers/hf/hf.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ describe("HuggingFaceTextGeneration ModelParser", () => {

const expectedOutput = {
output_type: "execute_result",
data: { generated_text: "Test text generation" },
data: "Test text generation",
execution_count: 0,
metadata: {},
};
Expand Down
17 changes: 6 additions & 11 deletions typescript/lib/parsers/hf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ export class HuggingFaceTextGenerationParser extends ParameterizedModelParser<Te
const response = await this.hfClient.textGenerationStream(
textGenerationArgs
);
output = await ConstructStreamOutput(
output = await constructStreamOutput(
response,
options as InferenceOptions
);
Expand Down Expand Up @@ -240,8 +240,7 @@ export class HuggingFaceTextGenerationParser extends ParameterizedModelParser<Te
}

if (output.output_type === "execute_result") {
return (output.data as TextGenerationOutput | TextGenerationStreamOutput)
.generated_text as string;
return output.data as string;
} else {
return "";
}
Expand All @@ -254,7 +253,7 @@ export class HuggingFaceTextGenerationParser extends ParameterizedModelParser<Te
* @param options
* @returns
*/
async function ConstructStreamOutput(
async function constructStreamOutput(
response: AsyncGenerator<TextGenerationStreamOutput>,
options: InferenceOptions
): Promise<Output> {
Expand All @@ -272,7 +271,7 @@ async function ConstructStreamOutput(

output = {
output_type: "execute_result",
data: delta,
data: accumulatedMessage,
execution_count: index,
metadata: metadata,
} as ExecuteResult;
Expand All @@ -281,16 +280,12 @@ async function ConstructStreamOutput(
}

function constructOutput(response: TextGenerationOutput): Output {
const metadata = {};
const data = response;

const output = {
output_type: "execute_result",
data: data,
data: response.generated_text,
execution_count: 0,
metadata: metadata,
metadata: _.omit(response, ["generated_text"]),
} as ExecuteResult;

return output;
}

Expand Down
37 changes: 25 additions & 12 deletions typescript/lib/parsers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ export class OpenAIChatModelParser extends ParameterizedModelParser<Chat.ChatCom
const input: PromptInput =
message.role === "user" ? message.content ?? "" : { ...message };

const responseWithoutContent = _.omit(assistantResponse, ["content"]);
const prompt: Prompt = {
name: `${promptName}_${prompts.length + 1}`,
input,
Expand All @@ -378,7 +379,8 @@ export class OpenAIChatModelParser extends ParameterizedModelParser<Chat.ChatCom
? [
{
output_type: "execute_result",
data: { ...assistantResponse },
data: assistantResponse.content,
metadata: { ...responseWithoutContent },
},
]
: undefined,
Expand Down Expand Up @@ -536,13 +538,15 @@ export class OpenAIChatModelParser extends ParameterizedModelParser<Chat.ChatCom
const outputs: ExecuteResult[] = [];
const responseWithoutChoices = _.omit(response, "choices");
for (const choice of response.choices) {
const messageWithoutContent = _.omit(choice.message, ["content"]);
const output: ExecuteResult = {
output_type: "execute_result",
data: { ...choice.message },
data: choice.message?.content,
execution_count: choice.index,
metadata: {
finish_reason: choice.finish_reason,
...responseWithoutChoices,
...messageWithoutContent,
},
};

Expand Down Expand Up @@ -585,12 +589,16 @@ export class OpenAIChatModelParser extends ParameterizedModelParser<Chat.ChatCom
/*index*/ choice.index
);

const messageWithoutContent = _.omit(message, ["content"]);
const output: ExecuteResult = {
output_type: "execute_result",
data: { ...message },
// TODO (rossdanlm): Handle ChatCompletionMessage.function_call
// too (next diff)
data: message?.content,
execution_count: choice.index,
metadata: {
finish_reason: choice.finish_reason,
...messageWithoutContent,
},
};
outputs.set(choice.index, output);
Expand Down Expand Up @@ -625,11 +633,10 @@ export class OpenAIChatModelParser extends ParameterizedModelParser<Chat.ChatCom
}

if (output.output_type === "execute_result") {
const message = output.data as Chat.ChatCompletionMessageParam;
if (message.content != null) {
return message.content;
} else if (message.function_call) {
return JSON.stringify(message.function_call);
if (typeof output.data === "string") {
return output.data;
} else if (output.metadata?.function_call) {
return JSON.stringify(output.metadata?.function_call);
} else {
return "";
}
Expand Down Expand Up @@ -671,11 +678,17 @@ export class OpenAIChatModelParser extends ParameterizedModelParser<Chat.ChatCom
const output = aiConfig.getLatestOutput(prompt);
if (output != null) {
if (output.output_type === "execute_result") {
const outputMessage =
output.data as unknown as Chat.ChatCompletionMessageParam;
// If the prompt has output saved, add it to the messages array
if (outputMessage.role === "assistant") {
messages.push(outputMessage);
if (output.metadata?.role === "assistant") {
if (typeof output.data === "string") {
messages.push({
content: output.data,
role: output.metadata?.role,
function_call: output.metadata?.function_call,
name: output.metadata?.name,
});
}
// TODO (rossdanlm): Support function_call
}
}
}
Expand Down

0 comments on commit 8e3a0be

Please sign in to comment.