Skip to content

Commit

Permalink
fix includes media
Browse files Browse the repository at this point in the history
  • Loading branch information
yu23ki14 committed Jan 20, 2025
1 parent b1f3a3d commit 416a5ff
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 3 deletions.
139 changes: 139 additions & 0 deletions api/tests/routers/test_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from datetime import datetime, timezone
from typing import Dict, List, Union
from unittest.mock import MagicMock

from fastapi.testclient import TestClient

from birdxplorer_common.models import Note, Post, Topic, TwitterTimestamp, XUser


def test_search_basic(client: TestClient, mock_storage: MagicMock) -> None:
# Mock data
timestamp = TwitterTimestamp.from_int(int(datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp() * 1000))

note = Note(
note_id="1234567890123456789", # 19-digit string
post_id="2234567890123456789", # 19-digit string
language="ja",
topics=[Topic(topic_id=1, label={"ja": "テスト", "en": "test"}, reference_count=1)],
summary="Test summary",
current_status="NEEDS_MORE_RATINGS",
created_at=timestamp,
)

post = Post(
post_id="2234567890123456789", # 19-digit string
x_user_id="9876543210123456789", # 19-digit string
x_user=XUser(
user_id="9876543210123456789", # 19-digit string
name="test_user",
profile_image="http://example.com/image.jpg",
followers_count=100,
following_count=50,
),
text="Test post",
media_details=[],
created_at=timestamp,
like_count=10,
repost_count=5,
impression_count=100,
links=[],
link="http://x.com/test_user/status/2234567890123456789",
)

# Mock storage response
mock_storage.search_notes_with_posts.return_value = [(note, post)]
mock_storage.count_search_results.return_value = 1

# Test basic search
response = client.get("/api/v1/data/search?note_includes_text=test")
assert response.status_code == 200

data = response.json()
assert "data" in data
assert "meta" in data
assert len(data["data"]) == 1

# Verify response structure
result = data["data"][0]
assert result["noteId"] == "1234567890123456789"
assert result["postId"] == "2234567890123456789"
assert result["language"] == "ja"
assert result["summary"] == "Test summary"
assert result["currentStatus"] == "NEEDS_MORE_RATINGS"
assert result["post"]["postId"] == "2234567890123456789"


def test_search_pagination(client: TestClient, mock_storage: MagicMock) -> None:
# Mock data for pagination test
mock_storage.search_notes_with_posts.return_value = []
mock_storage.count_search_results.return_value = 150

# Test first page
response = client.get("/api/v1/data/search?limit=50&offset=0")
assert response.status_code == 200
data = response.json()
assert data["meta"]["next"] is not None # Should have next page
assert data["meta"]["prev"] is None # Should not have prev page

# Test middle page
response = client.get("/api/v1/data/search?limit=50&offset=50")
assert response.status_code == 200
data = response.json()
assert data["meta"]["next"] is not None # Should have next page
assert data["meta"]["prev"] is not None # Should have prev page

# Test last page
response = client.get("/api/v1/data/search?limit=50&offset=100")
assert response.status_code == 200
data = response.json()
assert data["meta"]["next"] is None # Should not have next page
assert data["meta"]["prev"] is not None # Should have prev page


def test_search_parameters(client: TestClient, mock_storage: MagicMock) -> None:
mock_storage.search_notes_with_posts.return_value = []
mock_storage.count_search_results.return_value = 0

# Test various parameter combinations
test_cases: List[Dict[str, Union[str, List[str], List[int], int, bool]]] = [
{"note_includes_text": "test"},
{"note_excludes_text": "spam"},
{"post_includes_text": "hello"},
{"post_excludes_text": "goodbye"},
{"language": "ja"},
{"topic_ids": [1, 2, 3]},
{"note_status": ["NEEDS_MORE_RATINGS"]},
{"x_user_names": ["test_user"]},
{"x_user_followers_count_from": 1000},
{"post_like_count_from": 100},
{"post_includes_media": True},
]

for params in test_cases:
query = "&".join(
f"{k}={v}" if not isinstance(v, list) else f"{k}={','.join(map(str, v))}" for k, v in params.items()
)
response = client.get(f"/api/v1/data/search?{query}")
assert response.status_code == 200


def test_search_timestamp_conversion(client: TestClient, mock_storage: MagicMock) -> None:
mock_storage.search_notes_with_posts.return_value = []
mock_storage.count_search_results.return_value = 0

# Test various timestamp formats
base_timestamp = int(datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp() * 1000)
timestamp_cases = [
f"note_created_at_from={base_timestamp}", # Unix timestamp in milliseconds
"note_created_at_from=2023-01-01", # Date string
"note_created_at_from=2023-01-01T00:00:00Z", # ISO format
]

for query in timestamp_cases:
response = client.get(f"/api/v1/data/search?{query}")
assert response.status_code == 200

# Test invalid timestamp
response = client.get("/api/v1/data/search?note_created_at_from=invalid")
assert response.status_code == 422
14 changes: 11 additions & 3 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def search_notes_with_posts(
post_like_count_from: Union[int, None] = None,
post_repost_count_from: Union[int, None] = None,
post_impression_count_from: Union[int, None] = None,
post_includes_media: bool = True,
post_includes_media: Union[bool, None] = None,
offset: int = 0,
limit: int = 100,
) -> Generator[Tuple[NoteModel, PostModel], None, None]:
Expand Down Expand Up @@ -575,7 +575,11 @@ def search_notes_with_posts(
query = query.filter(PostRecord.repost_count >= post_repost_count_from)
if post_impression_count_from:
query = query.filter(PostRecord.impression_count >= post_impression_count_from)
if not post_includes_media:
if post_includes_media:
# Only include posts that have media
query = query.filter(PostRecord.media_details.any())
if post_includes_media is False:
# Only include posts that don't have media
query = query.filter(~PostRecord.media_details.any())

# Pagination
Expand Down Expand Up @@ -664,7 +668,11 @@ def count_search_results(
query = query.filter(PostRecord.repost_count >= post_repost_count_from)
if post_impression_count_from:
query = query.filter(PostRecord.impression_count >= post_impression_count_from)
if not post_includes_media:
if post_includes_media:
# Only include posts that have media
query = query.filter(PostRecord.media_details.any())
if post_includes_media is False:
# Only include posts that don't have media
query = query.filter(~PostRecord.media_details.any())

return query.count()
Expand Down

0 comments on commit 416a5ff

Please sign in to comment.