Source code for pay.collections.payable

from __future__ import annotations

from onegov.core.collection import Pagination
from onegov.core.orm.utils import QueryChain
from onegov.pay import Payment
from sqlalchemy.orm import joinedload, selectinload


from typing import overload, Literal, TYPE_CHECKING
if TYPE_CHECKING:
    from sqlalchemy.orm import DeclarativeBase, Session
    from typing import Self


# FIXME: This should be Intersection[DeclarativeBase, Payable] once this
#        feature gets added to typing_extensions
[docs] class PayableCollection[PayableT: DeclarativeBase](Pagination[PayableT]): """ Provides a collection of all payable models. This collection is meant to be read-only, so there's no add/delete methods. To add payments to payable models just use the payment property and directly assign a new or an existing payment. """
[docs] page: int
@overload def __init__( self: PayableCollection[PayableT], session: Session, cls: type[PayableT], page: int = 0 ): ... @overload def __init__( self: PayableCollection[DeclarativeBase], session: Session, cls: Literal['*'] = '*', page: int = 0 ): ... def __init__( self, session: Session, cls: Literal['*'] | type[PayableT] = '*', page: int = 0 ):
[docs] self.session = session
[docs] self.cls = cls
self.page = page if TYPE_CHECKING: # we override the method that would not be type safe since the type # of query changed from the base class Pagination
[docs] def transform_batch_query( # type:ignore[override] self, query: QueryChain[PayableT] # type:ignore[override] ) -> QueryChain[PayableT]: ...
@property
[docs] def classes(self) -> tuple[type[DeclarativeBase], ...]: if self.cls != '*': return (self.cls, ) assert Payment.registered_links is not None return tuple(link.cls for link in Payment.registered_links.values())
[docs] def query(self) -> QueryChain[PayableT]: return QueryChain(tuple( self.session.query(cls).options( # type: ignore[misc] joinedload(cls.payment) if hasattr(cls, 'payment') else selectinload(cls.payments) # type: ignore[attr-defined] ) for cls in self.classes ))
[docs] def __eq__(self, other: object) -> bool: if not isinstance(other, PayableCollection): return False return self.cls == other.cls and self.page == other.page
[docs] def subset(self) -> QueryChain[PayableT]: # type:ignore[override] return self.query()
@property
[docs] def page_index(self) -> int: return self.page
[docs] def page_by_index(self, index: int) -> Self: return self.__class__(self.session, self.cls, index)