mirror of
https://github.com/umap-project/umap.git
synced 2025-05-04 05:31:50 +02:00
wip(sync): move sync related code to a dedicated module
This commit is contained in:
parent
f7572c4893
commit
e865c31d69
6 changed files with 60 additions and 277 deletions
|
@ -5,7 +5,7 @@ from channels.security.websocket import AllowedHostsOriginValidator
|
|||
from django.core.asgi import get_asgi_application
|
||||
from django.urls import re_path
|
||||
|
||||
from . import consumers
|
||||
from .sync import consumers
|
||||
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "umap.settings")
|
||||
# Initialize Django ASGI application early to ensure the AppRegistry
|
||||
|
|
|
@ -1,23 +0,0 @@
|
|||
from django.conf import settings
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from umap import websocket_server
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "Run the websocket server"
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
help="The server host to bind to.",
|
||||
default=settings.WEBSOCKET_BACK_HOST,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
help="The server port to bind to.",
|
||||
default=settings.WEBSOCKET_BACK_PORT,
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
websocket_server.run(options["host"], options["port"])
|
0
umap/sync/__init__.py
Normal file
0
umap/sync/__init__.py
Normal file
|
@ -1,14 +1,16 @@
|
|||
import logging
|
||||
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
from django.core.signing import TimestampSigner
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .websocket_server import (
|
||||
from .payloads import (
|
||||
JoinRequest,
|
||||
JoinResponse,
|
||||
ListPeersResponse,
|
||||
OperationMessage,
|
||||
PeerMessage,
|
||||
Request,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
|
||||
|
@ -33,7 +35,6 @@ class SyncConsumer(AsyncWebsocketConsumer):
|
|||
return self.channel_layer.groups[self.map_id].keys()
|
||||
|
||||
async def connect(self):
|
||||
print("connect")
|
||||
self.map_id = self.scope["url_route"]["kwargs"]["map_id"]
|
||||
|
||||
# Join room group
|
||||
|
@ -43,7 +44,6 @@ class SyncConsumer(AsyncWebsocketConsumer):
|
|||
await self.send_peers_list()
|
||||
|
||||
async def disconnect(self, close_code):
|
||||
print("disconnect")
|
||||
await self.channel_layer.group_discard(self.map_id, self.channel_name)
|
||||
await self.send_peers_list()
|
||||
|
||||
|
@ -52,43 +52,32 @@ class SyncConsumer(AsyncWebsocketConsumer):
|
|||
await self.broadcast(message.model_dump_json())
|
||||
|
||||
async def broadcast(self, message):
|
||||
print("broadcast", message)
|
||||
# Send to one all channels
|
||||
await self.channel_layer.group_send(
|
||||
self.map_id, {"message": message, "type": "on_message"}
|
||||
)
|
||||
|
||||
async def send_to(self, channel, message):
|
||||
print("pair to pair", channel, message)
|
||||
# Send to one given channel
|
||||
await self.channel_layer.send(
|
||||
channel, {"message": message, "type": "on_message"}
|
||||
)
|
||||
|
||||
async def on_message(self, event):
|
||||
# This is what the consummers does for a single channel
|
||||
# Send to self channel
|
||||
await self.send(event["message"])
|
||||
|
||||
async def receive(self, text_data):
|
||||
print("receive")
|
||||
print(text_data)
|
||||
if text_data == "ping":
|
||||
return await self.send("pong")
|
||||
|
||||
# recompute the peers list at the time of message-sending.
|
||||
# as doing so beforehand would miss new connections
|
||||
# other_peers = connections.get_other_peers(websocket)
|
||||
# await self.send("pong" + self.channel_name)
|
||||
# await self.channel_layer.group_send(
|
||||
# self.map_id,
|
||||
# {"message": "pouet " + self.channel_name, "type": "broadcast"},
|
||||
# )
|
||||
try:
|
||||
incoming = Request.model_validate_json(text_data)
|
||||
except ValidationError:
|
||||
error = (
|
||||
except ValidationError as error:
|
||||
message = (
|
||||
f"An error occurred when receiving the following message: {text_data!r}"
|
||||
)
|
||||
print(error)
|
||||
# logging.error(error, e)
|
||||
logging.error(message, error)
|
||||
else:
|
||||
match incoming.root:
|
||||
# Broadcast all operation messages to connected peers
|
||||
|
@ -100,34 +89,4 @@ class SyncConsumer(AsyncWebsocketConsumer):
|
|||
|
||||
# Send peer messages to the proper peer
|
||||
case PeerMessage():
|
||||
print("Received peermessage", incoming.root)
|
||||
await self.send_to(incoming.root.recipient, text_data)
|
||||
|
||||
# Send peer messages to the proper peer
|
||||
# case PeerMessage(recipient=_id):
|
||||
# peer = connections.get(_id)
|
||||
# if peer:
|
||||
# await peer.send(text_data)
|
||||
|
||||
# message = JoinRequest.model_validate_json(text_data)
|
||||
# signed = TimestampSigner().unsign_object(message.token, max_age=30)
|
||||
# user, map_id, permissions = signed.values()
|
||||
|
||||
# # Check if permissions for this map have been granted by the server
|
||||
# if "edit" in signed["permissions"]:
|
||||
# connections = CONNECTIONS[map_id]
|
||||
# _id = connections.join(self)
|
||||
|
||||
# # Assign an ID to the joining peer and return it the list of connected peers.
|
||||
# peers: list[WebSocketClientProtocol] = [
|
||||
# connections.get_id(p) for p in connections.get_all_peers()
|
||||
# ]
|
||||
# response = JoinResponse(uuid=_id, peers=peers)
|
||||
# await self.send(response.model_dump_json())
|
||||
|
||||
# # await join_and_listen(map_id, permissions, user, websocket)
|
||||
|
||||
# # text_data_json = json.loads(text_data)
|
||||
# # message = text_data_json["message"]
|
||||
|
||||
# # self.send(text_data=json.dumps({"message": message}))
|
49
umap/sync/payloads.py
Normal file
49
umap/sync/payloads.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
|
||||
|
||||
class JoinRequest(BaseModel):
|
||||
kind: Literal["JoinRequest"] = "JoinRequest"
|
||||
token: str
|
||||
|
||||
|
||||
class OperationMessage(BaseModel):
|
||||
"""Message sent from one peer to all the others"""
|
||||
|
||||
kind: Literal["OperationMessage"] = "OperationMessage"
|
||||
verb: Literal["upsert", "update", "delete"]
|
||||
subject: Literal["map", "datalayer", "feature"]
|
||||
metadata: Optional[dict] = None
|
||||
key: Optional[str] = None
|
||||
|
||||
|
||||
class PeerMessage(BaseModel):
|
||||
"""Message sent from a specific peer to another one"""
|
||||
|
||||
kind: Literal["PeerMessage"] = "PeerMessage"
|
||||
sender: str
|
||||
recipient: str
|
||||
# The message can be whatever the peers want. It's not checked by the server.
|
||||
message: dict
|
||||
|
||||
|
||||
class Request(RootModel):
|
||||
"""Any message coming from the websocket should be one of these, and will be rejected otherwise."""
|
||||
|
||||
root: Union[PeerMessage, OperationMessage, JoinRequest] = Field(
|
||||
discriminator="kind"
|
||||
)
|
||||
|
||||
|
||||
class JoinResponse(BaseModel):
|
||||
"""Server response containing the list of peers"""
|
||||
|
||||
kind: Literal["JoinResponse"] = "JoinResponse"
|
||||
peers: list
|
||||
uuid: str
|
||||
|
||||
|
||||
class ListPeersResponse(BaseModel):
|
||||
kind: Literal["ListPeersResponse"] = "ListPeersResponse"
|
||||
peers: list
|
|
@ -1,202 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import websockets
|
||||
from django.conf import settings
|
||||
from django.core.signing import TimestampSigner
|
||||
from pydantic import BaseModel, Field, RootModel, ValidationError
|
||||
from websockets import WebSocketClientProtocol
|
||||
from websockets.server import serve
|
||||
|
||||
|
||||
class Connections:
|
||||
def __init__(self) -> None:
|
||||
self._connections: set[WebSocketClientProtocol] = set()
|
||||
self._ids: dict[WebSocketClientProtocol, str] = dict()
|
||||
|
||||
def join(self, websocket: WebSocketClientProtocol) -> str:
|
||||
self._connections.add(websocket)
|
||||
_id = str(uuid.uuid4())
|
||||
self._ids[websocket] = _id
|
||||
return _id
|
||||
|
||||
def leave(self, websocket: WebSocketClientProtocol) -> None:
|
||||
self._connections.remove(websocket)
|
||||
del self._ids[websocket]
|
||||
|
||||
def get(self, id) -> WebSocketClientProtocol:
|
||||
# use an iterator to stop iterating as soon as we found
|
||||
return next(k for k, v in self._ids.items() if v == id)
|
||||
|
||||
def get_id(self, websocket: WebSocketClientProtocol):
|
||||
return self._ids[websocket]
|
||||
|
||||
def get_other_peers(
|
||||
self, websocket: WebSocketClientProtocol
|
||||
) -> set[WebSocketClientProtocol]:
|
||||
return self._connections - {websocket}
|
||||
|
||||
def get_all_peers(self) -> set[WebSocketClientProtocol]:
|
||||
return self._connections
|
||||
|
||||
|
||||
# Contains the list of websocket connections handled by this process.
|
||||
# It's a mapping of map_id to a set of the active websocket connections
|
||||
CONNECTIONS: defaultdict[int, Connections] = defaultdict(Connections)
|
||||
|
||||
|
||||
class JoinRequest(BaseModel):
|
||||
kind: Literal["JoinRequest"] = "JoinRequest"
|
||||
token: str
|
||||
|
||||
|
||||
class OperationMessage(BaseModel):
|
||||
"""Message sent from one peer to all the others"""
|
||||
|
||||
kind: Literal["OperationMessage"] = "OperationMessage"
|
||||
verb: Literal["upsert", "update", "delete"]
|
||||
subject: Literal["map", "datalayer", "feature"]
|
||||
metadata: Optional[dict] = None
|
||||
key: Optional[str] = None
|
||||
|
||||
|
||||
class PeerMessage(BaseModel):
|
||||
"""Message sent from a specific peer to another one"""
|
||||
|
||||
kind: Literal["PeerMessage"] = "PeerMessage"
|
||||
sender: str
|
||||
recipient: str
|
||||
# The message can be whatever the peers want. It's not checked by the server.
|
||||
message: dict
|
||||
|
||||
|
||||
class ServerRequest(BaseModel):
|
||||
"""A request towards the server"""
|
||||
|
||||
kind: Literal["Server"] = "Server"
|
||||
action: Literal["list-peers"]
|
||||
|
||||
|
||||
class Request(RootModel):
|
||||
"""Any message coming from the websocket should be one of these, and will be rejected otherwise."""
|
||||
|
||||
root: Union[ServerRequest, PeerMessage, OperationMessage, JoinRequest] = Field(
|
||||
discriminator="kind"
|
||||
)
|
||||
|
||||
|
||||
class JoinResponse(BaseModel):
|
||||
"""Server response containing the list of peers"""
|
||||
|
||||
kind: Literal["JoinResponse"] = "JoinResponse"
|
||||
peers: list
|
||||
uuid: str
|
||||
|
||||
|
||||
class ListPeersResponse(BaseModel):
|
||||
kind: Literal["ListPeersResponse"] = "ListPeersResponse"
|
||||
peers: list
|
||||
|
||||
|
||||
async def join_and_listen(
|
||||
map_id: int, permissions: list, user: str | int, websocket: WebSocketClientProtocol
|
||||
):
|
||||
"""Join a "room" with other connected peers, and wait for messages."""
|
||||
logging.debug(f"{user} joined room #{map_id}")
|
||||
connections: Connections = CONNECTIONS[map_id]
|
||||
_id: str = connections.join(websocket)
|
||||
|
||||
# Assign an ID to the joining peer and return it the list of connected peers.
|
||||
peers: list[WebSocketClientProtocol] = [
|
||||
connections.get_id(p) for p in connections.get_all_peers()
|
||||
]
|
||||
response = JoinResponse(uuid=_id, peers=peers)
|
||||
await websocket.send(response.model_dump_json())
|
||||
|
||||
# Notify all other peers of the new list of connected peers.
|
||||
message = ListPeersResponse(peers=peers)
|
||||
websockets.broadcast(
|
||||
connections.get_other_peers(websocket), message.model_dump_json()
|
||||
)
|
||||
|
||||
try:
|
||||
async for raw_message in websocket:
|
||||
if raw_message == "ping":
|
||||
await websocket.send("pong")
|
||||
continue
|
||||
|
||||
# recompute the peers list at the time of message-sending.
|
||||
# as doing so beforehand would miss new connections
|
||||
other_peers = connections.get_other_peers(websocket)
|
||||
try:
|
||||
incoming = Request.model_validate_json(raw_message)
|
||||
except ValidationError as e:
|
||||
error = f"An error occurred when receiving the following message: {raw_message!r}"
|
||||
logging.error(error, e)
|
||||
else:
|
||||
match incoming.root:
|
||||
# Broadcast all operation messages to connected peers
|
||||
case OperationMessage():
|
||||
websockets.broadcast(other_peers, raw_message)
|
||||
|
||||
# Send peer messages to the proper peer
|
||||
case PeerMessage(recipient=_id):
|
||||
peer = connections.get(_id)
|
||||
if peer:
|
||||
await peer.send(raw_message)
|
||||
|
||||
finally:
|
||||
# On disconnect, remove the connection from the pool
|
||||
connections.leave(websocket)
|
||||
|
||||
# TODO: refactor this in a separate method.
|
||||
# Notify all other peers of the new list of connected peers.
|
||||
peers = [connections.get_id(p) for p in connections.get_all_peers()]
|
||||
message = ListPeersResponse(peers=peers)
|
||||
websockets.broadcast(
|
||||
connections.get_other_peers(websocket), message.model_dump_json()
|
||||
)
|
||||
|
||||
|
||||
async def handler(websocket: WebSocketClientProtocol):
|
||||
"""Main WebSocket handler.
|
||||
|
||||
Check if the permission is granted and let the peer enter a room.
|
||||
"""
|
||||
raw_message = await websocket.recv()
|
||||
|
||||
# The first event should always be 'join'
|
||||
message: JoinRequest = JoinRequest.model_validate_json(raw_message)
|
||||
signed = TimestampSigner().unsign_object(message.token, max_age=30)
|
||||
user, map_id, permissions = signed.values()
|
||||
|
||||
# Check if permissions for this map have been granted by the server
|
||||
if "edit" in signed["permissions"]:
|
||||
await join_and_listen(map_id, permissions, user, websocket)
|
||||
|
||||
|
||||
def run(host: str, port: int):
|
||||
if not settings.WEBSOCKET_ENABLED:
|
||||
msg = (
|
||||
"WEBSOCKET_ENABLED should be set to True to run the WebSocket Server. "
|
||||
"See the documentation at "
|
||||
"https://docs.umap-project.org/en/stable/config/settings/#websocket_enabled "
|
||||
"for more information."
|
||||
)
|
||||
print(msg)
|
||||
exit(1)
|
||||
|
||||
async def _serve():
|
||||
async with serve(handler, host, port):
|
||||
logging.debug(f"Waiting for connections on {host}:{port}")
|
||||
await asyncio.Future() # run forever
|
||||
|
||||
try:
|
||||
asyncio.run(_serve())
|
||||
except KeyboardInterrupt:
|
||||
print("Closing WebSocket server")
|
Loading…
Reference in a new issue