Source code for core.collection

import math

from functools import cached_property
from sqlalchemy import or_
from sqlalchemy.inspection import inspect

from onegov.core.orm import func


from typing import Any, Generic, Literal, TypeVar, TYPE_CHECKING
if TYPE_CHECKING:
    from _typeshed import SupportsItems
    from abc import abstractmethod
    from collections.abc import Collection, Iterable, Iterator, Sequence
    from sqlalchemy import Column
    from sqlalchemy.orm import Query, Session
    from sqlalchemy.sql.elements import ClauseElement
    from typing import Protocol
    from typing import Self
    from uuid import UUID

    from onegov.core.orm import Base

    # TODO: Maybe PKType should be generic as well? Or if we always
    #       use the same kind of primary key, then we can reduce
    #       this type union to something more specific
[docs] PKType = UUID | str | int
TextColumn = Column[str] | Column[str | None] # NOTE: To avoid referencing onegov.form from onegov.core and # introducing a cross-dependency, we use a Protocol to # forward declare exactly the attributes we require to # be implemented for a Form class, so it can be used with # our GenericCollection class _FormThatSupportsGetUsefulData(Protocol): def get_useful_data(self) -> SupportsItems[str, Any]: ...
[docs] _M = TypeVar('_M', bound='Base')
[docs] class GenericCollection(Generic[_M]): def __init__(self, session: 'Session', **kwargs: Any):
[docs] self.session = session
@property
[docs] def model_class(self) -> type[_M]: raise NotImplementedError
@cached_property
[docs] def primary_key(self) -> 'Column[str] | Column[UUID] | Column[int]': return inspect(self.model_class).primary_key[0]
[docs] def query(self) -> 'Query[_M]': return self.session.query(self.model_class)
[docs] def by_id(self, id: 'PKType') -> _M | None: return self.query().filter(self.primary_key == id).first()
[docs] def by_ids(self, ids: 'Collection[PKType]') -> list[_M]: # FIXME: This type error is a bug in the sqlalchemy-stubs # plugin, it might go away with SQLAlchemy 2.0, since # Column is being treated like a descriptor, even though # it is not, and there's a hidden descriptor inserted # into the DeclarativeBase to wrap the Column. return self.query().filter( self.primary_key.in_(ids) # type:ignore[union-attr] ).all() if ids else []
# NOTE: Subclasses should be more specific, so we get type # safety on the constructor of the model, ideally the # subclasses also set kwargs to Never at that point # so we get an error if we use an argument that doesn't # exist for the given model
[docs] def add(self, **kwargs: Any) -> _M: item = self.model_class(**kwargs) self.session.add(item) self.session.flush() return item
[docs] def add_by_form( self, form: '_FormThatSupportsGetUsefulData', properties: 'Iterable[str] | None' = None ) -> _M: cls = self.model_class return self.add(**{ # fields k: v for k, v in form.get_useful_data().items() if hasattr(cls, k) }, **{ # attributes k: getattr(form, k) for k in properties or () })
[docs] def delete(self, item: _M) -> None: self.session.delete(item) self.session.flush()
[docs] class SearcheableCollection(GenericCollection[_M]): """ Requires a self.locale and self.term """ @staticmethod
[docs] def match_term( column: 'Column[str] | Column[str | None]', language: str, term: str ) -> 'ClauseElement': """ Usage: model.filter(match_term(model.col, 'german', 'my search term')) """ document_tsvector = func.to_tsvector(language, column) # type:ignore ts_query_object = func.to_tsquery(language, term) # type:ignore return document_tsvector.op('@@')(ts_query_object)
@staticmethod
[docs] def term_to_tsquery_string(term: str) -> str: """ Returns the current search term transformed to use within Postgres ``to_tsquery`` function. Removes all unwanted characters, replaces prefix matching, joins word together using FOLLOWED BY. """ def cleanup(word: str, whitelist_chars: str = ',.-_') -> str: # FIXME: str.translate or even re.sub might be faster result = ''.join( char for char in word if char.isalnum() or char in whitelist_chars ) return f'{result}:*' if word.endswith('*') else result parts = (cleanup(part) for part in (term or '').split()) return ' <-> '.join(tuple(part for part in parts if part))
[docs] def filter_text_by_locale( self, column: 'Column[str] | Column[str | None]', term: str, locale: str | None = None ) -> 'ClauseElement': """ Returns an SqlAlchemy filter statement based on the search term. If no locale is provided, it will use english as language. ``to_tsquery`` creates a tsquery value from term, which must consist of single tokens separated by the Boolean operators & (AND), | (OR) and ! (NOT). ``to_tsvector`` parses a textual document into tokens, reduces the tokens to lexemes, and returns a tsvector which lists the lexemes together with their positions in the document. The document is processed according to the specified or default text search configuration. """ # FIXME: Move this to a ClassVar or global mapping = {'de_CH': 'german', 'fr_CH': 'french', 'it_CH': 'italian', 'rm_CH': 'english', None: 'english'} return self.__class__.match_term( column, mapping.get(locale, 'english'), term)
if TYPE_CHECKING: # NOTE: This enforces the properties to be implemented in subclasses @property @abstractmethod
[docs] def locale(self) -> str: ...
@property @abstractmethod def term(self) -> str: ... @property @abstractmethod def term_filter_cols(self) -> dict[str, 'TextColumn']: ... else: @property def term_filter_cols(self) -> dict[str, 'TextColumn']: """ Returns a dict of column names to search in with term. Must be attributes of self.model_class. """ raise NotImplementedError @property
[docs] def term_filter(self) -> 'Iterator[ClauseElement]': assert self.term_filter_cols term = self.__class__.term_to_tsquery_string(self.term) return ( self.filter_text_by_locale( getattr(self.model_class, col), term, self.locale) for col in self.term_filter_cols )
[docs] def query(self) -> 'Query[_M]': if not self.term or not self.locale: return super().query() return super().query().filter(or_(*self.term_filter))
# FIXME: We are a little bit inconsistent about what's a base class # and what's a mixin and how we use it downstream, we should # probably try to clean that up a bit, so we always do it the # same way...
[docs] class RangedPagination(Generic[_M]): """ Provides a pagination that supports loading multiple pages at once. This is useful in a context where a single button is used to 'load more' results one by one. In this case we need an URL that represents what's happening on the screen (multiple pages are shown at the same time). """ # how many items are shown per page
[docs] batch_size = 20
# how many items may be shown together, ranges exceeding this limit are # may be clipped by using `limit_range`.
[docs] range_limit = 5
[docs] def subset(self) -> 'Query[_M]': """ Returns an SQLAlchemy query containing all records that should be considered for pagination. """ raise NotImplementedError
@cached_property
[docs] def cached_subset(self) -> 'Query[_M]': return self.subset()
@property
[docs] def page_range(self) -> tuple[int, int]: """ Returns the current page range (starting at (0, 0)). """ raise NotImplementedError
[docs] def by_page_range(self, page_range: tuple[int, int]) -> 'Self': """ Returns an instance of the collection limited to the given page range. """ raise NotImplementedError
[docs] def limit_range( self, page_range: 'Sequence[int] | None', direction: Literal['up', 'down'] ) -> tuple[int, int]: """ Limits the range to the range limit in the given direction. For example, 0-99 will be limited to 89-99 with a limit of 10 and 'up'. With 'down' it will be limited to 0-9. """ assert direction in ('up', 'down') if not page_range: s, e = 0, 9 elif len(page_range) == 1: s, e = page_range[0], page_range[0] else: s, e = page_range[:2] if e < s: s, e = e, s if (e - s) > self.range_limit: if direction == 'down': return s, s + self.range_limit else: return max(0, e - self.range_limit), e return (s, e)
[docs] def transform_batch_query(self, query: 'Query[_M]') -> 'Query[_M]': """ Allows subclasses to transform the given query before it is used to retrieve the batch. This is a good place to add additional loading that should only apply to the batch (say joining other values to the batch which are then not loaded by the whole query). """ return query
@cached_property
[docs] def subset_count(self) -> int: """ Returns the total number of elements this pagination represents. """ # the ordering is entirely unnecessary for a count, so remove it # to count things faster return self.cached_subset.order_by(None).count()
@cached_property
[docs] def batch(self) -> tuple[_M, ...]: """ Returns the elements on the current page range. """ s, e = self.page_range s = s * self.batch_size e = e * self.batch_size + self.batch_size query = self.cached_subset.slice(s, e) return tuple(self.transform_batch_query(query))
@property
[docs] def pages_count(self) -> int: """ Returns the number of pages. """ if not self.batch_size: return 1 return int(math.ceil(self.subset_count / self.batch_size))
@property
[docs] def previous(self) -> 'Self | None': """ Returns the previous page or None. """ s, _e = self.page_range if s > 0: return self.by_page_range((s - 1, s - 1)) return None
@property
[docs] def next(self) -> 'Self | None': """ Returns the next page range or None. """ _s, e = self.page_range if e + 1 < self.pages_count: return self.by_page_range((e + 1, e + 1)) return None