Skip to content

Commit

Permalink
[hma] Check for hasher role in /lookup endpoint (#1729)
Browse files Browse the repository at this point in the history
  • Loading branch information
aryzle authored Dec 30, 2024
1 parent 148d8cc commit 318ffe6
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 188 deletions.
2 changes: 1 addition & 1 deletion hasher-matcher-actioner/.devcontainer/postcreate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set -e

pip install --editable .[all]

# Find Python packages in opt and install them
# Find Python packages in extensions and install them
for setup_script in "$(find /workspace/extensions -name setup.py)"
do
module_dir="$(dirname "$setup_script")"
Expand Down
2 changes: 1 addition & 1 deletion hasher-matcher-actioner/src/OpenMediaMatch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from threatexchange.signal_type.pdq import signal as _
## Resume regularly scheduled imports
# Resume regularly scheduled imports

import logging
import os
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def query_index(
try:
signal = signal_type.validate_signal_str(signal)
except Exception as e:
abort(400, f"invalid signal type: {e}")
abort(400, f"invalid signal: {e}")

index = _get_index(signal_type)

Expand Down Expand Up @@ -203,8 +203,10 @@ def lookup_get():
Output:
* List of matching banks
"""
# Url was passed as a query param?
if request.args.get("url", None):
if not current_app.config.get("ROLE_HASHER", False):
abort(403, "Hashing is disabled, missing role")

hashes = hashing.hash_media()
resp = {}
for signal_type in hashes.keys():
Expand All @@ -230,6 +232,9 @@ def lookup_post():
Output:
* List of matching banks
"""
if not current_app.config.get("ROLE_HASHER", False):
abort(403, "Hashing is disabled, missing role")

hashes = hashing.hash_media_post()

resp = {}
Expand Down
222 changes: 38 additions & 184 deletions hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

from io import BytesIO
import tempfile
import typing as t

from flask.testing import FlaskClient
from flask import Flask
from PIL import Image
import requests

from threatexchange.exchanges.impl.fb_threatexchange_api import (
FBThreatExchangeSignalExchangeAPI,
Expand All @@ -16,13 +20,9 @@
from OpenMediaMatch.tests.utils import (
app,
client,
create_bank,
add_hash_to_bank,
IMAGE_URL_TO_PDQ,
)
from OpenMediaMatch.background_tasks.build_index import build_all_indices
from OpenMediaMatch.persistence import get_storage
from OpenMediaMatch.storage.postgres import database


def test_status_response(client: FlaskClient):
Expand All @@ -31,194 +31,48 @@ def test_status_response(client: FlaskClient):
assert response.data == b"I-AM-ALIVE"


def test_banks_empty_index(client: FlaskClient):
response = client.get("/c/banks")
assert response.status_code == 200
assert response.json == []


def test_banks_create(client: FlaskClient):
# Must not start with number
post_response = client.post(
"/c/banks",
json={"name": "01_TEST_BANK"},
)
assert post_response.status_code == 400

# Cannot contain lowercase letters
post_response = client.post(
"/c/banks",
json={"name": "my_test_bank"},
)
assert post_response.status_code == 400

post_response = client.post(
"/c/banks",
json={"name": "MY_TEST_BANK_01"},
)
assert post_response.status_code == 201
assert post_response.json == {
"matching_enabled_ratio": 1.0,
"name": "MY_TEST_BANK_01",
}

# Should now be visible on index
response = client.get("/c/banks")
assert response.status_code == 200
assert response.json == [post_response.json]


def test_banks_update(client: FlaskClient):
post_response = client.post(
"/c/banks",
json={"name": "MY_TEST_BANK"},
)
assert post_response.status_code == 201

# check name validation
post_response = client.put(
"/c/bank/MY_TEST_BANK",
json={"name": "1_invalid_name"},
)
assert post_response.status_code == 400

# check update with rename
post_response = client.put(
"/c/bank/MY_TEST_BANK",
json={"name": "MY_TEST_BANK_RENAMED"},
)
assert post_response.status_code == 200
assert post_response.get_json()["name"] == "MY_TEST_BANK_RENAMED"

# check update without rename
post_response = client.put(
"/c/bank/MY_TEST_BANK_RENAMED",
json={"enabled": False},
)
assert post_response.status_code == 200
assert post_response.get_json()["matching_enabled_ratio"] == 0

# check update without ratio
post_response = client.put(
"/c/bank/MY_TEST_BANK_RENAMED",
json={"enabled_ratio": 0.5},
)
assert post_response.status_code == 200
assert post_response.get_json()["matching_enabled_ratio"] == 0.5

# Final test to make sure we only have one bank with proper name and disabled

get_response = client.get("/c/banks")
assert get_response.status_code == 200
json = get_response.get_json()
assert len(json) == 1
assert json[0] == {"name": "MY_TEST_BANK_RENAMED", "matching_enabled_ratio": 0.5}


def test_banks_delete(client: FlaskClient):
post_response = client.post(
"/c/banks",
json={"name": "MY_TEST_BANK"},
)
assert post_response.status_code == 201

# check name validation
post_response = client.delete(
"/c/bank/MY_TEST_BANK",
)
assert post_response.status_code == 200

# deleting non existing bank should succeed
post_response = client.delete(
"/c/bank/MY_TEST_BANK",
)
assert post_response.status_code == 200


def test_banks_add_hash(client: FlaskClient):
bank_name = "NEW_BANK"
create_bank(client, bank_name)

image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true"

post_response = client.post(
f"/c/bank/{bank_name}/content?url={image_url}&content_type=photo"
)

assert post_response.status_code == 200, str(post_response.get_json())
assert post_response.json == {
"id": 1,
"signals": {
"pdq": "f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22"
},
}


def test_banks_delete_hash(client: FlaskClient):
bank_name = "NEW_BANK"
image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true"

create_bank(client, bank_name)
add_hash_to_bank(client, bank_name, image_url, 1)

post_response = client.delete(f"/c/bank/{bank_name}/content/1")

assert post_response.status_code == 200
assert post_response.json == {"deleted": 1}


def test_banks_add_metadata(client: FlaskClient):
bank_name = "NEW_BANK"
create_bank(client, bank_name)

image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true"
post_request = f"/c/bank/{bank_name}/content?url={image_url}&content_type=photo"

post_response = client.post(
post_request, json={"metadata": {"invalid_metadata": 5}}
)
assert post_response.status_code == 400, str(post_response.get_json())

post_response = client.post(
post_request,
json={"metadata": {"content_id": "1197433091", "json": {"asdf": {}}}},
)

assert post_response.status_code == 200, str(post_response.get_json())


def test_banks_add_hash_index(app: Flask, client: FlaskClient):
bank_name = "NEW_BANK"
bank_name_2 = "NEW_BANK_2"
image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true"
image_url_2 = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/misc-images/c.png?raw=true"

# Make two banks and add images to each bank
create_bank(client, bank_name)
add_hash_to_bank(client, bank_name, image_url, 1)
create_bank(client, bank_name_2)
add_hash_to_bank(client, bank_name, image_url_2, 2)

def test_lookup_success(app: Flask, client: FlaskClient):
storage = get_storage()
# ensure index is empty to start with
assert storage.get_signal_type_index(PdqSignal) is None

# Build index
build_all_indices(storage, storage, storage)

# Test against first image
post_response = client.get(
f"/m/raw_lookup?signal_type=pdq&signal={IMAGE_URL_TO_PDQ[image_url]}"
)
assert post_response.status_code == 200
assert post_response.json == {"matches": [1]}
# test GET
image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true"
get_resp = client.get(f"/m/lookup?url={image_url}")
assert get_resp.status_code == 200

# Test against second image
post_response = client.get(
f"/m/raw_lookup?signal_type=pdq&signal={IMAGE_URL_TO_PDQ[image_url_2]}"
)
assert post_response.status_code == 200
assert post_response.json == {"matches": [2]}
# test POST with temp file
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
with tempfile.NamedTemporaryFile(suffix=".jpg") as f:
image.save(f, format="JPEG")
files = {"photo": (f.name, f.name, "image/jpeg")}
resp = client.post("/m/lookup", data=files)
assert resp.status_code == 200


def test_lookup_without_role(app: Flask, client: FlaskClient):
# role resets to True in the next test
client.application.config["ROLE_HASHER"] = False

# test GET
image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true"
get_resp = client.get(f"/m/lookup?url={image_url}")
assert get_resp.status_code == 403

# test POST with temp file
with tempfile.NamedTemporaryFile(suffix=".jpg") as f:
# Write a minimal valid JPEG file header
f.write(
b"\xff\xd8\xff\xe0\x00\x10\x4a\x46\x49\x46\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xd9"
)
f.flush()
files = {"file": (f.name, f.name, "image/jpeg")}
resp = client.post("/m/lookup", data=files)
assert resp.status_code == 403


def test_exchange_api_set_auth(app: Flask, client: FlaskClient):
Expand Down
Loading

0 comments on commit 318ffe6

Please sign in to comment.