Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add starboard #42

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ async def error_handler(ctx: arc.GatewayContext, exc: Exception) -> None:
raise exc


@client.set_startup_hook
@client.add_startup_hook
async def startup_hook(_: arc.GatewayClient) -> None:
await init_db()
35 changes: 32 additions & 3 deletions src/database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import BigInteger, Integer, SmallInteger
from sqlalchemy.ext.asyncio import AsyncAttrs, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

from src.config import DB_HOST, DB_NAME, DB_PASSWORD, DB_USER

Expand All @@ -8,10 +9,38 @@
)


Base = declarative_base()
class Base(AsyncAttrs, DeclarativeBase):
pass


Session = async_sessionmaker(bind=engine)


async def init_db() -> None:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)


# TODO: add reprs?


class StarboardSettings(Base):
__tablename__ = "starboard_settings"

guild_id: Mapped[int] = mapped_column(BigInteger, nullable=False, primary_key=True)
channel_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
threshold: Mapped[int] = mapped_column(SmallInteger, nullable=False, default=3)
error: Mapped[int | None] = mapped_column(SmallInteger, nullable=True, default=None)


class Starboard(Base):
__tablename__ = "starboard"

id: Mapped[int] = mapped_column(
Integer, nullable=False, primary_key=True, autoincrement=True
)
channel_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
message_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
stars: Mapped[int] = mapped_column(SmallInteger, nullable=False)
starboard_channel_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
starboard_message_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
253 changes: 253 additions & 0 deletions src/extensions/starboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
from __future__ import annotations

import enum
import logging

import arc
import hikari
from sqlalchemy import insert, select, update

from src.config import ROLE_IDS
from src.database import Session, Starboard, StarboardSettings
from src.hooks import restrict_to_roles

logger = logging.getLogger(__name__)

plugin = arc.GatewayPlugin("Starboard")


class StarboardSettingsError(enum.IntEnum):
CHANNEL_FORBIDDEN = 0
CHANNEL_NOT_FOUND = 1


# TODO: handle star remove
@plugin.listen()
async def on_reaction(
event: hikari.GuildReactionAddEvent,
) -> None:
logger.debug("Received guild reaction add event")

if event.emoji_name != "⭐":
return

message = await plugin.client.rest.fetch_message(event.channel_id, event.message_id)
star_count = sum(r.emoji == "⭐" for r in message.reactions)

# get starboard settings
async with Session() as session:
stmt = select(StarboardSettings).where(
StarboardSettings.guild_id == event.guild_id
)
result = await session.execute(stmt)
settings = result.scalars().first()

if (
not settings
or star_count < settings.threshold
or not settings.channel_id
or settings.error is not None
or event.channel_id == settings.channel_id
):
return

# get starred message
async with Session() as session:
stmt = select(Starboard).where(Starboard.message_id == event.message_id)
result = await session.execute(stmt)
starboard = result.scalars().first()

embed = hikari.Embed(
title=f"⭐ {star_count} - *jump to message*",
url=message.make_link(event.guild_id),
description=message.content,
timestamp=message.timestamp,
).set_author(
name=message.author.username,
icon=message.author.display_avatar_url,
)

images = [
att
for att in message.attachments
if att.media_type and att.media_type.startswith("image")
]
if images:
embed.set_image(images[0])

embeds = [embed, *message.embeds[:1]]

await create_starboard_message(event, settings, starboard, embeds, star_count)


async def create_starboard_message(
event: hikari.GuildReactionAddEvent,
settings: StarboardSettings,
starboard: Starboard | None,
embeds: list[hikari.Embed],
star_count: int,
) -> None:
assert settings.channel_id # already verified to exist

try:
if not starboard:
# starboard message does not exist, create it
logger.debug("Creating message")

message = await plugin.client.rest.create_message(
settings.channel_id,
embeds=embeds,
)

async with Session() as session:
session.add(
Starboard(
channel_id=event.channel_id,
message_id=event.message_id,
stars=star_count,
starboard_channel_id=settings.channel_id,
starboard_message_id=message.id,
)
)
await session.commit()
else:
# starboard message should exist
try:
# attempt to edit it
logger.debug("Editing message")

await plugin.client.rest.edit_message(
starboard.starboard_channel_id,
starboard.starboard_message_id,
embeds=embeds,
)
except hikari.NotFoundError:
# the message does not exist, create a new one
logger.debug("Starboard message does not exist, creating new")

message = await plugin.client.rest.create_message(
settings.channel_id,
embeds=embeds,
)
async with Session() as session:
stmt = (
update(Starboard)
.where(
Starboard.starboard_message_id
== starboard.starboard_message_id
)
.values(
starboard_message_id=message.id,
)
)
await session.execute(stmt)
await session.commit()

except hikari.ForbiddenError:
# bot cannot access the starboard channel
logger.debug("Can't access starboard channel")

async with Session() as session:
stmt = (
update(StarboardSettings)
.where(StarboardSettings.guild_id == event.guild_id)
.values(error=StarboardSettingsError.CHANNEL_FORBIDDEN)
)
await session.execute(stmt)
await session.commit()
except hikari.NotFoundError:
# the starboard channel does not exist
logger.debug("Can't find starboard channel")

async with Session() as session:
stmt = (
update(StarboardSettings)
.where(StarboardSettings.guild_id == event.guild_id)
.values(error=StarboardSettingsError.CHANNEL_NOT_FOUND)
)
await session.execute(stmt)
await session.commit()


@plugin.include
@arc.with_hook(restrict_to_roles(role_ids=[ROLE_IDS["committee"]]))
@arc.slash_command(
"starboard",
"Edit or view starboard settings.",
default_permissions=hikari.Permissions.MANAGE_GUILD,
)
async def starboard_settings(
ctx: arc.GatewayContext,
channel: arc.Option[
hikari.TextableGuildChannel | None,
arc.ChannelParams("The channel to post starboard messages to."),
] = None,
threshold: arc.Option[
int | None,
arc.IntParams(
"The minimum number of stars before this message is posted to the starboard",
min=1,
),
] = None,
) -> None:
assert ctx.guild_id

async with Session() as session:
stmt = select(StarboardSettings).where(
StarboardSettings.guild_id == ctx.guild_id
)
result = await session.execute(stmt)
settings = result.scalars().first()

if not channel and not threshold:
if not settings:
await ctx.respond(
"This server has no starboard settings.",
flags=hikari.MessageFlag.EPHEMERAL,
)
else:
embed = hikari.Embed(
title="Starboard Settings",
description=(
f"**Channel:** <#{settings.channel_id}>\n"
f"**Threshold:** {settings.threshold}\n"
),
)
if settings.error is not None:
error = StarboardSettingsError(settings.error)
assert embed.description
embed.description += f"**Error:** {error.name.replace("_", " ").title()}"

await ctx.respond(embed)

return

# TODO: use returning statement to get back new row
# then send embed

if not settings:
# TODO: use add not insert
stmt = insert(StarboardSettings).values(guild_id=ctx.guild_id)
else:
stmt = update(StarboardSettings).where(
StarboardSettings.guild_id == ctx.guild_id
)

# TODO: simplify logic
if channel and threshold:
stmt = stmt.values(channel_id=channel.id, threshold=threshold, error=None)
elif channel:
stmt = stmt.values(channel_id=channel.id, error=None)
elif threshold:
stmt = stmt.values(threshold=threshold)

async with Session() as session:
await session.execute(stmt)
await session.commit()

await ctx.respond("Settings updated.", flags=hikari.MessageFlag.EPHEMERAL)


@arc.loader
def loader(client: arc.GatewayClient) -> None:
client.add_plugin(plugin)
Loading