From 6237e1447cac7d5013a30407e8bcc7b54a2dd55e Mon Sep 17 00:00:00 2001 From: Thomas Legros Date: Mon, 15 Jan 2024 16:34:23 +0100 Subject: [PATCH 1/2] Update models for pydantic v2 migration --- src/pytmv1/core.py | 2 +- src/pytmv1/model/commons.py | 116 +++++++++++++++---------------- src/pytmv1/model/responses.py | 40 ++++------- tests/data.py | 34 ++++----- tests/integration/test_object.py | 34 +++++++++ tests/integration/test_search.py | 22 ++++++ tests/unit/test_core.py | 44 ++++++------ tests/unit/test_mapper.py | 22 +++--- 8 files changed, 177 insertions(+), 137 deletions(-) diff --git a/src/pytmv1/core.py b/src/pytmv1/core.py index cf6818e..b076511 100755 --- a/src/pytmv1/core.py +++ b/src/pytmv1/core.py @@ -283,7 +283,7 @@ def _parse_data(raw_response: Response, class_: Type[R]) -> R: etag=raw_response.headers.get("ETag", ""), ) if class_ == BaseTaskResp: - resp_class = task_action(raw_response.json()["action"]).resp_class + resp_class = task_action(raw_response.json()["action"]).class_ class_ = resp_class if resp_class else class_ return class_(**raw_response.json()) if "application" in content_type and class_ == BytesResp: diff --git a/src/pytmv1/model/commons.py b/src/pytmv1/model/commons.py index 35f6928..d1edbd6 100644 --- a/src/pytmv1/model/commons.py +++ b/src/pytmv1/model/commons.py @@ -3,8 +3,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union from pydantic import BaseModel as PydanticBaseModel -from pydantic import Field +from pydantic import ConfigDict, Field from pydantic import RootModel as PydanticRootModel +from pydantic import model_validator from pydantic.alias_generators import to_camel from .enums import ( @@ -25,37 +26,19 @@ Status, ) +CFG = ConfigDict(alias_generator=to_camel, populate_by_name=True) -class BaseModel(PydanticBaseModel): - def __init__(self, **data: Any): - super().__init__(**data) - class Config: - alias_generator = to_camel - populate_by_name = True +class BaseModel(PydanticBaseModel): + model_config = CFG class RootModel(PydanticRootModel[List[int]]): - class Config: - alias_generator = to_camel - populate_by_name = True + model_config = CFG class BaseConsumable(BaseModel): - def __init__(self, **data: Any): - super().__init__(**data) - - -def _get_task_id(headers: List[Dict[str, str]]) -> Optional[str]: - task_id: str = next( - ( - h.get("value", "") - for h in headers - if "Operation-Location" == h.get("name", "") - ), - "", - ).split("/")[-1] - return task_id if task_id != "" else None + ... class Account(BaseModel): @@ -204,9 +187,6 @@ class Error(BaseModel): message: Optional[str] = None number: Optional[int] = None - def __init__(self, **data: Any): - super().__init__(**data) - class ExceptionObject(BaseConsumable): value: str @@ -214,15 +194,11 @@ class ExceptionObject(BaseConsumable): last_modified_date_time: str description: Optional[str] = None - def __init__(self, **data: str) -> None: - super().__init__(value=self._obj_value(data), **data) - - @staticmethod - def _obj_value(args: Dict[str, str]) -> str: - obj_value: Optional[str] = args.get(args.get("type", "")) - if obj_value is None: - raise ValueError("Object value not found") - return obj_value + @model_validator(mode="before") + @classmethod + def _map_data(cls, data: Dict[str, str]) -> Dict[str, str]: + data["value"] = data[data["type"]] + return data class ImpactScope(BaseModel): @@ -272,11 +248,11 @@ class MsData(BaseModel): status: int task_id: Optional[str] = None - def __init__(self, **data: Any): - super().__init__( - taskId=_get_task_id(data.pop("headers", {})), - **data, - ) + @model_validator(mode="before") + @classmethod + def map_task_id(cls, data: Dict[str, Any]) -> Dict[str, Any]: + data["task_id"] = _get_task_id(data) + return data class MsDataUrl(MsData): @@ -284,23 +260,27 @@ class MsDataUrl(MsData): id: Optional[str] = None digest: Optional[Digest] = None - def __init__(self, **data: Any): + @model_validator(mode="before") + @classmethod + def map_data(cls, data: Dict[str, Any]) -> Dict[str, Any]: data.update(data.pop("body", {})) - super().__init__(**data) + return data class MsError(Error): extra: Dict[str, str] = {} task_id: Optional[str] = None - def __init__(self, **data: Any): + @model_validator(mode="before") + @classmethod + def map_data(cls, data: Dict[str, Any]) -> Dict[str, Any]: data.update(data.pop("body", {})) data.update(data.pop("error", {})) - super().__init__( - extra={"url": data.pop("url", "")}, - taskId=_get_task_id(data.pop("headers", {})), - **data, - ) + url = data.pop("url", None) + data["task_id"] = _get_task_id(data) + if url: + data["extra"] = {"url": url} + return data class MsStatus(RootModel): @@ -328,17 +308,13 @@ class SandboxSuspiciousObject(BaseModel): type: ObjectType value: str - def __init__(self, **data: Any): - obj: Tuple[str, str] = self._map(data) - super().__init__(type=obj[0], value=obj[1], **data) - - @staticmethod - def _map(args: Dict[str, str]) -> Tuple[str, str]: - return { - (k, v) - for k, v in args.items() - if k in map(lambda ot: ot.value, ObjectType) - }.pop() + @model_validator(mode="before") + @classmethod + def map_data(cls, data: Dict[str, str]) -> Dict[str, str]: + obj = get_object(data) + data["type"] = obj[0] + data["value"] = obj[1] + return data class Script(BaseConsumable): @@ -371,3 +347,23 @@ class TiIndicator(Indicator): matched_indicator_pattern_ids: List[str] first_seen_date_times: List[str] last_seen_date_times: List[str] + + +def get_object(data: Dict[str, str]) -> Tuple[str, str]: + for k, v in data.items(): + if k in map(lambda ot: ot.value, ObjectType): + return k, v + raise ValueError("Could not find ObjectType") + + +def _get_task_id(data: Dict[str, Any]) -> Optional[str]: + return next( + map( + lambda header: header.get("value", "").split("/")[-1], + filter( + lambda header: "Operation-Location" == header.get("name", ""), + data.pop("headers", []), + ), + ), + None, + ) diff --git a/src/pytmv1/model/responses.py b/src/pytmv1/model/responses.py index d1aeb87..afe993d 100644 --- a/src/pytmv1/model/responses.py +++ b/src/pytmv1/model/responses.py @@ -1,19 +1,9 @@ from __future__ import annotations from enum import Enum -from typing import ( - Any, - Dict, - Generic, - List, - Optional, - Tuple, - Type, - TypeVar, - Union, -) +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union -from pydantic import Field +from pydantic import Field, model_validator from .commons import ( Account, @@ -32,6 +22,7 @@ Script, SuspiciousObject, TiAlert, + get_object, ) from .enums import ( ObjectType, @@ -71,9 +62,6 @@ class BaseTaskResp(BaseStatusResponse): description: Optional[str] = None account: Optional[str] = None - def __init__(self, **data: Any): - super().__init__(**data) - MR = TypeVar("MR", bound=BaseMultiResponse[Any]) R = TypeVar("R", bound=BaseResponse) @@ -103,17 +91,13 @@ class BlockListTaskResp(BaseTaskResp): type: ObjectType value: str - def __init__(self, **data: Any): - obj: Tuple[str, str] = self._map(data) - super().__init__(type=obj[0], value=obj[1], **data) - - @staticmethod - def _map(args: Dict[str, str]) -> Tuple[str, str]: - return { - (k, v) - for k, v in args.items() - if k in map(lambda ot: ot.value, ObjectType) - }.pop() + @model_validator(mode="before") + @classmethod + def map_data(cls, data: Dict[str, str]) -> Dict[str, str]: + obj = get_object(data) + data["type"] = obj[0] + data["value"] = obj[1] + return data class BytesResp(BaseResponse): @@ -286,6 +270,6 @@ class TaskAction(Enum): RUN_OS_QUERY = ("runOsquery", None) RUN_YARA_RULES = ("runYaraRules", None) - def __init__(self, action: str, resp_class: Optional[Type[T]]): + def __init__(self, action: str, class_: Optional[Type[T]]): self.action = action - self.resp_class = resp_class + self.class_ = class_ diff --git a/tests/data.py b/tests/data.py index 59532f9..44a2fd1 100755 --- a/tests/data.py +++ b/tests/data.py @@ -30,7 +30,7 @@ def text(self) -> str: def sae_alert(): - return SaeAlert.construct( + return SaeAlert.model_construct( id="1", investigationStatus=InvestigationStatus.NEW, model="Possible Credential Dumping via Registry", @@ -40,32 +40,32 @@ def sae_alert(): description="description", workbenchLink="https://THE_WORKBENCH_URL", score=64, - impactScope=ImpactScope.construct( + impactScope=ImpactScope.model_construct( desktopCount=1, serverCount=0, accountCount=1, emailAddressCount=0, entities=[ - Entity.construct( - entity_value=HostInfo.construct( + Entity.model_construct( + entity_value=HostInfo.model_construct( name="host", ips=["1.1.1.1", "2.2.2.2"] ) ) ], ), indicators=[ - Indicator.construct( + Indicator.model_construct( provenance=["Alert"], - value=HostInfo.construct( + value=HostInfo.model_construct( name="host", ips=["1.1.1.1", "2.2.2.2"] ), ) ], matchedRules=[ - MatchedRule.construct( + MatchedRule.model_construct( name="Potential Credential Dumping via Registry", matchedFilters=[ - MatchedFilter.construct( + MatchedFilter.model_construct( name="Possible Credential Dumping via Registry Hive", mitreTechniqueIds=[ "V9.T1003.004", @@ -80,7 +80,7 @@ def sae_alert(): def ti_alert(): - return TiAlert.construct( + return TiAlert.model_construct( id="1", investigationStatus=InvestigationStatus.NEW, model="Threat Intelligence Sweeping", @@ -94,38 +94,38 @@ def ti_alert(): reportLink="https://THE_TI_REPORT_URL", createdBy="n/a", score=42, - impactScope=ImpactScope.construct( + impactScope=ImpactScope.model_construct( desktopCount=1, serverCount=0, accountCount=1, emailAddressCount=0, entities=[ - Entity.construct( - entity_value=HostInfo.construct( + Entity.model_construct( + entity_value=HostInfo.model_construct( name="host", ips=["1.1.1.1", "2.2.2.2"] ) ) ], ), indicators=[ - Indicator.construct( + Indicator.model_construct( provenance=["Alert"], - value=HostInfo.construct( + value=HostInfo.model_construct( name="host", ips=["1.1.1.1", "2.2.2.2"] ), ) ], matchedIndicatorPatterns=[ - MatchedIndicatorPattern.construct( + MatchedIndicatorPattern.model_construct( tags=["STIX2.malicious-activity"], pattern="[file:name = 'goog-phish-proto-1.vlpset']", ) ], matchedRules=[ - MatchedRule.construct( + MatchedRule.model_construct( name="Potential Credential Dumping via Registry", matchedFilters=[ - MatchedFilter.construct( + MatchedFilter.model_construct( name="Possible Credential Dumping via Registry Hive", mitreTechniqueIds=[ "V9.T1003.004", diff --git a/tests/integration/test_object.py b/tests/integration/test_object.py index b1f4062..0c30a61 100755 --- a/tests/integration/test_object.py +++ b/tests/integration/test_object.py @@ -21,6 +21,17 @@ def test_add_to_exception_list(client): assert result.response.items[0].status == 201 +def test_add_to_block_list(client): + result = client.add_to_block_list( + ObjectTask(objectType=ObjectType.IP, objectValue="1.1.1.1") + ) + assert isinstance(result.response, MultiResp) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].task_id + assert result.response.items[0].status == 202 + + def test_add_to_suspicious_list(client): result = client.add_to_suspicious_list( SuspiciousObjectTask( @@ -36,6 +47,17 @@ def test_add_to_suspicious_list(client): assert result.response.items[0].status == 201 +def test_remove_from_block_list(client): + result = client.remove_from_block_list( + ObjectTask(objectType=ObjectType.IP, objectValue="1.1.1.1") + ) + assert isinstance(result.response, MultiResp) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].task_id + assert result.response.items[0].status == 202 + + def test_remove_from_exception_list(client): result = client.remove_from_exception_list( ObjectTask(objectType=ObjectType.IP, objectValue="1.1.1.1") @@ -58,6 +80,12 @@ def test_remove_from_suspicious_list(client): assert result.response.items[0].status == 204 +def test_consume_exception_list(client): + result = client.consume_exception_list(lambda s: None) + assert result.result_code == ResultCode.SUCCESS + assert result.response.total_consumed == 1 + + def test_get_exception_list(client): result = client.get_exception_list() assert isinstance(result.response, GetExceptionListResp) @@ -67,6 +95,12 @@ def test_get_exception_list(client): assert result.response.items[0].value == "https://*.example.com/path1/*" +def test_consume_suspicious_list(client): + result = client.consume_suspicious_list(lambda s: None) + assert result.result_code == ResultCode.SUCCESS + assert result.response.total_consumed == 1 + + def test_get_suspicious_list(client): result = client.get_suspicious_list() assert isinstance(result.response, GetSuspiciousListResp) diff --git a/tests/integration/test_search.py b/tests/integration/test_search.py index a408759..6cd335f 100755 --- a/tests/integration/test_search.py +++ b/tests/integration/test_search.py @@ -9,6 +9,14 @@ ) +def test_consume_email_activity_data(client): + result = client.consume_email_activity_data( + lambda s: None, mailMsgSubject="spam" + ) + assert result.result_code == ResultCode.SUCCESS + assert result.response.total_consumed == 1 + + def test_get_email_activity_data(client): result = client.get_email_activity_data( mailMsgSubject="spam", mailSenderIp="192.169.1.1" @@ -25,6 +33,12 @@ def test_get_email_activity_data_count(client): assert result.response.total_count > 0 +def test_consume_endpoint_activity_data(client): + result = client.consume_endpoint_activity_data(lambda s: None, dpt="443") + assert result.result_code == ResultCode.SUCCESS + assert result.response.total_consumed == 1 + + def test_get_endpoint_activity_data(client): result = client.get_endpoint_activity_data(dpt="443") assert result.result_code == ResultCode.SUCCESS @@ -39,6 +53,14 @@ def test_get_endpoint_activity_count(client): assert result.response.total_count > 0 +def test_consume_endpoint_data(client): + result = client.consume_endpoint_data( + lambda s: None, QueryOp.AND, "client1" + ) + assert result.result_code == ResultCode.SUCCESS + assert result.response.total_consumed == 1 + + def test_get_endpoint_data(client): result = client.get_endpoint_data(QueryOp.AND, "client1") assert result.result_code == ResultCode.SUCCESS diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py index b2d9bdc..f5b17d3 100755 --- a/tests/unit/test_core.py +++ b/tests/unit/test_core.py @@ -11,7 +11,6 @@ Error, ExceptionObject, GetExceptionListResp, - MsData, MsError, MultiResp, NoContentResp, @@ -19,7 +18,6 @@ SandboxAnalysisResultResp, SandboxSubmissionStatusResp, SandboxSuspiciousListResp, - SandboxSuspiciousObject, Status, __version__, ) @@ -33,7 +31,7 @@ ServerMultiJsonError, ServerTextError, ) -from pytmv1.model.enums import Api, RiskLevel +from pytmv1.model.enums import Api from pytmv1.model.responses import BaseStatusResponse from tests.data import TextResponse @@ -48,14 +46,14 @@ def test_consume_linkable_with_next_link_multiple_items(mocker, core): GetExceptionListResp( nextLink="not_empty", items=[ - ExceptionObject.construct(), - ExceptionObject.construct(), + ExceptionObject.model_construct(), + ExceptionObject.model_construct(), ], ), GetExceptionListResp( items=[ - ExceptionObject.construct(), - ExceptionObject.construct(), + ExceptionObject.model_construct(), + ExceptionObject.model_construct(), ] ), ], @@ -78,7 +76,7 @@ def test_consume_linkable_with_next_link_single_item(mocker, core): nextLink="https://host/api/path?skipToken=c2tpcFRva2Vu", items=[], ), - GetExceptionListResp(items=[ExceptionObject.construct()]), + GetExceptionListResp(items=[ExceptionObject.model_construct()]), ], ) total = core._consume_linkable( @@ -176,13 +174,13 @@ def test_parse_data_with_json(): raw_response.headers = {"Content-Type": "application/json"} raw_response.json = lambda: { "items": [ - SandboxSuspiciousObject( - riskLevel=RiskLevel.HIGH, - analysisCompletionDateTime="2021-05-07T03:08:40", - expiredDateTime="2021-06-07T03:08:40Z", - rootSha1="fb5608fa03de204a12fe1e9e5275e4a682107471", - ip="6.6.6.6", - ) + { + "riskLevel": "high", + "analysisCompletionDateTime": "2021-05-07T03:08:40", + "expiredDateTime": "2021-06-07T03:08:40Z", + "rootSha1": "fb5608fa03de204a12fe1e9e5275e4a682107471", + "ip": "6.6.6.6", + } ] } response = core_m._parse_data(raw_response, SandboxSuspiciousListResp) @@ -204,8 +202,8 @@ def test_parse_data_with_multi_and_wrong_model_is_failed(): raw_response = Response() raw_response.headers = {"Content-Type": "application/json"} raw_response.status_code = 207 - raw_response.json = lambda: MultiResp(items=[MsData(status=200)]) - with pytest.raises(TypeError): + raw_response.json = lambda: {"items": [{"status": 200}]} + with pytest.raises(ValidationError): core_m._parse_data(raw_response, AddAlertNoteResp) @@ -213,7 +211,7 @@ def test_parse_data_with_single_and_wrong_model_is_failed(): raw_response = Response() raw_response.headers = {"Content-Type": "application/json"} raw_response.status_code = 200 - raw_response.json = lambda: AddAlertNoteResp(location="test") + raw_response.json = lambda: {"location": "test"} with pytest.raises(ValidationError): core_m._parse_data(raw_response, MultiResp) @@ -233,7 +231,7 @@ def test_parse_html(): def test_poll_status_with_rejected_status_is_not_polling(): start_time = time.time() core_m._poll_status( - lambda: BaseStatusResponse.construct(status=Status.REJECTED), + lambda: BaseStatusResponse.model_construct(status=Status.REJECTED), 20, ) assert time.time() - start_time < 20 @@ -242,7 +240,7 @@ def test_poll_status_with_rejected_status_is_not_polling(): def test_poll_status_with_running_status_is_polling(): start_time = time.time() core_m._poll_status( - lambda: BaseStatusResponse.construct(status=Status.RUNNING), + lambda: BaseStatusResponse.model_construct(status=Status.RUNNING), 2, ) assert time.time() - start_time >= 2 @@ -251,7 +249,7 @@ def test_poll_status_with_running_status_is_polling(): def test_poll_status_with_succeeded_status(): start_time = time.time() core_m._poll_status( - lambda: BaseStatusResponse.construct(status=Status.SUCCEEDED), + lambda: BaseStatusResponse.model_construct(status=Status.SUCCEEDED), 20, ) assert time.time() - start_time < 20 @@ -270,7 +268,7 @@ def test_send(core, mocker): def test_send_linkable(mocker, core): mock_process = mocker.patch.object(core, "_process") mock_process.return_value = GetExceptionListResp( - items=[ExceptionObject.construct()] + items=[ExceptionObject.model_construct()] ) result = core.send_linkable( GetExceptionListResp, @@ -284,7 +282,7 @@ def test_send_linkable(mocker, core): def test_send_sandbox_result_with_polling(core, mocker): mock_poll = mocker.patch.object(core_m, "_poll_status") - mock_poll.return_value = SandboxSubmissionStatusResp.construct( + mock_poll.return_value = SandboxSubmissionStatusResp.model_construct( status=Status.SUCCEEDED ) mock_send = mocker.patch.object(core, "_process") diff --git a/tests/unit/test_mapper.py b/tests/unit/test_mapper.py index c23bfa5..6aaf154 100755 --- a/tests/unit/test_mapper.py +++ b/tests/unit/test_mapper.py @@ -44,7 +44,7 @@ def test_map_common(): def test_map_entities_with_type_email(): - entities = [Entity.construct(entity_value="email@email.com")] + entities = [Entity.model_construct(entity_value="email@email.com")] dictionary = {} mapper._map_entities(dictionary, entities) assert dictionary["duser"] == "email@email.com" @@ -52,8 +52,8 @@ def test_map_entities_with_type_email(): def test_map_entities_with_type_host_info(): entities = [ - Entity.construct( - entity_value=HostInfo.construct( + Entity.model_construct( + entity_value=HostInfo.model_construct( name="host", ips=["1.1.1.1", "2.2.2.2"] ) ) @@ -65,14 +65,16 @@ def test_map_entities_with_type_host_info(): def test_map_entities_with_type_user(): - entities = [Entity.construct(entity_value="username")] + entities = [Entity.model_construct(entity_value="username")] dictionary = {} mapper._map_entities(dictionary, entities) assert dictionary["duser"] == "username" def test_map_indicators_with_type_command_line(): - indicators = [Indicator.construct(type="command_line", value="cmd.exe")] + indicators = [ + Indicator.model_construct(type="command_line", value="cmd.exe") + ] dictionary = {} mapper._map_indicators(dictionary, indicators) assert dictionary["dproc"] == "cmd.exe" @@ -80,8 +82,10 @@ def test_map_indicators_with_type_command_line(): def test_map_indicators_with_type_host_info(): indicators = [ - Indicator.construct( - value=HostInfo.construct(name="host", ips=["1.1.1.1", "2.2.2.2"]) + Indicator.model_construct( + value=HostInfo.model_construct( + name="host", ips=["1.1.1.1", "2.2.2.2"] + ) ) ] dictionary = {} @@ -91,7 +95,9 @@ def test_map_indicators_with_type_host_info(): def test_map_indicators_with_unknown_type(): - indicators = [Indicator.construct(type="unknown_type", value="unknown")] + indicators = [ + Indicator.model_construct(type="unknown_type", value="unknown") + ] dictionary = {} mapper._map_indicators(dictionary, indicators) assert dictionary["unknownType"] == "unknown" From ef6383a34ade95b2a14dcd7a0e028e381804867e Mon Sep 17 00:00:00 2001 From: Thomas Legros Date: Mon, 15 Jan 2024 18:43:59 +0100 Subject: [PATCH 2/2] Fixed release_conn not working properly --- src/pytmv1/adapter.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/pytmv1/adapter.py b/src/pytmv1/adapter.py index 4ef476b..5802d20 100644 --- a/src/pytmv1/adapter.py +++ b/src/pytmv1/adapter.py @@ -11,8 +11,14 @@ class HTTPConnectionPool(HTTPUrllib): @typing.no_type_check def urlopen(self, method, url, **kwargs): + kwargs.pop("preload_content", "") return super().urlopen( - method, url, pool_timeout=5, release_conn=True, **kwargs + method, + url, + pool_timeout=5, + release_conn=True, + preload_content=False, + **kwargs, )