Source code for core.orm

from __future__ import annotations

import psycopg2

from datetime import datetime
from markupsafe import escape, Markup
from onegov.core.orm.cache import orm_cached, request_cached
from onegov.core.orm.observer import observes
from onegov.core.orm.session_manager import SessionManager, query_schemas
from onegov.core.orm.sql import as_selectable, as_selectable_from_path
from sqlalchemy import event, inspect, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import object_session
from sqlalchemy.orm import registry
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Query
from sqlalchemy_utils import TranslationHybrid as BaseTranslationHybrid
from zope.sqlalchemy import mark_changed
from sqlalchemy.exc import InterfaceError, OperationalError
from uuid import UUID as PythonUUID

from .types import JSON
from .types import MarkupText
from .types import UTCDateTime


from typing import overload, Any, TYPE_CHECKING
if TYPE_CHECKING:
    from collections.abc import Callable, Iterator, Mapping
    from sqlalchemy.ext.hybrid import hybrid_property
    from sqlalchemy.orm import Mapped
    from sqlalchemy_utils.i18n import _TranslatableColumn
    from typing import Self, TypeGuard
    from typing_extensions import TypeIs


MISSING = object()

DB_CONNECTION_ERRORS = (
    OperationalError,
    InterfaceError,
    psycopg2.OperationalError,
    psycopg2.InterfaceError,
)


#: The base for all OneGov Core ORM Models
class ModelBase:

    #: set by :class:`onegov.core.orm.cache.OrmCacheDescriptor`, this attribute
    #: indicates if the current model was loaded from cache
    is_cached = False

    @overload
    @classmethod
    def get_polymorphic_class(cls, identity_value: str) -> type[Self]: ...

    @overload
    @classmethod
    def get_polymorphic_class[T](
        cls,
        identity_value: str,
        default: T
    ) -> type[Self] | T: ...

    @classmethod
    def get_polymorphic_class(
        cls,
        identity_value: str,
        default: Any = MISSING
    ) -> type[Self] | Any:
        """ Returns the polymorphic class if it exists, given the value
        of the polymorphic identity.

        Asserts that the identity is actually found, unless a default is
        provided.

        """
        mapper = inspect(cls).polymorphic_map.get(identity_value)  # type: ignore[union-attr]

        if default is MISSING:
            assert mapper, 'No such polymorphic_identity: {}'.format(
                identity_value
            )

        return mapper and mapper.class_ or default

    @property
    def session_manager(self) -> SessionManager | None:
        # FIXME: Should we assert that there is an active SessionManager
        #        when we access this property? This would allow us not
        #        having to check that there is one everywhere, but there
        #        may some existing code that relies on this possibly
        #        returning `None`, so let's leave it for now
        return SessionManager.get_active()


[docs] class Base(DeclarativeBase, ModelBase):
[docs] registry = registry(type_annotation_map={ datetime: UTCDateTime, dict[str, Any]: JSON, Markup: MarkupText, PythonUUID: UUID(as_uuid=True), # NOTE: I'm not happy that we use Text so liberally in OneGov, for # most cases String would work just fine and would be a lot # faster for filtering/searching, but alas, it is what it is # Migrating dozens of columns from Text to String does not # sound fun, but we may tackle it eventually... str: Text })
class TranslationHybrid(BaseTranslationHybrid): # NOTE: This works around the fact that `MappedColumn` does not expose # the column's `key` attribute in the way it exposes its `name` # attribute, so we need to pass the actual column, instead of the # `MappedColumn` to preserve the same API in SQLAlchemy 2.0. def __call__( self, attr: Mapped[Mapping[str, str]] | Mapped[Mapping[str, str] | None] ) -> hybrid_property[str | None]: return super().__call__(attr.column) # type: ignore[union-attr] #: A translation hybrid integrated with OneGov Core. See also: #: http://sqlalchemy-utils.readthedocs.org/en/latest/internationalization.html
[docs] translation_hybrid = TranslationHybrid( current_locale=lambda: SessionManager.get_active().current_locale, # type:ignore default_locale=lambda: SessionManager.get_active().default_locale, # type:ignore )
class TranslationMarkupHybrid(TranslationHybrid): """ A TranslationHybrid that stores `markupsafe.Markup`. """ def getter_factory( self, attr: _TranslatableColumn ) -> Callable[[object], Markup | None]: original_getter = super().getter_factory(attr) def getter(obj: object) -> Markup | None: value = original_getter(obj) if value is self.default_value and isinstance(value, str): # NOTE: The default may be a plain string so we need # to escape it return escape(self.default_value) # NOTE: Need to wrap in Markup, we may consider sanitizing # this in the future, to guard against stored values # that somehow bypassed the sanitization, but this will # be expensive return Markup(value) # nosec: B704 return getter def setter_factory( self, attr: _TranslatableColumn ) -> Callable[[object, str | None], None]: original_setter = super().setter_factory(attr) def setter(obj: object, value: str | None) -> None: if value is not None: value = escape(value) original_setter(obj, value) return setter if TYPE_CHECKING: def __call__( # type: ignore[override] self, attr: _TranslatableColumn ) -> hybrid_property[Markup | None]: pass #: A translation markup hybrid integrated with OneGov Core. translation_markup_hybrid = TranslationMarkupHybrid( current_locale=lambda: SessionManager.get_active().current_locale, # type:ignore default_locale=lambda: SessionManager.get_active().default_locale, # type:ignore ) @overload
[docs] def find_models[T, TG]( base: type[T], is_match: Callable[[type[T]], TypeGuard[TG]] ) -> Iterator[type[TG]]: ...
@overload def find_models[T, TG]( base: type[T], is_match: Callable[[type[T]], TypeIs[TG]] ) -> Iterator[type[TG]]: ... @overload def find_models[T]( base: type[T], is_match: Callable[[type[T]], bool] ) -> Iterator[type[T]]: ... def find_models[T]( base: type[T], is_match: Callable[[type[T]], bool] ) -> Iterator[type[Any]]: """ Finds the ORM models in the given ORM base class that match a filter. The filter is called with each class in the instance and it is supposed to return True if it matches. For example, find all SQLAlchemy models that use :class:`~onegov.core.orm.mixins.ContentMixin`:: from onegov.core.orm.mixins import ContentMixin find_models(base, is_match=lambda cls: issubclass(cls, ContentMixin)) """ for cls in base.__subclasses__(): if is_match(cls): yield cls yield from find_models(cls, is_match) def configure_listener( cls: type[Base], key: str, instance: Base ) -> None: """ The zope.sqlalchemy transaction mechanism doesn't recognize changes to cached objects. The following code intercepts all object changes and marks the transaction as changed if there was a change to a cached object. """ def mark_as_changed(obj: Base, *args: object, **kwargs: object) -> None: if obj.is_cached and (session := object_session(obj)): mark_changed(session) event.listen(instance, 'append', mark_as_changed) event.listen(instance, 'remove', mark_as_changed) event.listen(instance, 'set', mark_as_changed) event.listen(instance, 'init_collection', mark_as_changed) event.listen(instance, 'dispose_collection', mark_as_changed) event.listen(ModelBase, 'attribute_instrument', configure_listener) def share_session_manager(query: Query[Any]) -> None: session_manager = SessionManager.get_active() for desc in query.column_descriptions: desc['type'].session_manager = session_manager # type: ignore[union-attr] event.listen(Query, 'before_compile', share_session_manager, retval=False) __all__ = [ 'Base', 'SessionManager', 'as_selectable', 'as_selectable_from_path', 'translation_hybrid', 'find_models', 'observes', 'orm_cached', 'query_schemas', 'request_cached' ]