From c2ef7359d714b1e5d99cf5478afcca14780e48e5 Mon Sep 17 00:00:00 2001 From: Battlefield Duck Date: Wed, 1 Nov 2023 02:12:42 +0800 Subject: [PATCH] Support mongodb --- .gitignore | 4 +- discordgsm/database.py | 510 ++++++++++++++++++++++++++++++-------- discordgsm/environment.py | 2 +- discordgsm/gamedig.py | 4 + discordgsm/games.csv | 4 +- discordgsm/main.py | 21 +- discordgsm/server.py | 68 +++-- requirements.txt | 1 + 8 files changed, 471 insertions(+), 143 deletions(-) diff --git a/.gitignore b/.gitignore index 244e6ba..6163f80 100644 --- a/.gitignore +++ b/.gitignore @@ -130,8 +130,6 @@ dmypy.json # discordgsm data/logs/*txt +data/exports/ data/servers.db -data/servers.sql -node_modules public/static/guilds.json -sponsors.json diff --git a/discordgsm/database.py b/discordgsm/database.py index f50c44b..42d74e0 100644 --- a/discordgsm/database.py +++ b/discordgsm/database.py @@ -1,17 +1,24 @@ +from __future__ import annotations +from enum import Enum + import json import os +from pathlib import Path import sqlite3 import sys from argparse import ArgumentParser -from typing import Dict, List, Tuple +from pymongo import DeleteOne, MongoClient, UpdateMany, UpdateOne import psycopg2 from dotenv import load_dotenv + if __name__ == '__main__': from server import Server + from server import QueryServer else: from discordgsm.server import Server + from discordgsm.server import QueryServer load_dotenv() @@ -21,6 +28,19 @@ def stringify(data: dict): return json.dumps(data, ensure_ascii=False, separators=(',', ':')) +class Driver(Enum): + SQLite = 'sqlite' + PostgreSQL = 'pgsql' + MongoDB = 'mongodb' + + +drivers = [driver.value for driver in Driver] + + +class InvalidDriverError(Exception): + pass + + class Database: """Database with connection and cursor prepared""" @@ -34,20 +54,29 @@ def __exit__(self, type, value, traceback): self.close() def connect(self): - DB_CONNECTION = os.getenv('DB_CONNECTION', '') - DATABASE_URL = os.getenv('DATABASE_URL', '') - - if DATABASE_URL.startswith('postgres://') or DATABASE_URL.startswith('postgresql://') or DB_CONNECTION == 'pgsql': - self.type = 'pgsql' - self.conn = psycopg2.connect(DATABASE_URL, sslmode=os.getenv('POSTGRES_SSL_MODE', 'require')) + DB_CONNECTION: str = os.getenv('DB_CONNECTION', 'sqlite') + DATABASE_URL: str = os.getenv('DATABASE_URL', '') + + if DATABASE_URL.startswith('postgres://') or DATABASE_URL.startswith('postgresql://') or DB_CONNECTION == Driver.PostgreSQL.value: + self.driver = Driver.PostgreSQL + self.conn = psycopg2.connect( + DATABASE_URL, sslmode=os.getenv('POSTGRES_SSL_MODE', 'require')) + elif DB_CONNECTION == Driver.MongoDB.value: + self.driver = Driver.MongoDB + self.conn = MongoClient(DATABASE_URL) + self.collection = self.conn.get_default_database()['servers'] else: - self.type = 'sqlite' - self.conn = sqlite3.connect(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data', 'servers.db')) + self.driver = Driver.SQLite + self.conn = sqlite3.connect(os.path.join(os.path.dirname( + os.path.realpath(__file__)), '..', 'data', 'servers.db')) def create_table_if_not_exists(self): + if self.driver == Driver.MongoDB: + return + cursor = self.cursor() - if self.type == 'pgsql': + if self.driver == Driver.PostgreSQL: cursor.execute(''' CREATE TABLE IF NOT EXISTS servers ( id BIGSERIAL PRIMARY KEY, @@ -64,7 +93,7 @@ def create_table_if_not_exists(self): style_id TEXT NOT NULL, style_data TEXT NOT NULL )''') - elif self.type == 'sqlite': + elif self.driver == Driver.SQLite: cursor.execute(''' CREATE TABLE IF NOT EXISTS servers ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -99,12 +128,36 @@ def cursor(self): return cursor def transform(self, sql: str): - if self.type == 'pgsql': + if self.driver == Driver.PostgreSQL: return sql.replace('?', '%s').replace('IFNULL', 'COALESCE') return sql # sqlite def statistics(self): + if self.driver == Driver.MongoDB: + messages = len(self.collection.distinct("message_id")) + channels = len(self.collection.distinct("channel_id")) + guilds = len(self.collection.distinct("guild_id")) + + pipeline = [ + {"$group": { + "_id": { + "game_id": "$game_id", + "address": "$address", + "query_port": "$query_port", + "query_extra": "$query_extra" + } + }} + ] + unique_servers = len(list(self.collection.aggregate(pipeline))) + + return { + 'messages': messages, + 'channels': channels, + 'guilds': guilds, + 'unique_servers': unique_servers, + } + sql = ''' SELECT DISTINCT (SELECT COUNT(DISTINCT message_id) FROM servers) as messages, @@ -127,37 +180,74 @@ def statistics(self): } def games_servers_count(self): + if self.driver == Driver.MongoDB: + pipeline = [ + {"$group": {"_id": "$game_id", "count": {"$sum": 1}}} + ] + results = self.collection.aggregate(pipeline) + servers_count = {str(row['_id']): int(row['count']) + for row in results} + results.close() + return servers_count + cursor = self.cursor() - cursor.execute(self.transform('SELECT game_id, COUNT(*) FROM servers GROUP BY game_id')) + cursor.execute(self.transform( + 'SELECT game_id, COUNT(*) FROM servers GROUP BY game_id')) servers_count = {str(row[0]): int(row[1]) for row in cursor.fetchall()} cursor.close() return servers_count - def all_servers(self, channel_id: int = None, guild_id: int = None, message_id: int = None, game_id: str = None, filter_secret: bool = False): + def all_servers(self, *, channel_id: int = None, guild_id: int = None, message_id: int = None, game_id: str = None, filter_secret=False): """Get all servers""" + if self.driver == Driver.MongoDB: + if channel_id: + results = self.collection.find( + {"channel_id": channel_id}).sort("position") + elif guild_id: + results = self.collection.find( + {"guild_id": guild_id}).sort("position") + elif message_id: + results = self.collection.find( + {"message_id": message_id}).sort("position") + elif game_id: + results = self.collection.find( + {"game_id": game_id}).sort("position") + else: + results = self.collection.find({}).sort("position") + + servers = [Server.from_docs(doc, filter_secret) for doc in results] + results.close() + + return servers + cursor = self.cursor() if channel_id: - cursor.execute(self.transform('SELECT * FROM servers WHERE channel_id = ? ORDER BY position'), (channel_id,)) + cursor.execute(self.transform( + 'SELECT * FROM servers WHERE channel_id = ? ORDER BY position'), (channel_id,)) elif guild_id: - cursor.execute(self.transform('SELECT * FROM servers WHERE guild_id = ? ORDER BY position'), (guild_id,)) + cursor.execute(self.transform( + 'SELECT * FROM servers WHERE guild_id = ? ORDER BY position'), (guild_id,)) elif message_id: - cursor.execute(self.transform('SELECT * FROM servers WHERE message_id = ? ORDER BY position'), (message_id,)) + cursor.execute(self.transform( + 'SELECT * FROM servers WHERE message_id = ? ORDER BY position'), (message_id,)) elif game_id: - cursor.execute(self.transform('SELECT * FROM servers WHERE game_id = ? ORDER BY id'), (game_id,)) + cursor.execute(self.transform( + 'SELECT * FROM servers WHERE game_id = ? ORDER BY id'), (game_id,)) else: cursor.execute('SELECT * FROM servers ORDER BY position') - servers = [Server.from_list(row, filter_secret) for row in cursor.fetchall()] + servers = [Server.from_list(row, filter_secret) + for row in cursor.fetchall()] cursor.close() return servers - def all_channels_servers(self, servers: List[Server] = None): + def all_channels_servers(self, servers: list[Server] = None): """Convert or get servers to dict grouped by channel id""" all_servers = servers if servers is not None else self.all_servers() - channels_servers: Dict[int, List[Server]] = {} + channels_servers: dict[int, list[Server]] = {} for server in all_servers: if server.channel_id in channels_servers: @@ -167,10 +257,10 @@ def all_channels_servers(self, servers: List[Server] = None): return channels_servers - def all_messages_servers(self, servers: List[Server] = None): + def all_messages_servers(self, servers: list[Server] = None): """Convert or get servers to dict grouped by message id""" all_servers = servers if servers is not None else self.all_servers() - messages_servers: Dict[int, List[Server]] = {} + messages_servers: dict[int, list[Server]] = {} for server in all_servers: if server.message_id: @@ -183,21 +273,63 @@ def all_messages_servers(self, servers: List[Server] = None): def distinct_servers(self): """Get distinct servers (Query server purpose) (Only fetch game_id, address, query_port, query_extra, status, result)""" + if self.driver == Driver.MongoDB: + pipeline = [ + {"$group": { + "_id": { + "game_id": "$game_id", + "address": "$address", + "query_port": "$query_port", + "query_extra": "$query_extra", + "status": "$status", + "result": "$result" + } + }} + ] + results = self.collection.aggregate(pipeline) + servers = [QueryServer(**row['_id']) for row in results] + results.close() + return servers + cursor = self.cursor() - cursor.execute('SELECT DISTINCT game_id, address, query_port, query_extra, status, result FROM servers') - servers = [Server.from_distinct_query(row) for row in cursor.fetchall()] + cursor.execute( + 'SELECT DISTINCT game_id, address, query_port, query_extra, status, result FROM servers') + servers = [QueryServer.create(row) for row in cursor.fetchall()] cursor.close() return servers def add_server(self, s: Server): + if self.driver == Driver.MongoDB: + try: + max_position = self.collection.find_one({'channel_id': s.channel_id}, sort=[('position', -1)])["position"] + except TypeError: + max_position = 0 + + self.collection.insert_one({ + "position": max_position + 1, + "guild_id": s.guild_id, + "channel_id": s.channel_id, + "game_id": s.game_id, + "address": s.address, + "query_port": s.query_port, + "query_extra": s.query_extra, + "status": s.status, + "result": s.result, + "style_id": s.style_id, + "style_data": s.style_data + }) + + return self.find_server(s.channel_id, s.address, s.query_port) + sql = ''' INSERT INTO servers (position, guild_id, channel_id, game_id, address, query_port, query_extra, status, result, style_id, style_data) VALUES ((SELECT IFNULL(MAX(position + 1), 0) FROM servers WHERE channel_id = ?), ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)''' try: cursor = self.cursor() - cursor.execute(self.transform(sql), (s.channel_id, s.guild_id, s.channel_id, s.game_id, s.address, s.query_port, stringify(s.query_extra), s.status, stringify(s.result), s.style_id, stringify(s.style_data))) + cursor.execute(self.transform(sql), (s.channel_id, s.guild_id, s.channel_id, s.game_id, s.address, s.query_port, stringify( + s.query_extra), s.status, stringify(s.result), s.style_id, stringify(s.style_data))) self.conn.commit() except psycopg2.Error as e: self.conn.rollback() @@ -207,7 +339,20 @@ def add_server(self, s: Server): return self.find_server(s.channel_id, s.address, s.query_port) - def update_servers_message_id(self, servers: List[Server]): + def update_servers_message_id(self, servers: list[Server]): + if self.driver == Driver.MongoDB: + operations = [ + UpdateOne( + {"_id": server.id}, + {"$set": {"message_id": server.message_id}} + ) for server in servers + ] + + if operations: + self.collection.bulk_write(operations) + + return + sql = 'UPDATE servers SET message_id = ? WHERE id = ?' parameters = [(server.message_id, server.id) for server in servers] cursor = self.cursor() @@ -215,45 +360,80 @@ def update_servers_message_id(self, servers: List[Server]): self.conn.commit() cursor.close() - def update_servers(self, servers: List[Server]): + def update_servers(self, servers: list[Server], *, channel_id: int = None): + if channel_id is not None: + return self.__update_servers_channel_id(servers, channel_id) + """Update servers status and result""" - parameters = [(server.status, stringify(server.result), server.game_id, server.address, server.query_port, stringify(server.query_extra)) for server in servers] + if self.driver == Driver.MongoDB: + operations = [ + UpdateMany( + {"game_id": server.game_id, "address": server.address, + "query_port": server.query_port, "query_extra": server.query_extra}, + {"$set": {"status": server.status, "result": server.result}} + ) for server in servers + ] + + if operations: + self.collection.bulk_write(operations) + + return + + parameters = [(server.status, stringify(server.result), server.game_id, server.address, + server.query_port, stringify(server.query_extra)) for server in servers] sql = 'UPDATE servers SET status = ?, result = ? WHERE game_id = ? AND address = ? AND query_port = ? AND query_extra = ?' cursor = self.cursor() cursor.executemany(self.transform(sql), parameters) self.conn.commit() cursor.close() - def delete_server(self, server: Server): - sql = 'DELETE FROM servers WHERE id = ?' - cursor = self.cursor() - cursor.execute(self.transform(sql), (server.id,)) - self.conn.commit() - cursor.close() + def delete_servers(self, *, guild_id: int = None, channel_id: int = None, servers: list[Server] = None): + if guild_id is None and channel_id is None and servers is None: + return + + if self.driver == Driver.MongoDB: + if guild_id is not None: + self.collection.delete_many({"guild_id": guild_id}) + elif channel_id is not None: + self.collection.delete_many({"channel_id": channel_id}) + elif servers is not None: + operations = [DeleteOne({"_id": server.id}) + for server in servers] + + if operations: + self.collection.bulk_write(operations) + else: + cursor = self.cursor() - def factory_reset(self, guild_id: int): - sql = 'DELETE FROM servers WHERE guild_id = ?' - cursor = self.cursor() - cursor.execute(self.transform(sql), (guild_id,)) - self.conn.commit() - cursor.close() + if guild_id is not None: + sql = 'DELETE FROM servers WHERE guild_id = ?' + cursor.execute(self.transform(sql), (guild_id,)) + elif channel_id is not None: + sql = 'DELETE FROM servers WHERE channel_id = ?' + cursor.execute(self.transform(sql), (channel_id,)) + elif servers is not None: + sql = 'DELETE FROM servers WHERE id = ?' + parameters = [(server.id,) for server in servers] + cursor.executemany(self.transform(sql), parameters) - def delete_servers(self, channel_id: int): - sql = 'DELETE FROM servers WHERE channel_id = ?' - cursor = self.cursor() - cursor.execute(self.transform(sql), (channel_id,)) - self.conn.commit() - cursor.close() + self.conn.commit() + cursor.close() + + def find_server(self, channel_id: int, address: str = None, query_port: int = None): + if self.driver == Driver.MongoDB: + result = self.collection.find_one( + {"channel_id": channel_id, "address": address, "query_port": query_port}) + + if not result: + raise self.ServerNotFoundError() + + return Server.from_docs(result) - def find_server(self, channel_id: int, address: str = None, query_port: str = None, message_id: int = None): cursor = self.cursor() - if message_id is not None: - sql = 'SELECT * FROM servers WHERE channel_id = ? AND message_id = ?' - cursor.execute(self.transform(sql), (channel_id, message_id,)) - else: - sql = 'SELECT * FROM servers WHERE channel_id = ? AND address = ? AND query_port = ?' - cursor.execute(self.transform(sql), (channel_id, address, query_port)) + sql = 'SELECT * FROM servers WHERE channel_id = ? AND address = ? AND query_port = ?' + cursor.execute(self.transform( + sql), (channel_id, address, query_port)) row = cursor.fetchone() cursor.close() @@ -277,96 +457,216 @@ def modify_server_position(self, server1: Server, direction: bool): if server1.message_id is None or server2.message_id is None: return [] - return self.swap_servers_positon(server1, server2) + return self.__swap_servers_positon(server1, server2) - def swap_servers_positon(self, server1: Server, server2: Server): - sql = 'UPDATE servers SET position = case when position = ? then ? else ? end, message_id = case when message_id = ? then ? else ? end WHERE id IN (?, ?)' - cursor = self.cursor() - cursor.execute(self.transform(sql), (server1.position, server2.position, server1.position, server1.message_id, server2.message_id, server1.message_id, server1.id, server2.id)) - self.conn.commit() - cursor.close() + def __swap_servers_positon(self, server1: Server, server2: Server): + if self.driver == Driver.MongoDB: + # Update server1's position and message_id to server2's values + self.collection.update_one({"_id": server1.id}, {"$set": {"position": server2.position, "message_id": server2.message_id}}) + + # Update server2's position and message_id to the original server1's values + self.collection.update_one({"_id": server2.id}, {"$set": {"position": server1.position, "message_id": server1.message_id}}) + else: + sql = 'UPDATE servers SET position = case when position = ? then ? else ? end, message_id = case when message_id = ? then ? else ? end WHERE id IN (?, ?)' + cursor = self.cursor() + cursor.execute(self.transform(sql), (server1.position, server2.position, server1.position, + server1.message_id, server2.message_id, server1.message_id, server1.id, server2.id)) + self.conn.commit() + cursor.close() + # Swap the position and message_id values in the server objects server1.position, server2.position = server2.position, server1.position server1.message_id, server2.message_id = server2.message_id, server1.message_id return [server1, server2] def server_exists(self, channel_id: int, address: str, query_port: str): - sql = 'SELECT id FROM servers WHERE channel_id = ? AND address = ? AND query_port = ?' - cursor = self.cursor() - cursor.execute(self.transform(sql), (channel_id, address, query_port)) - exists = True if cursor.fetchone() else False - cursor.close() + if self.driver == Driver.MongoDB: + exists = self.collection.find_one({"channel_id": channel_id, "address": address, "query_port": query_port}) is not None + else: + sql = 'SELECT id FROM servers WHERE channel_id = ? AND address = ? AND query_port = ?' + cursor = self.cursor() + cursor.execute(self.transform(sql), (channel_id, address, query_port)) + exists = True if cursor.fetchone() else False + cursor.close() return exists def update_server_style_id(self, server: Server): + if self.driver == Driver.MongoDB: + self.collection.update_one({"_id": server.id}, {"$set": {"style_id": server.style_id}}) + return + sql = 'UPDATE servers SET style_id = ? WHERE id = ?' cursor = self.cursor() cursor.execute(self.transform(sql), (server.style_id, server.id)) self.conn.commit() cursor.close() - def update_server_style_data(self, server: Server): - sql = 'UPDATE servers SET style_data = ? WHERE id = ?' - cursor = self.cursor() - cursor.execute(self.transform(sql), (stringify(server.style_data), server.id)) - self.conn.commit() - cursor.close() + def update_servers_style_data(self, servers: list[Server]): + if self.driver == Driver.MongoDB: + if operations := [ + UpdateOne( + {"_id": server.id}, + {"$set": {"style_data": server.style_data}} + ) for server in servers + ]: + self.collection.bulk_write(operations) + + return - def update_servers_style_data(self, servers: List[Server]): sql = 'UPDATE servers SET style_data = ? WHERE id = ?' - parameters = [(stringify(server.style_data), server.id) for server in servers] + parameters = [(stringify(server.style_data), server.id) + for server in servers] cursor = self.cursor() cursor.executemany(self.transform(sql), parameters) self.conn.commit() cursor.close() - def update_servers_channel_id(self, servers: List[Server]): + def __update_servers_channel_id(self, servers: list[Server], channel_id: int): + if self.driver == Driver.MongoDB: + max_position = self.collection.find_one({'channel_id': channel_id}, sort=[('position', -1)])["position"] + + operations = [] + + for server in servers: + max_position += 1 + + operations.append( + UpdateOne( + {"_id": server.id}, + {"$set": {"channel_id": channel_id, "position": max_position}} + ) + ) + + if operations: + self.collection.bulk_write(operations) + + return + sql = 'UPDATE servers SET channel_id = ?, position = (SELECT IFNULL(MAX(position + 1), 0) FROM servers WHERE channel_id = ?) WHERE id = ?' - parameters = [(server.channel_id, server.channel_id, server.id) for server in servers] + parameters = [(channel_id, channel_id, server.id) + for server in servers] cursor = self.cursor() cursor.executemany(self.transform(sql), parameters) self.conn.commit() cursor.close() - def export(self): - cursor = self.cursor() - cursor.execute('SELECT * FROM servers') - file = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data', 'servers.sql') - - def values_builder(row: Tuple): - values = [] - - for value in row: - if value is None: - values.append('NULL') - elif isinstance(value, str): - value = str(value).replace("'", "''") - values.append(f"'{value}'") - else: - values.append(str(value)) + def export(self, *, to_driver: str): + if to_driver not in drivers: + raise InvalidDriverError( + f"'{to_driver}' is not a valid driver. Valid drivers are: {', '.join(drivers)}") - return ', '.join(values) + export_path = os.path.join(os.path.dirname( + os.path.realpath(__file__)), '..', 'data', 'exports') + Path(export_path).mkdir(parents=True, exist_ok=True) - with open(file, 'w', encoding='utf-8') as f: - f.writelines([f'INSERT INTO servers VALUES ({values_builder(row)});\n' for row in cursor]) + if to_driver == Driver.MongoDB.value: + servers = self.all_servers() + documents = [server.__dict__ for server in servers] + file = os.path.join(export_path, 'servers.json') - if self.type == 'pgsql': - # Sync the id sequence, fix postgresql duplicate key violates unique constraint - f.write("SELECT SETVAL((SELECT PG_GET_SERIAL_SEQUENCE('\"servers\"', 'id')), (SELECT (MAX('id') + 1) FROM 'servers'), FALSE);\n") + with open(file, 'w', encoding='utf-8') as f: + json.dump(documents, f, indent=2) + else: + file = os.path.join(export_path, 'servers.sql') + + if self.driver == Driver.SQLite.value: + # Export data to SQL file + with open(file, 'w', encoding='utf-8') as f: + for line in self.conn.iterdump(): + f.write('%s\n' % line) + elif self.driver == Driver.PostgreSQL.value: + cursor = self.cursor() + + # Export data to SQL file + with open(file, 'w', encoding='utf-8') as f: + cursor.copy_expert( + "COPY servers TO STDOUT WITH CSV DELIMITER ';'", f) + + cursor.close() + elif self.driver == Driver.MongoDB.value: + print("MongoDB does not support exporting to SQL file directly.") + return print(f'Exported to {os.path.abspath(file)}') + def import_(self, *, filename: str): + # Define the path to the exports directory + export_path = os.path.join(os.path.dirname( + os.path.realpath(__file__)), '..', 'data', 'exports') + + # Create the exports directory if it doesn't exist + Path(export_path).mkdir(parents=True, exist_ok=True) + + # Check if the filename ends with '.json' + if filename.endswith('.json'): + # If the driver is not MongoDB, raise an error + if self.driver != Driver.MongoDB: + raise ValueError( + "Invalid driver for JSON file. Expected 'mongodb'.") + + # Check if the filename ends with '.sql' + elif filename.endswith('.sql'): + # If the driver is not PostgreSQL or SQLite, raise an error + if self.driver not in [Driver.PostgreSQL, Driver.SQLite]: + raise ValueError( + "Invalid driver for SQL file. Expected 'pgsql' or 'sqlite'.") + + # Check if the file exists + file_path = os.path.join(export_path, filename) + if not os.path.exists(file_path): + raise FileNotFoundError( + f"The file {filename} does not exist in the export path.") + + # Load the data and insert it into the database + with open(file_path, 'r', encoding='utf-8') as file: + # If the driver is MongoDB + if self.driver == Driver.MongoDB: + # Load the JSON data + servers = json.load(file) + + # Insert the data into the MongoDB collection + result = self.collection.insert_many(servers) + print(f"Imported {len(result.inserted_ids)} servers.") + # If the driver is PostgreSQL or SQLite + elif self.driver in [Driver.PostgreSQL, Driver.SQLite]: + # Read the SQL commands + sql_script = file.read() + + self.create_table_if_not_exists() + + # Execute the SQL commands + cursor = self.cursor() + + if self.driver == Driver.PostgreSQL: + cursor.execute(sql_script) + if self.driver == Driver.SQLite: + cursor.executescript(sql_script) + + # Commit the changes and close the cursor + self.conn.commit() + cursor.close() + + print(f"Imported {len(sql_script.splitlines())} servers.") + class ServerNotFoundError(Exception): pass if __name__ == '__main__': + database = Database() + parser = ArgumentParser() - subparsers = parser.add_subparsers(dest='subparser_name') + subparsers = parser.add_subparsers(dest='action') subparsers.add_parser('all') - subparsers.add_parser('export') + export = subparsers.add_parser('export') + export.add_argument('--to_driver', choices=drivers, + default=database.driver.value) + + # Add a parser for the 'import' action + import_ = subparsers.add_parser('import') + import_.add_argument('--filename', required=True) args = parser.parse_args() @@ -374,10 +674,10 @@ class ServerNotFoundError(Exception): parser.print_help(sys.stderr) sys.exit(-1) - database = Database() - - if args.subparser_name == 'all': + if args.action == 'all': for server in database.all_servers(): print(server) - elif args.subparser_name == 'export': - database.export() + elif args.action == 'export': + database.export(to_driver=args.to_driver) + elif args.action == 'import': + database.import_(filename=args.filename) diff --git a/discordgsm/environment.py b/discordgsm/environment.py index 32cda3e..b4d5d34 100644 --- a/discordgsm/environment.py +++ b/discordgsm/environment.py @@ -67,7 +67,7 @@ def __int__(self) -> int: Variable('APP_ADVERTISE_TYPE', 'Presence advertise type. server_count = 0, individually = 1, player_stats = 2', AdvertiseType, default=0), Variable('TASK_QUERY_SERVER', 'Query servers task scheduled time in seconds.', float, default=60), Variable('TASK_QUERY_SERVER_TIMEOUT', 'Query servers task timeout in seconds.', float, default=15), - Variable('DB_CONNECTION', 'Database type. Accepted value: sqlite, pgsql', str, default='sqlite'), + Variable('DB_CONNECTION', 'Database type. Accepted value: sqlite, pgsql, mongodb', str, default='sqlite'), Variable('DATABASE_URL', 'Database connection url.', str), Variable('COMMAND_QUERY_PUBLIC', 'Whether the /queryserver command should be available to all users.', bool, default=False), Variable('COMMAND_QUERY_COOLDOWN', 'The /queryserver command cooldown in seconds. (Administrator will not be affected)', float, default=5), diff --git a/discordgsm/gamedig.py b/discordgsm/gamedig.py index 6ec9715..274757e 100644 --- a/discordgsm/gamedig.py +++ b/discordgsm/gamedig.py @@ -120,6 +120,10 @@ def is_port_valid(port: str): return 0 <= port_number <= 65535 async def query(self, server: Server): + # Backward compatibility + if server.game_id == 'forrest': + server.game_id = 'forest' + return await self.run({**{ 'type': server.game_id, 'host': server.address, diff --git a/discordgsm/games.csv b/discordgsm/games.csv index db037b1..ddb8185 100644 --- a/discordgsm/games.csv +++ b/discordgsm/games.csv @@ -64,10 +64,10 @@ crysis2,Crysis 2 (2011),gamespy3,port=64000 crysiswars,Crysis Wars (2008),gamespy3,port=64100 cs15,Counter-Strike 1.5 (2002),won,port=27015 cs16,Counter-Strike 1.6 (2003),source,port=27015 +cs2,Counter-Strike 2 (2023),source,port=27015 cscz,Counter-Strike: Condition Zero (2004),source,port=27015 csgo,Counter-Strike: Global Offensive (2012),source,port=27015 css,Counter-Strike: Source (2004),source,port=27015 -cs2,Counter-Strike 2 (2023),source,port=27015 daikatana,Daikatana (2000),quake2,port=27982;port_query_offset=10 darkesthour,Darkest Hour: Europe '44-'45 (2008),unreal2,port=7757;port_query_offset=1 @@ -127,12 +127,12 @@ halo,Halo (2003),gamespy2,port=2302 halo2,Halo 2 (2007),gamespy2,port=2302 heretic2,Heretic II (1998),gamespy1,port=27900;port_query_offset=1 hexen2,Hexen II (1997),hexen2,port=26900;port_query_offset=50 +hfnaw,Holdfast: Nations at War (2020),source,port=27000 hidden,The Hidden (2005),source,port=27015 hl2dm,Half-Life 2: Deathmatch (2004),source,port=27015 hldm,Half-Life Deathmatch (1998),source,port=27015 hldms,Half-Life Deathmatch: Source (2005),source,port=27015 hll,Hell Let Loose (2021),source,port=27015 -hfnaw,Holdfast: Nations at War (2020),source,port=27000 homefront,Homefront (2011),source,port=27015 homeworld2,Homeworld 2 (2003),gamespy1,port_query=6500 hurtworld,Hurtworld (2015),source,port=12871;port_query=12881 diff --git a/discordgsm/main.py b/discordgsm/main.py index 57adbe6..0dfc45d 100644 --- a/discordgsm/main.py +++ b/discordgsm/main.py @@ -44,7 +44,7 @@ def cache_message(message: Message): intents = discord.Intents.default() shard_ids = [int(shard_id) for shard_id in os.getenv('APP_SHARD_IDS').replace(';', ',').split(',') if shard_id] if len(os.getenv('APP_SHARD_IDS', '')) > 0 else None shard_count = int(os.getenv('APP_SHARD_COUNT', '1')) -client = Client(intents=intents) if not public else AutoShardedClient(intents=intents, shard_ids=shard_ids, shard_count=shard_count) +client = AutoShardedClient(intents=intents, shard_ids=shard_ids, shard_count=shard_count) # region Application event @@ -53,7 +53,7 @@ async def on_ready(): """Called when the client is done preparing the data received from Discord.""" await client.wait_until_ready() - Logger.info(f'Connected to {database.type} database') + Logger.info(f'Connected to {database.driver.value} database') Logger.info(f'Logged on as {client.user}') Logger.info(f'Add to Server: {invite_link}') @@ -92,7 +92,7 @@ async def on_guild_join(guild: discord.Guild): @client.event async def on_guild_remove(guild: discord.Guild): """Remove all associated servers in database when discordgsm leaves""" - database.factory_reset(guild.id) + database.delete_servers(guild_id=guild.id) Logger.info(f'{client.user} left {guild.name}({guild.id}), associated servers were deleted.') @@ -266,7 +266,7 @@ async def modal_on_submit(interaction: Interaction): for item in params.values(): item.default = item._value = str(item._value).strip() - game_id, address, query_port = game['id'], str(query_param['host']), str(query_param['port']) + game_id, address, query_port = game['id'], str(query_param['host']), int(str(query_param['port'])) # Validate the port number for key in params.keys(): @@ -396,7 +396,7 @@ async def command_delserver(interaction: Interaction, address: str, query_port: if server := await find_server(interaction, address, query_port): await interaction.response.defer(ephemeral=True) - database.delete_server(server) + database.delete_servers(servers=[server]) if await resend_channel_messages(interaction): await interaction.delete_original_response() @@ -428,7 +428,7 @@ async def command_factoryreset(interaction: Interaction): async def button_callback(interaction: Interaction): await interaction.response.defer(ephemeral=True) servers = database.all_servers(guild_id=interaction.guild.id) - database.factory_reset(interaction.guild.id) + database.delete_servers(guild_id=interaction.guild.id) async def purge_channel(channel_id: int): channel = client.get_channel(channel_id) @@ -544,7 +544,7 @@ async def command_editstyledata(interaction: Interaction, address: str, query_po async def modal_on_submit(interaction: Interaction): await interaction.response.defer(ephemeral=True) server.style_data.update({k: str(v) for k, v in edit_fields.items()}) - database.update_server_style_data(server) + database.update_servers_style_data([server]) await refresh_channel_messages(interaction) modal.on_submit = modal_on_submit @@ -567,11 +567,8 @@ async def command_switch(interaction: Interaction, channel: discord.TextChannel, if servers := await find_servers(interaction, address, query_port): await interaction.response.defer(ephemeral=True) + database.update_servers(servers, channel_id=channel.id) - for server in servers: - server.channel_id = channel.id - - database.update_servers_channel_id(servers) await resend_channel_messages(None, interaction.channel.id) await resend_channel_messages(None, channel.id) @@ -683,7 +680,7 @@ async def modal_on_submit(interaction: Interaction): webhook_url = str(text_input_webhook_url).strip() content = str(text_input_webhook_content).strip() server.style_data.update({'_alert_webhook_url': webhook_url, '_alert_content': content}) - database.update_server_style_data(server) + database.update_servers_style_data([server]) modal.on_submit = modal_on_submit await interaction.response.send_modal(modal) diff --git a/discordgsm/server.py b/discordgsm/server.py index 8f30c49..ff9e47e 100644 --- a/discordgsm/server.py +++ b/discordgsm/server.py @@ -8,6 +8,27 @@ from discordgsm.gamedig import GamedigResult +@dataclass +class QueryServer: + game_id: str + address: str + query_port: int + query_extra: dict + status: bool + result: GamedigResult + + @staticmethod + def create(row: tuple) -> QueryServer: + return QueryServer( + game_id=row[0], + address=row[1], + query_port=row[2], + query_extra=json.loads(row[3]), + status=row[4] == 1, + result=json.loads(row[5]), + ) + + @dataclass class Server: id: int @@ -17,7 +38,7 @@ class Server: message_id: Optional[int] game_id: str address: str - query_port: str + query_port: int query_extra: dict status: bool result: GamedigResult @@ -25,7 +46,7 @@ class Server: style_data: dict @staticmethod - def new(guild_id: int, channel_id: int, game_id: str, address: str, query_port: str, query_extra: dict, result: GamedigResult) -> Server: + def new(guild_id: int, channel_id: int, game_id: str, address: str, query_port: int, query_extra: dict, result: GamedigResult) -> Server: return Server( id=None, position=None, @@ -42,24 +63,6 @@ def new(guild_id: int, channel_id: int, game_id: str, address: str, query_port: style_data={} ) - @staticmethod - def from_distinct_query(row: tuple) -> Server: - return Server( - id=None, - position=None, - guild_id=None, - channel_id=None, - message_id=None, - game_id=row[0], - address=row[1], - query_port=row[2], - query_extra=json.loads(row[3]), - status=row[4] == 1, - result=json.loads(row[5]), - style_id=None, - style_data=None, - ) - @staticmethod def from_list(row: tuple, filter_secret=False) -> Server: query_extra: dict = json.loads(row[8]) @@ -85,3 +88,28 @@ def from_list(row: tuple, filter_secret=False) -> Server: style_id=row[11], style_data=style_data, ) + + @staticmethod + def from_docs(data: dict, filter_secret=False) -> Server: + server = Server( + id=data['_id'], + position=data['position'], + guild_id=data['guild_id'], + channel_id=data['channel_id'], + message_id=data.get('message_id'), + game_id=data['game_id'], + address=data['address'], + query_port=data['query_port'], + query_extra=data['query_extra'], + status=data['status'], + result=data['result'], + style_id=data['style_id'], + style_data=data['style_data'] + ) + + if filter_secret: + # Filter key started with _ and filter the description since it may contain secrets + server.query_extra = {k: v for k, v in server.query_extra.items() if not str(k).startswith('_')} + server.style_data = {k: v for k, v in server.style_data.items() if not str(k).startswith('_') and k != 'description'} + + return server diff --git a/requirements.txt b/requirements.txt index e06c243..c2b3826 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ Flask==3.0.0 gunicorn==21.2.0 opengsq==2.1.2 psycopg2-binary==2.9.9 +pymongo==4.5.0 python-dotenv==1.0.0 pywin32==306;platform_system=="Windows" tzdata==2023.3