from __future__ import annotations
import psycopg2
from markupsafe import escape, Markup
from onegov.core.orm.cache import orm_cached, request_cached
from import observes
from onegov.core.orm.session_manager import SessionManager, query_schemas
from onegov.core.orm.sql import as_selectable, as_selectable_from_path
from sqlalchemy import event, inspect
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import object_session
from sqlalchemy.orm import Query
from sqlalchemy_utils import TranslationHybrid
from zope.sqlalchemy import mark_changed
from sqlalchemy.exc import InterfaceError, OperationalError
from typing import overload, Any, ClassVar, TypeVar, TYPE_CHECKING
from import Callable, Iterator
from sqlalchemy import Column
from sqlalchemy_utils.i18n import _TranslatableColumn
from typing import Self
_T = TypeVar('_T')
MISSING = object()
#: The base for all OneGov Core ORM Models
class ModelBase:
#: set by :class:`onegov.core.orm.cache.OrmCacheDescriptor`, this attribute
#: indicates if the current model was loaded from cache
is_cached = False
# FIXME: These are temporary and help mypy know that these attributes
# exist on the Base ORM class
__tablename__: ClassVar[str]
def get_polymorphic_class(cls, identity_value: str) -> type[Self]: ...
def get_polymorphic_class(cls, identity_value: str, default: _T
) -> type[Self] | _T: ...
def get_polymorphic_class(
identity_value: str,
default: _T = MISSING # type:ignore[assignment]
) -> type[Self] | _T:
""" Returns the polymorphic class if it exists, given the value
of the polymorphic identity.
Asserts that the identity is actually found, unless a default is
mapper = inspect(cls).polymorphic_map.get(identity_value)
if default is MISSING:
assert mapper, 'No such polymorphic_identity: {}'.format(
return mapper and mapper.class_ or default
def session_manager(self) -> SessionManager | None:
# FIXME: Should we assert that there is an active SessionManager
# when we access this property? This would allow us not
# having to check that there is one everywhere, but there
# may some existing code that relies on this possibly
# returning `None`, so let's leave it for now
return SessionManager.get_active()
Base = declarative_base(cls=ModelBase)
#: A translation hybrid integrated with OneGov Core. See also:
translation_hybrid = TranslationHybrid(
SessionManager.get_active().current_locale, # type:ignore
SessionManager.get_active().default_locale, # type:ignore
class TranslationMarkupHybrid(TranslationHybrid):
""" A TranslationHybrid that stores `markupsafe.Markup`. """
def getter_factory(
attr: _TranslatableColumn
) -> Callable[[object], Markup | None]:
original_getter = super().getter_factory(attr)
def getter(obj: object) -> Markup | None:
value = original_getter(obj)
if value is self.default_value and isinstance(value, str):
# NOTE: The default may be a plain string so we need
# to escape it
return escape(self.default_value)
# NOTE: Need to wrap in Markup, we may consider sanitizing
# this in the future, to guard against stored values
# that somehow bypassed the sanitization, but this will
# be expensive
return Markup(value) # nosec: B704
return getter
def setter_factory(
attr: _TranslatableColumn
) -> Callable[[object, str | None], None]:
original_setter = super().setter_factory(attr)
def setter(obj: object, value: str | None) -> None:
if value is not None:
value = escape(value)
original_setter(obj, value)
return setter
# FIXME: In SQLAlchemy 2.0 this should return a hybrid_property
def __call__( # type: ignore[override]
attr: _TranslatableColumn
) -> Column[Markup | None]:
#: A translation markup hybrid integrated with OneGov Core.
translation_markup_hybrid = TranslationMarkupHybrid(
SessionManager.get_active().current_locale, # type:ignore
SessionManager.get_active().default_locale, # type:ignore
def find_models(
base: type[_T],
is_match: Callable[[type[_T]], bool]
) -> Iterator[type[_T]]:
""" Finds the ORM models in the given ORM base class that match a filter.
The filter is called with each class in the instance and it is supposed
to return True if it matches.
For example, find all SQLAlchemy models that use
from onegov.core.orm.mixins import ContentMixin
find_models(base, is_match=lambda cls: issubclass(cls, ContentMixin))
for cls in base.__subclasses__():
if is_match(cls):
yield cls
yield from find_models(cls, is_match)
def configure_listener(
cls: type[Base],
key: str,
instance: Base
) -> None:
""" The zope.sqlalchemy transaction mechanism doesn't recognize changes to
cached objects. The following code intercepts all object changes and marks
the transaction as changed if there was a change to a cached object.
def mark_as_changed(obj: Base, *args: Any, **kwargs: Any) -> None:
if obj.is_cached:
event.listen(instance, 'append', mark_as_changed)
event.listen(instance, 'remove', mark_as_changed)
event.listen(instance, 'set', mark_as_changed)
event.listen(instance, 'init_collection', mark_as_changed)
event.listen(instance, 'dispose_collection', mark_as_changed)
event.listen(ModelBase, 'attribute_instrument', configure_listener)
def share_session_manager(query: Query[Any]) -> None:
session_manager = SessionManager.get_active()
for desc in query.column_descriptions:
desc['type'].session_manager = session_manager
event.listen(Query, 'before_compile', share_session_manager, retval=False)
__all__ = [