diff --git a/hasher-matcher-actioner/.devcontainer/postcreate.sh b/hasher-matcher-actioner/.devcontainer/postcreate.sh index 25db67b42..101494c2b 100755 --- a/hasher-matcher-actioner/.devcontainer/postcreate.sh +++ b/hasher-matcher-actioner/.devcontainer/postcreate.sh @@ -3,7 +3,7 @@ set -e pip install --editable .[all] -# Find Python packages in opt and install them +# Find Python packages in extensions and install them for setup_script in "$(find /workspace/extensions -name setup.py)" do module_dir="$(dirname "$setup_script")" diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/app.py b/hasher-matcher-actioner/src/OpenMediaMatch/app.py index be8a34bee..b7b11182b 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/app.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/app.py @@ -7,7 +7,7 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") from threatexchange.signal_type.pdq import signal as _ -## Resume regularly scheduled imports +# Resume regularly scheduled imports import logging import os diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py index 658349e80..eeb9905bd 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py @@ -139,7 +139,7 @@ def query_index( try: signal = signal_type.validate_signal_str(signal) except Exception as e: - abort(400, f"invalid signal type: {e}") + abort(400, f"invalid signal: {e}") index = _get_index(signal_type) @@ -203,8 +203,10 @@ def lookup_get(): Output: * List of matching banks """ - # Url was passed as a query param? if request.args.get("url", None): + if not current_app.config.get("ROLE_HASHER", False): + abort(403, "Hashing is disabled, missing role") + hashes = hashing.hash_media() resp = {} for signal_type in hashes.keys(): @@ -230,6 +232,9 @@ def lookup_post(): Output: * List of matching banks """ + if not current_app.config.get("ROLE_HASHER", False): + abort(403, "Hashing is disabled, missing role") + hashes = hashing.hash_media_post() resp = {} diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py b/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py index d0b1b378a..4178f2a82 100644 --- a/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py +++ b/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py @@ -1,9 +1,13 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +from io import BytesIO +import tempfile import typing as t from flask.testing import FlaskClient from flask import Flask +from PIL import Image +import requests from threatexchange.exchanges.impl.fb_threatexchange_api import ( FBThreatExchangeSignalExchangeAPI, @@ -16,13 +20,9 @@ from OpenMediaMatch.tests.utils import ( app, client, - create_bank, - add_hash_to_bank, - IMAGE_URL_TO_PDQ, ) from OpenMediaMatch.background_tasks.build_index import build_all_indices from OpenMediaMatch.persistence import get_storage -from OpenMediaMatch.storage.postgres import database def test_status_response(client: FlaskClient): @@ -31,174 +31,7 @@ def test_status_response(client: FlaskClient): assert response.data == b"I-AM-ALIVE" -def test_banks_empty_index(client: FlaskClient): - response = client.get("/c/banks") - assert response.status_code == 200 - assert response.json == [] - - -def test_banks_create(client: FlaskClient): - # Must not start with number - post_response = client.post( - "/c/banks", - json={"name": "01_TEST_BANK"}, - ) - assert post_response.status_code == 400 - - # Cannot contain lowercase letters - post_response = client.post( - "/c/banks", - json={"name": "my_test_bank"}, - ) - assert post_response.status_code == 400 - - post_response = client.post( - "/c/banks", - json={"name": "MY_TEST_BANK_01"}, - ) - assert post_response.status_code == 201 - assert post_response.json == { - "matching_enabled_ratio": 1.0, - "name": "MY_TEST_BANK_01", - } - - # Should now be visible on index - response = client.get("/c/banks") - assert response.status_code == 200 - assert response.json == [post_response.json] - - -def test_banks_update(client: FlaskClient): - post_response = client.post( - "/c/banks", - json={"name": "MY_TEST_BANK"}, - ) - assert post_response.status_code == 201 - - # check name validation - post_response = client.put( - "/c/bank/MY_TEST_BANK", - json={"name": "1_invalid_name"}, - ) - assert post_response.status_code == 400 - - # check update with rename - post_response = client.put( - "/c/bank/MY_TEST_BANK", - json={"name": "MY_TEST_BANK_RENAMED"}, - ) - assert post_response.status_code == 200 - assert post_response.get_json()["name"] == "MY_TEST_BANK_RENAMED" - - # check update without rename - post_response = client.put( - "/c/bank/MY_TEST_BANK_RENAMED", - json={"enabled": False}, - ) - assert post_response.status_code == 200 - assert post_response.get_json()["matching_enabled_ratio"] == 0 - - # check update without ratio - post_response = client.put( - "/c/bank/MY_TEST_BANK_RENAMED", - json={"enabled_ratio": 0.5}, - ) - assert post_response.status_code == 200 - assert post_response.get_json()["matching_enabled_ratio"] == 0.5 - - # Final test to make sure we only have one bank with proper name and disabled - - get_response = client.get("/c/banks") - assert get_response.status_code == 200 - json = get_response.get_json() - assert len(json) == 1 - assert json[0] == {"name": "MY_TEST_BANK_RENAMED", "matching_enabled_ratio": 0.5} - - -def test_banks_delete(client: FlaskClient): - post_response = client.post( - "/c/banks", - json={"name": "MY_TEST_BANK"}, - ) - assert post_response.status_code == 201 - - # check name validation - post_response = client.delete( - "/c/bank/MY_TEST_BANK", - ) - assert post_response.status_code == 200 - - # deleting non existing bank should succeed - post_response = client.delete( - "/c/bank/MY_TEST_BANK", - ) - assert post_response.status_code == 200 - - -def test_banks_add_hash(client: FlaskClient): - bank_name = "NEW_BANK" - create_bank(client, bank_name) - - image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true" - - post_response = client.post( - f"/c/bank/{bank_name}/content?url={image_url}&content_type=photo" - ) - - assert post_response.status_code == 200, str(post_response.get_json()) - assert post_response.json == { - "id": 1, - "signals": { - "pdq": "f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22" - }, - } - - -def test_banks_delete_hash(client: FlaskClient): - bank_name = "NEW_BANK" - image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true" - - create_bank(client, bank_name) - add_hash_to_bank(client, bank_name, image_url, 1) - - post_response = client.delete(f"/c/bank/{bank_name}/content/1") - - assert post_response.status_code == 200 - assert post_response.json == {"deleted": 1} - - -def test_banks_add_metadata(client: FlaskClient): - bank_name = "NEW_BANK" - create_bank(client, bank_name) - - image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true" - post_request = f"/c/bank/{bank_name}/content?url={image_url}&content_type=photo" - - post_response = client.post( - post_request, json={"metadata": {"invalid_metadata": 5}} - ) - assert post_response.status_code == 400, str(post_response.get_json()) - - post_response = client.post( - post_request, - json={"metadata": {"content_id": "1197433091", "json": {"asdf": {}}}}, - ) - - assert post_response.status_code == 200, str(post_response.get_json()) - - -def test_banks_add_hash_index(app: Flask, client: FlaskClient): - bank_name = "NEW_BANK" - bank_name_2 = "NEW_BANK_2" - image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true" - image_url_2 = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/misc-images/c.png?raw=true" - - # Make two banks and add images to each bank - create_bank(client, bank_name) - add_hash_to_bank(client, bank_name, image_url, 1) - create_bank(client, bank_name_2) - add_hash_to_bank(client, bank_name, image_url_2, 2) - +def test_lookup_success(app: Flask, client: FlaskClient): storage = get_storage() # ensure index is empty to start with assert storage.get_signal_type_index(PdqSignal) is None @@ -206,19 +39,40 @@ def test_banks_add_hash_index(app: Flask, client: FlaskClient): # Build index build_all_indices(storage, storage, storage) - # Test against first image - post_response = client.get( - f"/m/raw_lookup?signal_type=pdq&signal={IMAGE_URL_TO_PDQ[image_url]}" - ) - assert post_response.status_code == 200 - assert post_response.json == {"matches": [1]} + # test GET + image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true" + get_resp = client.get(f"/m/lookup?url={image_url}") + assert get_resp.status_code == 200 - # Test against second image - post_response = client.get( - f"/m/raw_lookup?signal_type=pdq&signal={IMAGE_URL_TO_PDQ[image_url_2]}" - ) - assert post_response.status_code == 200 - assert post_response.json == {"matches": [2]} + # test POST with temp file + response = requests.get(image_url) + image = Image.open(BytesIO(response.content)) + with tempfile.NamedTemporaryFile(suffix=".jpg") as f: + image.save(f, format="JPEG") + files = {"photo": (f.name, f.name, "image/jpeg")} + resp = client.post("/m/lookup", data=files) + assert resp.status_code == 200 + + +def test_lookup_without_role(app: Flask, client: FlaskClient): + # role resets to True in the next test + client.application.config["ROLE_HASHER"] = False + + # test GET + image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true" + get_resp = client.get(f"/m/lookup?url={image_url}") + assert get_resp.status_code == 403 + + # test POST with temp file + with tempfile.NamedTemporaryFile(suffix=".jpg") as f: + # Write a minimal valid JPEG file header + f.write( + b"\xff\xd8\xff\xe0\x00\x10\x4a\x46\x49\x46\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xd9" + ) + f.flush() + files = {"file": (f.name, f.name, "image/jpeg")} + resp = client.post("/m/lookup", data=files) + assert resp.status_code == 403 def test_exchange_api_set_auth(app: Flask, client: FlaskClient): diff --git a/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api_banks.py b/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api_banks.py new file mode 100644 index 000000000..043fa4468 --- /dev/null +++ b/hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api_banks.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from flask.testing import FlaskClient +from flask import Flask + +from threatexchange.signal_type.pdq.signal import PdqSignal + +from OpenMediaMatch.tests.utils import ( + app, + client, + create_bank, + add_hash_to_bank, + IMAGE_URL_TO_PDQ, +) +from OpenMediaMatch.background_tasks.build_index import build_all_indices +from OpenMediaMatch.persistence import get_storage + + +def test_banks_empty_index(client: FlaskClient): + response = client.get("/c/banks") + assert response.status_code == 200 + assert response.json == [] + + +def test_banks_create(client: FlaskClient): + # Must not start with number + post_response = client.post( + "/c/banks", + json={"name": "01_TEST_BANK"}, + ) + assert post_response.status_code == 400 + + # Cannot contain lowercase letters + post_response = client.post( + "/c/banks", + json={"name": "my_test_bank"}, + ) + assert post_response.status_code == 400 + + post_response = client.post( + "/c/banks", + json={"name": "MY_TEST_BANK_01"}, + ) + assert post_response.status_code == 201 + assert post_response.json == { + "matching_enabled_ratio": 1.0, + "name": "MY_TEST_BANK_01", + } + + # Should now be visible on index + response = client.get("/c/banks") + assert response.status_code == 200 + assert response.json == [post_response.json] + + +def test_banks_update(client: FlaskClient): + post_response = client.post( + "/c/banks", + json={"name": "MY_TEST_BANK"}, + ) + assert post_response.status_code == 201 + + # check name validation + post_response = client.put( + "/c/bank/MY_TEST_BANK", + json={"name": "1_invalid_name"}, + ) + assert post_response.status_code == 400 + + # check update with rename + post_response = client.put( + "/c/bank/MY_TEST_BANK", + json={"name": "MY_TEST_BANK_RENAMED"}, + ) + assert post_response.status_code == 200 + assert post_response.get_json()["name"] == "MY_TEST_BANK_RENAMED" + + # check update without rename + post_response = client.put( + "/c/bank/MY_TEST_BANK_RENAMED", + json={"enabled": False}, + ) + assert post_response.status_code == 200 + assert post_response.get_json()["matching_enabled_ratio"] == 0 + + # check update without ratio + post_response = client.put( + "/c/bank/MY_TEST_BANK_RENAMED", + json={"enabled_ratio": 0.5}, + ) + assert post_response.status_code == 200 + assert post_response.get_json()["matching_enabled_ratio"] == 0.5 + + # Final test to make sure we only have one bank with proper name and disabled + + get_response = client.get("/c/banks") + assert get_response.status_code == 200 + json = get_response.get_json() + assert len(json) == 1 + assert json[0] == {"name": "MY_TEST_BANK_RENAMED", "matching_enabled_ratio": 0.5} + + +def test_banks_delete(client: FlaskClient): + post_response = client.post( + "/c/banks", + json={"name": "MY_TEST_BANK"}, + ) + assert post_response.status_code == 201 + + # check name validation + post_response = client.delete( + "/c/bank/MY_TEST_BANK", + ) + assert post_response.status_code == 200 + + # deleting non existing bank should succeed + post_response = client.delete( + "/c/bank/MY_TEST_BANK", + ) + assert post_response.status_code == 200 + + +def test_banks_add_hash(client: FlaskClient): + bank_name = "NEW_BANK" + create_bank(client, bank_name) + + image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true" + + post_response = client.post( + f"/c/bank/{bank_name}/content?url={image_url}&content_type=photo" + ) + + assert post_response.status_code == 200, str(post_response.get_json()) + assert post_response.json == { + "id": 1, + "signals": { + "pdq": "f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22" + }, + } + + +def test_banks_delete_hash(client: FlaskClient): + bank_name = "NEW_BANK" + image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true" + + create_bank(client, bank_name) + add_hash_to_bank(client, bank_name, image_url, 1) + + post_response = client.delete(f"/c/bank/{bank_name}/content/1") + + assert post_response.status_code == 200 + assert post_response.json == {"deleted": 1} + + +def test_banks_add_metadata(client: FlaskClient): + bank_name = "NEW_BANK" + create_bank(client, bank_name) + + image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true" + post_request = f"/c/bank/{bank_name}/content?url={image_url}&content_type=photo" + + post_response = client.post( + post_request, json={"metadata": {"invalid_metadata": 5}} + ) + assert post_response.status_code == 400, str(post_response.get_json()) + + post_response = client.post( + post_request, + json={"metadata": {"content_id": "1197433091", "json": {"asdf": {}}}}, + ) + + assert post_response.status_code == 200, str(post_response.get_json()) + + +def test_banks_add_hash_index(app: Flask, client: FlaskClient): + bank_name = "NEW_BANK" + bank_name_2 = "NEW_BANK_2" + image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true" + image_url_2 = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/misc-images/c.png?raw=true" + + # Make two banks and add images to each bank + create_bank(client, bank_name) + add_hash_to_bank(client, bank_name, image_url, 1) + create_bank(client, bank_name_2) + add_hash_to_bank(client, bank_name, image_url_2, 2) + + storage = get_storage() + # ensure index is empty to start with + assert storage.get_signal_type_index(PdqSignal) is None + + # Build index + build_all_indices(storage, storage, storage) + + # Test against first image + post_response = client.get( + f"/m/raw_lookup?signal_type=pdq&signal={IMAGE_URL_TO_PDQ[image_url]}" + ) + assert post_response.status_code == 200 + assert post_response.json == {"matches": [1]} + + # Test against second image + post_response = client.get( + f"/m/raw_lookup?signal_type=pdq&signal={IMAGE_URL_TO_PDQ[image_url_2]}" + ) + assert post_response.status_code == 200 + assert post_response.json == {"matches": [2]}