Skip to content

Commit

Permalink
[ez] Auto-format typescript files (#592)
Browse files Browse the repository at this point in the history
[ez] Auto-format typescript files

Doing this before next diff to keep things separate
  • Loading branch information
rossdanlm authored Dec 22, 2023
2 parents 2bf5458 + 395fad8 commit aa463e2
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 30 deletions.
29 changes: 19 additions & 10 deletions typescript/__tests__/parsers/palm-text/palm.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import { getAPIKeyFromEnv } from "../../../lib/utils";

const PALM_CONFIG_PATH = path.join(__dirname, "palm-text.aiconfig.json");

const mockGetApiKeyFromEnv = getAPIKeyFromEnv as jest.MockedFunction<typeof getAPIKeyFromEnv>;
const mockGetApiKeyFromEnv = getAPIKeyFromEnv as jest.MockedFunction<
typeof getAPIKeyFromEnv
>;

// This could probably be abstracted out into a test util
jest.mock("../../../lib/utils", () => {
Expand All @@ -26,22 +28,29 @@ describe("PaLM Text ModelParser", () => {
// no need to instantiate model parser. Load will instantiate it for us since its a default parser
const aiConfig = AIConfigRuntime.load(PALM_CONFIG_PATH);

const completionParams: protos.google.ai.generativelanguage.v1beta2.IGenerateTextRequest = {
model: "models/text-bison-001",
// Note: top_p matches global config settings for the model and temperature is different
topP: 0.9,
temperature: 0.8,
prompt: { text: "What are 5 interesting things to do in Toronto?" },
};
const completionParams: protos.google.ai.generativelanguage.v1beta2.IGenerateTextRequest =
{
model: "models/text-bison-001",
// Note: top_p matches global config settings for the model and temperature is different
topP: 0.9,
temperature: 0.8,
prompt: { text: "What are 5 interesting things to do in Toronto?" },
};

// Casting as JSONObject since the type of completionParams is protos.google.ai.generativelanguage.v1beta2.IGenerateTextRequest doesn't confrom to shape even though it looks like it does
const prompts = (await aiConfig.serialize("models/text-bison-001", completionParams as JSONObject, "interestingThingsToronto")) as Prompt[];
const prompts = (await aiConfig.serialize(
"models/text-bison-001",
completionParams as JSONObject,
"interestingThingsToronto"
)) as Prompt[];

expect(prompts).toHaveLength(1);
const prompt = prompts[0];

expect(prompt.name).toEqual("interestingThingsToronto");
expect(prompt.input).toEqual("What are 5 interesting things to do in Toronto?");
expect(prompt.input).toEqual(
"What are 5 interesting things to do in Toronto?"
);
expect(prompt.metadata?.model).toEqual({
name: "models/text-bison-001",
settings: {
Expand Down
1 change: 0 additions & 1 deletion typescript/demo/test-hf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,3 @@ async function run() {
}

run();

71 changes: 55 additions & 16 deletions typescript/lib/parsers/palm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ export class PaLMTextParser extends ParameterizedModelParser {
super();
}

public serialize(promptName: string, data: JSONObject, aiConfig: AIConfigRuntime, params?: JSONObject | undefined): Prompt | Prompt[] {
public serialize(
promptName: string,
data: JSONObject,
aiConfig: AIConfigRuntime,
params?: JSONObject | undefined
): Prompt | Prompt[] {
const startEvent = {
name: "on_serialize_start",
file: __filename,
Expand All @@ -38,14 +43,18 @@ export class PaLMTextParser extends ParameterizedModelParser {

// input type was found by looking at the impl of text generation api. When calling textGeneration, step into the defintion of the function and look at the type of the input parameter
// ModelParser abstract method types data as JSONObject, but we know that the data is going to be of type protos.google.ai.generativelanguage.v1beta2.IGenerateTextRequest.
const input = data as protos.google.ai.generativelanguage.v1beta2.IGenerateTextRequest;
const input =
data as protos.google.ai.generativelanguage.v1beta2.IGenerateTextRequest;

const prompt = input.prompt?.text as string;
const modelName = input.model as string;

let modelMetadata: ModelMetadata;
// Once relevant attributes are parsed, we no longer need them. These attributes get moved to their respective fields and the rest of the attributes are passed to the model as settings (model metadata).
modelMetadata = aiConfig.getModelMetadata(_.omit(input, ["model", "prompt"]) as JSONObject, modelName);
modelMetadata = aiConfig.getModelMetadata(
_.omit(input, ["model", "prompt"]) as JSONObject,
modelName
);

// Super simple since palm text generation is just one shot prompting.
const prompts: Prompt[] = [
Expand All @@ -71,7 +80,11 @@ export class PaLMTextParser extends ParameterizedModelParser {
return prompts;
}

public deserialize(prompt: Prompt, aiConfig: AIConfigRuntime, params?: JSONObject | undefined): any {
public deserialize(
prompt: Prompt,
aiConfig: AIConfigRuntime,
params?: JSONObject | undefined
): any {
// TODO: @ankush-lastmile PaLM unable to set output type to Text Generation API request type, `protos.google.ai.generativelanguage.v1beta2.IGenerateTextRequest` looks like it conforms to JSONObject type but it doesn't. Returns any for now.
const startEvent = {
name: "on_deserialize_start",
Expand All @@ -87,10 +100,16 @@ export class PaLMTextParser extends ParameterizedModelParser {

// Get Prompt Template (aka prompt string), paramaterize it, and set it in completionParams
const promptTemplate = prompt.input as string;
const promptText = this.resolvePromptTemplate(promptTemplate, prompt, aiConfig, params);
const promptText = this.resolvePromptTemplate(
promptTemplate,
prompt,
aiConfig,
params
);
completionParams.prompt = { text: promptText };

const refinedCompletionParams = refineTextGenerationParams(completionParams);
const refinedCompletionParams =
refineTextGenerationParams(completionParams);

const endEvent = {
name: "on_deserialize_end",
Expand Down Expand Up @@ -147,7 +166,11 @@ export class PaLMTextParser extends ParameterizedModelParser {
return outputs;
}

public getOutputText(aiConfig: AIConfigRuntime, output?: Output, prompt?: Prompt): string {
public getOutputText(
aiConfig: AIConfigRuntime,
output?: Output,
prompt?: Prompt
): string {
if (output == null && prompt != null) {
output = aiConfig.getLatestOutput(prompt);
}
Expand All @@ -168,21 +191,37 @@ export class PaLMTextParser extends ParameterizedModelParser {
* Refines the completion params for the PALM text generation api. Removes any unsupported params.
* The supported keys were found by looking at the PaLM text generation api. `INSERT TYPE HERE`
*/
export function refineTextGenerationParams(params: JSONObject): protos.google.ai.generativelanguage.v1beta2.IGenerateTextRequest {
export function refineTextGenerationParams(
params: JSONObject
): protos.google.ai.generativelanguage.v1beta2.IGenerateTextRequest {
return {
model: params.model as string | null,
prompt: params.prompt as google.ai.generativelanguage.v1beta2.ITextPrompt | null,
temperature: params.temperature != null ? (params.temperature as number) : null,
candidateCount: params.candidateCount != null ? (params.candidateCount as number) : null,
maxOutputTokens: params.maxOutputTokens != null ? (params.maxOutputTokens as number) : null,
prompt:
params.prompt as google.ai.generativelanguage.v1beta2.ITextPrompt | null,
temperature:
params.temperature != null ? (params.temperature as number) : null,
candidateCount:
params.candidateCount != null ? (params.candidateCount as number) : null,
maxOutputTokens:
params.maxOutputTokens != null
? (params.maxOutputTokens as number)
: null,
topP: params.topP != null ? (params.topP as number) : null,
topK: params.topK != null ? (params.topK as number) : null,
safetySettings: params.safetySettings !== undefined ? (params.safetySettings as google.ai.generativelanguage.v1beta2.ISafetySetting[]) : null,
stopSequences: params.stopSequences !== undefined ? (params.stopSequences as string[]) : null,
safetySettings:
params.safetySettings !== undefined
? (params.safetySettings as google.ai.generativelanguage.v1beta2.ISafetySetting[])
: null,
stopSequences:
params.stopSequences !== undefined
? (params.stopSequences as string[])
: null,
};
}

function constructOutputs(response: protos.google.ai.generativelanguage.v1beta2.IGenerateTextResponse): ExecuteResult[] {
function constructOutputs(
response: protos.google.ai.generativelanguage.v1beta2.IGenerateTextResponse
): ExecuteResult[] {
if (!response.candidates) {
return [];
}
Expand All @@ -195,7 +234,7 @@ function constructOutputs(response: protos.google.ai.generativelanguage.v1beta2.
output_type: "execute_result",
data: candidate.output,
execution_count: i,
metadata: _.omit(candidate, ["output"])
metadata: _.omit(candidate, ["output"]),
};

outputs.push(output);
Expand Down
7 changes: 4 additions & 3 deletions typescript/lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ export function extractOverrideSettings(
modelName: string
) {
let modelMetadata: ModelMetadata | string;
const globalModelSettings: InferenceSettings =
{...(configRuntime.getGlobalSettings(modelName)) ?? {}};
inferenceSettings = {...(inferenceSettings) ?? {}}
const globalModelSettings: InferenceSettings = {
...(configRuntime.getGlobalSettings(modelName) ?? {}),
};
inferenceSettings = { ...(inferenceSettings ?? {}) };

if (globalModelSettings != null) {
// Check if the model settings from the input data are the same as the global model settings
Expand Down

0 comments on commit aa463e2

Please sign in to comment.