from __future__ import annotations
import psycopg2
from markupsafe import escape, Markup
from onegov.core.orm.cache import orm_cached, request_cached
from onegov.core.orm.observer import observes
from onegov.core.orm.session_manager import SessionManager, query_schemas
from onegov.core.orm.sql import as_selectable, as_selectable_from_path
from sqlalchemy import event, inspect
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import object_session
from sqlalchemy.orm import Query
from sqlalchemy_utils import TranslationHybrid
from zope.sqlalchemy import mark_changed
from sqlalchemy.exc import InterfaceError, OperationalError
from typing import overload, Any, ClassVar, TypeVar, TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable, Iterator
from sqlalchemy import Column
from sqlalchemy_utils.i18n import _TranslatableColumn
from typing import Self
_T = TypeVar('_T')
MISSING = object()
DB_CONNECTION_ERRORS = (
OperationalError,
InterfaceError,
psycopg2.OperationalError,
psycopg2.InterfaceError,
)
#: The base for all OneGov Core ORM Models
class ModelBase:
#: set by :class:`onegov.core.orm.cache.OrmCacheDescriptor`, this attribute
#: indicates if the current model was loaded from cache
is_cached = False
# FIXME: These are temporary and help mypy know that these attributes
# exist on the Base ORM class
__tablename__: ClassVar[str]
@overload
@classmethod
def get_polymorphic_class(cls, identity_value: str) -> type[Self]: ...
@overload
@classmethod
def get_polymorphic_class(cls, identity_value: str, default: _T
) -> type[Self] | _T: ...
@classmethod
def get_polymorphic_class(
cls,
identity_value: str,
default: _T = MISSING # type:ignore[assignment]
) -> type[Self] | _T:
""" Returns the polymorphic class if it exists, given the value
of the polymorphic identity.
Asserts that the identity is actually found, unless a default is
provided.
"""
mapper = inspect(cls).polymorphic_map.get(identity_value)
if default is MISSING:
assert mapper, 'No such polymorphic_identity: {}'.format(
identity_value
)
return mapper and mapper.class_ or default
@property
def session_manager(self) -> SessionManager | None:
# FIXME: Should we assert that there is an active SessionManager
# when we access this property? This would allow us not
# having to check that there is one everywhere, but there
# may some existing code that relies on this possibly
# returning `None`, so let's leave it for now
return SessionManager.get_active()
[docs]
Base = declarative_base(cls=ModelBase)
#: A translation hybrid integrated with OneGov Core. See also:
#: http://sqlalchemy-utils.readthedocs.org/en/latest/internationalization.html
[docs]
translation_hybrid = TranslationHybrid(
current_locale=lambda:
SessionManager.get_active().current_locale, # type:ignore
default_locale=lambda:
SessionManager.get_active().default_locale, # type:ignore
)
class TranslationMarkupHybrid(TranslationHybrid):
""" A TranslationHybrid that stores `markupsafe.Markup`. """
def getter_factory(
self,
attr: _TranslatableColumn
) -> Callable[[object], Markup | None]:
original_getter = super().getter_factory(attr)
def getter(obj: object) -> Markup | None:
value = original_getter(obj)
if value is self.default_value and isinstance(value, str):
# NOTE: The default may be a plain string so we need
# to escape it
return escape(self.default_value)
# NOTE: Need to wrap in Markup, we may consider sanitizing
# this in the future, to guard against stored values
# that somehow bypassed the sanitization, but this will
# be expensive
return Markup(value) # noqa: RUF035
return getter
def setter_factory(
self,
attr: _TranslatableColumn
) -> Callable[[object, str | None], None]:
original_setter = super().setter_factory(attr)
def setter(obj: object, value: str | None) -> None:
if value is not None:
value = escape(value)
original_setter(obj, value)
return setter
if TYPE_CHECKING:
# FIXME: In SQLAlchemy 2.0 this should return a hybrid_property
def __call__( # type: ignore[override]
self,
attr: _TranslatableColumn
) -> Column[Markup | None]:
pass
#: A translation markup hybrid integrated with OneGov Core.
translation_markup_hybrid = TranslationMarkupHybrid(
current_locale=lambda:
SessionManager.get_active().current_locale, # type:ignore
default_locale=lambda:
SessionManager.get_active().default_locale, # type:ignore
)
[docs]
def find_models(
base: type[_T],
is_match: Callable[[type[_T]], bool]
) -> Iterator[type[_T]]:
""" Finds the ORM models in the given ORM base class that match a filter.
The filter is called with each class in the instance and it is supposed
to return True if it matches.
For example, find all SQLAlchemy models that use
:class:`~onegov.core.orm.mixins.ContentMixin`::
from onegov.core.orm.mixins import ContentMixin
find_models(base, is_match=lambda cls: issubclass(cls, ContentMixin))
"""
for cls in base.__subclasses__():
if is_match(cls):
yield cls
yield from find_models(cls, is_match)
def configure_listener(
cls: type[Base],
key: str,
instance: Base
) -> None:
""" The zope.sqlalchemy transaction mechanism doesn't recognize changes to
cached objects. The following code intercepts all object changes and marks
the transaction as changed if there was a change to a cached object.
"""
def mark_as_changed(obj: Base, *args: Any, **kwargs: Any) -> None:
if obj.is_cached:
mark_changed(object_session(obj))
event.listen(instance, 'append', mark_as_changed)
event.listen(instance, 'remove', mark_as_changed)
event.listen(instance, 'set', mark_as_changed)
event.listen(instance, 'init_collection', mark_as_changed)
event.listen(instance, 'dispose_collection', mark_as_changed)
event.listen(ModelBase, 'attribute_instrument', configure_listener)
def share_session_manager(query: Query[Any]) -> None:
session_manager = SessionManager.get_active()
for desc in query.column_descriptions:
desc['type'].session_manager = session_manager
event.listen(Query, 'before_compile', share_session_manager, retval=False)
__all__ = [
'Base',
'SessionManager',
'as_selectable',
'as_selectable_from_path',
'translation_hybrid',
'find_models',
'observes',
'orm_cached',
'query_schemas',
'request_cached'
]