from __future__ import annotations
import inspect
import importlib
import pkg_resources
import transaction
from contextlib import contextmanager
from inspect import getmembers, isfunction, ismethod
from itertools import chain
from onegov.core import LEVELS
from onegov.core.orm import Base, find_models
from onegov.core.orm.mixins import TimestampMixin
from onegov.core.orm.types import JSON
from sqlalchemy import Column, Text
from sqlalchemy import create_engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.orm import load_only
from sqlalchemy.pool import StaticPool
from toposort import toposort
from typing import cast, overload, Any, TypeVar, TYPE_CHECKING
if TYPE_CHECKING:
from _typeshed import SupportsRichComparison
from collections.abc import (
Callable, Collection, Iterable, Iterator, Mapping, Sequence)
# FIXME: Switch to importlib.resources
from pkg_resources import Distribution, EntryPoint
from sqlalchemy.engine import Connection
from sqlalchemy.orm import Query, Session
from types import CodeType, ModuleType
from typing import ParamSpec, Protocol, TypeAlias, TypeGuard
from .request import CoreRequest
_T_co = TypeVar('_T_co', covariant=True)
_P = ParamSpec('_P')
class _Task(Protocol[_P, _T_co]):
__name__: str
__code__: CodeType
task_name: str
always_run: bool
requires: str | None
raw: bool
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T_co: ...
RawFunc: TypeAlias = Callable[[Connection, Sequence[str]], Any]
UpgradeFunc: TypeAlias = Callable[['UpgradeContext'], Any]
RawTask: TypeAlias = _Task[[Connection, Sequence[str]], Any]
UpgradeTask: TypeAlias = _Task[['UpgradeContext'], Any]
TaskCallback: TypeAlias = Callable[[_Task[..., Any]], Any]
[docs]
class UpgradeState(Base, TimestampMixin):
""" Keeps the state of all upgrade steps over all modules. """
[docs]
__tablename__ = 'upgrades'
# the name of the module (e.g. onegov.core)
[docs]
module = Column(Text, primary_key=True)
# a json holding the state of the upgrades
[docs]
state = Column(JSON, nullable=False, default=dict)
@property
[docs]
def executed_tasks(self) -> set[str]:
if not self.state:
return set()
else:
return set(self.state.get('executed_tasks', []))
[docs]
def was_already_executed(self, task: _Task[..., Any]) -> bool:
return task.task_name in self.executed_tasks
[docs]
def mark_as_executed(self, task: _Task[..., Any]) -> None:
if not self.state:
self.state = {}
if 'executed_tasks' in self.state:
self.state['executed_tasks'].append(task.task_name)
self.state['executed_tasks'] = list(
set(self.state['executed_tasks']))
else:
self.state['executed_tasks'] = [task.task_name]
self.state.changed() # type:ignore[attr-defined]
[docs]
def get_distributions_with_entry_map(
key: str
) -> Iterator[tuple[Distribution, dict[str, EntryPoint]]]:
""" Iterates through all distributions with entry_maps and yields
each distribution along side the entry map with the given key.
"""
for distribution in pkg_resources.working_set:
if hasattr(distribution, 'get_entry_map'):
entry_map = distribution.get_entry_map(key)
if entry_map:
yield distribution, entry_map
[docs]
def get_upgrade_modules() -> Iterator[tuple[str, ModuleType]]:
""" Returns all modules that registered themselves for onegov.core
upgrades like this::
entry_points={
'onegov': [
'upgrade = onegov.mypackage.upgrade'
]
}
To add multiple upgrades in a single setup.py file, the following syntax
may be used. This will become the default in the future::
entry_points= {
'onegov_upgrades': [
'onegov.mypackage = onegov.mypackage.upgrade'
]
}
Note that the part before the ``=``-sign should be kept the same, even if
the location changes. Otherwise completed updates may be run again!
"""
yield 'onegov.core', importlib.import_module('onegov.core.upgrades')
distributions = get_distributions_with_entry_map('onegov')
for distribution, entry_map in distributions:
if 'upgrade' in entry_map:
yield (
'.'.join(entry_map['upgrade'].module_name.split('.')[:2]),
importlib.import_module(entry_map['upgrade'].module_name)
)
distributions = get_distributions_with_entry_map('onegov_upgrades')
for distribution, entry_map in distributions:
for entry in entry_map.values():
yield entry.name, importlib.import_module(entry.module_name)
[docs]
class upgrade_task: # noqa: N801
""" Marks the decorated function as an upgrade task. Upgrade tasks should
be defined outside classes (except for testing) - that is in the root of
the module (directly in onegov/form/upgrades.py for example).
Each upgrade task has the following properties:
:name:
The name of the task in readable English. Do not change the name of
the task once you pushed your changes. This name is used as an id!
:always_run:
By default, tasks are run once on if they haven't been run yet, or if
the module was not yet under upgrade control (it just added itself
through the entry_points mechanism).
There are times where this is not wanted.
If you enable upgrades in your module and you plan to run an upgrade
task in the same release, then you need to 'always_run' and detect
yourself if you need to do any changes.
If you need to enable upgrades and run a task at the same time which
can only be run once and is impossible to make idempotent, then you
need to make two releases. One that enables upgade control and a second
one that
Also, if you want to run an upgrade task every time you upgrade (for
example, for maintenance), then you should also use it.
The easiest way to use upgrade control is to enable it from the first
release on.
Another way to tackle this is to write update methods idempotent as
often as possible and to always run them.
:requires:
A reference to another projects upgrade task that should be run first.
For example::
requires='onegov.form:Add new form field'
If the reference cannot be found, the upgrade fails.
:raw:
A raw task is run without setting up a request. Instead a database
connection and a list of targeted schemas is passed. It's up to the
upgrade function to make sure the right schemas are targeted through
that connection!
This is useful for upgrades which need to execute some raw SQL that
cannot be executed in a request context. For example, this is a way
to migrate the upgrade states table of the upgrade itself.
Raw tasks are always run and there state is not recorded. So the
raw task needs to be idempotent and return False if there was no
change.
Only use this if you have no other way. This is a very blunt instrument
only used under special circumstances.
Note, raw tasks may be normal function OR generators. If they are
generators, then with each yield the transaction is either commited
or rolled-back depending on the dry-run cli argument.
Always run tasks may return ``False`` if they wish to be considered
as not run (and therefore not shown in the upgrade runner).
"""
def __init__(
self,
name: str,
always_run: bool = False,
requires: str | None = None,
raw: bool = False
):
if raw:
assert always_run, 'raw tasks must always run'
assert not requires, 'raw tasks may not require other tasks'
[docs]
self.always_run = always_run
[docs]
self.requires = requires
# NOTE: We only allow the two supported function signatures
@overload
[docs]
def __call__(self, fn: UpgradeFunc) -> UpgradeTask: ...
@overload
def __call__(self, fn: RawFunc) -> RawTask: ...
def __call__(self, fn: Callable[_P, _T]) -> _Task[_P, _T]:
fn = cast('_Task[_P, _T]', fn)
fn.task_name = self.name
fn.always_run = self.always_run
fn.requires = self.requires
fn.raw = self.raw
return fn
[docs]
def is_task(function: Callable[_P, _T]) -> TypeGuard[_Task[_P, _T]]:
""" Returns True if the given function is an uprade task. """
if not (isfunction(function) or ismethod(function)):
return False
return hasattr(function, 'task_name')
[docs]
def get_module_tasks(module: ModuleType) -> Iterator[_Task[..., Any]]:
""" Goes through a module or class and returns all upgrade tasks. """
for name, function in getmembers(module, is_task):
yield function
[docs]
def get_tasks_by_id(
upgrade_modules: Iterable[tuple[str, ModuleType]]
) -> dict[str, _Task[..., Any]]:
""" Takes a list of upgrade modules or classes and returns the tasks
keyed by id.
"""
tasks = {}
fn_names = set()
for distribution, upgrade_module in upgrade_modules:
for function in get_module_tasks(upgrade_module):
task_id = f'{distribution}:{function.task_name}'
assert task_id not in tasks, 'Duplicate task'
tasks[task_id] = function
# make sure we don't have duplicate function names - it works, but
# it makes debugging harder
msg = f'Duplicate function name: {function.__name__}'
assert function.__name__ not in fn_names, msg
fn_names.add(function.__name__)
return tasks
[docs]
def get_module_order_key(
tasks: Mapping[str, _Task[..., Any]]
) -> Callable[[str], SupportsRichComparison]:
""" Returns a sort order key which orders task_ids in order of their
module dependencies. That is a task from onegov.core is sorted before
a task in onegov.user, because onegov.user depends on onegov.core.
This is used to order unrelated tasks in a sane way.
"""
sorted_modules = {module: ix for ix, module in enumerate(chain(*LEVELS))}
modules = set()
for task in tasks:
modules.add(task.split(':', 1)[0])
def sortkey(task: str) -> SupportsRichComparison:
module = task.split(':', 1)[0]
return (
# sort by level (unknown models first)
sorted_modules.get(module, float('-inf')),
# then by module name
module,
# then by appearance of update function in the code
tasks[task].__code__.co_firstlineno
)
return sortkey
[docs]
def get_tasks(
upgrade_modules: Iterable[tuple[str, ModuleType]] | None = None
) -> list[tuple[str, _Task[..., Any]]]:
""" Takes a list of upgrade modules or classes and returns the
tasks that should be run in the order they should be run.
The order takes dependencies into account.
"""
tasks = get_tasks_by_id(upgrade_modules or get_upgrade_modules())
for task_id, task in tasks.items():
if task.requires:
assert not tasks[task.requires].raw, 'Raw tasks cannot be required'
graph: dict[str, set[str]] = {}
for task_id, task in tasks.items():
if task_id not in graph:
graph[task_id] = set()
if task.requires is not None:
graph[task_id].add(task.requires)
sorted_tasks = []
order_key = get_module_order_key(tasks)
for task_ids in toposort(graph):
sorted_tasks.extend(sorted(task_ids, key=order_key))
return [(task_id, tasks[task_id]) for task_id in sorted_tasks]
[docs]
def register_modules(
session: Session,
modules: Iterable[tuple[str, ModuleType]],
tasks: Collection[tuple[str, _Task[..., Any]]]
) -> None:
""" Sets up the state tracking for all modules. Initially, all tasks
are marekd as executed, because we assume tasks to upgrade older
deployments to a new deployment.
If this is a new deployment we do not need to execute these tasks.
Tasks where this is not desired should be marked as 'always_run'.
They will then manage their own state (i.e. check if they need to
run or not).
This function is idempotent.
"""
for module, upgrade_module in modules:
query = session.query(UpgradeState)
query = query.filter(UpgradeState.module == module)
if query.first():
continue
session.add(
UpgradeState(module=module, state={
'executed_tasks': [
task.task_name for task_id, task in tasks
if task_id.startswith(module)
]
})
)
session.flush()
[docs]
def register_all_modules_and_tasks(session: Session) -> None:
""" Registers all the modules and all the tasks. """
register_modules(session, get_upgrade_modules(), get_tasks())
[docs]
class UpgradeTransaction:
""" Holds the session and the alembic operations connection together and
commits/aborts both at the same time.
Not really a two-phase commit solution. That could be accomplished but
for now this suffices.
"""
def __init__(self, context: UpgradeContext):
[docs]
self.operations_transaction = context.operations_connection.begin()
[docs]
self.session = context.session
[docs]
def flush(self) -> None:
self.session.flush()
[docs]
def commit(self) -> None:
self.operations_transaction.commit()
transaction.commit()
[docs]
def abort(self) -> None:
self.operations_transaction.rollback()
transaction.abort()
[docs]
class UpgradeContext:
""" Holdes the context of the upgrade. An instance of this is passed
to each upgrade task.
"""
# FIXME: alembic currently auto generates type stubs for alembic.op but
# not for Operations, we should probably make a PR, for Operations
# to be auto-generated as well, so we get the available methods
def __init__(self, request: CoreRequest):
# alembic is a somewhat heavy import (thanks to the integrated mako)
# -> since we really only ever need it during upgrades we lazy load
from alembic.migration import MigrationContext
from alembic.operations import Operations
# we need to reset the request session property every time because
# we reuse the request over various session managers
if 'session' in request.__dict__:
del request.__dict__['session']
[docs]
self.request = request
self.session = session = request.session
[docs]
self.app = request.app
[docs]
self.schema = request.app.session_manager.current_schema
[docs]
self.engine = self.session.bind
# The locale of the upgrade is the default locale
request.app.session_manager.set_locale(
current_locale=self.request.locale,
default_locale=self.request.default_locale
)
# Make sure the connection is the same for the session, the engine
# and the alembic migrations context. Otherwise the upgrade locks up.
# If that happens again, test with this:
#
# http://docs.sqlalchemy.org/en/latest/core/pooling.html
# #sqlalchemy.pool.AssertionPool
#
conn = session._connection_for_bind( # type:ignore[attr-defined]
session.bind)
[docs]
self.operations_connection: Connection = conn
self.operations = Operations(
MigrationContext.configure(self.operations_connection))
[docs]
def begin(self) -> UpgradeTransaction:
return UpgradeTransaction(self)
[docs]
def has_column(self, table: str, column: str) -> bool:
inspector = Inspector(self.operations_connection)
return column in {c['name'] for c in inspector.get_columns(
table, schema=self.schema
)}
[docs]
def has_enum(self, enum: str) -> bool:
return self.session.execute(f"""
SELECT EXISTS (
SELECT 1 FROM pg_type
WHERE typname = '{enum}'
and typnamespace = (
SELECT oid FROM pg_namespace
WHERE nspname = '{self.schema}'
)
)
""").scalar()
[docs]
def has_table(self, table: str) -> bool:
inspector = Inspector(self.operations_connection)
return table in inspector.get_table_names(schema=self.schema)
[docs]
def models(self, table: str) -> Iterator[Any]:
def has_matching_tablename(model: Any) -> bool:
if not hasattr(model, '__tablename__'):
return False # abstract bases
return model.__tablename__ == table
for base in self.request.app.session_manager.bases:
yield from find_models(base, has_matching_tablename)
[docs]
def records_per_table(
self,
table: str,
columns: Iterable[Column[Any]] | None = None
) -> Iterator[Any]:
if columns is None:
def filter_columns(q: Query[Any]) -> Query[Any]:
return q
else:
column_names = tuple(c.name for c in columns)
def filter_columns(q: Query[Any]) -> Query[Any]:
return q.options(load_only(*column_names))
for model in self.models(table):
yield from filter_columns(self.session.query(model))
@contextmanager
[docs]
def stop_search_updates(self) -> Iterator[None]:
# XXX this would be better handled with a more general approach
# that doesn't require knowledge of onegov.search
if hasattr(self.app, 'es_orm_events'):
self.app.es_orm_events.stopped = True
yield
self.app.es_orm_events.stopped = False
else:
yield
[docs]
def is_empty_table(self, table: str) -> bool:
return self.operations_connection.execute(
f'SELECT * FROM {table} LIMIT 1').rowcount == 0
[docs]
def add_column_with_defaults(
self,
table: str,
column: Column[_T],
default: _T | Callable[[Any], _T]
) -> None:
# XXX while adding default values we shouldn't reindex the data
# since this is what the default add_column code does and will be
# doing in the future when Postgres 11 supports defaults for
# new columns - the way this is implemented now is with tight coupling
# that we shouldn't do, but we want to replace this when Postgres 11
# comes around
with self.stop_search_updates():
nullable_column = column.copy()
nullable_column.nullable = True
self.operations.add_column(table, nullable_column)
if not self.is_empty_table(table):
# fill it with defaults if the tables has any rows
for record in self.records_per_table(table, (column, )):
value = default(record) if callable(default) else default
setattr(record, column.name, value)
if hasattr(getattr(record, column.name), 'changed'):
getattr(record, column.name).changed()
# make sure all generated values are flushed
self.session.flush()
# activate non-null constraint
if not column.nullable:
self.operations.alter_column(
table, column.name, nullable=False)
# propagate to db
self.session.flush()
[docs]
class RawUpgradeRunner:
""" Runs the given raw tasks. """
def __init__(
self,
tasks: Sequence[tuple[str, RawTask]],
commit: bool = True,
on_task_success: TaskCallback | None = None,
on_task_fail: TaskCallback | None = None
):
# FIXME: We should just move the arguments around or set
# a lambda that does noething as the default, or
# do what we do in UpgradeRunner
assert on_task_success is not None
assert on_task_fail is not None
[docs]
self._on_task_success = on_task_success
[docs]
self._on_task_fail = on_task_fail
[docs]
def run_upgrade(self, dsn: str, schemas: Sequence[str]) -> int:
# it's possible for applications to exist without schemas, if the
# application doesn't use multi-tennants and has not been run yet
if not schemas:
return 0
engine = create_engine(dsn, poolclass=StaticPool)
connection = engine.connect()
executions = 0
for id, task in self.tasks:
t = connection.begin()
try:
success = False
if inspect.isgeneratorfunction(task):
for changed in task(connection, schemas):
success = success or changed
if self.commit:
t.commit()
else:
t.rollback()
t = connection.begin()
else:
success = task(connection, schemas)
if success:
executions += 1
self._on_task_success(task)
except Exception:
t.rollback()
self._on_task_fail(task)
raise
else:
if self.commit:
t.commit()
else:
t.rollback()
return executions
[docs]
class UpgradeRunner:
""" Runs the given basic tasks. """
[docs]
states: dict[str, UpgradeState]
def __init__(
self,
modules: Sequence[tuple[str, ModuleType]],
tasks: Sequence[tuple[str, UpgradeTask]],
commit: bool = True,
on_task_success: TaskCallback | None = None,
on_task_fail: TaskCallback | None = None
):
self.states = {}
[docs]
self._on_task_success = on_task_success
[docs]
self._on_task_fail = on_task_fail
[docs]
def on_task_success(self, task: UpgradeTask) -> None:
if self._on_task_success is not None:
self._on_task_success(task)
[docs]
def on_task_fail(self, task: UpgradeTask) -> None:
if self._on_task_fail is not None:
self._on_task_fail(task)
[docs]
def get_state(self, context: UpgradeContext, module: str) -> UpgradeState:
if module not in self.states:
query = context.session.query(UpgradeState)
query = query.filter(UpgradeState.module == module)
# assume that all modules have been ran through register_module
self.states[module] = query.one()
return context.session.merge(self.states[module], load=True)
[docs]
def register_modules(self, request: CoreRequest) -> None:
register_modules(request.session, self.modules, self.tasks)
if self.commit:
transaction.commit()
[docs]
def get_module_from_task_id(self, task_id: str) -> str:
return task_id.split(':', 1)[0]
[docs]
def run_upgrade(self, request: CoreRequest) -> int:
self.register_modules(request)
tasks = ((self.get_module_from_task_id(i), t) for i, t in self.tasks)
executed = 0
for module, task in tasks:
context = UpgradeContext(request)
state = self.get_state(context, module)
if not task.always_run and state.was_already_executed(task):
continue
upgrade = context.begin()
try:
result = task(context)
# mark all tasks as executed, even 'always run' ones
state.mark_as_executed(task)
upgrade.flush()
# always-run tasks which return False are considered
# to have not run
hidden = task.always_run and result is False
if not hidden:
executed += 1
self.on_task_success(task)
except Exception:
upgrade.abort()
self.on_task_fail(task)
raise
else:
if self.commit:
upgrade.commit()
else:
upgrade.abort()
finally:
context.session.invalidate()
context.engine.dispose()
return executed