From e865c31d69a6993b7f2463e5cfa62c2e89cd3a00 Mon Sep 17 00:00:00 2001 From: Yohan Boniface Date: Thu, 26 Dec 2024 15:51:05 +0100 Subject: [PATCH] wip(sync): move sync related code to a dedicated module --- umap/asgi.py | 2 +- .../commands/run_websocket_server.py | 23 -- umap/sync/__init__.py | 0 umap/{ => sync}/consumers.py | 61 +----- umap/sync/payloads.py | 49 +++++ umap/websocket_server.py | 202 ------------------ 6 files changed, 60 insertions(+), 277 deletions(-) delete mode 100644 umap/management/commands/run_websocket_server.py create mode 100644 umap/sync/__init__.py rename umap/{ => sync}/consumers.py (55%) create mode 100644 umap/sync/payloads.py delete mode 100644 umap/websocket_server.py diff --git a/umap/asgi.py b/umap/asgi.py index 6668a94c..5b130b65 100644 --- a/umap/asgi.py +++ b/umap/asgi.py @@ -5,7 +5,7 @@ from channels.security.websocket import AllowedHostsOriginValidator from django.core.asgi import get_asgi_application from django.urls import re_path -from . import consumers +from .sync import consumers os.environ.setdefault("DJANGO_SETTINGS_MODULE", "umap.settings") # Initialize Django ASGI application early to ensure the AppRegistry diff --git a/umap/management/commands/run_websocket_server.py b/umap/management/commands/run_websocket_server.py deleted file mode 100644 index 2cb2db89..00000000 --- a/umap/management/commands/run_websocket_server.py +++ /dev/null @@ -1,23 +0,0 @@ -from django.conf import settings -from django.core.management.base import BaseCommand - -from umap import websocket_server - - -class Command(BaseCommand): - help = "Run the websocket server" - - def add_arguments(self, parser): - parser.add_argument( - "--host", - help="The server host to bind to.", - default=settings.WEBSOCKET_BACK_HOST, - ) - parser.add_argument( - "--port", - help="The server port to bind to.", - default=settings.WEBSOCKET_BACK_PORT, - ) - - def handle(self, *args, **options): - websocket_server.run(options["host"], options["port"]) diff --git a/umap/sync/__init__.py b/umap/sync/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/umap/consumers.py b/umap/sync/consumers.py similarity index 55% rename from umap/consumers.py rename to umap/sync/consumers.py index d30f253c..f9c3492b 100644 --- a/umap/consumers.py +++ b/umap/sync/consumers.py @@ -1,14 +1,16 @@ +import logging + from channels.generic.websocket import AsyncWebsocketConsumer from django.core.signing import TimestampSigner +from pydantic import ValidationError -from .websocket_server import ( +from .payloads import ( JoinRequest, JoinResponse, ListPeersResponse, OperationMessage, PeerMessage, Request, - ValidationError, ) @@ -33,7 +35,6 @@ class SyncConsumer(AsyncWebsocketConsumer): return self.channel_layer.groups[self.map_id].keys() async def connect(self): - print("connect") self.map_id = self.scope["url_route"]["kwargs"]["map_id"] # Join room group @@ -43,7 +44,6 @@ class SyncConsumer(AsyncWebsocketConsumer): await self.send_peers_list() async def disconnect(self, close_code): - print("disconnect") await self.channel_layer.group_discard(self.map_id, self.channel_name) await self.send_peers_list() @@ -52,43 +52,32 @@ class SyncConsumer(AsyncWebsocketConsumer): await self.broadcast(message.model_dump_json()) async def broadcast(self, message): - print("broadcast", message) + # Send to one all channels await self.channel_layer.group_send( self.map_id, {"message": message, "type": "on_message"} ) async def send_to(self, channel, message): - print("pair to pair", channel, message) + # Send to one given channel await self.channel_layer.send( channel, {"message": message, "type": "on_message"} ) async def on_message(self, event): - # This is what the consummers does for a single channel + # Send to self channel await self.send(event["message"]) async def receive(self, text_data): - print("receive") - print(text_data) if text_data == "ping": return await self.send("pong") - # recompute the peers list at the time of message-sending. - # as doing so beforehand would miss new connections - # other_peers = connections.get_other_peers(websocket) - # await self.send("pong" + self.channel_name) - # await self.channel_layer.group_send( - # self.map_id, - # {"message": "pouet " + self.channel_name, "type": "broadcast"}, - # ) try: incoming = Request.model_validate_json(text_data) - except ValidationError: - error = ( + except ValidationError as error: + message = ( f"An error occurred when receiving the following message: {text_data!r}" ) - print(error) - # logging.error(error, e) + logging.error(message, error) else: match incoming.root: # Broadcast all operation messages to connected peers @@ -100,34 +89,4 @@ class SyncConsumer(AsyncWebsocketConsumer): # Send peer messages to the proper peer case PeerMessage(): - print("Received peermessage", incoming.root) await self.send_to(incoming.root.recipient, text_data) - - # Send peer messages to the proper peer - # case PeerMessage(recipient=_id): - # peer = connections.get(_id) - # if peer: - # await peer.send(text_data) - - # message = JoinRequest.model_validate_json(text_data) - # signed = TimestampSigner().unsign_object(message.token, max_age=30) - # user, map_id, permissions = signed.values() - - # # Check if permissions for this map have been granted by the server - # if "edit" in signed["permissions"]: - # connections = CONNECTIONS[map_id] - # _id = connections.join(self) - - # # Assign an ID to the joining peer and return it the list of connected peers. - # peers: list[WebSocketClientProtocol] = [ - # connections.get_id(p) for p in connections.get_all_peers() - # ] - # response = JoinResponse(uuid=_id, peers=peers) - # await self.send(response.model_dump_json()) - - # # await join_and_listen(map_id, permissions, user, websocket) - - # # text_data_json = json.loads(text_data) - # # message = text_data_json["message"] - - # # self.send(text_data=json.dumps({"message": message})) diff --git a/umap/sync/payloads.py b/umap/sync/payloads.py new file mode 100644 index 00000000..cfe2a003 --- /dev/null +++ b/umap/sync/payloads.py @@ -0,0 +1,49 @@ +from typing import Literal, Optional, Union + +from pydantic import BaseModel, Field, RootModel + + +class JoinRequest(BaseModel): + kind: Literal["JoinRequest"] = "JoinRequest" + token: str + + +class OperationMessage(BaseModel): + """Message sent from one peer to all the others""" + + kind: Literal["OperationMessage"] = "OperationMessage" + verb: Literal["upsert", "update", "delete"] + subject: Literal["map", "datalayer", "feature"] + metadata: Optional[dict] = None + key: Optional[str] = None + + +class PeerMessage(BaseModel): + """Message sent from a specific peer to another one""" + + kind: Literal["PeerMessage"] = "PeerMessage" + sender: str + recipient: str + # The message can be whatever the peers want. It's not checked by the server. + message: dict + + +class Request(RootModel): + """Any message coming from the websocket should be one of these, and will be rejected otherwise.""" + + root: Union[PeerMessage, OperationMessage, JoinRequest] = Field( + discriminator="kind" + ) + + +class JoinResponse(BaseModel): + """Server response containing the list of peers""" + + kind: Literal["JoinResponse"] = "JoinResponse" + peers: list + uuid: str + + +class ListPeersResponse(BaseModel): + kind: Literal["ListPeersResponse"] = "ListPeersResponse" + peers: list diff --git a/umap/websocket_server.py b/umap/websocket_server.py deleted file mode 100644 index 346fe960..00000000 --- a/umap/websocket_server.py +++ /dev/null @@ -1,202 +0,0 @@ -#!/usr/bin/env python - -import asyncio -import logging -import uuid -from collections import defaultdict -from typing import Literal, Optional, Union - -import websockets -from django.conf import settings -from django.core.signing import TimestampSigner -from pydantic import BaseModel, Field, RootModel, ValidationError -from websockets import WebSocketClientProtocol -from websockets.server import serve - - -class Connections: - def __init__(self) -> None: - self._connections: set[WebSocketClientProtocol] = set() - self._ids: dict[WebSocketClientProtocol, str] = dict() - - def join(self, websocket: WebSocketClientProtocol) -> str: - self._connections.add(websocket) - _id = str(uuid.uuid4()) - self._ids[websocket] = _id - return _id - - def leave(self, websocket: WebSocketClientProtocol) -> None: - self._connections.remove(websocket) - del self._ids[websocket] - - def get(self, id) -> WebSocketClientProtocol: - # use an iterator to stop iterating as soon as we found - return next(k for k, v in self._ids.items() if v == id) - - def get_id(self, websocket: WebSocketClientProtocol): - return self._ids[websocket] - - def get_other_peers( - self, websocket: WebSocketClientProtocol - ) -> set[WebSocketClientProtocol]: - return self._connections - {websocket} - - def get_all_peers(self) -> set[WebSocketClientProtocol]: - return self._connections - - -# Contains the list of websocket connections handled by this process. -# It's a mapping of map_id to a set of the active websocket connections -CONNECTIONS: defaultdict[int, Connections] = defaultdict(Connections) - - -class JoinRequest(BaseModel): - kind: Literal["JoinRequest"] = "JoinRequest" - token: str - - -class OperationMessage(BaseModel): - """Message sent from one peer to all the others""" - - kind: Literal["OperationMessage"] = "OperationMessage" - verb: Literal["upsert", "update", "delete"] - subject: Literal["map", "datalayer", "feature"] - metadata: Optional[dict] = None - key: Optional[str] = None - - -class PeerMessage(BaseModel): - """Message sent from a specific peer to another one""" - - kind: Literal["PeerMessage"] = "PeerMessage" - sender: str - recipient: str - # The message can be whatever the peers want. It's not checked by the server. - message: dict - - -class ServerRequest(BaseModel): - """A request towards the server""" - - kind: Literal["Server"] = "Server" - action: Literal["list-peers"] - - -class Request(RootModel): - """Any message coming from the websocket should be one of these, and will be rejected otherwise.""" - - root: Union[ServerRequest, PeerMessage, OperationMessage, JoinRequest] = Field( - discriminator="kind" - ) - - -class JoinResponse(BaseModel): - """Server response containing the list of peers""" - - kind: Literal["JoinResponse"] = "JoinResponse" - peers: list - uuid: str - - -class ListPeersResponse(BaseModel): - kind: Literal["ListPeersResponse"] = "ListPeersResponse" - peers: list - - -async def join_and_listen( - map_id: int, permissions: list, user: str | int, websocket: WebSocketClientProtocol -): - """Join a "room" with other connected peers, and wait for messages.""" - logging.debug(f"{user} joined room #{map_id}") - connections: Connections = CONNECTIONS[map_id] - _id: str = connections.join(websocket) - - # Assign an ID to the joining peer and return it the list of connected peers. - peers: list[WebSocketClientProtocol] = [ - connections.get_id(p) for p in connections.get_all_peers() - ] - response = JoinResponse(uuid=_id, peers=peers) - await websocket.send(response.model_dump_json()) - - # Notify all other peers of the new list of connected peers. - message = ListPeersResponse(peers=peers) - websockets.broadcast( - connections.get_other_peers(websocket), message.model_dump_json() - ) - - try: - async for raw_message in websocket: - if raw_message == "ping": - await websocket.send("pong") - continue - - # recompute the peers list at the time of message-sending. - # as doing so beforehand would miss new connections - other_peers = connections.get_other_peers(websocket) - try: - incoming = Request.model_validate_json(raw_message) - except ValidationError as e: - error = f"An error occurred when receiving the following message: {raw_message!r}" - logging.error(error, e) - else: - match incoming.root: - # Broadcast all operation messages to connected peers - case OperationMessage(): - websockets.broadcast(other_peers, raw_message) - - # Send peer messages to the proper peer - case PeerMessage(recipient=_id): - peer = connections.get(_id) - if peer: - await peer.send(raw_message) - - finally: - # On disconnect, remove the connection from the pool - connections.leave(websocket) - - # TODO: refactor this in a separate method. - # Notify all other peers of the new list of connected peers. - peers = [connections.get_id(p) for p in connections.get_all_peers()] - message = ListPeersResponse(peers=peers) - websockets.broadcast( - connections.get_other_peers(websocket), message.model_dump_json() - ) - - -async def handler(websocket: WebSocketClientProtocol): - """Main WebSocket handler. - - Check if the permission is granted and let the peer enter a room. - """ - raw_message = await websocket.recv() - - # The first event should always be 'join' - message: JoinRequest = JoinRequest.model_validate_json(raw_message) - signed = TimestampSigner().unsign_object(message.token, max_age=30) - user, map_id, permissions = signed.values() - - # Check if permissions for this map have been granted by the server - if "edit" in signed["permissions"]: - await join_and_listen(map_id, permissions, user, websocket) - - -def run(host: str, port: int): - if not settings.WEBSOCKET_ENABLED: - msg = ( - "WEBSOCKET_ENABLED should be set to True to run the WebSocket Server. " - "See the documentation at " - "https://docs.umap-project.org/en/stable/config/settings/#websocket_enabled " - "for more information." - ) - print(msg) - exit(1) - - async def _serve(): - async with serve(handler, host, port): - logging.debug(f"Waiting for connections on {host}:{port}") - await asyncio.Future() # run forever - - try: - asyncio.run(_serve()) - except KeyboardInterrupt: - print("Closing WebSocket server")