mirror of
https://github.com/umap-project/umap.git
synced 2025-05-04 05:31:50 +02:00
wip(sync): use django-channels to serve websockets
Co-authored-by: David Larlet <david@larlet.fr>
This commit is contained in:
parent
ebae9a8cd0
commit
c6c965a601
12 changed files with 169 additions and 260 deletions
|
@ -1,15 +1,22 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from channels.routing import ProtocolTypeRouter
|
from channels.routing import ProtocolTypeRouter, URLRouter
|
||||||
|
from channels.security.websocket import AllowedHostsOriginValidator
|
||||||
from django.core.asgi import get_asgi_application
|
from django.core.asgi import get_asgi_application
|
||||||
|
from django.urls import re_path
|
||||||
|
|
||||||
|
from .sync import consumers
|
||||||
|
|
||||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "umap.settings")
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "umap.settings")
|
||||||
# Initialize Django ASGI application early to ensure the AppRegistry
|
# Initialize Django ASGI application early to ensure the AppRegistry
|
||||||
# is populated before importing code that may import ORM models.
|
# is populated before importing code that may import ORM models.
|
||||||
django_asgi_app = get_asgi_application()
|
django_asgi_app = get_asgi_application()
|
||||||
|
|
||||||
|
urlpatterns = (re_path(r"ws/sync/(?P<map_id>\w+)/$", consumers.SyncConsumer.as_asgi()),)
|
||||||
|
|
||||||
application = ProtocolTypeRouter(
|
application = ProtocolTypeRouter(
|
||||||
{
|
{
|
||||||
"http": django_asgi_app,
|
"http": django_asgi_app,
|
||||||
|
"websocket": AllowedHostsOriginValidator(URLRouter(urlpatterns)),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
from django.conf import settings
|
|
||||||
from django.core.management.base import BaseCommand
|
|
||||||
|
|
||||||
from umap import websocket_server
|
|
||||||
|
|
||||||
|
|
||||||
class Command(BaseCommand):
|
|
||||||
help = "Run the websocket server"
|
|
||||||
|
|
||||||
def add_arguments(self, parser):
|
|
||||||
parser.add_argument(
|
|
||||||
"--host",
|
|
||||||
help="The server host to bind to.",
|
|
||||||
default=settings.WEBSOCKET_BACK_HOST,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--port",
|
|
||||||
help="The server port to bind to.",
|
|
||||||
default=settings.WEBSOCKET_BACK_PORT,
|
|
||||||
)
|
|
||||||
|
|
||||||
def handle(self, *args, **options):
|
|
||||||
websocket_server.run(options["host"], options["port"])
|
|
|
@ -343,3 +343,4 @@ WEBSOCKET_ENABLED = env.bool("WEBSOCKET_ENABLED", default=False)
|
||||||
WEBSOCKET_BACK_HOST = env("WEBSOCKET_BACK_HOST", default="localhost")
|
WEBSOCKET_BACK_HOST = env("WEBSOCKET_BACK_HOST", default="localhost")
|
||||||
WEBSOCKET_BACK_PORT = env.int("WEBSOCKET_BACK_PORT", default=8001)
|
WEBSOCKET_BACK_PORT = env.int("WEBSOCKET_BACK_PORT", default=8001)
|
||||||
WEBSOCKET_FRONT_URI = env("WEBSOCKET_FRONT_URI", default="ws://localhost:8001")
|
WEBSOCKET_FRONT_URI = env("WEBSOCKET_FRONT_URI", default="ws://localhost:8001")
|
||||||
|
CHANNEL_LAYERS = {"default": {"BACKEND": "channels.layers.InMemoryChannelLayer"}}
|
||||||
|
|
|
@ -77,7 +77,7 @@ export class SyncEngine {
|
||||||
|
|
||||||
start(authToken) {
|
start(authToken) {
|
||||||
this.transport = new WebSocketTransport(
|
this.transport = new WebSocketTransport(
|
||||||
this._umap.properties.websocketURI,
|
Utils.template(this._umap.properties.websocketURI, { id: this._umap.id }),
|
||||||
authToken,
|
authToken,
|
||||||
this
|
this
|
||||||
)
|
)
|
||||||
|
@ -125,7 +125,7 @@ export class SyncEngine {
|
||||||
|
|
||||||
if (this.offline) return
|
if (this.offline) return
|
||||||
if (this.transport) {
|
if (this.transport) {
|
||||||
this.transport.send('OperationMessage', message)
|
this.transport.send('OperationMessage', { sender: this.uuid, ...message })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -177,6 +177,7 @@ export class SyncEngine {
|
||||||
* @param {Object} payload
|
* @param {Object} payload
|
||||||
*/
|
*/
|
||||||
onOperationMessage(payload) {
|
onOperationMessage(payload) {
|
||||||
|
if (payload.sender === this.uuid) return
|
||||||
this._operations.storeRemoteOperations([payload])
|
this._operations.storeRemoteOperations([payload])
|
||||||
this._applyOperation(payload)
|
this._applyOperation(payload)
|
||||||
}
|
}
|
||||||
|
@ -484,7 +485,7 @@ export class Operations {
|
||||||
return (
|
return (
|
||||||
Utils.deepEqual(local.subject, remote.subject) &&
|
Utils.deepEqual(local.subject, remote.subject) &&
|
||||||
Utils.deepEqual(local.metadata, remote.metadata) &&
|
Utils.deepEqual(local.metadata, remote.metadata) &&
|
||||||
(!shouldCheckKey || (shouldCheckKey && local.key == remote.key))
|
(!shouldCheckKey || (shouldCheckKey && local.key === remote.key))
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ export class WebSocketTransport {
|
||||||
constructor(webSocketURI, authToken, messagesReceiver) {
|
constructor(webSocketURI, authToken, messagesReceiver) {
|
||||||
this.receiver = messagesReceiver
|
this.receiver = messagesReceiver
|
||||||
|
|
||||||
this.websocket = new WebSocket(webSocketURI)
|
this.websocket = new WebSocket(`${webSocketURI}`)
|
||||||
|
|
||||||
this.websocket.onopen = () => {
|
this.websocket.onopen = () => {
|
||||||
this.send('JoinRequest', { token: authToken })
|
this.send('JoinRequest', { token: authToken })
|
||||||
|
@ -48,6 +48,7 @@ export class WebSocketTransport {
|
||||||
}
|
}
|
||||||
|
|
||||||
onMessage(wsMessage) {
|
onMessage(wsMessage) {
|
||||||
|
console.log(wsMessage)
|
||||||
if (wsMessage.data === 'pong') {
|
if (wsMessage.data === 'pong') {
|
||||||
this.pongReceived = true
|
this.pongReceived = true
|
||||||
} else {
|
} else {
|
||||||
|
|
0
umap/sync/__init__.py
Normal file
0
umap/sync/__init__.py
Normal file
86
umap/sync/consumers.py
Normal file
86
umap/sync/consumers.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||||
|
from django.core.signing import TimestampSigner
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from .payloads import (
|
||||||
|
JoinRequest,
|
||||||
|
JoinResponse,
|
||||||
|
ListPeersResponse,
|
||||||
|
OperationMessage,
|
||||||
|
PeerMessage,
|
||||||
|
Request,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SyncConsumer(AsyncWebsocketConsumer):
|
||||||
|
@property
|
||||||
|
def peers(self):
|
||||||
|
return self.channel_layer.groups[self.map_id].keys()
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
self.map_id = self.scope["url_route"]["kwargs"]["map_id"]
|
||||||
|
|
||||||
|
# Join room group
|
||||||
|
await self.channel_layer.group_add(self.map_id, self.channel_name)
|
||||||
|
|
||||||
|
self.is_authenticated = False
|
||||||
|
await self.accept()
|
||||||
|
|
||||||
|
async def disconnect(self, close_code):
|
||||||
|
await self.channel_layer.group_discard(self.map_id, self.channel_name)
|
||||||
|
await self.send_peers_list()
|
||||||
|
|
||||||
|
async def send_peers_list(self):
|
||||||
|
message = ListPeersResponse(peers=self.peers)
|
||||||
|
await self.broadcast(message.model_dump_json())
|
||||||
|
|
||||||
|
async def broadcast(self, message):
|
||||||
|
# Send to all channels (including sender!)
|
||||||
|
await self.channel_layer.group_send(
|
||||||
|
self.map_id, {"message": message, "type": "on_message"}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send_to(self, channel, message):
|
||||||
|
# Send to one given channel
|
||||||
|
await self.channel_layer.send(
|
||||||
|
channel, {"message": message, "type": "on_message"}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_message(self, event):
|
||||||
|
# Send to self channel
|
||||||
|
await self.send(event["message"])
|
||||||
|
|
||||||
|
async def receive(self, text_data):
|
||||||
|
if not self.is_authenticated:
|
||||||
|
message = JoinRequest.model_validate_json(text_data)
|
||||||
|
signed = TimestampSigner().unsign_object(message.token, max_age=30)
|
||||||
|
user, map_id, permissions = signed.values()
|
||||||
|
if "edit" not in permissions:
|
||||||
|
return await self.disconnect()
|
||||||
|
response = JoinResponse(uuid=self.channel_name, peers=self.peers)
|
||||||
|
await self.send(response.model_dump_json())
|
||||||
|
await self.send_peers_list()
|
||||||
|
self.is_authenticated = True
|
||||||
|
return
|
||||||
|
|
||||||
|
if text_data == "ping":
|
||||||
|
return await self.send("pong")
|
||||||
|
|
||||||
|
try:
|
||||||
|
incoming = Request.model_validate_json(text_data)
|
||||||
|
except ValidationError as error:
|
||||||
|
message = (
|
||||||
|
f"An error occurred when receiving the following message: {text_data!r}"
|
||||||
|
)
|
||||||
|
logging.error(message, error)
|
||||||
|
else:
|
||||||
|
match incoming.root:
|
||||||
|
# Broadcast all operation messages to connected peers
|
||||||
|
case OperationMessage():
|
||||||
|
await self.broadcast(text_data)
|
||||||
|
|
||||||
|
# Send peer messages to the proper peer
|
||||||
|
case PeerMessage():
|
||||||
|
await self.send_to(incoming.root.recipient, text_data)
|
47
umap/sync/payloads.py
Normal file
47
umap/sync/payloads.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, RootModel
|
||||||
|
|
||||||
|
|
||||||
|
class JoinRequest(BaseModel):
|
||||||
|
kind: Literal["JoinRequest"] = "JoinRequest"
|
||||||
|
token: str
|
||||||
|
|
||||||
|
|
||||||
|
class OperationMessage(BaseModel):
|
||||||
|
"""Message sent from one peer to all the others"""
|
||||||
|
|
||||||
|
kind: Literal["OperationMessage"] = "OperationMessage"
|
||||||
|
verb: Literal["upsert", "update", "delete"]
|
||||||
|
subject: Literal["map", "datalayer", "feature"]
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
key: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PeerMessage(BaseModel):
|
||||||
|
"""Message sent from a specific peer to another one"""
|
||||||
|
|
||||||
|
kind: Literal["PeerMessage"] = "PeerMessage"
|
||||||
|
sender: str
|
||||||
|
recipient: str
|
||||||
|
# The message can be whatever the peers want. It's not checked by the server.
|
||||||
|
message: dict
|
||||||
|
|
||||||
|
|
||||||
|
class Request(RootModel):
|
||||||
|
"""Any message coming from the websocket should be one of these, and will be rejected otherwise."""
|
||||||
|
|
||||||
|
root: Union[PeerMessage, OperationMessage] = Field(discriminator="kind")
|
||||||
|
|
||||||
|
|
||||||
|
class JoinResponse(BaseModel):
|
||||||
|
"""Server response containing the list of peers"""
|
||||||
|
|
||||||
|
kind: Literal["JoinResponse"] = "JoinResponse"
|
||||||
|
peers: list
|
||||||
|
uuid: str
|
||||||
|
|
||||||
|
|
||||||
|
class ListPeersResponse(BaseModel):
|
||||||
|
kind: Literal["ListPeersResponse"] = "ListPeersResponse"
|
||||||
|
peers: list
|
|
@ -5,6 +5,7 @@ import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from channels.testing import ChannelsLiveServerTestCase
|
||||||
from playwright.sync_api import expect
|
from playwright.sync_api import expect
|
||||||
|
|
||||||
from ..base import mock_tiles
|
from ..base import mock_tiles
|
||||||
|
@ -87,3 +88,15 @@ def websocket_server():
|
||||||
yield ds_proc
|
yield ds_proc
|
||||||
# Shut it down at the end of the pytest session
|
# Shut it down at the end of the pytest session
|
||||||
ds_proc.terminate()
|
ds_proc.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def channels_live_server(request, settings):
|
||||||
|
server = ChannelsLiveServerTestCase()
|
||||||
|
server.serve_static = False
|
||||||
|
server._pre_setup()
|
||||||
|
settings.WEBSOCKET_FRONT_URI = f"{server.live_server_ws_url}/ws/sync/{{id}}/"
|
||||||
|
|
||||||
|
yield server
|
||||||
|
|
||||||
|
server._post_teardown()
|
||||||
|
|
|
@ -12,7 +12,7 @@ DATALAYER_UPDATE = re.compile(r".*/datalayer/update/.*")
|
||||||
|
|
||||||
@pytest.mark.xdist_group(name="websockets")
|
@pytest.mark.xdist_group(name="websockets")
|
||||||
def test_websocket_connection_can_sync_markers(
|
def test_websocket_connection_can_sync_markers(
|
||||||
new_page, live_server, websocket_server, tilelayer
|
new_page, live_server, channels_live_server, tilelayer
|
||||||
):
|
):
|
||||||
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
|
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
|
||||||
map.settings["properties"]["syncEnabled"] = True
|
map.settings["properties"]["syncEnabled"] = True
|
||||||
|
@ -80,7 +80,7 @@ def test_websocket_connection_can_sync_markers(
|
||||||
|
|
||||||
@pytest.mark.xdist_group(name="websockets")
|
@pytest.mark.xdist_group(name="websockets")
|
||||||
def test_websocket_connection_can_sync_polygons(
|
def test_websocket_connection_can_sync_polygons(
|
||||||
context, live_server, websocket_server, tilelayer
|
context, live_server, channels_live_server, tilelayer
|
||||||
):
|
):
|
||||||
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
|
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
|
||||||
map.settings["properties"]["syncEnabled"] = True
|
map.settings["properties"]["syncEnabled"] = True
|
||||||
|
@ -164,7 +164,7 @@ def test_websocket_connection_can_sync_polygons(
|
||||||
|
|
||||||
@pytest.mark.xdist_group(name="websockets")
|
@pytest.mark.xdist_group(name="websockets")
|
||||||
def test_websocket_connection_can_sync_map_properties(
|
def test_websocket_connection_can_sync_map_properties(
|
||||||
new_page, live_server, websocket_server, tilelayer
|
new_page, live_server, channels_live_server, tilelayer
|
||||||
):
|
):
|
||||||
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
|
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
|
||||||
map.settings["properties"]["syncEnabled"] = True
|
map.settings["properties"]["syncEnabled"] = True
|
||||||
|
@ -196,7 +196,7 @@ def test_websocket_connection_can_sync_map_properties(
|
||||||
|
|
||||||
@pytest.mark.xdist_group(name="websockets")
|
@pytest.mark.xdist_group(name="websockets")
|
||||||
def test_websocket_connection_can_sync_datalayer_properties(
|
def test_websocket_connection_can_sync_datalayer_properties(
|
||||||
new_page, live_server, websocket_server, tilelayer
|
new_page, live_server, channels_live_server, tilelayer
|
||||||
):
|
):
|
||||||
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
|
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
|
||||||
map.settings["properties"]["syncEnabled"] = True
|
map.settings["properties"]["syncEnabled"] = True
|
||||||
|
@ -225,7 +225,7 @@ def test_websocket_connection_can_sync_datalayer_properties(
|
||||||
|
|
||||||
@pytest.mark.xdist_group(name="websockets")
|
@pytest.mark.xdist_group(name="websockets")
|
||||||
def test_websocket_connection_can_sync_cloned_polygons(
|
def test_websocket_connection_can_sync_cloned_polygons(
|
||||||
context, live_server, websocket_server, tilelayer
|
context, live_server, channels_live_server, tilelayer
|
||||||
):
|
):
|
||||||
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
|
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
|
||||||
map.settings["properties"]["syncEnabled"] = True
|
map.settings["properties"]["syncEnabled"] = True
|
||||||
|
@ -288,7 +288,7 @@ def test_websocket_connection_can_sync_cloned_polygons(
|
||||||
|
|
||||||
@pytest.mark.xdist_group(name="websockets")
|
@pytest.mark.xdist_group(name="websockets")
|
||||||
def test_websocket_connection_can_sync_late_joining_peer(
|
def test_websocket_connection_can_sync_late_joining_peer(
|
||||||
new_page, live_server, websocket_server, tilelayer
|
new_page, live_server, channels_live_server, tilelayer
|
||||||
):
|
):
|
||||||
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
|
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
|
||||||
map.settings["properties"]["syncEnabled"] = True
|
map.settings["properties"]["syncEnabled"] = True
|
||||||
|
@ -349,7 +349,7 @@ def test_websocket_connection_can_sync_late_joining_peer(
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xdist_group(name="websockets")
|
@pytest.mark.xdist_group(name="websockets")
|
||||||
def test_should_sync_datalayers(new_page, live_server, websocket_server, tilelayer):
|
def test_should_sync_datalayers(new_page, live_server, channels_live_server, tilelayer):
|
||||||
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
|
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
|
||||||
map.settings["properties"]["syncEnabled"] = True
|
map.settings["properties"]["syncEnabled"] = True
|
||||||
map.save()
|
map.save()
|
||||||
|
@ -422,7 +422,7 @@ def test_should_sync_datalayers(new_page, live_server, websocket_server, tilelay
|
||||||
|
|
||||||
@pytest.mark.xdist_group(name="websockets")
|
@pytest.mark.xdist_group(name="websockets")
|
||||||
def test_create_and_sync_map(
|
def test_create_and_sync_map(
|
||||||
new_page, live_server, websocket_server, tilelayer, login, user
|
new_page, live_server, channels_live_server, tilelayer, login, user
|
||||||
):
|
):
|
||||||
# Create a syncable map with peerA
|
# Create a syncable map with peerA
|
||||||
peerA = login(user, prefix="Page A")
|
peerA = login(user, prefix="Page A")
|
||||||
|
|
|
@ -1,22 +0,0 @@
|
||||||
from umap.websocket_server import OperationMessage, PeerMessage, Request, ServerRequest
|
|
||||||
|
|
||||||
|
|
||||||
def test_messages_are_parsed_correctly():
|
|
||||||
server = Request.model_validate(dict(kind="Server", action="list-peers")).root
|
|
||||||
assert type(server) is ServerRequest
|
|
||||||
|
|
||||||
operation = Request.model_validate(
|
|
||||||
dict(
|
|
||||||
kind="OperationMessage",
|
|
||||||
verb="upsert",
|
|
||||||
subject="map",
|
|
||||||
metadata={},
|
|
||||||
key="key",
|
|
||||||
)
|
|
||||||
).root
|
|
||||||
assert type(operation) is OperationMessage
|
|
||||||
|
|
||||||
peer_message = Request.model_validate(
|
|
||||||
dict(kind="PeerMessage", sender="Alice", recipient="Bob", message={})
|
|
||||||
).root
|
|
||||||
assert type(peer_message) is PeerMessage
|
|
|
@ -1,202 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Literal, Optional, Union
|
|
||||||
|
|
||||||
import websockets
|
|
||||||
from django.conf import settings
|
|
||||||
from django.core.signing import TimestampSigner
|
|
||||||
from pydantic import BaseModel, Field, RootModel, ValidationError
|
|
||||||
from websockets import WebSocketClientProtocol
|
|
||||||
from websockets.server import serve
|
|
||||||
|
|
||||||
|
|
||||||
class Connections:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._connections: set[WebSocketClientProtocol] = set()
|
|
||||||
self._ids: dict[WebSocketClientProtocol, str] = dict()
|
|
||||||
|
|
||||||
def join(self, websocket: WebSocketClientProtocol) -> str:
|
|
||||||
self._connections.add(websocket)
|
|
||||||
_id = str(uuid.uuid4())
|
|
||||||
self._ids[websocket] = _id
|
|
||||||
return _id
|
|
||||||
|
|
||||||
def leave(self, websocket: WebSocketClientProtocol) -> None:
|
|
||||||
self._connections.remove(websocket)
|
|
||||||
del self._ids[websocket]
|
|
||||||
|
|
||||||
def get(self, id) -> WebSocketClientProtocol:
|
|
||||||
# use an iterator to stop iterating as soon as we found
|
|
||||||
return next(k for k, v in self._ids.items() if v == id)
|
|
||||||
|
|
||||||
def get_id(self, websocket: WebSocketClientProtocol):
|
|
||||||
return self._ids[websocket]
|
|
||||||
|
|
||||||
def get_other_peers(
|
|
||||||
self, websocket: WebSocketClientProtocol
|
|
||||||
) -> set[WebSocketClientProtocol]:
|
|
||||||
return self._connections - {websocket}
|
|
||||||
|
|
||||||
def get_all_peers(self) -> set[WebSocketClientProtocol]:
|
|
||||||
return self._connections
|
|
||||||
|
|
||||||
|
|
||||||
# Contains the list of websocket connections handled by this process.
|
|
||||||
# It's a mapping of map_id to a set of the active websocket connections
|
|
||||||
CONNECTIONS: defaultdict[int, Connections] = defaultdict(Connections)
|
|
||||||
|
|
||||||
|
|
||||||
class JoinRequest(BaseModel):
|
|
||||||
kind: Literal["JoinRequest"] = "JoinRequest"
|
|
||||||
token: str
|
|
||||||
|
|
||||||
|
|
||||||
class OperationMessage(BaseModel):
|
|
||||||
"""Message sent from one peer to all the others"""
|
|
||||||
|
|
||||||
kind: Literal["OperationMessage"] = "OperationMessage"
|
|
||||||
verb: Literal["upsert", "update", "delete"]
|
|
||||||
subject: Literal["map", "datalayer", "feature"]
|
|
||||||
metadata: Optional[dict] = None
|
|
||||||
key: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class PeerMessage(BaseModel):
|
|
||||||
"""Message sent from a specific peer to another one"""
|
|
||||||
|
|
||||||
kind: Literal["PeerMessage"] = "PeerMessage"
|
|
||||||
sender: str
|
|
||||||
recipient: str
|
|
||||||
# The message can be whatever the peers want. It's not checked by the server.
|
|
||||||
message: dict
|
|
||||||
|
|
||||||
|
|
||||||
class ServerRequest(BaseModel):
|
|
||||||
"""A request towards the server"""
|
|
||||||
|
|
||||||
kind: Literal["Server"] = "Server"
|
|
||||||
action: Literal["list-peers"]
|
|
||||||
|
|
||||||
|
|
||||||
class Request(RootModel):
|
|
||||||
"""Any message coming from the websocket should be one of these, and will be rejected otherwise."""
|
|
||||||
|
|
||||||
root: Union[ServerRequest, PeerMessage, OperationMessage] = Field(
|
|
||||||
discriminator="kind"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class JoinResponse(BaseModel):
|
|
||||||
"""Server response containing the list of peers"""
|
|
||||||
|
|
||||||
kind: Literal["JoinResponse"] = "JoinResponse"
|
|
||||||
peers: list
|
|
||||||
uuid: str
|
|
||||||
|
|
||||||
|
|
||||||
class ListPeersResponse(BaseModel):
|
|
||||||
kind: Literal["ListPeersResponse"] = "ListPeersResponse"
|
|
||||||
peers: list
|
|
||||||
|
|
||||||
|
|
||||||
async def join_and_listen(
|
|
||||||
map_id: int, permissions: list, user: str | int, websocket: WebSocketClientProtocol
|
|
||||||
):
|
|
||||||
"""Join a "room" with other connected peers, and wait for messages."""
|
|
||||||
logging.debug(f"{user} joined room #{map_id}")
|
|
||||||
connections: Connections = CONNECTIONS[map_id]
|
|
||||||
_id: str = connections.join(websocket)
|
|
||||||
|
|
||||||
# Assign an ID to the joining peer and return it the list of connected peers.
|
|
||||||
peers: list[WebSocketClientProtocol] = [
|
|
||||||
connections.get_id(p) for p in connections.get_all_peers()
|
|
||||||
]
|
|
||||||
response = JoinResponse(uuid=_id, peers=peers)
|
|
||||||
await websocket.send(response.model_dump_json())
|
|
||||||
|
|
||||||
# Notify all other peers of the new list of connected peers.
|
|
||||||
message = ListPeersResponse(peers=peers)
|
|
||||||
websockets.broadcast(
|
|
||||||
connections.get_other_peers(websocket), message.model_dump_json()
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for raw_message in websocket:
|
|
||||||
if raw_message == "ping":
|
|
||||||
await websocket.send("pong")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# recompute the peers list at the time of message-sending.
|
|
||||||
# as doing so beforehand would miss new connections
|
|
||||||
other_peers = connections.get_other_peers(websocket)
|
|
||||||
try:
|
|
||||||
incoming = Request.model_validate_json(raw_message)
|
|
||||||
except ValidationError as e:
|
|
||||||
error = f"An error occurred when receiving the following message: {raw_message!r}"
|
|
||||||
logging.error(error, e)
|
|
||||||
else:
|
|
||||||
match incoming.root:
|
|
||||||
# Broadcast all operation messages to connected peers
|
|
||||||
case OperationMessage():
|
|
||||||
websockets.broadcast(other_peers, raw_message)
|
|
||||||
|
|
||||||
# Send peer messages to the proper peer
|
|
||||||
case PeerMessage(recipient=_id):
|
|
||||||
peer = connections.get(_id)
|
|
||||||
if peer:
|
|
||||||
await peer.send(raw_message)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# On disconnect, remove the connection from the pool
|
|
||||||
connections.leave(websocket)
|
|
||||||
|
|
||||||
# TODO: refactor this in a separate method.
|
|
||||||
# Notify all other peers of the new list of connected peers.
|
|
||||||
peers = [connections.get_id(p) for p in connections.get_all_peers()]
|
|
||||||
message = ListPeersResponse(peers=peers)
|
|
||||||
websockets.broadcast(
|
|
||||||
connections.get_other_peers(websocket), message.model_dump_json()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def handler(websocket: WebSocketClientProtocol):
|
|
||||||
"""Main WebSocket handler.
|
|
||||||
|
|
||||||
Check if the permission is granted and let the peer enter a room.
|
|
||||||
"""
|
|
||||||
raw_message = await websocket.recv()
|
|
||||||
|
|
||||||
# The first event should always be 'join'
|
|
||||||
message: JoinRequest = JoinRequest.model_validate_json(raw_message)
|
|
||||||
signed = TimestampSigner().unsign_object(message.token, max_age=30)
|
|
||||||
user, map_id, permissions = signed.values()
|
|
||||||
|
|
||||||
# Check if permissions for this map have been granted by the server
|
|
||||||
if "edit" in signed["permissions"]:
|
|
||||||
await join_and_listen(map_id, permissions, user, websocket)
|
|
||||||
|
|
||||||
|
|
||||||
def run(host: str, port: int):
|
|
||||||
if not settings.WEBSOCKET_ENABLED:
|
|
||||||
msg = (
|
|
||||||
"WEBSOCKET_ENABLED should be set to True to run the WebSocket Server. "
|
|
||||||
"See the documentation at "
|
|
||||||
"https://docs.umap-project.org/en/stable/config/settings/#websocket_enabled "
|
|
||||||
"for more information."
|
|
||||||
)
|
|
||||||
print(msg)
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
async def _serve():
|
|
||||||
async with serve(handler, host, port):
|
|
||||||
logging.debug(f"Waiting for connections on {host}:{port}")
|
|
||||||
await asyncio.Future() # run forever
|
|
||||||
|
|
||||||
try:
|
|
||||||
asyncio.run(_serve())
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("Closing WebSocket server")
|
|
Loading…
Reference in a new issue