from __future__ import annotations
from datetime import timedelta
from onegov.core.collection import GenericCollection
from onegov.core.orm import Base
from onegov.core.orm.mixins import TimestampMixin
from sqlalchemy import func, Index
from sqlalchemy.orm import mapped_column, Mapped
from uuid import uuid4, UUID
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from sqlalchemy.orm import Query, Session
[docs]
DEFAULT_ACCESS_WINDOW = timedelta(days=1)
[docs]
class TANAccess(Base, TimestampMixin):
"""
This exists to keep track of which protected URLs have been accessed
by any given TAN session.
This allows us to throttle requests to protected resources.
"""
[docs]
__tablename__ = 'tan_accesses'
[docs]
__table_args__ = (
# TimestampMixin by default does not generate an index for
# the created column, so we do it here instead
Index('ix_tan_accesses_created', 'created'),
)
[docs]
id: Mapped[UUID] = mapped_column(
primary_key=True,
default=uuid4
)
# for an mTAN session this would be the phone number
[docs]
session_id: Mapped[str] = mapped_column(index=True)
# The url that was accessed
[docs]
url: Mapped[str] = mapped_column(index=True)
[docs]
class TANAccessCollection(GenericCollection[TANAccess]):
def __init__(
self,
session: Session,
session_id: str,
access_window: timedelta = DEFAULT_ACCESS_WINDOW,
):
super().__init__(session)
[docs]
self.session_id = session_id
[docs]
self.access_window = access_window
@property
[docs]
def model_class(self) -> type[TANAccess]:
return TANAccess
[docs]
def query(self) -> Query[TANAccess]:
cutoff = TANAccess.timestamp() - self.access_window
return self.session.query(TANAccess).filter(
TANAccess.session_id == self.session_id
).filter(TANAccess.created > cutoff)
[docs]
def add(self, *, url: str) -> TANAccess: # type:ignore[override]
access = self.by_url(url)
if access is not None:
# during the access_window subsequent accesses to the same
# url are treated like a single access
return access
access = TANAccess(session_id=self.session_id, url=url)
self.session.add(access)
self.session.flush()
return access
[docs]
def by_url(self, url: str) -> TANAccess | None:
return self.query().filter(TANAccess.url == url).first()
[docs]
def count(self) -> int:
return self.query().with_entities(func.count(TANAccess.id)).scalar()