diff --git a/.github/workflows/test-docs.yml b/.github/workflows/test-docs.yml index db76112d..f2ce937a 100644 --- a/.github/workflows/test-docs.yml +++ b/.github/workflows/test-docs.yml @@ -20,7 +20,11 @@ jobs: POSTGRES_PASSWORD: postgres POSTGRES_DB: postgres options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 - + redis: + image: redis + options: --health-cmd "redis-cli ping" --health-interval 10s --health-timeout 5s --health-retries 5 + ports: + - 6379:6379 strategy: fail-fast: false matrix: @@ -48,6 +52,8 @@ jobs: DJANGO_SETTINGS_MODULE: 'umap.tests.settings' UMAP_SETTINGS: 'umap/tests/settings.py' PLAYWRIGHT_TIMEOUT: '20000' + REDIS_HOST: localhost + REDIS_PORT: 6379 lint: runs-on: ubuntu-latest steps: diff --git a/pyproject.toml b/pyproject.toml index 5ed21b1f..5273f9b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ dev = [ "isort==5.13.2", ] test = [ + "daphne==4.1.2", "factory-boy==3.3.1", "playwright>=1.39", "pytest==8.3.4", @@ -70,10 +71,8 @@ s3 = [ "django-storages[s3]==1.14.4", ] sync = [ - "channels==4.2.0", - "daphne==4.1.2", "pydantic==2.10.5", - "websockets==13.1", + "redis==5.2.1", ] [project.scripts] @@ -104,3 +103,6 @@ format_css=true blank_line_after_tag="load,extends" line_break_after_multiline_tag=true +[lint] +# Disable autoremove of unused import. +unfixable = ["F401"] diff --git a/umap/asgi.py b/umap/asgi.py index 2ca12ddc..47d69a93 100644 --- a/umap/asgi.py +++ b/umap/asgi.py @@ -1,15 +1,20 @@ import os -from channels.routing import ProtocolTypeRouter +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "umap.settings") + from django.core.asgi import get_asgi_application -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "umap.settings") +from .sync.app import application as ws_application + # Initialize Django ASGI application early to ensure the AppRegistry # is populated before importing code that may import ORM models. django_asgi_app = get_asgi_application() -application = ProtocolTypeRouter( - { - "http": django_asgi_app, - } -) + +async def application(scope, receive, send): + if scope["type"] == "http": + await django_asgi_app(scope, receive, send) + elif scope["type"] == "websocket": + await ws_application(scope, receive, send) + else: + raise NotImplementedError(f"Unknown scope type {scope['type']}") diff --git a/umap/management/commands/run_websocket_server.py b/umap/management/commands/run_websocket_server.py deleted file mode 100644 index 2cb2db89..00000000 --- a/umap/management/commands/run_websocket_server.py +++ /dev/null @@ -1,23 +0,0 @@ -from django.conf import settings -from django.core.management.base import BaseCommand - -from umap import websocket_server - - -class Command(BaseCommand): - help = "Run the websocket server" - - def add_arguments(self, parser): - parser.add_argument( - "--host", - help="The server host to bind to.", - default=settings.WEBSOCKET_BACK_HOST, - ) - parser.add_argument( - "--port", - help="The server port to bind to.", - default=settings.WEBSOCKET_BACK_PORT, - ) - - def handle(self, *args, **options): - websocket_server.run(options["host"], options["port"]) diff --git a/umap/settings/base.py b/umap/settings/base.py index f47ad236..788a03b3 100644 --- a/umap/settings/base.py +++ b/umap/settings/base.py @@ -342,4 +342,5 @@ LOGGING = { WEBSOCKET_ENABLED = env.bool("WEBSOCKET_ENABLED", default=False) WEBSOCKET_BACK_HOST = env("WEBSOCKET_BACK_HOST", default="localhost") WEBSOCKET_BACK_PORT = env.int("WEBSOCKET_BACK_PORT", default=8001) -WEBSOCKET_FRONT_URI = env("WEBSOCKET_FRONT_URI", default="ws://localhost:8001") + +REDIS_URL = "redis://localhost:6379" diff --git a/umap/static/umap/js/modules/sync/engine.js b/umap/static/umap/js/modules/sync/engine.js index 212c2528..05994544 100644 --- a/umap/static/umap/js/modules/sync/engine.js +++ b/umap/static/umap/js/modules/sync/engine.js @@ -62,6 +62,7 @@ export class SyncEngine { this._reconnectDelay = RECONNECT_DELAY this.websocketConnected = false this.closeRequested = false + this.peerId = Utils.generateId() } async authenticate() { @@ -76,10 +77,14 @@ export class SyncEngine { } start(authToken) { + const path = this._umap.urls.get('ws_sync', { map_id: this._umap.id }) + const protocol = window.location.protocol === 'http:' ? 'ws:' : 'wss:' this.transport = new WebSocketTransport( - this._umap.properties.websocketURI, + `${protocol}//${window.location.host}${path}`, authToken, - this + this, + this.peerId, + this._umap.properties.user?.name ) } @@ -125,7 +130,7 @@ export class SyncEngine { if (this.offline) return if (this.transport) { - this.transport.send('OperationMessage', message) + this.transport.send('OperationMessage', { sender: this.peerId, ...message }) } } @@ -142,7 +147,7 @@ export class SyncEngine { } getNumberOfConnectedPeers() { - if (this.peers) return this.peers.length + if (this.peers) return Object.keys(this.peers).length return 0 } @@ -177,6 +182,7 @@ export class SyncEngine { * @param {Object} payload */ onOperationMessage(payload) { + if (payload.sender === this.peerId) return this._operations.storeRemoteOperations([payload]) this._applyOperation(payload) } @@ -188,9 +194,8 @@ export class SyncEngine { * @param {string} payload.uuid The server-assigned uuid for this peer * @param {string[]} payload.peers The list of peers uuids */ - onJoinResponse({ uuid, peers }) { - debug('received join response', { uuid, peers }) - this.uuid = uuid + onJoinResponse({ peer, peers }) { + debug('received join response', { peer, peers }) this.onListPeersResponse({ peers }) // Get one peer at random @@ -211,7 +216,7 @@ export class SyncEngine { * @param {string[]} payload.peers The list of peers uuids */ onListPeersResponse({ peers }) { - debug('received peerinfo', { peers }) + debug('received peerinfo', peers) this.peers = peers this.updaters.map.update({ key: 'numberOfConnectedPeers' }) } @@ -286,7 +291,7 @@ export class SyncEngine { sendToPeer(recipient, verb, payload) { payload.verb = verb this.transport.send('PeerMessage', { - sender: this.uuid, + sender: this.peerId, recipient: recipient, message: payload, }) @@ -298,7 +303,7 @@ export class SyncEngine { * @returns {string|bool} the selected peer uuid, or False if none was found. */ _getRandomPeer() { - const otherPeers = this.peers.filter((p) => p !== this.uuid) + const otherPeers = Object.keys(this.peers).filter((p) => p !== this.peerId) if (otherPeers.length > 0) { const random = Math.floor(Math.random() * otherPeers.length) return otherPeers[random] @@ -484,7 +489,7 @@ export class Operations { return ( Utils.deepEqual(local.subject, remote.subject) && Utils.deepEqual(local.metadata, remote.metadata) && - (!shouldCheckKey || (shouldCheckKey && local.key == remote.key)) + (!shouldCheckKey || (shouldCheckKey && local.key === remote.key)) ) } } diff --git a/umap/static/umap/js/modules/sync/websocket.js b/umap/static/umap/js/modules/sync/websocket.js index 26c99f26..5a18f880 100644 --- a/umap/static/umap/js/modules/sync/websocket.js +++ b/umap/static/umap/js/modules/sync/websocket.js @@ -3,13 +3,13 @@ const PING_INTERVAL = 30000 const FIRST_CONNECTION_TIMEOUT = 2000 export class WebSocketTransport { - constructor(webSocketURI, authToken, messagesReceiver) { + constructor(webSocketURI, authToken, messagesReceiver, peerId, username) { this.receiver = messagesReceiver this.websocket = new WebSocket(webSocketURI) this.websocket.onopen = () => { - this.send('JoinRequest', { token: authToken }) + this.send('JoinRequest', { token: authToken, peer: peerId, username }) this.receiver.onConnection() } this.websocket.addEventListener('message', this.onMessage.bind(this)) @@ -21,6 +21,10 @@ export class WebSocketTransport { } } + this.websocket.onerror = (error) => { + console.log('WS ERROR', error) + } + this.ensureOpen = setInterval(() => { if (this.websocket.readyState !== WebSocket.OPEN) { this.websocket.close() @@ -34,6 +38,7 @@ export class WebSocketTransport { // See https://making.close.com/posts/reliable-websockets/ for more details. this.pingInterval = setInterval(() => { if (this.websocket.readyState === WebSocket.OPEN) { + console.log('sending ping') this.websocket.send('ping') this.pongReceived = false setTimeout(() => { @@ -63,6 +68,7 @@ export class WebSocketTransport { } close() { + console.log('Closing') this.receiver.closeRequested = true this.websocket.close() } diff --git a/umap/sync/__init__.py b/umap/sync/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/umap/sync/app.py b/umap/sync/app.py new file mode 100644 index 00000000..27c52262 --- /dev/null +++ b/umap/sync/app.py @@ -0,0 +1,181 @@ +import asyncio +import logging + +import redis.asyncio as redis +from django.conf import settings +from django.core.signing import TimestampSigner +from django.urls import path +from pydantic import ValidationError + +from .payloads import ( + JoinRequest, + JoinResponse, + ListPeersResponse, + OperationMessage, + PeerMessage, + Request, +) + + +async def application(scope, receive, send): + path = scope["path"].lstrip("/") + for pattern in urlpatterns: + if matched := pattern.resolve(path): + await matched.func(scope, receive, send, **matched.kwargs) + break + else: + await send({"type": "websocket.close"}) + + +async def sync(scope, receive, send, **kwargs): + peer = Peer(kwargs["map_id"]) + peer._send = send + while True: + event = await receive() + + if event["type"] == "websocket.connect": + try: + await peer.connect() + await send({"type": "websocket.accept"}) + except ValueError: + await send({"type": "websocket.close"}) + + if event["type"] == "websocket.disconnect": + await peer.disconnect() + break + + if event["type"] == "websocket.receive": + if event["text"] == "ping": + await send({"type": "websocket.send", "text": "pong"}) + else: + await peer.receive(event["text"]) + + +class Peer: + def __init__(self, map_id, username=None): + self.username = username or "" + self.map_id = map_id + self.is_authenticated = False + self._subscriptions = [] + + @property + def room_key(self): + return f"umap:{self.map_id}" + + @property + def peer_key(self): + return f"user:{self.map_id}:{self.peer_id}" + + async def get_peers(self): + known = await self.client.hgetall(self.room_key) + active = await self.client.pubsub_channels(f"user:{self.map_id}:*") + if not active: + # Poor man way of deleting stale usernames from the store + # HEXPIRE command is not in the open source Redis version + await self.client.delete(self.room_key) + await self.store_username() + active = [name.split(b":")[-1] for name in active] + if self.peer_id.encode() not in active: + # Our connection may not yet be active + active.append(self.peer_id.encode()) + return {k: v for k, v in known.items() if k in active} + + async def store_username(self): + await self.client.hset(self.room_key, self.peer_id, self.username) + + async def listen_to_channel(self, channel_name): + async def reader(pubsub): + await pubsub.subscribe(channel_name) + while True: + if pubsub.connection is None: + # It has been unsubscribed/closed. + break + try: + message = await pubsub.get_message(ignore_subscribe_messages=True) + except Exception as err: + print(err) + break + if message is not None: + await self.send(message["data"].decode()) + await asyncio.sleep(0.001) # Be nice with the server + + async with self.client.pubsub() as pubsub: + self._subscriptions.append(pubsub) + asyncio.create_task(reader(pubsub)) + + async def listen(self): + await self.listen_to_channel(self.room_key) + await self.listen_to_channel(self.peer_key) + + async def connect(self): + self.client = redis.from_url(settings.REDIS_URL) + + async def disconnect(self): + await self.client.hdel(self.room_key, self.peer_id) + for pubsub in self._subscriptions: + await pubsub.unsubscribe() + await pubsub.close() + await self.send_peers_list() + await self.client.aclose() + + async def send_peers_list(self): + message = ListPeersResponse(peers=await self.get_peers()) + await self.broadcast(message.model_dump_json()) + + async def broadcast(self, message): + print("BROADCASTING", message) + # Send to all channels (including sender!) + await self.client.publish(self.room_key, message) + + async def send_to(self, peer_id, message): + print("SEND TO", peer_id, message) + # Send to one given channel + await self.client.publish(f"user:{self.map_id}:{peer_id}", message) + + async def receive(self, text_data): + if not self.is_authenticated: + print("AUTHENTICATING", text_data) + message = JoinRequest.model_validate_json(text_data) + signed = TimestampSigner().unsign_object(message.token, max_age=30) + user, map_id, permissions = signed.values() + assert str(map_id) == self.map_id + if "edit" not in permissions: + return await self.disconnect() + self.peer_id = message.peer + self.username = message.username + print("AUTHENTICATED", self.peer_id) + await self.store_username() + await self.listen() + response = JoinResponse(peer=self.peer_id, peers=await self.get_peers()) + await self.send(response.model_dump_json()) + await self.send_peers_list() + self.is_authenticated = True + return + + try: + incoming = Request.model_validate_json(text_data) + except ValidationError as error: + message = ( + f"An error occurred when receiving the following message: {text_data!r}" + ) + logging.error(message, error) + else: + match incoming.root: + # Broadcast all operation messages to connected peers + case OperationMessage(): + await self.broadcast(text_data) + + # Send peer messages to the proper peer + case PeerMessage(): + await self.send_to(incoming.root.recipient, text_data) + + async def send(self, text): + print(" FORWARDING TO", self.peer_id, text) + try: + await self._send({"type": "websocket.send", "text": text}) + except Exception as err: + print("Error sending message:", text) + print(err) + + +urlpatterns = [path("ws/sync/", name="ws_sync", view=sync)] diff --git a/umap/sync/payloads.py b/umap/sync/payloads.py new file mode 100644 index 00000000..9ab2bf1a --- /dev/null +++ b/umap/sync/payloads.py @@ -0,0 +1,49 @@ +from typing import Literal, Optional, Union + +from pydantic import BaseModel, Field, RootModel + + +class JoinRequest(BaseModel): + kind: Literal["JoinRequest"] = "JoinRequest" + token: str + peer: str + username: Optional[str] = "" + + +class OperationMessage(BaseModel): + """Message sent from one peer to all the others""" + + kind: Literal["OperationMessage"] = "OperationMessage" + verb: Literal["upsert", "update", "delete"] + subject: Literal["map", "datalayer", "feature"] + metadata: Optional[dict] = None + key: Optional[str] = None + + +class PeerMessage(BaseModel): + """Message sent from a specific peer to another one""" + + kind: Literal["PeerMessage"] = "PeerMessage" + sender: str + recipient: str + # The message can be whatever the peers want. It's not checked by the server. + message: dict + + +class Request(RootModel): + """Any message coming from the websocket should be one of these, and will be rejected otherwise.""" + + root: Union[PeerMessage, OperationMessage] = Field(discriminator="kind") + + +class JoinResponse(BaseModel): + """Server response containing the list of peers""" + + kind: Literal["JoinResponse"] = "JoinResponse" + peers: dict + peer: str + + +class ListPeersResponse(BaseModel): + kind: Literal["ListPeersResponse"] = "ListPeersResponse" + peers: dict diff --git a/umap/tests/integration/conftest.py b/umap/tests/integration/conftest.py index 4601a709..bbb202d1 100644 --- a/umap/tests/integration/conftest.py +++ b/umap/tests/integration/conftest.py @@ -1,12 +1,13 @@ import os import re -import subprocess -import time -from pathlib import Path import pytest +from daphne.testing import DaphneProcess +from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler from playwright.sync_api import expect +from umap.asgi import application + from ..base import mock_tiles @@ -67,23 +68,15 @@ def login(new_page, settings, live_server): return do_login -@pytest.fixture -def websocket_server(): - # Find the test-settings, and put them in the current environment - settings_path = (Path(__file__).parent.parent / "settings.py").absolute().as_posix() - os.environ["UMAP_SETTINGS"] = settings_path +@pytest.fixture(scope="function") +def asgi_live_server(request, live_server): + server = DaphneProcess("localhost", lambda: ASGIStaticFilesHandler(application)) + server.start() + server.ready.wait() + port = server.port.value + server.url = f"http://localhost:{port}" - ds_proc = subprocess.Popen( - [ - "umap", - "run_websocket_server", - ], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - ) - time.sleep(2) - # Ensure it started properly before yielding - assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8") - yield ds_proc - # Shut it down at the end of the pytest session - ds_proc.terminate() + yield server + + server.terminate() + server.join() diff --git a/umap/tests/integration/test_websocket_sync.py b/umap/tests/integration/test_websocket_sync.py index 96946ce8..a69e134f 100644 --- a/umap/tests/integration/test_websocket_sync.py +++ b/umap/tests/integration/test_websocket_sync.py @@ -1,6 +1,8 @@ import re import pytest +import redis +from django.conf import settings from playwright.sync_api import expect from umap.models import DataLayer, Map @@ -9,11 +11,21 @@ from ..base import DataLayerFactory, MapFactory DATALAYER_UPDATE = re.compile(r".*/datalayer/update/.*") +pytestmark = pytest.mark.django_db + + +def setup_function(): + # Sync client to prevent headache with pytest / pytest-asyncio and async + client = redis.from_url(settings.REDIS_URL) + # Make sure there are no dead peers in the Redis hash, otherwise asking for + # operations from another peer may never be answered + # FIXME this should not happen in an ideal world + assert client.connection_pool.connection_kwargs["db"] == 15 + client.flushdb() + @pytest.mark.xdist_group(name="websockets") -def test_websocket_connection_can_sync_markers( - new_page, live_server, websocket_server, tilelayer -): +def test_websocket_connection_can_sync_markers(new_page, asgi_live_server, tilelayer): map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map.settings["properties"]["syncEnabled"] = True map.save() @@ -21,9 +33,9 @@ def test_websocket_connection_can_sync_markers( # Create two tabs peerA = new_page("Page A") - peerA.goto(f"{live_server.url}{map.get_absolute_url()}?edit") + peerA.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit") peerB = new_page("Page B") - peerB.goto(f"{live_server.url}{map.get_absolute_url()}?edit") + peerB.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit") a_marker_pane = peerA.locator(".leaflet-marker-pane > div") b_marker_pane = peerB.locator(".leaflet-marker-pane > div") @@ -79,9 +91,7 @@ def test_websocket_connection_can_sync_markers( @pytest.mark.xdist_group(name="websockets") -def test_websocket_connection_can_sync_polygons( - context, live_server, websocket_server, tilelayer -): +def test_websocket_connection_can_sync_polygons(context, asgi_live_server, tilelayer): map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map.settings["properties"]["syncEnabled"] = True map.save() @@ -89,9 +99,9 @@ def test_websocket_connection_can_sync_polygons( # Create two tabs peerA = context.new_page() - peerA.goto(f"{live_server.url}{map.get_absolute_url()}?edit") + peerA.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit") peerB = context.new_page() - peerB.goto(f"{live_server.url}{map.get_absolute_url()}?edit") + peerB.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit") b_map_el = peerB.locator("#map") @@ -164,7 +174,7 @@ def test_websocket_connection_can_sync_polygons( @pytest.mark.xdist_group(name="websockets") def test_websocket_connection_can_sync_map_properties( - new_page, live_server, websocket_server, tilelayer + new_page, asgi_live_server, tilelayer ): map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map.settings["properties"]["syncEnabled"] = True @@ -173,9 +183,9 @@ def test_websocket_connection_can_sync_map_properties( # Create two tabs peerA = new_page() - peerA.goto(f"{live_server.url}{map.get_absolute_url()}?edit") + peerA.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit") peerB = new_page() - peerB.goto(f"{live_server.url}{map.get_absolute_url()}?edit") + peerB.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit") # Name change is synced peerA.get_by_role("link", name="Edit map name and caption").click() @@ -198,7 +208,7 @@ def test_websocket_connection_can_sync_map_properties( @pytest.mark.xdist_group(name="websockets") def test_websocket_connection_can_sync_datalayer_properties( - new_page, live_server, websocket_server, tilelayer + new_page, asgi_live_server, tilelayer ): map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map.settings["properties"]["syncEnabled"] = True @@ -207,9 +217,9 @@ def test_websocket_connection_can_sync_datalayer_properties( # Create two tabs peerA = new_page() - peerA.goto(f"{live_server.url}{map.get_absolute_url()}?edit") + peerA.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit") peerB = new_page() - peerB.goto(f"{live_server.url}{map.get_absolute_url()}?edit") + peerB.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit") # Layer addition, name and type are synced peerA.get_by_role("link", name="Manage layers").click() @@ -227,7 +237,7 @@ def test_websocket_connection_can_sync_datalayer_properties( @pytest.mark.xdist_group(name="websockets") def test_websocket_connection_can_sync_cloned_polygons( - context, live_server, websocket_server, tilelayer + context, asgi_live_server, tilelayer ): map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map.settings["properties"]["syncEnabled"] = True @@ -236,9 +246,9 @@ def test_websocket_connection_can_sync_cloned_polygons( # Create two tabs peerA = context.new_page() - peerA.goto(f"{live_server.url}{map.get_absolute_url()}?edit") + peerA.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit") peerB = context.new_page() - peerB.goto(f"{live_server.url}{map.get_absolute_url()}?edit") + peerB.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit") b_map_el = peerB.locator("#map") @@ -290,7 +300,7 @@ def test_websocket_connection_can_sync_cloned_polygons( @pytest.mark.xdist_group(name="websockets") def test_websocket_connection_can_sync_late_joining_peer( - new_page, live_server, websocket_server, tilelayer + new_page, asgi_live_server, tilelayer ): map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map.settings["properties"]["syncEnabled"] = True @@ -299,7 +309,7 @@ def test_websocket_connection_can_sync_late_joining_peer( # Create first peer (A) and have it join immediately peerA = new_page("Page A") - peerA.goto(f"{live_server.url}{map.get_absolute_url()}?edit") + peerA.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit") # Add a marker from peer A a_create_marker = peerA.get_by_title("Draw a marker") @@ -326,7 +336,7 @@ def test_websocket_connection_can_sync_late_joining_peer( # Now create peer B and have it join peerB = new_page("Page B") - peerB.goto(f"{live_server.url}{map.get_absolute_url()}?edit") + peerB.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit") # Check if peer B has received all the updates b_marker_pane = peerB.locator(".leaflet-marker-pane > div") @@ -351,7 +361,7 @@ def test_websocket_connection_can_sync_late_joining_peer( @pytest.mark.xdist_group(name="websockets") -def test_should_sync_datalayers(new_page, live_server, websocket_server, tilelayer): +def test_should_sync_datalayers(new_page, asgi_live_server, tilelayer): map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map.settings["properties"]["syncEnabled"] = True map.save() @@ -360,9 +370,9 @@ def test_should_sync_datalayers(new_page, live_server, websocket_server, tilelay # Create two tabs peerA = new_page("Page A") - peerA.goto(f"{live_server.url}{map.get_absolute_url()}?edit") + peerA.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit") peerB = new_page("Page B") - peerB.goto(f"{live_server.url}{map.get_absolute_url()}?edit") + peerB.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit") # Create a new layer from peerA peerA.get_by_role("link", name="Manage layers").click() @@ -423,9 +433,7 @@ def test_should_sync_datalayers(new_page, live_server, websocket_server, tilelay @pytest.mark.xdist_group(name="websockets") -def test_should_sync_datalayers_delete( - new_page, live_server, websocket_server, tilelayer -): +def test_should_sync_datalayers_delete(new_page, asgi_live_server, tilelayer): map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map.settings["properties"]["syncEnabled"] = True map.save() @@ -464,9 +472,9 @@ def test_should_sync_datalayers_delete( # Create two tabs peerA = new_page("Page A") - peerA.goto(f"{live_server.url}{map.get_absolute_url()}?edit") + peerA.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit") peerB = new_page("Page B") - peerB.goto(f"{live_server.url}{map.get_absolute_url()}?edit") + peerB.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit") peerA.get_by_role("button", name="Open browser").click() expect(peerA.get_by_text("datalayer 1")).to_be_visible() @@ -489,12 +497,10 @@ def test_should_sync_datalayers_delete( @pytest.mark.xdist_group(name="websockets") -def test_create_and_sync_map( - new_page, live_server, websocket_server, tilelayer, login, user -): +def test_create_and_sync_map(new_page, asgi_live_server, tilelayer, login, user): # Create a syncable map with peerA peerA = login(user, prefix="Page A") - peerA.goto(f"{live_server.url}/en/map/new/") + peerA.goto(f"{asgi_live_server.url}/en/map/new/") with peerA.expect_response(re.compile("./map/create/.*")): peerA.get_by_role("button", name="Save Draft").click() peerA.get_by_role("link", name="Map advanced properties").click() diff --git a/umap/tests/settings.py b/umap/tests/settings.py index b776c083..41de66d3 100644 --- a/umap/tests/settings.py +++ b/umap/tests/settings.py @@ -29,3 +29,5 @@ PASSWORD_HASHERS = [ WEBSOCKET_ENABLED = True WEBSOCKET_BACK_PORT = "8010" WEBSOCKET_FRONT_URI = "ws://localhost:8010" + +REDIS_URL = "redis://localhost:6379/15" diff --git a/umap/tests/test_websocket_server.py b/umap/tests/test_websocket_server.py deleted file mode 100644 index 62bc93e9..00000000 --- a/umap/tests/test_websocket_server.py +++ /dev/null @@ -1,22 +0,0 @@ -from umap.websocket_server import OperationMessage, PeerMessage, Request, ServerRequest - - -def test_messages_are_parsed_correctly(): - server = Request.model_validate(dict(kind="Server", action="list-peers")).root - assert type(server) is ServerRequest - - operation = Request.model_validate( - dict( - kind="OperationMessage", - verb="upsert", - subject="map", - metadata={}, - key="key", - ) - ).root - assert type(operation) is OperationMessage - - peer_message = Request.model_validate( - dict(kind="PeerMessage", sender="Alice", recipient="Bob", message={}) - ).root - assert type(peer_message) is PeerMessage diff --git a/umap/utils.py b/umap/utils.py index 26cf581d..561ae702 100644 --- a/umap/utils.py +++ b/umap/utils.py @@ -7,23 +7,36 @@ from django.core.serializers.json import DjangoJSONEncoder from django.urls import URLPattern, URLResolver, get_resolver -def _urls_for_js(urls=None): +def _get_url_names(module): + def _get_names(resolver): + names = [] + for pattern in resolver.url_patterns: + if getattr(pattern, "url_patterns", None): + # Do not add "admin" and other third party apps urls. + if not pattern.namespace: + names.extend(_get_names(pattern)) + elif getattr(pattern, "name", None): + names.append(pattern.name) + return names + + return _get_names(get_resolver(module)) + + +def _urls_for_js(): """ Return templated URLs prepared for javascript. """ - if urls is None: - # prevent circular import - from .urls import i18n_urls, urlpatterns - - urls = [ - url.name for url in urlpatterns + i18n_urls if getattr(url, "name", None) - ] - urls = dict(zip(urls, [get_uri_template(url) for url in urls])) + urls = {} + for module in ["umap.urls", "umap.sync.app"]: + names = _get_url_names(module) + urls.update( + dict(zip(names, [get_uri_template(url, module=module) for url in names])) + ) urls.update(getattr(settings, "UMAP_EXTRA_URLS", {})) return urls -def get_uri_template(urlname, args=None, prefix=""): +def get_uri_template(urlname, args=None, prefix="", module=None): """ Utility function to return an URI Template from a named URL in django Copied from django-digitalpaper. @@ -45,7 +58,7 @@ def get_uri_template(urlname, args=None, prefix=""): paths = template % dict([p, "{%s}" % p] for p in args) return "%s/%s" % (prefix, paths) - resolver = get_resolver(None) + resolver = get_resolver(module) parts = urlname.split(":") if len(parts) > 1 and parts[0] in resolver.namespace_dict: namespace = parts[0] diff --git a/umap/views.py b/umap/views.py index d1952405..c8c09476 100644 --- a/umap/views.py +++ b/umap/views.py @@ -609,7 +609,6 @@ class MapDetailMixin(SessionMixin): "umap_version": VERSION, "featuresHaveOwner": settings.UMAP_DEFAULT_FEATURES_HAVE_OWNERS, "websocketEnabled": settings.WEBSOCKET_ENABLED, - "websocketURI": settings.WEBSOCKET_FRONT_URI, "importers": settings.UMAP_IMPORTERS, "defaultLabelKeys": settings.UMAP_LABEL_KEYS, } diff --git a/umap/websocket_server.py b/umap/websocket_server.py deleted file mode 100644 index 6483d648..00000000 --- a/umap/websocket_server.py +++ /dev/null @@ -1,202 +0,0 @@ -#!/usr/bin/env python - -import asyncio -import logging -import uuid -from collections import defaultdict -from typing import Literal, Optional, Union - -import websockets -from django.conf import settings -from django.core.signing import TimestampSigner -from pydantic import BaseModel, Field, RootModel, ValidationError -from websockets import WebSocketClientProtocol -from websockets.server import serve - - -class Connections: - def __init__(self) -> None: - self._connections: set[WebSocketClientProtocol] = set() - self._ids: dict[WebSocketClientProtocol, str] = dict() - - def join(self, websocket: WebSocketClientProtocol) -> str: - self._connections.add(websocket) - _id = str(uuid.uuid4()) - self._ids[websocket] = _id - return _id - - def leave(self, websocket: WebSocketClientProtocol) -> None: - self._connections.remove(websocket) - del self._ids[websocket] - - def get(self, id) -> WebSocketClientProtocol: - # use an iterator to stop iterating as soon as we found - return next(k for k, v in self._ids.items() if v == id) - - def get_id(self, websocket: WebSocketClientProtocol): - return self._ids[websocket] - - def get_other_peers( - self, websocket: WebSocketClientProtocol - ) -> set[WebSocketClientProtocol]: - return self._connections - {websocket} - - def get_all_peers(self) -> set[WebSocketClientProtocol]: - return self._connections - - -# Contains the list of websocket connections handled by this process. -# It's a mapping of map_id to a set of the active websocket connections -CONNECTIONS: defaultdict[int, Connections] = defaultdict(Connections) - - -class JoinRequest(BaseModel): - kind: Literal["JoinRequest"] = "JoinRequest" - token: str - - -class OperationMessage(BaseModel): - """Message sent from one peer to all the others""" - - kind: Literal["OperationMessage"] = "OperationMessage" - verb: Literal["upsert", "update", "delete"] - subject: Literal["map", "datalayer", "feature"] - metadata: Optional[dict] = None - key: Optional[str] = None - - -class PeerMessage(BaseModel): - """Message sent from a specific peer to another one""" - - kind: Literal["PeerMessage"] = "PeerMessage" - sender: str - recipient: str - # The message can be whatever the peers want. It's not checked by the server. - message: dict - - -class ServerRequest(BaseModel): - """A request towards the server""" - - kind: Literal["Server"] = "Server" - action: Literal["list-peers"] - - -class Request(RootModel): - """Any message coming from the websocket should be one of these, and will be rejected otherwise.""" - - root: Union[ServerRequest, PeerMessage, OperationMessage] = Field( - discriminator="kind" - ) - - -class JoinResponse(BaseModel): - """Server response containing the list of peers""" - - kind: Literal["JoinResponse"] = "JoinResponse" - peers: list - uuid: str - - -class ListPeersResponse(BaseModel): - kind: Literal["ListPeersResponse"] = "ListPeersResponse" - peers: list - - -async def join_and_listen( - map_id: int, permissions: list, user: str | int, websocket: WebSocketClientProtocol -): - """Join a "room" with other connected peers, and wait for messages.""" - logging.debug(f"{user} joined room #{map_id}") - connections: Connections = CONNECTIONS[map_id] - _id: str = connections.join(websocket) - - # Assign an ID to the joining peer and return it the list of connected peers. - peers: list[WebSocketClientProtocol] = [ - connections.get_id(p) for p in connections.get_all_peers() - ] - response = JoinResponse(uuid=_id, peers=peers) - await websocket.send(response.model_dump_json()) - - # Notify all other peers of the new list of connected peers. - message = ListPeersResponse(peers=peers) - websockets.broadcast( - connections.get_other_peers(websocket), message.model_dump_json() - ) - - try: - async for raw_message in websocket: - if raw_message == "ping": - await websocket.send("pong") - continue - - # recompute the peers list at the time of message-sending. - # as doing so beforehand would miss new connections - other_peers = connections.get_other_peers(websocket) - try: - incoming = Request.model_validate_json(raw_message) - except ValidationError as e: - error = f"An error occurred when receiving the following message: {raw_message!r}" - logging.error(error, e) - else: - match incoming.root: - # Broadcast all operation messages to connected peers - case OperationMessage(): - websockets.broadcast(other_peers, raw_message) - - # Send peer messages to the proper peer - case PeerMessage(recipient=_id): - peer = connections.get(_id) - if peer: - await peer.send(raw_message) - - finally: - # On disconnect, remove the connection from the pool - connections.leave(websocket) - - # TODO: refactor this in a separate method. - # Notify all other peers of the new list of connected peers. - peers = [connections.get_id(p) for p in connections.get_all_peers()] - message = ListPeersResponse(peers=peers) - websockets.broadcast( - connections.get_other_peers(websocket), message.model_dump_json() - ) - - -async def handler(websocket: WebSocketClientProtocol): - """Main WebSocket handler. - - Check if the permission is granted and let the peer enter a room. - """ - raw_message = await websocket.recv() - - # The first event should always be 'join' - message: JoinRequest = JoinRequest.model_validate_json(raw_message) - signed = TimestampSigner().unsign_object(message.token, max_age=30) - user, map_id, permissions = signed.values() - - # Check if permissions for this map have been granted by the server - if "edit" in signed["permissions"]: - await join_and_listen(map_id, permissions, user, websocket) - - -def run(host: str, port: int): - if not settings.WEBSOCKET_ENABLED: - msg = ( - "WEBSOCKET_ENABLED should be set to True to run the WebSocket Server. " - "See the documentation at " - "https://docs.umap-project.org/en/stable/config/settings/#websocket_enabled " - "for more information." - ) - print(msg) - exit(1) - - async def _serve(): - async with serve(handler, host, port): - logging.debug(f"Waiting for connections on {host}:{port}") - await asyncio.Future() # run forever - - try: - asyncio.run(_serve()) - except KeyboardInterrupt: - print("Closing WebSocket server")