Compare commits

..

No commits in common. "fd8cc2a2cb790bae0419a7bc5700ffaed2743146" and "e865c31d69a6993b7f2463e5cfa62c2e89cd3a00" have entirely different histories.

5 changed files with 48 additions and 16 deletions

View file

@ -17,6 +17,8 @@ urlpatterns = (re_path(r"ws/sync/(?P<map_id>\w+)/$", consumers.SyncConsumer.as_a
application = ProtocolTypeRouter( application = ProtocolTypeRouter(
{ {
"http": django_asgi_app, "http": django_asgi_app,
"websocket": AllowedHostsOriginValidator(URLRouter(urlpatterns)), "websocket": consumers.TokenMiddleware(
AllowedHostsOriginValidator(URLRouter(urlpatterns))
),
} }
) )

View file

@ -7,7 +7,7 @@ export class WebSocketTransport {
this.receiver = messagesReceiver this.receiver = messagesReceiver
this.closeRequested = false this.closeRequested = false
this.websocket = new WebSocket(`${webSocketURI}`) this.websocket = new WebSocket(`${webSocketURI}?${authToken}`)
this.websocket.onopen = () => { this.websocket.onopen = () => {
this.send('JoinRequest', { token: authToken }) this.send('JoinRequest', { token: authToken })

View file

@ -14,6 +14,21 @@ from .payloads import (
) )
class TokenMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
signed = TimestampSigner().unsign_object(
scope["query_string"].decode(), max_age=30
)
user, map_id, permissions = signed.values()
if "edit" not in permissions:
raise ValueError("Invalid Token")
scope["user"] = user
return await self.app(scope, receive, send)
class SyncConsumer(AsyncWebsocketConsumer): class SyncConsumer(AsyncWebsocketConsumer):
@property @property
def peers(self): def peers(self):
@ -26,7 +41,7 @@ class SyncConsumer(AsyncWebsocketConsumer):
await self.channel_layer.group_add(self.map_id, self.channel_name) await self.channel_layer.group_add(self.map_id, self.channel_name)
await self.accept() await self.accept()
self.is_authenticated = False await self.send_peers_list()
async def disconnect(self, close_code): async def disconnect(self, close_code):
await self.channel_layer.group_discard(self.map_id, self.channel_name) await self.channel_layer.group_discard(self.map_id, self.channel_name)
@ -53,18 +68,6 @@ class SyncConsumer(AsyncWebsocketConsumer):
await self.send(event["message"]) await self.send(event["message"])
async def receive(self, text_data): async def receive(self, text_data):
if not self.is_authenticated:
message = JoinRequest.model_validate_json(text_data)
signed = TimestampSigner().unsign_object(message.token, max_age=30)
user, map_id, permissions = signed.values()
if "edit" not in permissions:
return await self.disconnect()
response = JoinResponse(uuid=self.channel_name, peers=self.peers)
await self.send(response.model_dump_json())
await self.send_peers_list()
self.is_authenticated = True
return
if text_data == "ping": if text_data == "ping":
return await self.send("pong") return await self.send("pong")
@ -78,6 +81,9 @@ class SyncConsumer(AsyncWebsocketConsumer):
else: else:
match incoming.root: match incoming.root:
# Broadcast all operation messages to connected peers # Broadcast all operation messages to connected peers
case JoinRequest():
response = JoinResponse(uuid=self.channel_name, peers=self.peers)
await self.send(response.model_dump_json())
case OperationMessage(): case OperationMessage():
await self.broadcast(text_data) await self.broadcast(text_data)

View file

@ -31,7 +31,9 @@ class PeerMessage(BaseModel):
class Request(RootModel): class Request(RootModel):
"""Any message coming from the websocket should be one of these, and will be rejected otherwise.""" """Any message coming from the websocket should be one of these, and will be rejected otherwise."""
root: Union[PeerMessage, OperationMessage] = Field(discriminator="kind") root: Union[PeerMessage, OperationMessage, JoinRequest] = Field(
discriminator="kind"
)
class JoinResponse(BaseModel): class JoinResponse(BaseModel):

View file

@ -0,0 +1,22 @@
from umap.websocket_server import OperationMessage, PeerMessage, Request, ServerRequest
def test_messages_are_parsed_correctly():
server = Request.model_validate(dict(kind="Server", action="list-peers")).root
assert type(server) is ServerRequest
operation = Request.model_validate(
dict(
kind="OperationMessage",
verb="upsert",
subject="map",
metadata={},
key="key",
)
).root
assert type(operation) is OperationMessage
peer_message = Request.model_validate(
dict(kind="PeerMessage", sender="Alice", recipient="Bob", message={})
).root
assert type(peer_message) is PeerMessage