From 9d5555da3b62099e0c0d7cf5861e4e53a317c3f9 Mon Sep 17 00:00:00 2001 From: sushichan044 Date: Tue, 8 Oct 2024 17:23:58 +0900 Subject: [PATCH] =?UTF-8?q?=E6=97=A2=E5=AD=98=E3=81=AE=E3=83=86=E3=82=B9?= =?UTF-8?q?=E3=83=88=E3=81=8C=E5=A3=8A=E3=82=8C=E3=82=8B=E9=83=A8=E5=88=86?= =?UTF-8?q?=E3=82=92=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/tests/conftest.py | 19 +++++---- common/birdxplorer_common/models.py | 9 +--- common/birdxplorer_common/storage.py | 25 ++++-------- common/tests/conftest.py | 61 ++++++++++++++++------------ common/tests/test_storage.py | 2 +- 5 files changed, 57 insertions(+), 59 deletions(-) diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 08c3162..00479d4 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -4,6 +4,14 @@ from typing import List, Type, Union from unittest.mock import MagicMock, patch +from dotenv import load_dotenv +from fastapi.testclient import TestClient +from polyfactory import Use +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.pytest_plugin import register_fixture +from pydantic import HttpUrl +from pytest import fixture + from birdxplorer_common.exceptions import UserEnrollmentNotFoundError from birdxplorer_common.models import ( LanguageIdentifier, @@ -27,13 +35,6 @@ PostgresStorageSettings, ) from birdxplorer_common.storage import Storage -from dotenv import load_dotenv -from fastapi.testclient import TestClient -from polyfactory import Use -from polyfactory.factories.pydantic_factory import ModelFactory -from polyfactory.pytest_plugin import register_fixture -from pydantic import HttpUrl -from pytest import fixture def gen_random_twitter_timestamp() -> int: @@ -272,7 +273,7 @@ def post_samples( x_user_id="1234567890123456782", x_user=x_user_samples[1], text="https://t.co/zzzzzzzzzzz/ https://t.co/wwwwwwwwwww/", - media_details=None, + media_details=[], created_at=1154922900000, like_count=10, repost_count=20, @@ -285,7 +286,7 @@ def post_samples( x_user_id="1234567890123456783", x_user=x_user_samples[2], text="empty", - media_details=None, + media_details=[], created_at=1154923900000, like_count=10, repost_count=20, diff --git a/common/birdxplorer_common/models.py b/common/birdxplorer_common/models.py index 8589a3e..908d52a 100644 --- a/common/birdxplorer_common/models.py +++ b/common/birdxplorer_common/models.py @@ -17,14 +17,9 @@ from uuid import UUID from pydantic import BaseModel as PydanticBaseModel -from pydantic import ( - ConfigDict, - GetCoreSchemaHandler, - HttpUrl, - TypeAdapter, - model_validator, -) +from pydantic import ConfigDict from pydantic import Field as PydanticField +from pydantic import GetCoreSchemaHandler, HttpUrl, TypeAdapter, model_validator from pydantic.alias_generators import to_camel from pydantic.main import IncEx from pydantic_core import core_schema diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index 6b527c1..181c81d 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -7,20 +7,15 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship from sqlalchemy.types import CHAR, DECIMAL, JSON, Integer, String, Uuid +from .models import BinaryBool, LanguageIdentifier +from .models import Link as LinkModel +from .models import LinkId, Media, MediaDetails, MediaType, NonNegativeInt +from .models import Note as NoteModel +from .models import NoteId, NotesClassification, NotesHarmful, ParticipantId +from .models import Post as PostModel +from .models import PostId, SummaryString +from .models import Topic as TopicModel from .models import ( - BinaryBool, - LanguageIdentifier, - LinkId, - Media, - MediaDetails, - MediaType, - NonNegativeInt, - NoteId, - NotesClassification, - NotesHarmful, - ParticipantId, - PostId, - SummaryString, TopicId, TopicLabel, TwitterTimestamp, @@ -28,10 +23,6 @@ UserId, UserName, ) -from .models import Link as LinkModel -from .models import Note as NoteModel -from .models import Post as PostModel -from .models import Topic as TopicModel from .models import XUser as XUserModel from .settings import GlobalSettings diff --git a/common/tests/conftest.py b/common/tests/conftest.py index a38c7cb..7a67411 100644 --- a/common/tests/conftest.py +++ b/common/tests/conftest.py @@ -3,6 +3,17 @@ from collections.abc import Generator from typing import List, Type +from dotenv import load_dotenv +from polyfactory import Use +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.pytest_plugin import register_fixture +from pytest import fixture +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session +from sqlalchemy.sql import text + from birdxplorer_common.models import ( Link, Media, @@ -26,16 +37,6 @@ TopicRecord, XUserRecord, ) -from dotenv import load_dotenv -from polyfactory import Use -from polyfactory.factories.pydantic_factory import ModelFactory -from polyfactory.pytest_plugin import register_fixture -from pytest import fixture -from sqlalchemy import create_engine -from sqlalchemy.engine import Engine -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import Session -from sqlalchemy.sql import text def gen_random_twitter_timestamp() -> int: @@ -316,7 +317,7 @@ def post_samples( x_user_id="1234567890123456782", x_user=x_user_samples[1], text="https://t.co/zzzzzzzzzzz/ https://t.co/wwwwwwwwwww/", - media_details=None, + media_details=[], created_at=1154922900000, like_count=10, repost_count=20, @@ -329,7 +330,7 @@ def post_samples( x_user_id="1234567890123456783", x_user=x_user_samples[2], text="empty", - media_details=None, + media_details=[], created_at=1154923900000, like_count=10, repost_count=20, @@ -451,6 +452,27 @@ def x_user_records_sample( yield res +@fixture +def media_records_sample( + media_samples: List[Media], + engine_for_test: Engine, +) -> Generator[List[MediaRecord], None, None]: + res = [ + MediaRecord( + media_key=d.media_key, + url=d.url, + type=d.type, + width=d.width, + height=d.height, + ) + for d in media_samples + ] + with Session(engine_for_test) as sess: + sess.add_all(res) + sess.commit() + yield res + + @fixture def link_records_sample( link_samples: List[Link], @@ -467,23 +489,12 @@ def link_records_sample( def post_records_sample( x_user_records_sample: List[XUserRecord], media_records_sample: List[MediaRecord], + link_records_sample: List[LinkRecord], link_samples: List[Link], post_samples: List[Post], engine_for_test: Engine, ) -> Generator[List[PostRecord], None, None]: - res = [ - PostRecord( - post_id=d.post_id, - user_id=d.x_user_id, - text=d.text, - media_details=d.media_details, - created_at=d.created_at, - like_count=d.like_count, - repost_count=d.repost_count, - impression_count=d.impression_count, - ) - for d in post_samples - ] + res = [] with Session(engine_for_test) as sess: for post in post_samples: inst = PostRecord( diff --git a/common/tests/test_storage.py b/common/tests/test_storage.py index 07d6fb2..ce854b1 100644 --- a/common/tests/test_storage.py +++ b/common/tests/test_storage.py @@ -45,7 +45,7 @@ def test_get_topic_list( [dict(search_url=HttpUrl("https://example.com/sh3")), [2, 3]], [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]], [dict(offset=1, limit=1, search_text="https://t.co/xxxxxxxxxxx/"), [2]], - [dict(with_media=True), [0, 1, 2]], + [dict(with_media=True), [0, 1, 2, 3, 4]], [dict(post_ids=[PostId.from_str("2234567890123456781")], with_media=False), [0]], ], )