Skip to content

Commit

Permalink
Merge pull request #8 from trendmicro/feature/update_models
Browse files Browse the repository at this point in the history
Update models for pydantic v2 migration
  • Loading branch information
AmbientPlatypus authored Jan 15, 2024
2 parents b60a433 + ef6383a commit 3e24a12
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 138 deletions.
8 changes: 7 additions & 1 deletion src/pytmv1/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@
class HTTPConnectionPool(HTTPUrllib):
@typing.no_type_check
def urlopen(self, method, url, **kwargs):
kwargs.pop("preload_content", "")
return super().urlopen(
method, url, pool_timeout=5, release_conn=True, **kwargs
method,
url,
pool_timeout=5,
release_conn=True,
preload_content=False,
**kwargs,
)


Expand Down
2 changes: 1 addition & 1 deletion src/pytmv1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def _parse_data(raw_response: Response, class_: Type[R]) -> R:
etag=raw_response.headers.get("ETag", ""),
)
if class_ == BaseTaskResp:
resp_class = task_action(raw_response.json()["action"]).resp_class
resp_class = task_action(raw_response.json()["action"]).class_
class_ = resp_class if resp_class else class_
return class_(**raw_response.json())
if "application" in content_type and class_ == BytesResp:
Expand Down
116 changes: 56 additions & 60 deletions src/pytmv1/model/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from typing import Any, Dict, List, Optional, Tuple, Union

from pydantic import BaseModel as PydanticBaseModel
from pydantic import Field
from pydantic import ConfigDict, Field
from pydantic import RootModel as PydanticRootModel
from pydantic import model_validator
from pydantic.alias_generators import to_camel

from .enums import (
Expand All @@ -25,37 +26,19 @@
Status,
)

CFG = ConfigDict(alias_generator=to_camel, populate_by_name=True)

class BaseModel(PydanticBaseModel):
def __init__(self, **data: Any):
super().__init__(**data)

class Config:
alias_generator = to_camel
populate_by_name = True
class BaseModel(PydanticBaseModel):
model_config = CFG


class RootModel(PydanticRootModel[List[int]]):
class Config:
alias_generator = to_camel
populate_by_name = True
model_config = CFG


class BaseConsumable(BaseModel):
def __init__(self, **data: Any):
super().__init__(**data)


def _get_task_id(headers: List[Dict[str, str]]) -> Optional[str]:
task_id: str = next(
(
h.get("value", "")
for h in headers
if "Operation-Location" == h.get("name", "")
),
"",
).split("/")[-1]
return task_id if task_id != "" else None
...


class Account(BaseModel):
Expand Down Expand Up @@ -204,25 +187,18 @@ class Error(BaseModel):
message: Optional[str] = None
number: Optional[int] = None

def __init__(self, **data: Any):
super().__init__(**data)


class ExceptionObject(BaseConsumable):
value: str
type: ObjectType
last_modified_date_time: str
description: Optional[str] = None

def __init__(self, **data: str) -> None:
super().__init__(value=self._obj_value(data), **data)

@staticmethod
def _obj_value(args: Dict[str, str]) -> str:
obj_value: Optional[str] = args.get(args.get("type", ""))
if obj_value is None:
raise ValueError("Object value not found")
return obj_value
@model_validator(mode="before")
@classmethod
def _map_data(cls, data: Dict[str, str]) -> Dict[str, str]:
data["value"] = data[data["type"]]
return data


class ImpactScope(BaseModel):
Expand Down Expand Up @@ -272,35 +248,39 @@ class MsData(BaseModel):
status: int
task_id: Optional[str] = None

def __init__(self, **data: Any):
super().__init__(
taskId=_get_task_id(data.pop("headers", {})),
**data,
)
@model_validator(mode="before")
@classmethod
def map_task_id(cls, data: Dict[str, Any]) -> Dict[str, Any]:
data["task_id"] = _get_task_id(data)
return data


class MsDataUrl(MsData):
url: str
id: Optional[str] = None
digest: Optional[Digest] = None

def __init__(self, **data: Any):
@model_validator(mode="before")
@classmethod
def map_data(cls, data: Dict[str, Any]) -> Dict[str, Any]:
data.update(data.pop("body", {}))
super().__init__(**data)
return data


class MsError(Error):
extra: Dict[str, str] = {}
task_id: Optional[str] = None

def __init__(self, **data: Any):
@model_validator(mode="before")
@classmethod
def map_data(cls, data: Dict[str, Any]) -> Dict[str, Any]:
data.update(data.pop("body", {}))
data.update(data.pop("error", {}))
super().__init__(
extra={"url": data.pop("url", "")},
taskId=_get_task_id(data.pop("headers", {})),
**data,
)
url = data.pop("url", None)
data["task_id"] = _get_task_id(data)
if url:
data["extra"] = {"url": url}
return data


class MsStatus(RootModel):
Expand Down Expand Up @@ -328,17 +308,13 @@ class SandboxSuspiciousObject(BaseModel):
type: ObjectType
value: str

def __init__(self, **data: Any):
obj: Tuple[str, str] = self._map(data)
super().__init__(type=obj[0], value=obj[1], **data)

@staticmethod
def _map(args: Dict[str, str]) -> Tuple[str, str]:
return {
(k, v)
for k, v in args.items()
if k in map(lambda ot: ot.value, ObjectType)
}.pop()
@model_validator(mode="before")
@classmethod
def map_data(cls, data: Dict[str, str]) -> Dict[str, str]:
obj = get_object(data)
data["type"] = obj[0]
data["value"] = obj[1]
return data


class Script(BaseConsumable):
Expand Down Expand Up @@ -371,3 +347,23 @@ class TiIndicator(Indicator):
matched_indicator_pattern_ids: List[str]
first_seen_date_times: List[str]
last_seen_date_times: List[str]


def get_object(data: Dict[str, str]) -> Tuple[str, str]:
for k, v in data.items():
if k in map(lambda ot: ot.value, ObjectType):
return k, v
raise ValueError("Could not find ObjectType")


def _get_task_id(data: Dict[str, Any]) -> Optional[str]:
return next(
map(
lambda header: header.get("value", "").split("/")[-1],
filter(
lambda header: "Operation-Location" == header.get("name", ""),
data.pop("headers", []),
),
),
None,
)
40 changes: 12 additions & 28 deletions src/pytmv1/model/responses.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,9 @@
from __future__ import annotations

from enum import Enum
from typing import (
Any,
Dict,
Generic,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union

from pydantic import Field
from pydantic import Field, model_validator

from .commons import (
Account,
Expand All @@ -32,6 +22,7 @@
Script,
SuspiciousObject,
TiAlert,
get_object,
)
from .enums import (
ObjectType,
Expand Down Expand Up @@ -71,9 +62,6 @@ class BaseTaskResp(BaseStatusResponse):
description: Optional[str] = None
account: Optional[str] = None

def __init__(self, **data: Any):
super().__init__(**data)


MR = TypeVar("MR", bound=BaseMultiResponse[Any])
R = TypeVar("R", bound=BaseResponse)
Expand Down Expand Up @@ -103,17 +91,13 @@ class BlockListTaskResp(BaseTaskResp):
type: ObjectType
value: str

def __init__(self, **data: Any):
obj: Tuple[str, str] = self._map(data)
super().__init__(type=obj[0], value=obj[1], **data)

@staticmethod
def _map(args: Dict[str, str]) -> Tuple[str, str]:
return {
(k, v)
for k, v in args.items()
if k in map(lambda ot: ot.value, ObjectType)
}.pop()
@model_validator(mode="before")
@classmethod
def map_data(cls, data: Dict[str, str]) -> Dict[str, str]:
obj = get_object(data)
data["type"] = obj[0]
data["value"] = obj[1]
return data


class BytesResp(BaseResponse):
Expand Down Expand Up @@ -286,6 +270,6 @@ class TaskAction(Enum):
RUN_OS_QUERY = ("runOsquery", None)
RUN_YARA_RULES = ("runYaraRules", None)

def __init__(self, action: str, resp_class: Optional[Type[T]]):
def __init__(self, action: str, class_: Optional[Type[T]]):
self.action = action
self.resp_class = resp_class
self.class_ = class_
34 changes: 17 additions & 17 deletions tests/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def text(self) -> str:


def sae_alert():
return SaeAlert.construct(
return SaeAlert.model_construct(
id="1",
investigationStatus=InvestigationStatus.NEW,
model="Possible Credential Dumping via Registry",
Expand All @@ -40,32 +40,32 @@ def sae_alert():
description="description",
workbenchLink="https://THE_WORKBENCH_URL",
score=64,
impactScope=ImpactScope.construct(
impactScope=ImpactScope.model_construct(
desktopCount=1,
serverCount=0,
accountCount=1,
emailAddressCount=0,
entities=[
Entity.construct(
entity_value=HostInfo.construct(
Entity.model_construct(
entity_value=HostInfo.model_construct(
name="host", ips=["1.1.1.1", "2.2.2.2"]
)
)
],
),
indicators=[
Indicator.construct(
Indicator.model_construct(
provenance=["Alert"],
value=HostInfo.construct(
value=HostInfo.model_construct(
name="host", ips=["1.1.1.1", "2.2.2.2"]
),
)
],
matchedRules=[
MatchedRule.construct(
MatchedRule.model_construct(
name="Potential Credential Dumping via Registry",
matchedFilters=[
MatchedFilter.construct(
MatchedFilter.model_construct(
name="Possible Credential Dumping via Registry Hive",
mitreTechniqueIds=[
"V9.T1003.004",
Expand All @@ -80,7 +80,7 @@ def sae_alert():


def ti_alert():
return TiAlert.construct(
return TiAlert.model_construct(
id="1",
investigationStatus=InvestigationStatus.NEW,
model="Threat Intelligence Sweeping",
Expand All @@ -94,38 +94,38 @@ def ti_alert():
reportLink="https://THE_TI_REPORT_URL",
createdBy="n/a",
score=42,
impactScope=ImpactScope.construct(
impactScope=ImpactScope.model_construct(
desktopCount=1,
serverCount=0,
accountCount=1,
emailAddressCount=0,
entities=[
Entity.construct(
entity_value=HostInfo.construct(
Entity.model_construct(
entity_value=HostInfo.model_construct(
name="host", ips=["1.1.1.1", "2.2.2.2"]
)
)
],
),
indicators=[
Indicator.construct(
Indicator.model_construct(
provenance=["Alert"],
value=HostInfo.construct(
value=HostInfo.model_construct(
name="host", ips=["1.1.1.1", "2.2.2.2"]
),
)
],
matchedIndicatorPatterns=[
MatchedIndicatorPattern.construct(
MatchedIndicatorPattern.model_construct(
tags=["STIX2.malicious-activity"],
pattern="[file:name = 'goog-phish-proto-1.vlpset']",
)
],
matchedRules=[
MatchedRule.construct(
MatchedRule.model_construct(
name="Potential Credential Dumping via Registry",
matchedFilters=[
MatchedFilter.construct(
MatchedFilter.model_construct(
name="Possible Credential Dumping via Registry Hive",
mitreTechniqueIds=[
"V9.T1003.004",
Expand Down
Loading

0 comments on commit 3e24a12

Please sign in to comment.