Skip to content

Commit

Permalink
Switch to field validator to get IDs from location
Browse files Browse the repository at this point in the history
  • Loading branch information
t0mz06 committed Nov 7, 2024
1 parent d425587 commit c1ba944
Showing 1 changed file with 15 additions and 25 deletions.
40 changes: 15 additions & 25 deletions src/pytmv1/model/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Dict, Generic, List, Optional, TypeVar, Union

from pydantic import Field, model_validator
from pydantic import Field, model_validator, field_validator

from .common import (
Account,
Expand Down Expand Up @@ -88,27 +88,21 @@ class AccountTaskResp(BaseTaskResp):


class AddAlertNoteResp(BaseResponse):
note_id: str
note_id: str = Field(validation_alias="Location")

@model_validator(mode="before")
@field_validator("note_id", mode="before")
@classmethod
def map_data(
cls, data: Dict[str, Optional[str]]
) -> Dict[str, Optional[str]]:
data["note_id"] = _get_id(data)
return data
def get_id(cls, value: str) -> str:
return _get_id(value)


class AddCustomScriptResp(BaseResponse):
script_id: str
script_id: str = Field(validation_alias="Location")

@model_validator(mode="before")
@field_validator("script_id", mode="before")
@classmethod
def map_data(
cls, data: Dict[str, Optional[str]]
) -> Dict[str, Optional[str]]:
data["script_id"] = _get_id(data)
return data
def get_id(cls, value: str) -> str:
return _get_id(value)


class BlockListTaskResp(BaseTaskResp):
Expand Down Expand Up @@ -263,15 +257,12 @@ class CustomScriptTaskResp(BaseTaskResp):


class OatPipelineResp(BaseResponse):
pipeline_id: str
pipeline_id: str = Field(validation_alias="Location")

@model_validator(mode="before")
@field_validator("pipeline_id", mode="before")
@classmethod
def map_data(
cls, data: Dict[str, Optional[str]]
) -> Dict[str, Optional[str]]:
data["pipeline_id"] = _get_id(data)
return data
def get_id(cls, value: str) -> str:
return _get_id(value)


class SubmitFileToSandboxResp(BaseResponse):
Expand Down Expand Up @@ -320,6 +311,5 @@ class TextResp(BaseResponse):
text: str


def _get_id(data: Dict[str, Optional[str]]) -> Optional[str]:
location = data.get("Location")
return location.split("/")[-1] if location else None
def _get_id(location: str) -> str:
return location.split("/")[-1]

0 comments on commit c1ba944

Please sign in to comment.