diff --git a/umap/asgi.py b/umap/asgi.py index 2ca12ddc..6668a94c 100644 --- a/umap/asgi.py +++ b/umap/asgi.py @@ -1,15 +1,24 @@ import os -from channels.routing import ProtocolTypeRouter +from channels.routing import ProtocolTypeRouter, URLRouter +from channels.security.websocket import AllowedHostsOriginValidator from django.core.asgi import get_asgi_application +from django.urls import re_path + +from . import consumers os.environ.setdefault("DJANGO_SETTINGS_MODULE", "umap.settings") # Initialize Django ASGI application early to ensure the AppRegistry # is populated before importing code that may import ORM models. django_asgi_app = get_asgi_application() +urlpatterns = (re_path(r"ws/sync/(?P\w+)/$", consumers.SyncConsumer.as_asgi()),) + application = ProtocolTypeRouter( { "http": django_asgi_app, + "websocket": consumers.TokenMiddleware( + AllowedHostsOriginValidator(URLRouter(urlpatterns)) + ), } ) diff --git a/umap/consumers.py b/umap/consumers.py new file mode 100644 index 00000000..abcdbd0a --- /dev/null +++ b/umap/consumers.py @@ -0,0 +1,117 @@ +from channels.generic.websocket import AsyncWebsocketConsumer +from django.core.signing import TimestampSigner + +from .websocket_server import ( + JoinRequest, + JoinResponse, + OperationMessage, + Request, + ValidationError, +) + + +class TokenMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + signed = TimestampSigner().unsign_object( + scope["query_string"].decode(), max_age=30 + ) + user, map_id, permissions = signed.values() + if "edit" not in permissions: + raise ValueError("Invalid Token") + scope["user"] = user + return await self.app(scope, receive, send) + + +class SyncConsumer(AsyncWebsocketConsumer): + async def connect(self): + print("connect") + self.map_id = self.scope["url_route"]["kwargs"]["map_id"] + + # Join room group + await self.channel_layer.group_add(self.map_id, self.channel_name) + + await self.accept() + + async def disconnect(self, close_code): + print("disconnect") + await self.channel_layer.group_discard(self.map_id, self.channel_name) + + async def broadcast(self, event): + print(event) + await self.send(event["message"]) + + async def receive(self, text_data): + print("receive") + print(text_data) + if text_data == "ping": + return await self.send("pong") + + # recompute the peers list at the time of message-sending. + # as doing so beforehand would miss new connections + # other_peers = connections.get_other_peers(websocket) + # await self.send("pong" + self.channel_name) + # await self.channel_layer.group_send( + # self.map_id, + # {"message": "pouet " + self.channel_name, "type": "broadcast"}, + # ) + try: + incoming = Request.model_validate_json(text_data) + except ValidationError: + error = ( + f"An error occurred when receiving the following message: {text_data!r}" + ) + print(error) + # logging.error(error, e) + else: + match incoming.root: + # Broadcast all operation messages to connected peers + case JoinRequest(): + response = JoinResponse( + uuid=self.channel_name, + peers=self.channel_layer.groups[self.map_id].keys(), + ) + await self.send(response.model_dump_json()) + case OperationMessage(): + await self.channel_layer.group_send( + self.map_id, + {"message": text_data, "type": "broadcast"}, + ) + + # Send peer messages to the proper peer + # case PeerMessage(recipient=_id): + # peer = connections.get(_id) + # if peer: + # await peer.send(raw_message) + # websockets.broadcast(other_peers, text_data) + + # Send peer messages to the proper peer + # case PeerMessage(recipient=_id): + # peer = connections.get(_id) + # if peer: + # await peer.send(text_data) + + # message = JoinRequest.model_validate_json(text_data) + # signed = TimestampSigner().unsign_object(message.token, max_age=30) + # user, map_id, permissions = signed.values() + + # # Check if permissions for this map have been granted by the server + # if "edit" in signed["permissions"]: + # connections = CONNECTIONS[map_id] + # _id = connections.join(self) + + # # Assign an ID to the joining peer and return it the list of connected peers. + # peers: list[WebSocketClientProtocol] = [ + # connections.get_id(p) for p in connections.get_all_peers() + # ] + # response = JoinResponse(uuid=_id, peers=peers) + # await self.send(response.model_dump_json()) + + # # await join_and_listen(map_id, permissions, user, websocket) + + # # text_data_json = json.loads(text_data) + # # message = text_data_json["message"] + + # # self.send(text_data=json.dumps({"message": message})) diff --git a/umap/settings/base.py b/umap/settings/base.py index f47ad236..fedefe88 100644 --- a/umap/settings/base.py +++ b/umap/settings/base.py @@ -343,3 +343,4 @@ WEBSOCKET_ENABLED = env.bool("WEBSOCKET_ENABLED", default=False) WEBSOCKET_BACK_HOST = env("WEBSOCKET_BACK_HOST", default="localhost") WEBSOCKET_BACK_PORT = env.int("WEBSOCKET_BACK_PORT", default=8001) WEBSOCKET_FRONT_URI = env("WEBSOCKET_FRONT_URI", default="ws://localhost:8001") +CHANNEL_LAYERS = {"default": {"BACKEND": "channels.layers.InMemoryChannelLayer"}} diff --git a/umap/static/umap/js/modules/sync/engine.js b/umap/static/umap/js/modules/sync/engine.js index 9094b482..9df767af 100644 --- a/umap/static/umap/js/modules/sync/engine.js +++ b/umap/static/umap/js/modules/sync/engine.js @@ -150,6 +150,7 @@ export class SyncEngine { * and dispatches the different "on*" methods. */ receive({ kind, ...payload }) { + console.log(kind, payload) if (kind === 'OperationMessage') { this.onOperationMessage(payload) } else if (kind === 'JoinResponse') { diff --git a/umap/static/umap/js/modules/sync/websocket.js b/umap/static/umap/js/modules/sync/websocket.js index ce346ad7..1ea9419c 100644 --- a/umap/static/umap/js/modules/sync/websocket.js +++ b/umap/static/umap/js/modules/sync/websocket.js @@ -7,7 +7,7 @@ export class WebSocketTransport { this.receiver = messagesReceiver this.closeRequested = false - this.websocket = new WebSocket(webSocketURI) + this.websocket = new WebSocket(`${webSocketURI}?${authToken}`) this.websocket.onopen = () => { this.send('JoinRequest', { token: authToken }) @@ -49,6 +49,7 @@ export class WebSocketTransport { } onMessage(wsMessage) { + console.log(wsMessage) if (wsMessage.data === 'pong') { this.pongReceived = true } else { diff --git a/umap/websocket_server.py b/umap/websocket_server.py index 6483d648..346fe960 100644 --- a/umap/websocket_server.py +++ b/umap/websocket_server.py @@ -85,7 +85,7 @@ class ServerRequest(BaseModel): class Request(RootModel): """Any message coming from the websocket should be one of these, and will be rejected otherwise.""" - root: Union[ServerRequest, PeerMessage, OperationMessage] = Field( + root: Union[ServerRequest, PeerMessage, OperationMessage, JoinRequest] = Field( discriminator="kind" )