Source code for org.models.file

""" Contains the models describing files and images. """
from __future__ import annotations

import sedate

from datetime import datetime
from dateutil.relativedelta import relativedelta
from functools import cached_property
from itertools import chain, groupby
from onegov.core.orm import as_selectable
from onegov.core.orm.mixins import dict_property, meta_property
from onegov.file import File, FileSet, FileCollection, FileSetCollection
from onegov.file import SearchableFile
from onegov.file.utils import IMAGE_MIME_TYPES_AND_SVG
from onegov.org import _
from onegov.org.models.extensions import AccessExtension
from onegov.org.utils import widest_access
from onegov.search import ORMSearchable
from operator import attrgetter, itemgetter
from sedate import standardize_date, utcnow
from sqlalchemy import asc, desc, select, nullslast  # type: ignore

from typing import (
    overload, Any, Generic, Literal, NamedTuple, TypeVar, TYPE_CHECKING)
if TYPE_CHECKING:
    from collections.abc import Callable, Iterable, Iterator
    from sqlalchemy.orm import Query, Session
    from sqlalchemy.sql import Select
    from typing import Self

[docs] _T = TypeVar('_T')
_RowT = TypeVar('_RowT') class IdRow(NamedTuple): id: str class FileRow(NamedTuple): number: int id: str name: str order: str signed: bool upload_date: datetime publish_end_date: datetime content_type: str
[docs] FileT = TypeVar('FileT', bound=File)
[docs] class DateInterval(NamedTuple):
[docs] name: str
[docs] start: datetime
[docs] end: datetime
[docs] class GroupFilesByDateMixin(Generic[FileT]): if TYPE_CHECKING:
[docs] def query(self) -> Query[FileT]: ...
[docs] def get_date_intervals( self, today: datetime ) -> Iterator[DateInterval]: today = standardize_date(today, 'UTC') month_end = today + relativedelta(day=31) month_start = today - relativedelta(day=1) next_month_start = month_start + relativedelta(months=1) in_distant_future = next_month_start + relativedelta(years=100) yield DateInterval( name=_('In future'), start=next_month_start, end=in_distant_future) yield DateInterval( name=_('This month'), start=month_start, end=month_end) last_month_end = month_start - relativedelta(microseconds=1) last_month_start = month_start - relativedelta(months=1) yield DateInterval( name=_('Last month'), start=last_month_start, end=last_month_end) if month_start.month not in (1, 2): this_year_end = last_month_start - relativedelta(microseconds=1) this_year_start = this_year_end.replace( month=1, day=1, hour=0, minute=0, second=0, microsecond=0) yield DateInterval( name=_('This year'), start=this_year_start, end=this_year_end) else: this_year_end = None this_year_start = None last_year_end = this_year_start or last_month_start last_year_end -= relativedelta(microseconds=1) last_year_start = last_year_end.replace( month=1, day=1, hour=0, minute=0, second=0, microsecond=0) yield DateInterval( name=_('Last year'), start=last_year_start, end=last_year_end) older_end = last_year_start - relativedelta(microseconds=1) older_start = datetime(2000, 1, 1, tzinfo=today.tzinfo) yield DateInterval( name=_('Older'), start=older_start, end=older_end)
@overload
[docs] def query_intervals( self, intervals: Iterable[DateInterval], before_filter: Callable[[Query[FileT]], Query[_RowT]], process: Callable[[_RowT], _T] ) -> Iterator[tuple[str, _T]]: ...
@overload def query_intervals( self, intervals: Iterable[DateInterval], before_filter: None, process: Callable[[FileT], _T] ) -> Iterator[tuple[str, _T]]: ... @overload def query_intervals( self, intervals: Iterable[DateInterval], before_filter: None = None, *, process: Callable[[FileT], _T] ) -> Iterator[tuple[str, _T]]: ... @overload def query_intervals( self, intervals: Iterable[DateInterval], before_filter: Callable[[Query[FileT]], Query[Any]] | None = None, process: None = None ) -> Iterator[tuple[str, Any]]: ... def query_intervals( self, intervals: Iterable[DateInterval], before_filter: Callable[[Query[FileT]], Query[Any]] | None = None, process: Callable[[Any], Any] | None = None ) -> Iterator[tuple[str, Any]]: base_query = self.query().order_by(desc(File.created)) if before_filter: base_query = before_filter(base_query) for interval in intervals: query = base_query.filter(File.created >= interval.start) query = query.filter(File.created <= interval.end) for result in query.all(): if process is not None: yield interval.name, process(result) @overload
[docs] def grouped_by_date( self, today: datetime | None = None, id_only: Literal[True] = True ) -> groupby[str, tuple[str, str]]: ...
@overload def grouped_by_date( self, today: datetime | None, id_only: Literal[False] ) -> groupby[str, tuple[str, FileT]]: ... @overload def grouped_by_date( self, today: datetime | None = None, *, id_only: Literal[False] ) -> groupby[str, tuple[str, FileT]]: ... def grouped_by_date( self, today: datetime | None = None, id_only: bool = True ) -> groupby[str, tuple[str, FileT | str]]: """ Returns all files grouped by natural language dates. By default, only ids are returned, as this is enough to build the necessary links, which is what you usually want from a file. The given date is expected to be in UTC. """ intervals = tuple(self.get_date_intervals(today or utcnow())) files: Iterator[tuple[str, str | FileT]] if id_only: def before_filter(query: Query[FileT]) -> Query[IdRow]: return query.with_entities(File.id) def process(result: IdRow) -> str: return result.id files = self.query_intervals(intervals, before_filter, process) else: def process_file(result: FileT) -> FileT: return result files = self.query_intervals(intervals, None, process_file) return groupby(files, key=itemgetter(0))
[docs] class GeneralFile(File, SearchableFile):
[docs] __mapper_args__ = {'polymorphic_identity': 'general'}
#: the access of all the linked models
[docs] linked_accesses: dict_property[dict[str, str]]
linked_accesses = meta_property(default=dict) @property
[docs] def access(self) -> str: if self.publication: return 'public' if not self.linked_accesses: # a file which is not a publication and has no linked # accesses is considered secret return 'secret' return widest_access(*self.linked_accesses.values())
@property
[docs] def es_public(self) -> bool: return self.published and self.access == 'public'
[docs] class ImageFile(File):
[docs] __mapper_args__ = {'polymorphic_identity': 'image'}
[docs] class ImageSet(FileSet, AccessExtension, ORMSearchable):
[docs] __mapper_args__ = {'polymorphic_identity': 'image'}
[docs] es_properties = { 'title': {'type': 'localized'}, 'lead': {'type': 'localized'} }
@property
[docs] def es_public(self) -> bool: return self.access == 'public'
@property
[docs] def es_suggestions(self) -> dict[str, list[str]]: return { 'input': [self.title.lower()] }
[docs] lead: dict_property[str | None] = meta_property()
[docs] view: dict_property[str | None] = meta_property()
[docs] order: dict_property[str] = meta_property(default='by-last-change')
[docs] order_direction: dict_property[str] = meta_property(default='desc')
[docs] show_images_on_homepage: dict_property[bool | None] = meta_property()
@property
[docs] def ordered_files(self) -> list[File]: if self.order == 'by-last-change': # the files are already sorted, since this relationship # is sorted by last change in descending order if self.order_direction == 'desc': return self.files else: return [*reversed(self.files)] sort_key: Callable[[File], str] if self.order == 'by-name': sort_key = attrgetter('name') elif self.order == 'by-caption': # we can't use attrgetter since note is nullable def sort_key(file: File) -> str: return file.note or '' else: raise AssertionError('unreachable') # for the rest we sort by attribute name return sorted( self.files, key=sort_key, reverse=self.order_direction == 'desc' )
[docs] class ImageSetCollection(FileSetCollection[ImageSet]): def __init__(self, session: Session) -> None: super().__init__(session, type='image')
[docs] class GeneralFileCollection( FileCollection[GeneralFile], GroupFilesByDateMixin[GeneralFile] ):
[docs] supported_content_types = 'all'
[docs] file_list = as_selectable(""" SELECT row_number() OVER () as number, -- Integer id, -- Text name, -- Text "order", -- Text signed, -- Boolean created as upload_date, -- UTCDateTime publish_end_date, -- UTCDateTime reference->>'content_type' AS content_type -- Text FROM files WHERE type = 'general' """)
def __init__(self, session: Session, order_by: str = 'name') -> None: super().__init__(session, type='general', allow_duplicates=False)
[docs] self.order_by = order_by
[docs] self.direction = order_by == 'name' and 'ascending' or 'descending'
[docs] self._last_interval: DateInterval | None = None
[docs] def for_order(self, order: str) -> Self: return self.__class__(self.session, order_by=order)
@cached_property
[docs] def intervals(self) -> tuple[DateInterval, ...]: return tuple(self.get_date_intervals(today=sedate.utcnow()))
@property
[docs] def statement(self) -> Select: stmt = select(self.file_list.c) if self.order_by == 'name': order = self.file_list.c.order elif self.order_by == 'date': order = self.file_list.c.upload_date elif self.order_by == 'publish_end_date': order = self.file_list.c.publish_end_date else: order = self.file_list.c.order direction = asc if self.direction == 'ascending' else desc return stmt.order_by(nullslast(direction(order)))
@property
[docs] def files(self) -> Query[FileRow]: return self.session.execute(self.statement)
[docs] def group(self, record: FileRow) -> str: def get_first_character(record: FileRow) -> str: if record.order[0].isdigit(): return '0-9' return record.order[0].upper() if self.order_by == 'name': return get_first_character(record) elif self.order_by == 'date' or self.order_by == 'publish_end_date': intervals: Iterable[DateInterval] if self._last_interval: intervals = chain((self._last_interval, ), self.intervals) else: intervals = self.intervals if self.order_by == 'date': for interval in intervals: if interval.start <= record.upload_date <= interval.end: break else: return _('Older') elif self.order_by == 'publish_end_date': for interval in intervals: if not record.publish_end_date: return _('None') if (interval.start <= record.publish_end_date <= interval.end): break else: return _('Older') # this method is usually called for each item in a sorted set, # we optimise for that by caching the last matching interval # and checking that one first the next time self._last_interval = interval return interval.name else: # default ordering by name return get_first_character(record)
[docs] class BaseImageFileCollection( FileCollection[FileT], GroupFilesByDateMixin[FileT] ):
[docs] supported_content_types = IMAGE_MIME_TYPES_AND_SVG
[docs] class ImageFileCollection(BaseImageFileCollection[ImageFile]): def __init__(self, session: Session) -> None: super().__init__(session, type='image', allow_duplicates=False)