diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index f82a4108b1..33deda2ade 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -156,7 +156,7 @@ async def create_task_and_step_from_block( navigation_goal=task_block.navigation_goal, data_extraction_goal=task_block.data_extraction_goal, navigation_payload=navigation_payload, - organization_id=workflow.organization_id, + organization_id=workflow_run.organization_id, proxy_location=workflow_run.proxy_location, extracted_information_schema=task_block.data_schema, workflow_run_id=workflow_run.workflow_run_id, diff --git a/skyvern/forge/sdk/artifact/storage/base.py b/skyvern/forge/sdk/artifact/storage/base.py index 27a36b6d40..103b1fa054 100644 --- a/skyvern/forge/sdk/artifact/storage/base.py +++ b/skyvern/forge/sdk/artifact/storage/base.py @@ -40,6 +40,10 @@ class BaseStorage(ABC): def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str: pass + @abstractmethod + async def retrieve_global_workflows(self) -> list[str]: + pass + @abstractmethod def build_log_uri(self, log_entity_type: LogEntityType, log_entity_id: str, artifact_type: ArtifactType) -> str: pass diff --git a/skyvern/forge/sdk/artifact/storage/local.py b/skyvern/forge/sdk/artifact/storage/local.py index 50a7781235..e4483a9fed 100644 --- a/skyvern/forge/sdk/artifact/storage/local.py +++ b/skyvern/forge/sdk/artifact/storage/local.py @@ -26,6 +26,17 @@ def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) - file_ext = FILE_EXTENTSION_MAP[artifact_type] return f"file://{self.artifact_path}/{step.task_id}/{step.order:02d}_{step.retry_index}_{step.step_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}" + async def retrieve_global_workflows(self) -> list[str]: + file_path = Path(f"{self.artifact_path}/{settings.ENV}/global_workflows.txt") + self._create_directories_if_not_exists(file_path) + if not file_path.exists(): + return [] + try: + with open(file_path, "r") as f: + return [line.strip() for line in f.readlines() if line.strip()] + except Exception: + return [] + def build_log_uri(self, log_entity_type: LogEntityType, log_entity_id: str, artifact_type: ArtifactType) -> str: file_ext = FILE_EXTENTSION_MAP[artifact_type] return f"file://{self.artifact_path}/logs/{log_entity_type}/{log_entity_id}/{datetime.utcnow().isoformat()}_{artifact_type}.{file_ext}" diff --git a/skyvern/forge/sdk/artifact/storage/s3.py b/skyvern/forge/sdk/artifact/storage/s3.py index 87ba1e2581..21ab740d73 100644 --- a/skyvern/forge/sdk/artifact/storage/s3.py +++ b/skyvern/forge/sdk/artifact/storage/s3.py @@ -29,6 +29,13 @@ def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) - file_ext = FILE_EXTENTSION_MAP[artifact_type] return f"s3://{self.bucket}/{settings.ENV}/{step.task_id}/{step.order:02d}_{step.retry_index}_{step.step_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}" + async def retrieve_global_workflows(self) -> list[str]: + uri = f"s3://{self.bucket}/{settings.ENV}/global_workflows.txt" + data = await self.async_client.download_file(uri, log_exception=False) + if not data: + return [] + return [line.strip() for line in data.decode("utf-8").split("\n") if line.strip()] + def build_log_uri(self, log_entity_type: LogEntityType, log_entity_id: str, artifact_type: ArtifactType) -> str: file_ext = FILE_EXTENTSION_MAP[artifact_type] return f"s3://{self.bucket}/{settings.ENV}/logs/{log_entity_type}/{log_entity_id}/{datetime.utcnow().isoformat()}_{artifact_type}.{file_ext}" diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 6d40b966ca..0efd2ac74e 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -1172,6 +1172,55 @@ async def get_workflow_by_permanent_id( LOG.error("SQLAlchemyError", exc_info=True) raise + async def get_workflows_by_permanent_ids( + self, + workflow_permanent_ids: list[str], + organization_id: str | None = None, + page: int = 1, + page_size: int = 10, + title: str = "", + statuses: list[WorkflowStatus] | None = None, + ) -> list[Workflow]: + """ + Get all workflows with the latest version for the organization. + """ + if page < 1: + raise ValueError(f"Page must be greater than 0, got {page}") + db_page = page - 1 + try: + async with self.Session() as session: + subquery = ( + select( + WorkflowModel.workflow_permanent_id, + func.max(WorkflowModel.version).label("max_version"), + ) + .where(WorkflowModel.workflow_permanent_id.in_(workflow_permanent_ids)) + .where(WorkflowModel.deleted_at.is_(None)) + .group_by( + WorkflowModel.workflow_permanent_id, + ) + .subquery() + ) + main_query = select(WorkflowModel).join( + subquery, + (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id) + & (WorkflowModel.version == subquery.c.max_version), + ) + if organization_id: + main_query = main_query.where(WorkflowModel.organization_id == organization_id) + if title: + main_query = main_query.where(WorkflowModel.title.ilike(f"%{title}%")) + if statuses: + main_query = main_query.where(WorkflowModel.status.in_(statuses)) + main_query = ( + main_query.order_by(WorkflowModel.created_at.desc()).limit(page_size).offset(db_page * page_size) + ) + workflows = (await session.scalars(main_query)).all() + return [convert_to_workflow(workflow, self.debug_enabled) for workflow in workflows] + except SQLAlchemyError: + LOG.error("SQLAlchemyError", exc_info=True) + raise + async def get_workflows_by_organization_id( self, organization_id: str, diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index df0b2fac84..fea8646afe 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -60,6 +60,7 @@ from skyvern.forge.sdk.workflow.exceptions import ( FailedToCreateWorkflow, FailedToUpdateWorkflow, + InvalidTemplateWorkflowPermanentId, WorkflowParameterMissingRequiredValue, ) from skyvern.forge.sdk.workflow.models.workflow import ( @@ -635,12 +636,18 @@ async def execute_workflow( workflow_request: WorkflowRequestBody, version: int | None = None, current_org: Organization = Depends(org_auth_service.get_current_org), + template: bool = Query(False), x_api_key: Annotated[str | None, Header()] = None, x_max_steps_override: Annotated[int | None, Header()] = None, ) -> RunWorkflowResponse: analytics.capture("skyvern-oss-agent-workflow-execute") context = skyvern_context.ensure_context() request_id = context.request_id + + if template: + if workflow_id not in await app.STORAGE.retrieve_global_workflows(): + raise InvalidTemplateWorkflowPermanentId(workflow_permanent_id=workflow_id) + workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run( request_id=request_id, workflow_request=workflow_request, @@ -648,6 +655,7 @@ async def execute_workflow( organization_id=current_org.organization_id, version=version, max_steps_override=x_max_steps_override, + is_template_workflow=template, ) if x_max_steps_override: LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override) @@ -914,12 +922,26 @@ async def get_workflows( only_workflows: bool = Query(False), title: str = Query(""), current_org: Organization = Depends(org_auth_service.get_current_org), + template: bool = Query(False), ) -> list[Workflow]: """ Get all workflows with the latest version for the organization. """ analytics.capture("skyvern-oss-agent-workflows-get") + if template: + global_workflows_permanent_ids = await app.STORAGE.retrieve_global_workflows() + if not global_workflows_permanent_ids: + return [] + workflows = await app.WORKFLOW_SERVICE.get_workflows_by_permanent_ids( + workflow_permanent_ids=global_workflows_permanent_ids, + page=page, + page_size=page_size, + title=title, + statuses=[WorkflowStatus.published, WorkflowStatus.draft], + ) + return workflows + if only_saved_tasks and only_workflows: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -943,11 +965,16 @@ async def get_workflow( workflow_permanent_id: str, version: int | None = None, current_org: Organization = Depends(org_auth_service.get_current_org), + template: bool = Query(False), ) -> Workflow: analytics.capture("skyvern-oss-agent-workflows-get") + if template: + if workflow_permanent_id not in await app.STORAGE.retrieve_global_workflows(): + raise InvalidTemplateWorkflowPermanentId(workflow_permanent_id=workflow_permanent_id) + return await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id( workflow_permanent_id=workflow_permanent_id, - organization_id=current_org.organization_id, + organization_id=None if template else current_org.organization_id, version=version, ) diff --git a/skyvern/forge/sdk/services/observer_service.py b/skyvern/forge/sdk/services/observer_service.py index 4ea7ff5077..196ba4c9dc 100644 --- a/skyvern/forge/sdk/services/observer_service.py +++ b/skyvern/forge/sdk/services/observer_service.py @@ -704,7 +704,7 @@ async def handle_block_result( # refresh workflow run model return await app.WORKFLOW_SERVICE.get_workflow_run( workflow_run_id=workflow_run_id, - organization_id=workflow.organization_id, + organization_id=workflow_run.organization_id, ) diff --git a/skyvern/forge/sdk/workflow/exceptions.py b/skyvern/forge/sdk/workflow/exceptions.py index 1ba5ca61cf..3aceed396d 100644 --- a/skyvern/forge/sdk/workflow/exceptions.py +++ b/skyvern/forge/sdk/workflow/exceptions.py @@ -124,3 +124,11 @@ def __init__(self, template: str, msg: str) -> None: class NoIterableValueFound(SkyvernException): def __init__(self) -> None: super().__init__("No iterable value found for the loop block") + + +class InvalidTemplateWorkflowPermanentId(SkyvernHTTPException): + def __init__(self, workflow_permanent_id: str) -> None: + super().__init__( + message=f"Invalid template workflow permanent id: {workflow_permanent_id}. Please make sure the workflow is a valid template.", + status_code=status.HTTP_400_BAD_REQUEST, + ) diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index ca160f40ea..c2862ed30f 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -441,9 +441,8 @@ async def execute( workflow_run_id=workflow_run_id, organization_id=organization_id, ) - workflow = await app.WORKFLOW_SERVICE.get_workflow( - workflow_id=workflow_run.workflow_id, - organization_id=organization_id, + workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id( + workflow_permanent_id=workflow_run.workflow_permanent_id, ) # if the task url is parameterized, we need to get the value from the workflow run context if self.url and workflow_run_context.has_parameter(self.url) and workflow_run_context.has_value(self.url): @@ -512,12 +511,12 @@ async def execute( workflow_run_block = await app.DATABASE.update_workflow_run_block( workflow_run_block_id=workflow_run_block_id, task_id=task.task_id, - organization_id=workflow.organization_id, + organization_id=organization_id, ) current_running_task = task - organization = await app.DATABASE.get_organization(organization_id=workflow.organization_id) + organization = await app.DATABASE.get_organization(organization_id=workflow_run.organization_id) if not organization: - raise Exception(f"Organization is missing organization_id={workflow.organization_id}") + raise Exception(f"Organization is missing organization_id={workflow_run.organization_id}") browser_state: BrowserState | None = None if is_first_task: @@ -544,7 +543,7 @@ async def execute( await app.DATABASE.update_task( task.task_id, status=TaskStatus.failed, - organization_id=workflow.organization_id, + organization_id=workflow_run.organization_id, failure_reason=str(e), ) raise e @@ -569,7 +568,7 @@ async def execute( workflow_run_id=workflow_run_id, task_id=task.task_id, workflow_id=workflow.workflow_id, - organization_id=workflow.organization_id, + organization_id=workflow_run.organization_id, step_id=step.step_id, ) try: @@ -578,7 +577,7 @@ async def execute( await app.DATABASE.update_task( task.task_id, status=TaskStatus.failed, - organization_id=workflow.organization_id, + organization_id=workflow_run.organization_id, failure_reason=str(e), ) raise e @@ -597,13 +596,15 @@ async def execute( await app.DATABASE.update_task( task.task_id, status=TaskStatus.failed, - organization_id=workflow.organization_id, + organization_id=workflow_run.organization_id, failure_reason=str(e), ) raise e # Check task status - updated_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=workflow.organization_id) + updated_task = await app.DATABASE.get_task( + task_id=task.task_id, organization_id=workflow_run.organization_id + ) if not updated_task: raise TaskNotFound(task.task_id) if not updated_task.status.is_final(): @@ -624,7 +625,7 @@ async def execute( task_status=updated_task.status, workflow_run_id=workflow_run_id, workflow_id=workflow.workflow_id, - organization_id=workflow.organization_id, + organization_id=workflow_run.organization_id, ) success = updated_task.status == TaskStatus.completed task_output = TaskOutput.from_task(updated_task) @@ -645,7 +646,7 @@ async def execute( task_status=updated_task.status, workflow_run_id=workflow_run_id, workflow_id=workflow.workflow_id, - organization_id=workflow.organization_id, + organization_id=workflow_run.organization_id, ) return await self.build_block_result( success=False, @@ -662,7 +663,7 @@ async def execute( task_status=updated_task.status, workflow_run_id=workflow_run_id, workflow_id=workflow.workflow_id, - organization_id=workflow.organization_id, + organization_id=workflow_run.organization_id, ) return await self.build_block_result( success=False, @@ -683,7 +684,7 @@ async def execute( status=updated_task.status, workflow_run_id=workflow_run_id, workflow_id=workflow.workflow_id, - organization_id=workflow.organization_id, + organization_id=workflow_run.organization_id, current_retry=current_retry, max_retries=self.max_retries, task_output=task_output.model_dump_json(), diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index e53ed337c8..07b0027a37 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -93,6 +93,7 @@ async def setup_workflow_run( workflow_request: WorkflowRequestBody, workflow_permanent_id: str, organization_id: str, + is_template_workflow: bool = False, version: int | None = None, max_steps_override: int | None = None, ) -> WorkflowRun: @@ -109,7 +110,7 @@ async def setup_workflow_run( # Validate the workflow and the organization workflow = await self.get_workflow_by_permanent_id( workflow_permanent_id=workflow_permanent_id, - organization_id=organization_id, + organization_id=None if is_template_workflow else organization_id, version=version, ) if workflow is None: @@ -125,7 +126,7 @@ async def setup_workflow_run( workflow_request=workflow_request, workflow_permanent_id=workflow_permanent_id, workflow_id=workflow_id, - organization_id=workflow.organization_id, + organization_id=organization_id, ) LOG.info( f"Created workflow run {workflow_run.workflow_run_id} for workflow {workflow.workflow_id}", @@ -202,7 +203,7 @@ async def execute_workflow( browser_session_id=browser_session_id, ) workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id) - workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id, organization_id=organization_id) + workflow = await self.get_workflow_by_permanent_id(workflow_permanent_id=workflow_run.workflow_permanent_id) # Set workflow run status to running, create workflow run parameters await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id) @@ -520,6 +521,24 @@ async def get_workflow_by_permanent_id( raise WorkflowNotFound(workflow_permanent_id=workflow_permanent_id, version=version) return workflow + async def get_workflows_by_permanent_ids( + self, + workflow_permanent_ids: list[str], + organization_id: str | None = None, + page: int = 1, + page_size: int = 10, + title: str = "", + statuses: list[WorkflowStatus] | None = None, + ) -> list[Workflow]: + return await app.DATABASE.get_workflows_by_permanent_ids( + workflow_permanent_ids, + organization_id=organization_id, + page=page, + page_size=page_size, + title=title, + statuses=statuses, + ) + async def get_workflows_by_organization_id( self, organization_id: str, @@ -864,7 +883,7 @@ async def build_workflow_run_status_response( organization_id: str, include_cost: bool = False, ) -> WorkflowRunStatusResponse: - workflow = await self.get_workflow_by_permanent_id(workflow_permanent_id, organization_id=organization_id) + workflow = await self.get_workflow_by_permanent_id(workflow_permanent_id) if workflow is None: LOG.error(f"Workflow {workflow_permanent_id} not found") raise WorkflowNotFound(workflow_permanent_id=workflow_permanent_id) @@ -903,7 +922,9 @@ async def build_workflow_run_status_response( try: async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT): downloaded_file_urls = await app.STORAGE.get_downloaded_files( - organization_id=workflow.organization_id, task_id=None, workflow_run_id=workflow_run.workflow_run_id + organization_id=workflow_run.organization_id, + task_id=None, + workflow_run_id=workflow_run.workflow_run_id, ) except asyncio.TimeoutError: LOG.warning( @@ -989,7 +1010,7 @@ async def clean_up_workflow( await self.persist_debug_artifacts(browser_state, tasks[-1], workflow, workflow_run) if workflow.persist_browser_session and browser_state.browser_artifacts.browser_session_dir: await app.STORAGE.store_browser_session( - workflow.organization_id, + workflow_run.organization_id, workflow.workflow_permanent_id, browser_state.browser_artifacts.browser_session_dir, ) @@ -1000,7 +1021,7 @@ async def clean_up_workflow( try: async with asyncio.timeout(SAVE_DOWNLOADED_FILES_TIMEOUT): await app.STORAGE.save_downloaded_files( - workflow.organization_id, task_id=None, workflow_run_id=workflow_run.workflow_run_id + workflow_run.organization_id, task_id=None, workflow_run_id=workflow_run.workflow_run_id ) except asyncio.TimeoutError: LOG.warning( @@ -1106,7 +1127,7 @@ async def persist_video_data( for video_artifact in video_artifacts: await app.ARTIFACT_MANAGER.update_artifact_data( artifact_id=video_artifact.video_artifact_id, - organization_id=workflow.organization_id, + organization_id=workflow_run.organization_id, data=video_artifact.video_data, )