from __future__ import annotations
import requests
import stripe
import transaction
from contextlib import contextmanager
from datetime import datetime
from decimal import Decimal
from functools import cached_property
from markupsafe import Markup
from onegov.core.orm.mixins import dict_property, meta_property
from onegov.pay import log
from onegov.pay.models.payment import Payment
from onegov.pay.models.payment_provider import PaymentProvider
from onegov.pay.utils import Price
from sqlalchemy.orm import object_session
from uuid import UUID, uuid4, uuid5
from typing import Any, TypeVar, TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Collection, Callable, Iterator, Mapping
from onegov.pay.types import FeePolicy
from sqlalchemy.orm import relationship, Query, Session
from transaction.interfaces import ITransaction
from typing_extensions import ParamSpec
[docs]
_R = TypeVar('_R', bound=stripe.api_resources.abstract.ListableAPIResource)
_P = ParamSpec('_P')
@contextmanager
[docs]
def stripe_api_key(key: str | None) -> Iterator[None]:
old_key = stripe.api_key
stripe.api_key = key
yield
stripe.api_key = old_key
# instantiate once to get keep-alive support
stripe.default_http_client = stripe.http_client.RequestsClient()
# our stripe payment ids are generated by using the token and a UUID namespace
[docs]
STRIPE_NAMESPACE = UUID('aebb1a32-282b-4521-838d-92a1136624d1')
[docs]
class StripeCaptureManager:
""" Captures an open stripe charge when the transaction finishes.
If there is an error during this step, it is logged, but the transaction
still continues successfully.
The user is then supposed to manually capture the charge.
"""
[docs]
transaction_manager = transaction.manager
def __init__(self, api_key: str, charge_id: str):
[docs]
self.charge_id = charge_id
@classmethod
[docs]
def capture_charge(cls, api_key: str, charge_id: str) -> None:
transaction.get().join(cls(api_key, charge_id))
[docs]
def sortKey(self) -> str:
return 'charge'
[docs]
def tpc_vote(self, transaction: ITransaction) -> None:
with stripe_api_key(self.api_key):
self.charge = stripe.Charge.retrieve(self.charge_id)
[docs]
def tpc_finish(self, transaction: ITransaction) -> None:
try:
with stripe_api_key(self.api_key):
self.charge.capture()
except Exception:
# we can never fail or we might end up with an incosistent
# database -> so must swallow any errors and report them
log.exception('Stripe charge with capture id {} failed'.format(
self.charge_id
))
[docs]
def commit(self, transaction: ITransaction) -> None:
pass
[docs]
def abort(self, transaction: ITransaction) -> None:
pass
[docs]
def tpc_begin(self, transaction: ITransaction) -> None:
pass
[docs]
def tpc_abort(self, transaction: ITransaction) -> None:
pass
[docs]
class StripeFeePolicy:
""" All stripe fee calculations in one place (should they ever change). """
@classmethod
[docs]
def from_amount(cls, amount: Decimal | float) -> float:
""" Gets the fee for the given amount. """
return round(float(amount) * cls.percentage + cls.fixed, 2)
@classmethod
[docs]
def compensate(cls, amount: Decimal | float) -> float:
""" Increases the amount in such a way that the stripe fee is included
in the effective charge (that is, the user paying the charge is paying
the fee as well).
"""
return round((float(amount) + cls.fixed) / (1 - cls.percentage), 2)
[docs]
class StripePayment(Payment):
[docs]
__mapper_args__ = {'polymorphic_identity': 'stripe_connect'}
[docs]
fee_policy: FeePolicy = StripeFeePolicy
#: the date of the payout
[docs]
payout_date: dict_property[datetime | None] = meta_property()
#: the id of the payout
[docs]
payout_id: dict_property[str | None] = meta_property()
#: the fee deducted by stripe
[docs]
effective_fee: dict_property[float | None] = meta_property()
if TYPE_CHECKING:
# our provider should always be StripeConnect, we could
# assert if we really wanted to make sure, but it would
# add a lot of assertions...
[docs]
provider: relationship[StripeConnect]
@property
[docs]
def fee(self) -> Decimal:
""" The calculated fee or the effective fee if available.
The effective fee is taken from the payout records. In practice
these values should always be the same.
"""
if self.effective_fee:
return Decimal(self.effective_fee)
assert self.amount is not None
return Decimal(self.fee_policy.from_amount(self.amount))
@property
[docs]
def remote_url(self) -> str:
if self.provider.livemode:
base = 'https://dashboard.stripe.com/payments/{}'
else:
base = 'https://dashboard.stripe.com/test/payments/{}'
return base.format(self.remote_id)
@property
[docs]
def charge(self) -> stripe.Charge:
with stripe_api_key(self.provider.access_token):
return stripe.Charge.retrieve(self.remote_id)
[docs]
def refund(self) -> stripe.Refund:
with stripe_api_key(self.provider.access_token):
refund = stripe.Refund.create(charge=self.remote_id)
self.state = 'cancelled'
return refund
[docs]
def sync(self, remote_obj: stripe.Charge | None = None) -> None:
charge = remote_obj or self.charge
if not charge.captured:
self.state = 'open'
elif charge.refunded:
self.state = 'cancelled'
elif charge.status == 'failed':
self.state = 'failed'
elif charge.captured and charge.paid:
self.state = 'paid'
[docs]
class StripeConnect(PaymentProvider[StripePayment]):
[docs]
__mapper_args__ = {'polymorphic_identity': 'stripe_connect'}
[docs]
fee_policy: FeePolicy = StripeFeePolicy
#: The Stripe Connect client id
[docs]
client_id: dict_property[str | None] = meta_property()
#: The API key of the connect user
[docs]
client_secret: dict_property[str | None] = meta_property()
#: The oauth_redirect gateway in use (see seantis/oauth_redirect on github)
[docs]
oauth_gateway: dict_property[str | None] = meta_property()
#: The auth code required by oauth_redirect
[docs]
oauth_gateway_auth: dict_property[str | None] = meta_property()
#: The oauth_redirect secret that should be used
[docs]
oauth_gateway_secret: dict_property[str | None] = meta_property()
#: The authorization code provided by OAuth
[docs]
authorization_code: dict_property[str | None] = meta_property()
#: The public stripe key
[docs]
publishable_key: dict_property[str | None] = meta_property()
#: The stripe user id as confirmed by OAuth
[docs]
user_id: dict_property[str | None] = meta_property()
#: The refresh token provided by OAuth
[docs]
refresh_token: dict_property[str | None] = meta_property()
#: The access token provieded by OAuth
[docs]
access_token: dict_property[str | None] = meta_property()
#: The id of the latest processed balance transaction
[docs]
latest_payout: dict_property[stripe.Payout | None] = meta_property()
#: Should the fee be charged to the customer or not?
[docs]
charge_fee_to_customer: dict_property[bool | None] = meta_property()
[docs]
def adjust_price(self, price: Price | None) -> Price | None:
if price and self.charge_fee_to_customer:
new_price = self.fee_policy.compensate(price.amount)
new_fee = self.fee_policy.from_amount(new_price)
return Price(
new_price,
price.currency,
new_fee,
price.credit_card_payment
)
return price
@property
[docs]
def livemode(self) -> bool:
assert self.access_token is not None
return not self.access_token.startswith('sk_test')
@property
[docs]
def payment_class(self) -> type[StripePayment]:
return StripePayment
@property
[docs]
def title(self) -> str:
return 'Stripe Connect'
@property
[docs]
def url(self) -> str:
return 'https://dashboard.stripe.com/'
@property
[docs]
def public_identity(self) -> str:
account = self.account
assert account is not None
business_name = getattr(account, 'business_name', '')
if business_name:
return f'{account.business_name} / {account.email}'
return account.email
@property
[docs]
def identity(self) -> str | None:
return self.user_id
@cached_property
[docs]
def account(self) -> stripe.Account | None:
with stripe_api_key(self.access_token):
return stripe.Account.retrieve(id=self.user_id)
@property
[docs]
def connected(self) -> bool:
return self.account and True or False
[docs]
def charge(
self,
amount: Decimal,
currency: str,
token: str
) -> StripePayment:
session = object_session(self)
payment = self.payment(
id=uuid5(STRIPE_NAMESPACE, token),
amount=amount,
currency=currency,
state='open'
)
with stripe_api_key(self.access_token):
charge = stripe.Charge.create(
amount=round(amount * 100, 0),
currency=currency,
source=token,
capture=False,
idempotency_key=token,
metadata={
'payment_id': payment.id.hex
}
)
assert self.access_token is not None
StripeCaptureManager.capture_charge(self.access_token, charge.id)
payment.remote_id = charge.id
# we do *not* want to lose this information, so even though the
# caller should make sure the payment is stored, we make sure
session.add(payment)
return payment
[docs]
def oauth_url(
self,
redirect_uri: str,
state: str | None = None,
user_fields: dict[str, Any] | None = None
) -> str:
""" Generates an oauth url to be shown in the browser. """
return stripe.OAuth.authorize_url(
client_id=self.client_id,
client_secret=self.client_secret,
scope='read_write',
redirect_uri=redirect_uri,
stripe_user=user_fields,
state=state
)
[docs]
def prepare_oauth_request(
self,
redirect_uri: str,
success_url: str,
error_url: str,
user_fields: dict[str, Any] | None = None
) -> str:
""" Registers the oauth request with the oauth_gateway and returns
an url that is ready to be used for the complete oauth request.
"""
assert (self.oauth_gateway
and self.oauth_gateway_auth
and self.oauth_gateway_secret)
register = f'{self.oauth_gateway}/register/{self.oauth_gateway_auth}'
payload = {
'url': redirect_uri,
'secret': self.oauth_gateway_secret,
'method': 'GET',
'success_url': success_url,
'error_url': error_url
}
response = requests.post(register, json=payload, timeout=60)
assert response.status_code == 200
return self.oauth_url(
redirect_uri='{}/redirect'.format(self.oauth_gateway),
state=response.json()['token'],
user_fields=user_fields
)
[docs]
def process_oauth_response(
self,
request_params: Mapping[str, Any]
) -> None:
""" Takes the parameters of an incoming oauth request and stores
them on the payment provider if successful.
"""
if 'error' in request_params:
raise RuntimeError('Stripe OAuth request failed ({}: {})'.format(
request_params['error'], request_params['error_description']
))
assert (
request_params['oauth_redirect_secret']
== self.oauth_gateway_secret)
self.authorization_code = request_params['code']
with stripe_api_key(self.client_secret):
data = stripe.OAuth.token(
grant_type='authorization_code',
code=self.authorization_code,
)
assert data['scope'] == 'read_write'
self.publishable_key = data['stripe_publishable_key']
self.user_id = data['stripe_user_id']
self.refresh_token = data['refresh_token']
self.access_token = data['access_token']
[docs]
def sync(self) -> None:
session = object_session(self)
self.sync_payment_states(session)
self.sync_payouts(session)
[docs]
def sync_payment_states(self, session: Session) -> None:
def payments(ids: Collection[UUID]) -> Query[StripePayment]:
q = session.query(self.payment_class)
q = q.filter(self.payment_class.id.in_(ids))
return q
charges: Iterator[stripe.Charge] = self.paged(
stripe.Charge.list,
limit=100
)
by_payment = {
payment_id: charge
for charge in charges
if (payment_id := charge.metadata.get('payment_id')) is not None
}
for payment in payments(by_payment.keys()):
payment.sync(remote_obj=by_payment[payment.id.hex])
[docs]
def sync_payouts(self, session: Session) -> None:
"""
See https://stripe.com/docs/api/balance_transactions/list
and https://stripe.com/docs/api/payouts
"""
payouts: Iterator[stripe.Payout] = self.paged(
stripe.Payout.list, limit=100, status='paid'
)
latest_payout = None
paid_charges = {}
for payout in payouts:
if latest_payout is None:
latest_payout = payout
if payout.id == self.latest_payout:
break
# We have to skip payouts that were manually triggered
# otherwise resulting in error with http_status of 400
if not payout.automatic:
continue
transactions: Iterator[stripe.BalanceTransaction] = self.paged(
stripe.BalanceTransaction.list,
limit=100,
payout=payout.id,
type='charge'
)
for charge in transactions:
paid_charges[charge.source] = (
datetime.fromtimestamp(payout.arrival_date),
payout.id,
charge.fee / 100
)
if paid_charges:
q = session.query(self.payment_class)
q = q.filter(self.payment_class.remote_id.in_(paid_charges.keys()))
for p in q:
p.payout_date, p.payout_id, p.effective_fee = (
paid_charges[p.remote_id])
self.latest_payout = latest_payout and latest_payout.id
[docs]
def paged(
self,
method: Callable[_P, stripe.ListObject],
*args: _P.args,
**kwargs: _P.kwargs
) -> Iterator[_R]:
with stripe_api_key(self.access_token):
list_obj = method(*args, **kwargs)
yield from list_obj.auto_paging_iter()