Source code for pydantic_mongo.abstract_repository

from typing import Any, Dict, Iterable, Optional, Type, Union, cast

from pymongo import UpdateOne
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.results import InsertOneResult, UpdateResult

from .base_abstract_repository import (
    BaseAbstractRepository,
    ModelWithId,
    OutputT,
    Sort,
    T,
)
from .pagination import Edge, encode_pagination_cursor, get_pagination_cursor_payload


[docs] class AbstractRepository(BaseAbstractRepository[T]): """A synchronous repository implementation for MongoDB using Pydantic models. This class provides a high-level interface for performing CRUD operations on MongoDB collections using Pydantic models for type safety and data validation. Example: class UserRepository(AbstractRepository[User]): class Meta: collection_name = 'users' repo = UserRepository(database) user = repo.find_one_by_id(user_id) Generic type T must be a Pydantic model with an 'id' field. """
[docs] class Meta: collection_name: str
def __init__(self, database: Database): """Initialize the repository with a MongoDB database connection. Args: database: PyMongo Database instance """ self.__database: Database = database super().__init__()
[docs] def get_collection(self) -> Collection: """Get the MongoDB collection associated with this repository. Returns: Collection: PyMongo Collection instance """ return self.__database[self._collection_name]
[docs] def save(self, model: T) -> Union[InsertOneResult, UpdateResult]: """Save a model instance to the database. This method will: - Insert the model if it doesn't have an ID - Update the model if it has an ID Args: model: The model instance to save Returns: Union[InsertOneResult, UpdateResult]: The result of the save operation """ document = self.to_document(model) model_with_id = cast(ModelWithId, model) if model_with_id.id: mongo_id = document.pop("_id") return self.get_collection().update_one( {"_id": mongo_id}, {"$set": document}, upsert=True ) result = self.get_collection().insert_one(document) model_with_id.id = result.inserted_id return result
[docs] def save_many(self, models: Iterable[T]): """Save multiple model instances to the database in bulk. This method optimizes bulk operations by: - Grouping models into insert and update operations - Performing bulk inserts and updates Args: models: Iterable of model instances to save """ models_to_insert = [] models_to_update = [] for model in models: model_with_id = cast(ModelWithId, model) if model_with_id.id: models_to_update.append(model) else: models_to_insert.append(model) if len(models_to_insert) > 0: result = self.get_collection().insert_many( (self.to_document(model) for model in models_to_insert) ) for idx, inserted_id in enumerate(result.inserted_ids): cast(ModelWithId, models_to_insert[idx]).id = inserted_id if len(models_to_update) == 0: return documents_to_update = [self.to_document(model) for model in models_to_update] mongo_ids = [doc.pop("_id") for doc in documents_to_update] bulk_operations = [ UpdateOne({"_id": mongo_id}, {"$set": document}, upsert=True) for mongo_id, document in zip(mongo_ids, documents_to_update) ] self.get_collection().bulk_write(bulk_operations)
[docs] def delete(self, model: T): """Delete a model instance from the database. Args: model: The model instance to delete Returns: DeleteResult: The result of the delete operation """ return self.get_collection().delete_one({"_id": cast(ModelWithId, model).id})
[docs] def delete_by_id(self, _id: Any): """Delete a model instance from the database by its ID. Args: _id: The ID of the model instance to delete Returns: DeleteResult: The result of the delete operation """ return self.get_collection().delete_one({"_id": _id})
[docs] def find_one_by_id(self, _id: Any) -> Optional[T]: """Find a single model instance by its ID. Args: _id: The ID to search for. Must match the type of the model's ID field (typically ObjectId for MongoDB) Returns: Optional[T]: The found model instance or None if not found """ return self.find_one_by({"id": _id})
[docs] def find_one_by(self, query: dict) -> Optional[T]: """Find a single model instance by a MongoDB query. Args: query: MongoDB query dictionary Returns: Optional[T]: The found model instance or None if not found Example: user = repo.find_one_by({"email": "user@example.com"}) """ result = self.get_collection().find_one(self._map_id(query)) return self.to_model(result) if result else None
[docs] def find_by_with_output_type( self, output_type: Type[OutputT], query: dict, skip: Optional[int] = None, limit: Optional[int] = None, sort: Optional[Sort] = None, projection: Optional[Dict[str, int]] = None, ) -> Iterable[OutputT]: """Find multiple model instances with custom output type. This method allows querying with a different output model than the repository's base model type, useful for projections and transformations. Args: output_type: The Pydantic model class for the output query: MongoDB query dictionary skip: Number of documents to skip limit: Maximum number of documents to return sort: List of (field, direction) tuples for sorting projection: MongoDB projection dictionary Returns: Iterable[OutputT]: Iterator of model instances of the specified output type """ mapped_projection = self._map_id(projection) if projection else None mapped_sort = self._map_sort(sort) if sort else None cursor = self.get_collection().find(self._map_id(query), mapped_projection) if limit: cursor.limit(limit) if skip: cursor.skip(skip) if mapped_sort: cursor.sort(mapped_sort) return map(lambda doc: self.to_model_custom(output_type, doc), cursor)
[docs] def find_by( self, query: dict, skip: Optional[int] = None, limit: Optional[int] = None, sort: Optional[Sort] = None, projection: Optional[Dict[str, int]] = None, ) -> Iterable[T]: """Find multiple model instances by a MongoDB query. Args: query: MongoDB query dictionary skip: Number of documents to skip limit: Maximum number of documents to return sort: List of (field, direction) tuples for sorting projection: MongoDB projection dictionary Returns: Iterable[T]: Iterator of model instances """ return self.find_by_with_output_type( output_type=self._document_class, query=query, skip=skip, limit=limit, sort=sort, projection=projection, )
[docs] def paginate_with_output_type( self, output_type: Type[OutputT], query: dict, limit: int, after: Optional[str] = None, before: Optional[str] = None, sort: Optional[Sort] = None, projection: Optional[Dict[str, int]] = None, ) -> Iterable[Edge[OutputT]]: """Paginate through model instances with custom output type. This method implements cursor-based pagination which is more reliable than offset-based pagination for large datasets. Args: output_type: The Pydantic model class for the output query: MongoDB query dictionary limit: Maximum number of documents per page after: Cursor string for fetching next page before: Cursor string for fetching previous page sort: List of (field, direction) tuples for sorting projection: MongoDB projection dictionary Returns: Iterable[Edge[OutputT]]: Iterator of Edge objects containing model instances and pagination cursors """ sort_keys = [] if not sort: sort = [("_id", 1)] for sort_expression in sort: sort_keys.append(sort_expression[0]) models = self.find_by_with_output_type( output_type, query=self.get_pagination_query( query=query, after=after, before=before, sort=sort ), limit=limit, sort=sort, projection=projection, ) return map( lambda model: Edge[OutputT]( node=model, cursor=encode_pagination_cursor( get_pagination_cursor_payload(model, sort_keys) ), ), models, )
[docs] def paginate( self, query: dict, limit: int, after: Optional[str] = None, before: Optional[str] = None, sort: Optional[Sort] = None, projection: Optional[Dict[str, int]] = None, ) -> Iterable[Edge[T]]: """Paginate through model instances using cursor-based pagination. This method implements cursor-based pagination which is more reliable than offset-based pagination for large datasets. Args: query: MongoDB query dictionary limit: Maximum number of documents per page after: Cursor string for fetching next page before: Cursor string for fetching previous page sort: List of (field, direction) tuples for sorting projection: MongoDB projection dictionary Returns: Iterable[Edge[T]]: Iterator of Edge objects containing model instances and pagination cursors Example: Get first page:: edges = repo.paginate({"status": "active"}, limit=10) Get next page using the last cursor:: next_edges = repo.paginate( {"status": "active"}, limit=10, after=list(edges)[-1].cursor ) """ return self.paginate_with_output_type( self._document_class, query, limit, after=after, before=before, sort=sort, projection=projection, )