Source code for org.models.tan

from __future__ import annotations

from datetime import timedelta
from onegov.core.collection import GenericCollection
from onegov.core.orm import Base
from onegov.core.orm.mixins import TimestampMixin
from sqlalchemy import func, Index
from sqlalchemy.orm import mapped_column, Mapped
from uuid import uuid4, UUID


from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from sqlalchemy.orm import Query, Session


[docs] DEFAULT_ACCESS_WINDOW = timedelta(days=1)
[docs] class TANAccess(Base, TimestampMixin): """ This exists to keep track of which protected URLs have been accessed by any given TAN session. This allows us to throttle requests to protected resources. """
[docs] __tablename__ = 'tan_accesses'
[docs] __table_args__ = ( # TimestampMixin by default does not generate an index for # the created column, so we do it here instead Index('ix_tan_accesses_created', 'created'), )
[docs] id: Mapped[UUID] = mapped_column( primary_key=True, default=uuid4 )
# for an mTAN session this would be the phone number
[docs] session_id: Mapped[str] = mapped_column(index=True)
# The url that was accessed
[docs] url: Mapped[str] = mapped_column(index=True)
[docs] class TANAccessCollection(GenericCollection[TANAccess]): def __init__( self, session: Session, session_id: str, access_window: timedelta = DEFAULT_ACCESS_WINDOW, ): super().__init__(session)
[docs] self.session_id = session_id
[docs] self.access_window = access_window
@property
[docs] def model_class(self) -> type[TANAccess]: return TANAccess
[docs] def query(self) -> Query[TANAccess]: cutoff = TANAccess.timestamp() - self.access_window return self.session.query(TANAccess).filter( TANAccess.session_id == self.session_id ).filter(TANAccess.created > cutoff)
[docs] def add(self, *, url: str) -> TANAccess: # type:ignore[override] access = self.by_url(url) if access is not None: # during the access_window subsequent accesses to the same # url are treated like a single access return access access = TANAccess(session_id=self.session_id, url=url) self.session.add(access) self.session.flush() return access
[docs] def by_url(self, url: str) -> TANAccess | None: return self.query().filter(TANAccess.url == url).first()
[docs] def count(self) -> int: return self.query().with_entities(func.count(TANAccess.id)).scalar()