From cf3575fc939f9c2c6567612908c51206e7902a18 Mon Sep 17 00:00:00 2001 From: Thomas Legros Date: Fri, 1 Dec 2023 14:35:15 +0100 Subject: [PATCH] Pytmv1 - Initial commit --- python/pmtmv1/README.md | 122 -- python/pytmv1/.coveragerc | 2 + python/pytmv1/README.md | 146 +++ python/pytmv1/pyproject.toml | 77 ++ python/pytmv1/src/pytmv1/__about__.py | 1 + python/pytmv1/src/pytmv1/__init__.py | 186 +++ python/pytmv1/src/pytmv1/caller.py | 1084 +++++++++++++++++ python/pytmv1/src/pytmv1/core.py | 360 ++++++ python/pytmv1/src/pytmv1/exceptions.py | 59 + python/pytmv1/src/pytmv1/mapper.py | 115 ++ python/pytmv1/src/pytmv1/model/__init__.py | 0 python/pytmv1/src/pytmv1/model/commons.py | 356 ++++++ python/pytmv1/src/pytmv1/model/enums.py | 260 ++++ python/pytmv1/src/pytmv1/model/requests.py | 66 + python/pytmv1/src/pytmv1/model/responses.py | 217 ++++ python/pytmv1/src/pytmv1/py.typed | 0 python/pytmv1/src/pytmv1/results.py | 148 +++ python/pytmv1/src/pytmv1/utils.py | 131 ++ python/pytmv1/tests/__init__.py | 0 python/pytmv1/tests/conftest.py | 48 + python/pytmv1/tests/data.py | 139 +++ python/pytmv1/tests/integration/__init__.py | 0 .../pytmv1/tests/integration/test_account.py | 33 + .../pytmv1/tests/integration/test_common.py | 96 ++ python/pytmv1/tests/integration/test_email.py | 49 + .../pytmv1/tests/integration/test_endpoint.py | 39 + .../pytmv1/tests/integration/test_network.py | 52 + .../pytmv1/tests/integration/test_object.py | 79 ++ .../pytmv1/tests/integration/test_sandbox.py | 112 ++ .../pytmv1/tests/integration/test_search.py | 46 + .../tests/integration/test_workbench.py | 77 ++ python/pytmv1/tests/unit/__init__.py | 0 python/pytmv1/tests/unit/test_caller.py | 9 + python/pytmv1/tests/unit/test_core.py | 497 ++++++++ python/pytmv1/tests/unit/test_mapper.py | 132 ++ python/pytmv1/tests/unit/test_utils.py | 116 ++ python/pytmv1/tox.ini | 5 + 37 files changed, 4737 insertions(+), 122 deletions(-) delete mode 100644 python/pmtmv1/README.md create mode 100755 python/pytmv1/.coveragerc create mode 100755 python/pytmv1/README.md create mode 100755 python/pytmv1/pyproject.toml create mode 100755 python/pytmv1/src/pytmv1/__about__.py create mode 100755 python/pytmv1/src/pytmv1/__init__.py create mode 100755 python/pytmv1/src/pytmv1/caller.py create mode 100755 python/pytmv1/src/pytmv1/core.py create mode 100755 python/pytmv1/src/pytmv1/exceptions.py create mode 100755 python/pytmv1/src/pytmv1/mapper.py create mode 100644 python/pytmv1/src/pytmv1/model/__init__.py create mode 100644 python/pytmv1/src/pytmv1/model/commons.py create mode 100644 python/pytmv1/src/pytmv1/model/enums.py create mode 100755 python/pytmv1/src/pytmv1/model/requests.py create mode 100644 python/pytmv1/src/pytmv1/model/responses.py create mode 100755 python/pytmv1/src/pytmv1/py.typed create mode 100755 python/pytmv1/src/pytmv1/results.py create mode 100755 python/pytmv1/src/pytmv1/utils.py create mode 100755 python/pytmv1/tests/__init__.py create mode 100755 python/pytmv1/tests/conftest.py create mode 100755 python/pytmv1/tests/data.py create mode 100755 python/pytmv1/tests/integration/__init__.py create mode 100644 python/pytmv1/tests/integration/test_account.py create mode 100755 python/pytmv1/tests/integration/test_common.py create mode 100755 python/pytmv1/tests/integration/test_email.py create mode 100755 python/pytmv1/tests/integration/test_endpoint.py create mode 100755 python/pytmv1/tests/integration/test_network.py create mode 100755 python/pytmv1/tests/integration/test_object.py create mode 100755 python/pytmv1/tests/integration/test_sandbox.py create mode 100755 python/pytmv1/tests/integration/test_search.py create mode 100755 python/pytmv1/tests/integration/test_workbench.py create mode 100755 python/pytmv1/tests/unit/__init__.py create mode 100755 python/pytmv1/tests/unit/test_caller.py create mode 100755 python/pytmv1/tests/unit/test_core.py create mode 100755 python/pytmv1/tests/unit/test_mapper.py create mode 100755 python/pytmv1/tests/unit/test_utils.py create mode 100755 python/pytmv1/tox.ini diff --git a/python/pmtmv1/README.md b/python/pmtmv1/README.md deleted file mode 100644 index d8b51da..0000000 --- a/python/pmtmv1/README.md +++ /dev/null @@ -1,122 +0,0 @@ -## Introduction - -Trend Vision Oneā„¢ [Python library (pytmv1)](https://pypi.org/project/pytmv1/) - -Please see the link above to access the library and the most current documentation. - -## Configuration - -| Parameter | Description | -| ----------- | ----------- | -| name | Identify the application using this library. | -| token | Authentication token created for your account. | -| url | Vision One API url this client connects to. | -| pool_connections | Number of connection pools to cache (defaults to 1). | -| pool_maxsize | Maximum size of the pool (defaults to 1). | - - -## Documentation -### Quick start -### Installation - -`pip install pytmv1` - -### Usage - -```python -import pytmv1 -client = pytmv1.client("MyApplication", "Token", "https://api.xdr.trendmicro.com") -result = client.get_exception_list() -result.response -GetExceptionListResp( - next_link=None, - items=[ - ExceptionObject( - url='https://*.example.com/path1/*', - type=, - last_modified_date_time='2023-01-12T14:05:37Z', - description='object description' - ) - ] -) -result.result_code -ResultCode.SUCCESS -``` - -### Build the project -### Install dependencies - -`pip install -e ".[dev]"` - -### Build - -`hatch build` - -### Run unit tests - -`pytest --verbose ./tests/unit` - -### Run integration tests - -`$url`: Vision One API url (i.e: https://api.xdr.trendmicro.com) - -`pytest --mock-url="$url" --verbose ./tests/integration` - -#### Supported APIs - -| Python | Vision One | -| --------| --------- | -| Connectivity | | -| `test_connectivity` | Check availability of service | -| Common | | -| `get_base_task_result` | Download response task results | -| `get_task_result` | Download response task results | -| Domain Account | | -| `disable_account` | Disable user account | -| `enable_account` | Enable user account | -| `reset_password_account` | Force password reset | -| `sign_out_account` | Force sign out | -| Email | | -| `delete_email_message` | Delete email message | -| `quarantine_email_message` | Quarantine email message | -| `restore_email_message` | Restore email message | -| Endpoint | | -| `collect_file` | Collect file | -| `isolate_endpoint` | Isolate endpoint | -| `restore_endpoint` | Restore endpoint | -| `terminate_process` | Terminate process | -| Sandbox Analysis | | -| `download_sandbox_analysis_result` | Download analysis results | -| `download_sandbox_investigation_package` | Download investigation package | -| `get_sandbox_analysis_result` | Get analysis results | -| `get_sandbox_submission_status` | Get submission status | -| `get_sandbox_suspicious_list` | Download suspicious object list | -| `submit_file_to_sandbox` | Submit file to sandbox | -| `submit_urls_to_sandbox` | Submit URLs to sandbox | -| Search | | -| `get_endpoint_data` `consume_endpoint_data` | Get endpoint data | -| Suspicious Objects | | -| `add_to_block_list` | Add to block list | -| `remove_from_block_list` | Remove from block list | -| Suspicious Object Exception List | | -| `add_to_exception_list` | Add to exception list | -| `get_exception_list` `consume_exception_list` | Get exception list | -| `remove_from_exception_list` | Remove from exception list | -| Suspicious Object List | | -| `add_to_suspicious_list` | Add to suspicious object list | -| `get_suspicious_list` `consume_suspicious_list` | List suspicious objects | -| `remove_from_suspicious_list` | Remove from suspicious object list | -| Workbench | | -| `add_alert_note` | Add alert note | -| `edit_alert_status` | Modify alert status | -| `get_alert_details` | Get alert details | -| `get_alert_list` `consume_alert_list` | Get alerts list | - -## Contributing -Thank you for your interest in this project, please make sure to read the contribution guide. - -## Code of conduct -See [Code of Conduct](CODE_OF_CONDUCT.md). - -## License -Project distributed under the [Apache 2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/python/pytmv1/.coveragerc b/python/pytmv1/.coveragerc new file mode 100755 index 0000000..2f15fbf --- /dev/null +++ b/python/pytmv1/.coveragerc @@ -0,0 +1,2 @@ +[run] +omit = __about__.py \ No newline at end of file diff --git a/python/pytmv1/README.md b/python/pytmv1/README.md new file mode 100755 index 0000000..4d2b6cb --- /dev/null +++ b/python/pytmv1/README.md @@ -0,0 +1,146 @@ +## PyTMV1: Python Library for Trend Micro Vision One +[![Build](https://github.com/TrendATI/pytmv1/actions/workflows/build.yml/badge.svg?branch=main)](https://github.com/TrendATI/pytmv1/actions/workflows/build.yml) +[![Lint](https://github.com/TrendATI/pytmv1/actions/workflows/lint.yml/badge.svg?branch=main)](https://github.com/TrendATI/pytmv1/actions/workflows/lint.yml) +[![Test](https://github.com/TrendATI/pytmv1/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/TrendATI/pytmv1/actions/workflows/test.yml) +[![Coverage](https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Ft0mz06%2F6c39ef59cc8beb9595e91fc96793de5b%2Fraw%2Fcoverage.json)](https://github.com/TrendATI/pytmv1/actions/workflows/coverage.yml) +[![Pypi: version](https://img.shields.io/pypi/v/pytmv1)](https://pypi.org/project/pytmv1) +[![Downloads](https://pepy.tech/badge/pytmv1)](https://pepy.tech/project/pytmv1) +[![Python: version](https://img.shields.io/pypi/pyversions/pytmv1)](https://pypi.org/project/pytmv1) +[![License: apache](https://img.shields.io/pypi/l/pytmv1)](https://spdx.org/licenses/Apache-2.0.html) +[![Types - mypy](https://img.shields.io/badge/types-mypy-blue.svg)](http://mypy-lang.org) +[![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) + + +#### Prerequisites +Using this project requires at least [Python 3.7](https://www.python.org/downloads/). + +#### Features + +- A thread-safe client for your application. +- HTTP pooling capabilities. +- Easy integration with Trend Micro Vision One APIs. + + +#### Configuration +| parameter | description | +|:-----------------|:-----------------------------------------------------| +| name | Identify the application using this library. | +| token | Authentication token created for your account. | +| url | Vision One API url this client connects to. | +| pool_connections | Number of connection pools to cache (defaults to 1). | +| pool_maxsize | Maximum size of the pool (defaults to 1). | + +#### Quick start +Installation +``` +pip install pytmv1 +``` + +Usage +```python +>>> import pytmv1 +>>> client = pytmv1.client("MyApplication", "Token", "https://api.xdr.trendmicro.com") +>>> result = client.get_exception_list() +>>> result.response +GetExceptionListResp( + next_link=None, + items=[ + ExceptionObject( + url='https://*.example.com/path1/*', + type=, + last_modified_date_time='2023-01-12T14:05:37Z', + description='object description' + ) + ] +) +>>> result.result_code +ResultCode.SUCCESS +``` + + +#### Build the project +Install dependencies +```console +pip install -e ".[dev]" +``` +Build +```console +hatch build +``` +Run unit tests +```console +pytest --verbose ./tests/unit +``` +Run integration tests + - `$url`: Vision One API url (i.e: https://api.xdr.trendmicro.com) + +```console +pytest --mock-url="$url" --verbose ./tests/integration +``` + +Supported APIs +-------------- +| Python | Vision One | +|:--------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| **Connectivity** | | +| `test_connectivity` | [Check availability of service](https://automation.trendmicro.com/xdr/api-v3#tag/Connectivity/paths/~1v3.0~1healthcheck~1connectivity/get) | +| **Common** | | +| `get_base_task_result` | [Download response task results](https://automation.trendmicro.com/xdr/api-v3#tag/Common/paths/~1v3.0~1response~1tasks~1%7Bid%7D/get) | +| `get_task_result` | [Download response task results](https://automation.trendmicro.com/xdr/api-v3#tag/Common/paths/~1v3.0~1response~1tasks~1{id}/get) | +| **Domain Account** | | +| `disable_account` | [Disable user account](https://automation.trendmicro.com/xdr/api-v3#tag/Domain-Account/paths/~1v3.0~1response~1domainAccounts~1disable/post) | +| `enable_account` | [Enable user account](https://automation.trendmicro.com/xdr/api-v3#tag/Domain-Account/paths/~1v3.0~1response~1domainAccounts~1enable/post) | +| `reset_password_account` | [Force password reset](https://automation.trendmicro.com/xdr/api-v3#tag/Domain-Account/paths/~1v3.0~1response~1domainAccounts~1resetPassword/post) | +| `sign_out_account` | [Force sign out](https://automation.trendmicro.com/xdr/api-v3#tag/Domain-Account/paths/~1v3.0~1response~1domainAccounts~1signOut/post) | +| **Email** | | +| `delete_email_message` | [Delete email message](https://automation.trendmicro.com/xdr/api-v3#tag/Email/paths/~1v3.0~1response~1emails~1delete/post) | +| `quarantine_email_message` | [Quarantine email message](https://automation.trendmicro.com/xdr/api-v3#tag/Email/paths/~1v3.0~1response~1emails~1quarantine/post) | +| `restore_email_message` | [Restore email message](https://automation.trendmicro.com/xdr/api-v3#tag/Email/paths/~1v3.0~1response~1emails~1restore/post) | +| **Endpoint** | | +| `collect_file` | [Collect file](https://automation.trendmicro.com/xdr/api-v3#tag/Endpoint/paths/~1v3.0~1response~1endpoints~1collectFile/post) | +| `isolate_endpoint` | [Isolate endpoint](https://automation.trendmicro.com/xdr/api-v3#tag/Endpoint/paths/~1v3.0~1response~1endpoints~1isolate/post) | +| `restore_endpoint` | [Restore endpoint](https://automation.trendmicro.com/xdr/api-v3#tag/Endpoint/paths/~1v3.0~1response~1endpoints~1restore/post) | +| `terminate_process` | [Terminate process](https://automation.trendmicro.com/xdr/api-v3#tag/Endpoint/paths/~1v3.0~1response~1endpoints~1terminateProcess/post) | +| **Sandbox Analysis** | | +| `download_sandbox_analysis_result` | [Download analysis results](https://automation.trendmicro.com/xdr/api-v3#tag/Sandbox-Analysis/paths/~1v3.0~1sandbox~1analysisResults~1{id}~1report/get) | +| `download_sandbox_investigation_package` | [Download investigation package](https://automation.trendmicro.com/xdr/api-v3#tag/Sandbox-Analysis/paths/~1v3.0~1sandbox~1analysisResults~1{id}~1investigationPackage/get) | +| `get_sandbox_analysis_result` | [Get analysis results](https://automation.trendmicro.com/xdr/api-v3#tag/Sandbox-Analysis/paths/~1v3.0~1sandbox~1analysisResults~1{id}/get) | +| `get_sandbox_submission_status` | [Get submission status](https://automation.trendmicro.com/xdr/api-v3#tag/Sandbox-Analysis/paths/~1v3.0~1sandbox~1tasks~1{id}/get) | +| `get_sandbox_suspicious_list` | [Download suspicious object list](https://automation.trendmicro.com/xdr/api-v3#tag/Sandbox-Analysis/paths/~1v3.0~1sandbox~1analysisResults~1{id}~1suspiciousObjects/get) | +| `submit_file_to_sandbox` | [Submit file to sandbox](https://automation.trendmicro.com/xdr/api-v3#tag/Sandbox-Analysis/paths/~1v3.0~1sandbox~1files~1analyze/post) | +| `submit_urls_to_sandbox` | [Submit URLs to sandbox](https://automation.trendmicro.com/xdr/api-v3#tag/Sandbox-Analysis/paths/~1v3.0~1sandbox~1urls~1analyze/post) | +| **Search** | | +| `get_email_activity_data` `consume_email_activity_data` | [Get email activity data](https://automation.trendmicro.com/xdr/api-v3#tag/Search/paths/~1v3.0~1search~1emailActivities/get) | +| `get_email_activity_data_count` | [Get email activity data count](https://automation.trendmicro.com/xdr/api-v3#tag/Search/paths/~1v3.0~1search~1emailActivities/get) | +| `get_endpoint_activity_data` `consume_endpoint_activity_data` | [Get endpoint activity data](https://automation.trendmicro.com/xdr/api-v3#tag/Search/paths/~1v3.0~1search~1endpointActivities/get) | +| `get_endpoint_activity_data_count` | [Get endpoint activity data count](https://automation.trendmicro.com/xdr/api-v3#tag/Search/paths/~1v3.0~1search~1endpointActivities/get) | +| `get_endpoint_data` `consume_endpoint_data` | [Get endpoint data](https://automation.trendmicro.com/xdr/api-v3#tag/Search/paths/~1v3.0~1eiqs~1endpoints/get) | +| **Suspicious Objects** | | +| `add_to_block_list` | [Add to block list](https://automation.trendmicro.com/xdr/api-v3#tag/Suspicious-Objects/paths/~1v3.0~1response~1suspiciousObjects/post) | +| `remove_from_block_list` | [Remove from block list](https://automation.trendmicro.com/xdr/api-v3#tag/Suspicious-Objects/paths/~1v3.0~1response~1suspiciousObjects~1delete/post) | +| **Suspicious Object Exception List** | | +| `add_to_exception_list` | [Add to exception list](https://automation.trendmicro.com/xdr/api-v3#tag/Suspicious-Object-Exception-List/paths/~1v3.0~1threatintel~1suspiciousObjectExceptions/post) | +| `get_exception_list` `consume_exception_list` | [Get exception list](https://automation.trendmicro.com/xdr/api-v3#tag/Suspicious-Object-Exception-List/paths/~1v3.0~1threatintel~1suspiciousObjectExceptions/get) | +| `remove_from_exception_list` | [Remove from exception list](https://automation.trendmicro.com/xdr/api-v3#tag/Suspicious-Object-Exception-List/paths/~1v3.0~1threatintel~1suspiciousObjectExceptions~1delete/post) | +| **Suspicious Object List** | | +| `add_to_suspicious_list` | [Add to suspicious object list](https://automation.trendmicro.com/xdr/api-v3#tag/Suspicious-Object-List/paths/~1v3.0~1threatintel~1suspiciousObjects/post) | +| `get_suspicious_list` `consume_suspicious_list` | [List suspicious objects](https://automation.trendmicro.com/xdr/api-v3#tag/Suspicious-Object-List/paths/~1v3.0~1threatintel~1suspiciousObjects/get) | +| `remove_from_suspicious_list` | [Remove from suspicious object list](https://automation.trendmicro.com/xdr/api-v3#tag/Suspicious-Object-List/paths/~1v3.0~1threatintel~1suspiciousObjects~1delete/post) | +| **Workbench** | | +| `add_alert_note` | [Add alert note](https://automation.trendmicro.com/xdr/api-v3#tag/Workbench-notes/paths/~1v3.0~1workbench~1alerts~1{alertId}~1notes/post) | +| `edit_alert_status` | [Modify alert status](https://automation.trendmicro.com/xdr/api-v3#tag/Workbench/paths/~1v3.0~1workbench~1alerts~1{id}/patch) | +| `get_alert_details` | [Get alert details](https://automation.trendmicro.com/xdr/api-v3#tag/Workbench/paths/~1v3.0~1workbench~1alerts/get) | +| `get_alert_list` `consume_alert_list` | [Get alerts list](https://automation.trendmicro.com/xdr/api-v3#tag/Workbench/paths/~1v3.0~1workbench~1alerts/get) | + +Contributing +------------ +Thank you for your interest in this project, please make sure to read the [contribution guide](CONTRIBUTING.md). + +Code of conduct +--------------- +See [Code of conduct](CODE_OF_CONDUCT.md). + +License +------- +Project distributed under the [Apache 2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/python/pytmv1/pyproject.toml b/python/pytmv1/pyproject.toml new file mode 100755 index 0000000..3bd8fc6 --- /dev/null +++ b/python/pytmv1/pyproject.toml @@ -0,0 +1,77 @@ +[build-system] +requires = ["hatchling>=1.12.2"] +build-backend = "hatchling.build" + +[project] +name = "pytmv1" +description = "Python library for Trend Micro Vision One" +license = "Apache-2.0" +readme = "README.md" +dynamic = ["version"] +requires-python = ">=3.7" +authors = [ + { name = "Thomas Legros", email = "thomas_legros@trendmicro.com" } +] +maintainers = [ + { name = "TrendATI", email = "ati-integration@trendmicro.com"}, +] +classifiers = [ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Internet :: WWW/HTTP", + "Topic :: Software Development :: Libraries", +] +dependencies = [ + "beautifulsoup4 ~= 4.11.1", + "requests ~= 2.27.1", + "pydantic ~= 1.10.4", +] + +[project.optional-dependencies] +dev = [ + "hatch ~= 1.6.3", + "psutil ~= 5.9.4", + "pytest ~= 7.2.0", + "pytest-mock ~= 3.10.0", + "pytest-cov ~= 4.0.0", +] + +[project.urls] +"Source" = "https://github.com/TrendATI/pytmv1" +"Issues" = "https://github.com/TrendATI/pytmv1/issues" + +[tool.hatch.build.targets.sdist] +exclude = [".github", "tests"] + +[tool.hatch.version] +path = "src/pytmv1/__about__.py" + +[tool.black] +target-version = ["py37"] +line-length = 79 +preview = true +color = true + +[tool.isort] +profile = "black" +line_length = 79 +color_output = true + +[tool.mypy] +python_version = "3.7" +exclude = ["dist", "tests", "venv"] +show_column_numbers = true +warn_unused_configs = true +pretty = true +strict = true + +[tool.pytest.ini_options] +addopts = "--show-capture=log -s" \ No newline at end of file diff --git a/python/pytmv1/src/pytmv1/__about__.py b/python/pytmv1/src/pytmv1/__about__.py new file mode 100755 index 0000000..22049ab --- /dev/null +++ b/python/pytmv1/src/pytmv1/__about__.py @@ -0,0 +1 @@ +__version__ = "0.6.2" diff --git a/python/pytmv1/src/pytmv1/__init__.py b/python/pytmv1/src/pytmv1/__init__.py new file mode 100755 index 0000000..aa00871 --- /dev/null +++ b/python/pytmv1/src/pytmv1/__init__.py @@ -0,0 +1,186 @@ +from .__about__ import __version__ +from .caller import Client, client +from .mapper import map_cef +from .model.commons import ( + Alert, + Digest, + EmailActivity, + EmailMessage, + Endpoint, + EndpointActivity, + Entity, + Error, + ExceptionObject, + HostInfo, + ImpactScope, + Indicator, + MatchedEvent, + MatchedFilter, + MatchedIndicatorPattern, + MatchedRule, + MsData, + MsDataUrl, + MsError, + SaeAlert, + SaeIndicator, + SandboxSuspiciousObject, + SuspiciousObject, + TiAlert, + TiIndicator, + Value, + ValueList, +) +from .model.enums import ( + EntityType, + EventID, + EventSubID, + IntegrityLevel, + InvestigationStatus, + ObjectType, + OperatingSystem, + ProductCode, + Provenance, + Provider, + QueryField, + QueryOp, + RiskLevel, + SandboxAction, + SandboxObjectType, + ScanAction, + Severity, + Status, + TaskAction, +) +from .model.requests import ( + AccountTask, + EmailMessageIdTask, + EmailMessageUIdTask, + EndpointTask, + FileTask, + ObjectTask, + ProcessTask, + SuspiciousObjectTask, +) +from .model.responses import ( + AccountTaskResp, + AddAlertNoteResp, + BaseTaskResp, + BlockListTaskResp, + BytesResp, + CollectFileTaskResp, + ConnectivityResp, + ConsumeLinkableResp, + EmailMessageTaskResp, + EndpointTaskResp, + GetAlertDetailsResp, + GetAlertListResp, + GetEmailActivityDataCountResp, + GetEmailActivityDataResp, + GetEndpointActivityDataCountResp, + GetEndpointActivityDataResp, + GetEndpointDataResp, + GetExceptionListResp, + GetSuspiciousListResp, + MultiResp, + MultiUrlResp, + NoContentResp, + SandboxAnalysisResultResp, + SandboxSubmissionStatusResp, + SandboxSubmitUrlTaskResp, + SandboxSuspiciousListResp, + SubmitFileToSandboxResp, + TerminateProcessTaskResp, +) +from .results import MultiResult, Result, ResultCode + +__all__ = [ + "__version__", + "client", + "map_cef", + "AccountTask", + "AccountTaskResp", + "AddAlertNoteResp", + "Alert", + "BaseTaskResp", + "BlockListTaskResp", + "BytesResp", + "Client", + "CollectFileTaskResp", + "ConnectivityResp", + "ConsumeLinkableResp", + "Digest", + "EmailActivity", + "EmailMessage", + "EmailMessageIdTask", + "EmailMessageTaskResp", + "EmailMessageUIdTask", + "Endpoint", + "EndpointActivity", + "EndpointTask", + "EndpointTaskResp", + "Entity", + "EntityType", + "Error", + "EventID", + "EventSubID", + "ExceptionObject", + "FileTask", + "GetAlertDetailsResp", + "GetAlertListResp", + "GetEmailActivityDataResp", + "GetEmailActivityDataCountResp", + "GetEndpointActivityDataResp", + "GetEndpointActivityDataCountResp", + "GetEndpointDataResp", + "GetExceptionListResp", + "GetSuspiciousListResp", + "HostInfo", + "ImpactScope", + "Indicator", + "IntegrityLevel", + "InvestigationStatus", + "MatchedEvent", + "MatchedFilter", + "MatchedIndicatorPattern", + "MatchedRule", + "MsData", + "MsDataUrl", + "MsError", + "MultiResult", + "MultiResp", + "MultiUrlResp", + "NoContentResp", + "ObjectTask", + "ObjectType", + "OperatingSystem", + "ProcessTask", + "ProductCode", + "Provenance", + "Provider", + "QueryField", + "QueryOp", + "Result", + "ResultCode", + "RiskLevel", + "SaeAlert", + "SaeIndicator", + "SandboxAction", + "SandboxAnalysisResultResp", + "SandboxObjectType", + "SandboxSubmissionStatusResp", + "SandboxSubmitUrlTaskResp", + "SandboxSuspiciousListResp", + "SandboxSuspiciousObject", + "ScanAction", + "Severity", + "Status", + "SubmitFileToSandboxResp", + "SuspiciousObject", + "SuspiciousObjectTask", + "TaskAction", + "TerminateProcessTaskResp", + "TiAlert", + "TiIndicator", + "Value", + "ValueList", +] diff --git a/python/pytmv1/src/pytmv1/caller.py b/python/pytmv1/src/pytmv1/caller.py new file mode 100755 index 0000000..07e999c --- /dev/null +++ b/python/pytmv1/src/pytmv1/caller.py @@ -0,0 +1,1084 @@ +from __future__ import annotations + +import logging +from functools import lru_cache +from logging import Logger +from typing import Callable, List, Optional, Type, Union + +from . import utils +from .core import Core +from .model.commons import ( + EmailActivity, + Endpoint, + EndpointActivity, + ExceptionObject, + SaeAlert, + SuspiciousObject, + TiAlert, +) +from .model.enums import ( + Api, + HttpMethod, + InvestigationStatus, + QueryOp, + SearchMode, +) +from .model.requests import ( + AccountTask, + EmailMessageIdTask, + EmailMessageUIdTask, + EndpointTask, + FileTask, + ObjectTask, + ProcessTask, + SuspiciousObjectTask, +) +from .model.responses import ( + AddAlertNoteResp, + BaseTaskResp, + BytesResp, + ConnectivityResp, + ConsumeLinkableResp, + GetAlertDetailsResp, + GetAlertListResp, + GetEmailActivityDataCountResp, + GetEmailActivityDataResp, + GetEndpointActivityDataCountResp, + GetEndpointActivityDataResp, + GetEndpointDataResp, + GetExceptionListResp, + GetSuspiciousListResp, + MultiResp, + MultiUrlResp, + NoContentResp, + S, + SandboxAnalysisResultResp, + SandboxSubmissionStatusResp, + SandboxSuspiciousListResp, + SubmitFileToSandboxResp, +) +from .results import MultiResult, Result + +log: Logger = logging.getLogger(__name__) + + +@lru_cache(maxsize=None) +def client( + name: str, + token: str, + url: str, + pool_connections: int = 1, + pool_maxsize: int = 1, + connect_timeout: int = 30, + read_timeout: int = 30, +) -> Client: + """Helper function to initialize a :class:`Client`. + + :param name: Identify the application using this library. + :type name: str + :param token: Authentication token created for your account. + :type token: str + :param url: Vision One API url this client connects to. + :type url: str + :param pool_connections: (optional) Number of connection to cache. + :type pool_connections: int + :param pool_maxsize: (optional) Maximum size of the pool. + :type pool_maxsize: int + :param connect_timeout: (optional) Seconds before connection timeout. + :type connect_timeout: int + :param read_timeout: (optional) Seconds before read timeout. + :type connect_timeout: int + :rtype: Client + """ + log.debug( + "Initializing new client with [Appname=%s, Token=*****, URL=%s]", + name, + url, + ) + return Client( + Core( + name, + token, + url, + pool_connections, + pool_maxsize, + connect_timeout, + read_timeout, + ) + ) + + +class Client: + def __init__(self, core: Core): + self._core = core + + def add_alert_note( + self, alert_id: str, note: str + ) -> Result[AddAlertNoteResp]: + """Adds a note to the specified Workbench alert. + + :param alert_id: Workbench alert id. + :type alert_id: str + :param note: Value of the note. + :type note: str + :rtype: Result[AddAlertNoteResp]: + """ + return self._core.send( + AddAlertNoteResp, + Api.ADD_ALERT_NOTE.value.format(alert_id), + HttpMethod.POST, + json={"content": note}, + ) + + def add_to_block_list( + self, *objects: ObjectTask + ) -> MultiResult[MultiResp]: + """Adds object(s) to the Suspicious Object List, + which blocks the objects on subsequent detections. + + :param objects: Object(s) to add. + :type objects: Tuple[ObjectTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_multi( + MultiResp, + Api.ADD_TO_BLOCK_LIST, + json=utils.build_object_request(*objects), + ) + + def add_to_exception_list( + self, *objects: ObjectTask + ) -> MultiResult[MultiResp]: + """Adds object(s) to the Exception List. + + :param objects: Object(s) to add. + :type objects: Tuple[ObjectTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_multi( + MultiResp, + Api.ADD_TO_EXCEPTION_LIST, + json=utils.build_object_request(*objects), + ) + + def add_to_suspicious_list( + self, *objects: SuspiciousObjectTask + ) -> MultiResult[MultiResp]: + """Adds object(s) to the Suspicious Object List. + + :param objects: Object(s) to add. + :type objects: Tuple[SuspiciousObjectTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_multi( + MultiResp, + Api.ADD_TO_SUSPICIOUS_LIST, + json=utils.build_suspicious_request(*objects), + ) + + def collect_file(self, *files: FileTask) -> MultiResult[MultiResp]: + """Collects a file from one or more endpoints and then sends the files + to Vision One in a password-protected archive. + + :param files: File(s) to collect. + :type files: Tuple[FileTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_endpoint(Api.COLLECT_ENDPOINT_FILE, *files) + + def consume_alert_list( + self, + consumer: Callable[[Union[SaeAlert, TiAlert]], None], + start_time: Optional[str] = None, + end_time: Optional[str] = None, + ) -> Result[ConsumeLinkableResp]: + """Retrieves and consume workbench alerts. + + :param consumer: Function which will consume every record in result. + :type consumer: Callable[[Union[SaeAlert, TiAlert]], None] + :param start_time: Date that indicates the start of the data retrieval + time range (yyyy-MM-ddThh:mm:ssZ in UTC). + Defaults to 24 hours before the request is made. + :type start_time: Optional[str] + :param end_time: Date that indicates the end of the data retrieval + time range (yyyy-MM-ddThh:mm:ssZ in UTC). + Defaults to the time the request is made. + :type end_time: Optional[str] + :rtype: Result[ConsumeLinkableResp]: + """ + return self._core.send_linkable( + GetAlertListResp, + Api.GET_ALERT_LIST, + consumer, + params=utils.filter_none( + { + "startDateTime": start_time, + "endDateTime": end_time, + "orderBy": "createdDateTime desc", + } + ), + ) + + def consume_email_activity_data( + self, + consumer: Callable[[EmailActivity], None], + start_time: Optional[str] = None, + end_time: Optional[str] = None, + select: Optional[List[str]] = None, + top: int = 500, + op: QueryOp = QueryOp.AND, + **fields: str, + ) -> Result[ConsumeLinkableResp]: + """Retrieves and consume email activity data in a paginated list + filtered by provided values. + + :param consumer: Function which will consume every record in result. + :type consumer: Callable[[EmailActivity], None] + :param start_time: Date that indicates the start of the data retrieval + time range (yyyy-MM-ddThh:mm:ssZ in UTC). + Defaults to 24 hours before the request is made. + :type start_time: Optional[str] + :param end_time: Date that indicates the end of the data retrieval + time range (yyyy-MM-ddThh:mm:ssZ in UTC). + Defaults to the time the request is made. + :type end_time: Optional[str] + :param select: List of fields to include in the search results, + if no fields are specified, the query returns all supported fields. + :type select: Optional[List[str]] + :param top: Number of records fetched per page. + :type top: int + :param op: Operator to apply between fields (ie: uuid=... OR tags=...) + :type op: QueryOp + :param fields: Field/value used to filter result (ie: uuid="123456") + check Vision One API documentation for full list of supported fields. + :type fields: Dict[str, str] + :rtype: Result[ConsumeLinkableResp]: + """ + return self._core.send_linkable( + GetEmailActivityDataResp, + Api.GET_EMAIL_ACTIVITY_DATA, + consumer, + params=utils.build_activity_request( + start_time, + end_time, + select, + top, + SearchMode.DEFAULT, + ), + headers=utils.activity_query(op, **fields), + ) + + def consume_endpoint_activity_data( + self, + consumer: Callable[[EndpointActivity], None], + start_time: Optional[str] = None, + end_time: Optional[str] = None, + select: Optional[List[str]] = None, + top: int = 500, + op: QueryOp = QueryOp.AND, + **fields: str, + ) -> Result[ConsumeLinkableResp]: + """Retrieves and consume endpoint activity data in a paginated list + filtered by provided values. + + :param consumer: Function which will consume every record in result. + :type consumer: Callable[[EndpointActivity], None] + :param start_time: Date that indicates the start of the data retrieval + time range (yyyy-MM-ddThh:mm:ssZ in UTC). + Defaults to 24 hours before the request is made. + :type start_time: Optional[str] + :param end_time: Date that indicates the end of the data retrieval + time range (yyyy-MM-ddThh:mm:ssZ in UTC). + Defaults to the time the request is made. + :type end_time: Optional[str] + :param select: List of fields to include in the search results, + if no fields are specified, the query returns all supported fields. + :type select: Optional[List[str]] + :param top: Number of records fetched per page. + :type top: int + :param op: Operator to apply between fields (ie: dpt=... OR src=...) + :type op: QueryOp + :param fields: Field/value used to filter result (ie: dpt="443") + check Vision One API documentation for full list of supported fields. + :type fields: Dict[str, str] + :rtype: Result[ConsumeLinkableResp]: + """ + return self._core.send_linkable( + GetEndpointActivityDataResp, + Api.GET_ENDPOINT_ACTIVITY_DATA, + consumer, + params=utils.build_activity_request( + start_time, + end_time, + select, + top, + SearchMode.DEFAULT, + ), + headers=utils.activity_query(op, **fields), + ) + + def consume_endpoint_data( + self, + consumer: Callable[[Endpoint], None], + op: QueryOp, + *values: str, + ) -> Result[ConsumeLinkableResp]: + """Retrieves and consume endpoints. + + :param consumer: Function which will consume every record in result. + :type consumer: Callable[[Endpoint], None] + :param op: Query operator to apply. + :type op: QueryOp + :param values: Agent guid, login account, endpoint name, ip address, + mac address, operating system, product code. + :type values: Tuple[str, ...] + :rtype: Result[ConsumeLinkableResp]: + """ + return self._core.send_linkable( + GetEndpointDataResp, + Api.GET_ENDPOINT_DATA, + consumer, + headers=utils.endpoint_query(op, *values), + ) + + def consume_exception_list( + self, consumer: Callable[[ExceptionObject], None] + ) -> Result[ConsumeLinkableResp]: + """Retrieves and consume exception objects. + + :param consumer: Function which will consume every record in result. + :type consumer: Callable[[ExceptionObject], None] + :rtype: Result[ConsumeLinkableResp]: + """ + return self._core.send_linkable( + GetExceptionListResp, Api.GET_EXCEPTION_LIST, consumer + ) + + def consume_suspicious_list( + self, consumer: Callable[[SuspiciousObject], None] + ) -> Result[ConsumeLinkableResp]: + """Retrieves and consume suspicious objects. + + :param consumer: Function which will consume every record in result. + :type consumer: Callable[[SuspiciousObject], None] + :rtype: Result[ConsumeLinkableResp]: + """ + return self._core.send_linkable( + GetSuspiciousListResp, Api.GET_SUSPICIOUS_LIST, consumer + ) + + def delete_email_message( + self, *messages: Union[EmailMessageUIdTask, EmailMessageIdTask] + ) -> MultiResult[MultiResp]: + """Deletes a message from one or more mailboxes. + + :param messages: Message(s) to delete. + :type messages: Tuple[EmailUIdTask, EmailMsgIdTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_multi( + MultiResp, + Api.DELETE_EMAIL_MESSAGE, + json=[ + task.dict(by_alias=True, exclude_none=True) + for task in messages + ], + ) + + def disable_account( + self, *accounts: AccountTask + ) -> MultiResult[MultiResp]: + """Signs the user out of all active application and browser sessions, + and prevents the user from signing in any new session. + + :param accounts: Account(s) to disable. + :type accounts: Tuple[AccountTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_multi( + MultiResp, + Api.DISABLE_ACCOUNT, + json=[ + task.dict(by_alias=True, exclude_none=True) + for task in accounts + ], + ) + + def download_sandbox_analysis_result( + self, + submit_id: str, + poll: bool = True, + poll_time_sec: float = 1800, + ) -> Result[BytesResp]: + """Downloads the analysis results of the specified object as PDF. + + :param submit_id: Sandbox submission id. + :type submit_id: str + :param poll: If we should wait until the task is finished before + to return the result. + :type poll: bool + :param poll_time_sec: Maximum time to wait for the result to + be available. + :type poll_time_sec: float + :rtype: Result[BytesResp]: + """ + return self._core.send_sandbox_result( + BytesResp, + Api.DOWNLOAD_SANDBOX_ANALYSIS_RESULT, + submit_id, + poll, + poll_time_sec, + ) + + def download_sandbox_investigation_package( + self, + submit_id: str, + poll: bool = True, + poll_time_sec: float = 1800, + ) -> Result[BytesResp]: + """Downloads the Investigation Package of the specified object. + + :param submit_id: Sandbox submission id. + :type submit_id: str + :param poll: If we should wait until the task is finished before + to return the result. + :type poll: bool + :param poll_time_sec: Maximum time to wait for the result to + be available. + :type poll_time_sec: float + :rtype: Result[BytesResp]: + """ + return self._core.send_sandbox_result( + BytesResp, + Api.DOWNLOAD_SANDBOX_INVESTIGATION_PACKAGE, + submit_id, + poll, + poll_time_sec, + ) + + def edit_alert_status( + self, + alert_id: str, + status: InvestigationStatus, + if_match: str, + ) -> Result[NoContentResp]: + """Edit the status of an alert or investigation triggered in Workbench. + + :param alert_id: Workbench alert id. + :type alert_id: str + :param status: Status to be updated. + :type status: InvestigationStatus + :param if_match: Target resource will be updated only if + it matches ETag of the target one. + :type if_match: str + :rtype: Result[NoContentResp]: + """ + return self._core.send( + NoContentResp, + Api.EDIT_ALERT_STATUS.value.format(alert_id), + HttpMethod.PATCH, + json={"investigationStatus": status}, + headers={ + "If-Match": ( + if_match + if if_match.startswith('"') + else '"' + if_match + '"' + ) + }, + ) + + def enable_account(self, *accounts: AccountTask) -> MultiResult[MultiResp]: + """Allows the user to sign in to new application and browser sessions. + + :param accounts: Account(s) to enable. + :type accounts: Tuple[AccountTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_multi( + MultiResp, + Api.ENABLE_ACCOUNT, + json=[ + task.dict(by_alias=True, exclude_none=True) + for task in accounts + ], + ) + + def get_alert_details(self, alert_id: str) -> Result[GetAlertDetailsResp]: + """Displays information about the specified alert. + + :param alert_id: Workbench alert id. + :type alert_id: str + :rtype: Result[GetAlertDetailsResp]: + """ + return self._core.send( + GetAlertDetailsResp, + Api.GET_ALERT_DETAILS.value.format(alert_id), + ) + + def get_alert_list( + self, start_time: Optional[str] = None, end_time: Optional[str] = None + ) -> Result[GetAlertListResp]: + """Retrieves workbench alerts in a paginated list. + + :param start_time: Date that indicates the start of the data retrieval + time range (yyyy-MM-ddThh:mm:ssZ in UTC). + Defaults to 24 hours before the request is made. + :type start_time: Optional[str] + :param end_time: Date that indicates the end of the data retrieval + time range (yyyy-MM-ddThh:mm:ssZ in UTC). + Defaults to the time the request is made. + :type end_time: Optional[str] + :rtype: Result[GetAlertListResp]: + """ + return self._core.send( + GetAlertListResp, + Api.GET_ALERT_LIST, + params=utils.filter_none( + { + "startDateTime": start_time, + "endDateTime": end_time, + "orderBy": "createdDateTime desc", + } + ), + ) + + def get_base_task_result( + self, + task_id: str, + poll: bool = True, + poll_time_sec: float = 1800, + ) -> Result[BaseTaskResp]: + """Retrieves the result of a response task. + + :param task_id: Task id. + :type task_id: str + :param poll: If we should wait until the task is finished before + to return the result. + :type poll: bool + :param poll_time_sec: Maximum time to wait for the result + to be available. + :type poll_time_sec: float + :rtype: Result[BaseTaskResultResp]: + """ + return self._core.send_task_result( + BaseTaskResp, task_id, poll, poll_time_sec + ) + + def get_email_activity_data( + self, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + select: Optional[List[str]] = None, + top: int = 500, + op: QueryOp = QueryOp.AND, + **fields: str, + ) -> Result[GetEmailActivityDataResp]: + """Retrieves email activity data in a paginated list + filtered by provided values. + + :param start_time: Date that indicates the start of the data retrieval + time range (yyyy-MM-ddThh:mm:ssZ in UTC). + Defaults to 24 hours before the request is made. + :type start_time: Optional[str] + :param end_time: Date that indicates the end of the data retrieval + time range (yyyy-MM-ddThh:mm:ssZ in UTC). + Defaults to the time the request is made. + :type end_time: Optional[str] + :param select: List of fields to include in the search results, + if no fields are specified, the query returns all supported fields. + :type select: Optional[List[str]] + :param top: Number of records fetched per page. + :type top: int + :param op: Operator to apply between fields (ie: uuid=... OR tags=...) + :type op: QueryOp + :param fields: Field/value used to filter result (ie: uuid="123456") + check Vision One API documentation for full list of supported fields. + :type fields: Dict[str, str] + :rtype: Result[GetEmailActivityDataResp]: + """ + return self._core.send( + GetEmailActivityDataResp, + Api.GET_EMAIL_ACTIVITY_DATA, + params=utils.build_activity_request( + start_time, + end_time, + select, + top, + SearchMode.DEFAULT, + ), + headers=utils.activity_query(op, **fields), + ) + + def get_email_activity_data_count( + self, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + select: Optional[List[str]] = None, + top: int = 500, + op: QueryOp = QueryOp.AND, + **fields: str, + ) -> Result[GetEmailActivityDataCountResp]: + """Retrieves the count of email activity data in a paginated list + filtered by provided values. + + :param start_time: Date that indicates the start of the data retrieval + time range (yyyy-MM-ddThh:mm:ssZ in UTC). + Defaults to 24 hours before the request is made. + :type start_time: Optional[str] + :param end_time: Date that indicates the end of the data retrieval + time range (yyyy-MM-ddThh:mm:ssZ in UTC). + Defaults to the time the request is made. + :type end_time: Optional[str] + :param select: List of fields to include in the search results, + if no fields are specified, the query returns all supported fields. + :type select: Optional[List[str]] + :param top: Number of records fetched per page. + :type top: int + :param op: Operator to apply between fields (ie: uuid=... OR tags=...) + :type op: QueryOp + :param fields: Field/value used to filter result (ie: uuid="123456") + check Vision One API documentation for full list of supported fields. + :type fields: Dict[str, str] + :rtype: Result[GetEmailActivityDataCountResp]: + """ + return self._core.send( + GetEmailActivityDataCountResp, + Api.GET_EMAIL_ACTIVITY_DATA, + params=utils.build_activity_request( + start_time, + end_time, + select, + top, + SearchMode.COUNT_ONLY, + ), + headers=utils.activity_query(op, **fields), + ) + + def get_endpoint_activity_data( + self, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + select: Optional[List[str]] = None, + top: int = 500, + op: QueryOp = QueryOp.AND, + **fields: str, + ) -> Result[GetEndpointActivityDataResp]: + """Retrieves endpoint activity data in a paginated list + filtered by provided values. + + :param start_time: Date that indicates the start of the data retrieval + time range (yyyy-MM-ddThh:mm:ssZ in UTC). + Defaults to 24 hours before the request is made. + :type start_time: Optional[str] + :param end_time: Date that indicates the end of the data retrieval + time range (yyyy-MM-ddThh:mm:ssZ in UTC). + Defaults to the time the request is made. + :type end_time: Optional[str] + :param select: List of fields to include in the search results, + if no fields are specified, the query returns all supported fields. + :type select: Optional[List[str]] + :param top: Number of records fetched per page. + :type top: int + :param op: Operator to apply between fields (ie: dpt=... OR src=...) + :type op: QueryOp + :param fields: Field/value used to filter result (ie: dpt="443") + check Vision One API documentation for full list of supported fields. + :type fields: Dict[str, str] + :rtype: Result[GetEndpointActivityDataResp]: + """ + return self._core.send( + GetEndpointActivityDataResp, + Api.GET_ENDPOINT_ACTIVITY_DATA, + params=utils.build_activity_request( + start_time, + end_time, + select, + top, + SearchMode.DEFAULT, + ), + headers=utils.activity_query(op, **fields), + ) + + def get_endpoint_activity_data_count( + self, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + select: Optional[List[str]] = None, + top: int = 500, + op: QueryOp = QueryOp.AND, + **fields: str, + ) -> Result[GetEndpointActivityDataCountResp]: + """Retrieves the count of endpoint activity data in a paginated list + filtered by provided values. + + :param start_time: Date that indicates the start of the data retrieval + time range (yyyy-MM-ddThh:mm:ssZ in UTC). + Defaults to 24 hours before the request is made. + :type start_time: Optional[str] + :param end_time: Date that indicates the end of the data retrieval + time range (yyyy-MM-ddThh:mm:ssZ in UTC). + Defaults to the time the request is made. + :type end_time: Optional[str] + :param select: List of fields to include in the search results, + if no fields are specified, the query returns all supported fields. + :type select: Optional[List[str]] + :param top: Number of records fetched per page. + :type top: int + :param op: Operator to apply between fields (ie: dpt=... OR src=...) + :type op: QueryOp + :param fields: Field/value used to filter result (ie: dpt="443") + check Vision One API documentation for full list of supported fields. + :type fields: Dict[str, str] + :rtype: Result[GetEndpointActivityDataCountResp]: + """ + return self._core.send( + GetEndpointActivityDataCountResp, + Api.GET_ENDPOINT_ACTIVITY_DATA, + params=utils.build_activity_request( + start_time, + end_time, + select, + top, + SearchMode.COUNT_ONLY, + ), + headers=utils.activity_query(op, **fields), + ) + + def get_endpoint_data( + self, op: QueryOp, *values: str + ) -> Result[GetEndpointDataResp]: + """Retrieves endpoints in a paginated list filtered by provided values. + + :param op: Query operator to apply. + :type op: QueryOp + :param values: Agent guid, login account, endpoint name, ip address, + mac address, operating system, product code. + :type values: Tuple[str, ...] + :rtype: Result[GetEndpointDataResp]: + """ + return self._core.send( + GetEndpointDataResp, + Api.GET_ENDPOINT_DATA, + headers=utils.endpoint_query(op, *values), + ) + + def get_exception_list(self) -> Result[GetExceptionListResp]: + """Retrieves exception objects in a paginated list. + + :rtype: Result[GetExceptionListResp]: + """ + return self._core.send(GetExceptionListResp, Api.GET_EXCEPTION_LIST) + + def get_sandbox_analysis_result( + self, + submit_id: str, + poll: bool = True, + poll_time_sec: float = 1800, + ) -> Result[SandboxAnalysisResultResp]: + """Retrieves the analysis results of the specified object. + + :param submit_id: Sandbox submission id. + :type submit_id: str + :param poll: If we should wait until the task is finished before + to return the result. + :type poll: bool + :param poll_time_sec: Maximum time to wait for the result + to be available. + :type poll_time_sec: float + :rtype: Result[SandboxAnalysisResultResp]: + """ + return self._core.send_sandbox_result( + SandboxAnalysisResultResp, + Api.GET_SANDBOX_ANALYSIS_RESULT, + submit_id, + poll, + poll_time_sec, + ) + + def get_sandbox_submission_status( + self, submit_id: str + ) -> Result[SandboxSubmissionStatusResp]: + """Retrieves the submission status of the specified object. + + :param submit_id: Sandbox submission id. + :type submit_id: str + :rtype: Result[SandboxSubmissionStatusResp]: + """ + return self._core.send( + SandboxSubmissionStatusResp, + Api.GET_SANDBOX_SUBMISSION_STATUS.value.format(submit_id), + ) + + def get_sandbox_suspicious_list( + self, + submit_id: str, + poll: bool = True, + poll_time_sec: float = 1800, + ) -> Result[SandboxSuspiciousListResp]: + """Retrieves the suspicious object list associated to the + specified object. + + :param submit_id: Sandbox submission id. + :type submit_id: str + :param poll: If we should wait until the task is finished before + to return the result. + :type poll: bool + :param poll_time_sec: Maximum time to wait for the result + to be available. + :type poll_time_sec: float + :rtype: Result[SandboxSuspiciousListResp]: + """ + return self._core.send_sandbox_result( + SandboxSuspiciousListResp, + Api.GET_SANDBOX_SUSPICIOUS_LIST, + submit_id, + poll, + poll_time_sec, + ) + + def get_suspicious_list( + self, + ) -> Result[GetSuspiciousListResp]: + """Retrieves suspicious objects in a paginated list. + + :rtype: Result[GetSuspiciousListResp]: + """ + return self._core.send(GetSuspiciousListResp, Api.GET_SUSPICIOUS_LIST) + + def get_task_result( + self, + task_id: str, + class_: Type[S], + poll: bool = True, + poll_time_sec: float = 1800, + ) -> Result[S]: + """Retrieves the result of a response task. + + :param task_id: Task id. + :type task_id: str + :param class_: Expected task result class. + :type class_: Type[S] + :param poll: If we should wait until the task is finished before + to return the result. + :type poll: bool + :param poll_time_sec: Maximum time to wait for the result + to be available. + :type poll_time_sec: float + :rtype: Result[BaseTaskResultResp]: + """ + return self._core.send_task_result( + class_, task_id, poll, poll_time_sec + ) + + def isolate_endpoint( + self, *endpoints: EndpointTask + ) -> MultiResult[MultiResp]: + """Disconnects one or more endpoints from the network + but allows communication with the managing Trend Micro server product. + + :param endpoints: Endpoint(s) to isolate. + :type endpoints: Tuple[EndpointTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_endpoint(Api.ISOLATE_ENDPOINT, *endpoints) + + def quarantine_email_message( + self, *messages: Union[EmailMessageUIdTask, EmailMessageIdTask] + ) -> MultiResult[MultiResp]: + """Quarantine a message from one or more mailboxes. + + :param messages: Message(s) to quarantine. + :type messages: Tuple[EmailUIdTask, EmailMsgIdTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_multi( + MultiResp, + Api.QUARANTINE_EMAIL_MESSAGE, + json=[ + task.dict(by_alias=True, exclude_none=True) + for task in messages + ], + ) + + def remove_from_block_list( + self, *objects: ObjectTask + ) -> MultiResult[MultiResp]: + """Removes object(s) that was added to the Suspicious Object List + using the "Add to block list" action + + :param objects: Object(s) to remove. + :type objects: Tuple[ObjectTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_multi( + MultiResp, + Api.REMOVE_FROM_BLOCK_LIST, + json=utils.build_object_request(*objects), + ) + + def remove_from_exception_list( + self, *objects: ObjectTask + ) -> MultiResult[MultiResp]: + """Removes object(s) from the Exception List. + + :param objects: Object(s) to remove. + :type objects: Tuple[ObjectTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_multi( + MultiResp, + Api.REMOVE_FROM_EXCEPTION_LIST, + json=utils.build_object_request(*objects), + ) + + def remove_from_suspicious_list( + self, *objects: ObjectTask + ) -> MultiResult[MultiResp]: + """Removes object(s) from the Suspicious List. + + :param objects: Object(s) to remove. + :type objects: Tuple[ObjectTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_multi( + MultiResp, + Api.REMOVE_FROM_SUSPICIOUS_LIST, + json=utils.build_object_request(*objects), + ) + + def reset_password_account( + self, *accounts: AccountTask + ) -> MultiResult[MultiResp]: + """Signs the user out of all active application and browser sessions, + and forces the user to create a new password during the next sign-in + attempt. + + :param accounts: Account(s) to reset. + :type accounts: Tuple[AccountTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_multi( + MultiResp, + Api.RESET_PASSWORD, + json=[ + task.dict(by_alias=True, exclude_none=True) + for task in accounts + ], + ) + + def restore_endpoint( + self, *endpoints: EndpointTask + ) -> MultiResult[MultiResp]: + """Restores network connectivity to one or more endpoints that applied + the "Isolate endpoint" action. + + :param endpoints: Endpoint(s) to restore. + :type endpoints: Tuple[EndpointTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_endpoint(Api.RESTORE_ENDPOINT, *endpoints) + + def restore_email_message( + self, *messages: Union[EmailMessageUIdTask, EmailMessageIdTask] + ) -> MultiResult[MultiResp]: + """Restore quarantined email message(s). + + :param messages: Message(s) to restore. + :type messages: Tuple[EmailUIdTask, EmailMsgIdTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_multi( + MultiResp, + Api.RESTORE_EMAIL_MESSAGE, + json=[ + task.dict(by_alias=True, exclude_none=True) + for task in messages + ], + ) + + def sign_out_account( + self, *accounts: AccountTask + ) -> MultiResult[MultiResp]: + """Signs the user out of all active application and browser sessions. + + :param accounts: Account(s) to sign out. + :type accounts: Tuple[AccountTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_multi( + MultiResp, + Api.SIGN_OUT_ACCOUNT, + json=[ + task.dict(by_alias=True, exclude_none=True) + for task in accounts + ], + ) + + def submit_file_to_sandbox( + self, + file: bytes, + file_name: str, + document_password: Optional[str] = None, + archive_password: Optional[str] = None, + arguments: Optional[str] = None, + ) -> Result[SubmitFileToSandboxResp]: + """Submits a file to the sandbox for analysis. + + :param file: Raw content in bytes. + :type file: bytes + :param file_name: Name of the file. + :type file_name: str + :param document_password: Password used to + decrypt the submitted file sample. + :type document_password: Optional[str] + :param archive_password: Password encoded in Base64 used to decrypt + the submitted archive. + :type archive_password: Optional[str] + :param arguments: Command line arguments to run the submitted file. + Only available for Portable Executable (PE) files and script files. + :type arguments: Optional[str] + :rtype: Result[SubmitFileToSandboxResp]: + """ + return self._core.send( + SubmitFileToSandboxResp, + Api.SUBMIT_FILE_TO_SANDBOX, + HttpMethod.POST, + data=utils.build_sandbox_file_request( + document_password, archive_password, arguments + ), + files={"file": (file_name, file, "application/octet-stream")}, + ) + + def submit_urls_to_sandbox(self, *urls: str) -> MultiResult[MultiUrlResp]: + """Submits URLs to the sandbox for analysis. + + :param urls: URL(s) to be submitted. + :type urls: Tuple[str, ...] + :rtype: MultiResult[MultiUrlResp] + """ + return self._core.send_multi( + MultiUrlResp, + Api.SUBMIT_URLS_TO_SANDBOX, + json=[{"url": url} for url in urls], + ) + + def terminate_process( + self, *processes: ProcessTask + ) -> MultiResult[MultiResp]: + """Terminates a process that is running on one or more endpoints. + + :param processes: Process(es) to terminate. + :type processes: Tuple[ProcessTask, ...] + :rtype: MultiResult[MultiResp] + """ + return self._core.send_endpoint( + Api.TERMINATE_ENDPOINT_PROCESS, *processes + ) + + def check_connectivity(self) -> Result[ConnectivityResp]: + """Checks the connection to the API service + and verifies if your authentication token is valid. + + :rtype: Result[ConnectivityResp] + """ + return self._core.send(ConnectivityResp, Api.CONNECTIVITY) diff --git a/python/pytmv1/src/pytmv1/core.py b/python/pytmv1/src/pytmv1/core.py new file mode 100755 index 0000000..2f2b908 --- /dev/null +++ b/python/pytmv1/src/pytmv1/core.py @@ -0,0 +1,360 @@ +import logging +import re +import time +from logging import Logger +from typing import Any, Callable, Dict, List, Type +from urllib.parse import SplitResult, urlsplit + +from bs4 import BeautifulSoup +from pydantic import AnyHttpUrl, parse_obj_as +from requests import PreparedRequest, Request, Response +from requests.adapters import HTTPAdapter + +from .__about__ import __version__ +from .exceptions import ( + ParseModelError, + ServerHtmlError, + ServerJsonError, + ServerMultiJsonError, + ServerTextError, +) +from .model.commons import ( + Error, + MsData, + MsDataUrl, + MsError, + MsStatus, + SaeAlert, + TiAlert, +) +from .model.enums import Api, HttpMethod, Provider, Status +from .model.requests import EndpointTask +from .model.responses import ( + MR, + AddAlertNoteResp, + BaseLinkableResp, + BaseMultiResponse, + BytesResp, + C, + ConsumeLinkableResp, + GetAlertDetailsResp, + MultiResp, + MultiUrlResp, + NoContentResp, + R, + S, + SandboxSubmissionStatusResp, +) +from .results import multi_result, result + +USERAGENT_SUFFIX: str = "PyTMV1" +API_VERSION: str = "v3.0" + +log: Logger = logging.getLogger(__name__) + + +class Core: + def __init__( + self, + appname: str, + token: str, + url: str, + pool_connections: int, + pool_maxsize: int, + connect_timeout: int, + read_timeout: int, + ): + self._adapter = HTTPAdapter(pool_connections, pool_maxsize, 0, True) + self._c_timeout = connect_timeout + self._r_timeout = read_timeout + self._appname = appname + self._token = token + self._url = parse_obj_as(AnyHttpUrl, _format(url)) + self._headers: Dict[str, str] = { + "Authorization": f"Bearer {self._token}", + "User-Agent": f"{self._appname}-{USERAGENT_SUFFIX}/{__version__}", + } + + @result + def send( + self, + class_: Type[R], + api: str, + method: HttpMethod = HttpMethod.GET, + **kwargs: Any, + ) -> R: + return self._process( + class_, + api, + method, + **kwargs, + ) + + @multi_result + def send_endpoint( + self, + api: Api, + *tasks: EndpointTask, + ) -> MultiResp: + return self._process( + MultiResp, + api, + HttpMethod.POST, + json=[ + task.dict(by_alias=True, exclude_none=True) for task in tasks + ], + ) + + @result + def send_linkable( + self, + class_: Type[BaseLinkableResp[C]], + api: str, + consumer: Callable[[C], None], + **kwargs: Any, + ) -> ConsumeLinkableResp: + return ConsumeLinkableResp( + total_consumed=self._consume_linkable( + lambda: self._process( + class_, + api, + **kwargs, + ), + consumer, + kwargs.get("headers", {}), + ) + ) + + @multi_result + def send_multi( + self, + class_: Type[MR], + api: str, + **kwargs: Any, + ) -> MR: + return self._process( + class_, + api, + HttpMethod.POST, + **kwargs, + ) + + @result + def send_sandbox_result( + self, + class_: Type[R], + api: Api, + submit_id: str, + poll: bool, + poll_time_sec: float, + ) -> R: + if poll: + _poll_status( + lambda: self._process( + SandboxSubmissionStatusResp, + Api.GET_SANDBOX_SUBMISSION_STATUS.value.format(submit_id), + ), + poll_time_sec, + ) + return self._process(class_, api.value.format(submit_id)) + + @result + def send_task_result( + self, class_: Type[S], task_id: str, poll: bool, poll_time_sec: float + ) -> S: + status_call: Callable[[], S] = lambda: self._process( + class_, + Api.GET_TASK_RESULT.value.format(task_id), + ) + if poll: + _poll_status( + status_call, + poll_time_sec, + ) + return status_call() + + def _consume_linkable( + self, + api_call: Callable[[], BaseLinkableResp[C]], + consumer: Callable[[C], None], + headers: Dict[str, str], + count: int = 0, + ) -> int: + total_count: int = count + response: BaseLinkableResp[C] = api_call() + for item in response.items: + consumer(item) + total_count += 1 + if response.next_link: + sr: SplitResult = urlsplit(response.next_link) + log.info("Found nextLink") + return self._consume_linkable( + lambda: self._process( + type(response), + f"{sr.path[5:]}?{sr.query}", + headers=headers, + ), + consumer, + headers, + total_count, + ) + log.info( + "Records consumed: [Total=%s, Type=%s]", + total_count, + type( + response.items[0] if len(response.items) > 0 else response + ).__name__, + ) + return total_count + + def _process( + self, + class_: Type[R], + uri: str, + method: HttpMethod = HttpMethod.GET, + **kwargs: Any, + ) -> R: + log.info( + "Processing request [Method=%s, Class=%s, URI=%s, Options=%s]", + method.value, + class_.__name__, + uri, + kwargs, + ) + raw_response: Response = self._send_internal( + self._prepare(uri, method, **kwargs) + ) + _validate(raw_response) + return _parse_data(raw_response, class_) + + def _prepare( + self, uri: str, method: HttpMethod, **kwargs: Any + ) -> PreparedRequest: + return Request( + method.value, + self._url + uri, + headers={**self._headers, **kwargs.pop("headers", {})}, + **kwargs, + ).prepare() + + def _send_internal(self, request: PreparedRequest) -> Response: + log.info( + "Sending request [Method=%s, URL=%s, Headers=%s, Body=%s]", + request.method, + request.url, + re.sub("Bearer \\S+", "*****", str(request.headers)), + ("Bytes" if type(request.body) == bytes else request.body), + ) + response: Response = self._adapter.send( + request, timeout=(self._c_timeout, self._r_timeout) + ) + log.info( + "Received response [Status=%s, Headers=%s, Body=%s]", + response.status_code, + response.headers, + _hide_binary(response), + ) + return response + + +def _format(url: str) -> str: + return (url if url.endswith("/") else url + "/") + API_VERSION + + +def _hide_binary(response: Response) -> str: + content_type = response.headers.get("Content-Type", "") + if "json" not in content_type and "application" in content_type: + return "***binary content***" + return response.text + + +def _is_http_success(status_codes: List[int]) -> bool: + return len(list(filter(lambda s: not 200 <= s < 399, status_codes))) == 0 + + +def _parse_data(raw_response: Response, class_: Type[R]) -> R: + content_type = raw_response.headers.get("Content-Type", "") + if "json" in content_type: + if issubclass(class_, BaseMultiResponse): + log.info("Parsing json multi response [Class=%s]", class_.__name__) + class_d: Type[List[Any]] + if issubclass(class_, MultiUrlResp): + class_d = List[MsDataUrl] + else: + class_d = List[MsData] + return class_( + items=parse_obj_as( + class_d, + raw_response.json(), + ) + ) + log.info("Parsing json response [Class=%s]", class_.__name__) + if class_ == GetAlertDetailsResp: + response_json: Dict[str, str] = raw_response.json() + return class_( + alert=parse_obj_as( + ( + SaeAlert + if response_json.get("alertProvider") == Provider.SAE + else TiAlert + ), + response_json, + ), + etag=raw_response.headers.get("ETag", ""), + ) + return class_.parse_obj(raw_response.json()) + if "application" in content_type and class_ == BytesResp: + log.info("Parsing binary response") + return class_(content=raw_response.content) + if raw_response.status_code == 201 and class_ == AddAlertNoteResp: + return class_.parse_obj(raw_response.headers) + if raw_response.status_code == 204 and class_ == NoContentResp: + return class_() + raise ParseModelError(class_.__name__, raw_response) + + +def _parse_html(html: str) -> str: + log.info("Parsing html response [Html=%s]", html) + soup = BeautifulSoup(html, "html.parser") + return "\n".join( + line.strip() for line in soup.text.split("\n") if line.strip() + ) + + +def _poll_status( + status_call: Callable[[], S], + poll_time_sec: float, +) -> None: + start_time: float = time.time() + elapsed_time: float = 0 + response: S = status_call() + while elapsed_time < poll_time_sec: + if response.status in [Status.QUEUED, Status.RUNNING]: + response = status_call() + elapsed_time = time.time() - start_time + else: + break + + +def _validate(raw_response: Response) -> None: + log.info("Validating response [%s]", raw_response) + content_type: str = raw_response.headers.get("Content-Type", "") + if "text/html" in content_type: + raise ServerHtmlError( + raw_response.status_code, _parse_html(raw_response.text) + ) + if not _is_http_success([raw_response.status_code]): + if "application/json" in content_type: + error: Dict[str, Any] = raw_response.json().get("error") + error["status"] = raw_response.status_code + raise ServerJsonError( + Error.parse_obj(error), + ) + raise ServerTextError(raw_response.status_code, raw_response.text) + if raw_response.status_code == 207: + if not _is_http_success( + MsStatus.parse_obj(raw_response.json()).values() + ): + raise ServerMultiJsonError( + parse_obj_as(List[MsError], raw_response.json()) + ) diff --git a/python/pytmv1/src/pytmv1/exceptions.py b/python/pytmv1/src/pytmv1/exceptions.py new file mode 100755 index 0000000..0b45c19 --- /dev/null +++ b/python/pytmv1/src/pytmv1/exceptions.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from typing import List + +from requests import Response + +from .model.commons import Error, MsError + + +class ServerCustError(Exception): + def __init__(self, status: int, message: str): + super().__init__(message) + self.status = status + + +class ServerJsonError(Exception): + def __init__(self, error: Error): + super().__init__( + f"Error response received from Vision One. [Error={error}]" + ) + self.error = error + + +class ServerMultiJsonError(Exception): + def __init__(self, errors: List[MsError]): + super().__init__( + ( + "Multi error response received from Vision One." + f" [Errors={errors}]" + ), + ) + self.errors = errors + + +class ServerHtmlError(ServerCustError): + def __init__(self, status: int, html: str): + super().__init__( + status, + html, + ) + + +class ServerTextError(ServerCustError): + def __init__(self, status: int, text: str): + super().__init__( + status, + text, + ) + + +class ParseModelError(ServerCustError): + def __init__(self, model: str, raw_response: Response): + super().__init__( + 500, + ( + "Could not parse response from Vision One.\n" + f"Conditions unmet [Model={model}, {raw_response}]" + ), + ) diff --git a/python/pytmv1/src/pytmv1/mapper.py b/python/pytmv1/src/pytmv1/mapper.py new file mode 100755 index 0000000..6ae3277 --- /dev/null +++ b/python/pytmv1/src/pytmv1/mapper.py @@ -0,0 +1,115 @@ +from typing import Dict, List + +from pydantic.utils import to_lower_camel + +from .model.commons import ( + Alert, + Entity, + HostInfo, + Indicator, + SaeAlert, + TiAlert, +) + +INDICATOR_CEF_MAP: Dict[str, str] = { + "command_line": "dproc", + "url": "request", + "domain": "sntdom", + "ip": "src", + "email_sender": "suser", + "fullpath": "filePath", + "filename": "fname", + "file_sha1": "fileHash", + "user_account": "suser", + "host": "shost", + "port": "spt", + "process_id": "dpid", + "registry_key": "TrendMicroVoRegistryKeyHandle", + "registry_value": "TrendMicroVoRegistryValue", + "registry_value_data": "TrendMicroVoRegistryData", + "file_sha256": "TrendMicroVoFileHashSha256", + "email_message_id": "TrendMicroVoEmailMessageId", + "email_message_unique_id": "TrendMicroVoEmailMessageUniqueId", +} + + +def map_cef(alert: Alert) -> Dict[str, str]: + data: Dict[str, str] = _map_common(alert) + _map_entities(data, alert.impact_scope.entities) + _map_indicators(data, alert.indicators) + if isinstance(alert, SaeAlert): + _map_sae(data, alert) + if isinstance(alert, TiAlert): + _map_ti(data, alert) + return data + + +def _map_common(alert: Alert) -> Dict[str, str]: + return dict( + externalId=alert.id, + act=alert.investigation_status, + cat=alert.model, + Severity=alert.severity, + rt=alert.created_date_time, + sourceServiceName=alert.alert_provider, + msg="Workbench Link: " + alert.workbench_link, + cnt=str(alert.score), + cn1=str(alert.impact_scope.desktop_count), + cn1Label="Desktop Count", + cn2=str(alert.impact_scope.server_count), + cn2Label="Server Count", + cn3=str(alert.impact_scope.account_count), + cn3Label="Account Count", + cn4=str(alert.impact_scope.email_address_count), + cn4Label="Email Address Count", + cs1=", ".join(alert.indicators[0].provenance), + cs1Label="Provenance", + ) + + +def _map_entities(data: Dict[str, str], entities: List[Entity]) -> None: + for entity in entities: + if isinstance(entity.entity_value, HostInfo): + data["dhost"] = entity.entity_value.name + data["dst"] = ", ".join(entity.entity_value.ips) + else: + data["duser"] = entity.entity_value + + +def _map_indicators(data: Dict[str, str], indicators: List[Indicator]) -> None: + for indicator in indicators: + if isinstance(indicator.value, HostInfo): + data["shost"] = indicator.value.name + data["src"] = ", ".join(indicator.value.ips) + else: + data[ + INDICATOR_CEF_MAP.get( + indicator.type, to_lower_camel(indicator.type) + ) + ] = indicator.value + + +def _map_sae(data: Dict[str, str], alert: SaeAlert) -> None: + data["cs2"] = alert.matched_rules[0].matched_filters[0].name + data["cs2Label"] = "Matched Filter" + data["cs3"] = ", ".join( + alert.matched_rules[0].matched_filters[0].mitre_technique_ids + ) + data["cs3Label"] = "Matched Techniques" + data["reason"] = alert.matched_rules[0].name + data["msg"] = data.get("msg", "") + f"\nDescription: {alert.description}" + + +def _map_ti(data: Dict[str, str], alert: TiAlert) -> None: + data["cs2"] = ", ".join(alert.matched_indicator_patterns[0].tags) + data["cs2Label"] = "Matched Pattern Tags" + data["cs3"] = alert.matched_indicator_patterns[0].pattern + data["cs3Label"] = "Matched Pattern" + data["msg"] = data.get("msg", "") + f"\nReport Link: {alert.report_link}" + data["createdBy"] = alert.created_by + if alert.campaign: + data["campaign"] = alert.campaign + if alert.industry: + data["industry"] = alert.industry + if alert.region_and_country: + data["regionAndCountry"] = alert.region_and_country diff --git a/python/pytmv1/src/pytmv1/model/__init__.py b/python/pytmv1/src/pytmv1/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/pytmv1/src/pytmv1/model/commons.py b/python/pytmv1/src/pytmv1/model/commons.py new file mode 100644 index 0000000..df095c7 --- /dev/null +++ b/python/pytmv1/src/pytmv1/model/commons.py @@ -0,0 +1,356 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple, Union + +from pydantic import BaseModel as PydanticBaseModel +from pydantic import Field +from pydantic.utils import to_lower_camel + +from .enums import ( + EntityType, + EventID, + EventSubID, + Iam, + IntegrityLevel, + InvestigationStatus, + ObjectType, + OperatingSystem, + ProductCode, + Provider, + RiskLevel, + ScanAction, + Severity, + Status, +) + + +class BaseModel(PydanticBaseModel): + class Config: + alias_generator = to_lower_camel + allow_population_by_field_name = True + + +class BaseConsumable(BaseModel): + ... + + +def _get_task_id(headers: List[Dict[str, str]]) -> Optional[str]: + task_id: str = next( + ( + h.get("value", "") + for h in headers + if "Operation-Location" == h.get("name", "") + ), + "", + ).split("/")[-1] + return task_id if task_id != "" else None + + +class Account(BaseModel): + account_name: str + iam: Iam + last_action_date_time: str + status: Status + + +class Alert(BaseConsumable): + id: str + schema_version: str + investigation_status: InvestigationStatus + workbench_link: str + alert_provider: Provider + model: str + score: int + severity: Severity + impact_scope: ImpactScope + created_date_time: str + updated_date_time: str + indicators: List[Indicator] + + +class Digest(BaseModel): + md5: str + sha1: str + sha256: str + + +class EmailMessage(BaseModel): + last_action_date_time: str + message_id: Optional[str] + mail_box: Optional[str] + message_subject: Optional[str] + unique_id: Optional[str] + organization_id: Optional[str] + status: Status + + +class Value(BaseModel): + updated_date_time: str + value: str + + +class ValueList(BaseModel): + updated_date_time: str + value: List[str] + + +class Endpoint(BaseConsumable): + agent_guid: str + login_account: ValueList + endpoint_name: Value + mac_address: ValueList + ip: ValueList + os_name: OperatingSystem + os_version: str + os_description: str + product_code: ProductCode + installed_product_codes: List[ProductCode] + + +class EmailActivity(BaseConsumable): + mail_msg_subject: Optional[str] + mail_msg_id: Optional[str] + msg_uuid: Optional[str] + mailbox: Optional[str] + mail_sender_ip: Optional[str] + mail_from_addresses: List[str] = Field(default=[]) + mail_whole_header: List[str] = Field(default=[]) + mail_to_addresses: List[str] = Field(default=[]) + mail_source_domain: Optional[str] + search_d_l: Optional[str] + scan_type: Optional[str] + event_time: Optional[int] + org_id: Optional[str] + mail_urls_visible_link: List[str] = Field(default=[]) + mail_urls_real_link: List[str] = Field(default=[]) + + +class EndpointActivity(BaseConsumable): + dpt: Optional[int] + dst: Optional[str] + endpoint_guid: Optional[str] + endpoint_host_name: Optional[str] + endpoint_ip: List[str] = Field(default=[]) + event_id: Optional[EventID] + event_sub_id: Optional[EventSubID] + object_integrity_level: Optional[IntegrityLevel] + object_true_type: Optional[int] + object_sub_true_type: Optional[int] + win_event_id: Optional[int] + event_time: Optional[int] + event_time_d_t: Optional[str] + host_name: Optional[str] + logon_user: List[str] = Field(default=[]) + object_cmd: Optional[str] + object_file_hash_sha1: Optional[str] + object_file_path: Optional[str] + object_host_name: Optional[str] + object_ip: Optional[str] + object_ips: List[str] = Field(default=[]) + object_port: Optional[int] + object_registry_data: Optional[str] + object_registry_key_handle: Optional[str] + object_registry_value: Optional[str] + object_signer: List[str] = Field(default=[]) + object_signer_valid: List[bool] = Field(default=[]) + object_user: Optional[str] + os: Optional[str] + parent_cmd: Optional[str] + parent_file_hash_sha1: Optional[str] + parent_file_path: Optional[str] + process_cmd: Optional[str] + process_file_hash_sha1: Optional[str] + process_file_path: Optional[str] + request: Optional[str] + search_d_l: Optional[str] + spt: Optional[int] + src: Optional[str] + src_file_hash_sha1: Optional[str] + src_file_path: Optional[str] + tags: List[str] = Field(default=[]) + uuid: Optional[str] + + +class HostInfo(BaseModel): + name: str + ips: List[str] + guid: str + + +class Entity(BaseModel): + entity_id: str + entity_type: EntityType + entity_value: Union[str, HostInfo] + related_entities: List[str] + related_indicator_ids: List[int] + provenance: List[str] + + +class Error(BaseModel): + status: int + code: Optional[str] = None + message: Optional[str] = None + number: Optional[int] = None + + +class ExceptionObject(BaseConsumable): + value: str + type: ObjectType + last_modified_date_time: str + description: Optional[str] + + def __init__(self, **data: str) -> None: + super().__init__(value=self._obj_value(data), **data) + + @staticmethod + def _obj_value(args: Dict[str, str]) -> str: + obj_value: Optional[str] = args.get(args.get("type", "")) + if obj_value is None: + raise ValueError("Object value not found") + return obj_value + + +class ImpactScope(BaseModel): + desktop_count: int + server_count: int + account_count: int + email_address_count: int + entities: List[Entity] + + +class Indicator(BaseModel): + id: int + type: str + value: Union[str, HostInfo] + related_entities: List[str] + provenance: List[str] + + +class MatchedEvent(BaseModel): + uuid: str + matched_date_time: str + type: str + + +class MatchedFilter(BaseModel): + id: str + name: str + matched_date_time: str + mitre_technique_ids: List[str] + matched_events: List[MatchedEvent] + + +class MatchedIndicatorPattern(BaseModel): + id: str + pattern: str + tags: List[str] + matched_logs: List[str] = Field(default=[]) + + +class MatchedRule(BaseModel): + id: str + name: str + matched_filters: List[MatchedFilter] + + +class MsData(BaseModel): + status: int + task_id: Optional[str] = None + + def __init__(self, **data: Any): + super().__init__( + taskId=_get_task_id(data.pop("headers", {})), + **data, + ) + + +class MsDataUrl(MsData): + url: str + id: Optional[str] + digest: Optional[Digest] + + def __init__(self, **data: Any): + data.update(data.pop("body", {})) + super().__init__(**data) + + +class MsError(Error): + extra: Dict[str, str] = {} + task_id: Optional[str] + + def __init__(self, **data: Any): + data.update(data.pop("body", {})) + data.update(data.pop("error", {})) + super().__init__( + extra={"url": data.pop("url", "")}, + taskId=_get_task_id(data.pop("headers", {})), + **data, + ) + + +class MsStatus(BaseModel): + __root__: List[int] + + def __init__(self, **data: Any): + super().__init__( + root=[int(d.get("status", 500)) for d in data.get("__root__", [])] + ) + + def values(self) -> List[int]: + return self.__root__ + + +class SaeAlert(Alert): + description: str + matched_rules: List[MatchedRule] + + +class SaeIndicator(Indicator): + field: str + filter_ids: List[str] + + +class SandboxSuspiciousObject(BaseModel): + risk_level: RiskLevel + analysis_completion_date_time: str + expired_date_time: str + root_sha1: str + type: ObjectType + value: str + + def __init__(self, **data: str) -> None: + obj: Tuple[str, str] = self._map(data) + super().__init__(type=obj[0], value=obj[1], **data) + + @staticmethod + def _map(args: Dict[str, str]) -> Tuple[str, str]: + return { + (k, v) + for k, v in args.items() + if k in map(lambda ot: ot.value, ObjectType) + }.pop() + + +class SuspiciousObject(ExceptionObject): + scan_action: ScanAction + risk_level: RiskLevel + in_exception_list: bool + expired_date_time: str + + +class TiAlert(Alert): + campaign: Optional[str] + industry: Optional[str] + region_and_country: Optional[str] + created_by: str + total_indicator_count: int + matched_indicator_count: int + report_link: str + matched_indicator_patterns: List[MatchedIndicatorPattern] + + +class TiIndicator(Indicator): + fields: List[List[str]] + matched_indicator_pattern_ids: List[str] + first_seen_date_times: List[str] + last_seen_date_times: List[str] diff --git a/python/pytmv1/src/pytmv1/model/enums.py b/python/pytmv1/src/pytmv1/model/enums.py new file mode 100644 index 0000000..dc2a648 --- /dev/null +++ b/python/pytmv1/src/pytmv1/model/enums.py @@ -0,0 +1,260 @@ +from enum import Enum + + +class Api(str, Enum): + ADD_ALERT_NOTE = "/workbench/alerts/{0}/notes" + ADD_TO_BLOCK_LIST = "/response/suspiciousObjects" + ADD_TO_EXCEPTION_LIST = "/threatintel/suspiciousObjectExceptions" + ADD_TO_SUSPICIOUS_LIST = "/threatintel/suspiciousObjects" + COLLECT_ENDPOINT_FILE = "/response/endpoints/collectFile" + CONNECTIVITY = ("/healthcheck/connectivity",) + DELETE_EMAIL_MESSAGE = "/response/emails/delete" + DISABLE_ACCOUNT = "/response/domainAccounts/disable" + DOWNLOAD_SANDBOX_ANALYSIS_RESULT = "/sandbox/analysisResults/{0}/report" + DOWNLOAD_SANDBOX_INVESTIGATION_PACKAGE = ( + "/sandbox/analysisResults/{0}/investigationPackage" + ) + EDIT_ALERT_STATUS = "/workbench/alerts/{0}" + ENABLE_ACCOUNT = "/response/domainAccounts/enable" + ISOLATE_ENDPOINT = "/response/endpoints/isolate" + GET_ALERT_DETAILS = "/workbench/alerts/{0}" + GET_ALERT_LIST = "/workbench/alerts" + GET_EMAIL_ACTIVITY_DATA = "/search/emailActivities" + GET_ENDPOINT_ACTIVITY_DATA = "/search/endpointActivities" + GET_ENDPOINT_DATA = "/eiqs/endpoints" + GET_EXCEPTION_LIST = "/threatintel/suspiciousObjectExceptions" + GET_SANDBOX_SUBMISSION_STATUS = "/sandbox/tasks/{0}" + GET_SANDBOX_ANALYSIS_RESULT = "/sandbox/analysisResults/{0}" + GET_SANDBOX_SUSPICIOUS_LIST = ( + "/sandbox/analysisResults/{0}/suspiciousObjects" + ) + GET_SUSPICIOUS_LIST = "/threatintel/suspiciousObjects" + GET_TASK_RESULT = "/response/tasks/{0}" + QUARANTINE_EMAIL_MESSAGE = "/response/emails/quarantine" + REMOVE_FROM_BLOCK_LIST = "/response/suspiciousObjects/delete" + REMOVE_FROM_EXCEPTION_LIST = ( + "/threatintel/suspiciousObjectExceptions/delete" + ) + REMOVE_FROM_SUSPICIOUS_LIST = "/threatintel/suspiciousObjects/delete" + RESET_PASSWORD = "/response/domainAccounts/resetPassword" + RESTORE_EMAIL_MESSAGE = "/response/emails/restore" + RESTORE_ENDPOINT = "/response/endpoints/restore" + SIGN_OUT_ACCOUNT = "/response/domainAccounts/signOut" + SUBMIT_FILE_TO_SANDBOX = "/sandbox/files/analyze" + SUBMIT_URLS_TO_SANDBOX = "/sandbox/urls/analyze" + TERMINATE_ENDPOINT_PROCESS = "/response/endpoints/terminateProcess" + + +class Iam(str, Enum): + # Azure AD + AAD = "AAD" + # On-premise AD + OPAD = "OPAD" + + +class IntegrityLevel(int, Enum): + UNTRUSTED = 0 + LOW = 4096 + MEDIUM = 8192 + HIGH = 12288 + SYSTEM = 16384 + + +class InvestigationStatus(str, Enum): + BENIGN_TRUE_POSITIVE = "Benign True Positive" + CLOSED = "Closed" + FALSE_POSITIVE = "False Positive" + IN_PROGRESS = "In Progress" + NEW = "New" + TRUE_POSITIVE = "True Positive" + + +class EntityType(str, Enum): + HOST = "host" + ACCOUNT = "account" + EMAIL_ADDRESS = "emailAddress" + CONTAINER = "container" + CLOUD_IDENTITY = "cloudIdentity" + AWS_LAMBDA = "awsLambda" + + +class EventID(str, Enum): + EVENT_PROCESS = "1" + EVENT_FILE = "2" + EVENT_CONNECTIO = "3" + EVENT_DNS = "4" + EVENT_REGISTRY = "5" + EVENT_ACCOUNT = "6" + EVENT_INTERNET = "7" + XDR_EVENT_MODIFIED_PROCESS = "8" + EVENT_WINDOWS_HOOK = "9" + EVENT_WINDOWS_EVENT = "10" + EVENT_AMSI = "11" + EVENT_WMI = "12" + TELEMETRY_MEMORY = "13" + TELEMETRY_BM = "14" + + +class EventSubID(int, Enum): + TELEMETRY_NONE = 0 + XDR_PROCESS_OPEN = 1 + XDR_PROCESS_CREATE = 2 + XDR_PROCESS_TERMINATE = 3 + XDR_PROCESS_LOAD_IMAGE = 4 + TELEMETRY_PROCESS_EXECUTE = 5 + TELEMETRY_PROCESS_CONNECT = 6 + TELEMETRY_PROCESS_TRACME = 7 + XDR_FILE_CREATE = 101 + XDR_FILE_OPEN = 102 + XDR_FILE_DELETE = 103 + XDR_FILE_SET_SECURITY = 104 + XDR_FILE_COPY = 105 + XDR_FILE_MOVE = 106 + XDR_FILE_CLOSE = 107 + TELEMETRY_FILE_MODIFY_TIMESTAMP = 108 + TELEMETRY_FILE_MODIFY = 109 + XDR_CONNECTION_CONNECT = 201 + XDR_CONNECTION_LISTEN = 202 + XDR_CONNECTION_CONNECT_INBOUND = 203 + XDR_CONNECTION_CONNECT_OUTBOUND = 204 + XDR_DNS_QUERY = 301 + XDR_REGISTRY_CREATE = 401 + XDR_REGISTRY_SET = 402 + XDR_REGISTRY_DELETE = 403 + XDR_REGISTRY_RENAME = 404 + XDR_ACCOUNT_ADD = 501 + XDR_ACCOUNT_DELETE = 502 + XDR_ACCOUNT_IMPERSONATE = 503 + XDR_ACCOUNT_MODIFY = 504 + XDR_INTERNET_OPEN = 601 + XDR_INTERNET_CONNECT = 602 + XDR_INTERNET_DOWNLOAD = 603 + XDR_MODIFIED_PROCESS_CREATE_REMOTETHREAD = 701 + XDR_MODIFIED_PROCESS_WRITE_MEMORY = 702 + TELEMETRY_MODIFIED_PROCESS_WRITE_PROCESS = 703 + TELEMETRY_MODIFIED_PROCESS_READ_PROCESS = 704 + TELEMETRY_MODIFIED_PROCESS_WRITE_PROCESS_NAME = 705 + XDR_WINDOWS_HOOK_SET = 801 + XDR_AMSI_EXECUTE = 901 + TELEMETRY_MEMORY_MODIFY = 1001 + TELEMETRY_MEMORY_MODIFY_PERMISSION = 1002 + TELEMETRY_MEMORY_READ = 1003 + TELEMETRY_BM_INVOKE = 1101 + TELEMETRY_BM_INVOKE_API = 1102 + + +class HttpMethod(str, Enum): + GET = "GET" + PATCH = "PATCH" + POST = "POST" + PUT = "PUT" + DELETE = "DELETE" + + +class ObjectType(str, Enum): + IP = "ip" + URL = "url" + DOMAIN = "domain" + FILE_SHA1 = "fileSha1" + FILE_SHA256 = "fileSha256" + SENDER_MAIL_ADDRESS = "senderMailAddress" + + +class OperatingSystem(str, Enum): + LINUX = "Linux" + WINDOWS = "Windows" + MACOS = "macOS" + MACOSX = "macOSX" + + +class ProductCode(str, Enum): + SAO = "sao" + SDS = "sds" + XES = "xes" + + +class Provenance(str, Enum): + ALERT = "Alert" + SWEEPING = "Sweeping" + NETWORK_ANALYTICS = "Network Analytics" + + +class Provider(str, Enum): + SAE = "SAE" + TI = "TI" + + +class QueryField(str, Enum): + AGENT_GUID = "agentGuid" + LOGIN_ACCOUNT = "loginAccount" + ENDPOINT_NAME = "endpointName" + MAC_ADDRESS = "macAddress" + IP = "ip" + OS_NAME = "osName" + PRODUCT_CODE = "productCode" + INSTALLED_PRODUCT_CODES = "installedProductCodes" + + +class QueryOp(str, Enum): + AND = " and " + OR = " or " + + +class RiskLevel(str, Enum): + NO_RISK = "noRisk" + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + + +class SandboxAction(str, Enum): + ANALYZE_FILE = "analyzeFile" + ANALYZE_URL = "analyzeUrl" + + +class SandboxObjectType(str, Enum): + URL = "url" + FILE = "file" + + +class ScanAction(str, Enum): + BLOCK = "block" + LOG = "log" + + +class SearchMode(str, Enum): + DEFAULT = "default" + COUNT_ONLY = "countOnly" + + +class Severity(str, Enum): + CRITICAL = "critical" + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + + +class Status(str, Enum): + FAILED = "failed" + QUEUED = "queued" + REJECTED = "rejected" + RUNNING = "running" + SUCCEEDED = "succeeded" + WAIT_FOR_APPROVAL = "waitForApproval" + + +class TaskAction(str, Enum): + COLLECT_FILE = ("collectFile",) + ISOLATE_ENDPOINT = ("isolate",) + RESTORE_ENDPOINT = ("restoreIsolate",) + TERMINATE_PROCESS = ("terminateProcess",) + QUARANTINE_MESSAGE = "quarantineMessage" + DELETE_MESSAGE = ("deleteMessage",) + RESTORE_MESSAGE = ("restoreMessage",) + BLOCK_SUSPICIOUS = ("block",) + REMOVE_SUSPICIOUS = ("restoreBlock",) + RESET_PASSWORD = "resetPassword" + SUBMIT_SANDBOX = ("submitSandbox",) + ENABLE_ACCOUNT = ("enableAccount",) + DISABLE_ACCOUNT = ("disableAccount",) + FORCE_SIGN_OUT = "forceSignOut" diff --git a/python/pytmv1/src/pytmv1/model/requests.py b/python/pytmv1/src/pytmv1/model/requests.py new file mode 100755 index 0000000..368f6f0 --- /dev/null +++ b/python/pytmv1/src/pytmv1/model/requests.py @@ -0,0 +1,66 @@ +from typing import Optional + +from .commons import BaseModel +from .enums import ObjectType, RiskLevel, ScanAction + + +class AccountTask(BaseModel): + account_name: str + """User account name.""" + description: Optional[str] = None + """Description of a response task.""" + + +class EndpointTask(BaseModel): + endpoint_name: Optional[str] + """Endpoint name.""" + agent_guid: Optional[str] = None + """Agent guid""" + description: Optional[str] = None + """Description of a response task.""" + + +class EmailMessageIdTask(BaseModel): + message_id: str + """Email message id.""" + mail_box: Optional[str] + """Email address.""" + description: Optional[str] = None + """Description of a response task.""" + + +class EmailMessageUIdTask(BaseModel): + unique_id: str + """Email unique message id.""" + description: Optional[str] = None + """Description of a response task.""" + + +class ObjectTask(BaseModel): + object_type: ObjectType + """Type of object.""" + object_value: str + """Value of an object.""" + description: Optional[str] = None + """Description of an object.""" + + +class SuspiciousObjectTask(ObjectTask): + scan_action: Optional[ScanAction] = None + """Action applied after detecting a suspicious object.""" + risk_level: Optional[RiskLevel] = None + """Risk level of a suspicious object.""" + days_to_expiration: Optional[int] = None + """Number of days before the object expires.""" + + +class FileTask(EndpointTask): + file_path: str + """File path of the file to be collected from the target.""" + + +class ProcessTask(EndpointTask): + file_sha1: str + """SHA1 hash of the terminated process's executable file.""" + file_name: Optional[str] = None + """File name of the target.""" diff --git a/python/pytmv1/src/pytmv1/model/responses.py b/python/pytmv1/src/pytmv1/model/responses.py new file mode 100644 index 0000000..3e17c76 --- /dev/null +++ b/python/pytmv1/src/pytmv1/model/responses.py @@ -0,0 +1,217 @@ +from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union + +from pydantic import Field +from pydantic.generics import GenericModel + +from .commons import ( + Account, + BaseConsumable, + BaseModel, + Digest, + EmailActivity, + EmailMessage, + Endpoint, + EndpointActivity, + ExceptionObject, + MsData, + MsDataUrl, + SaeAlert, + SandboxSuspiciousObject, + SuspiciousObject, + TiAlert, +) +from .enums import ( + ObjectType, + RiskLevel, + SandboxAction, + SandboxObjectType, + Status, + TaskAction, +) + +C = TypeVar("C", bound=BaseConsumable) +M = TypeVar("M", bound=MsData) + + +class BaseResponse(BaseModel): + ... + + +class BaseLinkableResp(BaseResponse, GenericModel, Generic[C]): + next_link: Optional[str] + items: List[C] = [] + + +class BaseMultiResponse(BaseResponse, GenericModel, Generic[M]): + items: List[M] = [] + + +class BaseStatusResponse(BaseResponse): + id: str + status: Status + created_date_time: str + last_action_date_time: str + + +class BaseTaskResp(BaseStatusResponse): + action: TaskAction + description: Optional[str] + account: Optional[str] + + +MR = TypeVar("MR", bound=BaseMultiResponse[Any]) +R = TypeVar("R", bound=BaseResponse) +S = TypeVar("S", bound=BaseStatusResponse) + + +class AccountTaskResp(BaseTaskResp): + tasks: List[Account] + + +class AddAlertNoteResp(BaseResponse): + location: str = Field(alias="Location") + + def note_id(self) -> str: + return self.location.split("/")[-1] + + +class BlockListTaskResp(BaseTaskResp): + type: ObjectType + value: str + + def __init__(self, **data: str) -> None: + obj: Tuple[str, str] = self._map(data) + super().__init__(type=obj[0], value=obj[1], **data) + + @staticmethod + def _map(args: Dict[str, str]) -> Tuple[str, str]: + return { + (k, v) + for k, v in args.items() + if k in map(lambda ot: ot.value, ObjectType) + }.pop() + + +class BytesResp(BaseResponse): + content: bytes + + +class CollectFileTaskResp(BaseTaskResp): + agent_guid: str + endpoint_name: str + file_path: str + file_sha1: Optional[str] + file_sha256: Optional[str] + file_size: Optional[int] + resource_location: Optional[str] + expired_date_time: Optional[str] + password: Optional[str] + + +class ConnectivityResp(BaseResponse): + status: str + + +class ConsumeLinkableResp(BaseResponse, alias_generator=None): + total_consumed: int + + +class EndpointTaskResp(BaseTaskResp): + agent_guid: str + endpoint_name: str + + +class GetAlertDetailsResp(BaseResponse): + alert: Union[SaeAlert, TiAlert] + etag: str + + +class GetAlertListResp(BaseLinkableResp[Union[SaeAlert, TiAlert]]): + total_count: int + count: int + + +class GetEndpointActivityDataResp(BaseLinkableResp[EndpointActivity]): + progress_rate: int + + +class GetEndpointActivityDataCountResp(BaseResponse): + total_count: int + + +class GetEmailActivityDataResp(BaseLinkableResp[EmailActivity]): + progress_rate: int + + +class GetEmailActivityDataCountResp(BaseResponse): + total_count: int + + +class GetEndpointDataResp(BaseLinkableResp[Endpoint]): + ... + + +class GetExceptionListResp(BaseLinkableResp[ExceptionObject]): + ... + + +class GetSuspiciousListResp(BaseLinkableResp[SuspiciousObject]): + ... + + +class MultiResp(BaseMultiResponse[MsData]): + ... + + +class MultiUrlResp(BaseMultiResponse[MsDataUrl]): + ... + + +class NoContentResp(BaseResponse): + ... + + +class EmailMessageTaskResp(BaseTaskResp): + tasks: List[EmailMessage] + + +class SubmitFileToSandboxResp(BaseResponse): + id: str + digest: Digest + arguments: Optional[str] + + +class SandboxAnalysisResultResp(BaseResponse): + id: str + type: SandboxObjectType + analysis_completion_date_time: str + risk_level: RiskLevel + true_file_type: Optional[str] + digest: Optional[Digest] + arguments: Optional[str] + detection_names: List[str] = Field(default=[]) + threat_types: List[str] = Field(default=[]) + + +class SandboxSubmissionStatusResp(BaseStatusResponse): + action: SandboxAction + resource_location: Optional[str] + is_cached: Optional[bool] + digest: Optional[Digest] + arguments: Optional[str] + + +class SandboxSuspiciousListResp(BaseResponse): + items: List[SandboxSuspiciousObject] + + +class SandboxSubmitUrlTaskResp(BaseTaskResp): + url: str + sandbox_task_id: str + + +class TerminateProcessTaskResp(BaseTaskResp): + agent_guid: str + endpoint_name: str + file_sha1: str + file_name: Optional[str] diff --git a/python/pytmv1/src/pytmv1/py.typed b/python/pytmv1/src/pytmv1/py.typed new file mode 100755 index 0000000..e69de29 diff --git a/python/pytmv1/src/pytmv1/results.py b/python/pytmv1/src/pytmv1/results.py new file mode 100755 index 0000000..272c03e --- /dev/null +++ b/python/pytmv1/src/pytmv1/results.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass, field +from enum import Enum +from functools import wraps +from logging import Logger +from typing import Any, Callable, Generic, List, Optional, TypeVar + +from pydantic import ValidationError +from requests import RequestException + +from .exceptions import ServerCustError, ServerJsonError, ServerMultiJsonError +from .model.commons import Error, MsError +from .model.responses import MR, R + +E = TypeVar("E", bound=Error) +F = TypeVar("F", bound=Callable[..., Any]) + +log: Logger = logging.getLogger(__name__) + + +def multi_result(func: F) -> Callable[..., MultiResult[MR]]: + @wraps(func) + def _multi_result(*args: Any, **kwargs: Any) -> MultiResult[MR]: + obj: MR | Exception = _wrapper(func, *args, **kwargs) + return ( + MultiResult.success(obj) + if not isinstance(obj, Exception) + else MultiResult.failed(obj) + ) + + return _multi_result + + +def result(func: F) -> Callable[..., Result[R]]: + @wraps(func) + def _result(*args: Any, **kwargs: Any) -> Result[R]: + obj: R | Exception = _wrapper(func, *args, **kwargs) + return ( + Result.success(obj) + if not isinstance(obj, Exception) + else Result.failed(obj) + ) + + return _result + + +def _wrapper(func: F, *args: Any, **kwargs: Any) -> R | Exception: + try: + start_time: float = time.time() + log.debug( + "Execution started [%s, %s]", + args, + kwargs, + ) + response: R = func(*args, **kwargs) + log.debug( + "Execution finished [Elapsed=%s, %s]", + time.time() - start_time, + response, + ) + return response + except ( + ServerCustError, + ServerJsonError, + ServerMultiJsonError, + ValidationError, + RequestException, + RuntimeError, + ) as exc: + log.exception("Unexpected issue occurred [%s]", exc) + return exc + + +def _error(exc: Exception) -> Error: + if isinstance(exc, ServerJsonError): + return exc.error + return Error( + status=_status(exc), code=type(exc).__name__, message=str(exc) + ) + + +def _errors(exc: Exception) -> List[MsError]: + if isinstance(exc, ServerMultiJsonError): + return exc.errors + if isinstance(exc, ServerJsonError): + return [ + MsError( + status=exc.error.status, + code=exc.error.code, + message=exc.error.message, + number=exc.error.number, + ) + ] + return [ + MsError(status=_status(exc), code=type(exc).__name__, message=str(exc)) + ] + + +def _status(exc: Exception) -> int: + return exc.status if isinstance(exc, ServerCustError) else 500 + + +@dataclass +class BaseResult(Generic[R]): + result_code: ResultCode + response: Optional[R] = None + + +@dataclass +class Result(BaseResult[R]): + error: Optional[Error] = None + + @classmethod + def success(cls, response: R) -> Result[R]: + return cls(ResultCode.SUCCESS, response) + + @classmethod + def failed(cls, exc: Exception) -> Result[R]: + return cls( + ResultCode.ERROR, + None, + _error(exc), + ) + + +@dataclass +class MultiResult(BaseResult[MR]): + errors: List[MsError] = field(default_factory=list) + + @classmethod + def success(cls, response: MR) -> MultiResult[MR]: + return cls(ResultCode.SUCCESS, response) + + @classmethod + def failed(cls, exc: Exception) -> MultiResult[MR]: + return cls( + ResultCode.ERROR, + None, + _errors(exc), + ) + + +class ResultCode(str, Enum): + SUCCESS = "SUCCESS" + ERROR = "ERROR" diff --git a/python/pytmv1/src/pytmv1/utils.py b/python/pytmv1/src/pytmv1/utils.py new file mode 100755 index 0000000..6e977c3 --- /dev/null +++ b/python/pytmv1/src/pytmv1/utils.py @@ -0,0 +1,131 @@ +import base64 +import re +from typing import Any, Dict, List, Optional, Pattern, Tuple + +from pydantic import IPvAnyAddress, IPvAnyAddressError + +from .model.enums import ( + OperatingSystem, + ProductCode, + QueryField, + QueryOp, + SearchMode, +) +from .model.requests import ObjectTask, SuspiciousObjectTask + +MAC_ADDRESS_PATTERN: Pattern[str] = re.compile( + "^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$" +) +GUID_PATTERN: Pattern[str] = re.compile("^(\\w+-+){1,5}\\w+$") + + +def build_activity_request( + start_time: Optional[str], + end_time: Optional[str], + select: Optional[List[str]], + top: int, + search_mode: SearchMode, +) -> Dict[str, str]: + return filter_none( + { + "startDateTime": start_time, + "endDateTime": end_time, + "select": ",".join(select) if select else select, + "top": top, + "mode": search_mode, + } + ) + + +def build_object_request(*tasks: ObjectTask) -> List[Dict[str, str]]: + return [ + filter_none( + { + task.object_type.value: task.object_value, + "description": task.description, + } + ) + for task in tasks + ] + + +def build_sandbox_file_request( + document_password: Optional[str], + archive_password: Optional[str], + arguments: Optional[str], +) -> Dict[str, str]: + return filter_none( + { + "documentPassword": _b64_encode(document_password), + "archivePassword": _b64_encode(archive_password), + "arguments": _b64_encode(arguments), + } + ) + + +def build_suspicious_request( + *tasks: SuspiciousObjectTask, +) -> List[Dict[str, Any]]: + return [ + filter_none( + { + task.object_type.value: task.object_value, + "description": task.description, + "riskLevel": ( + task.risk_level.value if task.risk_level else None + ), + "scanAction": ( + task.scan_action.value if task.scan_action else None + ), + "daysToExpiration": task.days_to_expiration, + } + ) + for task in tasks + ] + + +def activity_query(op: QueryOp, **fields: str) -> Dict[str, str]: + return {"TMV1-Query": op.join([f'{k}:"{v}"' for k, v in fields.items()])} + + +def endpoint_query(op: QueryOp, *values: str) -> Dict[str, str]: + return { + "TMV1-Query": op.join( + "(" + + QueryOp.OR.join( + f"{qt.value} eq '{value}'" + for qt in endpoint_query_field(value) + ) + + ")" + for value in values + ) + } + + +def endpoint_query_field(value: str) -> Tuple[QueryField, ...]: + if _is_ip_address(value): + return (QueryField.IP,) + if bool(MAC_ADDRESS_PATTERN.match(value)): + return (QueryField.MAC_ADDRESS,) + if bool(GUID_PATTERN.match(value)): + return (QueryField.AGENT_GUID,) + if next(filter(lambda os: os.value == value, OperatingSystem), None): + return (QueryField.OS_NAME,) + if next(filter(lambda pc: pc.value == value, ProductCode), None): + return QueryField.PRODUCT_CODE, QueryField.INSTALLED_PRODUCT_CODES + return QueryField.ENDPOINT_NAME, QueryField.LOGIN_ACCOUNT + + +def filter_none(dictionary: Dict[str, Optional[Any]]) -> Dict[str, Any]: + return {k: v for k, v in dictionary.items() if v} + + +def _b64_encode(value: Optional[str]) -> Optional[str]: + return base64.b64encode(value.encode()).decode() if value else None + + +def _is_ip_address(endpoint_value: str) -> bool: + try: + return bool(IPvAnyAddress.validate(endpoint_value)) + except IPvAnyAddressError: + return False diff --git a/python/pytmv1/tests/__init__.py b/python/pytmv1/tests/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/python/pytmv1/tests/conftest.py b/python/pytmv1/tests/conftest.py new file mode 100755 index 0000000..f6db2ac --- /dev/null +++ b/python/pytmv1/tests/conftest.py @@ -0,0 +1,48 @@ +import socket + +import pytest + +import pytmv1 +from pytmv1.core import Core + + +def pytest_addoption(parser): + parser.addoption( + "--mock-url", + action="store", + default="", + dest="mock-url", + help="Mock URL for Vision One API", + ) + + +@pytest.fixture(scope="package") +def client(pytestconfig): + return pytmv1.client( + "appname", + "token", + _default(pytestconfig.getoption("mock-url")), + ) + + +@pytest.fixture(scope="package") +def core(pytestconfig): + return Core( + "appname", + "token", + _default(pytestconfig.getoption("mock-url")), + 0, + 0, + 30, + 30, + ) + + +@pytest.fixture(scope="package") +def ip(pytestconfig): + url = pytestconfig.getoption("mock-url") + return socket.gethostbyname(url.split("/")[2]) if url != "" else None + + +def _default(url: str): + return url if url else "https://dummy-server.com" diff --git a/python/pytmv1/tests/data.py b/python/pytmv1/tests/data.py new file mode 100755 index 0000000..59532f9 --- /dev/null +++ b/python/pytmv1/tests/data.py @@ -0,0 +1,139 @@ +from requests import Response + +from pytmv1 import ( + Entity, + HostInfo, + ImpactScope, + Indicator, + InvestigationStatus, + MatchedFilter, + MatchedIndicatorPattern, + MatchedRule, + SaeAlert, + Severity, + TiAlert, +) + + +class TextResponse(Response): + def __init__(self, value: str): + super().__init__() + self.value = value + + @property + def content(self) -> bytes: + return self.value.encode("utf-8") + + @property + def text(self) -> str: + return self.value + + +def sae_alert(): + return SaeAlert.construct( + id="1", + investigationStatus=InvestigationStatus.NEW, + model="Possible Credential Dumping via Registry", + severity=Severity.HIGH, + createdDateTime="2022-09-06T02:49:33Z", + alertProvider="SAE", + description="description", + workbenchLink="https://THE_WORKBENCH_URL", + score=64, + impactScope=ImpactScope.construct( + desktopCount=1, + serverCount=0, + accountCount=1, + emailAddressCount=0, + entities=[ + Entity.construct( + entity_value=HostInfo.construct( + name="host", ips=["1.1.1.1", "2.2.2.2"] + ) + ) + ], + ), + indicators=[ + Indicator.construct( + provenance=["Alert"], + value=HostInfo.construct( + name="host", ips=["1.1.1.1", "2.2.2.2"] + ), + ) + ], + matchedRules=[ + MatchedRule.construct( + name="Potential Credential Dumping via Registry", + matchedFilters=[ + MatchedFilter.construct( + name="Possible Credential Dumping via Registry Hive", + mitreTechniqueIds=[ + "V9.T1003.004", + "V9.T1003.002", + "T1003", + ], + ) + ], + ) + ], + ) + + +def ti_alert(): + return TiAlert.construct( + id="1", + investigationStatus=InvestigationStatus.NEW, + model="Threat Intelligence Sweeping", + campaign="campaign", + industry="industry", + regionAndCountry="regionAndCountry", + severity=Severity.MEDIUM, + createdDateTime="2022-09-06T02:49:33Z", + alertProvider="TI", + workbenchLink="https://THE_WORKBENCH_URL", + reportLink="https://THE_TI_REPORT_URL", + createdBy="n/a", + score=42, + impactScope=ImpactScope.construct( + desktopCount=1, + serverCount=0, + accountCount=1, + emailAddressCount=0, + entities=[ + Entity.construct( + entity_value=HostInfo.construct( + name="host", ips=["1.1.1.1", "2.2.2.2"] + ) + ) + ], + ), + indicators=[ + Indicator.construct( + provenance=["Alert"], + value=HostInfo.construct( + name="host", ips=["1.1.1.1", "2.2.2.2"] + ), + ) + ], + matchedIndicatorPatterns=[ + MatchedIndicatorPattern.construct( + tags=["STIX2.malicious-activity"], + pattern="[file:name = 'goog-phish-proto-1.vlpset']", + ) + ], + matchedRules=[ + MatchedRule.construct( + name="Potential Credential Dumping via Registry", + matchedFilters=[ + MatchedFilter.construct( + name="Possible Credential Dumping via Registry Hive", + mitreTechniqueIds=[ + "V9.T1003.004", + "V9.T1003.002", + "T1003", + ], + ) + ], + ) + ], + ) diff --git a/python/pytmv1/tests/integration/__init__.py b/python/pytmv1/tests/integration/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/python/pytmv1/tests/integration/test_account.py b/python/pytmv1/tests/integration/test_account.py new file mode 100644 index 0000000..ac44d53 --- /dev/null +++ b/python/pytmv1/tests/integration/test_account.py @@ -0,0 +1,33 @@ +from pytmv1 import AccountTask, ResultCode + + +def test_disable_account(client): + result = client.disable_account(AccountTask(accountName="test")) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].status == 202 + assert result.response.items[0].task_id == "00000009" + + +def test_enable_account(client): + result = client.enable_account(AccountTask(accountName="test")) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].status == 202 + assert result.response.items[0].task_id == "00000010" + + +def test_reset_password_account(client): + result = client.reset_password_account(AccountTask(accountName="test")) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].status == 202 + assert result.response.items[0].task_id == "00000011" + + +def test_sign_out_account(client): + result = client.sign_out_account(AccountTask(accountName="test")) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].status == 202 + assert result.response.items[0].task_id == "00000012" diff --git a/python/pytmv1/tests/integration/test_common.py b/python/pytmv1/tests/integration/test_common.py new file mode 100755 index 0000000..6d85c9f --- /dev/null +++ b/python/pytmv1/tests/integration/test_common.py @@ -0,0 +1,96 @@ +from pytmv1 import ( + BaseTaskResp, + CollectFileTaskResp, + EmailMessageIdTask, + ResultCode, + Status, +) + + +def test_check_connectivity(client): + assert client.check_connectivity() + + +def test_get_base_task_result(client): + result = client.get_base_task_result("00000004", False) + assert isinstance(result.response, BaseTaskResp) + assert result.result_code == ResultCode.SUCCESS + assert result.response.status == Status.SUCCEEDED + assert result.response.id == "00000004" + + +def test_collect_file_task_result(client): + result = client.get_task_result("collect_file", CollectFileTaskResp, False) + assert isinstance(result.response, CollectFileTaskResp) + assert result.result_code == ResultCode.SUCCESS + assert result.response.status == Status.SUCCEEDED + assert result.response.file_sha256 + assert result.response.id == "00000003" + + +def test_collect_file_task_result_is_failed(client): + result = client.get_task_result( + "internal_error", CollectFileTaskResp, False + ) + assert not result.response + assert result.result_code == ResultCode.ERROR + assert result.error.code == "InternalServerError" + assert result.error.status == 500 + + +def test_collect_file_task_result_is_bad_request(client): + result = client.get_task_result("bad_request", CollectFileTaskResp, False) + assert not result.response + assert result.result_code == ResultCode.ERROR + assert result.error.code == "BadRequest" + assert result.error.status == 400 + + +def test_multi_status_is_failed(client): + result = client.delete_email_message( + EmailMessageIdTask(messageId="internal_server_error") + ) + assert not result.response + assert result.result_code == ResultCode.ERROR + assert result.errors[0].code == "InternalServerError" + assert result.errors[0].status == 500 + + +def test_multi_status_is_bad_request(client): + result = client.delete_email_message( + EmailMessageIdTask(messageId="fields_not_found") + ) + assert not result.response + assert result.result_code == ResultCode.ERROR + assert result.errors[0].code == "BadRequest" + assert result.errors[0].status == 400 + + +def test_multi_status_is_denied(client): + result = client.delete_email_message( + EmailMessageIdTask(messageId="insufficient_permissions") + ) + assert not result.response + assert result.result_code == ResultCode.ERROR + assert result.errors[0].code == "AccessDenied" + assert result.errors[0].status == 403 + + +def test_multi_status_is_not_supported(client): + result = client.delete_email_message( + EmailMessageIdTask(messageId="action_not_supported") + ) + assert not result.response + assert result.result_code == ResultCode.ERROR + assert result.errors[0].code == "NotSupported" + assert result.errors[0].status == 400 + + +def test_multi_status_is_task_error(client): + result = client.delete_email_message( + EmailMessageIdTask(messageId="task_duplication") + ) + assert not result.response + assert result.result_code == ResultCode.ERROR + assert result.errors[0].code == "TaskError" + assert result.errors[0].status == 400 diff --git a/python/pytmv1/tests/integration/test_email.py b/python/pytmv1/tests/integration/test_email.py new file mode 100755 index 0000000..2104c85 --- /dev/null +++ b/python/pytmv1/tests/integration/test_email.py @@ -0,0 +1,49 @@ +from pytmv1 import EmailMessageIdTask, ResultCode + + +def test_delete_email_message(client): + result = client.delete_email_message(EmailMessageIdTask(messageId="1")) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].status == 202 + + +def test_delete_email_message_is_failed(client): + result = client.delete_email_message( + EmailMessageIdTask(messageId="server_error") + ) + assert result.result_code == ResultCode.ERROR + assert result.errors[0].status == 500 + assert result.errors[0].code == "InternalServerError" + + +def test_delete_email_message_is_bad_request(client): + result = client.delete_email_message( + EmailMessageIdTask(messageId="invalid_format") + ) + assert result.result_code == ResultCode.ERROR + assert result.errors[0].code == "BadRequest" + assert result.errors[0].status == 400 + + +def test_delete_email_message_is_denied(client): + result = client.delete_email_message( + EmailMessageIdTask(messageId="access_denied") + ) + assert result.result_code == ResultCode.ERROR + assert result.errors[0].code == "AccessDenied" + assert result.errors[0].status == 403 + + +def test_quarantine_email_message(client): + result = client.quarantine_email_message(EmailMessageIdTask(messageId="1")) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].status == 202 + + +def test_restore_email_message(client): + result = client.restore_email_message(EmailMessageIdTask(messageId="1")) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].status == 202 diff --git a/python/pytmv1/tests/integration/test_endpoint.py b/python/pytmv1/tests/integration/test_endpoint.py new file mode 100755 index 0000000..a54b5d6 --- /dev/null +++ b/python/pytmv1/tests/integration/test_endpoint.py @@ -0,0 +1,39 @@ +from pytmv1 import EndpointTask, FileTask, MultiResp, ProcessTask, ResultCode + + +def test_collect_file(client): + result = client.collect_file( + FileTask(endpointName="client1", filePath="/tmp/dummy.txt") + ) + assert isinstance(result.response, MultiResp) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].status == 202 + + +def test_isolate_endpoint(client): + result = client.isolate_endpoint(EndpointTask(endpointName="client1")) + assert isinstance(result.response, MultiResp) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].status == 202 + + +def test_restore_endpoint(client): + result = client.restore_endpoint(EndpointTask(endpointName="client1")) + assert isinstance(result.response, MultiResp) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].status == 202 + + +def test_terminate_process(client): + result = client.terminate_process( + ProcessTask( + endpointName="client1", fileSha1="sha12345", fileName="dummy.exe" + ) + ) + assert isinstance(result.response, MultiResp) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].status == 202 diff --git a/python/pytmv1/tests/integration/test_network.py b/python/pytmv1/tests/integration/test_network.py new file mode 100755 index 0000000..4e14204 --- /dev/null +++ b/python/pytmv1/tests/integration/test_network.py @@ -0,0 +1,52 @@ +from threading import Thread + +import psutil +import pytest + + +def test_conn_opened_with_single_call_single_client_is_one(client, ip): + client.get_exception_list() + assert len(list_tcp_conn(ip)) == 1 + + +@pytest.mark.parametrize("execution_number", range(10)) +def test_conn_opened_with_multi_call_single_client_is_one( + execution_number, client, ip +): + client.get_exception_list() + assert len(list_tcp_conn(ip)) == 1 + + +def test_conn_opened_with_multi_processing_single_client_is_one(client, ip): + threads = thread_list(lambda: client.get_exception_list()) + for t in threads: + t.start() + for t in threads: + t.join() + assert len(list_tcp_conn(ip)) == 1 + + +def test_conn_opened_with_multi_processing_multi_client_is_one( + pytestconfig, client, ip +): + threads = thread_list(lambda: client.get_exception_list()) + for t in threads: + t.start() + for t in threads: + t.join() + assert len(list_tcp_conn(ip)) == 1 + + +def list_tcp_conn(ipaddr): + return list( + filter( + lambda sc: len(sc[4]) > 0 + and sc[4][0] == ipaddr + and sc[5] == "ESTABLISHED", + psutil.net_connections("tcp"), + ) + ) + + +def thread_list(func): + return [Thread(target=func) for _ in range(10)] diff --git a/python/pytmv1/tests/integration/test_object.py b/python/pytmv1/tests/integration/test_object.py new file mode 100755 index 0000000..b1f4062 --- /dev/null +++ b/python/pytmv1/tests/integration/test_object.py @@ -0,0 +1,79 @@ +from pytmv1 import ( + GetExceptionListResp, + GetSuspiciousListResp, + MultiResp, + ObjectTask, + ObjectType, + ResultCode, + SuspiciousObjectTask, +) +from pytmv1.model.enums import ScanAction + + +def test_add_to_exception_list(client): + result = client.add_to_exception_list( + ObjectTask(objectType=ObjectType.IP, objectValue="1.1.1.1") + ) + assert isinstance(result.response, MultiResp) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].task_id is None + assert result.response.items[0].status == 201 + + +def test_add_to_suspicious_list(client): + result = client.add_to_suspicious_list( + SuspiciousObjectTask( + objectType=ObjectType.IP, + objectValue="1.1.1.1", + scanAction=ScanAction.BLOCK, + ) + ) + assert isinstance(result.response, MultiResp) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].task_id is None + assert result.response.items[0].status == 201 + + +def test_remove_from_exception_list(client): + result = client.remove_from_exception_list( + ObjectTask(objectType=ObjectType.IP, objectValue="1.1.1.1") + ) + assert isinstance(result.response, MultiResp) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].task_id is None + assert result.response.items[0].status == 204 + + +def test_remove_from_suspicious_list(client): + result = client.remove_from_suspicious_list( + ObjectTask(objectType=ObjectType.IP, objectValue="1.1.1.1") + ) + assert isinstance(result.response, MultiResp) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].task_id is None + assert result.response.items[0].status == 204 + + +def test_get_exception_list(client): + result = client.get_exception_list() + assert isinstance(result.response, GetExceptionListResp) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].type == ObjectType.URL + assert result.response.items[0].value == "https://*.example.com/path1/*" + + +def test_get_suspicious_list(client): + result = client.get_suspicious_list() + assert isinstance(result.response, GetSuspiciousListResp) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].type == ObjectType.FILE_SHA256 + assert ( + result.response.items[0].value + == "asidj123123jsdsidjsid123sidsidj123sss123s224212312312312312sdaas" + ) diff --git a/python/pytmv1/tests/integration/test_sandbox.py b/python/pytmv1/tests/integration/test_sandbox.py new file mode 100755 index 0000000..a0f2276 --- /dev/null +++ b/python/pytmv1/tests/integration/test_sandbox.py @@ -0,0 +1,112 @@ +from pytmv1 import ( + BytesResp, + ResultCode, + SandboxAnalysisResultResp, + SandboxSubmissionStatusResp, + SandboxSuspiciousListResp, + SubmitFileToSandboxResp, +) + + +def test_submit_file_to_sandbox(client): + result = client.submit_file_to_sandbox( + bytes("content", "utf-8"), "fileName.txt" + ) + assert isinstance(result.response, SubmitFileToSandboxResp) + assert result.result_code == ResultCode.SUCCESS + assert result.response.id + + +def test_submit_file_to_sandbox_is_too_large_file(client): + result = client.submit_file_to_sandbox( + bytes("content", "utf-8"), "tooBig.txt" + ) + assert result.result_code == ResultCode.ERROR + assert result.error.code == "RequestEntityTooLarge" + assert result.error.status == 413 + + +def test_submit_file_to_sandbox_is_too_many_request(client): + result = client.submit_file_to_sandbox( + bytes("content", "utf-8"), "tooMany.txt" + ) + assert result.result_code == ResultCode.ERROR + assert result.error.code == "TooManyRequests" + assert result.error.status == 429 + + +def test_submit_urls_to_sandbox_with_multi_url(client): + result = client.submit_urls_to_sandbox( + "https://trendmicro.com", "https://trendmicro2.com" + ) + assert result.result_code == ResultCode.SUCCESS + assert result.response.items[0].url == "https://www.trendmicro.com" + assert result.response.items[0].status == 202 + assert result.response.items[0].task_id == "00000005" + assert ( + result.response.items[0].id == "012e4eac-9bd9-4e89-95db-77e02f75a6f5" + ) + assert ( + result.response.items[0].digest.md5 + == "f3a2e1227de8d5ae7296665c1f34b28d" + ) + assert result.response.items[1].url == "https://www.trendmicro2.com" + assert result.response.items[1].status == 202 + assert result.response.items[1].task_id == "00000006" + assert ( + result.response.items[1].id + == "01232cs823-9bd9-4e89-95db-77e02f75a6f34" + ) + assert ( + result.response.items[1].digest.md5 + == "x23s2sd11227de8d5ae7296665c1f34b3212" + ) + + +def test_submit_urls_to_sandbox_is_bad_request(client): + result = client.submit_urls_to_sandbox("bad_request") + assert result.result_code == ResultCode.ERROR + assert result.errors[0].extra["url"] == "https://www.trendmicro.com" + assert result.errors[0].status == 202 + assert result.errors[0].task_id == "00000005" + assert result.errors[1].extra["url"] == "test" + assert result.errors[1].status == 400 + assert result.errors[1].code == "BadRequest" + assert result.errors[1].message == "URL format is not right" + + +def test_get_sandbox_submission_status(client): + result = client.get_sandbox_submission_status("123") + assert isinstance(result.response, SandboxSubmissionStatusResp) + assert result.result_code == ResultCode.SUCCESS + assert result.response.id == "123" + + +def test_get_sandbox_analysis_result(client): + result = client.get_sandbox_analysis_result("123", False) + assert isinstance(result.response, SandboxAnalysisResultResp) + assert result.result_code == ResultCode.SUCCESS + assert result.response.id == "123" + + +def test_get_sandbox_suspicious_list(client): + result = client.get_sandbox_suspicious_list("1", False) + assert isinstance(result.response, SandboxSuspiciousListResp) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 + assert result.response.items[0].type == "ip" + assert result.response.items[0].value == "6.6.6.6" + + +def test_download_sandbox_analysis_result(client): + result = client.download_sandbox_analysis_result("1", False) + assert isinstance(result.response, BytesResp) + assert result.result_code == ResultCode.SUCCESS + assert result.response.content + + +def test_download_sandbox_investigation_package(client): + result = client.download_sandbox_investigation_package("1", False) + assert isinstance(result.response, BytesResp) + assert result.result_code == ResultCode.SUCCESS + assert result.response.content diff --git a/python/pytmv1/tests/integration/test_search.py b/python/pytmv1/tests/integration/test_search.py new file mode 100755 index 0000000..a408759 --- /dev/null +++ b/python/pytmv1/tests/integration/test_search.py @@ -0,0 +1,46 @@ +from pytmv1 import ( + GetEmailActivityDataCountResp, + GetEmailActivityDataResp, + GetEndpointActivityDataCountResp, + GetEndpointActivityDataResp, + GetEndpointDataResp, + QueryOp, + ResultCode, +) + + +def test_get_email_activity_data(client): + result = client.get_email_activity_data( + mailMsgSubject="spam", mailSenderIp="192.169.1.1" + ) + assert result.result_code == ResultCode.SUCCESS + assert isinstance(result.response, GetEmailActivityDataResp) + assert len(result.response.items) > 0 + + +def test_get_email_activity_data_count(client): + result = client.get_email_activity_data_count(mailMsgSubject="spam") + assert result.result_code == ResultCode.SUCCESS + assert isinstance(result.response, GetEmailActivityDataCountResp) + assert result.response.total_count > 0 + + +def test_get_endpoint_activity_data(client): + result = client.get_endpoint_activity_data(dpt="443") + assert result.result_code == ResultCode.SUCCESS + assert isinstance(result.response, GetEndpointActivityDataResp) + assert len(result.response.items) > 0 + + +def test_get_endpoint_activity_count(client): + result = client.get_endpoint_activity_data_count(dpt="443") + assert result.result_code == ResultCode.SUCCESS + assert isinstance(result.response, GetEndpointActivityDataCountResp) + assert result.response.total_count > 0 + + +def test_get_endpoint_data(client): + result = client.get_endpoint_data(QueryOp.AND, "client1") + assert result.result_code == ResultCode.SUCCESS + assert isinstance(result.response, GetEndpointDataResp) + assert len(result.response.items) > 0 diff --git a/python/pytmv1/tests/integration/test_workbench.py b/python/pytmv1/tests/integration/test_workbench.py new file mode 100755 index 0000000..e06ba01 --- /dev/null +++ b/python/pytmv1/tests/integration/test_workbench.py @@ -0,0 +1,77 @@ +from pytmv1 import ( + AddAlertNoteResp, + GetAlertListResp, + InvestigationStatus, + NoContentResp, + Provider, + ResultCode, +) + + +def test_add_alert_note(client): + result = client.add_alert_note("1", "dummy note") + assert isinstance(result.response, AddAlertNoteResp) + assert result.result_code == ResultCode.SUCCESS + assert result.response.note_id().isdigit() + + +def test_consume_alert_list(client): + result = client.consume_alert_list( + lambda s: None, "2020-06-15T10:00:00Z", "2020-06-15T10:00:00Z" + ) + assert result.result_code == ResultCode.SUCCESS + assert result.response.total_consumed == 2 + + +def test_consume_alert_list_with_next_link(client): + result = client.consume_alert_list( + lambda s: None, "next_link", "2020-06-15T10:00:00Z" + ) + assert result.result_code == ResultCode.SUCCESS + assert result.response.total_consumed == 11 + + +def test_edit_alert_status(client): + result = client.edit_alert_status( + "1", + InvestigationStatus.IN_PROGRESS, + "d41d8cd98f00b204e9800998ecf8427e", + ) + assert isinstance(result.response, NoContentResp) + assert result.result_code == ResultCode.SUCCESS + + +def test_edit_alert_status_is_precondition_failed(client): + result = client.edit_alert_status( + "1", InvestigationStatus.IN_PROGRESS, "precondition_failed" + ) + assert not result.response + assert result.result_code == ResultCode.ERROR + assert result.error.code == "ConditionNotMet" + assert result.error.status == 412 + + +def test_edit_alert_status_is_not_found(client): + result = client.edit_alert_status( + "1", InvestigationStatus.IN_PROGRESS, "not_found" + ) + assert not result.response + assert result.result_code == ResultCode.ERROR + assert result.error.code == "NotFound" + assert result.error.status == 404 + + +def test_get_alert_details(client): + result = client.get_alert_details("12345") + assert result.result_code == ResultCode.SUCCESS + assert result.response.alert.alert_provider == Provider.SAE + assert result.response.etag == "33a64df551425fcc55e4d42a148795d9f25f89d4" + + +def test_get_alert_list(client): + result = client.get_alert_list( + "2020-06-15T10:00:00Z", "2020-06-15T10:00:00Z" + ) + assert isinstance(result.response, GetAlertListResp) + assert result.result_code == ResultCode.SUCCESS + assert len(result.response.items) > 0 diff --git a/python/pytmv1/tests/unit/__init__.py b/python/pytmv1/tests/unit/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/python/pytmv1/tests/unit/test_caller.py b/python/pytmv1/tests/unit/test_caller.py new file mode 100755 index 0000000..ef23cc6 --- /dev/null +++ b/python/pytmv1/tests/unit/test_caller.py @@ -0,0 +1,9 @@ +import pytmv1 +from pytmv1.core import API_VERSION + + +def test_client(): + client = pytmv1.client("dummy_name", "dummy_token", "https://dummy.com") + assert client._core._appname == "dummy_name" + assert client._core._token == "dummy_token" + assert client._core._url == "https://dummy.com/" + API_VERSION diff --git a/python/pytmv1/tests/unit/test_core.py b/python/pytmv1/tests/unit/test_core.py new file mode 100755 index 0000000..2c947e9 --- /dev/null +++ b/python/pytmv1/tests/unit/test_core.py @@ -0,0 +1,497 @@ +import time + +import pytest +from pydantic import ValidationError +from requests import RequestException, Response + +from pytmv1 import ( + AddAlertNoteResp, + BytesResp, + CollectFileTaskResp, + Error, + ExceptionObject, + GetExceptionListResp, + MsData, + MsError, + MultiResp, + NoContentResp, + ResultCode, + SandboxAnalysisResultResp, + SandboxSubmissionStatusResp, + SandboxSuspiciousListResp, + SandboxSuspiciousObject, + Status, + __version__, +) +from pytmv1 import core as core_m +from pytmv1 import results +from pytmv1.core import API_VERSION, USERAGENT_SUFFIX, Core +from pytmv1.exceptions import ( + ParseModelError, + ServerHtmlError, + ServerJsonError, + ServerMultiJsonError, + ServerTextError, +) +from pytmv1.model.enums import Api, RiskLevel +from pytmv1.model.responses import BaseStatusResponse +from tests.data import TextResponse + +API_URL = "https://dummy.com/v3.0" + + +def test_consume_linkable_with_next_link_multiple_items(mocker, core): + mock_process = mocker.patch.object( + core, + "_process", + side_effect=[ + GetExceptionListResp( + nextLink="not_empty", + items=[ + ExceptionObject.construct(), + ExceptionObject.construct(), + ], + ), + GetExceptionListResp( + items=[ + ExceptionObject.construct(), + ExceptionObject.construct(), + ] + ), + ], + ) + total = core._consume_linkable( + lambda: core._process(GetExceptionListResp, Api.GET_EXCEPTION_LIST), + lambda x: None, + {}, + ) + assert mock_process.call_count == 2 + assert total == 4 + + +def test_consume_linkable_with_next_link_single_item(mocker, core): + mock_process = mocker.patch.object( + core, + "_process", + side_effect=[ + GetExceptionListResp( + nextLink="https://host/api/path?skipToken=c2tpcFRva2Vu", + items=[], + ), + GetExceptionListResp(items=[ExceptionObject.construct()]), + ], + ) + total = core._consume_linkable( + lambda: core._process(GetExceptionListResp, Api.GET_EXCEPTION_LIST), + lambda x: None, + {}, + ) + mock_process.assert_called() + assert total == 1 + + +def test_consume_linkable_without_next_link(mocker, core): + mock_process = mocker.patch.object( + core, "_process", return_value=GetExceptionListResp(items=[]) + ) + total = core._consume_linkable( + lambda: core._process(GetExceptionListResp, Api.GET_EXCEPTION_LIST), + lambda x: None, + {}, + ) + mock_process.assert_called() + assert total == 0 + + +def test_error(): + error = results._error( + ServerJsonError( + Error(status=500, code="X12", message="error", number=123) + ) + ) + assert error.status == 500 + assert error.code == "X12" + assert error.message == "error" + assert error.number == 123 + + +def test_errors(): + errors = results._errors( + ServerMultiJsonError( + [ + MsError(status=123, code="code", message="message"), + MsError(status=456, code="code2", message="message2"), + ] + ) + ) + assert errors[0].status == 123 + assert errors[0].code == "code" + assert errors[0].message == "message" + assert errors[1].status == 456 + assert errors[1].code == "code2" + assert errors[1].message == "message2" + + +def test_headers(core): + assert core._headers["Authorization"] == "Bearer token" + assert core._headers["User-Agent"] == "appname-{}/{}".format( + USERAGENT_SUFFIX, __version__ + ) + + +def test_hide_binary(): + raw_response = Response() + raw_response.headers = {"Content-Type": "application/pdf"} + assert core_m._hide_binary(raw_response) == "***binary content***" + raw_response.headers = {"Content-Type": "application/zip"} + assert core_m._hide_binary(raw_response) == "***binary content***" + raw_response.headers = {"Content-Type": "application/octet-stream"} + assert core_m._hide_binary(raw_response) == "***binary content***" + + +def test_is_http_success(): + assert core_m._is_http_success([200, 400, 600, 500]) is False + assert core_m._is_http_success([199, 400, 600, 500]) is False + assert core_m._is_http_success([400]) is False + assert core_m._is_http_success([200, 300, 398, 204]) is True + assert core_m._is_http_success([200]) is True + + +def test_parse_data_with_bytes(): + raw_response = TextResponse("raw") + raw_response.headers = {"Content-Type": "application/pdf"} + response = core_m._parse_data(raw_response, BytesResp) + assert response.content == raw_response.content + + +def test_parse_data_with_html_is_failed(): + raw_response = Response() + raw_response.headers = {"Content-Type": "text/html"} + with pytest.raises(ParseModelError): + core_m._parse_data(raw_response, NoContentResp) + + +def test_parse_data_with_json(): + raw_response = Response() + raw_response.headers = {"Content-Type": "application/json"} + raw_response.json = lambda: SandboxSuspiciousListResp( + items=[ + SandboxSuspiciousObject( + riskLevel=RiskLevel.HIGH, + analysisCompletionDateTime="2021-05-07T03:08:40", + expiredDateTime="2021-06-07T03:08:40Z", + rootSha1="fb5608fa03de204a12fe1e9e5275e4a682107471", + ip="6.6.6.6", + ) + ] + ) + response = core_m._parse_data(raw_response, SandboxSuspiciousListResp) + assert response.items[0].risk_level == "high" + assert ( + response.items[0].analysis_completion_date_time + == "2021-05-07T03:08:40" + ) + assert response.items[0].expired_date_time == "2021-06-07T03:08:40Z" + assert ( + response.items[0].root_sha1 + == "fb5608fa03de204a12fe1e9e5275e4a682107471" + ) + assert response.items[0].type == "ip" + assert response.items[0].value == "6.6.6.6" + + +def test_parse_data_with_multi_and_wrong_model_is_failed(): + raw_response = Response() + raw_response.headers = {"Content-Type": "application/json"} + raw_response.status_code = 207 + raw_response.json = lambda: MultiResp(items=[MsData(status=200)]) + with pytest.raises(ValidationError): + core_m._parse_data(raw_response, AddAlertNoteResp) + + +def test_parse_data_with_single_and_wrong_model_is_failed(): + raw_response = Response() + raw_response.headers = {"Content-Type": "application/json"} + raw_response.status_code = 200 + raw_response.json = lambda: AddAlertNoteResp(location="test") + with pytest.raises(ValidationError): + core_m._parse_data(raw_response, MultiResp) + + +def test_parse_data_without_content(): + raw_response = Response() + raw_response.status_code = 204 + response = core_m._parse_data(raw_response, NoContentResp) + assert isinstance(response, NoContentResp) + + +def test_parse_html(): + result = core_m._parse_html("

test

") + assert result == "test" + + +def test_poll_status_with_rejected_status_is_not_polling(): + start_time = time.time() + core_m._poll_status( + lambda: BaseStatusResponse.construct(status=Status.REJECTED), + 20, + ) + assert time.time() - start_time < 20 + + +def test_poll_status_with_running_status_is_polling(): + start_time = time.time() + core_m._poll_status( + lambda: BaseStatusResponse.construct(status=Status.RUNNING), + 2, + ) + assert time.time() - start_time >= 2 + + +def test_poll_status_with_succeeded_status(): + start_time = time.time() + core_m._poll_status( + lambda: BaseStatusResponse.construct(status=Status.SUCCEEDED), + 20, + ) + assert time.time() - start_time < 20 + + +def test_send(core, mocker): + raw_response = Response() + raw_response.status_code = 204 + mock_request = mocker.patch.object(core, "_send_internal") + mock_request.return_value = raw_response + result = core.send(NoContentResp, Api.EDIT_ALERT_STATUS) + mock_request.assert_called() + assert result.result_code == ResultCode.SUCCESS + + +def test_send_linkable(mocker, core): + mock_process = mocker.patch.object(core, "_process") + mock_process.return_value = GetExceptionListResp( + items=[ExceptionObject.construct()] + ) + result = core.send_linkable( + GetExceptionListResp, + Api.GET_EXCEPTION_LIST, + lambda x: None, + ) + mock_process.assert_called() + assert result.result_code == ResultCode.SUCCESS + assert result.response.total_consumed == 1 + + +def test_send_sandbox_result_with_polling(core, mocker): + mock_poll = mocker.patch.object(core_m, "_poll_status") + mock_poll.return_value = SandboxSubmissionStatusResp.construct( + status=Status.SUCCEEDED + ) + mock_send = mocker.patch.object(core, "_process") + result = core.send_sandbox_result( + SandboxAnalysisResultResp, + Api.GET_SANDBOX_ANALYSIS_RESULT, + "123", + True, + 0, + ) + mock_poll.assert_called() + mock_send.assert_called() + assert result.result_code == ResultCode.SUCCESS + + +def test_send_sandbox_result_with_polling_is_failed(core, mocker): + mock_poll = mocker.patch.object( + core_m, "_poll_status", side_effect=RequestException() + ) + mock_send = mocker.patch.object(core, "_process") + result = core.send_sandbox_result( + SandboxAnalysisResultResp, + Api.GET_SANDBOX_ANALYSIS_RESULT, + "123", + True, + 0, + ) + mock_poll.assert_called() + mock_send.assert_not_called() + assert result.result_code == ResultCode.ERROR + assert result.error.status == 500 + assert result.error.code == "RequestException" + + +def test_send_sandbox_result_without_polling(core, mocker): + mock_poll = mocker.patch.object(core_m, "_poll_status") + mock_send = mocker.patch.object(core, "_process") + result = core.send_sandbox_result( + SandboxAnalysisResultResp, + Api.GET_SANDBOX_ANALYSIS_RESULT, + "123", + False, + 0, + ) + mock_poll.assert_not_called() + mock_send.assert_called() + assert result.result_code == ResultCode.SUCCESS + + +def test_send_sandbox_result_without_polling_is_failed(core, mocker): + mock_poll = mocker.patch.object(core_m, "_poll_status") + mock_send = mocker.patch.object( + core, "_process", side_effect=RequestException() + ) + result = core.send_sandbox_result( + SandboxAnalysisResultResp, + Api.GET_SANDBOX_ANALYSIS_RESULT, + "123", + False, + 0, + ) + mock_poll.assert_not_called() + mock_send.assert_called() + assert result.result_code == ResultCode.ERROR + assert result.error.status == 500 + assert result.error.code == "RequestException" + + +def test_send_task_result(core, mocker): + mock_poll = mocker.patch.object(core_m, "_poll_status") + mock_send = mocker.patch.object(core, "_process") + result = core.send_task_result(CollectFileTaskResp, "123", False, 0) + mock_poll.assert_not_called() + mock_send.assert_called() + assert result.result_code == ResultCode.SUCCESS + + +def test_send_task_result_is_failed(core, mocker): + mock_poll = mocker.patch.object(core_m, "_poll_status") + mock_send = mocker.patch.object( + core, "_process", side_effect=RequestException() + ) + result = core.send_task_result(CollectFileTaskResp, "123", False, 0) + mock_poll.assert_not_called() + mock_send.assert_called() + assert result.result_code == ResultCode.ERROR + assert result.error.status == 500 + assert result.error.code == "RequestException" + + +def test_send_task_result_with_poll(core, mocker): + mock_poll = mocker.patch.object(core_m, "_poll_status") + mock_send = mocker.patch.object(core, "_process") + result = core.send_task_result(CollectFileTaskResp, "123", True, 0) + mock_poll.assert_called() + mock_send.assert_called() + assert result.result_code == ResultCode.SUCCESS + + +def test_send_task_result_with_poll_is_failed(core, mocker): + mock_poll = mocker.patch.object( + core_m, "_poll_status", side_effect=RequestException() + ) + mock_send = mocker.patch.object(core, "_process") + result = core.send_task_result(CollectFileTaskResp, "123", True, 0) + mock_poll.assert_called() + mock_send.assert_not_called() + assert result.result_code == ResultCode.ERROR + assert result.error.status == 500 + assert result.error.code == "RequestException" + + +def test_send_with_request_exception_is_failed(core, mocker): + mocker.patch.object(core, "_send_internal", side_effect=RequestException()) + result = core.send(GetExceptionListResp, Api.GET_EXCEPTION_LIST) + assert result.result_code == ResultCode.ERROR + assert result.error.status == 500 + assert result.error.code == "RequestException" + + +def test_send_with_runtime_error_is_failed(core, mocker): + mocker.patch.object(core, "_send_internal", side_effect=RuntimeError()) + result = core.send(GetExceptionListResp, Api.GET_EXCEPTION_LIST) + assert result.result_code == ResultCode.ERROR + assert result.error.status == 500 + assert result.error.code == "RuntimeError" + + +def test_send_with_validation_error_is_failed(core, mocker): + mocker.patch.object( + core, + "_send_internal", + side_effect=ValidationError([], NoContentResp), + ) + result = core.send(GetExceptionListResp, Api.GET_EXCEPTION_LIST) + assert result.result_code == ResultCode.ERROR + assert result.error.status == 500 + assert result.error.code == "ValidationError" + + +def test_status(): + status = results._status(ServerTextError(450, "error")) + assert status == 450 + + +def test_url_with_trailing_slash(): + test_core = Core("", "", "https://dummy/", 0, 0, 30, 30) + assert test_core._url == "https://dummy/" + API_VERSION + + +def test_url_without_trailing_slash(core): + test_core = Core("", "", "https://dummy", 0, 0, 30, 30) + assert test_core._url == "https://dummy/" + API_VERSION + + +def test_validate_with_html_is_failed(): + raw_response = Response() + raw_response.status_code = 200 + raw_response.headers = {"Content-Type": "text/html"} + with pytest.raises(ServerHtmlError): + core_m._validate(raw_response) + + +def test_validate_with_json_error_is_failed(): + raw_response = Response() + raw_response.headers = {"Content-Type": "application/json"} + raw_response.status_code = 500 + raw_response.json = lambda: { + "error": {"code": "CODE", "message": "some error", "number": 1} + } + with pytest.raises(ServerJsonError, match="some error"): + core_m._validate(raw_response) + + +def test_validate_with_text_error_is_failed(): + raw_response = TextResponse("some text") + raw_response.status_code = 500 + with pytest.raises(ServerTextError, match="some text"): + core_m._validate(raw_response) + + +def test_validate_multi_with_multi_data_is_failed(): + raw_response = Response() + raw_response.status_code = 207 + raw_response.json = lambda: [ + {"status": "400", "code": "code", "message": "message"}, + {"status": "403", "code": "code", "message": "message"}, + ] + with pytest.raises(ServerMultiJsonError, match="400.*403"): + core_m._validate(raw_response) + + +def test_validate_multi_with_single_data_is_failed(): + raw_response = Response() + raw_response.status_code = 207 + raw_response.json = lambda: [ + { + "status": "400", + "headers": [ + { + "name": "Operation-Location", + "value": "https://dummy-test.com/task/000004", + } + ], + "code": "code", + "message": "message", + } + ] + with pytest.raises(ServerMultiJsonError, match="400"): + core_m._validate(raw_response) diff --git a/python/pytmv1/tests/unit/test_mapper.py b/python/pytmv1/tests/unit/test_mapper.py new file mode 100755 index 0000000..c23bfa5 --- /dev/null +++ b/python/pytmv1/tests/unit/test_mapper.py @@ -0,0 +1,132 @@ +from pytmv1 import ( + Entity, + HostInfo, + Indicator, + InvestigationStatus, + Severity, + mapper, +) +from tests import data + + +def test_map_cef_with_sae_alert(mocker): + mock_mapper = mocker.patch.object(mapper, "_map_sae") + mapper.map_cef(data.sae_alert()) + mock_mapper.assert_called() + + +def test_map_cef_with_ti_alert(mocker): + mock_mapper = mocker.patch.object(mapper, "_map_ti") + mapper.map_cef(data.ti_alert()) + mock_mapper.assert_called() + + +def test_map_common(): + dictionary = mapper._map_common(data.sae_alert()) + assert dictionary["externalId"] == "1" + assert dictionary["act"] == InvestigationStatus.NEW.value + assert dictionary["cat"] == "Possible Credential Dumping via Registry" + assert dictionary["Severity"] == Severity.HIGH.value + assert dictionary["rt"] == "2022-09-06T02:49:33Z" + assert dictionary["sourceServiceName"] == "SAE" + assert dictionary["msg"] == "Workbench Link: https://THE_WORKBENCH_URL" + assert dictionary["cnt"] == "64" + assert dictionary["cn1"] == "1" + assert dictionary["cn1Label"] == "Desktop Count" + assert dictionary["cn2"] == "0" + assert dictionary["cn2Label"] == "Server Count" + assert dictionary["cn3"] == "1" + assert dictionary["cn3Label"] == "Account Count" + assert dictionary["cn4"] == "0" + assert dictionary["cn4Label"] == "Email Address Count" + assert dictionary["cs1"] == "Alert" + assert dictionary["cs1Label"] == "Provenance" + + +def test_map_entities_with_type_email(): + entities = [Entity.construct(entity_value="email@email.com")] + dictionary = {} + mapper._map_entities(dictionary, entities) + assert dictionary["duser"] == "email@email.com" + + +def test_map_entities_with_type_host_info(): + entities = [ + Entity.construct( + entity_value=HostInfo.construct( + name="host", ips=["1.1.1.1", "2.2.2.2"] + ) + ) + ] + dictionary = {} + mapper._map_entities(dictionary, entities) + assert dictionary["dhost"] == "host" + assert dictionary["dst"] == "1.1.1.1, 2.2.2.2" + + +def test_map_entities_with_type_user(): + entities = [Entity.construct(entity_value="username")] + dictionary = {} + mapper._map_entities(dictionary, entities) + assert dictionary["duser"] == "username" + + +def test_map_indicators_with_type_command_line(): + indicators = [Indicator.construct(type="command_line", value="cmd.exe")] + dictionary = {} + mapper._map_indicators(dictionary, indicators) + assert dictionary["dproc"] == "cmd.exe" + + +def test_map_indicators_with_type_host_info(): + indicators = [ + Indicator.construct( + value=HostInfo.construct(name="host", ips=["1.1.1.1", "2.2.2.2"]) + ) + ] + dictionary = {} + mapper._map_indicators(dictionary, indicators) + assert dictionary["shost"] == "host" + assert dictionary["src"] == "1.1.1.1, 2.2.2.2" + + +def test_map_indicators_with_unknown_type(): + indicators = [Indicator.construct(type="unknown_type", value="unknown")] + dictionary = {} + mapper._map_indicators(dictionary, indicators) + assert dictionary["unknownType"] == "unknown" + + +def test_map_sae(): + alert = data.sae_alert() + dictionary = mapper._map_common(alert) + mapper._map_sae(dictionary, alert) + assert dictionary["cs2"] == "Possible Credential Dumping via Registry Hive" + assert dictionary["cs2Label"] == "Matched Filter" + assert dictionary["cs3"] == "V9.T1003.004, V9.T1003.002, T1003" + assert dictionary["cs3Label"] == "Matched Techniques" + assert dictionary["reason"] == "Potential Credential Dumping via Registry" + assert ( + dictionary["msg"] + == "Workbench Link: https://THE_WORKBENCH_URL\nDescription:" + " description" + ) + + +def test_map_ti(): + alert = data.ti_alert() + dictionary = mapper._map_common(alert) + mapper._map_ti(dictionary, alert) + assert dictionary["cs2"] == "STIX2.malicious-activity" + assert dictionary["cs2Label"] == "Matched Pattern Tags" + assert dictionary["cs3"] == "[file:name = 'goog-phish-proto-1.vlpset']" + assert dictionary["cs3Label"] == "Matched Pattern" + assert ( + dictionary["msg"] + == "Workbench Link: https://THE_WORKBENCH_URL\nReport Link:" + " https://THE_TI_REPORT_URL" + ) + assert dictionary["createdBy"] == "n/a" + assert dictionary["campaign"] == "campaign" + assert dictionary["industry"] == "industry" + assert dictionary["regionAndCountry"] == "regionAndCountry" diff --git a/python/pytmv1/tests/unit/test_utils.py b/python/pytmv1/tests/unit/test_utils.py new file mode 100755 index 0000000..9e55d42 --- /dev/null +++ b/python/pytmv1/tests/unit/test_utils.py @@ -0,0 +1,116 @@ +from pytmv1 import OperatingSystem, ProductCode, QueryField, QueryOp, utils + + +def test_b64_encode(): + assert utils._b64_encode("testString") == "dGVzdFN0cmluZw==" + + +def test_b64_encode_with_none(): + assert utils._b64_encode(None) is None + + +def test_endpoint_query_field(): + assert utils.endpoint_query_field("client1")[0] == QueryField.ENDPOINT_NAME + assert utils.endpoint_query_field("client1")[1] == QueryField.LOGIN_ACCOUNT + assert utils.endpoint_query_field("1.1.1.1")[0] == QueryField.IP + assert ( + utils.endpoint_query_field("A1-7B-A5-63-16-F8")[0] + == QueryField.MAC_ADDRESS + ) + assert ( + utils.endpoint_query_field("35fa11da-a24e-40cf-8b56-baf8828cc151")[0] + == QueryField.AGENT_GUID + ) + assert utils.endpoint_query_field("Linux")[0] == QueryField.OS_NAME + assert utils.endpoint_query_field("sao")[0] == QueryField.PRODUCT_CODE + assert ( + utils.endpoint_query_field("sao")[1] + == QueryField.INSTALLED_PRODUCT_CODES + ) + + +def test_endpoint_query_with_endpoint_name(): + assert ( + utils.endpoint_query(QueryOp.AND, "dummy").get("TMV1-Query") + == "(endpointName eq 'dummy' or loginAccount eq 'dummy')" + ) + + +def test_endpoint_query_with_ip(): + assert ( + utils.endpoint_query(QueryOp.AND, "1.1.1.1").get("TMV1-Query") + == "(ip eq '1.1.1.1')" + ) + + +def test_endpoint_query_with_login_account(): + assert ( + utils.endpoint_query(QueryOp.AND, "DOMAIN\\Name_Lastname").get( + "TMV1-Query" + ) + == "(endpointName eq 'DOMAIN\\Name_Lastname' or" + " loginAccount eq 'DOMAIN\\Name_Lastname')" + ) + + +def test_endpoint_query_with_mac_address(): + assert ( + utils.endpoint_query(QueryOp.AND, "A1-7B-A5-63-16-F8").get( + "TMV1-Query" + ) + == "(macAddress eq 'A1-7B-A5-63-16-F8')" + ) + + +def test_endpoint_query_with_multiple_os_name_or_operator(): + assert ( + utils.endpoint_query( + QueryOp.OR, + OperatingSystem.WINDOWS.value, + OperatingSystem.LINUX.value, + ).get("TMV1-Query") + == "(osName eq 'Windows') or (osName eq 'Linux')" + ) + + +def test_endpoint_query_with_product_code_os_name_and_operator(): + assert ( + utils.endpoint_query( + QueryOp.AND, ProductCode.SAO.value, OperatingSystem.WINDOWS.value + ).get("TMV1-Query") + == "(productCode eq 'sao' or installedProductCodes eq 'sao') and" + " (osName eq 'Windows')" + ) + + +def test_endpoint_query_with_os_name(): + assert ( + utils.endpoint_query(QueryOp.AND, OperatingSystem.WINDOWS.value).get( + "TMV1-Query" + ) + == "(osName eq 'Windows')" + ) + + +def test_endpoint_query_with_product_code(): + assert ( + utils.endpoint_query(QueryOp.AND, ProductCode.SAO.value).get( + "TMV1-Query" + ) + == "(productCode eq 'sao' or installedProductCodes eq 'sao')" + ) + + +def test_filter_none(): + dictionary = utils.filter_none({"123": None}) + assert len(dictionary) == 0 + dictionary = utils.filter_none({"123": "Value"}) + assert len(dictionary) == 1 + + +def test_is_ip_address(): + assert not utils._is_ip_address("1.1.1") + assert not utils._is_ip_address("testvalue.com") + assert not utils._is_ip_address("A1-7B-A5-63-16-F8") + assert utils._is_ip_address("1.1.1.1") + assert utils._is_ip_address("2001:0db8:85a3:0000:0000:8a2e:0370:7334") diff --git a/python/pytmv1/tox.ini b/python/pytmv1/tox.ini new file mode 100755 index 0000000..233f531 --- /dev/null +++ b/python/pytmv1/tox.ini @@ -0,0 +1,5 @@ +[flake8] +exclude = .git,.github,.mypy_cache,.pytest_cache,__pycache__,dist,venv +max-complexity = 10 +max-line-length = 79 +statistics = true \ No newline at end of file