diff --git a/api/_lib/_llm.ts b/api/_lib/_llm.ts deleted file mode 100644 index ecdf7e8e..00000000 --- a/api/_lib/_llm.ts +++ /dev/null @@ -1,74 +0,0 @@ -/* eslint-disable @typescript-eslint/no-explicit-any */ -import { z, ZodObject } from "zod"; -import { openai } from "./_openai"; -import zodToJsonSchema from "zod-to-json-schema"; -import OpenAI from "openai"; - -type Schemas>> = T; - -export async function llmMany>>( - content: string, - schemas: Schemas -) { - try { - // if the user passes a key "message" in schemas, throw an error - if (schemas.message) throw new Error("Cannot use key 'message' in schemas"); - - const completion = await openai.chat.completions.create({ - messages: [ - { - role: "system", - content: `You are a helpful AI flowchart wizard. Create or update a flowchart for the user according the given message. Always respond with a complete and correct flowchart. If the user requests an entirely new flowchart, ignore the curren state and start from scratch.`, - }, - { - role: "user", - content, - }, - ], - tools: Object.entries(schemas).map(([key, schema]) => ({ - type: "function", - function: { - name: key, - parameters: zodToJsonSchema(schema), - }, - })), - model: "gpt-3.5-turbo-1106", - // model: "gpt-4-1106-preview", - }); - - const choice = completion.choices[0]; - - if (!choice) throw new Error("No choices returned"); - - // Must return the full thing, message and multiple tool calls - return simplifyChoice(choice) as SimplifiedChoice; - } catch (error) { - console.error(error); - const message = (error as Error)?.message || "Error with prompt"; - throw new Error(message); - } -} - -type SimplifiedChoice>> = { - message: string; - toolCalls: Array< - { - [K in keyof T]: { - name: K; - args: z.infer; - }; - }[keyof T] - >; -}; - -function simplifyChoice(choice: OpenAI.Chat.Completions.ChatCompletion.Choice) { - return { - message: choice.message.content || "", - toolCalls: - choice.message.tool_calls?.map((toolCall) => ({ - name: toolCall.function.name, - // Wish this were type-safe! - args: JSON.parse(toolCall.function.arguments ?? "{}"), - })) || [], - }; -} diff --git a/api/prompt/_edit.ts b/api/prompt/_edit.ts deleted file mode 100644 index 20126360..00000000 --- a/api/prompt/_edit.ts +++ /dev/null @@ -1,42 +0,0 @@ -import { VercelApiHandler } from "@vercel/node"; -import { llmMany } from "../_lib/_llm"; -import { z } from "zod"; - -const nodeSchema = z.object({ - // id: z.string(), - // classes: z.string(), - label: z.string(), -}); - -const edgeSchema = z.object({ - from: z.string(), - to: z.string(), - label: z.string().optional().default(""), -}); - -const graphSchema = z.object({ - nodes: z.array(nodeSchema), - edges: z.array(edgeSchema), -}); - -const handler: VercelApiHandler = async (req, res) => { - const { graph, prompt } = req.body; - if (!graph || !prompt) { - throw new Error("Missing graph or prompt"); - } - - const result = await llmMany( - `${prompt} - -Here is the current state of the flowchart: -${JSON.stringify(graph, null, 2)} -`, - { - updateGraph: graphSchema, - } - ); - - res.json(result); -}; - -export default handler; diff --git a/api/prompt/_shared.ts b/api/prompt/_shared.ts index 4b5f1cb3..81f2479c 100644 --- a/api/prompt/_shared.ts +++ b/api/prompt/_shared.ts @@ -3,17 +3,16 @@ import { streamText } from "ai"; import { stripe } from "../_lib/_stripe"; import { kv } from "@vercel/kv"; import { Ratelimit } from "@upstash/ratelimit"; -import { openai } from "@ai-sdk/openai"; +import { createOpenAI, type openai as OpenAI } from "@ai-sdk/openai"; export const reqSchema = z.object({ prompt: z.string().min(1), document: z.string(), }); -export async function handleRateLimit(req: Request) { - const ip = getIp(req); - let isPro = false, - customerId: null | string = null; +async function checkUserStatus(req: Request) { + let isPro = false; + let customerId: null | string = null; const token = req.headers.get("Authorization"); @@ -26,6 +25,16 @@ export async function handleRateLimit(req: Request) { } } + return { isPro, customerId }; +} + +export async function handleRateLimit( + req: Request, + isPro: boolean, + customerId: string | null +) { + const ip = getIp(req); + const ratelimit = new Ratelimit({ redis: kv, limiter: isPro @@ -85,11 +94,16 @@ export async function processRequest( req: Request, systemMessage: string, content: string, - model: Parameters[0] = "gpt-4-turbo" + model: Parameters[0] = "gpt-4-turbo" ) { - const rateLimitResponse = await handleRateLimit(req); + const { isPro, customerId } = await checkUserStatus(req); + const rateLimitResponse = await handleRateLimit(req, isPro, customerId); if (rateLimitResponse) return rateLimitResponse; + const openai = createOpenAI({ + apiKey: getOpenAiApiKey(isPro), + }); + const result = await streamText({ model: openai.chat(model), system: systemMessage, @@ -105,6 +119,19 @@ export async function processRequest( return result.toTextStreamResponse(); } +/** + * Returns the right api key depending on the user's subscription + * so we can track usage. Bear in mind a development key is used for + * anything that's not production. + */ +function getOpenAiApiKey(isPro: boolean) { + if (isPro) { + return process.env.OPENAI_API_KEY_PRO; + } + + return process.env.OPENAI_API_KEY_FREE; +} + function getIp(req: Request) { return ( req.headers.get("x-real-ip") ||