Skip to content

Commit

Permalink
Merge pull request #116 from codeforjapan/feat/issue-104-url-data-model
Browse files Browse the repository at this point in the history
Feat/issue 104 url data model
  • Loading branch information
yu23ki14 authored Oct 6, 2024
2 parents 4b8fc14 + ddb0846 commit 943a010
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 29 deletions.
55 changes: 53 additions & 2 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from birdxplorer_common.exceptions import UserEnrollmentNotFoundError
from birdxplorer_common.models import (
LanguageIdentifier,
Link,
Note,
NoteId,
ParticipantId,
Expand Down Expand Up @@ -66,6 +67,11 @@ class PostFactory(ModelFactory[Post]):
__model__ = Post


@register_fixture(name="link_factory")
class LinkFactory(ModelFactory[Link]):
__model__ = Link


@fixture
def user_enrollment_samples(
user_enrollment_factory: UserEnrollmentFactory,
Expand All @@ -84,6 +90,17 @@ def topic_samples(topic_factory: TopicFactory) -> Generator[List[Topic], None, N
yield topics


@fixture
def link_samples(link_factory: LinkFactory) -> Generator[List[Link], None, None]:
links = [
link_factory.build(link_id="9f56ee4a-6b36-b79c-d6ca-67865e54bbd5", url="https://example.com/sh0"),
link_factory.build(link_id="f5b0ac79-20fe-9718-4a40-6030bb62d156", url="https://example.com/sh1"),
link_factory.build(link_id="76a0ac4a-a20c-b1f4-1906-d00e2e8f8bf8", url="https://example.com/sh2"),
link_factory.build(link_id="6c352be8-eca3-0d96-55bf-a9bbef1c0fc2", url="https://example.com/sh3"),
]
yield links


@fixture
def note_samples(note_factory: NoteFactory, topic_samples: List[Topic]) -> Generator[List[Note], None, None]:
notes = [
Expand Down Expand Up @@ -160,10 +177,13 @@ def x_user_samples(x_user_factory: XUserFactory) -> Generator[List[XUser], None,


@fixture
def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Generator[List[Post], None, None]:
def post_samples(
post_factory: PostFactory, x_user_samples: List[XUser], link_samples: List[Link]
) -> Generator[List[Post], None, None]:
posts = [
post_factory.build(
post_id="2234567890123456781",
link=None,
x_user_id="1234567890123456781",
x_user=x_user_samples[0],
text="""\
Expand All @@ -175,9 +195,11 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene
like_count=10,
repost_count=20,
impression_count=30,
links=[link_samples[0]],
),
post_factory.build(
post_id="2234567890123456791",
link=None,
x_user_id="1234567890123456781",
x_user=x_user_samples[0],
text="""\
Expand All @@ -189,18 +211,47 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene
like_count=10,
repost_count=20,
impression_count=30,
links=[link_samples[1]],
),
post_factory.build(
post_id="2234567890123456801",
link=None,
x_user_id="1234567890123456782",
x_user=x_user_samples[1],
text="""\
次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ #旅行 #バケーション""",
次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ https://t.co/wwwwwwwwwww/ #旅行 #バケーション""",
media_details=None,
created_at=1154921800000,
like_count=10,
repost_count=20,
impression_count=30,
links=[link_samples[0], link_samples[3]],
),
post_factory.build(
post_id="2234567890123456811",
link=None,
x_user_id="1234567890123456782",
x_user=x_user_samples[1],
text="https://t.co/zzzzzzzzzzz/ https://t.co/wwwwwwwwwww/",
media_details=None,
created_at=1154922900000,
like_count=10,
repost_count=20,
impression_count=30,
links=[link_samples[2], link_samples[3]],
),
post_factory.build(
post_id="2234567890123456821",
link=None,
x_user_id="1234567890123456783",
x_user=x_user_samples[2],
text="empty",
media_details=None,
created_at=1154923900000,
like_count=10,
repost_count=20,
impression_count=30,
links=[],
),
]
yield posts
Expand Down
9 changes: 6 additions & 3 deletions api/tests/routers/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def test_posts_get_limit_and_offset(client: TestClient, post_samples: List[Post]
res_json = response.json()
assert res_json == {
"data": [json.loads(d.model_dump_json()) for d in post_samples[1:3]],
"meta": {"next": None, "prev": "http://testserver/api/v1/data/posts?offset=0&limit=2"},
"meta": {
"next": "http://testserver/api/v1/data/posts?offset=3&limit=2",
"prev": "http://testserver/api/v1/data/posts?offset=0&limit=2",
},
}


Expand Down Expand Up @@ -72,7 +75,7 @@ def test_posts_get_has_created_at_filter_start(client: TestClient, post_samples:
assert response.status_code == 200
res_json = response.json()
assert res_json == {
"data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2)],
"data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2, 3, 4)],
"meta": {"next": None, "prev": None},
}

Expand All @@ -99,7 +102,7 @@ def test_posts_get_created_at_start_filter_accepts_integer(client: TestClient, p
assert response.status_code == 200
res_json = response.json()
assert res_json == {
"data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2)],
"data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2, 3, 4)],
"meta": {"next": None, "prev": None},
}

Expand Down
73 changes: 72 additions & 1 deletion common/birdxplorer_common/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from enum import Enum
from random import Random
from typing import Any, Dict, List, Literal, Optional, Type, TypeAlias, TypeVar, Union
from uuid import UUID

from pydantic import BaseModel as PydanticBaseModel
from pydantic import ConfigDict, GetCoreSchemaHandler, HttpUrl, TypeAdapter
from pydantic import (
ConfigDict,
GetCoreSchemaHandler,
HttpUrl,
TypeAdapter,
model_validator,
)
from pydantic.alias_generators import to_camel
from pydantic.main import IncEx
from pydantic_core import core_schema
Expand Down Expand Up @@ -677,6 +685,68 @@ class XUser(BaseModel):
MediaDetails: TypeAlias = List[HttpUrl] | None


class LinkId(UUID):
"""
>>> LinkId("53dc4ed6-fc9b-54ef-1afa-90f1125098c5")
LinkId('53dc4ed6-fc9b-54ef-1afa-90f1125098c5')
>>> LinkId(UUID("53dc4ed6-fc9b-54ef-1afa-90f1125098c5"))
LinkId('53dc4ed6-fc9b-54ef-1afa-90f1125098c5')
"""

def __init__(
self,
hex: str | None = None,
int: int | None = None,
) -> None:
if isinstance(hex, UUID):
hex = str(hex)
super().__init__(hex, int=int)

@classmethod
def from_url(cls, url: HttpUrl) -> "LinkId":
"""
>>> LinkId.from_url("https://example.com/")
LinkId('d5d15194-6574-0c01-8f6f-15abd72b2cf6')
"""
random_number_generator = Random()
random_number_generator.seed(str(url).encode("utf-8"))
return LinkId(int=random_number_generator.getrandbits(128))

@classmethod
def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
return core_schema.no_info_plain_validator_function(
cls.validate,
serialization=core_schema.plain_serializer_function_ser_schema(cls.serialize, when_used="json"),
)

@classmethod
def validate(cls, v: Any) -> "LinkId":
return cls(v)

def serialize(self) -> str:
return str(self)


class Link(BaseModel):
"""
>>> Link.model_validate_json('{"linkId": "d5d15194-6574-0c01-8f6f-15abd72b2cf6", "url": "https://example.com"}')
Link(link_id=LinkId('d5d15194-6574-0c01-8f6f-15abd72b2cf6'), url=Url('https://example.com/'))
>>> Link(url="https://example.com/")
Link(link_id=LinkId('d5d15194-6574-0c01-8f6f-15abd72b2cf6'), url=Url('https://example.com/'))
>>> Link(link_id=UUID("d5d15194-6574-0c01-8f6f-15abd72b2cf6"), url="https://example.com/")
Link(link_id=LinkId('d5d15194-6574-0c01-8f6f-15abd72b2cf6'), url=Url('https://example.com/'))
""" # noqa: E501

link_id: LinkId
url: HttpUrl

@model_validator(mode="before")
def validate_link_id(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "link_id" not in values:
values["link_id"] = LinkId.from_url(values["url"])
return values


class Post(BaseModel):
post_id: PostId
link: Optional[HttpUrl] = None
Expand All @@ -688,6 +758,7 @@ class Post(BaseModel):
like_count: NonNegativeInt
repost_count: NonNegativeInt
impression_count: NonNegativeInt
links: List[Link] = []


class PaginationMeta(BaseModel):
Expand Down
24 changes: 22 additions & 2 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from sqlalchemy import ForeignKey, create_engine, func, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship
from sqlalchemy.types import CHAR, DECIMAL, JSON, Integer, String
from sqlalchemy.types import CHAR, DECIMAL, JSON, Integer, String, Uuid

from .models import BinaryBool, LanguageIdentifier, MediaDetails, NonNegativeInt
from .models import BinaryBool, LanguageIdentifier
from .models import Link as LinkModel
from .models import LinkId, MediaDetails, NonNegativeInt
from .models import Note as NoteModel
from .models import NoteId, NotesClassification, NotesHarmful, ParticipantId
from .models import Post as PostModel
Expand All @@ -34,6 +36,7 @@ def adapt_pydantic_http_url(url: AnyUrl) -> AsIs:

class Base(DeclarativeBase):
type_annotation_map = {
LinkId: Uuid,
TopicId: Integer,
TopicLabel: JSON,
NoteId: String,
Expand Down Expand Up @@ -88,6 +91,21 @@ class XUserRecord(Base):
following_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)


class LinkRecord(Base):
__tablename__ = "links"

link_id: Mapped[LinkId] = mapped_column(primary_key=True)
url: Mapped[HttpUrl] = mapped_column(nullable=False, index=True)


class PostLinkAssociation(Base):
__tablename__ = "post_link"

post_id: Mapped[PostId] = mapped_column(ForeignKey("posts.post_id"), primary_key=True)
link_id: Mapped[LinkId] = mapped_column(ForeignKey("links.link_id"), primary_key=True)
link: Mapped[LinkRecord] = relationship()


class PostRecord(Base):
__tablename__ = "posts"

Expand All @@ -100,6 +118,7 @@ class PostRecord(Base):
like_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
repost_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
impression_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
links: Mapped[List[PostLinkAssociation]] = relationship()


class RowNoteRecord(Base):
Expand Down Expand Up @@ -196,6 +215,7 @@ def _post_record_to_model(cls, post_record: PostRecord) -> PostModel:
like_count=post_record.like_count,
repost_count=post_record.repost_count,
impression_count=post_record.impression_count,
links=[LinkModel(link_id=link.link_id, url=link.link.url) for link in post_record.links],
)

def get_user_enrollment_by_participant_id(self, participant_id: ParticipantId) -> UserEnrollment:
Expand Down
1 change: 1 addition & 0 deletions common/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"sqlalchemy",
"pydantic_settings",
"JSON-log-formatter",
"ulid-py",
]

[project.urls]
Expand Down
Loading

0 comments on commit 943a010

Please sign in to comment.