from __future__ import annotations
import requests
from attr import attrs, attrib
from base64 import urlsafe_b64encode
from jwt import PyJWKClient, decode_complete, get_algorithm_by_name
from jwt.exceptions import InvalidIssuerError, InvalidSignatureError
from oauthlib.oauth2.rfc6749.endpoints import AuthorizationEndpoint
from oauthlib.oauth2.rfc6749.endpoints import MetadataEndpoint
from oauthlib.oauth2.rfc6749.endpoints import TokenEndpoint
from oauthlib.oauth2.rfc6749.grant_types import AuthorizationCodeGrant
from requests_oauthlib import OAuth2Session
from secrets import compare_digest
from typing import Any, Self, TYPE_CHECKING
if TYPE_CHECKING:
from onegov.core.request import CoreRequest
from onegov.user.auth.provider import (
HasApplicationIdAndNamespace, OIDCProvider)
@attrs(auto_attribs=True)
[docs]
class OIDCAttributes:
"""
Holds the expected OIDC claims.
These claims may either be included in the JWT id token
or in the response to the user endpoint
"""
# the unique id in the OIDC provider
# The username (should be an email), use for User.username
# The users first name if available, use for User.realname
# The users last name if available, use for User.realname
# the name of the group
# Can be used if first / last name are not available to fill User.realname
[docs]
preferred_username: str
@classmethod
[docs]
def from_cfg(cls, cfg: dict[str, Any]) -> Self:
return cls(
source_id=cfg.get('source_id', 'sub'),
username=cfg.get('username', 'email'),
group=cfg.get('group', 'group'),
first_name=cfg.get('first_name', 'given_name'),
last_name=cfg.get('last_name', 'family_name'),
preferred_username='preferred_username'
)
@attrs()
[docs]
class OIDCClient:
[docs]
client_id: str = attrib()
[docs]
client_secret: str = attrib()
[docs]
button_text: str = attrib()
# Needed attributes in the jwt token
[docs]
attributes: OIDCAttributes = attrib()
[docs]
primary: bool = attrib()
# Required OAuth scope in addition to "openid"
[docs]
scope: list[str] = attrib(factory=list)
# Override/amend discovered metadata
[docs]
def session(
self,
provider: OIDCProvider,
request: CoreRequest,
*,
with_openid_scope: bool = False,
) -> OAuth2Session:
""" Returns a requests session tied to a OAuth2 client """
assert isinstance(self.scope, list), 'Invalid scope, expected list'
provider_cls = type(provider)
redirect_url = request.class_link(
provider_cls, {'name': provider.name}, name='redirect')
return OAuth2Session(
self.client_id,
scope=['openid', *self.scope] if with_openid_scope else self.scope,
redirect_uri=redirect_url,
)
[docs]
def validate_token(
self,
request: CoreRequest,
token: dict[str, Any]
) -> dict[str, Any]:
metadata = self.metadata(request)
access_token = token.get('access_token')
id_token = token['id_token']
jwks_client = metadata['jwks_client']
signing_key = jwks_client.get_signing_key_from_jwt(id_token)
# TODO: Should we provide some configurable leeway for exp?
data = decode_complete(
id_token,
key=signing_key,
audience=self.client_id,
issuer=self.issuer,
algorithms=metadata.get(
'id_token_signing_alg_values_supported',
['RS256']
),
# the following claims are required for OIDC
options={'require': [
'iss',
'sub',
'aud',
'exp',
'iat'
]}
)
header = data['header']
payload = data['payload']
if access_token:
# validate the access_token using at_hash
alg = get_algorithm_by_name(header['alg'])
digest = alg.compute_hash_digest(access_token.encode('utf-8'))
at_hash = urlsafe_b64encode(digest[:len(digest) // 2])
given_at_hash = payload.get('at_hash', '').encode('utf-8')
if not compare_digest(at_hash, given_at_hash):
raise InvalidSignatureError('at_hash was missing or incorrect')
return payload
@attrs
[docs]
class OIDCConnections:
# instantiated connections for every tenant
[docs]
connections: dict[str, OIDCClient] = attrib()
[docs]
def client(self, app: HasApplicationIdAndNamespace) -> OIDCClient | None:
if app.application_id in self.connections:
return self.connections[app.application_id]
if app.namespace in self.connections:
return self.connections[app.namespace]
return None
@classmethod
[docs]
def from_cfg(cls, config: dict[str, Any]) -> Self:
clients = {
app_id: OIDCClient(
issuer=cfg['issuer'],
client_id=cfg['client_id'],
client_secret=cfg['client_secret'],
scope=cfg.get('scope', []),
attributes=OIDCAttributes.from_cfg(
cfg.get('attributes', {})
),
button_text=cfg['button_text'],
primary=cfg.get('primary', False),
fixed_metadata=cfg.get('fixed_metadata', {}),
) for app_id, cfg in config.items()
}
return cls(connections=clients)