Skip to content

Commit

Permalink
fix: fix mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdragun authored and solumath committed Apr 4, 2024
1 parent fc3c92e commit ada88b9
Show file tree
Hide file tree
Showing 24 changed files with 74 additions and 71 deletions.
8 changes: 4 additions & 4 deletions cogs/bettermeme/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,12 @@ async def __repost_message(self, ctx: ReactionContext, reactions: list[disnake.R

if len(other_attachments) > 0:
# Files are getting send as files
files = [file for file in other_attachments if isinstance(file, disnake.File)]
files = files[:10] if files else None
files_list = [file for file in other_attachments if isinstance(file, disnake.File)]
files = files_list[:10] if files_list else None

# And urls as string in separated message
urls = [file for file in other_attachments if isinstance(file, str)]
urls = "\n".join(urls) if urls else None
urls_list = [file for file in other_attachments if isinstance(file, str)]
urls = "\n".join(urls_list) if urls_list else None

secondary_message = await self.repost_channel.send(urls, files=files)
secondary_message_id = secondary_message.id
Expand Down
17 changes: 9 additions & 8 deletions cogs/fitwide/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from . import features
from .messages_cz import MessagesCZ

user_logins = []
user_logins: list[str] = []


async def autocomp_user_logins(inter: disnake.ApplicationCommandInteraction, user_input: str):
Expand Down Expand Up @@ -91,7 +91,7 @@ async def role_check(

year_roles = {year: disnake.utils.get(guild.roles, name=year) for year in years}

weird_members = {
weird_members: dict[disnake.Role, dict[disnake.Role, list[disnake.Member]]] = {
year_y: {year_x: [] for year_x in year_roles.values()} for year_y in year_roles.values()
}

Expand Down Expand Up @@ -125,7 +125,7 @@ async def role_check(
correct_role = disnake.utils.get(guild.roles, name=year)

if correct_role not in member.roles:
for role_name, role in year_roles.items():
for role in year_roles.values():
if role in member.roles and correct_role in weird_members[role].keys():
weird_members[role][correct_role].append(member)
break
Expand Down Expand Up @@ -336,11 +336,11 @@ async def update_db(
old_logins = [value for (value,) in login_results]

for line in data:
line = line.split(":")
login = line[0]
name = line[4].split(",", 1)[0]
line_split = line.split(":")
login = line_split[0]
name = line_split[4].split(",", 1)[0]
try:
year_fields = line[4].split(",")[1].split(" ")
year_fields = line_split[4].split(",")[1].split(" ")
year = " ".join(year_fields if "mail=" not in year_fields[-1] else year_fields[:-1])
mail = year_fields[-1].replace("mail=", "") if "mail=" in year_fields[-1] else None
except IndexError:
Expand All @@ -364,7 +364,8 @@ async def update_db(
session.merge(person)

cnt_new = 0
for person in session.query(ValidPersonDB):
all_persons = ValidPersonDB.get_all_persons()
for person in all_persons:
if person.login not in found_logins:
try:
# check for muni
Expand Down
6 changes: 3 additions & 3 deletions cogs/help/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def add_fields(self, embed, items):
value += item["description"] if item["description"] else ""
embed.add_field(name=name, value=value if value else None, inline=False)

async def api(self, message: commands.Context, params: list):
async def api(self, message: commands.Context, params: dict[str, str]):
"""Sending commands help to grillbot"""
mock_message = copy.copy(message)
mock_view = commands.view.StringView("")
Expand Down Expand Up @@ -132,9 +132,9 @@ async def api(self, message: commands.Context, params: list):

@cooldowns.default_cooldown
@commands.command(aliases=["god"], brief=MessagesCZ.title)
async def help(self, ctx: commands.Context, *command):
async def help(self, ctx: commands.Context, *command_list: str):
# Subcommand help
command = " ".join(command)
command = " ".join(command_list)
if command:
command_obj = self.bot.get_command(command)
if not command_obj:
Expand Down
3 changes: 2 additions & 1 deletion cogs/icons/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ async def on_button_click(self, inter: disnake.MessageInteraction):

async def cog_slash_command_error(
self, inter: disnake.ApplicationCommandInteraction, error: Exception
) -> None:
) -> bool:
if isinstance(error, utils.PCommandOnCooldown):
await inter.response.send_message(str(error), ephemeral=True)
return True
return False
2 changes: 1 addition & 1 deletion cogs/icons/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def do_set_icon(icon: disnake.Role, user: disnake.Member) -> None:
await user.add_roles(icon)


async def icon_autocomp(inter: disnake.ApplicationCommandInteraction, partial: str) -> str:
async def icon_autocomp(inter: disnake.ApplicationCommandInteraction, partial: str) -> list[str]:
icon_roles = get_icon_roles(inter.guild)
names = (icon_name(role) for role in icon_roles)
return [name for name in names if partial.casefold() in name.casefold()]
Expand Down
5 changes: 4 additions & 1 deletion cogs/moderation/features.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from datetime import datetime, timezone
from typing import TypeAlias

import disnake

import utils

from .views import View

SLOWMODE_CHANNEL_TYPES = disnake.TextChannel | disnake.Thread | disnake.VoiceChannel | disnake.ForumChannel
SLOWMODE_CHANNEL_TYPES: TypeAlias = (
disnake.TextChannel | disnake.Thread | disnake.VoiceChannel | disnake.ForumChannel
)

MODERATION_TRUE = "moderation:resolve:true"
MODERATION_FALSE = "moderation:resolve:false"
Expand Down
12 changes: 6 additions & 6 deletions cogs/poll/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ async def poll_create(
poll_options: dict,
poll_view: disnake.ui.View,
):
inter = poll_args.get("inter")
inter: disnake.ApplicationCommandInteraction = poll_args.get("inter")
attachment = poll_args.get("attachment")
anonymous = poll_args.get("anonymous")
end = poll_args.get("end")
end_str: str = poll_args.get("end")

is_end_valid, end = await poll_features.check_end(inter, end)
is_end_valid, end = await poll_features.check_end(inter, end_str)
if not is_end_valid:
return

Expand Down Expand Up @@ -186,7 +186,7 @@ async def list_polls(
await inter.send(MessagesCZ.no_active_polls)
return

content = ""
content: str = ""
for poll in polls:
message = await utils.get_message_from_url(self.bot, poll.message_url)
if not message or not message.embeds:
Expand All @@ -202,8 +202,8 @@ async def list_polls(
await inter.send(MessagesCZ.no_active_polls)
return

content = utils.cut_string_by_words(header + content, 1900, "\n")
for content_part in content:
content_list = utils.cut_string_by_words(header + content, 1900, "\n")
for content_part in content_list:
await inter.send(content_part, ephemeral=True)

async def task_end_poll(self, poll: PollDB) -> None:
Expand Down
2 changes: 1 addition & 1 deletion cogs/poll/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def check_end(inter: disnake.ApplicationCommandInteraction, end: str) -> t
return True, end


async def parse_attachment(attachment: disnake.Attachment) -> Union[str, disnake.File, None]:
async def parse_attachment(attachment: disnake.Attachment) -> tuple[str | None, disnake.File | None]:
"""parses the attachment url to get the attachment as file"""
if attachment is None or attachment.content_type is None:
return None, None
Expand Down
6 changes: 2 additions & 4 deletions cogs/report/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
Cog implementing anonymous reporting from users.
"""

from typing import Optional

import disnake
from disnake.ext import commands

Expand All @@ -22,13 +20,13 @@ def __init__(self, bot: commands.Bot):
super().__init__()
self.bot = bot

async def check_blocked_bot(self, inter: disnake.Interaction) -> Optional[disnake.Message]:
async def check_blocked_bot(self, inter: disnake.Interaction) -> disnake.Message | None:
try:
dm_message = await inter.author.send(MessagesCZ.check_dm, view=TrashView())
return dm_message
except disnake.Forbidden:
await inter.send(MessagesCZ.blocked_bot(user=inter.author.id), ephemeral=True)
return
return None

@cooldowns.default_cooldown
@commands.message_command(name="Report message", guild_ids=[Base.config.guild_id])
Expand Down
11 changes: 5 additions & 6 deletions cogs/warden/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,15 @@ async def scan_history(self, ctx: commands.Context, limit: int | str):
limit: [all | <int>]
"""
# parse parameter
if limit == "all":
limit = None
else:
if limit != "all":
try:
limit = int(limit)
if limit < 1:
raise ValueError
except ValueError:
raise commands.BadArgument("Expected 'all' or positive integer")

messages = await ctx.channel.history(limit=limit).flatten()
messages = await ctx.channel.history(limit=limit if limit != "all" else None).flatten()

title = "**INITIATING...**\n\nLoaded {} messages"
await asyncio.sleep(0.5)
Expand Down Expand Up @@ -178,13 +176,14 @@ async def checkDuplicate(self, message: disnake.Message):
duplicate = post
hamming_min = hamming

duplicates[duplicate] = hamming_min
if duplicate is not None:
duplicates[duplicate] = hamming_min

for duplicate, hamming_min in duplicates.items():
if hamming_min <= self.limit_soft:
await self._announceDuplicate(message, duplicate, hamming_min)

async def _announceDuplicate(self, message: disnake.Message, original: object, hamming: int):
async def _announceDuplicate(self, message: disnake.Message, original: ImageDB, hamming: int):
"""Send message that a post is a original
original: object
hamming: Hamming distance between the image and closest database entry
Expand Down
2 changes: 1 addition & 1 deletion database/better_meme.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from database import database, session


class BetterMemeDB(database.base):
class BetterMemeDB(database.base): # type: ignore
__tablename__ = "bot_better_meme"

member_ID = Column(String, primary_key=True)
Expand Down
8 changes: 3 additions & 5 deletions database/contestvote.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from __future__ import annotations

from typing import Optional

from sqlalchemy import BigInteger, Column, String

from database import database, session


class ContestVoteDB(database.base):
class ContestVoteDB(database.base): # type: ignore
__tablename__ = "contest_vote"

user_id = Column(String, nullable=False)
Expand Down Expand Up @@ -35,8 +33,8 @@ def get_user(cls, user_id: str) -> ContestVoteDB:
return session.query(cls).filter_by(user_id=str(user_id)).first()

@classmethod
def get_contribution_author(cls, contribution_id: int) -> Optional[str]:
def get_contribution_author(cls, contribution_id: int) -> str | None:
contribution = session.query(cls).filter_by(contribution_id=contribution_id).first()
if contribution:
return contribution.user_id
return
return None
2 changes: 1 addition & 1 deletion database/hugs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
UserHugStats = namedtuple("UserHugStats", ("given", "received"))


class HugsTableDB(database.base):
class HugsTableDB(database.base): # type: ignore
__tablename__ = "bot_hugs"

member_id = Column(BIGINT, primary_key=True, nullable=False, unique=True, autoincrement=False)
Expand Down
7 changes: 3 additions & 4 deletions database/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
from __future__ import annotations

from datetime import datetime
from typing import List

from sqlalchemy import BigInteger, Column, DateTime, String
from sqlalchemy.orm import Query

from database import database, session


class ImageDB(database.base):
class ImageDB(database.base): # type: ignore
__tablename__ = "images"

attachment_id = Column(BigInteger, primary_key=True)
Expand Down Expand Up @@ -39,15 +38,15 @@ def add_image(cls, channel_id: int, message_id: int, attachment_id: int, dhash:
session.commit()

@classmethod
def getHash(cls, dhash: str) -> List[ImageDB]:
def getHash(cls, dhash: str) -> list[ImageDB]:
return session.query(cls).filter(cls.dhash == dhash).all()

@classmethod
def getByMessage(cls, message_id: int) -> ImageDB:
return session.query(cls).filter(cls.message_id == message_id).one_or_none()

@classmethod
def getAll(cls) -> Query:
def getAll(cls) -> list[ImageDB]:
return session.query(cls)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion database/meme_repost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from database import database, session


class MemeRepostDB(database.base):
class MemeRepostDB(database.base): # type: ignore
__tablename__ = "bot_meme_reposts"

original_message_id = Column(String, primary_key=True, nullable=False, unique=True)
Expand Down
4 changes: 2 additions & 2 deletions database/pin_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from database import database, session


class PinMapDB(database.base):
class PinMapDB(database.base): # type: ignore
__tablename__ = "bot_pin_map"

channel_id = Column(String, primary_key=True)
Expand All @@ -19,7 +19,7 @@ def find_channel_by_id(cls, channel_id: str) -> PinMapDB:

@classmethod
def add_or_update_channel(cls, channel_id: str, message_id: str) -> None:
item: cls = cls.find_channel_by_id(channel_id)
item: PinMapDB = cls.find_channel_by_id(channel_id)

if item is None:
item = cls(channel_id=channel_id, message_id=message_id)
Expand Down
14 changes: 7 additions & 7 deletions database/poll.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from database import database, session


class VoterDB(database.base):
class VoterDB(database.base): # type: ignore
__tablename__ = "voter"

id = Column(String, primary_key=True)
Expand All @@ -38,7 +38,7 @@ class PollType(IntEnum):
opinion = 3 # Opinion is agree/neutral/disagree poll where you can only vote for one option


class PollDB(database.base):
class PollDB(database.base): # type: ignore
__tablename__ = "poll"

id = Column(Integer, primary_key=True)
Expand Down Expand Up @@ -100,14 +100,14 @@ def get_pending_polls(cls) -> list[PollDB]:
return session.query(cls).filter(or_(cls.end is None, cls.end > datetime.now(timezone.utc)))

@classmethod
def get_pending_polls_by_type(cls, type: PollType) -> list[PollDB]:
return cls.get_pending_polls().filter(cls.poll_type == type).all()
def get_pending_polls_by_type(cls, type: int) -> list[PollDB]:
return cls.get_pending_polls().filter(cls.poll_type == type).all() # type: ignore

@classmethod
def get_author_id(cls, poll_id: int) -> str | None:
poll = cls.get(poll_id)
if not poll:
return
return None
return poll.author_id

def remove_voter(self, voter: VoterDB) -> None:
Expand Down Expand Up @@ -168,7 +168,7 @@ def get_winning_options(self) -> list[PollOptionDB]:
return winning_options


class PollOptionDB(database.base):
class PollOptionDB(database.base): # type: ignore
__tablename__ = "poll_option"

id = Column(Integer, primary_key=True)
Expand Down Expand Up @@ -201,7 +201,7 @@ def add(cls, text: str, emoji: str, poll_id: int) -> None:

def remove_voter(self, voter: VoterDB | str) -> None:
if isinstance(voter, str):
voter = VoterDB.get(voter, self.id)
voter = VoterDB.get(voter, self.id) or ""

if voter:
session.delete(voter)
Expand Down
Loading

0 comments on commit ada88b9

Please sign in to comment.