from __future__ import annotations
import base64
import bleach
from urlextract import URLExtract, CacheFileError
from bleach.linkifier import TLDS
import errno
import fcntl
import gzip
import hashlib
import importlib
import inspect
import magic
import mimetypes
import morepath
import operator
import os.path
import re
import shutil
import sqlalchemy
import urllib.request
from collections.abc import Iterable
from contextlib import contextmanager
from cProfile import Profile
from functools import reduce
from importlib import import_module
from io import BytesIO, StringIO
from itertools import groupby, islice
from markupsafe import escape
from markupsafe import Markup
from onegov.core import log
from onegov.core.cache import lru_cache
from onegov.core.custom import json
from onegov.core.errors import AlreadyLockedError
from purl import URL
from threading import Thread
from time import perf_counter
from unidecode import unidecode
from uuid import UUID, uuid4
from webob import static
from yubico_client import Yubico # type:ignore[import-untyped]
from yubico_client.yubico_exceptions import ( # type:ignore[import-untyped]
SignatureVerificationError, StatusCodeError)
from typing import overload, Any, TypeVar, TYPE_CHECKING
if TYPE_CHECKING:
from _typeshed import SupportsRichComparison
from collections.abc import Callable, Collection, Iterator
from fs.base import FS, SubFS
from re import Match
from sqlalchemy import Column
from sqlalchemy.orm import Session
from types import ModuleType
from webob import Response
from .request import CoreRequest
from .types import FileDict, LaxFileDict
# http://stackoverflow.com/a/13500078
[docs]
_unwanted_url_chars = re.compile(r'[\.\(\)\\/\s<>\[\]{},:;?!@&=+$#@%|\*"\'`]+')
[docs]
_double_dash = re.compile(r'[-]+')
[docs]
_number_suffix = re.compile(r'-([0-9]+)$')
[docs]
_repeated_spaces = re.compile(r'\s\s+')
[docs]
_uuid = re.compile(
r'^[a-f0-9]{8}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{12}$')
# only temporary until bleach has a release > 1.4.1 -
[docs]
_email_regex = re.compile(
r"([a-z0-9!#$%&'*+\/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+\/=?^_`"
r"{|}~-]+)*(@|\sat\s)(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?(\.|"
r"\sdot\s))+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?)"
)
# detects multiple successive newlines
[docs]
_multiple_newlines = re.compile(r'\n{2,}', re.MULTILINE)
# detect starting strings of phone inside a link
# regex pattern for swiss phone numbers
[docs]
_phone_ch_country_code = r'(\+41|0041|0[0-9]{2})'
[docs]
_phone_ch = re.compile(_phone_ch_country_code + r'([ \r\f\t\d]+)')
# Adds a regex group to capture if a leading a tag is present or if the
# number is part of the href attributes
[docs]
_phone_ch_html_safe = re.compile(
_phone_inside_a_tags + _phone_ch_country_code + r'([ \r\f\t\d]+)')
# for yubikeys
[docs]
ALPHABET = 'cbdefghijklnrtuv'
[docs]
ALPHABET_RE = re.compile(r'^[cbdefghijklnrtuv]{12,44}$')
@contextmanager
[docs]
def local_lock(namespace: str, key: str) -> Iterator[None]:
""" Locks the given namespace/key combination on the current system,
automatically freeing it after the with statement has been completed or
once the process is killed.
Usage::
with lock('namespace', 'key'):
pass
"""
name = f'{namespace}-{key}'.replace('/', '-')
# NOTE: hardcoding /tmp is a bit piggy, but on the other hand we
# don't want different processes to miss each others locks
# just because one of them has a different TMPDIR, can we
# come up with a more robust way of doing this, e.g. with
# named semaphores?
with open(f'/tmp/{name}', 'w+') as f: # nosec:B108
try:
fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB)
yield
fcntl.flock(f, fcntl.LOCK_UN)
except BlockingIOError as exception:
raise AlreadyLockedError from exception
[docs]
def normalize_for_url(text: str) -> str:
""" Takes the given text and makes it fit to be used for an url.
That means replacing spaces and other unwanted characters with '-',
lowercasing everything and turning unicode characters into their closest
ascii equivalent using Unidecode.
See https://pypi.python.org/pypi/Unidecode
"""
# German is our main language, so we are extra considerate about it
# (unidecode turns ü into u)
text = text.replace('ü', 'ue')
text = text.replace('ä', 'ae')
text = text.replace('ö', 'oe')
clean = _unwanted_url_chars.sub('-', unidecode(text).strip(' ').lower())
clean = _double_dash.sub('-', clean)
clean = clean.strip('-')
return clean
[docs]
def increment_name(name: str) -> str:
""" Takes the given name and adds a numbered suffix beginning at 1.
For example::
foo => foo-1
foo-1 => foo-2
"""
match = _number_suffix.search(name)
if match:
number_str = match.group(1)
next_number = int(number_str) + 1
return f'{name[:-len(number_str)]}{next_number}'
else:
return f'{name}-1'
[docs]
def remove_repeated_spaces(text: str) -> str:
""" Removes repeated spaces in the text ('a b' -> 'a b'). """
return _repeated_spaces.sub(' ', text)
@contextmanager
[docs]
def profile(filename: str) -> Iterator[None]:
""" Profiles the wrapped code and stores the result in the profiles folder
with the given filename.
"""
profiler = Profile()
profiler.enable()
yield
profiler.disable()
profiler.create_stats()
profiler.dump_stats('profiles/{}'.format(filename))
@contextmanager
[docs]
def timing(name: str | None = None) -> Iterator[None]:
""" Runs the wrapped code and prints the time in ms it took to run it.
The name is printed in front of the time, if given.
"""
start = perf_counter()
yield
duration_ms = 1000.0 * (perf_counter() - start)
if name:
print(f'{name}: {duration_ms:.0f} ms')
else:
print(f'{duration_ms:.0f} ms')
@lru_cache(maxsize=32)
[docs]
def module_path_root(module: ModuleType | str) -> str:
if isinstance(module, str):
module = importlib.import_module(module)
assert module is not None
return os.path.dirname(inspect.getfile(module))
[docs]
def module_path(module: ModuleType | str, subpath: str) -> str:
""" Returns a subdirectory in the given python module.
:mod:
A python module (actual module or string)
:subpath:
Subpath below that python module. Leading slashes ('/') are ignored.
"""
parent = module_path_root(module)
path = os.path.join(parent, subpath.strip('/'))
# always be paranoid with path manipulation
assert is_subpath(parent, path)
return path
[docs]
def touch(file_path: str) -> None:
""" Touches the file on the given path. """
try:
os.utime(file_path, None)
except Exception:
open(file_path, 'a').close()
[docs]
class Bunch:
""" A simple but handy "collector of a bunch of named stuff" class.
See `<https://code.activestate.com/recipes/\
52308-the-simple-but-handy-collector-of-a-bunch-of-named/>`_.
For example::
point = Bunch(x=1, y=2)
assert point.x == 1
assert point.y == 2
point.z = 3
assert point.z == 3
Allows the creation of simple nested bunches, for example::
request = Bunch(**{'app.settings.org.my_setting': True})
assert request.app.settings.org.my_setting is True
"""
def __init__(self, **kwargs: Any):
self.__dict__.update(
(key, value)
for key, value in kwargs.items()
if '.' not in key
)
for key, value in kwargs.items():
if '.' in key:
name, _, key = key.partition('.')
setattr(self, name, Bunch(**{key: value}))
if TYPE_CHECKING:
# let mypy know that any attribute access could be valid
[docs]
def __getattr__(self, name: str) -> Any: ...
def __setattr__(self, name: str, value: Any) -> None: ...
def __delattr__(self, name: str) -> None: ...
[docs]
def __eq__(self, other: object) -> bool:
if type(other) is type(self):
return self.__dict__ == other.__dict__
return False
[docs]
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
[docs]
def render_file(file_path: str, request: CoreRequest) -> Response:
""" Takes the given file_path (content) and renders it to the browser.
The file must exist on the local system and be readable by the current
process.
"""
def hash_path(path: str) -> str:
return hashlib.new( # nosec:B324
'sha1',
path.encode('utf-8'),
usedforsecurity=False
).hexdigest()
# this is a very cachable result - though it's possible that a file
# changes it's content type, it should usually not, especially since
# we emphasize the use of random filenames
@request.app.cache.cache_on_arguments(to_str=hash_path)
def get_content_type(file_path: str) -> str:
content_type = mimetypes.guess_type(file_path)[0]
if not content_type:
content_type = magic.from_file(file_path, mime=True)
return content_type
return request.get_response(
static.FileApp(file_path, content_type=get_content_type(file_path)))
[docs]
def hash_dictionary(dictionary: dict[str, Any]) -> str:
""" Computes a sha256 hash for the given dictionary. The dictionary
is expected to only contain values that can be serialized by json.
That includes int, decimal, string, boolean.
Note that this function is not meant to be used for hashing secrets. Do
not include data in this dictionary that is secret!
"""
dict_as_string = json.dumps(dictionary, sort_keys=True).encode('utf-8')
return hashlib.new( # nosec:B324
'sha1',
dict_as_string,
usedforsecurity=False
).hexdigest()
@overload
[docs]
def groupbylist(
iterable: Iterable[_T],
key: None = ...
) -> list[tuple[_T, list[_T]]]: ...
@overload
def groupbylist(
iterable: Iterable[_T],
key: Callable[[_T], _KT]
) -> list[tuple[_KT, list[_T]]]: ...
def groupbylist(
iterable: Iterable[_T],
key: Callable[[_T], Any] | None = None
) -> list[tuple[Any, list[_T]]]:
""" Works just like Python's ``itertools.groupby`` function, but instead
of returning generators, it returns lists.
"""
return [(k, list(g)) for k, g in groupby(iterable, key=key)]
[docs]
def linkify_phone(text: str) -> Markup:
""" Takes a string and replaces valid phone numbers with html links. If a
phone number is matched, it will be replaced by the result of a callback
function, that does further checks on the regex match. If these checks do
not pass, the matched number will remain unchanged.
"""
def strip_whitespace(number: str) -> str:
return re.sub(r'\s', '', number)
def is_valid_length(number: str) -> bool:
if number.startswith('+00'):
return False
if number.startswith('00'):
return len(number) == 13
elif number.startswith('0'):
return len(number) == 10
elif number.startswith('+'):
return len(number) == 12
return False
def handle_match(match: Match[str]) -> str:
inside_html = match.group(1)
number = f'{match.group(2)}{match.group(3)}'
assert not number.endswith('\n')
if inside_html:
return match.group(0)
if is_valid_length(strip_whitespace(number)):
number = remove_repeated_spaces(number).strip()
return Markup(
'<a href="tel:{number}">{number}</a> '
).format(number=number)
return match.group(0)
# NOTE: re.sub isn't Markup aware, so we need to re-wrap
return Markup( # noqa: RUF035
_phone_ch_html_safe.sub(handle_match, escape(text)))
@lru_cache(maxsize=None)
[docs]
def top_level_domains() -> set[str]:
try:
return URLExtract()._load_cached_tlds()
except CacheFileError:
pass
# fallback
return {'agency', 'ngo', 'swiss', 'gle'}
[docs]
def linkify(text: str | None) -> Markup:
""" Takes plain text and injects html links for urls and email addresses.
By default the text is html escaped before it is linkified. This accounts
for the fact that we usually use this for text blocks that we mean to
extend with email addresses and urls.
If html is already possible, why linkify it?
Note: We need to clean the html after we've created it (linkify
parses escaped html and turns it into real html). As a consequence it
is possible to have html urls in the text that won't be escaped.
"""
if not text:
return Markup('')
def remove_dots(tlds: set[str]) -> list[str]:
return [domain[1:] for domain in tlds]
# bleach.linkify supports only a fairly limited amount of tlds
additional_tlds = top_level_domains()
if any(domain in text for domain in additional_tlds):
all_tlds = list(set(TLDS + remove_dots(additional_tlds)))
# Longest first, to prevent eager matching, if for example
# .co is matched before .com
all_tlds.sort(key=len, reverse=True)
bleach_linker = bleach.Linker(
url_re=bleach.linkifier.build_url_re(tlds=all_tlds),
email_re=bleach.linkifier.build_email_re(tlds=all_tlds),
parse_email=True if '@' in text else False
)
# NOTE: bleach's linkify always returns a plain string
# so we need to re-wrap
linkified = linkify_phone(Markup( # noqa: RUF035
bleach_linker.linkify(escape(text)))
)
else:
# NOTE: bleach's linkify always returns a plain string
# so we need to re-wrap
linkified = linkify_phone(Markup( # noqa: RUF035
bleach.linkify(escape(text), parse_email=True))
)
# NOTE: this is already vetted markup, don't clean it
if isinstance(text, Markup):
return linkified
return Markup(bleach.clean( # noqa: RUF035
linkified,
tags=['a'],
attributes={'a': ['href', 'rel']},
protocols=['http', 'https', 'mailto', 'tel']
))
[docs]
def paragraphify(text: str) -> Markup:
""" Takes a text with newlines groups them into paragraphs according to the
following rules:
If there's a single newline between two lines, a <br> will replace that
newline.
If there are multiple newlines between two lines, each line will become
a paragraph and the extra newlines are discarded.
"""
text = text and text.replace('\r', '').strip('\n')
if not text:
return Markup('')
was_markup = isinstance(text, Markup)
return Markup('').join(
Markup('<p>{}</p>').format(
(
# NOTE: re.split returns a plain str, so we need to restore
# markup based on whether it was markup before
Markup(p) if was_markup # noqa: RUF035
else escape(p)
).replace('\n', Markup('<br>'))
)
for p in _multiple_newlines.split(text)
)
[docs]
def to_html_ul(
value: str | None,
convert_dashes: bool = True,
with_title: bool = False
) -> Markup:
""" Linkify and convert to text to one or multiple ul's or paragraphs.
"""
if not value:
return Markup('')
value = value.replace('\r', '').strip('\n')
value = value.replace('\n\n', '\n \n')
if not convert_dashes:
return Markup('<p>{}</p>').format(
Markup('<br>').join(linkify(value).splitlines())
)
elements = []
temp: list[Markup] = []
def ul(inner: str) -> Markup:
return Markup('<ul class="bulleted">{}</ul>').format(inner)
def li(inner: str) -> Markup:
return Markup('<li>{}</li>').format(inner)
def p(inner: str) -> Markup:
return Markup('<p>{}</p>').format(inner)
was_list = False
for i, line in enumerate(value.splitlines()):
if not line:
continue
line = linkify(line)
is_list = line.startswith('-')
new_p_or_ul = True if line == ' ' else False
line = line.lstrip('-').strip()
if with_title:
elements.append(p(
Markup('<span class="title">{}</span>').format(line)))
with_title = False
else:
if new_p_or_ul or (was_list != is_list and i > 0):
elements.append(
ul(Markup('').join(temp)) if was_list
else p(Markup('<br>').join(temp))
)
temp = []
was_list = False
if not new_p_or_ul:
temp.append(li(line) if is_list else line)
new_p_or_ul = False
was_list = is_list
if temp:
elements.append(
ul(Markup('').join(temp)) if was_list
else p(Markup('<br>').join(temp))
)
return Markup('').join(elements)
@overload
[docs]
def ensure_scheme(url: str, default: str = 'http') -> str: ...
@overload
def ensure_scheme(url: None, default: str = 'http') -> None: ...
def ensure_scheme(url: str | None, default: str = 'http') -> str | None:
""" Makes sure that the given url has a scheme in front, if none
was provided.
"""
if not url:
return url
# purl (or to be precise urlparse) will parse empty host names ('abc.xyz')
# wrongly, assuming the abc.xyz is a path. by adding a double slash if
# there isn't one already, we can circumvent that problem
if '//' not in url:
url = '//' + url
_url = URL(url)
if _url.scheme():
return url
return _url.scheme(default).as_string()
[docs]
def is_uuid(value: str | UUID) -> bool:
""" Returns true if the given value is a uuid. The value may be a string
or of type UUID. If it's a string, the uuid is checked with a regex.
"""
if isinstance(value, str):
return _uuid.match(str(value)) and True or False
return isinstance(value, UUID)
[docs]
def is_non_string_iterable(obj: object) -> bool:
""" Returns true if the given obj is an iterable, but not a string. """
return not isinstance(obj, (str, bytes)) and isinstance(obj, Iterable)
[docs]
def relative_url(absolute_url: str | None) -> str:
""" Removes everything in front of the path, including scheme, host,
username, password and port.
"""
url = URL._mutate(
URL(absolute_url),
scheme=None,
username=None,
password=None,
host=None,
port=None
)
return url.as_string()
[docs]
def is_subpath(directory: str, path: str) -> bool:
""" Returns true if the given path is inside the given directory. """
directory = os.path.join(os.path.realpath(directory), '')
path = os.path.realpath(path)
# return true, if the common prefix of both is equal to directory
# e.g. /a/b/c/d.rst and directory is /a/b, the common prefix is /a/b
return os.path.commonprefix([path, directory]) == directory
@overload
[docs]
def is_sorted(
iterable: Iterable[SupportsRichComparison],
key: Callable[[SupportsRichComparison], SupportsRichComparison] = ...,
reverse: bool = ...
) -> bool: ...
@overload
def is_sorted(
iterable: Iterable[_T],
key: Callable[[_T], SupportsRichComparison],
reverse: bool = ...
) -> bool: ...
# FIXME: Do we really want to allow any Iterable? This seems like a bad
# idea to me... Iterators will be consumed and the Iterable might
# be infinite. This seems like it should be a Container instead,
# then we also don't need to use tee or list to make a copy
def is_sorted(
iterable: Iterable[Any],
key: Callable[[Any], SupportsRichComparison] = lambda i: i,
reverse: bool = False
) -> bool:
""" Returns True if the iterable is sorted. """
# NOTE: we previously used `tee` here, but since `sorted` consumes
# the entire iterator, this is the exact case where tee is
# slower than just pulling the entire sequence into a list
seq = list(iterable)
for a, b in zip(seq, sorted(seq, key=key, reverse=reverse)):
if a is not b:
return False
return True
[docs]
def morepath_modules(cls: type[morepath.App]) -> Iterator[str]:
""" Returns all morepath modules which should be scanned for the given
morepath application class.
We can't reliably know the actual morepath modules that
need to be scanned, which is why we assume that each module has
one namespace (like 'more.transaction' or 'onegov.core').
"""
for base in cls.__mro__:
if not issubclass(base, morepath.App):
continue
if base is morepath.App:
continue
module = '.'.join(base.__module__.split('.')[:2])
if module.startswith('test'):
continue
yield module
[docs]
def scan_morepath_modules(cls: type[morepath.App]) -> None:
""" Tries to scan all the morepath modules required for the given
application class. This is not guaranteed to stay reliable as there is
no sure way to discover all modules required by the application class.
"""
for module in sorted(morepath_modules(cls)):
morepath.scan(import_module(module))
[docs]
def get_unique_hstore_keys(
session: Session,
column: Column[dict[str, Any]]
) -> set[str]:
""" Returns a set of keys found in an hstore column over all records
of its table.
"""
base = session.query(column.keys()).with_entities( # type:ignore
sqlalchemy.func.skeys(column).label('keys'))
query = sqlalchemy.select(
[sqlalchemy.func.array_agg(sqlalchemy.column('keys'))],
distinct=True
).select_from(base.subquery())
keys = session.execute(query).scalar()
return set(keys) if keys else set()
[docs]
def makeopendir(fs: FS, directory: str) -> SubFS[FS]:
""" Creates and opens the given directory in the given PyFilesystem. """
if not fs.isdir(directory):
fs.makedir(directory)
return fs.opendir(directory)
[docs]
def append_query_param(url: str, key: str, value: str) -> str:
""" Appends a single query parameter to an url. This is faster than
using Purl, if and only if we only add one query param.
Also this function assumes that the value is already url encoded.
"""
template = '{}&{}={}' if '?' in url else '{}?{}={}'
return template.format(url, key, value)
[docs]
class PostThread(Thread):
""" POSTs the given data with the headers to the URL.
Example::
data = {'a': 1, 'b': 2}
data = json.dumps(data).encode('utf-8')
PostThread(
'https://example.com/post',
data,
(
('Content-Type', 'application/json; charset=utf-8'),
('Content-Length', len(data))
)
).start()
This only works for external URLs! If posting to server itself is
needed, use a process instead of the thread!
"""
def __init__(
self,
url: str,
data: bytes,
headers: Collection[tuple[str, str]],
timeout: float = 30
):
Thread.__init__(self)
[docs]
self.timeout = timeout
[docs]
def run(self) -> None:
try:
# Validate URL protocol before opening it, since it's possible to
# open ftp:// and file:// as well.
if not self.url.lower().startswith('http'):
raise ValueError from None
request = urllib.request.Request(self.url)
for header in self.headers:
request.add_header(header[0], header[1])
urllib.request.urlopen( # nosec B310
request, self.data, self.timeout
)
except Exception as e:
log.error(
'Error while sending a POST request to {}: {}'.format(
self.url, str(e)
)
)
[docs]
def toggle(collection: set[_T], item: _T | None) -> set[_T]:
""" Returns a new set where the item has been toggled. """
if item is None:
return collection
if item in collection:
return collection - {item}
else:
return collection | {item}
[docs]
def binary_to_dictionary(
binary: bytes,
filename: str | None = None
) -> FileDict:
""" Takes raw binary filedata and stores it in a dictionary together
with metadata information.
The data is compressed before it is stored int he dictionary. Use
:func:`dictionary_to_binary` to get the original binary data back.
"""
assert isinstance(binary, bytes)
mimetype = magic.from_buffer(binary, mime=True)
# according to https://tools.ietf.org/html/rfc7111, text/csv should be used
if mimetype == 'application/csv':
mimetype = 'text/csv'
gzipdata = BytesIO()
with gzip.GzipFile(fileobj=gzipdata, mode='wb') as f:
f.write(binary)
return {
'data': base64.b64encode(gzipdata.getvalue()).decode('ascii'),
'filename': filename,
'mimetype': mimetype,
'size': len(binary)
}
[docs]
def dictionary_to_binary(dictionary: LaxFileDict) -> bytes:
""" Takes a dictionary created by :func:`binary_to_dictionary` and returns
the original binary data.
"""
data = base64.b64decode(dictionary['data'])
with gzip.GzipFile(fileobj=BytesIO(data), mode='r') as f:
return f.read()
@overload
@overload
def safe_format(
format: str,
dictionary: dict[str, _T],
types: set[type[_T]] = ...,
adapt: Callable[[str], str] | None = ...,
raise_on_missing: bool = ...
) -> str: ...
def safe_format(
format: str,
dictionary: dict[str, Any],
types: set[type[Any]] | None = None,
adapt: Callable[[str], str] | None = None,
raise_on_missing: bool = False
) -> str:
""" Takes a user-supplied string with format blocks and returns a string
where those blocks are replaced by values in a dictionary.
For example::
>>> safe_format('[user] has logged in', {'user': 'admin'})
'admin has logged in'
:param format:
The format to use. Square brackets denote dictionary keys. To
literally print square bracktes, mask them by doubling ('[[' -> '[')
:param dictionary:
The dictionary holding the variables to use. If the key is not found
in the dictionary, the bracket is replaced with an empty string.
:param types:
A set of types supported by the dictionary. Limiting this to safe
types like builtins (str, int, float) ensure that no values are
accidentally leaked through faulty __str__ representations.
Note that inheritance is ignored. Supported types need to be
whitelisted explicitly.
:param adapt:
An optional callable that receives the key before it is used. Returns
the same key or an altered version.
:param raise_on_missing:
True if missing keys should result in a runtime error (defaults to
False).
This is strictly meant for formats provided by users. Python's string
formatting options are clearly superior to this, however it is less
secure!
"""
types = types or {int, str, float}
output = StringIO()
buffer = StringIO()
opened = 0
for ix, char in enumerate(format):
if char == '[':
opened += 1
if char == ']':
opened -= 1
if opened == 1 and char != '[' and char != ']':
print(char, file=buffer, end='')
continue
if opened == 2 or opened == -2:
if buffer.tell():
raise RuntimeError('Unexpected bracket inside bracket found')
print(char, file=output, end='')
opened = 0
continue
if buffer.tell():
k = adapt(buffer.getvalue()) if adapt else buffer.getvalue()
if raise_on_missing and k not in dictionary:
raise RuntimeError("Key '{}' is unknown".format(k))
v = dictionary.get(k, '')
t = type(v)
if t not in types:
raise RuntimeError("Invalid type for '{}': {}".format(k, t))
print(v, file=output, end='')
buffer = StringIO()
if char != '[' and char != ']':
print(char, file=output, end='')
if opened != 0:
raise RuntimeError("Uneven number of brackets in '{}'".format(format))
return output.getvalue()
[docs]
def is_valid_yubikey(
client_id: str,
secret_key: str,
expected_yubikey_id: str,
yubikey: str
) -> bool:
""" Asks the yubico validation servers if the given yubikey OTP is valid.
:client_id:
The yubico API client id.
:secret_key:
The yubico API secret key.
:expected_yubikey_id:
The expected yubikey id. The yubikey id is defined as the first twelve
characters of any yubikey value. Each user should have a yubikey
associated with it's account. If the yubikey value comes from a
different key, the key is invalid.
:yubikey:
The actual yubikey value that should be verified.
:return: True if yubico confirmed the validity of the key.
"""
assert client_id and secret_key and expected_yubikey_id and yubikey
assert len(expected_yubikey_id) == 12
try:
return Yubico(client_id, secret_key).verify(yubikey)
except StatusCodeError as e:
if e.status_code != 'REPLAYED_OTP':
raise e
return False
except SignatureVerificationError:
return False
[docs]
def yubikey_otp_to_serial(otp: str) -> int | None:
""" Takes a Yubikey OTP and calculates the serial number of the key.
The serial key is printed on the yubikey, in decimal and as a QR code.
Example::
>>> yubikey_otp_to_serial(
'ccccccdefghdefghdefghdefghdefghdefghdefghklv')
2311522
Adapted from Java::
https://github.com/Yubico/yubikey-salesforce-client/blob/
e38e46ee90296a852374a8b744555e99d16b6ca7/src/classes/Modhex.cls
If the key cannot be calculated, None is returned. This can happen if
they key is malformed.
"""
if not is_valid_yubikey_format(otp):
return None
token = 'cccc' + otp[:12]
toggle = False
keep = 0
bytesarray = []
for char in token:
n = ALPHABET.index(char)
toggle = not toggle
if toggle:
keep = n
else:
bytesarray.append((keep << 4) | n)
value = 0
# in Java, shifts on integers are masked with 0x1f using AND
# https://docs.oracle.com/javase/specs/jls/se8/html/jls-15.html#jls-15.19
mask_value = 0x1f
for i in range(8):
shift = (4 - 1 - i) * 8
value += (bytesarray[i] & 255) << (shift & mask_value)
return value
[docs]
def yubikey_public_id(otp: str) -> str:
""" Returns the yubikey identity given a token. """
return otp[:12]
[docs]
def dict_path(dictionary: dict[str, _T], path: str) -> _T:
""" Gets the value of the given dictionary at the given path.
For example::
>>> data = {'foo': {'bar': True}}
>>> dict_path(data, 'foo.bar')
True
"""
if not dictionary:
raise KeyError()
return reduce(operator.getitem, path.split('.'), dictionary) # type:ignore
[docs]
def safe_move(src: str, dst: str, tmp_dst: str | None = None) -> None:
""" Rename a file from ``src`` to ``dst``.
Optionally provide a ``tmp_dst`` where the file will be copied to
before being renamed. This needs to be on the same filesystem as
``tmp``, otherwise this will fail.
* Moves must be atomic. ``shutil.move()`` is not atomic.
* Moves must work across filesystems. Often temp directories and the
cache directories live on different filesystems. ``os.rename()`` can
throw errors if run across filesystems.
So we try ``os.rename()``, but if we detect a cross-filesystem copy, we
switch to ``shutil.move()`` with some wrappers to make it atomic.
Via https://alexwlchan.net/2019/03/atomic-cross-filesystem-moves-in-python
"""
try:
os.rename(src, dst)
except OSError as err:
if err.errno == errno.EXDEV:
# Generate a unique ID, and copy `<src>` to the target directory
# with a temporary name `<dst>.<ID>.tmp`. Because we're copying
# across a filesystem boundary, this initial copy may not be
# atomic. We intersperse a random UUID so if different processes
# are copying into `<dst>`, they don't overlap in their tmp copies.
copy_id = uuid4()
tmp_dst = f'{tmp_dst or dst}.{copy_id}.tmp'
shutil.copyfile(src, tmp_dst)
# Then do an atomic rename onto the new name, and clean up the
# source file.
os.rename(tmp_dst, dst)
os.unlink(src)
else:
raise
@overload
[docs]
def batched(
iterable: Iterable[_T],
batch_size: int,
container_factory: type[tuple] = ... # type:ignore[type-arg]
) -> Iterator[tuple[_T, ...]]: ...
@overload
def batched(
iterable: Iterable[_T],
batch_size: int,
container_factory: type[list] # type:ignore[type-arg]
) -> Iterator[list[_T]]: ...
# NOTE: If there were higher order TypeVars, we could properly infer
# the type of the Container, for now we just add overloads for
# two of the most common container_factories
@overload
def batched(
iterable: Iterable[_T],
batch_size: int,
container_factory: Callable[[Iterator[_T]], Collection[_T]]
) -> Iterator[Collection[_T]]: ...
def batched(
iterable: Iterable[_T],
batch_size: int,
container_factory: Callable[[Iterator[_T]], Collection[_T]] = tuple
) -> Iterator[Collection[_T]]:
""" Splits an iterable into batches of batch_size and puts them
inside a given collection (tuple by default).
The container_factory is necessary in order to consume the iterator
returned by islice. Otherwise this function would never return.
"""
iterator = iter(iterable)
while True:
batch = container_factory(islice(iterator, batch_size))
if len(batch) == 0:
return
yield batch