Skip to content

Commit

Permalink
implement basic feature of search endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
yu23ki14 committed Jan 15, 2025
1 parent 1cfd73b commit ef339d3
Show file tree
Hide file tree
Showing 5 changed files with 512 additions and 144 deletions.
4 changes: 2 additions & 2 deletions api/birdxplorer_api/openapi_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ class FastAPIEndpointDocs(Generic[_KEY]):
},
)

v1_data_post_favorite_count = FastAPIEndpointParamDocs(
v1_data_post_like_count = FastAPIEndpointParamDocs(
description="Postのお気に入り数。",
openapi_examples={
"single": {
Expand Down Expand Up @@ -583,7 +583,7 @@ class FastAPIEndpointDocs(Generic[_KEY]):
"x_user_name": v1_data_x_user_name,
"x_user_followers_count_from": v1_data_x_user_follower_count,
"x_user_follow_count_from": v1_data_x_user_follow_count,
"post_favorite_count_from": v1_data_post_favorite_count,
"post_like_count_from": v1_data_post_like_count,
"post_repost_count_from": v1_data_post_repost_count,
"post_impression_count_from": v1_data_post_impression_count,
"post_includes_media": v1_data_post_includes_media,
Expand Down
126 changes: 72 additions & 54 deletions api/birdxplorer_api/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def get_posts(

@router.get("/search", description=V1DataSearchDocs.description, response_model=SearchResponse)
def search(
request: Request,
note_includes_text: Union[None, str] = Query(default=None, **V1DataSearchDocs.params["note_includes_text"]),
note_excludes_text: Union[None, str] = Query(default=None, **V1DataSearchDocs.params["note_excludes_text"]),
post_includes_text: Union[None, str] = Query(default=None, **V1DataSearchDocs.params["post_includes_text"]),
Expand All @@ -437,9 +438,7 @@ def search(
x_user_follow_count_from: Union[None, int] = Query(
default=None, **V1DataSearchDocs.params["x_user_follow_count_from"]
),
post_favorite_count_from: Union[None, int] = Query(
default=None, **V1DataSearchDocs.params["post_favorite_count_from"]
),
post_like_count_from: Union[None, int] = Query(default=None, **V1DataSearchDocs.params["post_like_count_from"]),
post_repost_count_from: Union[None, int] = Query(
default=None, **V1DataSearchDocs.params["post_repost_count_from"]
),
Expand All @@ -450,59 +449,78 @@ def search(
offset: int = Query(default=0, ge=0, **V1DataSearchDocs.params["offset"]),
limit: int = Query(default=100, gt=0, le=1000, **V1DataSearchDocs.params["limit"]),
) -> SearchResponse:
return SearchResponse(
data=[
# Convert timestamp strings to TwitterTimestamp objects
if note_created_at_from is not None and isinstance(note_created_at_from, str):
note_created_at_from = ensure_twitter_timestamp(note_created_at_from)
if note_created_at_to is not None and isinstance(note_created_at_to, str):
note_created_at_to = ensure_twitter_timestamp(note_created_at_to)

# Get search results using the optimized storage method
results = []
for note, post in storage.search_notes_with_posts(
note_includes_text=note_includes_text,
note_excludes_text=note_excludes_text,
post_includes_text=post_includes_text,
post_excludes_text=post_excludes_text,
language=language,
topic_ids=topic_ids,
note_status=note_status,
note_created_at_from=note_created_at_from,
note_created_at_to=note_created_at_to,
x_user_names=x_user_names,
x_user_followers_count_from=x_user_followers_count_from,
x_user_follow_count_from=x_user_follow_count_from,
post_like_count_from=post_like_count_from,
post_repost_count_from=post_repost_count_from,
post_impression_count_from=post_impression_count_from,
post_includes_media=post_includes_media,
offset=offset,
limit=limit,
):
results.append(
SearchedNote(
noteId="1845672983001710655",
language="ja",
topics=[
{
"topicId": 26,
"label": {"ja": "セキュリティ上の脅威", "en": "security threat"},
"referenceCount": 0,
},
{"topicId": 47, "label": {"ja": "検閲", "en": "Censorship"}, "referenceCount": 0},
{"topicId": 51, "label": {"ja": "テクノロジー", "en": "technology"}, "referenceCount": 0},
],
postId="1846718284369912064",
summary="Content Security Policyは情報の持ち出しを防止する仕組みではありません。コンテンツインジェクションの脆弱性のリスクを軽減する仕組みです。適切なContent Security Policyがレスポンスヘッダーに設定されている場合でも、外部への通信をブロックできない点に注意が必要です。 Content Security Policy Level 3 https://w3c.github.io/webappsec-csp/", # noqa: E501
current_status="NEEDS_MORE_RATINGS",
created_at=1728877704750,
post={
"postId": "1846718284369912064",
"xUserId": "90954365",
"xUser": {
"userId": "90954365",
"name": "earthquakejapan",
"profileImage": "https://pbs.twimg.com/profile_images/1638600342/japan_rel96_normal.jpg",
"followersCount": 162934,
"followingCount": 6,
},
"text": "今後48時間以内に日本ではマグニチュード6.0の地震が発生する可能性があります。地図をご覧ください。",
"mediaDetails": [
{
"mediaKey": "3_1846718279236177920-1846718284369912064",
"type": "photo",
"url": "https://pbs.twimg.com/media/GaDcfZoX0AAko2-.jpg",
"width": 900,
"height": 738,
}
],
"createdAt": 1729094524000,
"likeCount": 451,
"repostCount": 104,
"impressionCount": 82378,
"links": [
{
"linkId": "9c139b99-8111-e4f0-ad41-fc9e40d08722",
"url": "https://www.quakeprediction.com/Earthquake%20Forecast%20Japan.html",
}
],
"link": "https://x.com/earthquakejapan/status/1846718284369912064",
},
noteId=note.note_id,
language=note.language,
topics=note.topics,
postId=note.post_id,
summary=note.summary,
current_status=note.current_status,
created_at=note.created_at,
post=post,
)
],
meta=PaginationMeta(next=None, prev=None),
)

# Get total count for pagination
total_count = storage.count_search_results(
note_includes_text=note_includes_text,
note_excludes_text=note_excludes_text,
post_includes_text=post_includes_text,
post_excludes_text=post_excludes_text,
language=language,
topic_ids=topic_ids,
note_status=note_status,
note_created_at_from=note_created_at_from,
note_created_at_to=note_created_at_to,
x_user_names=x_user_names,
x_user_followers_count_from=x_user_followers_count_from,
x_user_follow_count_from=x_user_follow_count_from,
post_like_count_from=post_like_count_from,
post_repost_count_from=post_repost_count_from,
post_impression_count_from=post_impression_count_from,
post_includes_media=post_includes_media,
)

# Generate pagination URLs
base_url = str(request.url).split("?")[0]
next_offset = offset + limit
prev_offset = max(offset - limit, 0)
next_url = None
if next_offset < total_count:
next_url = f"{base_url}?offset={next_offset}&limit={limit}"
prev_url = None
if offset > 0:
prev_url = f"{base_url}?offset={prev_offset}&limit={limit}"

return SearchResponse(data=results, meta=PaginationMeta(next=next_url, prev=prev_url))

return router
164 changes: 163 additions & 1 deletion common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generator, List, Union
from typing import Generator, List, Tuple, Union

from psycopg2.extensions import AsIs, register_adapter
from pydantic import AnyUrl, HttpUrl
Expand Down Expand Up @@ -507,6 +507,168 @@ def get_number_of_posts(
)
return query.count()

def search_notes_with_posts(
self,
note_includes_text: Union[str, None] = None,
note_excludes_text: Union[str, None] = None,
post_includes_text: Union[str, None] = None,
post_excludes_text: Union[str, None] = None,
language: Union[LanguageIdentifier, None] = None,
topic_ids: Union[List[TopicId], None] = None,
note_status: Union[List[str], None] = None,
note_created_at_from: Union[TwitterTimestamp, None] = None,
note_created_at_to: Union[TwitterTimestamp, None] = None,
x_user_names: Union[List[str], None] = None,
x_user_followers_count_from: Union[int, None] = None,
x_user_follow_count_from: Union[int, None] = None,
post_like_count_from: Union[int, None] = None,
post_repost_count_from: Union[int, None] = None,
post_impression_count_from: Union[int, None] = None,
post_includes_media: bool = True,
offset: int = 0,
limit: int = 100,
) -> Generator[Tuple[NoteModel, PostModel], None, None]:
with Session(self.engine) as sess:
# Base query joining notes, posts and users
query = (
sess.query(NoteRecord, PostRecord)
.outerjoin(PostRecord, NoteRecord.post_id == PostRecord.post_id)
.outerjoin(XUserRecord, PostRecord.user_id == XUserRecord.user_id)
)

# Apply note filters
if note_includes_text:
query = query.filter(NoteRecord.summary.like(f"%{note_includes_text}%"))
if note_excludes_text:
query = query.filter(~NoteRecord.summary.like(f"%{note_excludes_text}%"))
if language:
query = query.filter(NoteRecord.language == language)
if topic_ids:
subq = (
select(NoteTopicAssociation.note_id)
.filter(NoteTopicAssociation.topic_id.in_(topic_ids))
.group_by(NoteTopicAssociation.note_id)
.subquery()
)
query = query.join(subq, NoteRecord.note_id == subq.c.note_id)
if note_status:
query = query.filter(NoteRecord.current_status.in_(note_status))
if note_created_at_from:
query = query.filter(NoteRecord.created_at >= note_created_at_from)
if note_created_at_to:
query = query.filter(NoteRecord.created_at <= note_created_at_to)

# Apply post filters
if post_includes_text:
query = query.filter(PostRecord.text.like(f"%{post_includes_text}%"))
if post_excludes_text:
query = query.filter(~PostRecord.text.like(f"%{post_excludes_text}%"))
if x_user_names:
query = query.filter(XUserRecord.name.in_(x_user_names))
if x_user_followers_count_from:
query = query.filter(XUserRecord.followers_count >= x_user_followers_count_from)
if x_user_follow_count_from:
query = query.filter(XUserRecord.following_count >= x_user_follow_count_from)
if post_like_count_from:
query = query.filter(PostRecord.like_count >= post_like_count_from)
if post_repost_count_from:
query = query.filter(PostRecord.repost_count >= post_repost_count_from)
if post_impression_count_from:
query = query.filter(PostRecord.impression_count >= post_impression_count_from)
if not post_includes_media:
query = query.filter(~PostRecord.media_details.any())

# Pagination
query = query.offset(offset).limit(limit)

# Execute query and yield results
for note_record, post_record in query.all():
note = NoteModel(
note_id=note_record.note_id,
post_id=note_record.post_id,
topics=[
TopicModel(topic_id=topic.topic_id, label=topic.topic.label, reference_count=0)
for topic in note_record.topics
],
language=note_record.language,
summary=note_record.summary,
current_status=note_record.current_status,
created_at=note_record.created_at,
)

post = self._post_record_to_model(post_record, with_media=post_includes_media) if post_record else None
yield note, post

def count_search_results(
self,
note_includes_text: Union[str, None] = None,
note_excludes_text: Union[str, None] = None,
post_includes_text: Union[str, None] = None,
post_excludes_text: Union[str, None] = None,
language: Union[LanguageIdentifier, None] = None,
topic_ids: Union[List[TopicId], None] = None,
note_status: Union[List[str], None] = None,
note_created_at_from: Union[TwitterTimestamp, None] = None,
note_created_at_to: Union[TwitterTimestamp, None] = None,
x_user_names: Union[List[str], None] = None,
x_user_followers_count_from: Union[int, None] = None,
x_user_follow_count_from: Union[int, None] = None,
post_like_count_from: Union[int, None] = None,
post_repost_count_from: Union[int, None] = None,
post_impression_count_from: Union[int, None] = None,
post_includes_media: bool = True,
) -> int:
with Session(self.engine) as sess:
query = (
sess.query(NoteRecord)
.join(PostRecord, NoteRecord.post_id == PostRecord.post_id)
.join(XUserRecord, PostRecord.user_id == XUserRecord.user_id)
)

# Apply note filters
if note_includes_text:
query = query.filter(NoteRecord.summary.like(f"%{note_includes_text}%"))
if note_excludes_text:
query = query.filter(~NoteRecord.summary.like(f"%{note_excludes_text}%"))
if language:
query = query.filter(NoteRecord.language == language)
if topic_ids:
subq = (
select(NoteTopicAssociation.note_id)
.filter(NoteTopicAssociation.topic_id.in_(topic_ids))
.group_by(NoteTopicAssociation.note_id)
.subquery()
)
query = query.join(subq, NoteRecord.note_id == subq.c.note_id)
if note_status:
query = query.filter(NoteRecord.current_status.in_(note_status))
if note_created_at_from:
query = query.filter(NoteRecord.created_at >= note_created_at_from)
if note_created_at_to:
query = query.filter(NoteRecord.created_at <= note_created_at_to)

# Apply post filters
if post_includes_text:
query = query.filter(PostRecord.text.like(f"%{post_includes_text}%"))
if post_excludes_text:
query = query.filter(~PostRecord.text.like(f"%{post_excludes_text}%"))
if x_user_names:
query = query.filter(XUserRecord.name.in_(x_user_names))
if x_user_followers_count_from:
query = query.filter(XUserRecord.followers_count >= x_user_followers_count_from)
if x_user_follow_count_from:
query = query.filter(XUserRecord.following_count >= x_user_follow_count_from)
if post_like_count_from:
query = query.filter(PostRecord.like_count >= post_like_count_from)
if post_repost_count_from:
query = query.filter(PostRecord.repost_count >= post_repost_count_from)
if post_impression_count_from:
query = query.filter(PostRecord.impression_count >= post_impression_count_from)
if not post_includes_media:
query = query.filter(~PostRecord.media_details.any())

return query.count()


def gen_storage(settings: GlobalSettings) -> Storage:
engine = create_engine(settings.storage_settings.sqlalchemy_database_url)
Expand Down
Loading

0 comments on commit ef339d3

Please sign in to comment.