diff --git a/umap/asgi.py b/umap/asgi.py index 2ca12ddc..1f9d618a 100644 --- a/umap/asgi.py +++ b/umap/asgi.py @@ -1,15 +1,22 @@ import os -from channels.routing import ProtocolTypeRouter +from channels.routing import ProtocolTypeRouter, URLRouter +from channels.security.websocket import AllowedHostsOriginValidator from django.core.asgi import get_asgi_application +from django.urls import re_path + +from .sync import consumers os.environ.setdefault("DJANGO_SETTINGS_MODULE", "umap.settings") # Initialize Django ASGI application early to ensure the AppRegistry # is populated before importing code that may import ORM models. django_asgi_app = get_asgi_application() +urlpatterns = (re_path(r"ws/sync/(?P\w+)/$", consumers.SyncConsumer.as_asgi()),) + application = ProtocolTypeRouter( { "http": django_asgi_app, + "websocket": AllowedHostsOriginValidator(URLRouter(urlpatterns)), } ) diff --git a/umap/management/commands/run_websocket_server.py b/umap/management/commands/run_websocket_server.py deleted file mode 100644 index 2cb2db89..00000000 --- a/umap/management/commands/run_websocket_server.py +++ /dev/null @@ -1,23 +0,0 @@ -from django.conf import settings -from django.core.management.base import BaseCommand - -from umap import websocket_server - - -class Command(BaseCommand): - help = "Run the websocket server" - - def add_arguments(self, parser): - parser.add_argument( - "--host", - help="The server host to bind to.", - default=settings.WEBSOCKET_BACK_HOST, - ) - parser.add_argument( - "--port", - help="The server port to bind to.", - default=settings.WEBSOCKET_BACK_PORT, - ) - - def handle(self, *args, **options): - websocket_server.run(options["host"], options["port"]) diff --git a/umap/settings/base.py b/umap/settings/base.py index f47ad236..fedefe88 100644 --- a/umap/settings/base.py +++ b/umap/settings/base.py @@ -343,3 +343,4 @@ WEBSOCKET_ENABLED = env.bool("WEBSOCKET_ENABLED", default=False) WEBSOCKET_BACK_HOST = env("WEBSOCKET_BACK_HOST", default="localhost") WEBSOCKET_BACK_PORT = env.int("WEBSOCKET_BACK_PORT", default=8001) WEBSOCKET_FRONT_URI = env("WEBSOCKET_FRONT_URI", default="ws://localhost:8001") +CHANNEL_LAYERS = {"default": {"BACKEND": "channels.layers.InMemoryChannelLayer"}} diff --git a/umap/static/umap/js/modules/sync/engine.js b/umap/static/umap/js/modules/sync/engine.js index 212c2528..c9b32167 100644 --- a/umap/static/umap/js/modules/sync/engine.js +++ b/umap/static/umap/js/modules/sync/engine.js @@ -77,7 +77,7 @@ export class SyncEngine { start(authToken) { this.transport = new WebSocketTransport( - this._umap.properties.websocketURI, + Utils.template(this._umap.properties.websocketURI, { id: this._umap.id }), authToken, this ) @@ -125,7 +125,7 @@ export class SyncEngine { if (this.offline) return if (this.transport) { - this.transport.send('OperationMessage', message) + this.transport.send('OperationMessage', { sender: this.uuid, ...message }) } } @@ -177,6 +177,7 @@ export class SyncEngine { * @param {Object} payload */ onOperationMessage(payload) { + if (payload.sender === this.uuid) return this._operations.storeRemoteOperations([payload]) this._applyOperation(payload) } @@ -484,7 +485,7 @@ export class Operations { return ( Utils.deepEqual(local.subject, remote.subject) && Utils.deepEqual(local.metadata, remote.metadata) && - (!shouldCheckKey || (shouldCheckKey && local.key == remote.key)) + (!shouldCheckKey || (shouldCheckKey && local.key === remote.key)) ) } } diff --git a/umap/static/umap/js/modules/sync/websocket.js b/umap/static/umap/js/modules/sync/websocket.js index 26c99f26..accdbcc3 100644 --- a/umap/static/umap/js/modules/sync/websocket.js +++ b/umap/static/umap/js/modules/sync/websocket.js @@ -6,7 +6,7 @@ export class WebSocketTransport { constructor(webSocketURI, authToken, messagesReceiver) { this.receiver = messagesReceiver - this.websocket = new WebSocket(webSocketURI) + this.websocket = new WebSocket(`${webSocketURI}`) this.websocket.onopen = () => { this.send('JoinRequest', { token: authToken }) @@ -48,6 +48,7 @@ export class WebSocketTransport { } onMessage(wsMessage) { + console.log(wsMessage) if (wsMessage.data === 'pong') { this.pongReceived = true } else { diff --git a/umap/sync/__init__.py b/umap/sync/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/umap/sync/consumers.py b/umap/sync/consumers.py new file mode 100644 index 00000000..dc722279 --- /dev/null +++ b/umap/sync/consumers.py @@ -0,0 +1,86 @@ +import logging + +from channels.generic.websocket import AsyncWebsocketConsumer +from django.core.signing import TimestampSigner +from pydantic import ValidationError + +from .payloads import ( + JoinRequest, + JoinResponse, + ListPeersResponse, + OperationMessage, + PeerMessage, + Request, +) + + +class SyncConsumer(AsyncWebsocketConsumer): + @property + def peers(self): + return self.channel_layer.groups[self.map_id].keys() + + async def connect(self): + self.map_id = self.scope["url_route"]["kwargs"]["map_id"] + + # Join room group + await self.channel_layer.group_add(self.map_id, self.channel_name) + + self.is_authenticated = False + await self.accept() + + async def disconnect(self, close_code): + await self.channel_layer.group_discard(self.map_id, self.channel_name) + await self.send_peers_list() + + async def send_peers_list(self): + message = ListPeersResponse(peers=self.peers) + await self.broadcast(message.model_dump_json()) + + async def broadcast(self, message): + # Send to all channels (including sender!) + await self.channel_layer.group_send( + self.map_id, {"message": message, "type": "on_message"} + ) + + async def send_to(self, channel, message): + # Send to one given channel + await self.channel_layer.send( + channel, {"message": message, "type": "on_message"} + ) + + async def on_message(self, event): + # Send to self channel + await self.send(event["message"]) + + async def receive(self, text_data): + if not self.is_authenticated: + message = JoinRequest.model_validate_json(text_data) + signed = TimestampSigner().unsign_object(message.token, max_age=30) + user, map_id, permissions = signed.values() + if "edit" not in permissions: + return await self.disconnect() + response = JoinResponse(uuid=self.channel_name, peers=self.peers) + await self.send(response.model_dump_json()) + await self.send_peers_list() + self.is_authenticated = True + return + + if text_data == "ping": + return await self.send("pong") + + try: + incoming = Request.model_validate_json(text_data) + except ValidationError as error: + message = ( + f"An error occurred when receiving the following message: {text_data!r}" + ) + logging.error(message, error) + else: + match incoming.root: + # Broadcast all operation messages to connected peers + case OperationMessage(): + await self.broadcast(text_data) + + # Send peer messages to the proper peer + case PeerMessage(): + await self.send_to(incoming.root.recipient, text_data) diff --git a/umap/sync/payloads.py b/umap/sync/payloads.py new file mode 100644 index 00000000..6a15a3f1 --- /dev/null +++ b/umap/sync/payloads.py @@ -0,0 +1,47 @@ +from typing import Literal, Optional, Union + +from pydantic import BaseModel, Field, RootModel + + +class JoinRequest(BaseModel): + kind: Literal["JoinRequest"] = "JoinRequest" + token: str + + +class OperationMessage(BaseModel): + """Message sent from one peer to all the others""" + + kind: Literal["OperationMessage"] = "OperationMessage" + verb: Literal["upsert", "update", "delete"] + subject: Literal["map", "datalayer", "feature"] + metadata: Optional[dict] = None + key: Optional[str] = None + + +class PeerMessage(BaseModel): + """Message sent from a specific peer to another one""" + + kind: Literal["PeerMessage"] = "PeerMessage" + sender: str + recipient: str + # The message can be whatever the peers want. It's not checked by the server. + message: dict + + +class Request(RootModel): + """Any message coming from the websocket should be one of these, and will be rejected otherwise.""" + + root: Union[PeerMessage, OperationMessage] = Field(discriminator="kind") + + +class JoinResponse(BaseModel): + """Server response containing the list of peers""" + + kind: Literal["JoinResponse"] = "JoinResponse" + peers: list + uuid: str + + +class ListPeersResponse(BaseModel): + kind: Literal["ListPeersResponse"] = "ListPeersResponse" + peers: list diff --git a/umap/tests/integration/conftest.py b/umap/tests/integration/conftest.py index 4601a709..620ab5ec 100644 --- a/umap/tests/integration/conftest.py +++ b/umap/tests/integration/conftest.py @@ -5,6 +5,7 @@ import time from pathlib import Path import pytest +from channels.testing import ChannelsLiveServerTestCase from playwright.sync_api import expect from ..base import mock_tiles @@ -87,3 +88,15 @@ def websocket_server(): yield ds_proc # Shut it down at the end of the pytest session ds_proc.terminate() + + +@pytest.fixture(scope="function") +def channels_live_server(request, settings): + server = ChannelsLiveServerTestCase() + server.serve_static = False + server._pre_setup() + settings.WEBSOCKET_FRONT_URI = f"{server.live_server_ws_url}/ws/sync/{{id}}/" + + yield server + + server._post_teardown() diff --git a/umap/tests/integration/test_websocket_sync.py b/umap/tests/integration/test_websocket_sync.py index c5e56e89..17a99436 100644 --- a/umap/tests/integration/test_websocket_sync.py +++ b/umap/tests/integration/test_websocket_sync.py @@ -12,7 +12,7 @@ DATALAYER_UPDATE = re.compile(r".*/datalayer/update/.*") @pytest.mark.xdist_group(name="websockets") def test_websocket_connection_can_sync_markers( - new_page, live_server, websocket_server, tilelayer + new_page, live_server, channels_live_server, tilelayer ): map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map.settings["properties"]["syncEnabled"] = True @@ -80,7 +80,7 @@ def test_websocket_connection_can_sync_markers( @pytest.mark.xdist_group(name="websockets") def test_websocket_connection_can_sync_polygons( - context, live_server, websocket_server, tilelayer + context, live_server, channels_live_server, tilelayer ): map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map.settings["properties"]["syncEnabled"] = True @@ -164,7 +164,7 @@ def test_websocket_connection_can_sync_polygons( @pytest.mark.xdist_group(name="websockets") def test_websocket_connection_can_sync_map_properties( - new_page, live_server, websocket_server, tilelayer + new_page, live_server, channels_live_server, tilelayer ): map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map.settings["properties"]["syncEnabled"] = True @@ -196,7 +196,7 @@ def test_websocket_connection_can_sync_map_properties( @pytest.mark.xdist_group(name="websockets") def test_websocket_connection_can_sync_datalayer_properties( - new_page, live_server, websocket_server, tilelayer + new_page, live_server, channels_live_server, tilelayer ): map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map.settings["properties"]["syncEnabled"] = True @@ -225,7 +225,7 @@ def test_websocket_connection_can_sync_datalayer_properties( @pytest.mark.xdist_group(name="websockets") def test_websocket_connection_can_sync_cloned_polygons( - context, live_server, websocket_server, tilelayer + context, live_server, channels_live_server, tilelayer ): map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map.settings["properties"]["syncEnabled"] = True @@ -288,7 +288,7 @@ def test_websocket_connection_can_sync_cloned_polygons( @pytest.mark.xdist_group(name="websockets") def test_websocket_connection_can_sync_late_joining_peer( - new_page, live_server, websocket_server, tilelayer + new_page, live_server, channels_live_server, tilelayer ): map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map.settings["properties"]["syncEnabled"] = True @@ -349,7 +349,7 @@ def test_websocket_connection_can_sync_late_joining_peer( @pytest.mark.xdist_group(name="websockets") -def test_should_sync_datalayers(new_page, live_server, websocket_server, tilelayer): +def test_should_sync_datalayers(new_page, live_server, channels_live_server, tilelayer): map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map.settings["properties"]["syncEnabled"] = True map.save() @@ -422,7 +422,7 @@ def test_should_sync_datalayers(new_page, live_server, websocket_server, tilelay @pytest.mark.xdist_group(name="websockets") def test_create_and_sync_map( - new_page, live_server, websocket_server, tilelayer, login, user + new_page, live_server, channels_live_server, tilelayer, login, user ): # Create a syncable map with peerA peerA = login(user, prefix="Page A") diff --git a/umap/tests/test_websocket_server.py b/umap/tests/test_websocket_server.py deleted file mode 100644 index 62bc93e9..00000000 --- a/umap/tests/test_websocket_server.py +++ /dev/null @@ -1,22 +0,0 @@ -from umap.websocket_server import OperationMessage, PeerMessage, Request, ServerRequest - - -def test_messages_are_parsed_correctly(): - server = Request.model_validate(dict(kind="Server", action="list-peers")).root - assert type(server) is ServerRequest - - operation = Request.model_validate( - dict( - kind="OperationMessage", - verb="upsert", - subject="map", - metadata={}, - key="key", - ) - ).root - assert type(operation) is OperationMessage - - peer_message = Request.model_validate( - dict(kind="PeerMessage", sender="Alice", recipient="Bob", message={}) - ).root - assert type(peer_message) is PeerMessage diff --git a/umap/websocket_server.py b/umap/websocket_server.py deleted file mode 100644 index 6483d648..00000000 --- a/umap/websocket_server.py +++ /dev/null @@ -1,202 +0,0 @@ -#!/usr/bin/env python - -import asyncio -import logging -import uuid -from collections import defaultdict -from typing import Literal, Optional, Union - -import websockets -from django.conf import settings -from django.core.signing import TimestampSigner -from pydantic import BaseModel, Field, RootModel, ValidationError -from websockets import WebSocketClientProtocol -from websockets.server import serve - - -class Connections: - def __init__(self) -> None: - self._connections: set[WebSocketClientProtocol] = set() - self._ids: dict[WebSocketClientProtocol, str] = dict() - - def join(self, websocket: WebSocketClientProtocol) -> str: - self._connections.add(websocket) - _id = str(uuid.uuid4()) - self._ids[websocket] = _id - return _id - - def leave(self, websocket: WebSocketClientProtocol) -> None: - self._connections.remove(websocket) - del self._ids[websocket] - - def get(self, id) -> WebSocketClientProtocol: - # use an iterator to stop iterating as soon as we found - return next(k for k, v in self._ids.items() if v == id) - - def get_id(self, websocket: WebSocketClientProtocol): - return self._ids[websocket] - - def get_other_peers( - self, websocket: WebSocketClientProtocol - ) -> set[WebSocketClientProtocol]: - return self._connections - {websocket} - - def get_all_peers(self) -> set[WebSocketClientProtocol]: - return self._connections - - -# Contains the list of websocket connections handled by this process. -# It's a mapping of map_id to a set of the active websocket connections -CONNECTIONS: defaultdict[int, Connections] = defaultdict(Connections) - - -class JoinRequest(BaseModel): - kind: Literal["JoinRequest"] = "JoinRequest" - token: str - - -class OperationMessage(BaseModel): - """Message sent from one peer to all the others""" - - kind: Literal["OperationMessage"] = "OperationMessage" - verb: Literal["upsert", "update", "delete"] - subject: Literal["map", "datalayer", "feature"] - metadata: Optional[dict] = None - key: Optional[str] = None - - -class PeerMessage(BaseModel): - """Message sent from a specific peer to another one""" - - kind: Literal["PeerMessage"] = "PeerMessage" - sender: str - recipient: str - # The message can be whatever the peers want. It's not checked by the server. - message: dict - - -class ServerRequest(BaseModel): - """A request towards the server""" - - kind: Literal["Server"] = "Server" - action: Literal["list-peers"] - - -class Request(RootModel): - """Any message coming from the websocket should be one of these, and will be rejected otherwise.""" - - root: Union[ServerRequest, PeerMessage, OperationMessage] = Field( - discriminator="kind" - ) - - -class JoinResponse(BaseModel): - """Server response containing the list of peers""" - - kind: Literal["JoinResponse"] = "JoinResponse" - peers: list - uuid: str - - -class ListPeersResponse(BaseModel): - kind: Literal["ListPeersResponse"] = "ListPeersResponse" - peers: list - - -async def join_and_listen( - map_id: int, permissions: list, user: str | int, websocket: WebSocketClientProtocol -): - """Join a "room" with other connected peers, and wait for messages.""" - logging.debug(f"{user} joined room #{map_id}") - connections: Connections = CONNECTIONS[map_id] - _id: str = connections.join(websocket) - - # Assign an ID to the joining peer and return it the list of connected peers. - peers: list[WebSocketClientProtocol] = [ - connections.get_id(p) for p in connections.get_all_peers() - ] - response = JoinResponse(uuid=_id, peers=peers) - await websocket.send(response.model_dump_json()) - - # Notify all other peers of the new list of connected peers. - message = ListPeersResponse(peers=peers) - websockets.broadcast( - connections.get_other_peers(websocket), message.model_dump_json() - ) - - try: - async for raw_message in websocket: - if raw_message == "ping": - await websocket.send("pong") - continue - - # recompute the peers list at the time of message-sending. - # as doing so beforehand would miss new connections - other_peers = connections.get_other_peers(websocket) - try: - incoming = Request.model_validate_json(raw_message) - except ValidationError as e: - error = f"An error occurred when receiving the following message: {raw_message!r}" - logging.error(error, e) - else: - match incoming.root: - # Broadcast all operation messages to connected peers - case OperationMessage(): - websockets.broadcast(other_peers, raw_message) - - # Send peer messages to the proper peer - case PeerMessage(recipient=_id): - peer = connections.get(_id) - if peer: - await peer.send(raw_message) - - finally: - # On disconnect, remove the connection from the pool - connections.leave(websocket) - - # TODO: refactor this in a separate method. - # Notify all other peers of the new list of connected peers. - peers = [connections.get_id(p) for p in connections.get_all_peers()] - message = ListPeersResponse(peers=peers) - websockets.broadcast( - connections.get_other_peers(websocket), message.model_dump_json() - ) - - -async def handler(websocket: WebSocketClientProtocol): - """Main WebSocket handler. - - Check if the permission is granted and let the peer enter a room. - """ - raw_message = await websocket.recv() - - # The first event should always be 'join' - message: JoinRequest = JoinRequest.model_validate_json(raw_message) - signed = TimestampSigner().unsign_object(message.token, max_age=30) - user, map_id, permissions = signed.values() - - # Check if permissions for this map have been granted by the server - if "edit" in signed["permissions"]: - await join_and_listen(map_id, permissions, user, websocket) - - -def run(host: str, port: int): - if not settings.WEBSOCKET_ENABLED: - msg = ( - "WEBSOCKET_ENABLED should be set to True to run the WebSocket Server. " - "See the documentation at " - "https://docs.umap-project.org/en/stable/config/settings/#websocket_enabled " - "for more information." - ) - print(msg) - exit(1) - - async def _serve(): - async with serve(handler, host, port): - logging.debug(f"Waiting for connections on {host}:{port}") - await asyncio.Future() # run forever - - try: - asyncio.run(_serve()) - except KeyboardInterrupt: - print("Closing WebSocket server")