Skip to content

Commit

Permalink
fix: a number of resolver bugs outlined in PR body (#1237)
Browse files Browse the repository at this point in the history
  • Loading branch information
armandobelardo authored Aug 2, 2024
1 parent e623285 commit d5bf09a
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 21 deletions.
16 changes: 8 additions & 8 deletions packages/template-resolver/src/SnippetTemplateResolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ export class SnippetTemplateResolver {
case "BODY":
return accessByPathNonNull(this.payload.requestBody, location.path);
case "RELATIVE":
// We should warn if this ever happens, the relative directive should only really happen within containers
return accessByPathNonNull(this.payload.requestBody, location.path);
// If you're here you don't have a payload, so return undefined
return undefined;
case "QUERY":
return this.accessParameterPayloadByPath(this.payload.queryParameters, location.path);
case "PATH":
Expand Down Expand Up @@ -198,10 +198,6 @@ export class SnippetTemplateResolver {
});
}
} else {
if (payloadOverride == null && input.value.isOptional && input.value.type === "enum") {
continue;
}

const evaluatedInput = this.resolveV1Template({
template: input.value,
payloadOverride,
Expand Down Expand Up @@ -229,7 +225,7 @@ export class SnippetTemplateResolver {
if (template.templateInput == null) {
return new DefaultedV1Snippet({ template, isRequired });
}
const payloadValue = this.getPayloadValue(template.templateInput);
const payloadValue = this.getPayloadValue(template.templateInput, payloadOverride);
if (!Array.isArray(payloadValue)) {
return new DefaultedV1Snippet({ template, isRequired });
}
Expand Down Expand Up @@ -263,7 +259,6 @@ export class SnippetTemplateResolver {
return new DefaultedV1Snippet({ template, isRequired });
}

// const payloadMap = payloadValue as Map<string, unknown>;
const evaluatedInputs: V1Snippet[] = [];
for (const key in payloadValue) {
const value = payloadValue[key as keyof typeof payloadValue];
Expand Down Expand Up @@ -300,6 +295,11 @@ export class SnippetTemplateResolver {
return new DefaultedV1Snippet({ template, isRequired });
}
const maybeEnumWireValue = this.getPayloadValue(template.templateInput, payloadOverride);

if (maybeEnumWireValue == null) {
return new DefaultedV1Snippet({ template, isRequired });
}

const enumSdkValue =
(typeof maybeEnumWireValue === "string" ? enumValues[maybeEnumWireValue] : undefined) ??
defaultEnumValue;
Expand Down
63 changes: 63 additions & 0 deletions packages/template-resolver/src/__test__/cohere.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2854,6 +2854,69 @@ export const CHAT_COMPLETION_SNIPPET_WITH_LEGACY_CLIENT_INSTANTIATION: FernRegis
],
},
},
{
type: "template",
value: {
imports: [],
isOptional: true,
containerTemplateString: "messages: [\n\t\t$FERN_INPUT\n\t]",
delimiter: ",\n\t\t",
innerTemplate: {
imports: [],
isOptional: true,
templateString: "{\n\t\t\t$FERN_INPUT\n\t\t}",
templateInputs: [
{
type: "template",
value: {
imports: [],
isOptional: true,
containerTemplateString:
'"tool_calls": [\n\t\t\t\t$FERN_INPUT\n\t\t\t]',
delimiter: ",\n\t\t\t\t",
innerTemplate: {
imports: [],
isOptional: true,
templateString: "{\n\t\t\t\t\t$FERN_INPUT\n\t\t\t\t}",
templateInputs: [
{
type: "template",
value: {
imports: [],
isOptional: true,
templateString: '"id": $FERN_INPUT',
templateInputs: [
{
location: "RELATIVE",
path: "id",
type: "payload",
},
],
type: "generic",
},
},
],
inputDelimiter: ",\n\t\t\t\t\t",
type: "generic",
},
templateInput: {
location: "RELATIVE",
path: "tool_calls",
},
type: "iterable",
},
},
],
inputDelimiter: ",\n\t\t\t",
type: "generic",
},
templateInput: {
location: "BODY",
path: "messages",
},
type: "iterable",
},
},
],
},
},
Expand Down
5 changes: 3 additions & 2 deletions packages/template-resolver/src/__test__/octo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ export const CHAT_COMPLETION_SNIPPET: FernRegistry.EndpointSnippetTemplate = {
templateString: "$FERN_INPUT",
templateInputs: [
{
location: "BODY",
location: "RELATIVE",
path: undefined,
type: "payload",
},
Expand All @@ -146,7 +146,7 @@ export const CHAT_COMPLETION_SNIPPET: FernRegistry.EndpointSnippetTemplate = {
templateString: "$FERN_INPUT",
templateInputs: [
{
location: "BODY",
location: "RELATIVE",
path: undefined,
type: "payload",
},
Expand Down Expand Up @@ -646,5 +646,6 @@ export const CHAT_COMPLETION_PAYLOAD: FernRegistry.CustomSnippetPayload = {
presence_penalty: 0,
temperature: 0.1,
top_p: 0.9,
logit_bias: { "": undefined },
},
};
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,38 @@ await cohere.chatStream({
message: "Hi! How can I help you today?",
},
],
promptTruncation: Cohere.ChatStreamRequestPromptTruncation.Off,
});
"
`;

exports[`Snippet Template Resolver > Test Chat Completion snippet with deeply nested iterables 1`] = `
"const cohere = new CohereClient({
token: "YOUR_TOKEN",
clientName: "YOUR_CLIENT_NAME",
});
await cohere.chatStream({
message: "Hello world!",
chatHistory: [
{
role: "USER",
message: "Hello",
},
{
role: "CHATBOT",
message: "Hi! How can I help you today?",
},
],
promptTruncation: Cohere.ChatStreamRequestPromptTruncation.Off,
messages: [
{
tool_calls: [
{
id: "qqw",
},
],
},
],
});
"
`;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ client = AsyncOctoAI(
)
await client.text_gen.create_chat_completion_stream(
logit_bias={},
max_tokens=512,
messages=[],
model="qwen1.5-32b-chat",
Expand All @@ -36,6 +37,7 @@ top_p=0.9

exports[`Snippet Template Resolver > Test Snippet Template Resolution 1`] = `
"from octoai.image_gen import ImageGenerationRequest
from octoai.image_gen import Scheduler
from octoai import AsyncAcme
Expand All @@ -47,7 +49,8 @@ client.image_gen.generate_sdxl(
tune_id="someId",
offset="10",
output_format="pcm_16000",
loras={"key1": "value1", "key2": "value2"}
loras={"key1": "value1", "key2": "value2"},
sampler=OctoAI.myenum.PNDM
)
)
"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,6 @@ await api.createMovie();
"
`;

exports[`Snippet Template Resolver > Test Unions Object Total Mismatch 2`] = `
"const cohere = new CohereClient({
token: "YOUR_TOKEN",
clientName: "YOUR_CLIENT_NAME",
});
await api.createMovie();
"
`;

exports[`Snippet Template Resolver > Test Unions Similar Object 1`] = `
"const cohere = new CohereClient({
token: "YOUR_TOKEN",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { SnippetTemplateResolver } from "../../SnippetTemplateResolver";
import { CHAT_COMPLETION_SNIPPET } from "../cohere";
import { CHAT_COMPLETION_SNIPPET, CHAT_COMPLETION_SNIPPET_WITH_LEGACY_CLIENT_INSTANTIATION } from "../cohere";

describe("Snippet Template Resolver", () => {
it("Test Chat Completion snippet", async () => {
Expand All @@ -17,6 +17,7 @@ describe("Snippet Template Resolver", () => {
],
requestBody: {
message: "Hello world!",
prompt_truncation: "OFF",
chat_history: [
{
role: "USER",
Expand Down Expand Up @@ -53,4 +54,52 @@ describe("Snippet Template Resolver", () => {

expect(customSnippet.client).toMatchSnapshot();
});

it("Test Chat Completion snippet with deeply nested iterables", async () => {
const resolver = new SnippetTemplateResolver({
payload: {
auth: {
type: "bearer",
token: "BE_1234",
},
headers: [
{
name: "X-Client-Name",
value: "Cohere's Client",
},
],
requestBody: {
message: "Hello world!",
prompt_truncation: "OFF",
chat_history: [
{
role: "USER",
message: "Hello",
},
{
role: "CHATBOT",
message: "Hi! How can I help you today?",
},
],
messages: [
{
tool_calls: [
{
id: "qqw",
},
],
},
],
},
},
endpointSnippetTemplate: CHAT_COMPLETION_SNIPPET_WITH_LEGACY_CLIENT_INSTANTIATION,
});
const customSnippet = await resolver.resolveWithFormatting();

if (customSnippet.type !== "typescript") {
throw new Error("Expected snippet to be typescript");
}

expect(customSnippet.client).toMatchSnapshot();
});
});

0 comments on commit d5bf09a

Please sign in to comment.