Source code for org.models.tan

from datetime import timedelta
from uuid import uuid4

from onegov.core.collection import GenericCollection
from onegov.core.orm import Base
from onegov.core.orm.mixins import TimestampMixin
from onegov.core.orm.types import UUID
from sqlalchemy import func, Column, Index, Text


from typing import TYPE_CHECKING
if TYPE_CHECKING:
    import uuid
    from sqlalchemy.orm import Query, Session


[docs] DEFAULT_ACCESS_WINDOW = timedelta(days=1)
[docs] class TANAccess(Base, TimestampMixin): """ This exists to keep track of which protected URLs have been accessed by any given TAN session. This allows us to throttle requests to protected resources. """
[docs] __tablename__ = 'tan_accesses'
[docs] __table_args__ = ( # TimestampMixin by default does not generate an index for # the created column, so we do it here instead Index('ix_tan_accesses_created', 'created'), )
[docs] id: 'Column[uuid.UUID]' = Column( UUID, # type: ignore[arg-type] primary_key=True, default=uuid4 )
# for an mTAN session this would be the phone number
[docs] session_id: 'Column[str]' = Column( Text, index=True, nullable=False )
# The url that was accessed
[docs] url: 'Column[str]' = Column( Text, index=True, nullable=False )
[docs] class TANAccessCollection(GenericCollection[TANAccess]): def __init__( self, session: 'Session', session_id: str, access_window: timedelta = DEFAULT_ACCESS_WINDOW, ): super().__init__(session)
[docs] self.session_id = session_id
[docs] self.access_window = access_window
@property
[docs] def model_class(self) -> type[TANAccess]: return TANAccess
[docs] def query(self) -> 'Query[TANAccess]': cutoff = TANAccess.timestamp() - self.access_window return self.session.query(TANAccess).filter( TANAccess.session_id == self.session_id ).filter(TANAccess.created > cutoff)
[docs] def add(self, *, url: str) -> TANAccess: # type:ignore[override] access = self.by_url(url) if access is not None: # during the access_window subsequent accesses to the same # url are treated like a single access return access access = TANAccess(session_id=self.session_id, url=url) self.session.add(access) self.session.flush() return access
[docs] def by_url(self, url: str) -> TANAccess | None: return self.query().filter(TANAccess.url == url).first()
[docs] def count(self) -> int: return self.query().with_entities(func.count(TANAccess.id)).scalar()