wip(sync): move sync related code to a dedicated module

This commit is contained in:
Yohan Boniface 2024-12-26 15:51:05 +01:00
parent f7572c4893
commit e865c31d69
6 changed files with 60 additions and 277 deletions

View file

@ -5,7 +5,7 @@ from channels.security.websocket import AllowedHostsOriginValidator
from django.core.asgi import get_asgi_application
from django.urls import re_path
from . import consumers
from .sync import consumers
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "umap.settings")
# Initialize Django ASGI application early to ensure the AppRegistry

View file

@ -1,23 +0,0 @@
from django.conf import settings
from django.core.management.base import BaseCommand
from umap import websocket_server
class Command(BaseCommand):
help = "Run the websocket server"
def add_arguments(self, parser):
parser.add_argument(
"--host",
help="The server host to bind to.",
default=settings.WEBSOCKET_BACK_HOST,
)
parser.add_argument(
"--port",
help="The server port to bind to.",
default=settings.WEBSOCKET_BACK_PORT,
)
def handle(self, *args, **options):
websocket_server.run(options["host"], options["port"])

0
umap/sync/__init__.py Normal file
View file

View file

@ -1,14 +1,16 @@
import logging
from channels.generic.websocket import AsyncWebsocketConsumer
from django.core.signing import TimestampSigner
from pydantic import ValidationError
from .websocket_server import (
from .payloads import (
JoinRequest,
JoinResponse,
ListPeersResponse,
OperationMessage,
PeerMessage,
Request,
ValidationError,
)
@ -33,7 +35,6 @@ class SyncConsumer(AsyncWebsocketConsumer):
return self.channel_layer.groups[self.map_id].keys()
async def connect(self):
print("connect")
self.map_id = self.scope["url_route"]["kwargs"]["map_id"]
# Join room group
@ -43,7 +44,6 @@ class SyncConsumer(AsyncWebsocketConsumer):
await self.send_peers_list()
async def disconnect(self, close_code):
print("disconnect")
await self.channel_layer.group_discard(self.map_id, self.channel_name)
await self.send_peers_list()
@ -52,43 +52,32 @@ class SyncConsumer(AsyncWebsocketConsumer):
await self.broadcast(message.model_dump_json())
async def broadcast(self, message):
print("broadcast", message)
# Send to one all channels
await self.channel_layer.group_send(
self.map_id, {"message": message, "type": "on_message"}
)
async def send_to(self, channel, message):
print("pair to pair", channel, message)
# Send to one given channel
await self.channel_layer.send(
channel, {"message": message, "type": "on_message"}
)
async def on_message(self, event):
# This is what the consummers does for a single channel
# Send to self channel
await self.send(event["message"])
async def receive(self, text_data):
print("receive")
print(text_data)
if text_data == "ping":
return await self.send("pong")
# recompute the peers list at the time of message-sending.
# as doing so beforehand would miss new connections
# other_peers = connections.get_other_peers(websocket)
# await self.send("pong" + self.channel_name)
# await self.channel_layer.group_send(
# self.map_id,
# {"message": "pouet " + self.channel_name, "type": "broadcast"},
# )
try:
incoming = Request.model_validate_json(text_data)
except ValidationError:
error = (
except ValidationError as error:
message = (
f"An error occurred when receiving the following message: {text_data!r}"
)
print(error)
# logging.error(error, e)
logging.error(message, error)
else:
match incoming.root:
# Broadcast all operation messages to connected peers
@ -100,34 +89,4 @@ class SyncConsumer(AsyncWebsocketConsumer):
# Send peer messages to the proper peer
case PeerMessage():
print("Received peermessage", incoming.root)
await self.send_to(incoming.root.recipient, text_data)
# Send peer messages to the proper peer
# case PeerMessage(recipient=_id):
# peer = connections.get(_id)
# if peer:
# await peer.send(text_data)
# message = JoinRequest.model_validate_json(text_data)
# signed = TimestampSigner().unsign_object(message.token, max_age=30)
# user, map_id, permissions = signed.values()
# # Check if permissions for this map have been granted by the server
# if "edit" in signed["permissions"]:
# connections = CONNECTIONS[map_id]
# _id = connections.join(self)
# # Assign an ID to the joining peer and return it the list of connected peers.
# peers: list[WebSocketClientProtocol] = [
# connections.get_id(p) for p in connections.get_all_peers()
# ]
# response = JoinResponse(uuid=_id, peers=peers)
# await self.send(response.model_dump_json())
# # await join_and_listen(map_id, permissions, user, websocket)
# # text_data_json = json.loads(text_data)
# # message = text_data_json["message"]
# # self.send(text_data=json.dumps({"message": message}))

49
umap/sync/payloads.py Normal file
View file

@ -0,0 +1,49 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field, RootModel
class JoinRequest(BaseModel):
kind: Literal["JoinRequest"] = "JoinRequest"
token: str
class OperationMessage(BaseModel):
"""Message sent from one peer to all the others"""
kind: Literal["OperationMessage"] = "OperationMessage"
verb: Literal["upsert", "update", "delete"]
subject: Literal["map", "datalayer", "feature"]
metadata: Optional[dict] = None
key: Optional[str] = None
class PeerMessage(BaseModel):
"""Message sent from a specific peer to another one"""
kind: Literal["PeerMessage"] = "PeerMessage"
sender: str
recipient: str
# The message can be whatever the peers want. It's not checked by the server.
message: dict
class Request(RootModel):
"""Any message coming from the websocket should be one of these, and will be rejected otherwise."""
root: Union[PeerMessage, OperationMessage, JoinRequest] = Field(
discriminator="kind"
)
class JoinResponse(BaseModel):
"""Server response containing the list of peers"""
kind: Literal["JoinResponse"] = "JoinResponse"
peers: list
uuid: str
class ListPeersResponse(BaseModel):
kind: Literal["ListPeersResponse"] = "ListPeersResponse"
peers: list

View file

@ -1,202 +0,0 @@
#!/usr/bin/env python
import asyncio
import logging
import uuid
from collections import defaultdict
from typing import Literal, Optional, Union
import websockets
from django.conf import settings
from django.core.signing import TimestampSigner
from pydantic import BaseModel, Field, RootModel, ValidationError
from websockets import WebSocketClientProtocol
from websockets.server import serve
class Connections:
def __init__(self) -> None:
self._connections: set[WebSocketClientProtocol] = set()
self._ids: dict[WebSocketClientProtocol, str] = dict()
def join(self, websocket: WebSocketClientProtocol) -> str:
self._connections.add(websocket)
_id = str(uuid.uuid4())
self._ids[websocket] = _id
return _id
def leave(self, websocket: WebSocketClientProtocol) -> None:
self._connections.remove(websocket)
del self._ids[websocket]
def get(self, id) -> WebSocketClientProtocol:
# use an iterator to stop iterating as soon as we found
return next(k for k, v in self._ids.items() if v == id)
def get_id(self, websocket: WebSocketClientProtocol):
return self._ids[websocket]
def get_other_peers(
self, websocket: WebSocketClientProtocol
) -> set[WebSocketClientProtocol]:
return self._connections - {websocket}
def get_all_peers(self) -> set[WebSocketClientProtocol]:
return self._connections
# Contains the list of websocket connections handled by this process.
# It's a mapping of map_id to a set of the active websocket connections
CONNECTIONS: defaultdict[int, Connections] = defaultdict(Connections)
class JoinRequest(BaseModel):
kind: Literal["JoinRequest"] = "JoinRequest"
token: str
class OperationMessage(BaseModel):
"""Message sent from one peer to all the others"""
kind: Literal["OperationMessage"] = "OperationMessage"
verb: Literal["upsert", "update", "delete"]
subject: Literal["map", "datalayer", "feature"]
metadata: Optional[dict] = None
key: Optional[str] = None
class PeerMessage(BaseModel):
"""Message sent from a specific peer to another one"""
kind: Literal["PeerMessage"] = "PeerMessage"
sender: str
recipient: str
# The message can be whatever the peers want. It's not checked by the server.
message: dict
class ServerRequest(BaseModel):
"""A request towards the server"""
kind: Literal["Server"] = "Server"
action: Literal["list-peers"]
class Request(RootModel):
"""Any message coming from the websocket should be one of these, and will be rejected otherwise."""
root: Union[ServerRequest, PeerMessage, OperationMessage, JoinRequest] = Field(
discriminator="kind"
)
class JoinResponse(BaseModel):
"""Server response containing the list of peers"""
kind: Literal["JoinResponse"] = "JoinResponse"
peers: list
uuid: str
class ListPeersResponse(BaseModel):
kind: Literal["ListPeersResponse"] = "ListPeersResponse"
peers: list
async def join_and_listen(
map_id: int, permissions: list, user: str | int, websocket: WebSocketClientProtocol
):
"""Join a "room" with other connected peers, and wait for messages."""
logging.debug(f"{user} joined room #{map_id}")
connections: Connections = CONNECTIONS[map_id]
_id: str = connections.join(websocket)
# Assign an ID to the joining peer and return it the list of connected peers.
peers: list[WebSocketClientProtocol] = [
connections.get_id(p) for p in connections.get_all_peers()
]
response = JoinResponse(uuid=_id, peers=peers)
await websocket.send(response.model_dump_json())
# Notify all other peers of the new list of connected peers.
message = ListPeersResponse(peers=peers)
websockets.broadcast(
connections.get_other_peers(websocket), message.model_dump_json()
)
try:
async for raw_message in websocket:
if raw_message == "ping":
await websocket.send("pong")
continue
# recompute the peers list at the time of message-sending.
# as doing so beforehand would miss new connections
other_peers = connections.get_other_peers(websocket)
try:
incoming = Request.model_validate_json(raw_message)
except ValidationError as e:
error = f"An error occurred when receiving the following message: {raw_message!r}"
logging.error(error, e)
else:
match incoming.root:
# Broadcast all operation messages to connected peers
case OperationMessage():
websockets.broadcast(other_peers, raw_message)
# Send peer messages to the proper peer
case PeerMessage(recipient=_id):
peer = connections.get(_id)
if peer:
await peer.send(raw_message)
finally:
# On disconnect, remove the connection from the pool
connections.leave(websocket)
# TODO: refactor this in a separate method.
# Notify all other peers of the new list of connected peers.
peers = [connections.get_id(p) for p in connections.get_all_peers()]
message = ListPeersResponse(peers=peers)
websockets.broadcast(
connections.get_other_peers(websocket), message.model_dump_json()
)
async def handler(websocket: WebSocketClientProtocol):
"""Main WebSocket handler.
Check if the permission is granted and let the peer enter a room.
"""
raw_message = await websocket.recv()
# The first event should always be 'join'
message: JoinRequest = JoinRequest.model_validate_json(raw_message)
signed = TimestampSigner().unsign_object(message.token, max_age=30)
user, map_id, permissions = signed.values()
# Check if permissions for this map have been granted by the server
if "edit" in signed["permissions"]:
await join_and_listen(map_id, permissions, user, websocket)
def run(host: str, port: int):
if not settings.WEBSOCKET_ENABLED:
msg = (
"WEBSOCKET_ENABLED should be set to True to run the WebSocket Server. "
"See the documentation at "
"https://docs.umap-project.org/en/stable/config/settings/#websocket_enabled "
"for more information."
)
print(msg)
exit(1)
async def _serve():
async with serve(handler, host, port):
logging.debug(f"Waiting for connections on {host}:{port}")
await asyncio.Future() # run forever
try:
asyncio.run(_serve())
except KeyboardInterrupt:
print("Closing WebSocket server")