Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
brunneis committed May 31, 2021
1 parent 0345466 commit ed2d69a
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 42 deletions.
2 changes: 1 addition & 1 deletion easymongo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

from .easymongo import *

__version__ = '0.0.4b0'
__version__ = '1.215.0'
113 changes: 75 additions & 38 deletions easymongo/easymongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
# -*- coding: utf-8 -*-

from pymongo import MongoClient
import logging
from .errors import ConnectionError
import time


class MongodbConnector:
class MongoDB:
def __init__(self,
host=None,
host: str = None,
port=None,
default_database=None,
default_collection=None,
connect=False,
endpoint=None):
default_database: str = None,
default_collection: str = None,
connect: bool = False,
endpoint: str = None):

if endpoint is not None:
host, port = endpoint.split(':')
Expand All @@ -33,7 +33,7 @@ def __init__(self,
def client(self):
return self._client

def _get_collection(self, database_name, collection_name):
def _get_collection(self, database_name: str, collection_name: str):
self.open_connection()
database_name, collection_name = \
self._get_database_and_collection_names(database_name,
Expand All @@ -42,19 +42,19 @@ def _get_collection(self, database_name, collection_name):
collection = getattr(database, collection_name)
return collection

def set_defaults(self, database_name, database_collection=None):
def set_defaults(self, database_name: str, database_collection: str = None):
self._default_database = database_name
if database_collection != None:
self._default_collection = database_collection

def open_connection(self, attempts=10):
def open_connection(self, attempts: int = 10):
if not self._client:
try:
self._client = MongoClient(**self._config)
self._client.server_info()
except Exception:
if attempts == 0:
logging.exception('')
raise ConnectionError
else:
time.sleep(1)
self.open_connection(attempts - 1)
Expand All @@ -63,14 +63,21 @@ def close_connection(self):
if self._client:
self._client.close()

def _get_database_and_collection_names(self, database_name, collection_name):
def _get_database_and_collection_names(
self,
database_name: str,
collection_name: str,
):
if not database_name and hasattr(self, '_default_database'):
database_name = self._default_database
if not collection_name and hasattr(self, '_default_collection'):
collection_name = self._default_collection
return database_name, collection_name

def get_and_close(self, query=None, database_name=None, collection_name=None):
def get_and_close(self,
query=None,
database_name: str = None,
collection_name: str = None):
result = self.get(query, database_name, collection_name)
self.close_connection()
return result
Expand All @@ -82,13 +89,15 @@ def exists(self, query, database_name=None, collection_name=None):
return True

def create_index(self,
attribute=None,
attribute: str = None,
keys=None,
database_name=None,
collection_name=None,
unique=False,
type_='asc',
background=True):
database_name: str = None,
collection_name: str = None,
unique: bool = False,
type_: str = None,
background: bool = True):
if type_ is None:
type_ = 'asc'
if not keys and not attribute:
raise TypeError
if not keys:
Expand All @@ -102,15 +111,15 @@ def create_index(self,

def get(self,
query=None,
database_name=None,
collection_name=None,
sort=None,
sort_attribute=None,
sort_type=None,
limit=None,
database_name: str = None,
collection_name: str = None,
sort: str = None,
sort_attribute: str = None,
sort_type: str = None,
limit: int = None,
index=None,
index_attribute=None,
index_type=None):
index_attribute: str = None,
index_type: str = None):
collection = self._get_collection(database_name, collection_name)
if sort_attribute and sort_type:
if sort_type == 'desc':
Expand Down Expand Up @@ -139,12 +148,14 @@ def get(self,

def get_random(self,
query=None,
database_name=None,
collection_name=None,
database_name: str = None,
collection_name: str = None,
sort=None,
sort_attribute=None,
sort_type=None,
limit=1):
sort_attribute: str = None,
sort_type: str = None,
limit: int = None):
if limit is None:
limit = 1
collection = self._get_collection(database_name, collection_name)
if sort_attribute and sort_type:
if sort_type == 'desc':
Expand All @@ -160,30 +171,56 @@ def get_random(self,
result = collection.aggregate(operation)
return result

def update(self, query, value, database_name=None, collection_name=None):
def update(self,
query,
value,
database_name: str = None,
collection_name: str = None):
collection = self._get_collection(database_name, collection_name)
collection.update(query, {'$set': value}, upsert=True)

def push(self, query, key, value, database_name=None, collection_name=None):
def push(self,
query,
key: str,
value,
database_name: str = None,
collection_name: str = None):
collection = self._get_collection(database_name, collection_name)
collection.update(query, {'$push': {key: {'$each': value}}})

def put(self, value, query=None, database_name=None, collection_name=None):
def put(self,
value,
query=None,
database_name: str = None,
collection_name: str = None):
if query:
self.update(query, value, database_name, collection_name)
else:
self.insert(value, database_name, collection_name)

def remove(self, query=None, database_name=None, collection_name=None):
def remove(self,
query=None,
database_name: str = None,
collection_name: str = None):
collection = self._get_collection(database_name, collection_name)
collection.remove(query)

def insert(self, value, database_name=None, collection_name=None):
def insert(self,
value,
database_name: str = None,
collection_name: str = None):
collection = self._get_collection(database_name, collection_name)
collection.insert_one(dict(value))

def count(self, query=None, database_name=None, collection_name=None):
def count(self,
query=None,
database_name: str = None,
collection_name: str = None):
collection = self._get_collection(database_name, collection_name)
if not query:
query = {}
return collection.count_documents(filter=query)


# Backwards compatible
MongodbConnector = MongoDB
5 changes: 5 additions & 0 deletions easymongo/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

class ConnectionError(Exception):
pass
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
pymongo>=3.9.0
pymongo>=3.9.0,<4.0.0
dnspython>=2.1.0,<3.0.0

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@
"Programming Language :: Python :: Implementation :: PyPy",
"Topic :: Software Development :: Libraries :: Python Modules",
],
install_requires=['pymongo>=3.9.0'])
install_requires=['pymongo>=3.9.0,<4.0.0', 'dnspython>=2.1.0, <3.0.0'])
2 changes: 1 addition & 1 deletion upload.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash
python setup.py bdist_wheel
python3 setup.py bdist_wheel
twine upload dist/*.whl

0 comments on commit ed2d69a

Please sign in to comment.