Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(backend): Add ability to execute store agents without agent ownership #9179

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
45 changes: 33 additions & 12 deletions autogpt_platform/backend/backend/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
from typing import Any, Literal, Optional, Type

import prisma
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
from prisma.types import AgentGraphWhereInput
from prisma.models import (
AgentGraph,
AgentGraphExecution,
AgentNode,
AgentNodeLink,
StoreListing,
)
from prisma.types import AgentGraphWhereInput, StoreListingWhereInput
from pydantic.fields import computed_field

from backend.blocks.agent import AgentExecutorBlock
Expand Down Expand Up @@ -529,7 +535,6 @@ async def get_execution(user_id: str, execution_id: str) -> GraphExecution | Non
async def get_graph(
graph_id: str,
version: int | None = None,
template: bool = False,
user_id: str | None = None,
Copy link
Contributor

@majdyz majdyz Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user_id meaning is no longer a filter here but a security requirement, I think we need to make this mandatory.
The template is deprecated, right?

for_export: bool = False,
) -> GraphModel | None:
Expand All @@ -543,20 +548,38 @@ async def get_graph(
where_clause: AgentGraphWhereInput = {
"id": graph_id,
}

if version is not None:
where_clause["version"] = version
Swiftyos marked this conversation as resolved.
Show resolved Hide resolved
elif not template:
where_clause["isActive"] = True

# TODO: Fix hack workaround to get adding store agents to work
if user_id is not None and not template:
where_clause["userId"] = user_id

graph = await AgentGraph.prisma().find_first(
where=where_clause,
include=AGENT_GRAPH_INCLUDE,
order={"version": "desc"},
)

if not graph:
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we already have this, we can remove if graph else None in both returns of this function.


if graph.userId == user_id:
return GraphModel.from_db(graph, for_export) if graph else None

# If the graph is not owned by the user, we need to check if it's a store listing.
if not version:
version = graph.version

store_listing_where: StoreListingWhereInput = {
"agentId": graph_id,
"agentVersion": version,
}

store_listing = await StoreListing.prisma().find_first(where=store_listing_where)

# If it does not belong to the user nor is not a store listing, return None
if not store_listing:
return None

# If it is a store listing, return the graph model
Comment on lines +566 to +584
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if graph.userId == user_id:
return GraphModel.from_db(graph, for_export) if graph else None
# If the graph is not owned by the user, we need to check if it's a store listing.
if not version:
version = graph.version
store_listing_where: StoreListingWhereInput = {
"agentId": graph_id,
"agentVersion": version,
}
store_listing = await StoreListing.prisma().find_first(where=store_listing_where)
# If it does not belong to the user nor is not a store listing, return None
if not store_listing:
return None
# If it is a store listing, return the graph model
# The Graph has to be owned by the user or a store listing.
if graph.userId != user_id and not (await StoreListing.prisma().find_first(where={
"agentId": graph_id,
"agentVersion": version or graph.version,
})):
return None

I think a concise form is easier to understand.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw don't we put model queries for StoreListing in store/db.py? can we put this function there?

Copy link
Contributor

@majdyz majdyz Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also ["agentId", "agentVersion"] has to be indexed.

We can replace the index from:

  @@index([agentId])

to

  @@index([agentId, agentVersion])

Or even further, why do we need ID for StoreListing, StoreListingVersion ?
Can we make [agentId, agentVersion] StoreListing id and [agentId, agentVersion, version] StoreListingVersion id ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also does that mean the query has to filter the isApproved & isDeleted too?
I think the query needs to be placed on store/db.py with these filters ?

return GraphModel.from_db(graph, for_export) if graph else None


Expand Down Expand Up @@ -611,9 +634,7 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
async with transaction() as tx:
await __create_graph(tx, graph, user_id)

if created_graph := await get_graph(
graph.id, graph.version, graph.is_template, user_id=user_id
):
if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
return created_graph

raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
Expand Down
2 changes: 1 addition & 1 deletion autogpt_platform/backend/backend/executor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def add_execution(
graph_id: str,
data: BlockInput,
user_id: str,
graph_version: int | None = None,
graph_version: int,
Swiftyos marked this conversation as resolved.
Show resolved Hide resolved
) -> GraphExecutionEntry:
graph: GraphModel | None = self.db_client.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
Expand Down
5 changes: 4 additions & 1 deletion autogpt_platform/backend/backend/executor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def execute_graph(**kwargs):
try:
log(f"Executing recurring job for graph #{args.graph_id}")
get_execution_client().add_execution(
args.graph_id, args.input_data, args.user_id
graph_id=args.graph_id,
data=args.input_data,
user_id=args.user_id,
graph_version=args.graph_version,
)
except Exception as e:
logger.exception(f"Error executing graph {args.graph_id}: {e}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ async def webhook_ingress_generic(
continue
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
executor.add_execution(
node.graph_id,
graph_id=node.graph_id,
graph_version=node.graph_version,
data={f"webhook_{webhook_id}_payload": payload},
user_id=webhook.user_id,
)
Expand Down
9 changes: 7 additions & 2 deletions autogpt_platform/backend/backend/server/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,14 @@ def run(self):

@staticmethod
async def test_execute_graph(
graph_id: str, node_input: dict[typing.Any, typing.Any], user_id: str
graph_id: str,
graph_version: int,
node_input: dict[typing.Any, typing.Any],
user_id: str,
):
return backend.server.routers.v1.execute_graph(graph_id, node_input, user_id)
return backend.server.routers.v1.execute_graph(
graph_id, graph_version, node_input, user_id
)

@staticmethod
async def test_create_graph(
Expand Down
13 changes: 5 additions & 8 deletions autogpt_platform/backend/backend/server/routers/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,11 @@ async def get_graph_all_versions(
async def create_new_graph(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.GraphModel:
return await do_create_graph(create_graph, is_template=False, user_id=user_id)
return await do_create_graph(create_graph, user_id=user_id)


async def do_create_graph(
create_graph: CreateGraph,
is_template: bool,
# user_id doesn't have to be annotated like on other endpoints,
# because create_graph isn't used directly as an endpoint
user_id: str,
Expand All @@ -217,7 +216,6 @@ async def do_create_graph(
graph = await graph_db.get_graph(
create_graph.template_id,
create_graph.template_version,
template=True,
user_id=user_id,
)
if not graph:
Expand All @@ -230,8 +228,6 @@ async def do_create_graph(
status_code=400, detail="Either graph or template_id must be provided."
)

graph.is_template = is_template
graph.is_active = not is_template
Pwuts marked this conversation as resolved.
Show resolved Hide resolved
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)

graph = await graph_db.create_graph(graph, user_id=user_id)
Expand Down Expand Up @@ -368,12 +364,13 @@ def get_credentials(credentials_id: str) -> "Credentials | None":
)
def execute_graph(
graph_id: str,
graph_version: int,
node_input: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[str, Any]: # FIXME: add proper return type
try:
graph_exec = execution_manager_client().add_execution(
graph_id, node_input, user_id=user_id
graph_id, node_input, user_id=user_id, graph_version=graph_version
)
return {"id": graph_exec.graph_exec_id}
except Exception as e:
Expand Down Expand Up @@ -452,7 +449,7 @@ async def get_templates(
async def get_template(
graph_id: str, version: int | None = None
) -> graph_db.GraphModel:
graph = await graph_db.get_graph(graph_id, version, template=True)
graph = await graph_db.get_graph(graph_id, version)
if not graph:
raise HTTPException(status_code=404, detail=f"Template #{graph_id} not found.")
return graph
Expand All @@ -466,7 +463,7 @@ async def get_template(
async def create_new_template(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.GraphModel:
return await do_create_graph(create_graph, is_template=True, user_id=user_id)
return await do_create_graph(create_graph, user_id=user_id)


########################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def add_agent_to_library(

# Create a new graph from the template
graph = await backend.data.graph.get_graph(
agent.id, agent.version, template=True, user_id=user_id
agent.id, agent.version, user_id=user_id
)

if not graph:
Expand Down
4 changes: 1 addition & 3 deletions autogpt_platform/backend/backend/server/v2/store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,9 +811,7 @@ async def get_agent(

agent = store_listing_version.Agent

graph = await backend.data.graph.get_graph(
agent.id, agent.version, template=True
)
graph = await backend.data.graph.get_graph(agent.id, agent.version)

if not graph:
raise fastapi.HTTPException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ async def block_autogen_agent():
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"input": "Write me a block that writes a string into a file."}
response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
test_graph.id, test_graph.version, input_data, test_user.id
)
print(response)
result = await wait_execution(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ async def reddit_marketing_agent():
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"subreddit": "AutoGPT"}
response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
test_graph.id, test_graph.version, input_data, test_user.id
)
print(response)
result = await wait_execution(test_user.id, test_graph.id, response["id"], 120)
Expand Down
2 changes: 1 addition & 1 deletion autogpt_platform/backend/backend/usecases/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def sample_agent():
test_graph = await create_graph(create_test_graph(), test_user.id)
input_data = {"input_1": "Hello", "input_2": "World"}
response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
test_graph.id, test_graph.version, input_data, test_user.id
)
print(response)
result = await wait_execution(test_user.id, test_graph.id, response["id"], 10)
Expand Down
2 changes: 1 addition & 1 deletion autogpt_platform/backend/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ model AgentGraphExecution {

AgentNodeExecutions AgentNodeExecution[]

// Link to User model
// Link to User model -- Executed by this user
userId String
user User @relation(fields: [userId], references: [id], onDelete: Cascade)

Expand Down
2 changes: 1 addition & 1 deletion autogpt_platform/backend/test/executor/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def execute_graph(

# --- Test adding new executions --- #
response = await agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
test_graph.id, test_graph.version, input_data, test_user.id
)
graph_exec_id = response["id"]
logger.info(f"Created execution with ID: {graph_exec_id}")
Expand Down
Loading