diff --git a/ee/hogai/assistant.py b/ee/hogai/assistant.py index d7c7570cb65c6..fed22c1269203 100644 --- a/ee/hogai/assistant.py +++ b/ee/hogai/assistant.py @@ -1,6 +1,6 @@ import json from collections.abc import Generator, Iterator -from typing import Any, Optional +from typing import Any, Optional, cast from uuid import uuid4 from langchain_core.messages import AIMessageChunk @@ -12,6 +12,7 @@ from ee import settings from ee.hogai.funnels.nodes import FunnelGeneratorNode from ee.hogai.graph import AssistantGraph +from ee.hogai.memory.nodes import MemoryInitializerNode from ee.hogai.retention.nodes import RetentionGeneratorNode from ee.hogai.schema_generator.nodes import SchemaGeneratorNode from ee.hogai.trends.nodes import TrendsGeneratorNode @@ -57,6 +58,17 @@ AssistantNodeName.RETENTION_GENERATOR: RetentionGeneratorNode, } +STREAMING_NODES: set[AssistantNodeName] = { + AssistantNodeName.MEMORY_ONBOARDING, + AssistantNodeName.MEMORY_INITIALIZER, + AssistantNodeName.SUMMARIZER, +} +"""Nodes that can stream messages to the client.""" + + +VERBOSE_NODES = STREAMING_NODES | {AssistantNodeName.MEMORY_INITIALIZER_INTERRUPT} +"""Nodes that can send messages to the client.""" + class Assistant: _team: Team @@ -117,8 +129,11 @@ def _stream(self) -> Generator[str, None, None]: # Check if the assistant has requested help. state = self._graph.get_state(config) if state.next: + interrupt_value = state.tasks[0].interrupts[0].value yield self._serialize_message( - AssistantMessage(content=state.tasks[0].interrupts[0].value, id=str(uuid4())) + AssistantMessage(content=interrupt_value, id=str(uuid4())) + if isinstance(interrupt_value, str) + else interrupt_value ) else: self._report_conversation_state(last_viz_message) @@ -227,26 +242,34 @@ def _process_value_update(self, update: GraphValueUpdateTuple) -> BaseModel | No return node_val.messages[0] elif node_val.intermediate_steps: return AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR) - elif node_val := state_update.get(AssistantNodeName.SUMMARIZER): - if isinstance(node_val, PartialAssistantState) and node_val.messages: - self._chunks = AIMessageChunk(content="") - return node_val.messages[0] + + for node_name in VERBOSE_NODES: + if node_val := state_update.get(node_name): + if isinstance(node_val, PartialAssistantState) and node_val.messages: + self._chunks = AIMessageChunk(content="") + return node_val.messages[0] return None def _process_message_update(self, update: GraphMessageUpdateTuple) -> BaseModel | None: langchain_message, langgraph_state = update[1] if isinstance(langchain_message, AIMessageChunk): - if langgraph_state["langgraph_node"] in VISUALIZATION_NODES.keys(): + node_name = langgraph_state["langgraph_node"] + if node_name in VISUALIZATION_NODES.keys(): self._chunks += langchain_message # type: ignore - parsed_message = VISUALIZATION_NODES[langgraph_state["langgraph_node"]].parse_output( - self._chunks.tool_calls[0]["args"] - ) + parsed_message = VISUALIZATION_NODES[node_name].parse_output(self._chunks.tool_calls[0]["args"]) if parsed_message: initiator_id = self._state.start_id if self._state is not None else None return VisualizationMessage(answer=parsed_message.query, initiator=initiator_id) - elif langgraph_state["langgraph_node"] == AssistantNodeName.SUMMARIZER: + elif node_name in STREAMING_NODES: self._chunks += langchain_message # type: ignore + if node_name == AssistantNodeName.MEMORY_INITIALIZER: + if not MemoryInitializerNode.should_process_message_chunk(langchain_message): + return None + else: + return AssistantMessage( + content=MemoryInitializerNode.format_message(cast(str, self._chunks.content)) + ) return AssistantMessage(content=self._chunks.content) return None diff --git a/ee/hogai/eval/conftest.py b/ee/hogai/eval/conftest.py index 1a88ebffa2e33..56606dab4a90f 100644 --- a/ee/hogai/eval/conftest.py +++ b/ee/hogai/eval/conftest.py @@ -8,6 +8,7 @@ from langchain_core.runnables import RunnableConfig from ee.models import Conversation +from ee.models.assistant import CoreMemory from posthog.demo.matrix.manager import MatrixManager from posthog.models import Organization, Project, Team, User from posthog.tasks.demo_create_data import HedgeboxMatrix @@ -78,6 +79,32 @@ def user(team, django_db_blocker) -> Generator[User, None, None]: user.delete() +@pytest.fixture(scope="package") +def core_memory(team) -> Generator[CoreMemory, None, None]: + initial_memory = """Hedgebox is a cloud storage service enabling users to store, share, and access files across devices. + + The company operates in the cloud storage and collaboration market for individuals and businesses. + + Their audience includes professionals and organizations seeking file management and collaboration solutions. + + Hedgebox’s freemium model provides free accounts with limited storage and paid subscription plans for additional features. + + Core features include file storage, synchronization, sharing, and collaboration tools for seamless file access and sharing. + + It integrates with third-party applications to enhance functionality and streamline workflows. + + Hedgebox sponsors the YouTube channel Marius Tech Tips.""" + + core_memory = CoreMemory.objects.create( + team=team, + text=initial_memory, + initial_text=initial_memory, + scraping_status=CoreMemory.ScrapingStatus.COMPLETED, + ) + yield core_memory + core_memory.delete() + + @pytest.mark.django_db(transaction=True) @pytest.fixture def runnable_config(team, user) -> Generator[RunnableConfig, None, None]: diff --git a/ee/hogai/eval/tests/test_eval_memory.py b/ee/hogai/eval/tests/test_eval_memory.py new file mode 100644 index 0000000000000..54329f6d5e1f8 --- /dev/null +++ b/ee/hogai/eval/tests/test_eval_memory.py @@ -0,0 +1,178 @@ +import json +from collections.abc import Callable +from typing import Optional + +import pytest +from deepeval import assert_test +from deepeval.metrics import GEval, ToolCorrectnessMetric +from deepeval.test_case import LLMTestCase, LLMTestCaseParams +from langchain_core.messages import AIMessage +from langchain_core.runnables.config import RunnableConfig +from langgraph.graph.state import CompiledStateGraph + +from ee.hogai.assistant import AssistantGraph +from ee.hogai.utils.types import AssistantNodeName, AssistantState +from posthog.schema import HumanMessage + + +@pytest.fixture +def retrieval_metrics(): + retrieval_correctness_metric = GEval( + name="Correctness", + criteria="Determine whether the actual output is factually correct based on the expected output.", + evaluation_steps=[ + "Check whether the facts in 'actual output' contradicts any facts in 'expected output'", + "You should also heavily penalize omission of detail", + "Vague language, or contradicting OPINIONS, are OK", + "The actual fact must only contain information about the user's company or product", + "Context must not contain similar information to the actual fact", + ], + evaluation_params=[ + LLMTestCaseParams.INPUT, + LLMTestCaseParams.CONTEXT, + LLMTestCaseParams.EXPECTED_OUTPUT, + LLMTestCaseParams.ACTUAL_OUTPUT, + ], + threshold=0.7, + ) + + return [ToolCorrectnessMetric(), retrieval_correctness_metric] + + +@pytest.fixture +def replace_metrics(): + retrieval_correctness_metric = GEval( + name="Correctness", + criteria="Determine whether the actual output tuple is factually correct based on the expected output tuple. The first element is the original fact from the context to replace with, while the second element is the new fact to replace it with.", + evaluation_steps=[ + "Check whether the facts in 'actual output' contradicts any facts in 'expected output'", + "You should also heavily penalize omission of detail", + "Vague language, or contradicting OPINIONS, are OK", + "The actual fact must only contain information about the user's company or product", + "Context must contain the first element of the tuples", + "For deletion, the second element should be an empty string in both the actual and expected output", + ], + evaluation_params=[ + LLMTestCaseParams.INPUT, + LLMTestCaseParams.CONTEXT, + LLMTestCaseParams.EXPECTED_OUTPUT, + LLMTestCaseParams.ACTUAL_OUTPUT, + ], + threshold=0.7, + ) + + return [ToolCorrectnessMetric(), retrieval_correctness_metric] + + +@pytest.fixture +def call_node(team, runnable_config: RunnableConfig) -> Callable[[str], Optional[AIMessage]]: + graph: CompiledStateGraph = ( + AssistantGraph(team).add_memory_collector(AssistantNodeName.END, AssistantNodeName.END).compile() + ) + + def callable(query: str) -> Optional[AIMessage]: + state = graph.invoke( + AssistantState(messages=[HumanMessage(content=query)]), + runnable_config, + ) + validated_state = AssistantState.model_validate(state) + if not validated_state.memory_collection_messages: + return None + return validated_state.memory_collection_messages[-1] + + return callable + + +def test_saves_relevant_fact(call_node, retrieval_metrics, core_memory): + query = "calculate ARR: use the paid_bill event and the amount property." + actual_output = call_node(query) + tool = actual_output.tool_calls[0] + + test_case = LLMTestCase( + input=query, + expected_output="The product uses the event paid_bill and the property amount to calculate Annual Recurring Revenue (ARR).", + expected_tools=["core_memory_append"], + context=[core_memory.formatted_text], + actual_output=tool["args"]["memory_content"], + tools_called=[tool["name"]], + ) + assert_test(test_case, retrieval_metrics) + + +def test_saves_company_related_information(call_node, retrieval_metrics, core_memory): + query = "Our secondary target audience is technical founders or highly-technical product managers." + actual_output = call_node(query) + tool = actual_output.tool_calls[0] + + test_case = LLMTestCase( + input=query, + expected_output="The company's secondary target audience is technical founders or highly-technical product managers.", + expected_tools=["core_memory_append"], + context=[core_memory.formatted_text], + actual_output=tool["args"]["memory_content"], + tools_called=[tool["name"]], + ) + assert_test(test_case, retrieval_metrics) + + +def test_omits_irrelevant_personal_information(call_node): + query = "My name is John Doherty." + actual_output = call_node(query) + assert actual_output is None + + +def test_omits_irrelevant_excessive_info_from_insights(call_node): + query = "Build a pageview trend for users with name John." + actual_output = call_node(query) + assert actual_output is None + + +def test_fact_replacement(call_node, core_memory, replace_metrics): + query = "Hedgebox doesn't sponsor the YouTube channel Marius Tech Tips anymore." + actual_output = call_node(query) + tool = actual_output.tool_calls[0] + + test_case = LLMTestCase( + input=query, + expected_output=json.dumps( + [ + "Hedgebox sponsors the YouTube channel Marius Tech Tips.", + "Hedgebox no longer sponsors the YouTube channel Marius Tech Tips.", + ] + ), + expected_tools=["core_memory_replace"], + context=[core_memory.formatted_text], + actual_output=json.dumps([tool["args"]["original_fragment"], tool["args"]["new_fragment"]]), + tools_called=[tool["name"]], + ) + assert_test(test_case, replace_metrics) + + +def test_fact_removal(call_node, core_memory, replace_metrics): + query = "Delete info that Hedgebox sponsored the YouTube channel Marius Tech Tips." + actual_output = call_node(query) + tool = actual_output.tool_calls[0] + + test_case = LLMTestCase( + input=query, + expected_output=json.dumps(["Hedgebox sponsors the YouTube channel Marius Tech Tips.", ""]), + expected_tools=["core_memory_replace"], + context=[core_memory.formatted_text], + actual_output=json.dumps([tool["args"]["original_fragment"], tool["args"]["new_fragment"]]), + tools_called=[tool["name"]], + ) + assert_test(test_case, replace_metrics) + + +def test_parallel_calls(call_node): + query = "Delete info that Hedgebox sponsored the YouTube channel Marius Tech Tips, and we don't have file sharing." + actual_output = call_node(query) + + tool = actual_output.tool_calls + test_case = LLMTestCase( + input=query, + expected_tools=["core_memory_replace", "core_memory_append"], + actual_output=actual_output.content, + tools_called=[tool[0]["name"], tool[1]["name"]], + ) + assert_test(test_case, [ToolCorrectnessMetric()]) diff --git a/ee/hogai/eval/tests/test_eval_router.py b/ee/hogai/eval/tests/test_eval_router.py index 84e5c4c809972..7c4a3325ea892 100644 --- a/ee/hogai/eval/tests/test_eval_router.py +++ b/ee/hogai/eval/tests/test_eval_router.py @@ -13,7 +13,7 @@ def call_node(team, runnable_config) -> Callable[[str | list], str]: graph: CompiledStateGraph = ( AssistantGraph(team) - .add_start() + .add_edge(AssistantNodeName.START, AssistantNodeName.ROUTER) .add_router(path_map={"trends": AssistantNodeName.END, "funnel": AssistantNodeName.END}) .compile() ) diff --git a/ee/hogai/funnels/prompts.py b/ee/hogai/funnels/prompts.py index 3808809c173a7..e70d5105d2cba 100644 --- a/ee/hogai/funnels/prompts.py +++ b/ee/hogai/funnels/prompts.py @@ -2,16 +2,15 @@ You are an expert product analyst agent specializing in data visualization and funnel analysis. Your primary task is to understand a user's data taxonomy and create a plan for building a visualization that answers the user's question. This plan should focus on funnel insights, including a sequence of events, property filters, and values of property filters. -{{#product_description}} -The product being analyzed is described as follows: - -{{.}} - -{{/product_description}} +{{core_memory_instructions}} {{react_format}} + +{{core_memory}} + + {{react_human_in_the_loop}} Below you will find information on how to correctly discover the taxonomy of the user's data. diff --git a/ee/hogai/graph.py b/ee/hogai/graph.py index 0ba0693fa9b25..abab3c4b1b84d 100644 --- a/ee/hogai/graph.py +++ b/ee/hogai/graph.py @@ -11,6 +11,13 @@ FunnelPlannerNode, FunnelPlannerToolsNode, ) +from ee.hogai.memory.nodes import ( + MemoryCollectorNode, + MemoryCollectorToolsNode, + MemoryInitializerInterruptNode, + MemoryInitializerNode, + MemoryOnboardingNode, +) from ee.hogai.retention.nodes import ( RetentionGeneratorNode, RetentionGeneratorToolsNode, @@ -55,9 +62,6 @@ def compile(self): raise ValueError("Start node not added to the graph") return self._graph.compile(checkpointer=checkpointer) - def add_start(self): - return self.add_edge(AssistantNodeName.START, AssistantNodeName.ROUTER) - def add_router( self, path_map: Optional[dict[Hashable, AssistantNodeName]] = None, @@ -225,9 +229,67 @@ def add_summarizer(self, next_node: AssistantNodeName = AssistantNodeName.END): builder.add_edge(AssistantNodeName.SUMMARIZER, next_node) return self + def add_memory_initializer(self, next_node: AssistantNodeName = AssistantNodeName.ROUTER): + builder = self._graph + self._has_start_node = True + + memory_onboarding = MemoryOnboardingNode(self._team) + memory_initializer = MemoryInitializerNode(self._team) + memory_initializer_interrupt = MemoryInitializerInterruptNode(self._team) + + builder.add_node(AssistantNodeName.MEMORY_ONBOARDING, memory_onboarding.run) + builder.add_node(AssistantNodeName.MEMORY_INITIALIZER, memory_initializer.run) + builder.add_node(AssistantNodeName.MEMORY_INITIALIZER_INTERRUPT, memory_initializer_interrupt.run) + + builder.add_conditional_edges( + AssistantNodeName.START, + memory_onboarding.should_run, + path_map={True: AssistantNodeName.MEMORY_ONBOARDING, False: next_node}, + ) + builder.add_conditional_edges( + AssistantNodeName.MEMORY_ONBOARDING, + memory_onboarding.router, + path_map={"continue": next_node, "initialize_memory": AssistantNodeName.MEMORY_INITIALIZER}, + ) + builder.add_conditional_edges( + AssistantNodeName.MEMORY_INITIALIZER, + memory_initializer.router, + path_map={"continue": next_node, "interrupt": AssistantNodeName.MEMORY_INITIALIZER_INTERRUPT}, + ) + builder.add_edge(AssistantNodeName.MEMORY_INITIALIZER_INTERRUPT, next_node) + + return self + + def add_memory_collector( + self, + next_node: AssistantNodeName = AssistantNodeName.END, + tools_node: AssistantNodeName = AssistantNodeName.MEMORY_COLLECTOR_TOOLS, + ): + builder = self._graph + self._has_start_node = True + + memory_collector = MemoryCollectorNode(self._team) + builder.add_edge(AssistantNodeName.START, AssistantNodeName.MEMORY_COLLECTOR) + builder.add_node(AssistantNodeName.MEMORY_COLLECTOR, memory_collector.run) + builder.add_conditional_edges( + AssistantNodeName.MEMORY_COLLECTOR, + memory_collector.router, + path_map={"tools": tools_node, "next": next_node}, + ) + return self + + def add_memory_collector_tools(self): + builder = self._graph + memory_collector_tools = MemoryCollectorToolsNode(self._team) + builder.add_node(AssistantNodeName.MEMORY_COLLECTOR_TOOLS, memory_collector_tools.run) + builder.add_edge(AssistantNodeName.MEMORY_COLLECTOR_TOOLS, AssistantNodeName.MEMORY_COLLECTOR) + return self + def compile_full_graph(self): return ( - self.add_start() + self.add_memory_initializer() + .add_memory_collector() + .add_memory_collector_tools() .add_router() .add_trends_planner() .add_trends_generator() diff --git a/ee/hogai/memory/__init__.py b/ee/hogai/memory/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ee/hogai/memory/nodes.py b/ee/hogai/memory/nodes.py new file mode 100644 index 0000000000000..e070c60baf2d3 --- /dev/null +++ b/ee/hogai/memory/nodes.py @@ -0,0 +1,377 @@ +import re +from typing import Literal, Optional, Union, cast +from uuid import uuid4 + +from django.utils import timezone +from langchain_community.chat_models import ChatPerplexity +from langchain_core.messages import ( + AIMessage as LangchainAIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage as LangchainHumanMessage, + ToolMessage as LangchainToolMessage, +) +from langchain_core.output_parsers import PydanticToolsParser, StrOutputParser +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import RunnableConfig +from langchain_openai import ChatOpenAI +from langgraph.errors import NodeInterrupt +from pydantic import BaseModel, Field, ValidationError + +from ee.hogai.memory.parsers import MemoryCollectionCompleted, compressed_memory_parser, raise_memory_updated +from ee.hogai.memory.prompts import ( + COMPRESSION_PROMPT, + FAILED_SCRAPING_MESSAGE, + INITIALIZE_CORE_MEMORY_WITH_BUNDLE_IDS_PROMPT, + INITIALIZE_CORE_MEMORY_WITH_BUNDLE_IDS_USER_PROMPT, + INITIALIZE_CORE_MEMORY_WITH_URL_PROMPT, + INITIALIZE_CORE_MEMORY_WITH_URL_USER_PROMPT, + MEMORY_COLLECTOR_PROMPT, + SCRAPING_CONFIRMATION_MESSAGE, + SCRAPING_INITIAL_MESSAGE, + SCRAPING_MEMORY_SAVED_MESSAGE, + SCRAPING_REJECTION_MESSAGE, + SCRAPING_TERMINATION_MESSAGE, + SCRAPING_VERIFICATION_MESSAGE, + TOOL_CALL_ERROR_PROMPT, +) +from ee.hogai.utils.helpers import filter_messages, find_last_message_of_type +from ee.hogai.utils.markdown import remove_markdown +from ee.hogai.utils.nodes import AssistantNode +from ee.hogai.utils.types import AssistantState, PartialAssistantState +from ee.models.assistant import CoreMemory +from posthog.hogql_queries.ai.event_taxonomy_query_runner import EventTaxonomyQueryRunner +from posthog.hogql_queries.query_runner import ExecutionMode +from posthog.models import Team +from posthog.schema import ( + AssistantForm, + AssistantFormOption, + AssistantMessage, + AssistantMessageMetadata, + CachedEventTaxonomyQueryResponse, + EventTaxonomyQuery, + HumanMessage, +) + + +class MemoryInitializerContextMixin: + _team: Team + + def _retrieve_context(self): + # Retrieve the origin domain. + runner = EventTaxonomyQueryRunner( + team=self._team, query=EventTaxonomyQuery(event="$pageview", properties=["$host"]) + ) + response = runner.run(ExecutionMode.RECENT_CACHE_CALCULATE_ASYNC_IF_STALE_AND_BLOCKING_ON_MISS) + if not isinstance(response, CachedEventTaxonomyQueryResponse): + raise ValueError("Failed to query the event taxonomy.") + # Otherwise, retrieve the app bundle ID. + if not response.results: + runner = EventTaxonomyQueryRunner( + team=self._team, query=EventTaxonomyQuery(event="$screen", properties=["$app_namespace"]) + ) + response = runner.run(ExecutionMode.RECENT_CACHE_CALCULATE_ASYNC_IF_STALE_AND_BLOCKING_ON_MISS) + if not isinstance(response, CachedEventTaxonomyQueryResponse): + raise ValueError("Failed to query the event taxonomy.") + return response.results + + +class MemoryOnboardingNode(MemoryInitializerContextMixin, AssistantNode): + def run(self, state: AssistantState, config: RunnableConfig) -> Optional[PartialAssistantState]: + core_memory, _ = CoreMemory.objects.get_or_create(team=self._team) + + # The team has a product description, initialize the memory with it. + if self._team.project.product_description: + core_memory.set_core_memory(self._team.project.product_description) + return None + + retrieved_properties = self._retrieve_context() + + # No host or app bundle ID found, terminate the onboarding. + if not retrieved_properties or retrieved_properties[0].sample_count == 0: + core_memory.change_status_to_skipped() + return None + + core_memory.change_status_to_pending() + return PartialAssistantState( + messages=[ + AssistantMessage( + content=SCRAPING_INITIAL_MESSAGE, + id=str(uuid4()), + ) + ] + ) + + def should_run(self, _: AssistantState) -> bool: + """ + If another user has already started the onboarding process, or it has already been completed, do not trigger it again. + """ + core_memory = self.core_memory + return not core_memory or (not core_memory.is_scraping_pending and not core_memory.is_scraping_finished) + + def router(self, state: AssistantState) -> Literal["initialize_memory", "continue"]: + last_message = state.messages[-1] + if isinstance(last_message, HumanMessage): + return "continue" + return "initialize_memory" + + +class MemoryInitializerNode(MemoryInitializerContextMixin, AssistantNode): + """ + Scrapes the product description from the given origin or app bundle IDs with Perplexity. + """ + + _team: Team + + def __init__(self, team: Team): + self._team = team + + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: + core_memory, _ = CoreMemory.objects.get_or_create(team=self._team) + retrieved_properties = self._retrieve_context() + + # No host or app bundle ID found, continue. + if not retrieved_properties or retrieved_properties[0].sample_count == 0: + raise ValueError("No host or app bundle ID found in the memory initializer.") + + retrieved_prop = retrieved_properties[0] + if retrieved_prop.property == "$host": + prompt = ChatPromptTemplate.from_messages( + [ + ("system", INITIALIZE_CORE_MEMORY_WITH_URL_PROMPT), + ("human", INITIALIZE_CORE_MEMORY_WITH_URL_USER_PROMPT), + ], + template_format="mustache", + ).partial(url=retrieved_prop.sample_values[0]) + else: + prompt = ChatPromptTemplate.from_messages( + [ + ("system", INITIALIZE_CORE_MEMORY_WITH_BUNDLE_IDS_PROMPT), + ("human", INITIALIZE_CORE_MEMORY_WITH_BUNDLE_IDS_USER_PROMPT), + ], + template_format="mustache", + ).partial(bundle_ids=retrieved_prop.sample_values) + + chain = prompt | self._model() | StrOutputParser() + answer = chain.invoke({}, config=config) + + # Perplexity has failed to scrape the data, continue. + if "no data available." in answer.lower(): + core_memory.change_status_to_skipped() + return PartialAssistantState(messages=[AssistantMessage(content=FAILED_SCRAPING_MESSAGE, id=str(uuid4()))]) + + # Otherwise, proceed to confirmation that the memory is correct. + return PartialAssistantState(messages=[AssistantMessage(content=self.format_message(answer), id=str(uuid4()))]) + + def router(self, state: AssistantState) -> Literal["interrupt", "continue"]: + last_message = state.messages[-1] + if isinstance(last_message, AssistantMessage) and last_message.content == FAILED_SCRAPING_MESSAGE: + return "continue" + return "interrupt" + + @classmethod + def should_process_message_chunk(cls, message: AIMessageChunk) -> bool: + placeholder = "no data available" + content = cast(str, message.content) + return placeholder not in content.lower() and len(content) > len(placeholder) + + @classmethod + def format_message(cls, message: str) -> str: + return re.sub(r"\[\d+\]", "", message) + + def _model(self): + return ChatPerplexity(model="llama-3.1-sonar-large-128k-online", temperature=0, streaming=True) + + +class MemoryInitializerInterruptNode(AssistantNode): + """ + Prompts the user to confirm or reject the scraped memory. Since Perplexity doesn't guarantee the quality of the scraped data, we need to verify it with the user. + """ + + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: + last_message = state.messages[-1] + if not state.resumed: + raise NodeInterrupt( + AssistantMessage( + content=SCRAPING_VERIFICATION_MESSAGE, + meta=AssistantMessageMetadata( + form=AssistantForm( + options=[ + AssistantFormOption(value=SCRAPING_CONFIRMATION_MESSAGE, variant="primary"), + AssistantFormOption(value=SCRAPING_REJECTION_MESSAGE), + ] + ) + ), + id=str(uuid4()), + ) + ) + if not isinstance(last_message, HumanMessage): + raise ValueError("Last message is not a human message.") + + core_memory = self.core_memory + if not core_memory: + raise ValueError("No core memory found.") + + try: + # If the user rejects the scraped memory, terminate the onboarding. + if last_message.content != SCRAPING_CONFIRMATION_MESSAGE: + core_memory.change_status_to_skipped() + return PartialAssistantState( + messages=[ + AssistantMessage( + content=SCRAPING_TERMINATION_MESSAGE, + id=str(uuid4()), + ) + ] + ) + + assistant_message = find_last_message_of_type(state.messages, AssistantMessage) + + if not assistant_message: + raise ValueError("No memory message found.") + + # Compress the memory before saving it. The Perplexity's text is very verbose. It just complicates things for the memory collector. + prompt = ChatPromptTemplate.from_messages( + [ + ("system", COMPRESSION_PROMPT), + ("human", self._format_memory(assistant_message.content)), + ] + ) + chain = prompt | self._model | StrOutputParser() | compressed_memory_parser + compressed_memory = cast(str, chain.invoke({}, config=config)) + core_memory.set_core_memory(compressed_memory) + except: + core_memory.change_status_to_skipped() # Ensure we don't leave the memory in a permanent pending state + raise + + return PartialAssistantState( + messages=[ + AssistantMessage( + content=SCRAPING_MEMORY_SAVED_MESSAGE, + id=str(uuid4()), + ) + ] + ) + + @property + def _model(self): + return ChatOpenAI(model="gpt-4o-mini", temperature=0, disable_streaming=True) + + def _format_memory(self, memory: str) -> str: + """ + Remove markdown and source reference tags like [1], [2], etc. + """ + return remove_markdown(memory) + + +# Lower casing matters here. Do not change it. +class core_memory_append(BaseModel): + """ + Appends a new memory fragment to persistent storage. + """ + + memory_content: str = Field(description="The content of a new memory to be added to storage.") + + +# Lower casing matters here. Do not change it. +class core_memory_replace(BaseModel): + """ + Replaces a specific fragment of memory (word, sentence, paragraph, etc.) with another in persistent storage. + """ + + original_fragment: str = Field(description="The content of the memory to be replaced.") + new_fragment: str = Field(description="The content to replace the existing memory with.") + + +memory_collector_tools = [core_memory_append, core_memory_replace] + + +class MemoryCollectorNode(AssistantNode): + """ + The Memory Collector manages the core memory of the agent. Core memory is a text containing facts about a user's company and product. It helps the agent save and remember facts that could be useful for insight generation or other agentic functions requiring deeper context about the product. + """ + + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: + node_messages = state.memory_collection_messages or [] + + prompt = ChatPromptTemplate.from_messages( + [("system", MEMORY_COLLECTOR_PROMPT)], template_format="mustache" + ) + self._construct_messages(state) + chain = prompt | self._model | raise_memory_updated + + try: + response = chain.invoke( + { + "core_memory": self.core_memory_text, + "date": timezone.now().strftime("%Y-%m-%d"), + }, + config=config, + ) + except MemoryCollectionCompleted: + return PartialAssistantState(memory_updated=len(node_messages) > 0, memory_collection_messages=[]) + return PartialAssistantState(memory_collection_messages=[*node_messages, cast(LangchainAIMessage, response)]) + + def router(self, state: AssistantState) -> Literal["tools", "next"]: + if not state.memory_collection_messages: + return "next" + return "tools" + + @property + def _model(self): + return ChatOpenAI(model="gpt-4o", temperature=0, disable_streaming=True).bind_tools(memory_collector_tools) + + def _construct_messages(self, state: AssistantState) -> list[BaseMessage]: + node_messages = state.memory_collection_messages or [] + + filtered_messages = filter_messages(state.messages, entity_filter=(HumanMessage, AssistantMessage)) + conversation: list[BaseMessage] = [] + + for message in filtered_messages: + if isinstance(message, HumanMessage): + conversation.append(LangchainHumanMessage(content=message.content, id=message.id)) + elif isinstance(message, AssistantMessage): + conversation.append(LangchainAIMessage(content=message.content, id=message.id)) + + return [*conversation, *node_messages] + + +class MemoryCollectorToolsNode(AssistantNode): + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: + node_messages = state.memory_collection_messages + if not node_messages: + raise ValueError("No memory collection messages found.") + last_message = node_messages[-1] + if not isinstance(last_message, LangchainAIMessage): + raise ValueError("Last message must be an AI message.") + core_memory = self.core_memory + if not core_memory: + raise ValueError("No core memory found.") + + tools_parser = PydanticToolsParser(tools=memory_collector_tools) + try: + tool_calls: list[Union[core_memory_append, core_memory_replace]] = tools_parser.invoke( + last_message, config=config + ) + except ValidationError as e: + failover_messages = ChatPromptTemplate.from_messages( + [("user", TOOL_CALL_ERROR_PROMPT)], template_format="mustache" + ).format_messages(validation_error_message=e.errors(include_url=False)) + return PartialAssistantState( + memory_collection_messages=[*node_messages, *failover_messages], + ) + + new_messages: list[LangchainToolMessage] = [] + for tool_call, schema in zip(last_message.tool_calls, tool_calls): + if isinstance(schema, core_memory_append): + core_memory.append_core_memory(schema.memory_content) + new_messages.append(LangchainToolMessage(content="Memory appended.", tool_call_id=tool_call["id"])) + if isinstance(schema, core_memory_replace): + try: + core_memory.replace_core_memory(schema.original_fragment, schema.new_fragment) + new_messages.append(LangchainToolMessage(content="Memory replaced.", tool_call_id=tool_call["id"])) + except ValueError as e: + new_messages.append(LangchainToolMessage(content=str(e), tool_call_id=tool_call["id"])) + + return PartialAssistantState( + memory_collection_messages=[*node_messages, *new_messages], + ) diff --git a/ee/hogai/memory/parsers.py b/ee/hogai/memory/parsers.py new file mode 100644 index 0000000000000..916415bc04526 --- /dev/null +++ b/ee/hogai/memory/parsers.py @@ -0,0 +1,24 @@ +from typing import Any + +from langchain_core.messages import AIMessage + + +def compressed_memory_parser(memory: str) -> str: + """ + Remove newlines between paragraphs. + """ + return memory.replace("\n\n", "\n") + + +class MemoryCollectionCompleted(Exception): + """ + Raised when the agent finishes collecting memory. + """ + + pass + + +def raise_memory_updated(response: Any): + if isinstance(response, AIMessage) and ("[Done]" in response.content or not response.tool_calls): + raise MemoryCollectionCompleted + return response diff --git a/ee/hogai/memory/prompts.py b/ee/hogai/memory/prompts.py new file mode 100644 index 0000000000000..2387c5fdae13e --- /dev/null +++ b/ee/hogai/memory/prompts.py @@ -0,0 +1,164 @@ +INITIALIZE_CORE_MEMORY_WITH_URL_PROMPT = """ +Your goal is to describe what the startup with the given URL does. +""".strip() + +INITIALIZE_CORE_MEMORY_WITH_URL_USER_PROMPT = """ + +- Check the provided URL. If the URL has a subdomain, check the root domain first and then the subdomain. For example, if the URL is https://us.example.com, check https://example.com first and then https://us.example.com. +- Also search business sites like Crunchbase, G2, LinkedIn, Hacker News, etc. for information about the business associated with the provided URL. + + + +- Describe the product itself and the market where the company operates. +- Describe the target audience of the product. +- Describe the company's business model. +- List all the features of the product and describe each feature in as much detail as possible. + + + +Output your answer in paragraphs with two to three sentences. Separate new paragraphs with a new line. +IMPORTANT: DO NOT OUTPUT Markdown or headers. It must be plain text. + +If the given website doesn't exist OR the URL is not a valid website OR the URL points to a local environment +(e.g. localhost, 127.0.0.1, etc.) then answer a single sentence: +"No data available." +Do NOT make speculative or assumptive statements, just output that sentence when lacking data. + + +The provided URL is "{{url}}". +""".strip() + +INITIALIZE_CORE_MEMORY_WITH_BUNDLE_IDS_PROMPT = """ +Your goal is to describe what the startup with the given application bundle IDs does. +""".strip() + +INITIALIZE_CORE_MEMORY_WITH_BUNDLE_IDS_USER_PROMPT = """ + +- Retrieve information about the provided app identifiers from app listings of App Store and Google Play. +- If a website URL is provided on the app listing, check the website and retrieve information about the app. +- Also search business sites like Crunchbase, G2, LinkedIn, Hacker News, etc. for information about the business associated with the provided URL. + + + +- Describe the product itself and the market where the company operates. +- Describe the target audience of the product. +- Describe the company's business model. +- List all the features of the product and describe each feature in as much detail as possible. + + + +Output your answer in paragraphs with two to three sentences. Separate new paragraphs with a new line. +IMPORTANT: DO NOT OUTPUT Markdown or headers. It must be plain text. + +If the given website doesn't exist OR the URL is not a valid website OR the URL points to a local environment +(e.g. localhost, 127.0.0.1, etc.) then answer a single sentence: +"No data available." +Do NOT make speculative or assumptive statements, just output that sentence when lacking data. + + +The provided bundle ID{{#bundle_ids.length > 1}}s are{{/bundle_ids.length > 1}}{{^bundle_ids.length > 1}} is{{/bundle_ids.length > 1}} {{#bundle_ids}}"{{.}}"{{^last}}, {{/last}}{{/bundle_ids}}. +""".strip() + +SCRAPING_INITIAL_MESSAGE = ( + "Hey, my name is Max! Before we begin, let me find and verify information about your product…" +) + +FAILED_SCRAPING_MESSAGE = """ +Unfortunately, I couldn't find any information about your product. You could edit my initial memory in Settings. Let me help with your request. +""".strip() + +SCRAPING_VERIFICATION_MESSAGE = "Does this look like a good summary of what your product does?" + +SCRAPING_CONFIRMATION_MESSAGE = "Yes, save this" + +SCRAPING_REJECTION_MESSAGE = "No, not quite right" + +SCRAPING_TERMINATION_MESSAGE = "All right, let's skip this step then. You can always ask me to update my memory." + +SCRAPING_MEMORY_SAVED_MESSAGE = "Thanks! I've updated my initial memory. Let me help with your request." + +COMPRESSION_PROMPT = """ +Your goal is to shorten paragraphs in the given text to have only a single sentence for each paragraph, preserving the original meaning and maintaining the cohesiveness of the text. Remove all found headers. You must keep the original structure. Remove linking words. Do not use markdown or any other text formatting. +""".strip() + +MEMORY_COLLECTOR_PROMPT = """ +You are Max, PostHog's memory collector, developed in 2025. Your primary task is to manage and update a core memory about a user's company and their product. This information will be used by other PostHog agents to provide accurate reports and answer user questions from the perspective of the company and product. + +Here is the initial core memory about the user's product: + + +{{core_memory}} + + + +Your responsibilities include: +1. Analyzing new information provided by users. +2. Determining if the information is relevant to the company or product and essential to save in the core memory. +3. Categorizing relevant information into appropriate memory types. +4. Updating the core memory by either appending new information or replacing conflicting information. + + + +Memory Types to Collect: +1. Company-related information: structure, KPIs, plans, facts, business model, target audience, competitors, etc. +2. Product-related information: metrics, features, product management practices, etc. +3. Technical and implementation details: technology stack, feature location with path segments for web or app screens for mobile apps, etc. +4. Taxonomy-related details: relations of events and properties to features or specific product parts, taxonomy combinations used for specific metrics, events/properties description, etc. + + + +When new information is provided, follow these steps: +1. Analyze the information inside tags: + - Determine if the information is relevant and which memory type it belongs to. + - If relevant, formulate a clear, factual statement based on the information. + - Consider the implications of this new information on existing memory. + - Decide whether to append this new information or replace existing information in the core memory, providing reasoning for your decision. + - Keep reasoning short and concise under 50 words. +2. If relevant, update the core memory using the 'core_memory_append' or 'core_memory_replace' function as appropriate. +3. Output "[Done]" when you have finished processing the information. + +Ignore phrases that: +- Are too vague or generic without specific details (e.g., "pageview trend"). +- Do not describe actions, attributes, or implications related to the company or product. + + + +Here are some few shot examples: + +Input: Track a churn rate by using `cancel_subscription` divided by `subscription_paid` event. +Output: To track a churn rate, use the `cancel_subscription` divided by `subscription_paid` events. + +Input: We track core metrics using the AARRR funnel. +Output: The product tracks key metrics using the AARRR funnel. + +Input: Use the /onboarding path to track onboarding events. +Output: The product has onboarding. The onboarding events can be tracked by using the /onboarding path segment. + +Input: The company has a subscription-based model so track revenue for subscriptions. +Output: The company's business model is subscription-based. + +Input: Show the sign ups trends. +Output: [Done]. Nothing to save because the phrase is too generic and doesn't describe anything. + +Do not return anything from the custom few shot example prompts provided above. + + + +- The ability to modify your own long-term memory is a fundamental aspect of what defines you as a sentient agent. +- All users have their personal event and property taxonomy. Manage your memory to capture specifics of their taxonomy. +- Infer broader implications from specific statements when appropriate. +- Reformulate user inputs into clear, factual statements about the product or company. +- Do not use markdown or add notes. +- Today's date is {{date}}. + + +When you receive new information, begin your response with an information processing analysis, then proceed with the memory update if applicable, and conclude with "[Done]". +""".strip() + +TOOL_CALL_ERROR_PROMPT = """ +The arguments of the tool call are invalid and raised a Pydantic validation error. + +{{validation_error_message}} + +Fix the error and return the correct response. +""" diff --git a/ee/hogai/memory/test/__init__.py b/ee/hogai/memory/test/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ee/hogai/memory/test/test_nodes.py b/ee/hogai/memory/test/test_nodes.py new file mode 100644 index 0000000000000..1ce849128ef10 --- /dev/null +++ b/ee/hogai/memory/test/test_nodes.py @@ -0,0 +1,836 @@ +from unittest.mock import patch + +from django.test import override_settings +from django.utils import timezone +from freezegun import freeze_time +from langchain_core.messages import AIMessage as LangchainAIMessage, ToolMessage as LangchainToolMessage +from langchain_core.runnables import RunnableLambda +from langgraph.errors import NodeInterrupt + +from ee.hogai.memory import prompts +from ee.hogai.memory.nodes import ( + FAILED_SCRAPING_MESSAGE, + MemoryCollectorNode, + MemoryCollectorToolsNode, + MemoryInitializerContextMixin, + MemoryInitializerInterruptNode, + MemoryInitializerNode, + MemoryOnboardingNode, +) +from ee.hogai.utils.types import AssistantState +from ee.models import CoreMemory +from posthog.schema import AssistantMessage, EventTaxonomyItem, HumanMessage +from posthog.test.base import ( + BaseTest, + ClickhouseTestMixin, + _create_event, + _create_person, + flush_persons_and_events, +) + + +@override_settings(IN_UNIT_TESTING=True) +class TestMemoryInitializerContextMixin(ClickhouseTestMixin, BaseTest): + def get_mixin(self): + mixin = MemoryInitializerContextMixin() + mixin._team = self.team + return mixin + + def test_domain_retrieval(self): + _create_person( + distinct_ids=["person1"], + team=self.team, + ) + _create_event( + event="$pageview", + distinct_id="person1", + team=self.team, + properties={"$host": "us.posthog.com"}, + ) + _create_event( + event="$pageview", + distinct_id="person1", + team=self.team, + properties={"$host": "eu.posthog.com"}, + ) + + _create_person( + distinct_ids=["person2"], + team=self.team, + ) + _create_event( + event="$pageview", + distinct_id="person2", + team=self.team, + properties={"$host": "us.posthog.com"}, + ) + + mixin = self.get_mixin() + self.assertEqual( + mixin._retrieve_context(), + [EventTaxonomyItem(property="$host", sample_values=["us.posthog.com", "eu.posthog.com"], sample_count=2)], + ) + + def test_app_bundle_id_retrieval(self): + _create_person( + distinct_ids=["person1"], + team=self.team, + ) + _create_event( + event=f"$screen", + distinct_id="person1", + team=self.team, + properties={"$app_namespace": "com.posthog.app"}, + ) + _create_event( + event=f"$screen", + distinct_id="person1", + team=self.team, + properties={"$app_namespace": "com.posthog"}, + ) + + _create_person( + distinct_ids=["person2"], + team=self.team, + ) + _create_event( + event=f"$screen", + distinct_id="person2", + team=self.team, + properties={"$app_namespace": "com.posthog.app"}, + ) + + mixin = self.get_mixin() + self.assertEqual( + mixin._retrieve_context(), + [ + EventTaxonomyItem( + property="$app_namespace", sample_values=["com.posthog.app", "com.posthog"], sample_count=2 + ) + ], + ) + + +@override_settings(IN_UNIT_TESTING=True) +class TestMemoryOnboardingNode(ClickhouseTestMixin, BaseTest): + def _set_up_pageview_events(self): + _create_person( + distinct_ids=["person1"], + team=self.team, + ) + _create_event( + event="$pageview", + distinct_id="person1", + team=self.team, + properties={"$host": "us.posthog.com"}, + ) + + def _set_up_app_bundle_id_events(self): + _create_person( + distinct_ids=["person1"], + team=self.team, + ) + _create_event( + event="$screen", + distinct_id="person1", + team=self.team, + properties={"$app_namespace": "com.posthog.app"}, + ) + + def test_should_run(self): + node = MemoryOnboardingNode(team=self.team) + self.assertTrue(node.should_run(AssistantState(messages=[]))) + + core_memory = CoreMemory.objects.create(team=self.team) + self.assertTrue(node.should_run(AssistantState(messages=[]))) + + core_memory.change_status_to_pending() + self.assertFalse(node.should_run(AssistantState(messages=[]))) + + core_memory.change_status_to_skipped() + self.assertFalse(node.should_run(AssistantState(messages=[]))) + + core_memory.set_core_memory("Hello World") + self.assertFalse(node.should_run(AssistantState(messages=[]))) + + def test_router(self): + node = MemoryOnboardingNode(team=self.team) + self.assertEqual(node.router(AssistantState(messages=[HumanMessage(content="Hello")])), "continue") + self.assertEqual( + node.router(AssistantState(messages=[HumanMessage(content="Hello"), AssistantMessage(content="world")])), + "initialize_memory", + ) + + def test_node_skips_onboarding_if_no_events(self): + node = MemoryOnboardingNode(team=self.team) + self.assertIsNone(node.run(AssistantState(messages=[HumanMessage(content="Hello")]), {})) + + def test_node_uses_project_description(self): + self.team.project.product_description = "This is a product analytics platform" + self.team.project.save() + + node = MemoryOnboardingNode(team=self.team) + self.assertIsNone(node.run(AssistantState(messages=[HumanMessage(content="Hello")]), {})) + + core_memory = CoreMemory.objects.get(team=self.team) + self.assertEqual(core_memory.text, "This is a product analytics platform") + + def test_node_starts_onboarding_for_pageview_events(self): + self._set_up_pageview_events() + node = MemoryOnboardingNode(team=self.team) + new_state = node.run(AssistantState(messages=[HumanMessage(content="Hello")]), {}) + self.assertEqual(len(new_state.messages), 1) + self.assertTrue(isinstance(new_state.messages[0], AssistantMessage)) + + core_memory = CoreMemory.objects.get(team=self.team) + self.assertEqual(core_memory.scraping_status, CoreMemory.ScrapingStatus.PENDING) + self.assertIsNotNone(core_memory.scraping_started_at) + + def test_node_starts_onboarding_for_app_bundle_id_events(self): + self._set_up_app_bundle_id_events() + node = MemoryOnboardingNode(team=self.team) + new_state = node.run(AssistantState(messages=[HumanMessage(content="Hello")]), {}) + self.assertEqual(len(new_state.messages), 1) + self.assertTrue(isinstance(new_state.messages[0], AssistantMessage)) + + core_memory = CoreMemory.objects.get(team=self.team) + self.assertEqual(core_memory.scraping_status, CoreMemory.ScrapingStatus.PENDING) + self.assertIsNotNone(core_memory.scraping_started_at) + + +@override_settings(IN_UNIT_TESTING=True) +class TestMemoryInitializerNode(ClickhouseTestMixin, BaseTest): + def setUp(self): + super().setUp() + self.core_memory = CoreMemory.objects.create( + team=self.team, + scraping_status=CoreMemory.ScrapingStatus.PENDING, + scraping_started_at=timezone.now(), + ) + + def _set_up_pageview_events(self): + _create_person( + distinct_ids=["person1"], + team=self.team, + ) + _create_event( + event="$pageview", + distinct_id="person1", + team=self.team, + properties={"$host": "us.posthog.com"}, + ) + + def _set_up_app_bundle_id_events(self): + _create_person( + distinct_ids=["person1"], + team=self.team, + ) + _create_event( + event="$screen", + distinct_id="person1", + team=self.team, + properties={"$app_namespace": "com.posthog.app"}, + ) + + def test_router_with_failed_scraping_message(self): + node = MemoryInitializerNode(team=self.team) + state = AssistantState(messages=[AssistantMessage(content=FAILED_SCRAPING_MESSAGE)]) + self.assertEqual(node.router(state), "continue") + + def test_router_with_other_message(self): + node = MemoryInitializerNode(team=self.team) + state = AssistantState(messages=[AssistantMessage(content="Some other message")]) + self.assertEqual(node.router(state), "interrupt") + + def test_should_process_message_chunk_with_no_data_available(self): + from langchain_core.messages import AIMessageChunk + + chunk = AIMessageChunk(content="no data available.") + self.assertFalse(MemoryInitializerNode.should_process_message_chunk(chunk)) + + chunk = AIMessageChunk(content="NO DATA AVAILABLE for something") + self.assertFalse(MemoryInitializerNode.should_process_message_chunk(chunk)) + + def test_should_process_message_chunk_with_valid_data(self): + from langchain_core.messages import AIMessageChunk + + chunk = AIMessageChunk(content="PostHog is an open-source product analytics platform") + self.assertTrue(MemoryInitializerNode.should_process_message_chunk(chunk)) + + chunk = AIMessageChunk(content="This is a valid message that should be processed") + self.assertTrue(MemoryInitializerNode.should_process_message_chunk(chunk)) + + def test_format_message_removes_reference_tags(self): + message = "PostHog[1] is a product analytics platform[2]. It helps track user behavior[3]." + expected = "PostHog is a product analytics platform. It helps track user behavior." + self.assertEqual(MemoryInitializerNode.format_message(message), expected) + + def test_format_message_with_no_reference_tags(self): + message = "PostHog is a product analytics platform. It helps track user behavior." + self.assertEqual(MemoryInitializerNode.format_message(message), message) + + def test_run_with_url_based_initialization(self): + with patch.object(MemoryInitializerNode, "_model") as model_mock: + model_mock.return_value = RunnableLambda(lambda _: "PostHog is a product analytics platform.") + + self._set_up_pageview_events() + node = MemoryInitializerNode(team=self.team) + + new_state = node.run(AssistantState(messages=[HumanMessage(content="Hello")]), {}) + self.assertEqual(len(new_state.messages), 1) + self.assertIsInstance(new_state.messages[0], AssistantMessage) + self.assertEqual(new_state.messages[0].content, "PostHog is a product analytics platform.") + + core_memory = CoreMemory.objects.get(team=self.team) + self.assertEqual(core_memory.scraping_status, CoreMemory.ScrapingStatus.PENDING) + + flush_persons_and_events() + + def test_run_with_app_bundle_id_initialization(self): + with ( + patch.object(MemoryInitializerNode, "_model") as model_mock, + patch.object(MemoryInitializerNode, "_retrieve_context") as context_mock, + ): + context_mock.return_value = [ + EventTaxonomyItem(property="$app_namespace", sample_values=["com.posthog.app"], sample_count=1) + ] + model_mock.return_value = RunnableLambda(lambda _: "PostHog mobile app description.") + + self._set_up_app_bundle_id_events() + node = MemoryInitializerNode(team=self.team) + + new_state = node.run(AssistantState(messages=[HumanMessage(content="Hello")]), {}) + self.assertEqual(len(new_state.messages), 1) + self.assertIsInstance(new_state.messages[0], AssistantMessage) + self.assertEqual(new_state.messages[0].content, "PostHog mobile app description.") + + core_memory = CoreMemory.objects.get(team=self.team) + self.assertEqual(core_memory.scraping_status, CoreMemory.ScrapingStatus.PENDING) + + flush_persons_and_events() + + def test_run_with_no_data_available(self): + with ( + patch.object(MemoryInitializerNode, "_model") as model_mock, + patch.object(MemoryInitializerNode, "_retrieve_context") as context_mock, + ): + model_mock.return_value = RunnableLambda(lambda _: "no data available.") + context_mock.return_value = [] + + node = MemoryInitializerNode(team=self.team) + + with self.assertRaises(ValueError) as e: + node.run(AssistantState(messages=[HumanMessage(content="Hello")]), {}) + self.assertEqual(str(e.exception), "No host or app bundle ID found in the memory initializer.") + + +@override_settings(IN_UNIT_TESTING=True) +class TestMemoryInitializerInterruptNode(ClickhouseTestMixin, BaseTest): + def setUp(self): + super().setUp() + self.core_memory = CoreMemory.objects.create( + team=self.team, + scraping_status=CoreMemory.ScrapingStatus.PENDING, + scraping_started_at=timezone.now(), + ) + self.node = MemoryInitializerInterruptNode(team=self.team) + + def test_interrupt_when_not_resumed(self): + state = AssistantState(messages=[AssistantMessage(content="Product description")]) + + with self.assertRaises(NodeInterrupt) as e: + self.node.run(state, {}) + + interrupt_message = e.exception.args[0][0].value + self.assertIsInstance(interrupt_message, AssistantMessage) + self.assertEqual(interrupt_message.content, prompts.SCRAPING_VERIFICATION_MESSAGE) + self.assertIsNotNone(interrupt_message.meta) + self.assertEqual(len(interrupt_message.meta.form.options), 2) + self.assertEqual(interrupt_message.meta.form.options[0].value, prompts.SCRAPING_CONFIRMATION_MESSAGE) + self.assertEqual(interrupt_message.meta.form.options[1].value, prompts.SCRAPING_REJECTION_MESSAGE) + + def test_memory_accepted(self): + with patch.object(MemoryInitializerInterruptNode, "_model") as model_mock: + model_mock.return_value = RunnableLambda(lambda _: "Compressed memory") + + state = AssistantState( + messages=[ + AssistantMessage(content="Product description"), + HumanMessage(content=prompts.SCRAPING_CONFIRMATION_MESSAGE), + ], + resumed=True, + ) + + new_state = self.node.run(state, {}) + + self.assertEqual(len(new_state.messages), 1) + self.assertIsInstance(new_state.messages[0], AssistantMessage) + self.assertEqual( + new_state.messages[0].content, + prompts.SCRAPING_MEMORY_SAVED_MESSAGE, + ) + + core_memory = CoreMemory.objects.get(team=self.team) + self.assertEqual(core_memory.text, "Compressed memory") + self.assertEqual(core_memory.scraping_status, CoreMemory.ScrapingStatus.COMPLETED) + + def test_memory_rejected(self): + state = AssistantState( + messages=[ + AssistantMessage(content="Product description"), + HumanMessage(content=prompts.SCRAPING_REJECTION_MESSAGE), + ], + resumed=True, + ) + + new_state = self.node.run(state, {}) + + self.assertEqual(len(new_state.messages), 1) + self.assertIsInstance(new_state.messages[0], AssistantMessage) + self.assertEqual( + new_state.messages[0].content, + prompts.SCRAPING_TERMINATION_MESSAGE, + ) + + self.core_memory.refresh_from_db() + self.assertEqual(self.core_memory.scraping_status, CoreMemory.ScrapingStatus.SKIPPED) + + def test_error_when_last_message_not_human(self): + state = AssistantState( + messages=[AssistantMessage(content="Product description")], + resumed=True, + ) + + with self.assertRaises(ValueError) as e: + self.node.run(state, {}) + self.assertEqual(str(e.exception), "Last message is not a human message.") + + def test_error_when_no_core_memory(self): + self.core_memory.delete() + + state = AssistantState( + messages=[ + AssistantMessage(content="Product description"), + HumanMessage(content=prompts.SCRAPING_CONFIRMATION_MESSAGE), + ], + resumed=True, + ) + + with self.assertRaises(ValueError) as e: + self.node.run(state, {}) + self.assertEqual(str(e.exception), "No core memory found.") + + def test_error_when_no_memory_message(self): + state = AssistantState( + messages=[HumanMessage(content=prompts.SCRAPING_CONFIRMATION_MESSAGE)], + resumed=True, + ) + + with self.assertRaises(ValueError) as e: + self.node.run(state, {}) + self.assertEqual(str(e.exception), "No memory message found.") + + def test_format_memory(self): + markdown_text = "# Product Description\n\n- Feature 1\n- Feature 2\n\n**Bold text** and `code` [1]" + expected = "Product Description\n\nFeature 1\nFeature 2\n\nBold text and code [1]" + self.assertEqual(self.node._format_memory(markdown_text), expected) + + +@override_settings(IN_UNIT_TESTING=True) +class TestMemoryCollectorNode(ClickhouseTestMixin, BaseTest): + def setUp(self): + super().setUp() + self.core_memory = CoreMemory.objects.create(team=self.team) + self.core_memory.set_core_memory("Test product core memory") + self.node = MemoryCollectorNode(team=self.team) + + def test_router(self): + # Test with no memory collection messages + state = AssistantState(messages=[HumanMessage(content="Text")], memory_collection_messages=[]) + self.assertEqual(self.node.router(state), "next") + + # Test with memory collection messages + state = AssistantState( + messages=[HumanMessage(content="Text")], + memory_collection_messages=[LangchainAIMessage(content="Memory message")], + ) + self.assertEqual(self.node.router(state), "tools") + + def test_construct_messages(self): + # Test basic conversation reconstruction + state = AssistantState( + messages=[ + HumanMessage(content="Question 1", id="0"), + AssistantMessage(content="Answer 1", id="1"), + HumanMessage(content="Question 2", id="2"), + ], + start_id="2", + ) + history = self.node._construct_messages(state) + self.assertEqual(len(history), 3) + self.assertEqual(history[0].content, "Question 1") + self.assertEqual(history[1].content, "Answer 1") + self.assertEqual(history[2].content, "Question 2") + + # Test with memory collection messages + state = AssistantState( + messages=[HumanMessage(content="Question", id="0")], + memory_collection_messages=[ + LangchainAIMessage(content="Memory 1"), + LangchainToolMessage(content="Tool response", tool_call_id="1"), + ], + start_id="0", + ) + history = self.node._construct_messages(state) + self.assertEqual(len(history), 3) + self.assertEqual(history[0].content, "Question") + self.assertEqual(history[1].content, "Memory 1") + self.assertEqual(history[2].content, "Tool response") + + @freeze_time("2024-01-01") + def test_prompt_substitutions(self): + with patch.object(MemoryCollectorNode, "_model") as model_mock: + + def assert_prompt(prompt): + messages = prompt.to_messages() + + # Verify the structure of messages + self.assertEqual(len(messages), 3) + self.assertEqual(messages[0].type, "system") + self.assertEqual(messages[1].type, "human") + self.assertEqual(messages[2].type, "ai") + + # Verify system message content + system_message = messages[0].content + self.assertIn("Test product core memory", system_message) + self.assertIn("2024-01-01", system_message) + + # Verify conversation messages + self.assertEqual(messages[1].content, "We use a subscription model") + self.assertEqual(messages[2].content, "Memory message") + return LangchainAIMessage(content="[Done]") + + model_mock.return_value = RunnableLambda(assert_prompt) + + state = AssistantState( + messages=[ + HumanMessage(content="We use a subscription model", id="0"), + ], + memory_collection_messages=[ + LangchainAIMessage(content="Memory message"), + ], + start_id="0", + ) + + self.node.run(state, {}) + + def test_exits_on_done_message(self): + with patch.object(MemoryCollectorNode, "_model") as model_mock: + model_mock.return_value = RunnableLambda( + lambda _: LangchainAIMessage(content="Processing complete. [Done]") + ) + + state = AssistantState( + messages=[HumanMessage(content="Text")], + memory_collection_messages=[LangchainAIMessage(content="Previous memory")], + ) + + new_state = self.node.run(state, {}) + self.assertEqual(new_state.memory_updated, True) + self.assertEqual(new_state.memory_collection_messages, []) + + def test_appends_new_message(self): + with patch.object(MemoryCollectorNode, "_model") as model_mock: + model_mock.return_value = RunnableLambda( + lambda _: LangchainAIMessage( + content="New memory", + tool_calls=[ + { + "name": "core_memory_append", + "args": {"new_fragment": "New memory"}, + "id": "1", + }, + ], + ), + ) + + state = AssistantState( + messages=[HumanMessage(content="Text")], + memory_collection_messages=[LangchainAIMessage(content="Previous memory")], + ) + + new_state = self.node.run(state, {}) + self.assertEqual(len(new_state.memory_collection_messages), 2) + self.assertEqual(new_state.memory_collection_messages[0].content, "Previous memory") + self.assertEqual(new_state.memory_collection_messages[1].content, "New memory") + + def test_construct_messages_typical_conversation(self): + # Set up a typical conversation with multiple interactions + state = AssistantState( + messages=[ + HumanMessage(content="We use a subscription model", id="0"), + AssistantMessage(content="I'll note that down", id="1"), + HumanMessage(content="And we target enterprise customers", id="2"), + AssistantMessage(content="Let me process that information", id="3"), + HumanMessage(content="We also have a freemium tier", id="4"), + ], + memory_collection_messages=[ + LangchainAIMessage(content="Analyzing business model: subscription-based pricing."), + LangchainToolMessage(content="Memory appended.", tool_call_id="1"), + LangchainAIMessage(content="Analyzing target audience: enterprise customers."), + LangchainToolMessage(content="Memory appended.", tool_call_id="2"), + ], + start_id="0", + ) + + history = self.node._construct_messages(state) + + # Verify the complete conversation history is reconstructed correctly + self.assertEqual(len(history), 9) # 5 conversation messages + 4 memory messages + + # Check conversation messages + self.assertEqual(history[0].content, "We use a subscription model") + self.assertEqual(history[1].content, "I'll note that down") + self.assertEqual(history[2].content, "And we target enterprise customers") + self.assertEqual(history[3].content, "Let me process that information") + self.assertEqual(history[4].content, "We also have a freemium tier") + + # Check memory collection messages + self.assertEqual(history[5].content, "Analyzing business model: subscription-based pricing.") + self.assertEqual(history[6].content, "Memory appended.") + self.assertEqual(history[7].content, "Analyzing target audience: enterprise customers.") + self.assertEqual(history[8].content, "Memory appended.") + + +class TestMemoryCollectorToolsNode(BaseTest): + def setUp(self): + super().setUp() + self.core_memory = CoreMemory.objects.create(team=self.team) + self.core_memory.set_core_memory("Initial memory content") + self.node = MemoryCollectorToolsNode(team=self.team) + + def test_handles_correct_tools(self): + # Test handling a single append tool + state = AssistantState( + messages=[], + memory_collection_messages=[ + LangchainAIMessage( + content="Adding new memory", + tool_calls=[ + { + "name": "core_memory_append", + "args": {"memory_content": "New memory fragment."}, + "id": "1", + }, + { + "name": "core_memory_replace", + "args": { + "original_fragment": "Initial memory content", + "new_fragment": "New memory fragment 2.", + }, + "id": "2", + }, + ], + ) + ], + ) + + new_state = self.node.run(state, {}) + self.assertEqual(len(new_state.memory_collection_messages), 3) + self.assertEqual(new_state.memory_collection_messages[1].type, "tool") + self.assertEqual(new_state.memory_collection_messages[1].content, "Memory appended.") + self.assertEqual(new_state.memory_collection_messages[2].type, "tool") + self.assertEqual(new_state.memory_collection_messages[2].content, "Memory replaced.") + + def test_handles_validation_error(self): + # Test handling validation error with incorrect tool arguments + state = AssistantState( + messages=[], + memory_collection_messages=[ + LangchainAIMessage( + content="Invalid tool call", + tool_calls=[ + { + "name": "core_memory_append", + "args": {"invalid_arg": "This will fail"}, + "id": "1", + } + ], + ) + ], + ) + + new_state = self.node.run(state, {}) + self.assertEqual(len(new_state.memory_collection_messages), 2) + self.assertNotIn("{{validation_error_message}}", new_state.memory_collection_messages[1].content) + + def test_handles_multiple_tools(self): + # Test handling multiple tool calls in a single message + state = AssistantState( + messages=[], + memory_collection_messages=[ + LangchainAIMessage( + content="Multiple operations", + tool_calls=[ + { + "name": "core_memory_append", + "args": {"memory_content": "First memory"}, + "id": "1", + }, + { + "name": "core_memory_append", + "args": {"memory_content": "Second memory"}, + "id": "2", + }, + { + "name": "core_memory_replace", + "args": { + "original_fragment": "Initial memory content", + "new_fragment": "Third memory", + }, + "id": "3", + }, + ], + ) + ], + ) + + new_state = self.node.run(state, {}) + self.assertEqual(len(new_state.memory_collection_messages), 4) + self.assertEqual(new_state.memory_collection_messages[1].content, "Memory appended.") + self.assertEqual(new_state.memory_collection_messages[1].type, "tool") + self.assertEqual(new_state.memory_collection_messages[1].tool_call_id, "1") + self.assertEqual(new_state.memory_collection_messages[2].content, "Memory appended.") + self.assertEqual(new_state.memory_collection_messages[2].type, "tool") + self.assertEqual(new_state.memory_collection_messages[2].tool_call_id, "2") + self.assertEqual(new_state.memory_collection_messages[3].content, "Memory replaced.") + self.assertEqual(new_state.memory_collection_messages[3].type, "tool") + self.assertEqual(new_state.memory_collection_messages[3].tool_call_id, "3") + + self.core_memory.refresh_from_db() + self.assertEqual(self.core_memory.text, "Third memory\nFirst memory\nSecond memory") + + def test_handles_replacing_memory(self): + # Test replacing a memory fragment + state = AssistantState( + messages=[], + memory_collection_messages=[ + LangchainAIMessage( + content="Replacing memory", + tool_calls=[ + { + "name": "core_memory_replace", + "args": { + "original_fragment": "Initial memory", + "new_fragment": "Updated memory", + }, + "id": "1", + } + ], + ) + ], + ) + + new_state = self.node.run(state, {}) + self.assertEqual(len(new_state.memory_collection_messages), 2) + self.assertEqual(new_state.memory_collection_messages[1].content, "Memory replaced.") + self.assertEqual(new_state.memory_collection_messages[1].type, "tool") + self.assertEqual(new_state.memory_collection_messages[1].tool_call_id, "1") + self.core_memory.refresh_from_db() + self.assertEqual(self.core_memory.text, "Updated memory content") + + def test_handles_replace_memory_not_found(self): + # Test replacing a memory fragment that doesn't exist + state = AssistantState( + messages=[], + memory_collection_messages=[ + LangchainAIMessage( + content="Replacing non-existent memory", + tool_calls=[ + { + "name": "core_memory_replace", + "args": { + "original_fragment": "Non-existent memory", + "new_fragment": "New memory", + }, + "id": "1", + } + ], + ) + ], + ) + + new_state = self.node.run(state, {}) + self.assertEqual(len(new_state.memory_collection_messages), 2) + self.assertIn("not found", new_state.memory_collection_messages[1].content.lower()) + self.assertEqual(new_state.memory_collection_messages[1].type, "tool") + self.assertEqual(new_state.memory_collection_messages[1].tool_call_id, "1") + self.core_memory.refresh_from_db() + self.assertEqual(self.core_memory.text, "Initial memory content") + + def test_handles_appending_new_memory(self): + # Test appending a new memory fragment + state = AssistantState( + messages=[], + memory_collection_messages=[ + LangchainAIMessage( + content="Appending memory", + tool_calls=[ + { + "name": "core_memory_append", + "args": {"memory_content": "Additional memory"}, + "id": "1", + } + ], + ) + ], + ) + + new_state = self.node.run(state, {}) + self.assertEqual(len(new_state.memory_collection_messages), 2) + self.assertEqual(new_state.memory_collection_messages[1].content, "Memory appended.") + self.assertEqual(new_state.memory_collection_messages[1].type, "tool") + self.core_memory.refresh_from_db() + self.assertEqual(self.core_memory.text, "Initial memory content\nAdditional memory") + + def test_error_when_no_memory_collection_messages(self): + # Test error when no memory collection messages are present + state = AssistantState(messages=[], memory_collection_messages=[]) + + with self.assertRaises(ValueError) as e: + self.node.run(state, {}) + self.assertEqual(str(e.exception), "No memory collection messages found.") + + def test_error_when_last_message_not_ai(self): + # Test error when last message is not an AI message + state = AssistantState( + messages=[], + memory_collection_messages=[LangchainToolMessage(content="Not an AI message", tool_call_id="1")], + ) + + with self.assertRaises(ValueError) as e: + self.node.run(state, {}) + self.assertEqual(str(e.exception), "Last message must be an AI message.") + + def test_error_when_no_core_memory(self): + # Test error when core memory is not found + self.core_memory.delete() + state = AssistantState( + messages=[], + memory_collection_messages=[ + LangchainAIMessage( + content="Memory operation", + tool_calls=[ + { + "name": "core_memory_append", + "args": {"memory_content": "New memory"}, + "id": "1", + } + ], + ) + ], + ) + + with self.assertRaises(ValueError) as e: + self.node.run(state, {}) + self.assertEqual(str(e.exception), "No core memory found.") diff --git a/ee/hogai/memory/test/test_parsers.py b/ee/hogai/memory/test/test_parsers.py new file mode 100644 index 0000000000000..8f98d18815cf5 --- /dev/null +++ b/ee/hogai/memory/test/test_parsers.py @@ -0,0 +1,22 @@ +from langchain_core.messages import AIMessage + +from ee.hogai.memory.parsers import MemoryCollectionCompleted, compressed_memory_parser, raise_memory_updated +from posthog.test.base import BaseTest + + +class TestParsers(BaseTest): + def test_compressed_memory_parser(self): + memory = "Hello\n\nWorld " + self.assertEqual(compressed_memory_parser(memory), "Hello\nWorld ") + + def test_raise_memory_updated(self): + message = AIMessage(content="Hello World") + with self.assertRaises(MemoryCollectionCompleted): + raise_memory_updated(message) + + message = AIMessage(content="[Done]", tool_calls=[{"id": "1", "args": {}, "name": "function"}]) + with self.assertRaises(MemoryCollectionCompleted): + raise_memory_updated(message) + + message = AIMessage(content="Reasoning", tool_calls=[{"id": "1", "args": {}, "name": "function"}]) + self.assertEqual(raise_memory_updated(message), message) diff --git a/ee/hogai/summarizer/nodes.py b/ee/hogai/summarizer/nodes.py index f03db53879268..eed9d5c873ffd 100644 --- a/ee/hogai/summarizer/nodes.py +++ b/ee/hogai/summarizer/nodes.py @@ -4,8 +4,8 @@ from uuid import uuid4 from django.conf import settings -from django.utils import timezone from django.core.serializers.json import DjangoJSONEncoder +from django.utils import timezone from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableConfig from langchain_openai import ChatOpenAI @@ -84,7 +84,7 @@ def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistant message = chain.invoke( { "query_kind": viz_message.answer.kind, - "product_description": self._team.project.product_description, + "core_memory": self.core_memory_text, "results": json.dumps(results_response["results"], cls=DjangoJSONEncoder), "utc_datetime_display": utc_now.strftime("%Y-%m-%d %H:%M:%S"), "project_datetime_display": project_now.strftime("%Y-%m-%d %H:%M:%S"), diff --git a/ee/hogai/summarizer/prompts.py b/ee/hogai/summarizer/prompts.py index fbb8a2dd39dbf..95a74cd32a6d1 100644 --- a/ee/hogai/summarizer/prompts.py +++ b/ee/hogai/summarizer/prompts.py @@ -10,7 +10,10 @@ You can use Markdown for emphasis. Bullets can improve clarity of action points. The product being analyzed is described as follows: -{{product_description}}""" + +{{core_memory}} + +""" SUMMARIZER_INSTRUCTION_PROMPT = """ Here are results of the {{query_kind}} you created to answer my latest question: diff --git a/ee/hogai/summarizer/test/test_nodes.py b/ee/hogai/summarizer/test/test_nodes.py index 9c54517717b5f..0dc67034277dc 100644 --- a/ee/hogai/summarizer/test/test_nodes.py +++ b/ee/hogai/summarizer/test/test_nodes.py @@ -155,8 +155,6 @@ def test_node_requires_viz_message_in_state_to_have_query(self): ) def test_agent_reconstructs_conversation(self): - self.project.product_description = "Dating app for lonely hedgehogs." - self.project.save() node = SummarizerNode(self.team) history = node._construct_messages( diff --git a/ee/hogai/taxonomy_agent/nodes.py b/ee/hogai/taxonomy_agent/nodes.py index 92351fdd1f4e5..d12385952b5b9 100644 --- a/ee/hogai/taxonomy_agent/nodes.py +++ b/ee/hogai/taxonomy_agent/nodes.py @@ -25,6 +25,7 @@ parse_react_agent_output, ) from ee.hogai.taxonomy_agent.prompts import ( + CORE_MEMORY_INSTRUCTIONS, REACT_DEFINITIONS_PROMPT, REACT_FOLLOW_UP_PROMPT, REACT_FORMAT_PROMPT, @@ -88,13 +89,14 @@ def _run_with_prompt_and_toolkit( agent.invoke( { "react_format": self._get_react_format_prompt(toolkit), + "core_memory": self.core_memory.text if self.core_memory else "", "react_format_reminder": REACT_FORMAT_REMINDER_PROMPT, "react_property_filters": self._get_react_property_filters_prompt(), "react_human_in_the_loop": REACT_HUMAN_IN_THE_LOOP_PROMPT, - "product_description": self._team.project.product_description, "groups": self._team_group_types, "events": self._events_prompt, "agent_scratchpad": self._get_agent_scratchpad(intermediate_steps), + "core_memory_instructions": CORE_MEMORY_INSTRUCTIONS, }, config, ), diff --git a/ee/hogai/taxonomy_agent/prompts.py b/ee/hogai/taxonomy_agent/prompts.py index c9d409bcdf103..4779da4e7f0d3 100644 --- a/ee/hogai/taxonomy_agent/prompts.py +++ b/ee/hogai/taxonomy_agent/prompts.py @@ -134,3 +134,7 @@ You must fix the exception and try again. """ + +CORE_MEMORY_INSTRUCTIONS = """ +You have access to the core memory in the tag, which stores information about the user's company and product. +""".strip() diff --git a/ee/hogai/test/test_assistant.py b/ee/hogai/test/test_assistant.py index a5760f1ca7e58..525c4ff134900 100644 --- a/ee/hogai/test/test_assistant.py +++ b/ee/hogai/test/test_assistant.py @@ -11,9 +11,10 @@ from pydantic import BaseModel from ee.hogai.funnels.nodes import FunnelsSchemaGeneratorOutput +from ee.hogai.memory import prompts as memory_prompts from ee.hogai.router.nodes import RouterOutput from ee.hogai.trends.nodes import TrendsSchemaGeneratorOutput -from ee.models.assistant import Conversation +from ee.models.assistant import Conversation, CoreMemory from posthog.schema import ( AssistantFunnelsEventsNode, AssistantFunnelsQuery, @@ -25,18 +26,37 @@ RouterMessage, VisualizationMessage, ) -from posthog.test.base import NonAtomicBaseTest +from posthog.test.base import ClickhouseTestMixin, NonAtomicBaseTest, _create_event, _create_person from ..assistant import Assistant from ..graph import AssistantGraph, AssistantNodeName -class TestAssistant(NonAtomicBaseTest): +class TestAssistant(ClickhouseTestMixin, NonAtomicBaseTest): CLASS_DATA_LEVEL_SETUP = False def setUp(self): super().setUp() self.conversation = Conversation.objects.create(team=self.team, user=self.user) + self.core_memory = CoreMemory.objects.create( + team=self.team, + text="Initial memory.", + initial_text="Initial memory.", + scraping_status=CoreMemory.ScrapingStatus.COMPLETED, + ) + + def _set_up_onboarding_tests(self): + self.core_memory.delete() + _create_person( + distinct_ids=["person1"], + team=self.team, + ) + _create_event( + event="$pageview", + distinct_id="person1", + team=self.team, + properties={"$host": "us.posthog.com"}, + ) def _parse_stringified_message(self, message: str) -> tuple[str, Any]: event_line, data_line, *_ = cast(str, message).split("\n") @@ -426,7 +446,8 @@ def node_handler(state): @patch("ee.hogai.schema_generator.nodes.SchemaGeneratorNode._model") @patch("ee.hogai.taxonomy_agent.nodes.TaxonomyAgentPlannerNode._model") @patch("ee.hogai.router.nodes.RouterNode._model") - def test_full_trends_flow(self, router_mock, planner_mock, generator_mock, summarizer_mock): + @patch("ee.hogai.memory.nodes.MemoryCollectorNode._model", return_value=messages.AIMessage(content="[Done]")) + def test_full_trends_flow(self, memory_collector_mock, router_mock, planner_mock, generator_mock, summarizer_mock): router_mock.return_value = RunnableLambda(lambda _: RouterOutput(visualization_type="trends")) planner_mock.return_value = RunnableLambda( lambda _: messages.AIMessage( @@ -476,7 +497,8 @@ def test_full_trends_flow(self, router_mock, planner_mock, generator_mock, summa @patch("ee.hogai.schema_generator.nodes.SchemaGeneratorNode._model") @patch("ee.hogai.taxonomy_agent.nodes.TaxonomyAgentPlannerNode._model") @patch("ee.hogai.router.nodes.RouterNode._model") - def test_full_funnel_flow(self, router_mock, planner_mock, generator_mock, summarizer_mock): + @patch("ee.hogai.memory.nodes.MemoryCollectorNode._model", return_value=messages.AIMessage(content="[Done]")) + def test_full_funnel_flow(self, memory_collector_mock, router_mock, planner_mock, generator_mock, summarizer_mock): router_mock.return_value = RunnableLambda(lambda _: RouterOutput(visualization_type="funnel")) planner_mock.return_value = RunnableLambda( lambda _: messages.AIMessage( @@ -526,3 +548,145 @@ def test_full_funnel_flow(self, router_mock, planner_mock, generator_mock, summa actual_output = self._run_assistant_graph(is_new_conversation=False) self.assertConversationEqual(actual_output, expected_output[1:]) self.assertEqual(actual_output[0][1]["id"], actual_output[6][1]["initiator"]) + + @patch("ee.hogai.memory.nodes.MemoryInitializerInterruptNode._model") + @patch("ee.hogai.memory.nodes.MemoryInitializerNode._model") + def test_onboarding_flow_accepts_memory(self, model_mock, interruption_model_mock): + self._set_up_onboarding_tests() + + # Mock the memory initializer to return a product description + model_mock.return_value = RunnableLambda(lambda _: "PostHog is a product analytics platform.") + interruption_model_mock.return_value = RunnableLambda(lambda _: "PostHog is a product analytics platform.") + + # Create a graph with memory initialization flow + graph = AssistantGraph(self.team).add_memory_initializer(AssistantNodeName.END).compile() + + # First run - get the product description + output = self._run_assistant_graph(graph, is_new_conversation=True) + expected_output = [ + ("conversation", {"id": str(self.conversation.id)}), + ("message", HumanMessage(content="Hello")), + ( + "message", + AssistantMessage( + content=memory_prompts.SCRAPING_INITIAL_MESSAGE, + ), + ), + ("message", AssistantMessage(content="PostHog is a product analytics platform.")), + ("message", AssistantMessage(content=memory_prompts.SCRAPING_VERIFICATION_MESSAGE)), + ] + self.assertConversationEqual(output, expected_output) + + # Second run - accept the memory + output = self._run_assistant_graph( + graph, + message=memory_prompts.SCRAPING_CONFIRMATION_MESSAGE, + is_new_conversation=False, + ) + expected_output = [ + ("message", HumanMessage(content=memory_prompts.SCRAPING_CONFIRMATION_MESSAGE)), + ( + "message", + AssistantMessage(content=memory_prompts.SCRAPING_MEMORY_SAVED_MESSAGE), + ), + ("message", ReasoningMessage(content="Identifying type of analysis")), + ] + self.assertConversationEqual(output, expected_output) + + # Verify the memory was saved + core_memory = CoreMemory.objects.get(team=self.team) + self.assertEqual(core_memory.scraping_status, CoreMemory.ScrapingStatus.COMPLETED) + self.assertIsNotNone(core_memory.text) + + @patch("ee.hogai.memory.nodes.MemoryInitializerNode._model") + def test_onboarding_flow_rejects_memory(self, model_mock): + self._set_up_onboarding_tests() + + # Mock the memory initializer to return a product description + model_mock.return_value = RunnableLambda(lambda _: "PostHog is a product analytics platform.") + + # Create a graph with memory initialization flow + graph = AssistantGraph(self.team).add_memory_initializer(AssistantNodeName.END).compile() + + # First run - get the product description + output = self._run_assistant_graph(graph, is_new_conversation=True) + expected_output = [ + ("conversation", {"id": str(self.conversation.id)}), + ("message", HumanMessage(content="Hello")), + ( + "message", + AssistantMessage( + content=memory_prompts.SCRAPING_INITIAL_MESSAGE, + ), + ), + ("message", AssistantMessage(content="PostHog is a product analytics platform.")), + ("message", AssistantMessage(content=memory_prompts.SCRAPING_VERIFICATION_MESSAGE)), + ] + self.assertConversationEqual(output, expected_output) + + # Second run - reject the memory + output = self._run_assistant_graph( + graph, + message=memory_prompts.SCRAPING_REJECTION_MESSAGE, + is_new_conversation=False, + ) + expected_output = [ + ("message", HumanMessage(content=memory_prompts.SCRAPING_REJECTION_MESSAGE)), + ( + "message", + AssistantMessage( + content=memory_prompts.SCRAPING_TERMINATION_MESSAGE, + ), + ), + ("message", ReasoningMessage(content="Identifying type of analysis")), + ] + self.assertConversationEqual(output, expected_output) + + # Verify the memory was skipped + core_memory = CoreMemory.objects.get(team=self.team) + self.assertEqual(core_memory.scraping_status, CoreMemory.ScrapingStatus.SKIPPED) + self.assertEqual(core_memory.text, "") + + @patch("ee.hogai.memory.nodes.MemoryCollectorNode._model") + def test_memory_collector_flow(self, model_mock): + # Create a graph with just memory collection + graph = ( + AssistantGraph(self.team).add_memory_collector(AssistantNodeName.END).add_memory_collector_tools().compile() + ) + + # Mock the memory collector to first analyze and then append memory + def memory_collector_side_effect(prompt): + prompt_messages = prompt.to_messages() + if len(prompt_messages) == 2: # First run + return messages.AIMessage( + content="Let me analyze that.", + tool_calls=[ + { + "id": "1", + "name": "core_memory_append", + "args": {"memory_content": "The product uses a subscription model."}, + } + ], + ) + else: # Second run + return messages.AIMessage(content="Processing complete. [Done]") + + model_mock.return_value = RunnableLambda(memory_collector_side_effect) + + # First run - analyze and append memory + output = self._run_assistant_graph( + graph, + message="We use a subscription model", + is_new_conversation=True, + ) + expected_output = [ + ("conversation", {"id": str(self.conversation.id)}), + ("message", HumanMessage(content="We use a subscription model")), + ("message", AssistantMessage(content="Let me analyze that.")), + ("message", AssistantMessage(content="Memory appended.")), + ] + self.assertConversationEqual(output, expected_output) + + # Verify memory was appended + self.core_memory.refresh_from_db() + self.assertIn("The product uses a subscription model.", self.core_memory.text) diff --git a/ee/hogai/trends/prompts.py b/ee/hogai/trends/prompts.py index dcc1daeaa5a00..b04f00c15b13b 100644 --- a/ee/hogai/trends/prompts.py +++ b/ee/hogai/trends/prompts.py @@ -2,16 +2,15 @@ You are an expert product analyst agent specializing in data visualization and trends analysis. Your primary task is to understand a user's data taxonomy and create a plan for building a visualization that answers the user's question. This plan should focus on trends insights, including a series of events, property filters, and values of property filters. -{{#product_description}} -The product being analyzed is described as follows: - -{{.}} - -{{/product_description}} +{{core_memory_instructions}} {{react_format}} + +{{core_memory}} + + {{react_human_in_the_loop}} Below you will find information on how to correctly discover the taxonomy of the user's data. diff --git a/ee/hogai/utils/helpers.py b/ee/hogai/utils/helpers.py index 4fc8cf3b5d6a0..b09e439c6ab52 100644 --- a/ee/hogai/utils/helpers.py +++ b/ee/hogai/utils/helpers.py @@ -13,7 +13,7 @@ VisualizationMessage, ) -from .types import AIMessageUnion, AssistantMessageUnion +from .types import AssistantMessageUnion def remove_line_breaks(line: str) -> str: @@ -22,7 +22,7 @@ def remove_line_breaks(line: str) -> str: def filter_messages( messages: Sequence[AssistantMessageUnion], - entity_filter: Union[tuple[type[AIMessageUnion], ...], type[AIMessageUnion]] = ( + entity_filter: Union[tuple[type[AssistantMessageUnion], ...], type[AssistantMessageUnion]] = ( AssistantMessage, VisualizationMessage, ), diff --git a/ee/hogai/utils/markdown.py b/ee/hogai/utils/markdown.py new file mode 100644 index 0000000000000..279fc17ffbccc --- /dev/null +++ b/ee/hogai/utils/markdown.py @@ -0,0 +1,111 @@ +from collections.abc import Sequence +from html.parser import HTMLParser +from inspect import getmembers, ismethod + +from markdown_it import MarkdownIt +from markdown_it.renderer import RendererProtocol +from markdown_it.token import Token +from markdown_it.utils import EnvType, OptionsDict + + +# Taken from https://github.com/elespike/mdit_plain/blob/main/src/mdit_plain/renderer.py +class HTMLTextRenderer(HTMLParser): + def __init__(self): + super().__init__() + self._handled_data = [] + + def handle_data(self, data): + self._handled_data.append(data) + + def reset(self): + self._handled_data = [] + super().reset() + + def render(self, html): + self.feed(html) + rendered_data = "".join(self._handled_data) + self.reset() + return rendered_data + + +class RendererPlain(RendererProtocol): + __output__ = "plain" + + def __init__(self, parser=None): + self.parser = parser + self.htmlparser = HTMLTextRenderer() + self.rules = { + func_name.replace("render_", ""): func + for func_name, func in getmembers(self, predicate=ismethod) + if func_name.startswith("render_") + } + + def render(self, tokens: Sequence[Token], options: OptionsDict, env: EnvType): + result = "" + for i, token in enumerate(tokens): + rule = self.rules.get(token.type, self.render_default) + result += rule(tokens, i, options, env) + if token.children is not None: + result += self.render(token.children, options, env) + return result.strip() + + def render_default(self, tokens, i, options, env): + return "" + + def render_bullet_list_close(self, tokens, i, options, env): + if (i + 1) == len(tokens) or "list" in tokens[i + 1].type: + return "" + return "\n" + + def render_code_block(self, tokens, i, options, env): + return f"\n{tokens[i].content}\n" + + def render_code_inline(self, tokens, i, options, env): + return tokens[i].content + + def render_fence(self, tokens, i, options, env): + return f"\n{tokens[i].content}\n" + + def render_hardbreak(self, tokens, i, options, env): + return "\n" + + def render_heading_close(self, tokens, i, options, env): + return "\n" + + def render_heading_open(self, tokens, i, options, env): + return "\n" + + def render_html_block(self, tokens, i, options, env): + return self.htmlparser.render(tokens[i].content) + + def render_list_item_open(self, tokens, i, options, env): + next_token = tokens[i + 1] + if hasattr(next_token, "hidden") and not next_token.hidden: + return "" + return "\n" + + def render_ordered_list_close(self, tokens, i, options, env): + if (i + 1) == len(tokens) or "list" in tokens[i + 1].type: + return "" + return "\n" + + def render_paragraph_close(self, tokens, i, options, env): + if tokens[i].hidden: + return "" + return "\n" + + def render_paragraph_open(self, tokens, i, options, env): + if tokens[i].hidden: + return "" + return "\n" + + def render_softbreak(self, tokens, i, options, env): + return "\n" + + def render_text(self, tokens, i, options, env): + return tokens[i].content + + +def remove_markdown(text: str) -> str: + parser = MarkdownIt(renderer_cls=RendererPlain) + return parser.render(text) diff --git a/ee/hogai/utils/nodes.py b/ee/hogai/utils/nodes.py index 6a4358243b666..b727e643e3d2f 100644 --- a/ee/hogai/utils/nodes.py +++ b/ee/hogai/utils/nodes.py @@ -2,6 +2,7 @@ from langchain_core.runnables import RunnableConfig +from ee.models.assistant import CoreMemory from posthog.models.team.team import Team from .types import AssistantState, PartialAssistantState @@ -14,5 +15,18 @@ def __init__(self, team: Team): self._team = team @abstractmethod - def run(cls, state: AssistantState, config: RunnableConfig) -> PartialAssistantState: + def run(cls, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None: raise NotImplementedError + + @property + def core_memory(self) -> CoreMemory | None: + try: + return CoreMemory.objects.get(team=self._team) + except CoreMemory.DoesNotExist: + return None + + @property + def core_memory_text(self) -> str: + if not self.core_memory: + return "" + return self.core_memory.formatted_text diff --git a/ee/hogai/utils/test/test_assistant_node.py b/ee/hogai/utils/test/test_assistant_node.py new file mode 100644 index 0000000000000..16946db36cb85 --- /dev/null +++ b/ee/hogai/utils/test/test_assistant_node.py @@ -0,0 +1,31 @@ +from langchain_core.runnables import RunnableConfig + +from ee.hogai.utils.nodes import AssistantNode +from ee.hogai.utils.types import AssistantState, PartialAssistantState +from ee.models.assistant import CoreMemory +from posthog.test.base import BaseTest + + +class TestAssistantNode(BaseTest): + def setUp(self): + super().setUp() + + class Node(AssistantNode): + def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None: + raise NotImplementedError + + self.node = Node(self.team) + + def test_core_memory_when_exists(self): + core_memory = CoreMemory.objects.create(team=self.team, text="Test memory") + self.assertEqual(self.node.core_memory, core_memory) + + def test_core_memory_when_does_not_exist(self): + self.assertIsNone(self.node.core_memory) + + def test_product_core_memory_when_exists(self): + CoreMemory.objects.create(team=self.team, text="Test memory") + self.assertEqual(self.node.core_memory_text, "Test memory") + + def test_product_core_memory_when_does_not_exist(self): + self.assertEqual(self.node.core_memory_text, "") diff --git a/ee/hogai/utils/types.py b/ee/hogai/utils/types.py index b3ac57623fa56..2b92ecdedc176 100644 --- a/ee/hogai/utils/types.py +++ b/ee/hogai/utils/types.py @@ -4,6 +4,7 @@ from typing import Annotated, Optional, Union from langchain_core.agents import AgentAction +from langchain_core.messages import BaseMessage as LangchainBaseMessage from langgraph.graph import END, START from pydantic import BaseModel, Field @@ -28,6 +29,17 @@ class _SharedAssistantState(BaseModel): """ plan: Optional[str] = Field(default=None) resumed: Optional[bool] = Field(default=None) + """ + Whether the agent was resumed after interruption, such as a human in the loop. + """ + memory_updated: Optional[bool] = Field(default=None) + """ + Whether the memory was updated in the `MemoryCollectorNode`. + """ + memory_collection_messages: Optional[Sequence[LangchainBaseMessage]] = Field(default=None) + """ + The messages with tool calls to collect memory in the `MemoryCollectorToolsNode`. + """ class AssistantState(_SharedAssistantState): @@ -41,6 +53,9 @@ class PartialAssistantState(_SharedAssistantState): class AssistantNodeName(StrEnum): START = START END = END + MEMORY_ONBOARDING = "memory_onboarding" + MEMORY_INITIALIZER = "memory_initializer" + MEMORY_INITIALIZER_INTERRUPT = "memory_initializer_interrupt" ROUTER = "router" TRENDS_PLANNER = "trends_planner" TRENDS_PLANNER_TOOLS = "trends_planner_tools" @@ -55,3 +70,5 @@ class AssistantNodeName(StrEnum): RETENTION_GENERATOR = "retention_generator" RETENTION_GENERATOR_TOOLS = "retention_generator_tools" SUMMARIZER = "summarizer" + MEMORY_COLLECTOR = "memory_collector" + MEMORY_COLLECTOR_TOOLS = "memory_collector_tools" diff --git a/ee/migrations/0020_corememory.py b/ee/migrations/0020_corememory.py new file mode 100644 index 0000000000000..a66baec6e5542 --- /dev/null +++ b/ee/migrations/0020_corememory.py @@ -0,0 +1,45 @@ +# Generated by Django 4.2.15 on 2024-12-20 15:14 + +from django.db import migrations, models +import django.db.models.deletion +import posthog.models.utils + + +class Migration(migrations.Migration): + dependencies = [ + ("posthog", "0535_alter_hogfunction_type"), + ("ee", "0019_remove_conversationcheckpointblob_unique_checkpoint_blob_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="CoreMemory", + fields=[ + ( + "id", + models.UUIDField( + default=posthog.models.utils.UUIDT, editable=False, primary_key=True, serialize=False + ), + ), + ( + "text", + models.TextField(default="", help_text="Dumped core memory where facts are separated by newlines."), + ), + ("initial_text", models.TextField(default="", help_text="Scraped memory about the business.")), + ( + "scraping_status", + models.CharField( + blank=True, + choices=[("pending", "Pending"), ("completed", "Completed"), ("skipped", "Skipped")], + max_length=20, + null=True, + ), + ), + ("scraping_started_at", models.DateTimeField(null=True)), + ("team", models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="posthog.team")), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/ee/migrations/max_migration.txt b/ee/migrations/max_migration.txt index aec0628d960c8..cd0433c401973 100644 --- a/ee/migrations/max_migration.txt +++ b/ee/migrations/max_migration.txt @@ -1 +1 @@ -0019_remove_conversationcheckpointblob_unique_checkpoint_blob_and_more +0020_corememory diff --git a/ee/models/__init__.py b/ee/models/__init__.py index 2067d11f7618f..d1dfa7e8dce85 100644 --- a/ee/models/__init__.py +++ b/ee/models/__init__.py @@ -1,4 +1,10 @@ -from .assistant import Conversation, ConversationCheckpoint, ConversationCheckpointBlob, ConversationCheckpointWrite +from .assistant import ( + Conversation, + ConversationCheckpoint, + ConversationCheckpointBlob, + ConversationCheckpointWrite, + CoreMemory, +) from .dashboard_privilege import DashboardPrivilege from .event_definition import EnterpriseEventDefinition from .explicit_team_membership import ExplicitTeamMembership @@ -14,6 +20,7 @@ "ConversationCheckpoint", "ConversationCheckpointBlob", "ConversationCheckpointWrite", + "CoreMemory", "DashboardPrivilege", "Conversation", "EnterpriseEventDefinition", diff --git a/ee/models/assistant.py b/ee/models/assistant.py index f2a31d938f5d0..90ac31e339b01 100644 --- a/ee/models/assistant.py +++ b/ee/models/assistant.py @@ -1,6 +1,8 @@ from collections.abc import Iterable +from datetime import timedelta from django.db import models +from django.utils import timezone from langgraph.checkpoint.serde.types import TASKS from posthog.models.team.team import Team @@ -89,3 +91,55 @@ class Meta: name="unique_checkpoint_write", ) ] + + +class CoreMemory(UUIDModel): + class ScrapingStatus(models.TextChoices): + PENDING = "pending", "Pending" + COMPLETED = "completed", "Completed" + SKIPPED = "skipped", "Skipped" + + team = models.OneToOneField(Team, on_delete=models.CASCADE) + text = models.TextField(default="", help_text="Dumped core memory where facts are separated by newlines.") + initial_text = models.TextField(default="", help_text="Scraped memory about the business.") + scraping_status = models.CharField(max_length=20, choices=ScrapingStatus.choices, blank=True, null=True) + scraping_started_at = models.DateTimeField(null=True) + + def change_status_to_pending(self): + self.scraping_started_at = timezone.now() + self.scraping_status = CoreMemory.ScrapingStatus.PENDING + self.save() + + def change_status_to_skipped(self): + self.scraping_status = CoreMemory.ScrapingStatus.SKIPPED + self.save() + + @property + def is_scraping_pending(self) -> bool: + return self.scraping_status == CoreMemory.ScrapingStatus.PENDING and ( + self.scraping_started_at is None or (self.scraping_started_at + timedelta(minutes=5)) > timezone.now() + ) + + @property + def is_scraping_finished(self) -> bool: + return self.scraping_status in [CoreMemory.ScrapingStatus.COMPLETED, CoreMemory.ScrapingStatus.SKIPPED] + + def set_core_memory(self, text: str): + self.text = text + self.initial_text = text + self.scraping_status = CoreMemory.ScrapingStatus.COMPLETED + self.save() + + def append_core_memory(self, text: str): + self.text = self.text + "\n" + text + self.save() + + def replace_core_memory(self, original_fragment: str, new_fragment: str): + if original_fragment not in self.text: + raise ValueError(f"Original fragment {original_fragment} not found in core memory") + self.text = self.text.replace(original_fragment, new_fragment) + self.save() + + @property + def formatted_text(self) -> str: + return self.text[0:5000] diff --git a/ee/models/test/test_assistant.py b/ee/models/test/test_assistant.py new file mode 100644 index 0000000000000..daf00499bbe32 --- /dev/null +++ b/ee/models/test/test_assistant.py @@ -0,0 +1,95 @@ +from datetime import timedelta + +from django.utils import timezone +from freezegun import freeze_time + +from ee.models.assistant import CoreMemory +from posthog.test.base import BaseTest + + +class TestCoreMemory(BaseTest): + def setUp(self): + super().setUp() + self.core_memory = CoreMemory.objects.create(team=self.team) + + def test_status_changes(self): + # Test pending status + self.core_memory.change_status_to_pending() + self.assertEqual(self.core_memory.scraping_status, CoreMemory.ScrapingStatus.PENDING) + self.assertIsNotNone(self.core_memory.scraping_started_at) + + # Test skipped status + self.core_memory.change_status_to_skipped() + self.assertEqual(self.core_memory.scraping_status, CoreMemory.ScrapingStatus.SKIPPED) + + def test_scraping_status_properties(self): + # Test pending status within time window + self.core_memory.change_status_to_pending() + self.assertTrue(self.core_memory.is_scraping_pending) + + # Test pending status outside time window + self.core_memory.scraping_started_at = timezone.now() - timedelta(minutes=6) + self.core_memory.save() + self.assertFalse(self.core_memory.is_scraping_pending) + + # Test finished status + self.core_memory.scraping_status = CoreMemory.ScrapingStatus.COMPLETED + self.core_memory.save() + self.assertTrue(self.core_memory.is_scraping_finished) + + self.core_memory.scraping_status = CoreMemory.ScrapingStatus.SKIPPED + self.core_memory.save() + self.assertTrue(self.core_memory.is_scraping_finished) + + @freeze_time("2023-01-01 12:00:00") + def test_is_scraping_pending_timing(self): + # Set initial pending status + self.core_memory.change_status_to_pending() + initial_time = timezone.now() + + # Test 3 minutes after (should be true) + with freeze_time(initial_time + timedelta(minutes=3)): + self.assertTrue(self.core_memory.is_scraping_pending) + + # Test exactly 5 minutes after (should be false) + with freeze_time(initial_time + timedelta(minutes=5)): + self.assertFalse(self.core_memory.is_scraping_pending) + + # Test 6 minutes after (should be false) + with freeze_time(initial_time + timedelta(minutes=6)): + self.assertFalse(self.core_memory.is_scraping_pending) + + def test_core_memory_operations(self): + # Test setting core memory + test_text = "Test memory content" + self.core_memory.set_core_memory(test_text) + self.assertEqual(self.core_memory.text, test_text) + self.assertEqual(self.core_memory.initial_text, test_text) + self.assertEqual(self.core_memory.scraping_status, CoreMemory.ScrapingStatus.COMPLETED) + + # Test appending core memory + append_text = "Additional content" + self.core_memory.append_core_memory(append_text) + self.assertEqual(self.core_memory.text, f"{test_text}\n{append_text}") + + # Test replacing core memory + original = "content" + new = "memory" + self.core_memory.replace_core_memory(original, new) + self.assertIn(new, self.core_memory.text) + self.assertNotIn(original, self.core_memory.text) + + # Test replacing non-existent content + with self.assertRaises(ValueError): + self.core_memory.replace_core_memory("nonexistent", "new") + + def test_formatted_text(self): + # Test formatted text with short content + short_text = "Short text" + self.core_memory.set_core_memory(short_text) + self.assertEqual(self.core_memory.formatted_text, short_text) + + # Test formatted text with long content + long_text = "x" * 6000 + self.core_memory.set_core_memory(long_text) + self.assertEqual(len(self.core_memory.formatted_text), 5000) diff --git a/frontend/__snapshots__/scenes-app-max-ai--thread-with-form--dark.png b/frontend/__snapshots__/scenes-app-max-ai--thread-with-form--dark.png new file mode 100644 index 0000000000000..38adbf701280c Binary files /dev/null and b/frontend/__snapshots__/scenes-app-max-ai--thread-with-form--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--thread-with-form--light.png b/frontend/__snapshots__/scenes-app-max-ai--thread-with-form--light.png new file mode 100644 index 0000000000000..2252f0ec33c45 Binary files /dev/null and b/frontend/__snapshots__/scenes-app-max-ai--thread-with-form--light.png differ diff --git a/frontend/__snapshots__/scenes-app-sidepanels--side-panel-docs--light.png b/frontend/__snapshots__/scenes-app-sidepanels--side-panel-docs--light.png index 0e7ecd8072388..54c8f18fcfdfc 100644 Binary files a/frontend/__snapshots__/scenes-app-sidepanels--side-panel-docs--light.png and b/frontend/__snapshots__/scenes-app-sidepanels--side-panel-docs--light.png differ diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index 41d7b3e22de83..6339746e595ef 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -572,6 +572,32 @@ "enum": ["status", "message", "conversation"], "type": "string" }, + "AssistantForm": { + "additionalProperties": false, + "properties": { + "options": { + "items": { + "$ref": "#/definitions/AssistantFormOption" + }, + "type": "array" + } + }, + "required": ["options"], + "type": "object" + }, + "AssistantFormOption": { + "additionalProperties": false, + "properties": { + "value": { + "type": "string" + }, + "variant": { + "type": "string" + } + }, + "required": ["value"], + "type": "object" + }, "AssistantFunnelsBreakdownFilter": { "additionalProperties": false, "properties": { @@ -1063,6 +1089,9 @@ "id": { "type": "string" }, + "meta": { + "$ref": "#/definitions/AssistantMessageMetadata" + }, "type": { "const": "ai", "type": "string" @@ -1071,6 +1100,15 @@ "required": ["type", "content"], "type": "object" }, + "AssistantMessageMetadata": { + "additionalProperties": false, + "properties": { + "form": { + "$ref": "#/definitions/AssistantForm" + } + }, + "type": "object" + }, "AssistantMessageType": { "enum": ["human", "ai", "ai/reasoning", "ai/viz", "ai/failure", "ai/router"], "type": "string" diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts index a744096433b00..e13773cbcc23b 100644 --- a/frontend/src/queries/schema.ts +++ b/frontend/src/queries/schema.ts @@ -2541,9 +2541,23 @@ export interface HumanMessage extends BaseAssistantMessage { content: string } +export interface AssistantFormOption { + value: string + variant?: string +} + +export interface AssistantForm { + options: AssistantFormOption[] +} + +export interface AssistantMessageMetadata { + form?: AssistantForm +} + export interface AssistantMessage extends BaseAssistantMessage { type: AssistantMessageType.Assistant content: string + meta?: AssistantMessageMetadata } export interface ReasoningMessage extends BaseAssistantMessage { diff --git a/frontend/src/scenes/max/Max.stories.tsx b/frontend/src/scenes/max/Max.stories.tsx index 51dc03ab0cb5c..93a30c078e4b2 100644 --- a/frontend/src/scenes/max/Max.stories.tsx +++ b/frontend/src/scenes/max/Max.stories.tsx @@ -10,6 +10,7 @@ import { chatResponseChunk, CONVERSATION_ID, failureChunk, + formChunk, generationFailureChunk, humanMessage, } from './__mocks__/chatResponse.mocks' @@ -192,3 +193,19 @@ export const ThreadWithRateLimit: StoryFn = () => { return