From b6946b89be21dd61d3317d36141de4d7444e9edf Mon Sep 17 00:00:00 2001 From: sushichan044 Date: Thu, 16 Jan 2025 20:02:29 +0900 Subject: [PATCH 1/3] =?UTF-8?q?fix:=20=E6=98=8E=E7=A4=BA=E7=9A=84=E3=81=AB?= =?UTF-8?q?=E5=85=A8=E3=82=AA=E3=83=AA=E3=82=B8=E3=83=B3=E3=81=8B=E3=82=89?= =?UTF-8?q?=E3=81=AE=E3=83=AA=E3=82=AF=E3=82=A8=E3=82=B9=E3=83=88=E3=82=92?= =?UTF-8?q?=E8=A8=B1=E5=8F=AF=E3=81=99=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/birdxplorer_common/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/birdxplorer_common/settings.py b/common/birdxplorer_common/settings.py index 923ad98..ce15cc2 100644 --- a/common/birdxplorer_common/settings.py +++ b/common/birdxplorer_common/settings.py @@ -32,7 +32,7 @@ class CORSSettings(BaseSettings): allow_methods: list[str] = ["GET"] allow_headers: list[str] = ["*"] - allow_origins: list[str] = [] + allow_origins: list[str] = ["*"] class GlobalSettings(BaseSettings): From ecb36cb7439e5701bf49f869a5cd109fa7c6df3d Mon Sep 17 00:00:00 2001 From: yu23ki14 Date: Wed, 15 Jan 2025 10:34:10 +0900 Subject: [PATCH 2/3] implement basic feature of search endpoint --- api/birdxplorer_api/openapi_doc.py | 4 +- api/birdxplorer_api/routers/data.py | 126 +++++++----- common/birdxplorer_common/storage.py | 164 ++++++++++++++- common/tests/test_search.py | 188 ++++++++++++++++++ .../scripts/migrate_all.py | 174 ++++++++-------- 5 files changed, 512 insertions(+), 144 deletions(-) create mode 100644 common/tests/test_search.py diff --git a/api/birdxplorer_api/openapi_doc.py b/api/birdxplorer_api/openapi_doc.py index ce36f0b..64efa39 100644 --- a/api/birdxplorer_api/openapi_doc.py +++ b/api/birdxplorer_api/openapi_doc.py @@ -527,7 +527,7 @@ class FastAPIEndpointDocs(Generic[_KEY]): }, ) -v1_data_post_favorite_count = FastAPIEndpointParamDocs( +v1_data_post_like_count = FastAPIEndpointParamDocs( description="Postのお気に入り数。", openapi_examples={ "single": { @@ -583,7 +583,7 @@ class FastAPIEndpointDocs(Generic[_KEY]): "x_user_name": v1_data_x_user_name, "x_user_followers_count_from": v1_data_x_user_follower_count, "x_user_follow_count_from": v1_data_x_user_follow_count, - "post_favorite_count_from": v1_data_post_favorite_count, + "post_like_count_from": v1_data_post_like_count, "post_repost_count_from": v1_data_post_repost_count, "post_impression_count_from": v1_data_post_impression_count, "post_includes_media": v1_data_post_includes_media, diff --git a/api/birdxplorer_api/routers/data.py b/api/birdxplorer_api/routers/data.py index 1b48fa6..56cb2c6 100644 --- a/api/birdxplorer_api/routers/data.py +++ b/api/birdxplorer_api/routers/data.py @@ -417,6 +417,7 @@ def get_posts( @router.get("/search", description=V1DataSearchDocs.description, response_model=SearchResponse) def search( + request: Request, note_includes_text: Union[None, str] = Query(default=None, **V1DataSearchDocs.params["note_includes_text"]), note_excludes_text: Union[None, str] = Query(default=None, **V1DataSearchDocs.params["note_excludes_text"]), post_includes_text: Union[None, str] = Query(default=None, **V1DataSearchDocs.params["post_includes_text"]), @@ -437,9 +438,7 @@ def search( x_user_follow_count_from: Union[None, int] = Query( default=None, **V1DataSearchDocs.params["x_user_follow_count_from"] ), - post_favorite_count_from: Union[None, int] = Query( - default=None, **V1DataSearchDocs.params["post_favorite_count_from"] - ), + post_like_count_from: Union[None, int] = Query(default=None, **V1DataSearchDocs.params["post_like_count_from"]), post_repost_count_from: Union[None, int] = Query( default=None, **V1DataSearchDocs.params["post_repost_count_from"] ), @@ -450,59 +449,78 @@ def search( offset: int = Query(default=0, ge=0, **V1DataSearchDocs.params["offset"]), limit: int = Query(default=100, gt=0, le=1000, **V1DataSearchDocs.params["limit"]), ) -> SearchResponse: - return SearchResponse( - data=[ + # Convert timestamp strings to TwitterTimestamp objects + if note_created_at_from is not None and isinstance(note_created_at_from, str): + note_created_at_from = ensure_twitter_timestamp(note_created_at_from) + if note_created_at_to is not None and isinstance(note_created_at_to, str): + note_created_at_to = ensure_twitter_timestamp(note_created_at_to) + + # Get search results using the optimized storage method + results = [] + for note, post in storage.search_notes_with_posts( + note_includes_text=note_includes_text, + note_excludes_text=note_excludes_text, + post_includes_text=post_includes_text, + post_excludes_text=post_excludes_text, + language=language, + topic_ids=topic_ids, + note_status=note_status, + note_created_at_from=note_created_at_from, + note_created_at_to=note_created_at_to, + x_user_names=x_user_names, + x_user_followers_count_from=x_user_followers_count_from, + x_user_follow_count_from=x_user_follow_count_from, + post_like_count_from=post_like_count_from, + post_repost_count_from=post_repost_count_from, + post_impression_count_from=post_impression_count_from, + post_includes_media=post_includes_media, + offset=offset, + limit=limit, + ): + results.append( SearchedNote( - noteId="1845672983001710655", - language="ja", - topics=[ - { - "topicId": 26, - "label": {"ja": "セキュリティ上の脅威", "en": "security threat"}, - "referenceCount": 0, - }, - {"topicId": 47, "label": {"ja": "検閲", "en": "Censorship"}, "referenceCount": 0}, - {"topicId": 51, "label": {"ja": "テクノロジー", "en": "technology"}, "referenceCount": 0}, - ], - postId="1846718284369912064", - summary="Content Security Policyは情報の持ち出しを防止する仕組みではありません。コンテンツインジェクションの脆弱性のリスクを軽減する仕組みです。適切なContent Security Policyがレスポンスヘッダーに設定されている場合でも、外部への通信をブロックできない点に注意が必要です。 Content Security Policy Level 3 https://w3c.github.io/webappsec-csp/", # noqa: E501 - current_status="NEEDS_MORE_RATINGS", - created_at=1728877704750, - post={ - "postId": "1846718284369912064", - "xUserId": "90954365", - "xUser": { - "userId": "90954365", - "name": "earthquakejapan", - "profileImage": "https://pbs.twimg.com/profile_images/1638600342/japan_rel96_normal.jpg", - "followersCount": 162934, - "followingCount": 6, - }, - "text": "今後48時間以内に日本ではマグニチュード6.0の地震が発生する可能性があります。地図をご覧ください。", - "mediaDetails": [ - { - "mediaKey": "3_1846718279236177920-1846718284369912064", - "type": "photo", - "url": "https://pbs.twimg.com/media/GaDcfZoX0AAko2-.jpg", - "width": 900, - "height": 738, - } - ], - "createdAt": 1729094524000, - "likeCount": 451, - "repostCount": 104, - "impressionCount": 82378, - "links": [ - { - "linkId": "9c139b99-8111-e4f0-ad41-fc9e40d08722", - "url": "https://www.quakeprediction.com/Earthquake%20Forecast%20Japan.html", - } - ], - "link": "https://x.com/earthquakejapan/status/1846718284369912064", - }, + noteId=note.note_id, + language=note.language, + topics=note.topics, + postId=note.post_id, + summary=note.summary, + current_status=note.current_status, + created_at=note.created_at, + post=post, ) - ], - meta=PaginationMeta(next=None, prev=None), + ) + + # Get total count for pagination + total_count = storage.count_search_results( + note_includes_text=note_includes_text, + note_excludes_text=note_excludes_text, + post_includes_text=post_includes_text, + post_excludes_text=post_excludes_text, + language=language, + topic_ids=topic_ids, + note_status=note_status, + note_created_at_from=note_created_at_from, + note_created_at_to=note_created_at_to, + x_user_names=x_user_names, + x_user_followers_count_from=x_user_followers_count_from, + x_user_follow_count_from=x_user_follow_count_from, + post_like_count_from=post_like_count_from, + post_repost_count_from=post_repost_count_from, + post_impression_count_from=post_impression_count_from, + post_includes_media=post_includes_media, ) + # Generate pagination URLs + base_url = str(request.url).split("?")[0] + next_offset = offset + limit + prev_offset = max(offset - limit, 0) + next_url = None + if next_offset < total_count: + next_url = f"{base_url}?offset={next_offset}&limit={limit}" + prev_url = None + if offset > 0: + prev_url = f"{base_url}?offset={prev_offset}&limit={limit}" + + return SearchResponse(data=results, meta=PaginationMeta(next=next_url, prev=prev_url)) + return router diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index ab59332..f103096 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -1,4 +1,4 @@ -from typing import Generator, List, Union +from typing import Generator, List, Tuple, Union from psycopg2.extensions import AsIs, register_adapter from pydantic import AnyUrl, HttpUrl @@ -507,6 +507,168 @@ def get_number_of_posts( ) return query.count() + def search_notes_with_posts( + self, + note_includes_text: Union[str, None] = None, + note_excludes_text: Union[str, None] = None, + post_includes_text: Union[str, None] = None, + post_excludes_text: Union[str, None] = None, + language: Union[LanguageIdentifier, None] = None, + topic_ids: Union[List[TopicId], None] = None, + note_status: Union[List[str], None] = None, + note_created_at_from: Union[TwitterTimestamp, None] = None, + note_created_at_to: Union[TwitterTimestamp, None] = None, + x_user_names: Union[List[str], None] = None, + x_user_followers_count_from: Union[int, None] = None, + x_user_follow_count_from: Union[int, None] = None, + post_like_count_from: Union[int, None] = None, + post_repost_count_from: Union[int, None] = None, + post_impression_count_from: Union[int, None] = None, + post_includes_media: bool = True, + offset: int = 0, + limit: int = 100, + ) -> Generator[Tuple[NoteModel, PostModel], None, None]: + with Session(self.engine) as sess: + # Base query joining notes, posts and users + query = ( + sess.query(NoteRecord, PostRecord) + .outerjoin(PostRecord, NoteRecord.post_id == PostRecord.post_id) + .outerjoin(XUserRecord, PostRecord.user_id == XUserRecord.user_id) + ) + + # Apply note filters + if note_includes_text: + query = query.filter(NoteRecord.summary.like(f"%{note_includes_text}%")) + if note_excludes_text: + query = query.filter(~NoteRecord.summary.like(f"%{note_excludes_text}%")) + if language: + query = query.filter(NoteRecord.language == language) + if topic_ids: + subq = ( + select(NoteTopicAssociation.note_id) + .filter(NoteTopicAssociation.topic_id.in_(topic_ids)) + .group_by(NoteTopicAssociation.note_id) + .subquery() + ) + query = query.join(subq, NoteRecord.note_id == subq.c.note_id) + if note_status: + query = query.filter(NoteRecord.current_status.in_(note_status)) + if note_created_at_from: + query = query.filter(NoteRecord.created_at >= note_created_at_from) + if note_created_at_to: + query = query.filter(NoteRecord.created_at <= note_created_at_to) + + # Apply post filters + if post_includes_text: + query = query.filter(PostRecord.text.like(f"%{post_includes_text}%")) + if post_excludes_text: + query = query.filter(~PostRecord.text.like(f"%{post_excludes_text}%")) + if x_user_names: + query = query.filter(XUserRecord.name.in_(x_user_names)) + if x_user_followers_count_from: + query = query.filter(XUserRecord.followers_count >= x_user_followers_count_from) + if x_user_follow_count_from: + query = query.filter(XUserRecord.following_count >= x_user_follow_count_from) + if post_like_count_from: + query = query.filter(PostRecord.like_count >= post_like_count_from) + if post_repost_count_from: + query = query.filter(PostRecord.repost_count >= post_repost_count_from) + if post_impression_count_from: + query = query.filter(PostRecord.impression_count >= post_impression_count_from) + if not post_includes_media: + query = query.filter(~PostRecord.media_details.any()) + + # Pagination + query = query.offset(offset).limit(limit) + + # Execute query and yield results + for note_record, post_record in query.all(): + note = NoteModel( + note_id=note_record.note_id, + post_id=note_record.post_id, + topics=[ + TopicModel(topic_id=topic.topic_id, label=topic.topic.label, reference_count=0) + for topic in note_record.topics + ], + language=note_record.language, + summary=note_record.summary, + current_status=note_record.current_status, + created_at=note_record.created_at, + ) + + post = self._post_record_to_model(post_record, with_media=post_includes_media) if post_record else None + yield note, post + + def count_search_results( + self, + note_includes_text: Union[str, None] = None, + note_excludes_text: Union[str, None] = None, + post_includes_text: Union[str, None] = None, + post_excludes_text: Union[str, None] = None, + language: Union[LanguageIdentifier, None] = None, + topic_ids: Union[List[TopicId], None] = None, + note_status: Union[List[str], None] = None, + note_created_at_from: Union[TwitterTimestamp, None] = None, + note_created_at_to: Union[TwitterTimestamp, None] = None, + x_user_names: Union[List[str], None] = None, + x_user_followers_count_from: Union[int, None] = None, + x_user_follow_count_from: Union[int, None] = None, + post_like_count_from: Union[int, None] = None, + post_repost_count_from: Union[int, None] = None, + post_impression_count_from: Union[int, None] = None, + post_includes_media: bool = True, + ) -> int: + with Session(self.engine) as sess: + query = ( + sess.query(NoteRecord) + .join(PostRecord, NoteRecord.post_id == PostRecord.post_id) + .join(XUserRecord, PostRecord.user_id == XUserRecord.user_id) + ) + + # Apply note filters + if note_includes_text: + query = query.filter(NoteRecord.summary.like(f"%{note_includes_text}%")) + if note_excludes_text: + query = query.filter(~NoteRecord.summary.like(f"%{note_excludes_text}%")) + if language: + query = query.filter(NoteRecord.language == language) + if topic_ids: + subq = ( + select(NoteTopicAssociation.note_id) + .filter(NoteTopicAssociation.topic_id.in_(topic_ids)) + .group_by(NoteTopicAssociation.note_id) + .subquery() + ) + query = query.join(subq, NoteRecord.note_id == subq.c.note_id) + if note_status: + query = query.filter(NoteRecord.current_status.in_(note_status)) + if note_created_at_from: + query = query.filter(NoteRecord.created_at >= note_created_at_from) + if note_created_at_to: + query = query.filter(NoteRecord.created_at <= note_created_at_to) + + # Apply post filters + if post_includes_text: + query = query.filter(PostRecord.text.like(f"%{post_includes_text}%")) + if post_excludes_text: + query = query.filter(~PostRecord.text.like(f"%{post_excludes_text}%")) + if x_user_names: + query = query.filter(XUserRecord.name.in_(x_user_names)) + if x_user_followers_count_from: + query = query.filter(XUserRecord.followers_count >= x_user_followers_count_from) + if x_user_follow_count_from: + query = query.filter(XUserRecord.following_count >= x_user_follow_count_from) + if post_like_count_from: + query = query.filter(PostRecord.like_count >= post_like_count_from) + if post_repost_count_from: + query = query.filter(PostRecord.repost_count >= post_repost_count_from) + if post_impression_count_from: + query = query.filter(PostRecord.impression_count >= post_impression_count_from) + if not post_includes_media: + query = query.filter(~PostRecord.media_details.any()) + + return query.count() + def gen_storage(settings: GlobalSettings) -> Storage: engine = create_engine(settings.storage_settings.sqlalchemy_database_url) diff --git a/common/tests/test_search.py b/common/tests/test_search.py new file mode 100644 index 0000000..4485c53 --- /dev/null +++ b/common/tests/test_search.py @@ -0,0 +1,188 @@ +from typing import List + +import pytest +from sqlalchemy.engine import Engine + +from birdxplorer_common.models import ( + LanguageIdentifier, + Note, + Post, + Topic, + TopicId, + TwitterTimestamp, +) +from birdxplorer_common.storage import NoteRecord, PostRecord, Storage, TopicRecord + + +def test_basic_search( + engine_for_test: Engine, + note_samples: List[Note], + post_samples: List[Post], + note_records_sample: List[NoteRecord], + post_records_sample: List[PostRecord], +) -> None: + """Test basic search functionality without any filters""" + storage = Storage(engine=engine_for_test) + results = list(storage.search_notes_with_posts(limit=2)) + assert len(results) == 2 + for note, post in results: + assert note is not None + + +def test_search_by_note_text( + engine_for_test: Engine, + note_samples: List[Note], + post_samples: List[Post], + note_records_sample: List[NoteRecord], + post_records_sample: List[PostRecord], +) -> None: + """Test searching notes by included and excluded text""" + storage = Storage(engine=engine_for_test) + + # Test searching notes with text that should be included + results = list(storage.search_notes_with_posts(note_includes_text="summary")) + assert len(results) > 0 + for note, _ in results: + assert "summary" in note.summary.lower() + + # Test searching notes with text that should be excluded + results = list(storage.search_notes_with_posts(note_excludes_text="empty")) + assert len(results) > 0 + for note, _ in results: + assert "empty" not in note.summary.lower() + + +def test_search_by_language( + engine_for_test: Engine, + note_samples: List[Note], + post_samples: List[Post], + note_records_sample: List[NoteRecord], + post_records_sample: List[PostRecord], +) -> None: + """Test filtering by language""" + storage = Storage(engine=engine_for_test) + + # Test searching for English notes + results = list(storage.search_notes_with_posts(language=LanguageIdentifier("en"))) + assert len(results) > 0 + for note, _ in results: + assert note.language == "en" + + # Test searching for Japanese notes + results = list(storage.search_notes_with_posts(language=LanguageIdentifier("ja"))) + assert len(results) > 0 + for note, _ in results: + assert note.language == "ja" + + +def test_search_by_topics( + engine_for_test: Engine, + note_samples: List[Note], + post_samples: List[Post], + note_records_sample: List[NoteRecord], + post_records_sample: List[PostRecord], + topic_records_sample: List[TopicRecord], +) -> None: + """Test filtering by topics""" + storage = Storage(engine=engine_for_test) + topic_ids = [TopicId(0)] # Topic 0 is used in several notes in the sample data + + results = list(storage.search_notes_with_posts(topic_ids=topic_ids)) + assert len(results) > 0 + for note, _ in results: + note_topic_ids = [topic.topic_id for topic in note.topics] + assert any(tid in note_topic_ids for tid in topic_ids) + + +def test_search_by_post_text( + engine_for_test: Engine, + note_samples: List[Note], + post_samples: List[Post], + note_records_sample: List[NoteRecord], + post_records_sample: List[PostRecord], +) -> None: + """Test searching posts by included and excluded text""" + storage = Storage(engine=engine_for_test) + + # Test searching posts with text that should be included + results = list(storage.search_notes_with_posts(post_includes_text="プロジェクト")) + assert len(results) > 0 + for _, post in results: + assert post is not None + assert "プロジェクト" in post.text + + # Test searching posts with text that should be excluded + results = list(storage.search_notes_with_posts(post_excludes_text="empty")) + assert len(results) > 0 + for _, post in results: + if post is not None: + assert "empty" not in post.text + + +def test_combined_search( + engine_for_test: Engine, + note_samples: List[Note], + post_samples: List[Post], + note_records_sample: List[NoteRecord], + post_records_sample: List[PostRecord], +) -> None: + """Test combining multiple search criteria""" + storage = Storage(engine=engine_for_test) + + results = list( + storage.search_notes_with_posts(note_includes_text="summary", language=LanguageIdentifier("en"), limit=2) + ) + + assert len(results) <= 2 + for note, _ in results: + assert "summary" in note.summary.lower() + assert note.language == "en" + + +def test_pagination( + engine_for_test: Engine, + note_samples: List[Note], + post_samples: List[Post], + note_records_sample: List[NoteRecord], + post_records_sample: List[PostRecord], +) -> None: + """Test pagination functionality""" + storage = Storage(engine=engine_for_test) + + # Get first page + page_size = 2 + first_page = list(storage.search_notes_with_posts(limit=page_size, offset=0)) + assert len(first_page) <= page_size + + # Get second page + second_page = list(storage.search_notes_with_posts(limit=page_size, offset=page_size)) + assert len(second_page) <= page_size + + # Ensure pages are different + first_page_ids = {note.note_id for note, _ in first_page} + second_page_ids = {note.note_id for note, _ in second_page} + assert not first_page_ids.intersection(second_page_ids) + + +def test_count_search_results( + engine_for_test: Engine, + note_samples: List[Note], + post_samples: List[Post], + note_records_sample: List[NoteRecord], + post_records_sample: List[PostRecord], +) -> None: + """Test the count functionality of search results""" + storage = Storage(engine=engine_for_test) + + # Get total count + total_count = storage.count_search_results() + assert total_count > 0 + + # Get filtered count + filtered_count = storage.count_search_results(note_includes_text="summary", language=LanguageIdentifier("en")) + assert filtered_count > 0 + assert filtered_count <= total_count + + # Verify count matches actual results + results = list(storage.search_notes_with_posts(note_includes_text="summary", language=LanguageIdentifier("en"))) + assert len(results) == filtered_count diff --git a/migrate/birdxplorer_migration/scripts/migrate_all.py b/migrate/birdxplorer_migration/scripts/migrate_all.py index 8d6b879..387839e 100644 --- a/migrate/birdxplorer_migration/scripts/migrate_all.py +++ b/migrate/birdxplorer_migration/scripts/migrate_all.py @@ -34,91 +34,91 @@ storage = gen_storage(settings=settings) Base.metadata.create_all(storage.engine) - with Session(storage.engine) as sess: - with open(os.path.join(args.data_dir, args.topics_file_name), "r", encoding="utf-8") as fin: - for d in csv.DictReader(fin): - d["topic_id"] = int(d["topic_id"]) - d["label"] = json.loads(d["label"]) - if sess.query(TopicRecord).filter(TopicRecord.topic_id == d["topic_id"]).count() > 0: - continue - sess.add(TopicRecord(topic_id=d["topic_id"], label=d["label"])) - sess.commit() - with open(os.path.join(args.data_dir, args.notes_file_name), "r", encoding="utf-8") as fin: - for d in csv.DictReader(fin): - if sess.query(NoteRecord).filter(NoteRecord.note_id == d["note_id"]).count() > 0: - continue - sess.add( - NoteRecord( - note_id=d["note_id"], - post_id=d["post_id"], - language=d["language"], - summary=d["summary"], - created_at=d["created_at"], - ) - ) - sess.commit() - with open( - os.path.join(args.data_dir, args.notes_topics_association_file_name), - "r", - encoding="utf-8", - ) as fin: - for d in csv.DictReader(fin): - if ( - sess.query(NoteTopicAssociation) - .filter( - NoteTopicAssociation.note_id == d["note_id"], - NoteTopicAssociation.topic_id == d["topic_id"], - ) - .count() - > 0 - ): - continue - sess.add( - NoteTopicAssociation( - note_id=d["note_id"], - topic_id=d["topic_id"], - ) - ) - sess.commit() - with open(os.path.join(args.data_dir, args.x_users_file_name), "r", encoding="utf-8") as fin: - for d in csv.DictReader(fin): - d["followers_count"] = int(d["followers_count"]) - d["following_count"] = int(d["following_count"]) - if sess.query(XUserRecord).filter(XUserRecord.user_id == d["user_id"]).count() > 0: - continue - sess.add( - XUserRecord( - user_id=d["user_id"], - name=d["name"], - profile_image=d["profile_image"], - followers_count=d["followers_count"], - following_count=d["following_count"], - ) - ) - sess.commit() - with open(os.path.join(args.data_dir, args.posts_file_name), "r", encoding="utf-8") as fin: - for d in csv.DictReader(fin): - if ( - args.limit_number_of_post_rows is not None - and sess.query(PostRecord).count() >= args.limit_number_of_post_rows - ): - break - d["like_count"] = int(d["like_count"]) - d["repost_count"] = int(d["repost_count"]) - d["impression_count"] = int(d["impression_count"]) - if sess.query(PostRecord).filter(PostRecord.post_id == d["post_id"]).count() > 0: - continue - sess.add( - PostRecord( - post_id=d["post_id"], - user_id=d["user_id"], - text=d["text"], - media_details=(json.loads(d["media_details"]) if len(d["media_details"]) > 0 else None), - created_at=d["created_at"], - like_count=d["like_count"], - repost_count=d["repost_count"], - impression_count=d["impression_count"], - ) - ) - sess.commit() + # with Session(storage.engine) as sess: + # with open(os.path.join(args.data_dir, args.topics_file_name), "r", encoding="utf-8") as fin: + # for d in csv.DictReader(fin): + # d["topic_id"] = int(d["topic_id"]) + # d["label"] = json.loads(d["label"]) + # if sess.query(TopicRecord).filter(TopicRecord.topic_id == d["topic_id"]).count() > 0: + # continue + # sess.add(TopicRecord(topic_id=d["topic_id"], label=d["label"])) + # sess.commit() + # with open(os.path.join(args.data_dir, args.notes_file_name), "r", encoding="utf-8") as fin: + # for d in csv.DictReader(fin): + # if sess.query(NoteRecord).filter(NoteRecord.note_id == d["note_id"]).count() > 0: + # continue + # sess.add( + # NoteRecord( + # note_id=d["note_id"], + # post_id=d["post_id"], + # language=d["language"], + # summary=d["summary"], + # created_at=d["created_at"], + # ) + # ) + # sess.commit() + # with open( + # os.path.join(args.data_dir, args.notes_topics_association_file_name), + # "r", + # encoding="utf-8", + # ) as fin: + # for d in csv.DictReader(fin): + # if ( + # sess.query(NoteTopicAssociation) + # .filter( + # NoteTopicAssociation.note_id == d["note_id"], + # NoteTopicAssociation.topic_id == d["topic_id"], + # ) + # .count() + # > 0 + # ): + # continue + # sess.add( + # NoteTopicAssociation( + # note_id=d["note_id"], + # topic_id=d["topic_id"], + # ) + # ) + # sess.commit() + # with open(os.path.join(args.data_dir, args.x_users_file_name), "r", encoding="utf-8") as fin: + # for d in csv.DictReader(fin): + # d["followers_count"] = int(d["followers_count"]) + # d["following_count"] = int(d["following_count"]) + # if sess.query(XUserRecord).filter(XUserRecord.user_id == d["user_id"]).count() > 0: + # continue + # sess.add( + # XUserRecord( + # user_id=d["user_id"], + # name=d["name"], + # profile_image=d["profile_image"], + # followers_count=d["followers_count"], + # following_count=d["following_count"], + # ) + # ) + # sess.commit() + # with open(os.path.join(args.data_dir, args.posts_file_name), "r", encoding="utf-8") as fin: + # for d in csv.DictReader(fin): + # if ( + # args.limit_number_of_post_rows is not None + # and sess.query(PostRecord).count() >= args.limit_number_of_post_rows + # ): + # break + # d["like_count"] = int(d["like_count"]) + # d["repost_count"] = int(d["repost_count"]) + # d["impression_count"] = int(d["impression_count"]) + # if sess.query(PostRecord).filter(PostRecord.post_id == d["post_id"]).count() > 0: + # continue + # sess.add( + # PostRecord( + # post_id=d["post_id"], + # user_id=d["user_id"], + # text=d["text"], + # media_details=(json.loads(d["media_details"]) if len(d["media_details"]) > 0 else None), + # created_at=d["created_at"], + # like_count=d["like_count"], + # repost_count=d["repost_count"], + # impression_count=d["impression_count"], + # ) + # ) + # sess.commit() logger.info("Migration is done") From e0532f0b767e9ab6d950d19ad6510317b7c6fdfb Mon Sep 17 00:00:00 2001 From: yu23ki14 Date: Wed, 15 Jan 2025 10:39:22 +0900 Subject: [PATCH 3/3] use 138 common lib for test --- api/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index 7bc5466..953f150 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -61,7 +61,7 @@ dev=[ "httpx", ] prod=[ - "birdxplorer_common @ git+https://github.com/codeforjapan/BirdXplorer.git@main#subdirectory=common", + "birdxplorer_common @ git+https://github.com/codeforjapan/BirdXplorer.git@feature/138#subdirectory=common", "psycopg2", "gunicorn", ]