wip(sync): use django-channels to serve websockets

Co-authored-by: David Larlet <david@larlet.fr>
This commit is contained in:
Yohan Boniface 2024-12-23 13:37:01 +01:00
parent ebae9a8cd0
commit c6c965a601
12 changed files with 169 additions and 260 deletions

View file

@ -1,15 +1,22 @@
import os import os
from channels.routing import ProtocolTypeRouter from channels.routing import ProtocolTypeRouter, URLRouter
from channels.security.websocket import AllowedHostsOriginValidator
from django.core.asgi import get_asgi_application from django.core.asgi import get_asgi_application
from django.urls import re_path
from .sync import consumers
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "umap.settings") os.environ.setdefault("DJANGO_SETTINGS_MODULE", "umap.settings")
# Initialize Django ASGI application early to ensure the AppRegistry # Initialize Django ASGI application early to ensure the AppRegistry
# is populated before importing code that may import ORM models. # is populated before importing code that may import ORM models.
django_asgi_app = get_asgi_application() django_asgi_app = get_asgi_application()
urlpatterns = (re_path(r"ws/sync/(?P<map_id>\w+)/$", consumers.SyncConsumer.as_asgi()),)
application = ProtocolTypeRouter( application = ProtocolTypeRouter(
{ {
"http": django_asgi_app, "http": django_asgi_app,
"websocket": AllowedHostsOriginValidator(URLRouter(urlpatterns)),
} }
) )

View file

@ -1,23 +0,0 @@
from django.conf import settings
from django.core.management.base import BaseCommand
from umap import websocket_server
class Command(BaseCommand):
help = "Run the websocket server"
def add_arguments(self, parser):
parser.add_argument(
"--host",
help="The server host to bind to.",
default=settings.WEBSOCKET_BACK_HOST,
)
parser.add_argument(
"--port",
help="The server port to bind to.",
default=settings.WEBSOCKET_BACK_PORT,
)
def handle(self, *args, **options):
websocket_server.run(options["host"], options["port"])

View file

@ -343,3 +343,4 @@ WEBSOCKET_ENABLED = env.bool("WEBSOCKET_ENABLED", default=False)
WEBSOCKET_BACK_HOST = env("WEBSOCKET_BACK_HOST", default="localhost") WEBSOCKET_BACK_HOST = env("WEBSOCKET_BACK_HOST", default="localhost")
WEBSOCKET_BACK_PORT = env.int("WEBSOCKET_BACK_PORT", default=8001) WEBSOCKET_BACK_PORT = env.int("WEBSOCKET_BACK_PORT", default=8001)
WEBSOCKET_FRONT_URI = env("WEBSOCKET_FRONT_URI", default="ws://localhost:8001") WEBSOCKET_FRONT_URI = env("WEBSOCKET_FRONT_URI", default="ws://localhost:8001")
CHANNEL_LAYERS = {"default": {"BACKEND": "channels.layers.InMemoryChannelLayer"}}

View file

@ -77,7 +77,7 @@ export class SyncEngine {
start(authToken) { start(authToken) {
this.transport = new WebSocketTransport( this.transport = new WebSocketTransport(
this._umap.properties.websocketURI, Utils.template(this._umap.properties.websocketURI, { id: this._umap.id }),
authToken, authToken,
this this
) )
@ -125,7 +125,7 @@ export class SyncEngine {
if (this.offline) return if (this.offline) return
if (this.transport) { if (this.transport) {
this.transport.send('OperationMessage', message) this.transport.send('OperationMessage', { sender: this.uuid, ...message })
} }
} }
@ -177,6 +177,7 @@ export class SyncEngine {
* @param {Object} payload * @param {Object} payload
*/ */
onOperationMessage(payload) { onOperationMessage(payload) {
if (payload.sender === this.uuid) return
this._operations.storeRemoteOperations([payload]) this._operations.storeRemoteOperations([payload])
this._applyOperation(payload) this._applyOperation(payload)
} }
@ -484,7 +485,7 @@ export class Operations {
return ( return (
Utils.deepEqual(local.subject, remote.subject) && Utils.deepEqual(local.subject, remote.subject) &&
Utils.deepEqual(local.metadata, remote.metadata) && Utils.deepEqual(local.metadata, remote.metadata) &&
(!shouldCheckKey || (shouldCheckKey && local.key == remote.key)) (!shouldCheckKey || (shouldCheckKey && local.key === remote.key))
) )
} }
} }

View file

@ -6,7 +6,7 @@ export class WebSocketTransport {
constructor(webSocketURI, authToken, messagesReceiver) { constructor(webSocketURI, authToken, messagesReceiver) {
this.receiver = messagesReceiver this.receiver = messagesReceiver
this.websocket = new WebSocket(webSocketURI) this.websocket = new WebSocket(`${webSocketURI}`)
this.websocket.onopen = () => { this.websocket.onopen = () => {
this.send('JoinRequest', { token: authToken }) this.send('JoinRequest', { token: authToken })
@ -48,6 +48,7 @@ export class WebSocketTransport {
} }
onMessage(wsMessage) { onMessage(wsMessage) {
console.log(wsMessage)
if (wsMessage.data === 'pong') { if (wsMessage.data === 'pong') {
this.pongReceived = true this.pongReceived = true
} else { } else {

0
umap/sync/__init__.py Normal file
View file

86
umap/sync/consumers.py Normal file
View file

@ -0,0 +1,86 @@
import logging
from channels.generic.websocket import AsyncWebsocketConsumer
from django.core.signing import TimestampSigner
from pydantic import ValidationError
from .payloads import (
JoinRequest,
JoinResponse,
ListPeersResponse,
OperationMessage,
PeerMessage,
Request,
)
class SyncConsumer(AsyncWebsocketConsumer):
@property
def peers(self):
return self.channel_layer.groups[self.map_id].keys()
async def connect(self):
self.map_id = self.scope["url_route"]["kwargs"]["map_id"]
# Join room group
await self.channel_layer.group_add(self.map_id, self.channel_name)
self.is_authenticated = False
await self.accept()
async def disconnect(self, close_code):
await self.channel_layer.group_discard(self.map_id, self.channel_name)
await self.send_peers_list()
async def send_peers_list(self):
message = ListPeersResponse(peers=self.peers)
await self.broadcast(message.model_dump_json())
async def broadcast(self, message):
# Send to all channels (including sender!)
await self.channel_layer.group_send(
self.map_id, {"message": message, "type": "on_message"}
)
async def send_to(self, channel, message):
# Send to one given channel
await self.channel_layer.send(
channel, {"message": message, "type": "on_message"}
)
async def on_message(self, event):
# Send to self channel
await self.send(event["message"])
async def receive(self, text_data):
if not self.is_authenticated:
message = JoinRequest.model_validate_json(text_data)
signed = TimestampSigner().unsign_object(message.token, max_age=30)
user, map_id, permissions = signed.values()
if "edit" not in permissions:
return await self.disconnect()
response = JoinResponse(uuid=self.channel_name, peers=self.peers)
await self.send(response.model_dump_json())
await self.send_peers_list()
self.is_authenticated = True
return
if text_data == "ping":
return await self.send("pong")
try:
incoming = Request.model_validate_json(text_data)
except ValidationError as error:
message = (
f"An error occurred when receiving the following message: {text_data!r}"
)
logging.error(message, error)
else:
match incoming.root:
# Broadcast all operation messages to connected peers
case OperationMessage():
await self.broadcast(text_data)
# Send peer messages to the proper peer
case PeerMessage():
await self.send_to(incoming.root.recipient, text_data)

47
umap/sync/payloads.py Normal file
View file

@ -0,0 +1,47 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field, RootModel
class JoinRequest(BaseModel):
kind: Literal["JoinRequest"] = "JoinRequest"
token: str
class OperationMessage(BaseModel):
"""Message sent from one peer to all the others"""
kind: Literal["OperationMessage"] = "OperationMessage"
verb: Literal["upsert", "update", "delete"]
subject: Literal["map", "datalayer", "feature"]
metadata: Optional[dict] = None
key: Optional[str] = None
class PeerMessage(BaseModel):
"""Message sent from a specific peer to another one"""
kind: Literal["PeerMessage"] = "PeerMessage"
sender: str
recipient: str
# The message can be whatever the peers want. It's not checked by the server.
message: dict
class Request(RootModel):
"""Any message coming from the websocket should be one of these, and will be rejected otherwise."""
root: Union[PeerMessage, OperationMessage] = Field(discriminator="kind")
class JoinResponse(BaseModel):
"""Server response containing the list of peers"""
kind: Literal["JoinResponse"] = "JoinResponse"
peers: list
uuid: str
class ListPeersResponse(BaseModel):
kind: Literal["ListPeersResponse"] = "ListPeersResponse"
peers: list

View file

@ -5,6 +5,7 @@ import time
from pathlib import Path from pathlib import Path
import pytest import pytest
from channels.testing import ChannelsLiveServerTestCase
from playwright.sync_api import expect from playwright.sync_api import expect
from ..base import mock_tiles from ..base import mock_tiles
@ -87,3 +88,15 @@ def websocket_server():
yield ds_proc yield ds_proc
# Shut it down at the end of the pytest session # Shut it down at the end of the pytest session
ds_proc.terminate() ds_proc.terminate()
@pytest.fixture(scope="function")
def channels_live_server(request, settings):
server = ChannelsLiveServerTestCase()
server.serve_static = False
server._pre_setup()
settings.WEBSOCKET_FRONT_URI = f"{server.live_server_ws_url}/ws/sync/{{id}}/"
yield server
server._post_teardown()

View file

@ -12,7 +12,7 @@ DATALAYER_UPDATE = re.compile(r".*/datalayer/update/.*")
@pytest.mark.xdist_group(name="websockets") @pytest.mark.xdist_group(name="websockets")
def test_websocket_connection_can_sync_markers( def test_websocket_connection_can_sync_markers(
new_page, live_server, websocket_server, tilelayer new_page, live_server, channels_live_server, tilelayer
): ):
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
map.settings["properties"]["syncEnabled"] = True map.settings["properties"]["syncEnabled"] = True
@ -80,7 +80,7 @@ def test_websocket_connection_can_sync_markers(
@pytest.mark.xdist_group(name="websockets") @pytest.mark.xdist_group(name="websockets")
def test_websocket_connection_can_sync_polygons( def test_websocket_connection_can_sync_polygons(
context, live_server, websocket_server, tilelayer context, live_server, channels_live_server, tilelayer
): ):
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
map.settings["properties"]["syncEnabled"] = True map.settings["properties"]["syncEnabled"] = True
@ -164,7 +164,7 @@ def test_websocket_connection_can_sync_polygons(
@pytest.mark.xdist_group(name="websockets") @pytest.mark.xdist_group(name="websockets")
def test_websocket_connection_can_sync_map_properties( def test_websocket_connection_can_sync_map_properties(
new_page, live_server, websocket_server, tilelayer new_page, live_server, channels_live_server, tilelayer
): ):
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
map.settings["properties"]["syncEnabled"] = True map.settings["properties"]["syncEnabled"] = True
@ -196,7 +196,7 @@ def test_websocket_connection_can_sync_map_properties(
@pytest.mark.xdist_group(name="websockets") @pytest.mark.xdist_group(name="websockets")
def test_websocket_connection_can_sync_datalayer_properties( def test_websocket_connection_can_sync_datalayer_properties(
new_page, live_server, websocket_server, tilelayer new_page, live_server, channels_live_server, tilelayer
): ):
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
map.settings["properties"]["syncEnabled"] = True map.settings["properties"]["syncEnabled"] = True
@ -225,7 +225,7 @@ def test_websocket_connection_can_sync_datalayer_properties(
@pytest.mark.xdist_group(name="websockets") @pytest.mark.xdist_group(name="websockets")
def test_websocket_connection_can_sync_cloned_polygons( def test_websocket_connection_can_sync_cloned_polygons(
context, live_server, websocket_server, tilelayer context, live_server, channels_live_server, tilelayer
): ):
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
map.settings["properties"]["syncEnabled"] = True map.settings["properties"]["syncEnabled"] = True
@ -288,7 +288,7 @@ def test_websocket_connection_can_sync_cloned_polygons(
@pytest.mark.xdist_group(name="websockets") @pytest.mark.xdist_group(name="websockets")
def test_websocket_connection_can_sync_late_joining_peer( def test_websocket_connection_can_sync_late_joining_peer(
new_page, live_server, websocket_server, tilelayer new_page, live_server, channels_live_server, tilelayer
): ):
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
map.settings["properties"]["syncEnabled"] = True map.settings["properties"]["syncEnabled"] = True
@ -349,7 +349,7 @@ def test_websocket_connection_can_sync_late_joining_peer(
@pytest.mark.xdist_group(name="websockets") @pytest.mark.xdist_group(name="websockets")
def test_should_sync_datalayers(new_page, live_server, websocket_server, tilelayer): def test_should_sync_datalayers(new_page, live_server, channels_live_server, tilelayer):
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS) map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
map.settings["properties"]["syncEnabled"] = True map.settings["properties"]["syncEnabled"] = True
map.save() map.save()
@ -422,7 +422,7 @@ def test_should_sync_datalayers(new_page, live_server, websocket_server, tilelay
@pytest.mark.xdist_group(name="websockets") @pytest.mark.xdist_group(name="websockets")
def test_create_and_sync_map( def test_create_and_sync_map(
new_page, live_server, websocket_server, tilelayer, login, user new_page, live_server, channels_live_server, tilelayer, login, user
): ):
# Create a syncable map with peerA # Create a syncable map with peerA
peerA = login(user, prefix="Page A") peerA = login(user, prefix="Page A")

View file

@ -1,22 +0,0 @@
from umap.websocket_server import OperationMessage, PeerMessage, Request, ServerRequest
def test_messages_are_parsed_correctly():
server = Request.model_validate(dict(kind="Server", action="list-peers")).root
assert type(server) is ServerRequest
operation = Request.model_validate(
dict(
kind="OperationMessage",
verb="upsert",
subject="map",
metadata={},
key="key",
)
).root
assert type(operation) is OperationMessage
peer_message = Request.model_validate(
dict(kind="PeerMessage", sender="Alice", recipient="Bob", message={})
).root
assert type(peer_message) is PeerMessage

View file

@ -1,202 +0,0 @@
#!/usr/bin/env python
import asyncio
import logging
import uuid
from collections import defaultdict
from typing import Literal, Optional, Union
import websockets
from django.conf import settings
from django.core.signing import TimestampSigner
from pydantic import BaseModel, Field, RootModel, ValidationError
from websockets import WebSocketClientProtocol
from websockets.server import serve
class Connections:
def __init__(self) -> None:
self._connections: set[WebSocketClientProtocol] = set()
self._ids: dict[WebSocketClientProtocol, str] = dict()
def join(self, websocket: WebSocketClientProtocol) -> str:
self._connections.add(websocket)
_id = str(uuid.uuid4())
self._ids[websocket] = _id
return _id
def leave(self, websocket: WebSocketClientProtocol) -> None:
self._connections.remove(websocket)
del self._ids[websocket]
def get(self, id) -> WebSocketClientProtocol:
# use an iterator to stop iterating as soon as we found
return next(k for k, v in self._ids.items() if v == id)
def get_id(self, websocket: WebSocketClientProtocol):
return self._ids[websocket]
def get_other_peers(
self, websocket: WebSocketClientProtocol
) -> set[WebSocketClientProtocol]:
return self._connections - {websocket}
def get_all_peers(self) -> set[WebSocketClientProtocol]:
return self._connections
# Contains the list of websocket connections handled by this process.
# It's a mapping of map_id to a set of the active websocket connections
CONNECTIONS: defaultdict[int, Connections] = defaultdict(Connections)
class JoinRequest(BaseModel):
kind: Literal["JoinRequest"] = "JoinRequest"
token: str
class OperationMessage(BaseModel):
"""Message sent from one peer to all the others"""
kind: Literal["OperationMessage"] = "OperationMessage"
verb: Literal["upsert", "update", "delete"]
subject: Literal["map", "datalayer", "feature"]
metadata: Optional[dict] = None
key: Optional[str] = None
class PeerMessage(BaseModel):
"""Message sent from a specific peer to another one"""
kind: Literal["PeerMessage"] = "PeerMessage"
sender: str
recipient: str
# The message can be whatever the peers want. It's not checked by the server.
message: dict
class ServerRequest(BaseModel):
"""A request towards the server"""
kind: Literal["Server"] = "Server"
action: Literal["list-peers"]
class Request(RootModel):
"""Any message coming from the websocket should be one of these, and will be rejected otherwise."""
root: Union[ServerRequest, PeerMessage, OperationMessage] = Field(
discriminator="kind"
)
class JoinResponse(BaseModel):
"""Server response containing the list of peers"""
kind: Literal["JoinResponse"] = "JoinResponse"
peers: list
uuid: str
class ListPeersResponse(BaseModel):
kind: Literal["ListPeersResponse"] = "ListPeersResponse"
peers: list
async def join_and_listen(
map_id: int, permissions: list, user: str | int, websocket: WebSocketClientProtocol
):
"""Join a "room" with other connected peers, and wait for messages."""
logging.debug(f"{user} joined room #{map_id}")
connections: Connections = CONNECTIONS[map_id]
_id: str = connections.join(websocket)
# Assign an ID to the joining peer and return it the list of connected peers.
peers: list[WebSocketClientProtocol] = [
connections.get_id(p) for p in connections.get_all_peers()
]
response = JoinResponse(uuid=_id, peers=peers)
await websocket.send(response.model_dump_json())
# Notify all other peers of the new list of connected peers.
message = ListPeersResponse(peers=peers)
websockets.broadcast(
connections.get_other_peers(websocket), message.model_dump_json()
)
try:
async for raw_message in websocket:
if raw_message == "ping":
await websocket.send("pong")
continue
# recompute the peers list at the time of message-sending.
# as doing so beforehand would miss new connections
other_peers = connections.get_other_peers(websocket)
try:
incoming = Request.model_validate_json(raw_message)
except ValidationError as e:
error = f"An error occurred when receiving the following message: {raw_message!r}"
logging.error(error, e)
else:
match incoming.root:
# Broadcast all operation messages to connected peers
case OperationMessage():
websockets.broadcast(other_peers, raw_message)
# Send peer messages to the proper peer
case PeerMessage(recipient=_id):
peer = connections.get(_id)
if peer:
await peer.send(raw_message)
finally:
# On disconnect, remove the connection from the pool
connections.leave(websocket)
# TODO: refactor this in a separate method.
# Notify all other peers of the new list of connected peers.
peers = [connections.get_id(p) for p in connections.get_all_peers()]
message = ListPeersResponse(peers=peers)
websockets.broadcast(
connections.get_other_peers(websocket), message.model_dump_json()
)
async def handler(websocket: WebSocketClientProtocol):
"""Main WebSocket handler.
Check if the permission is granted and let the peer enter a room.
"""
raw_message = await websocket.recv()
# The first event should always be 'join'
message: JoinRequest = JoinRequest.model_validate_json(raw_message)
signed = TimestampSigner().unsign_object(message.token, max_age=30)
user, map_id, permissions = signed.values()
# Check if permissions for this map have been granted by the server
if "edit" in signed["permissions"]:
await join_and_listen(map_id, permissions, user, websocket)
def run(host: str, port: int):
if not settings.WEBSOCKET_ENABLED:
msg = (
"WEBSOCKET_ENABLED should be set to True to run the WebSocket Server. "
"See the documentation at "
"https://docs.umap-project.org/en/stable/config/settings/#websocket_enabled "
"for more information."
)
print(msg)
exit(1)
async def _serve():
async with serve(handler, host, port):
logging.debug(f"Waiting for connections on {host}:{port}")
await asyncio.Future() # run forever
try:
asyncio.run(_serve())
except KeyboardInterrupt:
print("Closing WebSocket server")