Skip to content

Commit

Permalink
🔄 synced local 'skyvern/' with remote 'skyvern/'
Browse files Browse the repository at this point in the history
<!-- ELLIPSIS_HIDDEN -->

> [!IMPORTANT]
> Add support for global workflows by using permanent workflow IDs and updating workflow retrieval and execution methods.
>
>   - **Behavior**:
>     - Modify `get_user_data_dir_for_workflow_run_id()` in `special_browsers.py` to use `get_workflow_by_permanent_id()`.
>     - Modify `execute_workflow()` in `run_workflow.py` to use `get_workflow_by_permanent_id()`.
>     - Add `retrieve_global_workflows()` method to `BaseStorage` in `base.py` and implement it in `local.py` and `s3.py`.
>     - Add `get_workflows_by_permanent_ids()` to `client.py` and `service.py` to retrieve workflows by permanent IDs.
>     - Add `InvalidTemplateWorkflowPermanentId` exception in `exceptions.py`.
>   - **Routes**:
>     - Update `execute_workflow()` and `get_workflows()` in `agent_protocol.py` to handle template workflows using global workflow IDs.
>   - **Misc**:
>     - Change `organization_id` references to `workflow_run.organization_id` in `agent.py`, `observer_service.py`, and `block.py`.
>     - Add `is_template_workflow` parameter to `setup_workflow_run()` in `service.py`.
>     - Update `get_workflow_by_permanent_id()` calls across multiple files to support global workflows.
>
> <sup>This description was created by </sup>[<img alt="Ellipsis" src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=Skyvern-AI%2Fskyvern-cloud&utm_source=github&utm_medium=referral)<sup> for 5fd39022532b95b86fcd39dd726424f9de19f358. It will automatically update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
wintonzheng committed Jan 28, 2025
1 parent 833cd81 commit 34daf57
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 26 deletions.
2 changes: 1 addition & 1 deletion skyvern/forge/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ async def create_task_and_step_from_block(
navigation_goal=task_block.navigation_goal,
data_extraction_goal=task_block.data_extraction_goal,
navigation_payload=navigation_payload,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
proxy_location=workflow_run.proxy_location,
extracted_information_schema=task_block.data_schema,
workflow_run_id=workflow_run.workflow_run_id,
Expand Down
4 changes: 4 additions & 0 deletions skyvern/forge/sdk/artifact/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class BaseStorage(ABC):
def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str:
pass

@abstractmethod
async def retrieve_global_workflows(self) -> list[str]:
pass

@abstractmethod
def build_log_uri(self, log_entity_type: LogEntityType, log_entity_id: str, artifact_type: ArtifactType) -> str:
pass
Expand Down
11 changes: 11 additions & 0 deletions skyvern/forge/sdk/artifact/storage/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"file://{self.artifact_path}/{step.task_id}/{step.order:02d}_{step.retry_index}_{step.step_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"

async def retrieve_global_workflows(self) -> list[str]:
file_path = Path(f"{self.artifact_path}/{settings.ENV}/global_workflows.txt")
self._create_directories_if_not_exists(file_path)
if not file_path.exists():
return []
try:
with open(file_path, "r") as f:
return [line.strip() for line in f.readlines() if line.strip()]
except Exception:
return []

def build_log_uri(self, log_entity_type: LogEntityType, log_entity_id: str, artifact_type: ArtifactType) -> str:
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"file://{self.artifact_path}/logs/{log_entity_type}/{log_entity_id}/{datetime.utcnow().isoformat()}_{artifact_type}.{file_ext}"
Expand Down
7 changes: 7 additions & 0 deletions skyvern/forge/sdk/artifact/storage/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def build_uri(self, artifact_id: str, step: Step, artifact_type: ArtifactType) -
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"s3://{self.bucket}/{settings.ENV}/{step.task_id}/{step.order:02d}_{step.retry_index}_{step.step_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"

async def retrieve_global_workflows(self) -> list[str]:
uri = f"s3://{self.bucket}/{settings.ENV}/global_workflows.txt"
data = await self.async_client.download_file(uri, log_exception=False)
if not data:
return []
return [line.strip() for line in data.decode("utf-8").split("\n") if line.strip()]

def build_log_uri(self, log_entity_type: LogEntityType, log_entity_id: str, artifact_type: ArtifactType) -> str:
file_ext = FILE_EXTENTSION_MAP[artifact_type]
return f"s3://{self.bucket}/{settings.ENV}/logs/{log_entity_type}/{log_entity_id}/{datetime.utcnow().isoformat()}_{artifact_type}.{file_ext}"
Expand Down
49 changes: 49 additions & 0 deletions skyvern/forge/sdk/db/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,55 @@ async def get_workflow_by_permanent_id(
LOG.error("SQLAlchemyError", exc_info=True)
raise

async def get_workflows_by_permanent_ids(
self,
workflow_permanent_ids: list[str],
organization_id: str | None = None,
page: int = 1,
page_size: int = 10,
title: str = "",
statuses: list[WorkflowStatus] | None = None,
) -> list[Workflow]:
"""
Get all workflows with the latest version for the organization.
"""
if page < 1:
raise ValueError(f"Page must be greater than 0, got {page}")
db_page = page - 1
try:
async with self.Session() as session:
subquery = (
select(
WorkflowModel.workflow_permanent_id,
func.max(WorkflowModel.version).label("max_version"),
)
.where(WorkflowModel.workflow_permanent_id.in_(workflow_permanent_ids))
.where(WorkflowModel.deleted_at.is_(None))
.group_by(
WorkflowModel.workflow_permanent_id,
)
.subquery()
)
main_query = select(WorkflowModel).join(
subquery,
(WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
& (WorkflowModel.version == subquery.c.max_version),
)
if organization_id:
main_query = main_query.where(WorkflowModel.organization_id == organization_id)
if title:
main_query = main_query.where(WorkflowModel.title.ilike(f"%{title}%"))
if statuses:
main_query = main_query.where(WorkflowModel.status.in_(statuses))
main_query = (
main_query.order_by(WorkflowModel.created_at.desc()).limit(page_size).offset(db_page * page_size)
)
workflows = (await session.scalars(main_query)).all()
return [convert_to_workflow(workflow, self.debug_enabled) for workflow in workflows]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise

async def get_workflows_by_organization_id(
self,
organization_id: str,
Expand Down
29 changes: 28 additions & 1 deletion skyvern/forge/sdk/routes/agent_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from skyvern.forge.sdk.workflow.exceptions import (
FailedToCreateWorkflow,
FailedToUpdateWorkflow,
InvalidTemplateWorkflowPermanentId,
WorkflowParameterMissingRequiredValue,
)
from skyvern.forge.sdk.workflow.models.workflow import (
Expand Down Expand Up @@ -635,19 +636,26 @@ async def execute_workflow(
workflow_request: WorkflowRequestBody,
version: int | None = None,
current_org: Organization = Depends(org_auth_service.get_current_org),
template: bool = Query(False),
x_api_key: Annotated[str | None, Header()] = None,
x_max_steps_override: Annotated[int | None, Header()] = None,
) -> RunWorkflowResponse:
analytics.capture("skyvern-oss-agent-workflow-execute")
context = skyvern_context.ensure_context()
request_id = context.request_id

if template:
if workflow_id not in await app.STORAGE.retrieve_global_workflows():
raise InvalidTemplateWorkflowPermanentId(workflow_permanent_id=workflow_id)

workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
request_id=request_id,
workflow_request=workflow_request,
workflow_permanent_id=workflow_id,
organization_id=current_org.organization_id,
version=version,
max_steps_override=x_max_steps_override,
is_template_workflow=template,
)
if x_max_steps_override:
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
Expand Down Expand Up @@ -914,12 +922,26 @@ async def get_workflows(
only_workflows: bool = Query(False),
title: str = Query(""),
current_org: Organization = Depends(org_auth_service.get_current_org),
template: bool = Query(False),
) -> list[Workflow]:
"""
Get all workflows with the latest version for the organization.
"""
analytics.capture("skyvern-oss-agent-workflows-get")

if template:
global_workflows_permanent_ids = await app.STORAGE.retrieve_global_workflows()
if not global_workflows_permanent_ids:
return []
workflows = await app.WORKFLOW_SERVICE.get_workflows_by_permanent_ids(
workflow_permanent_ids=global_workflows_permanent_ids,
page=page,
page_size=page_size,
title=title,
statuses=[WorkflowStatus.published, WorkflowStatus.draft],
)
return workflows

if only_saved_tasks and only_workflows:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand All @@ -943,11 +965,16 @@ async def get_workflow(
workflow_permanent_id: str,
version: int | None = None,
current_org: Organization = Depends(org_auth_service.get_current_org),
template: bool = Query(False),
) -> Workflow:
analytics.capture("skyvern-oss-agent-workflows-get")
if template:
if workflow_permanent_id not in await app.STORAGE.retrieve_global_workflows():
raise InvalidTemplateWorkflowPermanentId(workflow_permanent_id=workflow_permanent_id)

return await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id(
workflow_permanent_id=workflow_permanent_id,
organization_id=current_org.organization_id,
organization_id=None if template else current_org.organization_id,
version=version,
)

Expand Down
2 changes: 1 addition & 1 deletion skyvern/forge/sdk/services/observer_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ async def handle_block_result(
# refresh workflow run model
return await app.WORKFLOW_SERVICE.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
)


Expand Down
8 changes: 8 additions & 0 deletions skyvern/forge/sdk/workflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,11 @@ def __init__(self, template: str, msg: str) -> None:
class NoIterableValueFound(SkyvernException):
def __init__(self) -> None:
super().__init__("No iterable value found for the loop block")


class InvalidTemplateWorkflowPermanentId(SkyvernHTTPException):
def __init__(self, workflow_permanent_id: str) -> None:
super().__init__(
message=f"Invalid template workflow permanent id: {workflow_permanent_id}. Please make sure the workflow is a valid template.",
status_code=status.HTTP_400_BAD_REQUEST,
)
31 changes: 16 additions & 15 deletions skyvern/forge/sdk/workflow/models/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,8 @@ async def execute(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
workflow = await app.WORKFLOW_SERVICE.get_workflow(
workflow_id=workflow_run.workflow_id,
organization_id=organization_id,
workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id(
workflow_permanent_id=workflow_run.workflow_permanent_id,
)
# if the task url is parameterized, we need to get the value from the workflow run context
if self.url and workflow_run_context.has_parameter(self.url) and workflow_run_context.has_value(self.url):
Expand Down Expand Up @@ -512,12 +511,12 @@ async def execute(
workflow_run_block = await app.DATABASE.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
task_id=task.task_id,
organization_id=workflow.organization_id,
organization_id=organization_id,
)
current_running_task = task
organization = await app.DATABASE.get_organization(organization_id=workflow.organization_id)
organization = await app.DATABASE.get_organization(organization_id=workflow_run.organization_id)
if not organization:
raise Exception(f"Organization is missing organization_id={workflow.organization_id}")
raise Exception(f"Organization is missing organization_id={workflow_run.organization_id}")

browser_state: BrowserState | None = None
if is_first_task:
Expand All @@ -544,7 +543,7 @@ async def execute(
await app.DATABASE.update_task(
task.task_id,
status=TaskStatus.failed,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
failure_reason=str(e),
)
raise e
Expand All @@ -569,7 +568,7 @@ async def execute(
workflow_run_id=workflow_run_id,
task_id=task.task_id,
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
step_id=step.step_id,
)
try:
Expand All @@ -578,7 +577,7 @@ async def execute(
await app.DATABASE.update_task(
task.task_id,
status=TaskStatus.failed,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
failure_reason=str(e),
)
raise e
Expand All @@ -597,13 +596,15 @@ async def execute(
await app.DATABASE.update_task(
task.task_id,
status=TaskStatus.failed,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
failure_reason=str(e),
)
raise e

# Check task status
updated_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=workflow.organization_id)
updated_task = await app.DATABASE.get_task(
task_id=task.task_id, organization_id=workflow_run.organization_id
)
if not updated_task:
raise TaskNotFound(task.task_id)
if not updated_task.status.is_final():
Expand All @@ -624,7 +625,7 @@ async def execute(
task_status=updated_task.status,
workflow_run_id=workflow_run_id,
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
)
success = updated_task.status == TaskStatus.completed
task_output = TaskOutput.from_task(updated_task)
Expand All @@ -645,7 +646,7 @@ async def execute(
task_status=updated_task.status,
workflow_run_id=workflow_run_id,
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
)
return await self.build_block_result(
success=False,
Expand All @@ -662,7 +663,7 @@ async def execute(
task_status=updated_task.status,
workflow_run_id=workflow_run_id,
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
)
return await self.build_block_result(
success=False,
Expand All @@ -683,7 +684,7 @@ async def execute(
status=updated_task.status,
workflow_run_id=workflow_run_id,
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
organization_id=workflow_run.organization_id,
current_retry=current_retry,
max_retries=self.max_retries,
task_output=task_output.model_dump_json(),
Expand Down
Loading

0 comments on commit 34daf57

Please sign in to comment.