diff --git a/easymongo/__init__.py b/easymongo/__init__.py index 3ddd806..8f14123 100644 --- a/easymongo/__init__.py +++ b/easymongo/__init__.py @@ -3,4 +3,4 @@ from .easymongo import * -__version__ = '0.0.4b0' +__version__ = '1.215.0' diff --git a/easymongo/easymongo.py b/easymongo/easymongo.py index aac8ce4..d31157f 100644 --- a/easymongo/easymongo.py +++ b/easymongo/easymongo.py @@ -2,18 +2,18 @@ # -*- coding: utf-8 -*- from pymongo import MongoClient -import logging +from .errors import ConnectionError import time -class MongodbConnector: +class MongoDB: def __init__(self, - host=None, + host: str = None, port=None, - default_database=None, - default_collection=None, - connect=False, - endpoint=None): + default_database: str = None, + default_collection: str = None, + connect: bool = False, + endpoint: str = None): if endpoint is not None: host, port = endpoint.split(':') @@ -33,7 +33,7 @@ def __init__(self, def client(self): return self._client - def _get_collection(self, database_name, collection_name): + def _get_collection(self, database_name: str, collection_name: str): self.open_connection() database_name, collection_name = \ self._get_database_and_collection_names(database_name, @@ -42,19 +42,19 @@ def _get_collection(self, database_name, collection_name): collection = getattr(database, collection_name) return collection - def set_defaults(self, database_name, database_collection=None): + def set_defaults(self, database_name: str, database_collection: str = None): self._default_database = database_name if database_collection != None: self._default_collection = database_collection - def open_connection(self, attempts=10): + def open_connection(self, attempts: int = 10): if not self._client: try: self._client = MongoClient(**self._config) self._client.server_info() except Exception: if attempts == 0: - logging.exception('') + raise ConnectionError else: time.sleep(1) self.open_connection(attempts - 1) @@ -63,14 +63,21 @@ def close_connection(self): if self._client: self._client.close() - def _get_database_and_collection_names(self, database_name, collection_name): + def _get_database_and_collection_names( + self, + database_name: str, + collection_name: str, + ): if not database_name and hasattr(self, '_default_database'): database_name = self._default_database if not collection_name and hasattr(self, '_default_collection'): collection_name = self._default_collection return database_name, collection_name - def get_and_close(self, query=None, database_name=None, collection_name=None): + def get_and_close(self, + query=None, + database_name: str = None, + collection_name: str = None): result = self.get(query, database_name, collection_name) self.close_connection() return result @@ -82,13 +89,15 @@ def exists(self, query, database_name=None, collection_name=None): return True def create_index(self, - attribute=None, + attribute: str = None, keys=None, - database_name=None, - collection_name=None, - unique=False, - type_='asc', - background=True): + database_name: str = None, + collection_name: str = None, + unique: bool = False, + type_: str = None, + background: bool = True): + if type_ is None: + type_ = 'asc' if not keys and not attribute: raise TypeError if not keys: @@ -102,15 +111,15 @@ def create_index(self, def get(self, query=None, - database_name=None, - collection_name=None, - sort=None, - sort_attribute=None, - sort_type=None, - limit=None, + database_name: str = None, + collection_name: str = None, + sort: str = None, + sort_attribute: str = None, + sort_type: str = None, + limit: int = None, index=None, - index_attribute=None, - index_type=None): + index_attribute: str = None, + index_type: str = None): collection = self._get_collection(database_name, collection_name) if sort_attribute and sort_type: if sort_type == 'desc': @@ -139,12 +148,14 @@ def get(self, def get_random(self, query=None, - database_name=None, - collection_name=None, + database_name: str = None, + collection_name: str = None, sort=None, - sort_attribute=None, - sort_type=None, - limit=1): + sort_attribute: str = None, + sort_type: str = None, + limit: int = None): + if limit is None: + limit = 1 collection = self._get_collection(database_name, collection_name) if sort_attribute and sort_type: if sort_type == 'desc': @@ -160,30 +171,56 @@ def get_random(self, result = collection.aggregate(operation) return result - def update(self, query, value, database_name=None, collection_name=None): + def update(self, + query, + value, + database_name: str = None, + collection_name: str = None): collection = self._get_collection(database_name, collection_name) collection.update(query, {'$set': value}, upsert=True) - def push(self, query, key, value, database_name=None, collection_name=None): + def push(self, + query, + key: str, + value, + database_name: str = None, + collection_name: str = None): collection = self._get_collection(database_name, collection_name) collection.update(query, {'$push': {key: {'$each': value}}}) - def put(self, value, query=None, database_name=None, collection_name=None): + def put(self, + value, + query=None, + database_name: str = None, + collection_name: str = None): if query: self.update(query, value, database_name, collection_name) else: self.insert(value, database_name, collection_name) - def remove(self, query=None, database_name=None, collection_name=None): + def remove(self, + query=None, + database_name: str = None, + collection_name: str = None): collection = self._get_collection(database_name, collection_name) collection.remove(query) - def insert(self, value, database_name=None, collection_name=None): + def insert(self, + value, + database_name: str = None, + collection_name: str = None): collection = self._get_collection(database_name, collection_name) collection.insert_one(dict(value)) - def count(self, query=None, database_name=None, collection_name=None): + def count(self, + query=None, + database_name: str = None, + collection_name: str = None): collection = self._get_collection(database_name, collection_name) if not query: query = {} return collection.count_documents(filter=query) + + +# Backwards compatible +MongodbConnector = MongoDB diff --git a/easymongo/errors.py b/easymongo/errors.py new file mode 100644 index 0000000..a3b8692 --- /dev/null +++ b/easymongo/errors.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +class ConnectionError(Exception): + pass diff --git a/requirements.txt b/requirements.txt index b722bd3..680aced 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,3 @@ -pymongo>=3.9.0 +pymongo>=3.9.0,<4.0.0 +dnspython>=2.1.0,<3.0.0 + diff --git a/setup.py b/setup.py index 01e87b0..d781acc 100755 --- a/setup.py +++ b/setup.py @@ -24,4 +24,4 @@ "Programming Language :: Python :: Implementation :: PyPy", "Topic :: Software Development :: Libraries :: Python Modules", ], - install_requires=['pymongo>=3.9.0']) + install_requires=['pymongo>=3.9.0,<4.0.0', 'dnspython>=2.1.0, <3.0.0']) diff --git a/upload.sh b/upload.sh index 99fbc36..da77908 100755 --- a/upload.sh +++ b/upload.sh @@ -1,4 +1,4 @@ #!/bin/bash -python setup.py bdist_wheel +python3 setup.py bdist_wheel twine upload dist/*.whl