Compare commits

..

No commits in common. "fd8cc2a2cb790bae0419a7bc5700ffaed2743146" and "e865c31d69a6993b7f2463e5cfa62c2e89cd3a00" have entirely different histories.

5 changed files with 48 additions and 16 deletions

View file

@ -17,6 +17,8 @@ urlpatterns = (re_path(r"ws/sync/(?P<map_id>\w+)/$", consumers.SyncConsumer.as_a
application = ProtocolTypeRouter(
{
"http": django_asgi_app,
"websocket": AllowedHostsOriginValidator(URLRouter(urlpatterns)),
"websocket": consumers.TokenMiddleware(
AllowedHostsOriginValidator(URLRouter(urlpatterns))
),
}
)

View file

@ -7,7 +7,7 @@ export class WebSocketTransport {
this.receiver = messagesReceiver
this.closeRequested = false
this.websocket = new WebSocket(`${webSocketURI}`)
this.websocket = new WebSocket(`${webSocketURI}?${authToken}`)
this.websocket.onopen = () => {
this.send('JoinRequest', { token: authToken })

View file

@ -14,6 +14,21 @@ from .payloads import (
)
class TokenMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
signed = TimestampSigner().unsign_object(
scope["query_string"].decode(), max_age=30
)
user, map_id, permissions = signed.values()
if "edit" not in permissions:
raise ValueError("Invalid Token")
scope["user"] = user
return await self.app(scope, receive, send)
class SyncConsumer(AsyncWebsocketConsumer):
@property
def peers(self):
@ -26,7 +41,7 @@ class SyncConsumer(AsyncWebsocketConsumer):
await self.channel_layer.group_add(self.map_id, self.channel_name)
await self.accept()
self.is_authenticated = False
await self.send_peers_list()
async def disconnect(self, close_code):
await self.channel_layer.group_discard(self.map_id, self.channel_name)
@ -53,18 +68,6 @@ class SyncConsumer(AsyncWebsocketConsumer):
await self.send(event["message"])
async def receive(self, text_data):
if not self.is_authenticated:
message = JoinRequest.model_validate_json(text_data)
signed = TimestampSigner().unsign_object(message.token, max_age=30)
user, map_id, permissions = signed.values()
if "edit" not in permissions:
return await self.disconnect()
response = JoinResponse(uuid=self.channel_name, peers=self.peers)
await self.send(response.model_dump_json())
await self.send_peers_list()
self.is_authenticated = True
return
if text_data == "ping":
return await self.send("pong")
@ -78,6 +81,9 @@ class SyncConsumer(AsyncWebsocketConsumer):
else:
match incoming.root:
# Broadcast all operation messages to connected peers
case JoinRequest():
response = JoinResponse(uuid=self.channel_name, peers=self.peers)
await self.send(response.model_dump_json())
case OperationMessage():
await self.broadcast(text_data)

View file

@ -31,7 +31,9 @@ class PeerMessage(BaseModel):
class Request(RootModel):
"""Any message coming from the websocket should be one of these, and will be rejected otherwise."""
root: Union[PeerMessage, OperationMessage] = Field(discriminator="kind")
root: Union[PeerMessage, OperationMessage, JoinRequest] = Field(
discriminator="kind"
)
class JoinResponse(BaseModel):

View file

@ -0,0 +1,22 @@
from umap.websocket_server import OperationMessage, PeerMessage, Request, ServerRequest
def test_messages_are_parsed_correctly():
server = Request.model_validate(dict(kind="Server", action="list-peers")).root
assert type(server) is ServerRequest
operation = Request.model_validate(
dict(
kind="OperationMessage",
verb="upsert",
subject="map",
metadata={},
key="key",
)
).root
assert type(operation) is OperationMessage
peer_message = Request.model_validate(
dict(kind="PeerMessage", sender="Alice", recipient="Bob", message={})
).root
assert type(peer_message) is PeerMessage