Source code for websockets.server

from __future__ import annotations

import hashlib
import http
from asyncio import Future
from functools import cached_property, partial
from http.cookies import SimpleCookie
from json import dumps, loads
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse

import transaction
from itsdangerous import BadSignature, Signer
from markupsafe import escape
from websockets.exceptions import ConnectionClosed, InvalidOrigin
from websockets.legacy.protocol import broadcast
from websockets.legacy.server import WebSocketServerProtocol, serve

from onegov.chat.collections import ChatCollection
from onegov.chat.utils import param_from_path
from onegov.core import cache
from onegov.core.browser_session import BrowserSession
from onegov.core.orm import Base, SessionManager
from onegov.user import User, UserCollection
from onegov.websockets import log
from onegov.websockets.security import (WebsocketSecurityError,
                                        consume_websocket_token)

if TYPE_CHECKING:
    from collections.abc import Collection
    from uuid import UUID

    from sqlalchemy.orm import Session
    from websockets import Headers
    from websockets.legacy.server import HTTPResponse

    from onegov.chat.models import Chat
    from onegov.core.types import JSONObject, JSONObject_ro
    from onegov.server.config import Config


[docs] CONNECTIONS: dict[str, set[WebSocketServerProtocol]] = {}
[docs] TOKEN = '' # nosec: B105
[docs] NOTFOUND = object()
[docs] SESSIONS: dict[str, Session] = {}
[docs] STAFF_CONNECTIONS: dict[str, set[WebSocketServerProtocol]] = {}
[docs] STAFF: dict[str, dict[str, User]] = {} # For Authentication of User
[docs] ACTIVE_CHATS: dict[str, dict[UUID, Chat]] = {} # For DB
[docs] CHANNELS: dict[str, dict[str, set[WebSocketServerProtocol]]] = {}
[docs] class WebSocketServer(WebSocketServerProtocol): """ A websocket server connection. This protocol handles multiple websocket applications: - Ticket notifications - Ticker - Chat Chat behaves differently from the others and will eventually be carved out into a separate service. To not interfere with any existing functionality, we try to refrain from making backwards-incompatible changes. That way, ticker and notifications should continue to work without any modification. TODO: Move chat to a dedicated service. """
[docs] schema: str
[docs] user_id: str | None
[docs] signed_session_id: str | None
def __init__( self, config: Config, session_manager: SessionManager, *args: Any, **kwargs: Any ): super().__init__(*args, **kwargs)
[docs] self.config = config
[docs] self.session_manager = session_manager
[docs] async def process_request( self, path: str, headers: Headers ) -> HTTPResponse | None: """ Intercept initial HTTP request. Before establishing a WebSocket connection, a client sends a HTTP request to "upgrade" the connection to a WebSocket connection. Chat ---- We authenticate the user before creating the WebSocket connection. The user is identified based on the session cookie. In addition to the cookie, we require a one-time token that the user must have obtained prior to requesting the WebSocket connection. """ url = urlparse(path) if '/chats' not in url.path: # For non-chat requests (e.g., ticker) we'll skip the dance below # and let the protocol handle authentication # (handle_authentication). return None try: cookie: SimpleCookie = SimpleCookie(headers['Cookie']) session_id = cookie['session_id'].value except KeyError: log.error( 'No session cookie found in request. ' 'Check that you sent the request from the same origin as ' f'the WebSocket server ({self.host})' ) return http.HTTPStatus.BAD_REQUEST, [], b'' self.signed_session_id = session_id try: self.schema = param_from_path('schema', path) except ValueError as err: log.error( f'Unable to retrieve schema from path: {path}', exc_info=err ) return http.HTTPStatus.BAD_REQUEST, [], b'' # browser_session requires self.schema self.user_id = self.browser_session.get('userid') try: # Consume the presented token or deny the connection. The token # acts like CSRF token to protect against Cross-Site WebSocket # Hijacks. consume_websocket_token(path, self.browser_session) except WebsocketSecurityError as err: log.error('Rejecting WebSocket connection.', exc_info=err) return http.HTTPStatus.UNAUTHORIZED, [], b'' try: # Checking the origin is done at a later stage by handshake(), this # check is totally superfluous. However, rejecting clients because # of a wrong origin would get unnoticed otherwise. You can safely # delete this block in the future. # # TODO: Pass in valid origins. Is there already a list of allowed # origins? self.process_origin(headers, self.origins) except InvalidOrigin as err: log.debug('WebSocket connection will be rejected.', exc_info=err) self.populate_staff() return None
[docs] def populate_staff(self) -> None: """ Populate staff users. """ STAFF[self.schema] = { user.username: user for user in ( UserCollection(self.session) .query() .filter(User.role.in_(['editor', 'admin'])) ) } transaction.commit()
[docs] async def get_chat(self, id: UUID) -> Chat: chat = ACTIVE_CHATS.setdefault(self.schema, {}).get(id, NOTFOUND) # Force (cached) session to fetch latest state of the database, # otherwise new chats are not visible to this session. self.session.expire_all() transaction.commit() if chat is NOTFOUND: chat = ChatCollection(self.session).by_id(id) log.debug(f'searching for chat with id {id}') log.debug(f'chat from collection {chat}') if chat and not chat.active: chat = None ACTIVE_CHATS[self.schema][id] = chat # type: ignore transaction.commit() return chat # type: ignore
[docs] async def update_database(self) -> None: self.session.flush() transaction.commit()
[docs] def unsign(self, text: str) -> str | None: """ Unsigns a signed text, returning None if unsuccessful. """ identity_secret = self.application_config[ 'identity_secret'] + self.application_id_hash try: signer = Signer(identity_secret, salt='generic-signer') return signer.unsign(text).decode('utf-8') except BadSignature: return None
@property
[docs] def session(self) -> Session: self.session_manager.set_current_schema(self.schema) session = self.session_manager.session() ACTIVE_CHATS[self.schema] = {} return session
@property
[docs] def application_id_hash(self) -> str: """ The application_id as hash, use this if the application_id can be read by the user -> this obfuscates things slightly. """ # sha-1 should be enough, because even if somebody was able to get # the cleartext value I honestly couldn't tell you what it could be # used for ... return hashlib.new( # nosec: B324 'sha1', self.application_id.encode('utf-8'), usedforsecurity=False ).hexdigest()
@property
[docs] def session_cache(self) -> cache.RedisCacheRegion: """ A cache that is kept for a long-ish time. """ day = 60 * 60 * 24 return cache.get( namespace=f'{self.application_id}:sessions', expiration_time=7 * day, redis_url=self.application_config.get('redis_url', 'redis://127.0.0.1:6379/0') )
@property
[docs] def namespace(self) -> str: return self.schema.split('-', 1)[0]
@property
[docs] def application_id(self) -> str: return '/'.join(self.schema.split('-', 1))
@property
[docs] def application_config(self) -> dict[str, Any]: for c in self.config.applications: if c.namespace == self.namespace: return c.configuration return {}
@cached_property
[docs] def browser_session(self) -> BrowserSession | dict[str, Any]: if self.signed_session_id is None: return {} session_id = self.unsign(self.signed_session_id) if session_id is None: return {} return BrowserSession( cache=self.session_cache, token=session_id, )
[docs] def get_payload( message: str | bytes, expected: Collection[str] ) -> JSONObject | None: """ Deserialize JSON payload and check type. """ try: payload = loads(message) assert payload['type'] in expected return payload except Exception: log.warning('Invalid payload received') return None
[docs] async def error( websocket: WebSocketServerProtocol, message: str, close: bool = True ) -> None: """ Sends an error. """ await websocket.send( dumps({ 'type': 'error', 'message': message }) ) if close: await websocket.close()
[docs] async def acknowledge(websocket: WebSocketServerProtocol) -> None: """ Sends an acknowledge. """ await websocket.send( dumps({ 'type': 'acknowledged' }) )
[docs] async def handle_listen( websocket: WebSocketServerProtocol, payload: JSONObject_ro ) -> None: """ Handles listening clients. """ assert payload.get('type') == 'register' schema = payload.get('schema') if not schema or not isinstance(schema, str): await error(websocket, f'invalid schema: {schema}') return channel = payload.get('channel') if channel is not None and not isinstance(channel, str): await error(websocket, f'invalid channel: {channel}') return await acknowledge(websocket) schema_channel = f'{schema}-{channel}' if channel else schema log.debug(f'{websocket.id} listens @ {schema_channel}') connections = CONNECTIONS.setdefault(schema_channel, set()) connections.add(websocket) try: await websocket.wait_closed() finally: connections = CONNECTIONS.setdefault(schema_channel, set()) if websocket in connections: connections.remove(websocket)
[docs] async def handle_authentication( websocket: WebSocketServerProtocol, payload: JSONObject_ro ) -> None: """ Handles authentication. """ assert payload.get('type') == 'authenticate' token = payload.get('token') if not token or not isinstance(token, str): await error(websocket, 'invalid token') return if token != TOKEN: await error(websocket, 'authentication failed') return await acknowledge(websocket) log.debug(f'{websocket.id} authenticated')
[docs] async def handle_status( websocket: WebSocketServerProtocol, payload: JSONObject_ro ) -> None: """ Handles status requests. """ assert payload.get('type') == 'status' await acknowledge(websocket) await websocket.send( dumps({ 'type': 'status', 'message': { 'connections': { key: len(values) for key, values in CONNECTIONS.items() } } }) ) log.debug(f'{websocket.id} status sent')
[docs] async def handle_broadcast( websocket: WebSocketServerProtocol, payload: JSONObject_ro ) -> None: """ Handles broadcasts. """ assert payload.get('type') == 'broadcast' message = payload.get('message') schema = payload.get('schema') channel = payload.get('channel') if not schema or not isinstance(schema, str): await error(websocket, f'invalid schema: {schema}') return if channel is not None and not isinstance(channel, str): await error(websocket, f'invalid channel: {channel}') return if not message: await error(websocket, 'missing message') return await acknowledge(websocket) schema_channel = f'{schema}-{channel}' if channel else schema connections = CONNECTIONS.get(schema_channel, set()) if connections: broadcast( connections, dumps({ 'type': 'notification', 'message': message }) ) log.debug( f'{websocket.id} sent {message}' f' to {len(connections)} receiver(s) @ {schema_channel}' )
[docs] async def handle_manage( websocket: WebSocketServerProtocol, authentication_payload: JSONObject_ro ) -> None: """ Handles managing clients. """ await handle_authentication(websocket, authentication_payload) async for message in websocket: payload = get_payload(message, ('broadcast', 'status')) if payload and payload['type'] == 'broadcast': await handle_broadcast(websocket, payload) elif payload and payload['type'] == 'status': await handle_status(websocket, payload) else: await error( websocket, # FIXME: technically message can be bytes f'invalid command: {message}' # type:ignore )
[docs] async def handle_customer_chat( websocket: WebSocketServer, payload: JSONObject_ro ) -> None: """ Starts a chat. Handles listening to messages on channel. """ schema = payload.get('schema') if not schema or not isinstance(schema, str): await error(websocket, f'invalid schema: {schema}') return if 'active_chat_id' not in websocket.browser_session: log.error( 'Unable to find active_chat_id in session, aborting.' ) return None channel = websocket.browser_session['active_chat_id'] await acknowledge(websocket) all_channels = CHANNELS.setdefault(schema, {}) channel_connections = all_channels.setdefault( channel.hex, set() ) channel_connections.add(websocket) staff_connections = STAFF_CONNECTIONS.setdefault(schema, set()) chat = await websocket.get_chat(channel.hex) log.debug(f'added {websocket.id} to channel-connections') while websocket.open: try: message = await websocket.recv() log.debug(f'customer {websocket.id!r} got the message {message!r}') if loads(message)['type'] == 'message': stored = ChatCollection(websocket.session).by_id(channel) if not stored: log.error(f'Unable to find stored chat with {channel=}') continue chat = stored content = loads(message) closed_connections = [] for client in channel_connections: try: await client.send(dumps({ 'type': 'notification', 'message': message, })) except ConnectionClosed as err: log.error( 'Attempting to communicate with a closed' 'connection, removing client from channels.', exc_info=err ) closed_connections.append(client) for connection in closed_connections: channel_connections.remove(connection) # If customer is the only connection send chat request if len(channel_connections) == 1 and not chat.user_id: log.debug('only client in channel, sending request.') for client in staff_connections: await client.send(dumps({ 'type': 'notification', 'message': dumps({ 'type': 'request', 'text': content['text'], 'userId': content['userId'], 'user': content['user'], 'topic': chat.topic, 'channel': channel.hex }) })) chat_history = chat.chat_history.copy() chat_history.append({ 'userId': escape(content['userId']), 'user': escape(content['user']), 'text': escape(content['text']), 'time': escape(content['time']), }) chat.chat_history = chat_history except Exception as e: log.exception('The debugged error message is -', exc_info=e) channel_connections.remove(websocket) log.debug(f'removed {websocket.id} from channel-connections') finally: await websocket.update_database() return None
[docs] async def handle_staff_chat( websocket: WebSocketServer, payload: JSONObject_ro ) -> None: """ Registers staff member and listens to messages. """ schema = payload.get('schema') if not schema or not isinstance(schema, str): await error(websocket, f'invalid schema: {schema}') return _ = websocket.session await acknowledge(websocket) if websocket.user_id in STAFF[schema]: log.debug('User is in Database.') all_channels = CHANNELS.setdefault(schema, {}) staff_connections = STAFF_CONNECTIONS.setdefault(schema, set()) staff_connections.add(websocket) channel_connections: set[WebSocketServerProtocol] = set() open_channel = '' log.debug(f'added {websocket.id} to staff-connections') while websocket.open: try: message = await websocket.recv() content = loads(message) log.debug( f'staff member {websocket.id!r} ' f'got the message {message!r}' ) # Forward each websocket message, no matter the type log.debug( f'current channel connections: {channel_connections}') closed_connections = [] for client in channel_connections: try: await client.send(dumps({ 'type': 'notification', 'message': message, })) except ConnectionClosed as err: log.error( 'Attempting to communicate with a closed' 'connection, removing client from channels.', exc_info=err ) closed_connections.append(client) for connection in closed_connections: channel_connections.remove(connection) # If the type is a message, save to DB if content['type'] == 'message': chat = ( ChatCollection(websocket.session) .by_id(open_channel) ) if not chat: log.error( f'Unable to find stored chat with {open_channel=}' ) continue log.debug(f'staff received message {content}') chat_history = chat.chat_history.copy() chat_history.append({ 'userId': escape(content['userId']), 'user': escape(content['user']), 'text': escape(content['text']), 'time': escape(content['time']), }) chat.chat_history = chat_history elif content['type'] == 'reconnect': log.debug(f'reconnecting to channel {content["channel"]}') channel_connections = all_channels.setdefault( content['channel'], set() ) channel_connections.add(websocket) elif content['type'] == 'end-chat': log.debug(f'ending chat with id {content["channel"]}') chat = ChatCollection(websocket.session).by_id( escape(content['channel']) ) if not chat: log.error( "Unable to find stored chat" f"with {content['channel']=}" ) continue chat.active = False elif content['type'] == 'accepted': log.debug('staff-member accepted-request') open_channel = loads(message)['channel'] channel_connections = all_channels.setdefault( open_channel, set() ) channel_connections.add(websocket) chat = ChatCollection(websocket.session).by_id( open_channel) if not chat: log.error( 'Unable to find stored chat' f'with {open_channel=}' ) continue # Tell everone else you've accepted for client in staff_connections: if client != websocket: inner = dumps({ 'type': 'hide-request', 'channel': open_channel }) await client.send(dumps({ 'type': 'notification', 'message': inner, })) inner = dumps({ 'type': 'chat-history', 'history': chat.chat_history, 'channel': open_channel }) await websocket.send(dumps({ 'type': 'notification', 'message': inner, })) log.debug('sent chat history') # FIXME: Rather than escape we should try to parse this # as an UUID, since otherwise the DB update will # fail anyways chat.user_id = escape(content['userId']) # type:ignore elif content['type'] == 'request-chat-history': open_channel = content['channel'] chat = ChatCollection(websocket.session).by_id( open_channel) if not chat: log.error( 'Unable to find stored chat' f'with {open_channel=}' ) continue channel_connections = all_channels.setdefault(open_channel, set()) channel_connections.add(websocket) log.debug('staff member reconnected') inner = dumps({ 'type': 'chat-history', 'history': chat.chat_history, 'channel': open_channel }) await websocket.send(dumps({ 'type': 'notification', 'message': inner, })) except Exception as e: log.exception('The debugged error message is -', exc_info=e) if websocket in staff_connections: staff_connections.remove(websocket) log.debug(f'removed {websocket.id} from staff-connections') finally: await websocket.update_database()
[docs] async def handle_start(websocket: WebSocketServerProtocol) -> None: log.debug(f'{websocket.id} connected') message = await websocket.recv() payload = get_payload(message, ('authenticate', 'register', 'customer_chat', 'staff_chat')) if payload and payload['type'] == 'authenticate': await handle_manage(websocket, payload) elif payload and payload['type'] == 'register': await handle_listen(websocket, payload) elif payload and (payload['type'] == 'customer_chat'): await handle_customer_chat(websocket, payload) # type: ignore elif payload and (payload['type'] == 'staff_chat'): await handle_staff_chat(websocket, payload) # type: ignore else: # FIXME: technically message can be bytes await error(websocket, f'invalid command: {message}') # type:ignore log.debug(f'{websocket.id} disconnected')
[docs] async def main( host: str, port: int, token: str, config: Config | None = None ) -> None: global TOKEN TOKEN = token log.debug(f'Serving on ws://{host}:{port}') if config: dsn = config.applications[0].configuration['dsn'] session_manager = SessionManager( dsn, Base, session_config={'autoflush': False} ) async with serve(handle_start, host, port, create_protocol=partial(WebSocketServer, config, session_manager)): await Future() else: async with serve(handle_start, host, port): await Future()