From 49fb65545e396911c156261fafede498974e7be1 Mon Sep 17 00:00:00 2001 From: "Mike P. Sinn" Date: Tue, 3 Sep 2024 22:06:32 -0500 Subject: [PATCH] Added research agent and UI --- app/actions.ts | 9 +- app/researcher/page.tsx | 67 ++++++++ components/ArticleRenderer.tsx | 118 ++++++++++++++ components/CustomReactMarkdown.tsx | 96 ++++++++++++ components/assistant/Readme.tsx | 92 +---------- lib/agents/fdai/safetySummaryAgent.ts | 2 +- lib/agents/researcher.ts | 63 -------- lib/agents/researcher/getSearchResults.ts | 14 ++ lib/agents/researcher/researcher.ts | 148 ++++++++++++++++++ .../{ => researcher}/searchQueryGenerator.ts | 0 lib/utils/dumpTypeDefinition.ts | 15 ++ tests/fdai.test.ts | 26 +-- tests/seed.test.ts | 25 +++ 13 files changed, 511 insertions(+), 164 deletions(-) create mode 100644 app/researcher/page.tsx create mode 100644 components/ArticleRenderer.tsx create mode 100644 components/CustomReactMarkdown.tsx delete mode 100644 lib/agents/researcher.ts create mode 100644 lib/agents/researcher/getSearchResults.ts create mode 100644 lib/agents/researcher/researcher.ts rename lib/agents/{ => researcher}/searchQueryGenerator.ts (100%) create mode 100644 lib/utils/dumpTypeDefinition.ts diff --git a/app/actions.ts b/app/actions.ts index 0371faee..9e5d4077 100644 --- a/app/actions.ts +++ b/app/actions.ts @@ -7,7 +7,7 @@ import { type Message as AIMessage } from "ai" import { prisma } from "@/lib/db" import { getCurrentUser } from "@/lib/session" import { type Chat } from "@/lib/types" - +import {type ReportOutput, writeArticle, type ModelName} from "@/lib/agents/researcher/researcher" type GetChatResult = Chat[] | null type SetChatResults = Chat[] @@ -166,3 +166,10 @@ export const clearAllChats = async (userId: string) => { return revalidatePath(deletedChats.map((chat) => chat.path).join(", ")) } } + +export async function writeArticleAction(topic: string, modelName?: ModelName): Promise { + const article = await writeArticle(topic, { modelName: modelName }) + + revalidatePath('/') + return article +} \ No newline at end of file diff --git a/app/researcher/page.tsx b/app/researcher/page.tsx new file mode 100644 index 00000000..d54559ed --- /dev/null +++ b/app/researcher/page.tsx @@ -0,0 +1,67 @@ +'use client' + +import { useState } from 'react' +import { writeArticleAction } from '@/app/actions' +import ArticleRenderer from '@/components/ArticleRenderer' +import { Button } from "@/components/ui/button" +import { Input } from "@/components/ui/input" +import { Card, CardContent, CardDescription, CardFooter, CardHeader, CardTitle } from "@/components/ui/card" +import { ReportOutput } from '@/lib/agents/researcher/researcher' +import GlobalBrainNetwork from "@/components/landingPage/global-brain-network" + +export default function Home() { + const [article, setArticle] = useState(null) + const [error, setError] = useState('') + const [isGenerating, setIsGenerating] = useState(false) + + async function handleSubmit(formData: FormData) { + const topic = formData.get('topic') as string + if (!topic) { + setError('Please enter a topic') + return + } + + setIsGenerating(true) + setError('') + + try { + const generatedArticle = await writeArticleAction(topic) + setArticle(generatedArticle) + } catch (err) { + setError('Failed to generate article. Please try again.') + } finally { + setIsGenerating(false) + } + } + + return ( +
+

Article Generator

+ + + Generate an Article + Enter a topic to generate an article + + +
{ + e.preventDefault() + handleSubmit(new FormData(e.currentTarget)) + }} className="flex gap-4"> + + +
+
+ {error && ( + +

{error}

+
+ )} +
+ + {isGenerating && } + {!isGenerating && article && } +
+ ) +} \ No newline at end of file diff --git a/components/ArticleRenderer.tsx b/components/ArticleRenderer.tsx new file mode 100644 index 00000000..12e3111a --- /dev/null +++ b/components/ArticleRenderer.tsx @@ -0,0 +1,118 @@ +import {useState} from 'react' +import {Card, CardContent, CardDescription, CardHeader, CardTitle} from "@/components/ui/card" +import {Badge} from "@/components/ui/badge" +import {Separator} from "@/components/ui/separator" +import {Clock, Tag, Folder, Link2} from 'lucide-react' +import {ReportOutput} from '@/lib/agents/researcher/researcher' +import {CustomReactMarkdown} from "@/components/CustomReactMarkdown"; + +export default function ArticleRenderer(props: ReportOutput) { + const [expandedResult, setExpandedResult] = useState(null) + + const { + title, + description, + content, + sources, + tags, + category, + readingTime, + searchResults + } = props + + return ( +
+
+ + + {title} + {description} + + + + + {content} + + + +
+ + + Article Info + + +
+ + {category} +
+
+ + {readingTime} min read +
+
+ + {tags?.map((tag, index) => ( + {tag} + ))} +
+
+
+ + + + Sources + + + + + +
+
+ + + + Search Results + + + {searchResults?.map((result, index) => ( +
+

+ + {result.title} + +

+ {result.publishedDate && ( +

+ Published on: {new Date(result.publishedDate).toLocaleDateString()} +

+ )} +

+ {expandedResult === result.id ? result.text : `${result.text.slice(0, 150)}...`} + {result.text.length > 150 && ( + + )} +

+ {index < (searchResults.length - 1) && } +
+ ))} +
+
+
+ ) +} \ No newline at end of file diff --git a/components/CustomReactMarkdown.tsx b/components/CustomReactMarkdown.tsx new file mode 100644 index 00000000..4a0782c3 --- /dev/null +++ b/components/CustomReactMarkdown.tsx @@ -0,0 +1,96 @@ +import React from "react" +import { ReactMarkdown } from "react-markdown/lib/react-markdown" +import rehypeRaw from "rehype-raw" +import remarkGfm from "remark-gfm" +import remarkMath from "remark-math" +import { cn } from "@/lib/utils" +import { CodeBlock } from "./ui/code-block" + +export const CustomReactMarkdown = React.memo(function CustomReactMarkdown({ + children, + className, + ...props +}: React.ComponentPropsWithoutRef) { + return ( + *]:mb-4 [&_li]:mb-2", + className + )} + rehypePlugins={[rehypeRaw as any, { allowDangerousHtml: true }]} + remarkPlugins={[remarkGfm, remarkMath]} + skipHtml={false} + components={{ + p({ children }) { + return

{children}

+ }, + br() { + return <> + }, + h1({ children }) { + return

{children}

+ }, + h2({ children }) { + return

{children}

+ }, + h3({ children }) { + return

{children}

+ }, + a({ href, children, ...props }) { + let target = "" + if (href?.startsWith("http")) { + target = "_blank" + } else if (href?.startsWith("#")) { + target = "_self" + } + return ( + + {children} + + ) + }, + code({ inline, className, children, ...props }) { + if (children && children.length) { + if (children[0] == "▍") { + return ( + + ) + } + children[0] = (children[0] as string).replace("`▍`", "▍") + } + const match = /language-(\w+)/.exec(className || "") + if (inline) { + return ( + + {children} + + ) + } + return ( + + ) + }, + }} + {...props} + > + {children} +
+ ) +}) \ No newline at end of file diff --git a/components/assistant/Readme.tsx b/components/assistant/Readme.tsx index da0b2dec..0e6462a6 100644 --- a/components/assistant/Readme.tsx +++ b/components/assistant/Readme.tsx @@ -1,101 +1,15 @@ "use client" -import rehypeRaw from "rehype-raw" -import remarkGfm from "remark-gfm" -import remarkMath from "remark-math" - import { cn } from "@/lib/utils" - -import { MemoizedReactMarkdown } from "../Markdown" -import { CodeBlock } from "../ui/code-block" +import { CustomReactMarkdown } from "../CustomReactMarkdown" export function Readme({ props: readme }: { props: string }) { return (
- Talk to Wishonia - {children}

- }, - br() { - return <> - }, - h1({ children }) { - return

{children}

- }, - h2({ children }) { - return

{children}

- }, - h3({ children }) { - return

{children}

- }, - a({ node, href, children, ...props }) { - let target = "" - - if (href?.startsWith("http")) { - target = "_blank" - } else if (href?.startsWith("#")) { - target = "_self" - } - - return ( - - {children} - - ) - }, - code({ node, inline, className, children, ...props }) { - if (children && children.length) { - if (children[0] == "▍") { - return ( - - ) - } - - children[0] = (children[0] as string).replace("`▍`", "▍") - } - - const match = /language-(\w+)/.exec(className || "") - - if (inline) { - return ( - - {children} - - ) - } - - return ( - - ) - }, - }} - > + {typeof readme === "string" ? readme : JSON.stringify(readme)} -
+
) diff --git a/lib/agents/fdai/safetySummaryAgent.ts b/lib/agents/fdai/safetySummaryAgent.ts index 9c2fbc92..0955f33f 100644 --- a/lib/agents/fdai/safetySummaryAgent.ts +++ b/lib/agents/fdai/safetySummaryAgent.ts @@ -2,7 +2,7 @@ import { anthropic } from "@ai-sdk/anthropic" import { generateObject } from "ai" import Exa from "exa-js" import { z } from "zod" -import {generateSearchQueries} from "@/lib/agents/searchQueryGenerator"; +import {generateSearchQueries} from "@/lib/agents/researcher/searchQueryGenerator"; const exa = new Exa(process.env.EXA_API_KEY); diff --git a/lib/agents/researcher.ts b/lib/agents/researcher.ts deleted file mode 100644 index 980332b3..00000000 --- a/lib/agents/researcher.ts +++ /dev/null @@ -1,63 +0,0 @@ -import dotenv from 'dotenv'; -import Exa from 'exa-js'; -import {generateSearchQueries} from "@/lib/agents/searchQueryGenerator"; -import {getLLMResponse} from "@/lib/agents/getLLMResponse"; - -dotenv.config(); - -const EXA_API_KEY = process.env.EXA_API_KEY as string; - -const exa = new Exa(EXA_API_KEY); - -interface SearchResult { - url: string; - text: string; -} - -async function getSearchResults(queries: string[], linksPerQuery: number = 5): Promise { - let results: SearchResult[] = []; - for (const query of queries) { - const searchResponse = await exa.searchAndContents(query, { - numResults: linksPerQuery, - useAutoprompt: false, - }); - results.push(...searchResponse.results); - } - return results; -} - -async function synthesizeReport(topic: string, searchContents: SearchResult[], contentSlice: number = 750): Promise { - const inputData = searchContents - .map( - (item) => - `--START ITEM--\nURL: ${item.url}\nCONTENT: ${item.text.slice(0, contentSlice)}\n--END ITEM--\n`, - ) - .join(""); - return await getLLMResponse({ - system: "You are a medical research assistant. Write a report according to the user's instructions.", - user: - "Input Data:\n" + - inputData + - `Write a ${topic} based on the provided information. -Include as many sources as possible. -Provide citations in the text using footnote notation ([#]). -First provide the report, followed by a single "References" section that lists all the URLs used, in the format [#] .`, - model: 'gpt-4' //want a better report? use gpt-4 (but it costs more) - }); -} - -export async function researcher(topic: string): Promise { - console.log(`Starting research on topic: "${topic}"`); - - const searchQueries = await generateSearchQueries(topic, 1); - console.log("Generated search queries:", searchQueries); - - const searchResults = await getSearchResults(searchQueries, 10); - console.log( - `Found ${searchResults.length} search results. Here's the first one:`, - searchResults[0], - ); - - console.log("Synthesizing report...") - return await synthesizeReport(topic, searchResults); -} \ No newline at end of file diff --git a/lib/agents/researcher/getSearchResults.ts b/lib/agents/researcher/getSearchResults.ts new file mode 100644 index 00000000..e087f69b --- /dev/null +++ b/lib/agents/researcher/getSearchResults.ts @@ -0,0 +1,14 @@ +import Exa, {SearchResult} from "exa-js"; +const exa = new Exa(process.env.EXA_API_KEY); + +export async function getSearchResults(queries: string[], linksPerQuery: number = 5): Promise { + let results: SearchResult[] = []; + for (const query of queries) { + const searchResponse = await exa.searchAndContents(query, { + numResults: linksPerQuery, + useAutoprompt: false, + }); + results.push(...searchResponse.results); + } + return results; +} diff --git a/lib/agents/researcher/researcher.ts b/lib/agents/researcher/researcher.ts new file mode 100644 index 00000000..28447885 --- /dev/null +++ b/lib/agents/researcher/researcher.ts @@ -0,0 +1,148 @@ +import { z } from 'zod'; +import { generateObject } from 'ai'; +import { getSearchResults } from "@/lib/agents/researcher/getSearchResults"; +import { generateSearchQueries } from "@/lib/agents/researcher/searchQueryGenerator"; +import {anthropic} from "@ai-sdk/anthropic"; +import {openai} from "@ai-sdk/openai"; +import {google} from "@ai-sdk/google"; +import {LanguageModelV1} from "@ai-sdk/provider"; +import {SearchResult} from "exa-js"; +import { ModelName } from '@/types'; +const GeneratedReportSchema = z.object({ + title: z.string().describe('The title of the report'), + description: z.string().describe('A brief description or summary of the report'), + content: z.string().describe('The main content of the report in markdown format. DO NOT include the title.'), + sources: z.array(z.object({ + url: z.string(), + title: z.string(), + description: z.string(), + })).describe('An array of sources used in the report'), + tags: z.array(z.string()).describe('Relevant tags for the report'), + category: z.string().describe('The main category of the report'), + readingTime: z.number().describe('Estimated reading time in minutes'), +}); + +type GeneratedReport = z.infer; + +export type ReportOutput = GeneratedReport & { + searchResults: SearchResult[]; +}; + +function getModel(modelName: string): LanguageModelV1 { + if (modelName.includes("claude")) { + return anthropic(modelName); + } + if (modelName.includes("gpt")) { + return openai(modelName); + } + if (modelName.includes("gemini")) { + return google('models/' + modelName, { + topK: 0, + safetySettings: [ + { category: 'HARM_CATEGORY_HATE_SPEECH', threshold: 'BLOCK_NONE' }, + { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_NONE' }, + { category: 'HARM_CATEGORY_HARASSMENT', threshold: 'BLOCK_NONE' }, + { category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', threshold: 'BLOCK_NONE' } + ] + }); + } + return anthropic('claude-3-5-sonnet-20240620'); // Default model +} + +export async function writeArticle( + topic: string, + options: { + numberOfSearchQueryVariations?: number, + numberOfWebResultsToInclude?: number, + audience?: string, + purpose?: string, + maxCharactersOfSearchContentToUse?: number, + tone?: string, + format?: 'article' | 'bullet-points' | 'Q&A', + wordLimit?: number, + includeSummary?: boolean, + languageLevel?: 'beginner' | 'intermediate' | 'advanced' | 'expert', + citationStyle?: 'footnote' | 'hyperlinked-text' | 'endnotes', + modelName?: 'claude-3-5-sonnet-20240620' | 'claude-3-opus-20240229' | 'claude-3-sonnet-20240229' | 'claude-3-haiku-20240307' | + 'gpt-4o' | 'gpt-4o-2024-05-13' | 'gpt-4o-2024-08-06' | 'gpt-4o-mini' | 'gpt-4o-mini-2024-07-18' | 'gpt-4-turbo' | 'gpt-4-turbo-2024-04-09' | 'gpt-4-turbo-preview' | 'gpt-4-0125-preview' | 'gpt-4-1106-preview' | 'gpt-4' | 'gpt-4-0613' | 'gpt-3.5-turbo-0125' | 'gpt-3.5-turbo' | 'gpt-3.5-turbo-1106' | + 'gemini-1.5-flash-latest' | 'gemini-1.5-flash' | 'gemini-1.5-pro-latest' | 'gemini-1.5-pro' | 'gemini-1.0-pro' + // doesn't work 'gemini-pro' , + } = {} +): Promise { + const { + numberOfSearchQueryVariations = 1, + numberOfWebResultsToInclude = 10, + audience = 'general', + purpose = 'inform', + maxCharactersOfSearchContentToUse = 999999, + tone = 'neutral', + format = 'article', + wordLimit, + includeSummary = false, + languageLevel = 'intermediate', + citationStyle = 'footnote', + modelName = 'claude-3-5-sonnet-20240620', + } = options; + + console.log(`Starting research on topic: "${topic}"`); + + const searchQueries = await generateSearchQueries(topic, numberOfSearchQueryVariations); + console.log("Generated search queries:", searchQueries); + + const searchResults = await getSearchResults(searchQueries, numberOfWebResultsToInclude); + console.log(`Found ${searchResults.length} search results.`); + + console.log("Synthesizing report..."); + + const model: LanguageModelV1 = getModel(modelName); + + const inputData = searchResults.map( + (item) => `--START ITEM: ${item.title}--\n + TITLE: ${item.title}\n + URL: ${item.url}\n + CONTENT: ${item.text.slice(0, maxCharactersOfSearchContentToUse)}\n + --END ITEM: ${item.title}--\n` + ).join(""); + + let citationInstructions = ''; + if (citationStyle === 'footnote') { + citationInstructions = 'Provide citations in the text using markdown footnote notation like [^1].'; + } else if (citationStyle === 'hyperlinked-text') { + citationInstructions = 'Hyperlink the relevant text in the report to the source URLs used using markdown hyperlink notation like [text](https://link-where-you-got-the-information).'; + } + + const prompt = ` + Write an extremely information-dense and comprehensive ${format} on the topic of "${topic}" based on the Web Search Results below. + + # Guidelines + + Avoid fluff and filler content. Focus on providing the most relevant and useful information. + DO NOT include the title in the content. + + Audience: ${audience} + Purpose: ${purpose} + Tone: ${tone} + Language Level: ${languageLevel} + Citatation Style: ${citationInstructions} + ${wordLimit ? `Word Limit: ${wordLimit} words` : ''} + ${includeSummary ? 'Include a brief summary at the beginning.' : ''} + +# Web Search Results + Here is a list of web pages and excerpts from them that you can use to write the report: + ${inputData} + `; + + const result = await generateObject({ + model: model, + schema: GeneratedReportSchema, + prompt, + }); + + debugger; + console.log("Article generated successfully!", result.object); + + return { + ...(result.object as unknown as GeneratedReport), + searchResults: searchResults, + }; +} \ No newline at end of file diff --git a/lib/agents/searchQueryGenerator.ts b/lib/agents/researcher/searchQueryGenerator.ts similarity index 100% rename from lib/agents/searchQueryGenerator.ts rename to lib/agents/researcher/searchQueryGenerator.ts diff --git a/lib/utils/dumpTypeDefinition.ts b/lib/utils/dumpTypeDefinition.ts new file mode 100644 index 00000000..8bb08265 --- /dev/null +++ b/lib/utils/dumpTypeDefinition.ts @@ -0,0 +1,15 @@ +export function dumpTypeDefinition(obj: any): string { + const getType = (value: any): string => { + if (value === null) return 'null'; + if (Array.isArray(value)) return `${getType(value[0])}[]`; + if (typeof value === 'object') { + const entries = Object.entries(value) + .map(([key, val]) => `${key}: ${getType(val)}`) + .join('; '); + return `{ ${entries} }`; + } + return typeof value; + }; + + return getType(obj); +} \ No newline at end of file diff --git a/tests/fdai.test.ts b/tests/fdai.test.ts index 5c1f5d62..ee550805 100644 --- a/tests/fdai.test.ts +++ b/tests/fdai.test.ts @@ -1,22 +1,28 @@ /** * @jest-environment node */ +import { writeFileSync } from "fs"; import { getOrCreateTestUser } from "@/tests/test-helpers"; -import {writeFileSync} from "fs"; +import {writeArticle} from "@/lib/agents/researcher/researcher"; import { foodOrDrugCostBenefitAnalysis, safeUnapprovedDrugs } from "@/lib/agents/fdai/fdaiAgent"; import { doMetaAnalysis } from "@/lib/agents/fdai/fdaiMetaAnalyzer"; +import { generateSafetySummary } from "@/lib/agents/fdai/safetySummaryAgent"; +import { generateSideEffects } from "@/lib/agents/fdai/sideEffectsAgent"; +import { generateMostEffectiveTreatments, + generateMostEffectiveUnapprovedTreatments, + generateTreatmentsStartingWith } from "@/lib/agents/fdai/treatmentsIndexer"; import { getOrCreateDfdaAccessToken } from "@/lib/dfda"; -import {generateSideEffects} from "@/lib/agents/fdai/sideEffectsAgent"; -import {generateSafetySummary} from "@/lib/agents/fdai/safetySummaryAgent"; -import { - generateMostEffectiveTreatments, generateMostEffectiveUnapprovedTreatments, - generateTreatmentsByAlphabet, - generateTreatmentsStartingWith -} from "@/lib/agents/fdai/treatmentsIndexer"; -import {aiModels} from "@/lib/models/aiModelRegistry"; - +import { aiModels } from "@/lib/models/aiModelRegistry"; +import {dumpTypeDefinition} from "@/lib/utils/dumpTypeDefinition"; describe("FDAi Tests", () => { + it("generates a report based on a study", async () => { + const article = await writeArticle("The most effective experimental treatments for long covid", { + modelName: "claude-3-5-sonnet-20240620", + }) + console.log(dumpTypeDefinition(article)) + expect(article).not.toBeNull() + }); it("generates treatments by alphabet", async () => { const geminiProUnapproved = await generateMostEffectiveUnapprovedTreatments("PTSD", aiModels['gemini-pro']); console.log("geminiProUnapproved", geminiProUnapproved) diff --git a/tests/seed.test.ts b/tests/seed.test.ts index 9a1dfd91..95d744ad 100644 --- a/tests/seed.test.ts +++ b/tests/seed.test.ts @@ -7,6 +7,9 @@ import { seedGlobalSolutionPairAllocations } from "@/prisma/seedGlobalSolutionPa import { assertTestDB, getOrCreateTestUser } from "@/tests/test-helpers" import { loadJsonToDatabase } from "@/lib/prisma/loadDatabaseFromJson" +import { PrismaClient } from '@prisma/client' +import fs from 'fs' +import path from 'path' beforeAll(async () => { await assertTestDB() @@ -25,4 +28,26 @@ describe("Database-seeder tests", () => { await seedGlobalSolutionPairAllocations(testUser) await loadJsonToDatabase("WishingWell", testUser.id) }) + + it("Imports DfdaCause from JSON file", async () => { + const prisma = new PrismaClient() + + // Read the JSON file + const jsonPath = path.join(process.cwd(), 'prisma', 'ct_causes.json') + const jsonData = JSON.parse(fs.readFileSync(jsonPath, 'utf-8')) + + // Import the data + for (const cause of jsonData) { + await prisma.dfdaCause.create({ + data: { + id: cause.id, + name: cause.name, + updatedAt: new Date(cause.updated_at), + createdAt: new Date(cause.created_at), + deletedAt: cause.deleted_at ? new Date(cause.deleted_at) : null, + numberOfConditions: cause.number_of_conditions, + }, + }) + } + }) })