Source code for mongo_thingy

import asyncio
import warnings
from collections.abc import Mapping

from pymongo import MongoClient, ReturnDocument
from pymongo.errors import ConfigurationError
from thingy import DatabaseThingy, classproperty, registry

from mongo_thingy.cursor import AsyncCursor, Cursor

try:
    from motor.motor_tornado import MotorClient
except ImportError:  # pragma: no cover
    MotorClient = None

try:
    from motor.motor_asyncio import AsyncIOMotorClient
except ImportError:  # pragma: no cover
    AsyncIOMotorClient = None


class ThingyList(list):
    def distinct(self, key):
        def __get_value(item):
            if isinstance(item, BaseThingy):
                item = item.view()
            return item.get(key)

        values = set(__get_value(item) for item in self)
        return list(values)

    def view(self, name="defaults"):
        def __view(item):
            if not isinstance(item, BaseThingy):
                raise TypeError(f"Can't view type {type(item)}.")
            return item.view(name)

        return [__view(item) for item in self]


class BaseThingy(DatabaseThingy):
    """Represents a document in a collection"""

    _client = None
    _client_cls = None
    _collection = None
    _collection_name = None
    _cursor_cls = None
    _result_cls = ThingyList

    @classproperty
    def _table(cls):
        return cls._collection

    @classproperty
    def _table_name(cls):
        return cls._collection_name

    @classproperty
    def table_name(cls):
        return cls.collection_name

    @classproperty
    def collection(cls):
        return cls.get_collection()

    @classproperty
    def collection_name(cls):
        return cls.get_table_name()

    @classproperty
    def client(cls):
        return cls.get_client()

    @classmethod
    def _get_client(cls, database):
        return database.client

    @classmethod
    def _get_database(cls, collection, name):
        if collection is not None:
            return collection.database
        if cls._client and name:
            return cls._client[name]
        raise AttributeError("Undefined database.")

    @classmethod
    def _get_table(cls, database, table_name):
        return database[table_name]

    @classmethod
    def _get_database_name(cls, database):
        return database.name

    @classmethod
    def _get_table_name(cls, table):
        return table.name

    @classmethod
    def get_client(cls):
        if cls._client:
            return cls._client
        return cls._get_client(cls.database)

    @classmethod
    def get_collection(cls):
        return cls.get_table()

    @classmethod
    def add_index(cls, keys, **kwargs):
        kwargs.setdefault("background", True)
        if not hasattr(cls, "_indexes"):
            cls._indexes = []
        cls._indexes.append((keys, kwargs))

    @classmethod
    def count_documents(cls, filter=None, *args, **kwargs):
        if filter is None:
            filter = {}
        return cls.collection.count_documents(filter, *args, **kwargs)

    @classmethod
    def count(cls, filter=None, *args, **kwargs):
        warnings.warn(
            "count is deprecated. Use count_documents instead.", DeprecationWarning
        )
        return cls.count_documents(filter=filter, *args, **kwargs)

    @classmethod
    def connect(cls, *args, client_cls=None, database_name=None, **kwargs):
        if not client_cls:
            client_cls = cls._client_cls

        cls._client = client_cls(*args, **kwargs)
        try:
            cls._database = cls._client.get_database(database_name)
        except (ConfigurationError, TypeError):
            cls._database = cls._client["test"]

    @classmethod
    def disconnect(cls, *args, **kwargs):
        if cls._client:
            cls._client.close()
        cls._client = None
        cls._database = None

    @classmethod
    def distinct(cls, *args, **kwargs):
        return cls.collection.distinct(*args, **kwargs)

    @classmethod
    def find(cls, *args, view=None, **kwargs):
        delegate = cls.collection.find(*args, **kwargs)
        return cls._cursor_cls(delegate, thingy_cls=cls, view=view)

    @classmethod
    def find_one(cls, filter=None, *args, **kwargs):
        if filter is not None and not isinstance(filter, Mapping):
            filter = {"_id": filter}

        cursor = cls.find(filter, *args, **kwargs)
        return cursor.first()

    @classmethod
    def delete_many(cls, filter=None, *args, **kwargs):
        return cls.collection.delete_many(filter, *args, **kwargs)

    @classmethod
    def delete_one(cls, filter=None, *args, **kwargs):
        if filter is not None and not isinstance(filter, Mapping):
            filter = {"_id": filter}

        return cls.collection.delete_one(filter, *args, **kwargs)

    @classmethod
    def update_many(cls, filter, update, *args, **kwargs):
        return cls.collection.update_many(filter, update, *args, **kwargs)

    @classmethod
    def update_one(cls, filter, update, *args, **kwargs):
        if filter is not None and not isinstance(filter, Mapping):
            filter = {"_id": filter}

        return cls.collection.update_one(filter, update, *args, **kwargs)

    @property
    def id(self):
        return self.__dict__.get("id") or self._id

    @id.setter
    def id(self, value):
        if "id" in self.__dict__:
            self.__dict__["id"] = value
        else:
            self._id = value

    def delete(self):
        return self.get_collection().delete_one({"_id": self.id})


[docs]class Thingy(BaseThingy): _client_cls = MongoClient _cursor_cls = Cursor
[docs] @classmethod def create_index(cls, keys, **kwargs): cls.collection.create_index(keys, **kwargs)
[docs] @classmethod def create_indexes(cls): if hasattr(cls, "_indexes"): for keys, kwargs in cls._indexes: cls.create_index(keys, **kwargs)
[docs] @classmethod def find_one_and_replace(cls, filter, replacement, *args, **kwargs): if filter is not None and not isinstance(filter, Mapping): filter = {"_id": filter} kwargs.setdefault("return_document", ReturnDocument.AFTER) result = cls.collection.find_one_and_replace( filter, replacement, *args, **kwargs ) if result is not None: return cls(result)
[docs] @classmethod def find_one_and_update(cls, filter, update, *args, **kwargs): if filter is not None and not isinstance(filter, Mapping): filter = {"_id": filter} kwargs.setdefault("return_document", ReturnDocument.AFTER) result = cls.collection.find_one_and_update(filter, update, *args, **kwargs) if result is not None: return cls(result)
[docs] def save(self, force_insert=False, refresh=False): data = self.__dict__ collection = self.get_collection() if self.id is not None and not force_insert: filter = {"_id": self.id} collection.replace_one(filter, data, upsert=True) else: collection.insert_one(data) if refresh: self.__dict__ = collection.find_one(self.id) return self
[docs]class AsyncThingy(BaseThingy): _client_cls = MotorClient or AsyncIOMotorClient _cursor_cls = AsyncCursor
[docs] @classmethod async def create_index(cls, keys, **kwargs): await cls.collection.create_index(keys, **kwargs)
[docs] @classmethod async def create_indexes(cls): if hasattr(cls, "_indexes"): for keys, kwargs in cls._indexes: await cls.create_index(keys, **kwargs)
[docs] @classmethod async def find_one_and_replace(cls, filter, replacement, *args, **kwargs): if filter is not None and not isinstance(filter, Mapping): filter = {"_id": filter} kwargs.setdefault("return_document", ReturnDocument.AFTER) result = await cls.collection.find_one_and_replace( filter, replacement, *args, **kwargs ) if result is not None: return cls(result)
[docs] @classmethod async def find_one_and_update(cls, filter, update, *args, **kwargs): if filter is not None and not isinstance(filter, Mapping): filter = {"_id": filter} kwargs.setdefault("return_document", ReturnDocument.AFTER) result = await cls.collection.find_one_and_update( filter, update, *args, **kwargs ) if result is not None: return cls(result)
[docs] async def save(self, force_insert=False, refresh=False): data = self.__dict__ collection = self.get_collection() if self.id is not None and not force_insert: filter = {"_id": self.id} await collection.replace_one(filter, data, upsert=True) else: await collection.insert_one(data) if refresh: self.__dict__ = await collection.find_one(self.id) return self
[docs]def connect(*args, **kwargs): if AsyncThingy._client_cls is not None: AsyncThingy.connect(*args, **kwargs) Thingy.connect(*args, **kwargs)
def disconnect(*args, **kwargs): Thingy.disconnect(*args, **kwargs) AsyncThingy.disconnect(*args, **kwargs)
[docs]def create_indexes(): """Create indexes registered on all :class:`Thingy`""" tasks = [] for cls in registry: if issubclass(cls, Thingy): cls.create_indexes() if issubclass(cls, AsyncThingy): coroutine = cls.create_indexes() task = asyncio.create_task(coroutine) tasks.append(task) if tasks: return asyncio.wait(tasks)
__all__ = ["AsyncThingy", "Thingy", "connect", "create_indexes"]