Skip to content

Commit

Permalink
shu/fix task url type (#999)
Browse files Browse the repository at this point in the history
  • Loading branch information
wintonzheng authored Oct 18, 2024
1 parent dad53e1 commit f690160
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 7 deletions.
7 changes: 7 additions & 0 deletions skyvern/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,10 @@ def __init__(self, data: dict | None = None) -> None:
class CachedActionPlanError(SkyvernException):
def __init__(self, message: str) -> None:
super().__init__(message)


class InvalidUrl(SkyvernHTTPException):
def __init__(self, url: str) -> None:
super().__init__(
f"Invalid URL: {url}. Skyvern supports HTTP and HTTPS urls.", status_code=status.HTTP_400_BAD_REQUEST
)
8 changes: 5 additions & 3 deletions skyvern/forge/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.validators import validate_url
from skyvern.forge.sdk.models import Organization, Step, StepStatus
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskStatus
from skyvern.forge.sdk.settings_manager import SettingsManager
Expand Down Expand Up @@ -126,6 +127,7 @@ async def create_task_and_step_from_block(

task_url = working_page.url

task_url = validate_url(task_url)
task = await app.DATABASE.create_task(
url=task_url,
title=task_block.title,
Expand Down Expand Up @@ -183,10 +185,10 @@ async def create_task_and_step_from_block(

async def create_task(self, task_request: TaskRequest, organization_id: str | None = None) -> Task:
task = await app.DATABASE.create_task(
url=task_request.url,
url=str(task_request.url),
title=task_request.title,
webhook_callback_url=task_request.webhook_callback_url,
totp_verification_url=task_request.totp_verification_url,
webhook_callback_url=str(task_request.webhook_callback_url),
totp_verification_url=str(task_request.totp_verification_url),
totp_identifier=task_request.totp_identifier,
navigation_goal=task_request.navigation_goal,
data_extraction_goal=task_request.data_extraction_goal,
Expand Down
13 changes: 13 additions & 0 deletions skyvern/forge/sdk/core/validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pydantic import HttpUrl, ValidationError, parse_obj_as

from skyvern.exceptions import InvalidUrl


def validate_url(url: str) -> str:
try:
# Use parse_obj_as to validate the string as an HttpUrl
parse_obj_as(HttpUrl, url)
return url
except ValidationError:
# Handle the validation error
raise InvalidUrl(url=url)
22 changes: 18 additions & 4 deletions skyvern/forge/sdk/schemas/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import StrEnum
from typing import Any

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl

from skyvern.exceptions import InvalidTaskStatusTransition, TaskAlreadyCanceled

Expand All @@ -22,7 +22,7 @@ class ProxyLocation(StrEnum):
NONE = "NONE"


class TaskRequest(BaseModel):
class TaskBase(BaseModel):
title: str | None = Field(
default=None,
description="The title of the task.",
Expand Down Expand Up @@ -76,6 +76,20 @@ class TaskRequest(BaseModel):
)


class TaskRequest(TaskBase):
url: HttpUrl = Field(
...,
description="Starting URL for the task.",
examples=["https://www.geico.com"],
)
webhook_callback_url: HttpUrl | None = Field(
default=None,
description="The URL to call when the task is completed.",
examples=["https://my-webhook.com"],
)
totp_verification_url: HttpUrl | None = None


class TaskStatus(StrEnum):
created = "created"
queued = "queued"
Expand Down Expand Up @@ -144,7 +158,7 @@ def requires_failure_reason(self) -> bool:
return self in status_requires_failure_reason


class Task(TaskRequest):
class Task(TaskBase):
created_at: datetime = Field(
...,
description="The creation datetime of the task.",
Expand Down Expand Up @@ -229,7 +243,7 @@ def to_task_response(


class TaskResponse(BaseModel):
request: TaskRequest
request: TaskBase
task_id: str
status: TaskStatus
created_at: datetime
Expand Down

0 comments on commit f690160

Please sign in to comment.