Skip to content

Commit

Permalink
Migrate from pymongo to motor mongo driver
Browse files Browse the repository at this point in the history
  • Loading branch information
ankith26 committed Jun 15, 2024
1 parent 4b488a6 commit 263bbc1
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 58 deletions.
85 changes: 27 additions & 58 deletions bot/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
- `query()`: Returns user details, uses Discord ID to find in DB.
- `roll()`: Returns user details, uses roll number to find in DB.
- `roll_or_query_error()`: Replies eror message if server is not academic or author is not a bot admin.
- `on_ready()`: Logs a message when the bot joins a server.
- `main()`: Reads server config, loads DB and starts bot.
- `on_ready()`: Executes startup actions.
"""

Expand All @@ -37,15 +36,15 @@
import discord
from discord.ext import commands

from pymongo import MongoClient, database
from motor.motor_asyncio import AsyncIOMotorClient

from config_verification import ConfigEntry, server_configs

if not server_configs:
sys.exit(1)

TOKEN = os.getenv("DISCORD_TOKEN")
MONGO_DATABASE = os.getenv("MONGO_DATABASE")
TOKEN = os.environ["DISCORD_TOKEN"]
MONGO_DATABASE = os.environ["MONGO_DATABASE"]
MONGO_URI = os.getenv("MONGO_URI")

PROTOCOL = os.getenv("PROTOCOL")
Expand All @@ -68,7 +67,10 @@
bot = commands.Bot(command_prefix=".", intents=intent)
# to get message privelege

db: database.Database | None = None # assigned in main function
mongo_client = AsyncIOMotorClient(
f"{MONGO_URI}/{MONGO_DATABASE}?retryWrites=true&w=majority"
)
users = mongo_client[MONGO_DATABASE]["users"]

# Yes, global variable. Not the most ideal thing but is efficient
token_to_id: dict[str, tuple[int, float]] = {}
Expand All @@ -89,10 +91,6 @@ async def webserver():
"""

async def authenticate(request: web.Request):
if db is None:
# should not happen, but if it somehow does, it's a server error
return web.Response(status=500)

token = request.match_info["token"]
try:
discord_id = str(token_to_id.pop(token)[0])
Expand All @@ -113,7 +111,7 @@ async def authenticate(request: web.Request):
# client sent bad request
return web.Response(status=400)

db.users.update_one(search, {"$set": updated}, upsert=True)
await users.update_one(search, {"$set": updated}, upsert=True)
return web.Response()

app = web.Application()
Expand All @@ -125,20 +123,20 @@ async def authenticate(request: web.Request):
await web.TCPSite(runner, BOT_PRIVATE_IP, 80).start()


def get_users_from_discordid(user_id: int):
async def get_users_from_discordid(user_id: int):
"""
Finds users from the database, given their ID and returns
a list containing those users.
"""
users: list[DBEntry] = []
if db is not None:
users = list(db.users.find({"discordId": str(user_id)}))
return users
ret: list[DBEntry] = []
async for document in users.find({"discordId": str(user_id)}):
ret.append(document)
return ret


def is_verified(user_id: int):
async def is_verified(user_id: int):
"""Checks if any user with the given ID exists in the DB or not."""
return True if get_users_from_discordid(user_id) else False
return bool(await get_users_from_discordid(user_id))


@commands.check
Expand All @@ -147,9 +145,9 @@ def check_bot_admin(ctx: commands.Context):
return ctx.author.id in BOT_ADMINS


def get_realname_from_discordid(user_id: int):
async def get_realname_from_discordid(user_id: int):
"""Returns the real name of the first user who matches the given ID."""
users = get_users_from_discordid(user_id)
users = await get_users_from_discordid(user_id)
assert users
return users[0]["name"]

Expand Down Expand Up @@ -202,7 +200,7 @@ async def set_nickname(member: discord.Member, server_config: ConfigEntry):
set the given user's nickname to their name fetched from the database.
"""
if server_config["setrealname"]:
realname = get_realname_from_discordid(member.id)
realname = await get_realname_from_discordid(member.id)
await member.edit(nick=realname)


Expand Down Expand Up @@ -260,7 +258,7 @@ async def verify_user(ctx: commands.Context):
return

author = ctx.message.author
if is_verified(author.id):
if await is_verified(author.id):
# user has already previously verified
await post_verification(ctx, author)
return
Expand Down Expand Up @@ -294,7 +292,7 @@ async def verify_user(ctx: commands.Context):
while time.time() < expire_time and token in token_to_id:
await asyncio.sleep(1)

if token not in token_to_id and is_verified(author.id):
if token not in token_to_id and await is_verified(author.id):
await post_verification(ctx, author)
return

Expand Down Expand Up @@ -359,15 +357,7 @@ async def query(
If present, replies with their name, email and roll number. Otherwise
replies telling the author that the mentioned user is not registed with CAS.
"""
if db is None:
await ctx.reply(
"The bot is currently initializing and the command cannot be processed.\n"
"Please wait for some time and then try again.",
ephemeral=True,
)
return

user = db.users.find_one({"discordId": str(identifier.id)})
user = await users.find_one({"discordId": str(identifier.id)})
if user:
await ctx.reply(
f"Name: {user['name']}\nEmail: {user['email']}\nRoll Number: {user['rollno']}",
Expand All @@ -394,15 +384,7 @@ async def roll(
Same as the `query` command, except the user is mentioned by roll number
instead of Discord ID.
"""
if db is None:
await ctx.reply(
"The bot is currently initializing and the command cannot be processed.\n"
"Please wait for some time and then try again.",
ephemeral=True,
)
return

user = db.users.find_one({"rollno": str(identifier)})
user = await users.find_one({"rollno": str(identifier)})
if user:
await ctx.reply(
f"Name: {user['name']}\nEmail: {user['email']}\nRoll Number: {user['rollno']}",
Expand Down Expand Up @@ -457,24 +439,11 @@ async def on_ready():
except Exception as e:
print(e)

bot.loop.create_task(webserver())


def main():
"""
First it checks if each server has a valid configuration. If not, it exits with an error.
Otherwise, It iniates a client for a MongoDB instance and fetches the database from there,
setting the global variable `db`. Then it starts the bot.
"""
global db # pylint: disable=global-statement
users_count = await users.count_documents({})
print(f"Connected to database! The collection `users` has {users_count} documents.")

mongo_client = MongoClient(
f"{MONGO_URI}/{MONGO_DATABASE}?retryWrites=true&w=majority"
)
db = mongo_client.get_database(MONGO_DATABASE)

bot.run(TOKEN)
bot.loop.create_task(webserver())


if __name__ == "__main__":
main()
bot.run(TOKEN)
1 change: 1 addition & 0 deletions bot/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ discord.py==2.3.2
dnspython==2.6.1
frozenlist==1.4.1
idna==3.7
motor==3.4.0
multidict==6.0.5
pymongo==4.7.3
yarl==1.9.4

0 comments on commit 263bbc1

Please sign in to comment.