from __future__ import annotations
import psycopg2
from datetime import datetime
from markupsafe import escape, Markup
from onegov.core.orm.cache import orm_cached, request_cached
from onegov.core.orm.observer import observes
from onegov.core.orm.session_manager import SessionManager, query_schemas
from onegov.core.orm.sql import as_selectable, as_selectable_from_path
from sqlalchemy import event, inspect, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import object_session
from sqlalchemy.orm import registry
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Query
from sqlalchemy_utils import TranslationHybrid as BaseTranslationHybrid
from zope.sqlalchemy import mark_changed
from sqlalchemy.exc import InterfaceError, OperationalError
from uuid import UUID as PythonUUID
from .types import JSON
from .types import MarkupText
from .types import UTCDateTime
from typing import overload, Any, TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable, Iterator, Mapping
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import Mapped
from sqlalchemy_utils.i18n import _TranslatableColumn
from typing import Self, TypeGuard
from typing_extensions import TypeIs
MISSING = object()
DB_CONNECTION_ERRORS = (
OperationalError,
InterfaceError,
psycopg2.OperationalError,
psycopg2.InterfaceError,
)
#: The base for all OneGov Core ORM Models
class ModelBase:
#: set by :class:`onegov.core.orm.cache.OrmCacheDescriptor`, this attribute
#: indicates if the current model was loaded from cache
is_cached = False
@overload
@classmethod
def get_polymorphic_class(cls, identity_value: str) -> type[Self]: ...
@overload
@classmethod
def get_polymorphic_class[T](
cls,
identity_value: str,
default: T
) -> type[Self] | T: ...
@classmethod
def get_polymorphic_class(
cls,
identity_value: str,
default: Any = MISSING
) -> type[Self] | Any:
""" Returns the polymorphic class if it exists, given the value
of the polymorphic identity.
Asserts that the identity is actually found, unless a default is
provided.
"""
mapper = inspect(cls).polymorphic_map.get(identity_value) # type: ignore[union-attr]
if default is MISSING:
assert mapper, 'No such polymorphic_identity: {}'.format(
identity_value
)
return mapper and mapper.class_ or default
@property
def session_manager(self) -> SessionManager | None:
# FIXME: Should we assert that there is an active SessionManager
# when we access this property? This would allow us not
# having to check that there is one everywhere, but there
# may some existing code that relies on this possibly
# returning `None`, so let's leave it for now
return SessionManager.get_active()
[docs]
class Base(DeclarativeBase, ModelBase):
[docs]
registry = registry(type_annotation_map={
datetime: UTCDateTime,
dict[str, Any]: JSON,
Markup: MarkupText,
PythonUUID: UUID(as_uuid=True),
# NOTE: I'm not happy that we use Text so liberally in OneGov, for
# most cases String would work just fine and would be a lot
# faster for filtering/searching, but alas, it is what it is
# Migrating dozens of columns from Text to String does not
# sound fun, but we may tackle it eventually...
str: Text
})
class TranslationHybrid(BaseTranslationHybrid):
# NOTE: This works around the fact that `MappedColumn` does not expose
# the column's `key` attribute in the way it exposes its `name`
# attribute, so we need to pass the actual column, instead of the
# `MappedColumn` to preserve the same API in SQLAlchemy 2.0.
def __call__(
self,
attr: Mapped[Mapping[str, str]] | Mapped[Mapping[str, str] | None]
) -> hybrid_property[str | None]:
return super().__call__(attr.column) # type: ignore[union-attr]
#: A translation hybrid integrated with OneGov Core. See also:
#: http://sqlalchemy-utils.readthedocs.org/en/latest/internationalization.html
[docs]
translation_hybrid = TranslationHybrid(
current_locale=lambda:
SessionManager.get_active().current_locale, # type:ignore
default_locale=lambda:
SessionManager.get_active().default_locale, # type:ignore
)
class TranslationMarkupHybrid(TranslationHybrid):
""" A TranslationHybrid that stores `markupsafe.Markup`. """
def getter_factory(
self,
attr: _TranslatableColumn
) -> Callable[[object], Markup | None]:
original_getter = super().getter_factory(attr)
def getter(obj: object) -> Markup | None:
value = original_getter(obj)
if value is self.default_value and isinstance(value, str):
# NOTE: The default may be a plain string so we need
# to escape it
return escape(self.default_value)
# NOTE: Need to wrap in Markup, we may consider sanitizing
# this in the future, to guard against stored values
# that somehow bypassed the sanitization, but this will
# be expensive
return Markup(value) # nosec: B704
return getter
def setter_factory(
self,
attr: _TranslatableColumn
) -> Callable[[object, str | None], None]:
original_setter = super().setter_factory(attr)
def setter(obj: object, value: str | None) -> None:
if value is not None:
value = escape(value)
original_setter(obj, value)
return setter
if TYPE_CHECKING:
def __call__( # type: ignore[override]
self,
attr: _TranslatableColumn
) -> hybrid_property[Markup | None]:
pass
#: A translation markup hybrid integrated with OneGov Core.
translation_markup_hybrid = TranslationMarkupHybrid(
current_locale=lambda:
SessionManager.get_active().current_locale, # type:ignore
default_locale=lambda:
SessionManager.get_active().default_locale, # type:ignore
)
@overload
[docs]
def find_models[T, TG](
base: type[T],
is_match: Callable[[type[T]], TypeGuard[TG]]
) -> Iterator[type[TG]]: ...
@overload
def find_models[T, TG](
base: type[T],
is_match: Callable[[type[T]], TypeIs[TG]]
) -> Iterator[type[TG]]: ...
@overload
def find_models[T](
base: type[T],
is_match: Callable[[type[T]], bool]
) -> Iterator[type[T]]: ...
def find_models[T](
base: type[T],
is_match: Callable[[type[T]], bool]
) -> Iterator[type[Any]]:
""" Finds the ORM models in the given ORM base class that match a filter.
The filter is called with each class in the instance and it is supposed
to return True if it matches.
For example, find all SQLAlchemy models that use
:class:`~onegov.core.orm.mixins.ContentMixin`::
from onegov.core.orm.mixins import ContentMixin
find_models(base, is_match=lambda cls: issubclass(cls, ContentMixin))
"""
for cls in base.__subclasses__():
if is_match(cls):
yield cls
yield from find_models(cls, is_match)
def configure_listener(
cls: type[Base],
key: str,
instance: Base
) -> None:
""" The zope.sqlalchemy transaction mechanism doesn't recognize changes to
cached objects. The following code intercepts all object changes and marks
the transaction as changed if there was a change to a cached object.
"""
def mark_as_changed(obj: Base, *args: object, **kwargs: object) -> None:
if obj.is_cached and (session := object_session(obj)):
mark_changed(session)
event.listen(instance, 'append', mark_as_changed)
event.listen(instance, 'remove', mark_as_changed)
event.listen(instance, 'set', mark_as_changed)
event.listen(instance, 'init_collection', mark_as_changed)
event.listen(instance, 'dispose_collection', mark_as_changed)
event.listen(ModelBase, 'attribute_instrument', configure_listener)
def share_session_manager(query: Query[Any]) -> None:
session_manager = SessionManager.get_active()
for desc in query.column_descriptions:
desc['type'].session_manager = session_manager # type: ignore[union-attr]
event.listen(Query, 'before_compile', share_session_manager, retval=False)
__all__ = [
'Base',
'SessionManager',
'as_selectable',
'as_selectable_from_path',
'translation_hybrid',
'find_models',
'observes',
'orm_cached',
'query_schemas',
'request_cached'
]