wip(sync): POC of using Redis for pubsub (#2426)
Some checks are pending
Test & Docs / tests (postgresql, 3.10) (push) Waiting to run
Test & Docs / tests (postgresql, 3.12) (push) Waiting to run
Test & Docs / lint (push) Waiting to run
Test & Docs / docs (push) Waiting to run

## TODO

- [x] add expire to peers registry hash in redis, as for now when the
server closes the connection we have extra users (edit: we cleaned
manually, as HEXPIRE is not available in FOSS version of Redis)
- [x] make that the peer uuid is created by the client, so when it
reconnects, it uses the same, and does not create a new one
- [ ] see if we can use a connection_pool
- [x] use dynamic websocket_uri (that must include the map id)
- [x] integrate Redis in playwright tests
This commit is contained in:
Yohan Boniface 2025-01-27 15:58:30 +01:00 committed by GitHub
commit 60918e6ca5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 361 additions and 340 deletions

View file

@ -20,7 +20,11 @@ jobs:
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
redis:
image: redis
options: --health-cmd "redis-cli ping" --health-interval 10s --health-timeout 5s --health-retries 5
ports:
- 6379:6379
strategy:
fail-fast: false
matrix:
@ -48,6 +52,8 @@ jobs:
DJANGO_SETTINGS_MODULE: 'umap.tests.settings'
UMAP_SETTINGS: 'umap/tests/settings.py'
PLAYWRIGHT_TIMEOUT: '20000'
REDIS_HOST: localhost
REDIS_PORT: 6379
lint:
runs-on: ubuntu-latest
steps:

View file

@ -54,6 +54,7 @@ dev = [
"isort==5.13.2",
]
test = [
"daphne==4.1.2",
"factory-boy==3.3.1",
"playwright>=1.39",
"pytest==8.3.4",
@ -70,10 +71,8 @@ s3 = [
"django-storages[s3]==1.14.4",
]
sync = [
"channels==4.2.0",
"daphne==4.1.2",
"pydantic==2.10.5",
"websockets==13.1",
"redis==5.2.1",
]
[project.scripts]
@ -104,3 +103,6 @@ format_css=true
blank_line_after_tag="load,extends"
line_break_after_multiline_tag=true
[lint]
# Disable autoremove of unused import.
unfixable = ["F401"]

View file

@ -1,15 +1,20 @@
import os
from channels.routing import ProtocolTypeRouter
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "umap.settings")
from django.core.asgi import get_asgi_application
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "umap.settings")
from .sync.app import application as ws_application
# Initialize Django ASGI application early to ensure the AppRegistry
# is populated before importing code that may import ORM models.
django_asgi_app = get_asgi_application()
application = ProtocolTypeRouter(
{
"http": django_asgi_app,
}
)
async def application(scope, receive, send):
if scope["type"] == "http":
await django_asgi_app(scope, receive, send)
elif scope["type"] == "websocket":
await ws_application(scope, receive, send)
else:
raise NotImplementedError(f"Unknown scope type {scope['type']}")

View file

@ -1,23 +0,0 @@
from django.conf import settings
from django.core.management.base import BaseCommand
from umap import websocket_server
class Command(BaseCommand):
help = "Run the websocket server"
def add_arguments(self, parser):
parser.add_argument(
"--host",
help="The server host to bind to.",
default=settings.WEBSOCKET_BACK_HOST,
)
parser.add_argument(
"--port",
help="The server port to bind to.",
default=settings.WEBSOCKET_BACK_PORT,
)
def handle(self, *args, **options):
websocket_server.run(options["host"], options["port"])

View file

@ -342,4 +342,5 @@ LOGGING = {
WEBSOCKET_ENABLED = env.bool("WEBSOCKET_ENABLED", default=False)
WEBSOCKET_BACK_HOST = env("WEBSOCKET_BACK_HOST", default="localhost")
WEBSOCKET_BACK_PORT = env.int("WEBSOCKET_BACK_PORT", default=8001)
WEBSOCKET_FRONT_URI = env("WEBSOCKET_FRONT_URI", default="ws://localhost:8001")
REDIS_URL = "redis://localhost:6379"

View file

@ -62,6 +62,7 @@ export class SyncEngine {
this._reconnectDelay = RECONNECT_DELAY
this.websocketConnected = false
this.closeRequested = false
this.peerId = Utils.generateId()
}
async authenticate() {
@ -76,10 +77,14 @@ export class SyncEngine {
}
start(authToken) {
const path = this._umap.urls.get('ws_sync', { map_id: this._umap.id })
const protocol = window.location.protocol === 'http:' ? 'ws:' : 'wss:'
this.transport = new WebSocketTransport(
this._umap.properties.websocketURI,
`${protocol}//${window.location.host}${path}`,
authToken,
this
this,
this.peerId,
this._umap.properties.user?.name
)
}
@ -125,7 +130,7 @@ export class SyncEngine {
if (this.offline) return
if (this.transport) {
this.transport.send('OperationMessage', message)
this.transport.send('OperationMessage', { sender: this.peerId, ...message })
}
}
@ -142,7 +147,7 @@ export class SyncEngine {
}
getNumberOfConnectedPeers() {
if (this.peers) return this.peers.length
if (this.peers) return Object.keys(this.peers).length
return 0
}
@ -177,6 +182,7 @@ export class SyncEngine {
* @param {Object} payload
*/
onOperationMessage(payload) {
if (payload.sender === this.peerId) return
this._operations.storeRemoteOperations([payload])
this._applyOperation(payload)
}
@ -188,9 +194,8 @@ export class SyncEngine {
* @param {string} payload.uuid The server-assigned uuid for this peer
* @param {string[]} payload.peers The list of peers uuids
*/
onJoinResponse({ uuid, peers }) {
debug('received join response', { uuid, peers })
this.uuid = uuid
onJoinResponse({ peer, peers }) {
debug('received join response', { peer, peers })
this.onListPeersResponse({ peers })
// Get one peer at random
@ -211,7 +216,7 @@ export class SyncEngine {
* @param {string[]} payload.peers The list of peers uuids
*/
onListPeersResponse({ peers }) {
debug('received peerinfo', { peers })
debug('received peerinfo', peers)
this.peers = peers
this.updaters.map.update({ key: 'numberOfConnectedPeers' })
}
@ -286,7 +291,7 @@ export class SyncEngine {
sendToPeer(recipient, verb, payload) {
payload.verb = verb
this.transport.send('PeerMessage', {
sender: this.uuid,
sender: this.peerId,
recipient: recipient,
message: payload,
})
@ -298,7 +303,7 @@ export class SyncEngine {
* @returns {string|bool} the selected peer uuid, or False if none was found.
*/
_getRandomPeer() {
const otherPeers = this.peers.filter((p) => p !== this.uuid)
const otherPeers = Object.keys(this.peers).filter((p) => p !== this.peerId)
if (otherPeers.length > 0) {
const random = Math.floor(Math.random() * otherPeers.length)
return otherPeers[random]
@ -484,7 +489,7 @@ export class Operations {
return (
Utils.deepEqual(local.subject, remote.subject) &&
Utils.deepEqual(local.metadata, remote.metadata) &&
(!shouldCheckKey || (shouldCheckKey && local.key == remote.key))
(!shouldCheckKey || (shouldCheckKey && local.key === remote.key))
)
}
}

View file

@ -3,13 +3,13 @@ const PING_INTERVAL = 30000
const FIRST_CONNECTION_TIMEOUT = 2000
export class WebSocketTransport {
constructor(webSocketURI, authToken, messagesReceiver) {
constructor(webSocketURI, authToken, messagesReceiver, peerId, username) {
this.receiver = messagesReceiver
this.websocket = new WebSocket(webSocketURI)
this.websocket.onopen = () => {
this.send('JoinRequest', { token: authToken })
this.send('JoinRequest', { token: authToken, peer: peerId, username })
this.receiver.onConnection()
}
this.websocket.addEventListener('message', this.onMessage.bind(this))
@ -21,6 +21,10 @@ export class WebSocketTransport {
}
}
this.websocket.onerror = (error) => {
console.log('WS ERROR', error)
}
this.ensureOpen = setInterval(() => {
if (this.websocket.readyState !== WebSocket.OPEN) {
this.websocket.close()
@ -34,6 +38,7 @@ export class WebSocketTransport {
// See https://making.close.com/posts/reliable-websockets/ for more details.
this.pingInterval = setInterval(() => {
if (this.websocket.readyState === WebSocket.OPEN) {
console.log('sending ping')
this.websocket.send('ping')
this.pongReceived = false
setTimeout(() => {
@ -63,6 +68,7 @@ export class WebSocketTransport {
}
close() {
console.log('Closing')
this.receiver.closeRequested = true
this.websocket.close()
}

0
umap/sync/__init__.py Normal file
View file

181
umap/sync/app.py Normal file
View file

@ -0,0 +1,181 @@
import asyncio
import logging
import redis.asyncio as redis
from django.conf import settings
from django.core.signing import TimestampSigner
from django.urls import path
from pydantic import ValidationError
from .payloads import (
JoinRequest,
JoinResponse,
ListPeersResponse,
OperationMessage,
PeerMessage,
Request,
)
async def application(scope, receive, send):
path = scope["path"].lstrip("/")
for pattern in urlpatterns:
if matched := pattern.resolve(path):
await matched.func(scope, receive, send, **matched.kwargs)
break
else:
await send({"type": "websocket.close"})
async def sync(scope, receive, send, **kwargs):
peer = Peer(kwargs["map_id"])
peer._send = send
while True:
event = await receive()
if event["type"] == "websocket.connect":
try:
await peer.connect()
await send({"type": "websocket.accept"})
except ValueError:
await send({"type": "websocket.close"})
if event["type"] == "websocket.disconnect":
await peer.disconnect()
break
if event["type"] == "websocket.receive":
if event["text"] == "ping":
await send({"type": "websocket.send", "text": "pong"})
else:
await peer.receive(event["text"])
class Peer:
def __init__(self, map_id, username=None):
self.username = username or ""
self.map_id = map_id
self.is_authenticated = False
self._subscriptions = []
@property
def room_key(self):
return f"umap:{self.map_id}"
@property
def peer_key(self):
return f"user:{self.map_id}:{self.peer_id}"
async def get_peers(self):
known = await self.client.hgetall(self.room_key)
active = await self.client.pubsub_channels(f"user:{self.map_id}:*")
if not active:
# Poor man way of deleting stale usernames from the store
# HEXPIRE command is not in the open source Redis version
await self.client.delete(self.room_key)
await self.store_username()
active = [name.split(b":")[-1] for name in active]
if self.peer_id.encode() not in active:
# Our connection may not yet be active
active.append(self.peer_id.encode())
return {k: v for k, v in known.items() if k in active}
async def store_username(self):
await self.client.hset(self.room_key, self.peer_id, self.username)
async def listen_to_channel(self, channel_name):
async def reader(pubsub):
await pubsub.subscribe(channel_name)
while True:
if pubsub.connection is None:
# It has been unsubscribed/closed.
break
try:
message = await pubsub.get_message(ignore_subscribe_messages=True)
except Exception as err:
print(err)
break
if message is not None:
await self.send(message["data"].decode())
await asyncio.sleep(0.001) # Be nice with the server
async with self.client.pubsub() as pubsub:
self._subscriptions.append(pubsub)
asyncio.create_task(reader(pubsub))
async def listen(self):
await self.listen_to_channel(self.room_key)
await self.listen_to_channel(self.peer_key)
async def connect(self):
self.client = redis.from_url(settings.REDIS_URL)
async def disconnect(self):
await self.client.hdel(self.room_key, self.peer_id)
for pubsub in self._subscriptions:
await pubsub.unsubscribe()
await pubsub.close()
await self.send_peers_list()
await self.client.aclose()
async def send_peers_list(self):
message = ListPeersResponse(peers=await self.get_peers())
await self.broadcast(message.model_dump_json())
async def broadcast(self, message):
print("BROADCASTING", message)
# Send to all channels (including sender!)
await self.client.publish(self.room_key, message)
async def send_to(self, peer_id, message):
print("SEND TO", peer_id, message)
# Send to one given channel
await self.client.publish(f"user:{self.map_id}:{peer_id}", message)
async def receive(self, text_data):
if not self.is_authenticated:
print("AUTHENTICATING", text_data)
message = JoinRequest.model_validate_json(text_data)
signed = TimestampSigner().unsign_object(message.token, max_age=30)
user, map_id, permissions = signed.values()
assert str(map_id) == self.map_id
if "edit" not in permissions:
return await self.disconnect()
self.peer_id = message.peer
self.username = message.username
print("AUTHENTICATED", self.peer_id)
await self.store_username()
await self.listen()
response = JoinResponse(peer=self.peer_id, peers=await self.get_peers())
await self.send(response.model_dump_json())
await self.send_peers_list()
self.is_authenticated = True
return
try:
incoming = Request.model_validate_json(text_data)
except ValidationError as error:
message = (
f"An error occurred when receiving the following message: {text_data!r}"
)
logging.error(message, error)
else:
match incoming.root:
# Broadcast all operation messages to connected peers
case OperationMessage():
await self.broadcast(text_data)
# Send peer messages to the proper peer
case PeerMessage():
await self.send_to(incoming.root.recipient, text_data)
async def send(self, text):
print(" FORWARDING TO", self.peer_id, text)
try:
await self._send({"type": "websocket.send", "text": text})
except Exception as err:
print("Error sending message:", text)
print(err)
urlpatterns = [path("ws/sync/<str:map_id>", name="ws_sync", view=sync)]

49
umap/sync/payloads.py Normal file
View file

@ -0,0 +1,49 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field, RootModel
class JoinRequest(BaseModel):
kind: Literal["JoinRequest"] = "JoinRequest"
token: str
peer: str
username: Optional[str] = ""
class OperationMessage(BaseModel):
"""Message sent from one peer to all the others"""
kind: Literal["OperationMessage"] = "OperationMessage"
verb: Literal["upsert", "update", "delete"]
subject: Literal["map", "datalayer", "feature"]
metadata: Optional[dict] = None
key: Optional[str] = None
class PeerMessage(BaseModel):
"""Message sent from a specific peer to another one"""
kind: Literal["PeerMessage"] = "PeerMessage"
sender: str
recipient: str
# The message can be whatever the peers want. It's not checked by the server.
message: dict
class Request(RootModel):
"""Any message coming from the websocket should be one of these, and will be rejected otherwise."""
root: Union[PeerMessage, OperationMessage] = Field(discriminator="kind")
class JoinResponse(BaseModel):
"""Server response containing the list of peers"""
kind: Literal["JoinResponse"] = "JoinResponse"
peers: dict
peer: str
class ListPeersResponse(BaseModel):
kind: Literal["ListPeersResponse"] = "ListPeersResponse"
peers: dict

View file

@ -1,12 +1,13 @@
import os
import re
import subprocess
import time
from pathlib import Path
import pytest
from daphne.testing import DaphneProcess
from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
from playwright.sync_api import expect
from umap.asgi import application
from ..base import mock_tiles
@ -67,23 +68,15 @@ def login(new_page, settings, live_server):
return do_login
@pytest.fixture
def websocket_server():
# Find the test-settings, and put them in the current environment
settings_path = (Path(__file__).parent.parent / "settings.py").absolute().as_posix()
os.environ["UMAP_SETTINGS"] = settings_path
@pytest.fixture(scope="function")
def asgi_live_server(request, live_server):
server = DaphneProcess("localhost", lambda: ASGIStaticFilesHandler(application))
server.start()
server.ready.wait()
port = server.port.value
server.url = f"http://localhost:{port}"
ds_proc = subprocess.Popen(
[
"umap",
"run_websocket_server",
],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
time.sleep(2)
# Ensure it started properly before yielding
assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8")
yield ds_proc
# Shut it down at the end of the pytest session
ds_proc.terminate()
yield server
server.terminate()
server.join()

View file

@ -1,6 +1,8 @@
import re
import pytest
import redis
from django.conf import settings
from playwright.sync_api import expect
from umap.models import DataLayer, Map
@ -9,11 +11,21 @@ from ..base import DataLayerFactory, MapFactory
DATALAYER_UPDATE = re.compile(r".*/datalayer/update/.*")
pytestmark = pytest.mark.django_db
def setup_function():
# Sync client to prevent headache with pytest / pytest-asyncio and async
client = redis.from_url(settings.REDIS_URL)
# Make sure there are no dead peers in the Redis hash, otherwise asking for
# operations from another peer may never be answered
# FIXME this should not happen in an ideal world
assert client.connection_pool.connection_kwargs["db"] == 15
client.flushdb()
@pytest.mark.xdist_group(name="websockets")
def test_websocket_connection_can_sync_markers(
new_page, live_server, websocket_server, tilelayer
):
def test_websocket_connection_can_sync_markers(new_page, asgi_live_server, tilelayer):
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
map.settings["properties"]["syncEnabled"] = True
map.save()
@ -21,9 +33,9 @@ def test_websocket_connection_can_sync_markers(
# Create two tabs
peerA = new_page("Page A")
peerA.goto(f"{live_server.url}{map.get_absolute_url()}?edit")
peerA.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit")
peerB = new_page("Page B")
peerB.goto(f"{live_server.url}{map.get_absolute_url()}?edit")
peerB.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit")
a_marker_pane = peerA.locator(".leaflet-marker-pane > div")
b_marker_pane = peerB.locator(".leaflet-marker-pane > div")
@ -79,9 +91,7 @@ def test_websocket_connection_can_sync_markers(
@pytest.mark.xdist_group(name="websockets")
def test_websocket_connection_can_sync_polygons(
context, live_server, websocket_server, tilelayer
):
def test_websocket_connection_can_sync_polygons(context, asgi_live_server, tilelayer):
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
map.settings["properties"]["syncEnabled"] = True
map.save()
@ -89,9 +99,9 @@ def test_websocket_connection_can_sync_polygons(
# Create two tabs
peerA = context.new_page()
peerA.goto(f"{live_server.url}{map.get_absolute_url()}?edit")
peerA.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit")
peerB = context.new_page()
peerB.goto(f"{live_server.url}{map.get_absolute_url()}?edit")
peerB.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit")
b_map_el = peerB.locator("#map")
@ -164,7 +174,7 @@ def test_websocket_connection_can_sync_polygons(
@pytest.mark.xdist_group(name="websockets")
def test_websocket_connection_can_sync_map_properties(
new_page, live_server, websocket_server, tilelayer
new_page, asgi_live_server, tilelayer
):
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
map.settings["properties"]["syncEnabled"] = True
@ -173,9 +183,9 @@ def test_websocket_connection_can_sync_map_properties(
# Create two tabs
peerA = new_page()
peerA.goto(f"{live_server.url}{map.get_absolute_url()}?edit")
peerA.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit")
peerB = new_page()
peerB.goto(f"{live_server.url}{map.get_absolute_url()}?edit")
peerB.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit")
# Name change is synced
peerA.get_by_role("link", name="Edit map name and caption").click()
@ -198,7 +208,7 @@ def test_websocket_connection_can_sync_map_properties(
@pytest.mark.xdist_group(name="websockets")
def test_websocket_connection_can_sync_datalayer_properties(
new_page, live_server, websocket_server, tilelayer
new_page, asgi_live_server, tilelayer
):
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
map.settings["properties"]["syncEnabled"] = True
@ -207,9 +217,9 @@ def test_websocket_connection_can_sync_datalayer_properties(
# Create two tabs
peerA = new_page()
peerA.goto(f"{live_server.url}{map.get_absolute_url()}?edit")
peerA.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit")
peerB = new_page()
peerB.goto(f"{live_server.url}{map.get_absolute_url()}?edit")
peerB.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit")
# Layer addition, name and type are synced
peerA.get_by_role("link", name="Manage layers").click()
@ -227,7 +237,7 @@ def test_websocket_connection_can_sync_datalayer_properties(
@pytest.mark.xdist_group(name="websockets")
def test_websocket_connection_can_sync_cloned_polygons(
context, live_server, websocket_server, tilelayer
context, asgi_live_server, tilelayer
):
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
map.settings["properties"]["syncEnabled"] = True
@ -236,9 +246,9 @@ def test_websocket_connection_can_sync_cloned_polygons(
# Create two tabs
peerA = context.new_page()
peerA.goto(f"{live_server.url}{map.get_absolute_url()}?edit")
peerA.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit")
peerB = context.new_page()
peerB.goto(f"{live_server.url}{map.get_absolute_url()}?edit")
peerB.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit")
b_map_el = peerB.locator("#map")
@ -290,7 +300,7 @@ def test_websocket_connection_can_sync_cloned_polygons(
@pytest.mark.xdist_group(name="websockets")
def test_websocket_connection_can_sync_late_joining_peer(
new_page, live_server, websocket_server, tilelayer
new_page, asgi_live_server, tilelayer
):
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
map.settings["properties"]["syncEnabled"] = True
@ -299,7 +309,7 @@ def test_websocket_connection_can_sync_late_joining_peer(
# Create first peer (A) and have it join immediately
peerA = new_page("Page A")
peerA.goto(f"{live_server.url}{map.get_absolute_url()}?edit")
peerA.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit")
# Add a marker from peer A
a_create_marker = peerA.get_by_title("Draw a marker")
@ -326,7 +336,7 @@ def test_websocket_connection_can_sync_late_joining_peer(
# Now create peer B and have it join
peerB = new_page("Page B")
peerB.goto(f"{live_server.url}{map.get_absolute_url()}?edit")
peerB.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit")
# Check if peer B has received all the updates
b_marker_pane = peerB.locator(".leaflet-marker-pane > div")
@ -351,7 +361,7 @@ def test_websocket_connection_can_sync_late_joining_peer(
@pytest.mark.xdist_group(name="websockets")
def test_should_sync_datalayers(new_page, live_server, websocket_server, tilelayer):
def test_should_sync_datalayers(new_page, asgi_live_server, tilelayer):
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
map.settings["properties"]["syncEnabled"] = True
map.save()
@ -360,9 +370,9 @@ def test_should_sync_datalayers(new_page, live_server, websocket_server, tilelay
# Create two tabs
peerA = new_page("Page A")
peerA.goto(f"{live_server.url}{map.get_absolute_url()}?edit")
peerA.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit")
peerB = new_page("Page B")
peerB.goto(f"{live_server.url}{map.get_absolute_url()}?edit")
peerB.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit")
# Create a new layer from peerA
peerA.get_by_role("link", name="Manage layers").click()
@ -423,9 +433,7 @@ def test_should_sync_datalayers(new_page, live_server, websocket_server, tilelay
@pytest.mark.xdist_group(name="websockets")
def test_should_sync_datalayers_delete(
new_page, live_server, websocket_server, tilelayer
):
def test_should_sync_datalayers_delete(new_page, asgi_live_server, tilelayer):
map = MapFactory(name="sync", edit_status=Map.ANONYMOUS)
map.settings["properties"]["syncEnabled"] = True
map.save()
@ -464,9 +472,9 @@ def test_should_sync_datalayers_delete(
# Create two tabs
peerA = new_page("Page A")
peerA.goto(f"{live_server.url}{map.get_absolute_url()}?edit")
peerA.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit")
peerB = new_page("Page B")
peerB.goto(f"{live_server.url}{map.get_absolute_url()}?edit")
peerB.goto(f"{asgi_live_server.url}{map.get_absolute_url()}?edit")
peerA.get_by_role("button", name="Open browser").click()
expect(peerA.get_by_text("datalayer 1")).to_be_visible()
@ -489,12 +497,10 @@ def test_should_sync_datalayers_delete(
@pytest.mark.xdist_group(name="websockets")
def test_create_and_sync_map(
new_page, live_server, websocket_server, tilelayer, login, user
):
def test_create_and_sync_map(new_page, asgi_live_server, tilelayer, login, user):
# Create a syncable map with peerA
peerA = login(user, prefix="Page A")
peerA.goto(f"{live_server.url}/en/map/new/")
peerA.goto(f"{asgi_live_server.url}/en/map/new/")
with peerA.expect_response(re.compile("./map/create/.*")):
peerA.get_by_role("button", name="Save Draft").click()
peerA.get_by_role("link", name="Map advanced properties").click()

View file

@ -29,3 +29,5 @@ PASSWORD_HASHERS = [
WEBSOCKET_ENABLED = True
WEBSOCKET_BACK_PORT = "8010"
WEBSOCKET_FRONT_URI = "ws://localhost:8010"
REDIS_URL = "redis://localhost:6379/15"

View file

@ -1,22 +0,0 @@
from umap.websocket_server import OperationMessage, PeerMessage, Request, ServerRequest
def test_messages_are_parsed_correctly():
server = Request.model_validate(dict(kind="Server", action="list-peers")).root
assert type(server) is ServerRequest
operation = Request.model_validate(
dict(
kind="OperationMessage",
verb="upsert",
subject="map",
metadata={},
key="key",
)
).root
assert type(operation) is OperationMessage
peer_message = Request.model_validate(
dict(kind="PeerMessage", sender="Alice", recipient="Bob", message={})
).root
assert type(peer_message) is PeerMessage

View file

@ -7,23 +7,36 @@ from django.core.serializers.json import DjangoJSONEncoder
from django.urls import URLPattern, URLResolver, get_resolver
def _urls_for_js(urls=None):
def _get_url_names(module):
def _get_names(resolver):
names = []
for pattern in resolver.url_patterns:
if getattr(pattern, "url_patterns", None):
# Do not add "admin" and other third party apps urls.
if not pattern.namespace:
names.extend(_get_names(pattern))
elif getattr(pattern, "name", None):
names.append(pattern.name)
return names
return _get_names(get_resolver(module))
def _urls_for_js():
"""
Return templated URLs prepared for javascript.
"""
if urls is None:
# prevent circular import
from .urls import i18n_urls, urlpatterns
urls = [
url.name for url in urlpatterns + i18n_urls if getattr(url, "name", None)
]
urls = dict(zip(urls, [get_uri_template(url) for url in urls]))
urls = {}
for module in ["umap.urls", "umap.sync.app"]:
names = _get_url_names(module)
urls.update(
dict(zip(names, [get_uri_template(url, module=module) for url in names]))
)
urls.update(getattr(settings, "UMAP_EXTRA_URLS", {}))
return urls
def get_uri_template(urlname, args=None, prefix=""):
def get_uri_template(urlname, args=None, prefix="", module=None):
"""
Utility function to return an URI Template from a named URL in django
Copied from django-digitalpaper.
@ -45,7 +58,7 @@ def get_uri_template(urlname, args=None, prefix=""):
paths = template % dict([p, "{%s}" % p] for p in args)
return "%s/%s" % (prefix, paths)
resolver = get_resolver(None)
resolver = get_resolver(module)
parts = urlname.split(":")
if len(parts) > 1 and parts[0] in resolver.namespace_dict:
namespace = parts[0]

View file

@ -609,7 +609,6 @@ class MapDetailMixin(SessionMixin):
"umap_version": VERSION,
"featuresHaveOwner": settings.UMAP_DEFAULT_FEATURES_HAVE_OWNERS,
"websocketEnabled": settings.WEBSOCKET_ENABLED,
"websocketURI": settings.WEBSOCKET_FRONT_URI,
"importers": settings.UMAP_IMPORTERS,
"defaultLabelKeys": settings.UMAP_LABEL_KEYS,
}

View file

@ -1,202 +0,0 @@
#!/usr/bin/env python
import asyncio
import logging
import uuid
from collections import defaultdict
from typing import Literal, Optional, Union
import websockets
from django.conf import settings
from django.core.signing import TimestampSigner
from pydantic import BaseModel, Field, RootModel, ValidationError
from websockets import WebSocketClientProtocol
from websockets.server import serve
class Connections:
def __init__(self) -> None:
self._connections: set[WebSocketClientProtocol] = set()
self._ids: dict[WebSocketClientProtocol, str] = dict()
def join(self, websocket: WebSocketClientProtocol) -> str:
self._connections.add(websocket)
_id = str(uuid.uuid4())
self._ids[websocket] = _id
return _id
def leave(self, websocket: WebSocketClientProtocol) -> None:
self._connections.remove(websocket)
del self._ids[websocket]
def get(self, id) -> WebSocketClientProtocol:
# use an iterator to stop iterating as soon as we found
return next(k for k, v in self._ids.items() if v == id)
def get_id(self, websocket: WebSocketClientProtocol):
return self._ids[websocket]
def get_other_peers(
self, websocket: WebSocketClientProtocol
) -> set[WebSocketClientProtocol]:
return self._connections - {websocket}
def get_all_peers(self) -> set[WebSocketClientProtocol]:
return self._connections
# Contains the list of websocket connections handled by this process.
# It's a mapping of map_id to a set of the active websocket connections
CONNECTIONS: defaultdict[int, Connections] = defaultdict(Connections)
class JoinRequest(BaseModel):
kind: Literal["JoinRequest"] = "JoinRequest"
token: str
class OperationMessage(BaseModel):
"""Message sent from one peer to all the others"""
kind: Literal["OperationMessage"] = "OperationMessage"
verb: Literal["upsert", "update", "delete"]
subject: Literal["map", "datalayer", "feature"]
metadata: Optional[dict] = None
key: Optional[str] = None
class PeerMessage(BaseModel):
"""Message sent from a specific peer to another one"""
kind: Literal["PeerMessage"] = "PeerMessage"
sender: str
recipient: str
# The message can be whatever the peers want. It's not checked by the server.
message: dict
class ServerRequest(BaseModel):
"""A request towards the server"""
kind: Literal["Server"] = "Server"
action: Literal["list-peers"]
class Request(RootModel):
"""Any message coming from the websocket should be one of these, and will be rejected otherwise."""
root: Union[ServerRequest, PeerMessage, OperationMessage] = Field(
discriminator="kind"
)
class JoinResponse(BaseModel):
"""Server response containing the list of peers"""
kind: Literal["JoinResponse"] = "JoinResponse"
peers: list
uuid: str
class ListPeersResponse(BaseModel):
kind: Literal["ListPeersResponse"] = "ListPeersResponse"
peers: list
async def join_and_listen(
map_id: int, permissions: list, user: str | int, websocket: WebSocketClientProtocol
):
"""Join a "room" with other connected peers, and wait for messages."""
logging.debug(f"{user} joined room #{map_id}")
connections: Connections = CONNECTIONS[map_id]
_id: str = connections.join(websocket)
# Assign an ID to the joining peer and return it the list of connected peers.
peers: list[WebSocketClientProtocol] = [
connections.get_id(p) for p in connections.get_all_peers()
]
response = JoinResponse(uuid=_id, peers=peers)
await websocket.send(response.model_dump_json())
# Notify all other peers of the new list of connected peers.
message = ListPeersResponse(peers=peers)
websockets.broadcast(
connections.get_other_peers(websocket), message.model_dump_json()
)
try:
async for raw_message in websocket:
if raw_message == "ping":
await websocket.send("pong")
continue
# recompute the peers list at the time of message-sending.
# as doing so beforehand would miss new connections
other_peers = connections.get_other_peers(websocket)
try:
incoming = Request.model_validate_json(raw_message)
except ValidationError as e:
error = f"An error occurred when receiving the following message: {raw_message!r}"
logging.error(error, e)
else:
match incoming.root:
# Broadcast all operation messages to connected peers
case OperationMessage():
websockets.broadcast(other_peers, raw_message)
# Send peer messages to the proper peer
case PeerMessage(recipient=_id):
peer = connections.get(_id)
if peer:
await peer.send(raw_message)
finally:
# On disconnect, remove the connection from the pool
connections.leave(websocket)
# TODO: refactor this in a separate method.
# Notify all other peers of the new list of connected peers.
peers = [connections.get_id(p) for p in connections.get_all_peers()]
message = ListPeersResponse(peers=peers)
websockets.broadcast(
connections.get_other_peers(websocket), message.model_dump_json()
)
async def handler(websocket: WebSocketClientProtocol):
"""Main WebSocket handler.
Check if the permission is granted and let the peer enter a room.
"""
raw_message = await websocket.recv()
# The first event should always be 'join'
message: JoinRequest = JoinRequest.model_validate_json(raw_message)
signed = TimestampSigner().unsign_object(message.token, max_age=30)
user, map_id, permissions = signed.values()
# Check if permissions for this map have been granted by the server
if "edit" in signed["permissions"]:
await join_and_listen(map_id, permissions, user, websocket)
def run(host: str, port: int):
if not settings.WEBSOCKET_ENABLED:
msg = (
"WEBSOCKET_ENABLED should be set to True to run the WebSocket Server. "
"See the documentation at "
"https://docs.umap-project.org/en/stable/config/settings/#websocket_enabled "
"for more information."
)
print(msg)
exit(1)
async def _serve():
async with serve(handler, host, port):
logging.debug(f"Waiting for connections on {host}:{port}")
await asyncio.Future() # run forever
try:
asyncio.run(_serve())
except KeyboardInterrupt:
print("Closing WebSocket server")