Source code for core.orm

from __future__ import annotations

import psycopg2

from datetime import datetime
from markupsafe import escape, Markup
from onegov.core.orm.cache import orm_cached, request_cached
from onegov.core.orm.observer import observes
from onegov.core.orm.session_manager import SessionManager, query_schemas
from onegov.core.orm.sql import as_selectable, as_selectable_from_path
from sqlalchemy import event, inspect, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import object_session
from sqlalchemy.orm import registry
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Query
from sqlalchemy_utils import TranslationHybrid as BaseTranslationHybrid
from zope.sqlalchemy import mark_changed
from sqlalchemy.exc import InterfaceError, OperationalError
from uuid import UUID as PythonUUID

from .types import JSON
from .types import MarkupText
from .types import UTCDateTime


from typing import overload, Any, TypeVar, TYPE_CHECKING
if TYPE_CHECKING:
    from collections.abc import Callable, Iterator, Mapping
    from sqlalchemy.ext.hybrid import hybrid_property
    from sqlalchemy.orm import Mapped
    from sqlalchemy_utils.i18n import _TranslatableColumn
    from typing import Self

_T = TypeVar('_T')


MISSING = object()

DB_CONNECTION_ERRORS = (
    OperationalError,
    InterfaceError,
    psycopg2.OperationalError,
    psycopg2.InterfaceError,
)


#: The base for all OneGov Core ORM Models
class ModelBase:

    #: set by :class:`onegov.core.orm.cache.OrmCacheDescriptor`, this attribute
    #: indicates if the current model was loaded from cache
    is_cached = False

    @overload
    @classmethod
    def get_polymorphic_class(cls, identity_value: str) -> type[Self]: ...

    @overload
    @classmethod
    def get_polymorphic_class(cls, identity_value: str, default: _T
                              ) -> type[Self] | _T: ...

    @classmethod
    def get_polymorphic_class(
        cls,
        identity_value: str,
        default: _T = MISSING  # type:ignore[assignment]
    ) -> type[Self] | _T:
        """ Returns the polymorphic class if it exists, given the value
        of the polymorphic identity.

        Asserts that the identity is actually found, unless a default is
        provided.

        """
        mapper = inspect(cls).polymorphic_map.get(identity_value)  # type: ignore[union-attr]

        if default is MISSING:
            assert mapper, 'No such polymorphic_identity: {}'.format(
                identity_value
            )

        return mapper and mapper.class_ or default

    @property
    def session_manager(self) -> SessionManager | None:
        # FIXME: Should we assert that there is an active SessionManager
        #        when we access this property? This would allow us not
        #        having to check that there is one everywhere, but there
        #        may some existing code that relies on this possibly
        #        returning `None`, so let's leave it for now
        return SessionManager.get_active()


[docs] class Base(DeclarativeBase, ModelBase):
[docs] registry = registry(type_annotation_map={ datetime: UTCDateTime, dict[str, Any]: JSON, Markup: MarkupText, PythonUUID: UUID(as_uuid=True), # NOTE: I'm not happy that we use Text so liberally in OneGov, for # most cases String would work just fine and would be a lot # faster for filtering/searching, but alas, it is what it is # Migrating dozens of columns from Text to String does not # sound fun, but we may tackle it eventually... str: Text })
class TranslationHybrid(BaseTranslationHybrid): # NOTE: This works around the fact that `MappedColumn` does not expose # the column's `key` attribute in the way it exposes its `name` # attribute, so we need to pass the actual column, instead of the # `MappedColumn` to preserve the same API in SQLAlchemy 2.0. def __call__( self, attr: Mapped[Mapping[str, str]] | Mapped[Mapping[str, str] | None] ) -> hybrid_property[str | None]: return super().__call__(attr.column) # type: ignore[union-attr] #: A translation hybrid integrated with OneGov Core. See also: #: http://sqlalchemy-utils.readthedocs.org/en/latest/internationalization.html
[docs] translation_hybrid = TranslationHybrid( current_locale=lambda: SessionManager.get_active().current_locale, # type:ignore default_locale=lambda: SessionManager.get_active().default_locale, # type:ignore )
class TranslationMarkupHybrid(TranslationHybrid): """ A TranslationHybrid that stores `markupsafe.Markup`. """ def getter_factory( self, attr: _TranslatableColumn ) -> Callable[[object], Markup | None]: original_getter = super().getter_factory(attr) def getter(obj: object) -> Markup | None: value = original_getter(obj) if value is self.default_value and isinstance(value, str): # NOTE: The default may be a plain string so we need # to escape it return escape(self.default_value) # NOTE: Need to wrap in Markup, we may consider sanitizing # this in the future, to guard against stored values # that somehow bypassed the sanitization, but this will # be expensive return Markup(value) # nosec: B704 return getter def setter_factory( self, attr: _TranslatableColumn ) -> Callable[[object, str | None], None]: original_setter = super().setter_factory(attr) def setter(obj: object, value: str | None) -> None: if value is not None: value = escape(value) original_setter(obj, value) return setter if TYPE_CHECKING: def __call__( # type: ignore[override] self, attr: _TranslatableColumn ) -> hybrid_property[Markup | None]: pass #: A translation markup hybrid integrated with OneGov Core. translation_markup_hybrid = TranslationMarkupHybrid( current_locale=lambda: SessionManager.get_active().current_locale, # type:ignore default_locale=lambda: SessionManager.get_active().default_locale, # type:ignore )
[docs] def find_models( base: type[_T], is_match: Callable[[type[_T]], bool] ) -> Iterator[type[_T]]: """ Finds the ORM models in the given ORM base class that match a filter. The filter is called with each class in the instance and it is supposed to return True if it matches. For example, find all SQLAlchemy models that use :class:`~onegov.core.orm.mixins.ContentMixin`:: from onegov.core.orm.mixins import ContentMixin find_models(base, is_match=lambda cls: issubclass(cls, ContentMixin)) """ for cls in base.__subclasses__(): if is_match(cls): yield cls yield from find_models(cls, is_match)
def configure_listener( cls: type[Base], key: str, instance: Base ) -> None: """ The zope.sqlalchemy transaction mechanism doesn't recognize changes to cached objects. The following code intercepts all object changes and marks the transaction as changed if there was a change to a cached object. """ def mark_as_changed(obj: Base, *args: object, **kwargs: object) -> None: if obj.is_cached and (session := object_session(obj)): mark_changed(session) event.listen(instance, 'append', mark_as_changed) event.listen(instance, 'remove', mark_as_changed) event.listen(instance, 'set', mark_as_changed) event.listen(instance, 'init_collection', mark_as_changed) event.listen(instance, 'dispose_collection', mark_as_changed) event.listen(ModelBase, 'attribute_instrument', configure_listener) def share_session_manager(query: Query[Any]) -> None: session_manager = SessionManager.get_active() for desc in query.column_descriptions: desc['type'].session_manager = session_manager # type: ignore[union-attr] event.listen(Query, 'before_compile', share_session_manager, retval=False) __all__ = [ 'Base', 'SessionManager', 'as_selectable', 'as_selectable_from_path', 'translation_hybrid', 'find_models', 'observes', 'orm_cached', 'query_schemas', 'request_cached' ]