diff --git a/umap/asgi.py b/umap/asgi.py index 5b130b65..1f9d618a 100644 --- a/umap/asgi.py +++ b/umap/asgi.py @@ -17,8 +17,6 @@ urlpatterns = (re_path(r"ws/sync/(?P\w+)/$", consumers.SyncConsumer.as_a application = ProtocolTypeRouter( { "http": django_asgi_app, - "websocket": consumers.TokenMiddleware( - AllowedHostsOriginValidator(URLRouter(urlpatterns)) - ), + "websocket": AllowedHostsOriginValidator(URLRouter(urlpatterns)), } ) diff --git a/umap/static/umap/js/modules/sync/websocket.js b/umap/static/umap/js/modules/sync/websocket.js index 1ea9419c..c49b7953 100644 --- a/umap/static/umap/js/modules/sync/websocket.js +++ b/umap/static/umap/js/modules/sync/websocket.js @@ -7,7 +7,7 @@ export class WebSocketTransport { this.receiver = messagesReceiver this.closeRequested = false - this.websocket = new WebSocket(`${webSocketURI}?${authToken}`) + this.websocket = new WebSocket(`${webSocketURI}`) this.websocket.onopen = () => { this.send('JoinRequest', { token: authToken }) diff --git a/umap/sync/consumers.py b/umap/sync/consumers.py index f9c3492b..48912820 100644 --- a/umap/sync/consumers.py +++ b/umap/sync/consumers.py @@ -14,21 +14,6 @@ from .payloads import ( ) -class TokenMiddleware: - def __init__(self, app): - self.app = app - - async def __call__(self, scope, receive, send): - signed = TimestampSigner().unsign_object( - scope["query_string"].decode(), max_age=30 - ) - user, map_id, permissions = signed.values() - if "edit" not in permissions: - raise ValueError("Invalid Token") - scope["user"] = user - return await self.app(scope, receive, send) - - class SyncConsumer(AsyncWebsocketConsumer): @property def peers(self): @@ -41,7 +26,7 @@ class SyncConsumer(AsyncWebsocketConsumer): await self.channel_layer.group_add(self.map_id, self.channel_name) await self.accept() - await self.send_peers_list() + self.is_authenticated = False async def disconnect(self, close_code): await self.channel_layer.group_discard(self.map_id, self.channel_name) @@ -68,6 +53,18 @@ class SyncConsumer(AsyncWebsocketConsumer): await self.send(event["message"]) async def receive(self, text_data): + if not self.is_authenticated: + message = JoinRequest.model_validate_json(text_data) + signed = TimestampSigner().unsign_object(message.token, max_age=30) + user, map_id, permissions = signed.values() + if "edit" not in permissions: + return await self.disconnect() + response = JoinResponse(uuid=self.channel_name, peers=self.peers) + await self.send(response.model_dump_json()) + await self.send_peers_list() + self.is_authenticated = True + return + if text_data == "ping": return await self.send("pong") @@ -81,9 +78,6 @@ class SyncConsumer(AsyncWebsocketConsumer): else: match incoming.root: # Broadcast all operation messages to connected peers - case JoinRequest(): - response = JoinResponse(uuid=self.channel_name, peers=self.peers) - await self.send(response.model_dump_json()) case OperationMessage(): await self.broadcast(text_data) diff --git a/umap/sync/payloads.py b/umap/sync/payloads.py index cfe2a003..6a15a3f1 100644 --- a/umap/sync/payloads.py +++ b/umap/sync/payloads.py @@ -31,9 +31,7 @@ class PeerMessage(BaseModel): class Request(RootModel): """Any message coming from the websocket should be one of these, and will be rejected otherwise.""" - root: Union[PeerMessage, OperationMessage, JoinRequest] = Field( - discriminator="kind" - ) + root: Union[PeerMessage, OperationMessage] = Field(discriminator="kind") class JoinResponse(BaseModel):