mirror of
https://github.com/umap-project/umap.git
synced 2025-05-15 18:21:53 +02:00
Compare commits
No commits in common. "fd8cc2a2cb790bae0419a7bc5700ffaed2743146" and "e865c31d69a6993b7f2463e5cfa62c2e89cd3a00" have entirely different histories.
fd8cc2a2cb
...
e865c31d69
5 changed files with 48 additions and 16 deletions
|
@ -17,6 +17,8 @@ urlpatterns = (re_path(r"ws/sync/(?P<map_id>\w+)/$", consumers.SyncConsumer.as_a
|
||||||
application = ProtocolTypeRouter(
|
application = ProtocolTypeRouter(
|
||||||
{
|
{
|
||||||
"http": django_asgi_app,
|
"http": django_asgi_app,
|
||||||
"websocket": AllowedHostsOriginValidator(URLRouter(urlpatterns)),
|
"websocket": consumers.TokenMiddleware(
|
||||||
|
AllowedHostsOriginValidator(URLRouter(urlpatterns))
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,7 +7,7 @@ export class WebSocketTransport {
|
||||||
this.receiver = messagesReceiver
|
this.receiver = messagesReceiver
|
||||||
this.closeRequested = false
|
this.closeRequested = false
|
||||||
|
|
||||||
this.websocket = new WebSocket(`${webSocketURI}`)
|
this.websocket = new WebSocket(`${webSocketURI}?${authToken}`)
|
||||||
|
|
||||||
this.websocket.onopen = () => {
|
this.websocket.onopen = () => {
|
||||||
this.send('JoinRequest', { token: authToken })
|
this.send('JoinRequest', { token: authToken })
|
||||||
|
|
|
@ -14,6 +14,21 @@ from .payloads import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenMiddleware:
|
||||||
|
def __init__(self, app):
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
signed = TimestampSigner().unsign_object(
|
||||||
|
scope["query_string"].decode(), max_age=30
|
||||||
|
)
|
||||||
|
user, map_id, permissions = signed.values()
|
||||||
|
if "edit" not in permissions:
|
||||||
|
raise ValueError("Invalid Token")
|
||||||
|
scope["user"] = user
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
class SyncConsumer(AsyncWebsocketConsumer):
|
class SyncConsumer(AsyncWebsocketConsumer):
|
||||||
@property
|
@property
|
||||||
def peers(self):
|
def peers(self):
|
||||||
|
@ -26,7 +41,7 @@ class SyncConsumer(AsyncWebsocketConsumer):
|
||||||
await self.channel_layer.group_add(self.map_id, self.channel_name)
|
await self.channel_layer.group_add(self.map_id, self.channel_name)
|
||||||
|
|
||||||
await self.accept()
|
await self.accept()
|
||||||
self.is_authenticated = False
|
await self.send_peers_list()
|
||||||
|
|
||||||
async def disconnect(self, close_code):
|
async def disconnect(self, close_code):
|
||||||
await self.channel_layer.group_discard(self.map_id, self.channel_name)
|
await self.channel_layer.group_discard(self.map_id, self.channel_name)
|
||||||
|
@ -53,18 +68,6 @@ class SyncConsumer(AsyncWebsocketConsumer):
|
||||||
await self.send(event["message"])
|
await self.send(event["message"])
|
||||||
|
|
||||||
async def receive(self, text_data):
|
async def receive(self, text_data):
|
||||||
if not self.is_authenticated:
|
|
||||||
message = JoinRequest.model_validate_json(text_data)
|
|
||||||
signed = TimestampSigner().unsign_object(message.token, max_age=30)
|
|
||||||
user, map_id, permissions = signed.values()
|
|
||||||
if "edit" not in permissions:
|
|
||||||
return await self.disconnect()
|
|
||||||
response = JoinResponse(uuid=self.channel_name, peers=self.peers)
|
|
||||||
await self.send(response.model_dump_json())
|
|
||||||
await self.send_peers_list()
|
|
||||||
self.is_authenticated = True
|
|
||||||
return
|
|
||||||
|
|
||||||
if text_data == "ping":
|
if text_data == "ping":
|
||||||
return await self.send("pong")
|
return await self.send("pong")
|
||||||
|
|
||||||
|
@ -78,6 +81,9 @@ class SyncConsumer(AsyncWebsocketConsumer):
|
||||||
else:
|
else:
|
||||||
match incoming.root:
|
match incoming.root:
|
||||||
# Broadcast all operation messages to connected peers
|
# Broadcast all operation messages to connected peers
|
||||||
|
case JoinRequest():
|
||||||
|
response = JoinResponse(uuid=self.channel_name, peers=self.peers)
|
||||||
|
await self.send(response.model_dump_json())
|
||||||
case OperationMessage():
|
case OperationMessage():
|
||||||
await self.broadcast(text_data)
|
await self.broadcast(text_data)
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,9 @@ class PeerMessage(BaseModel):
|
||||||
class Request(RootModel):
|
class Request(RootModel):
|
||||||
"""Any message coming from the websocket should be one of these, and will be rejected otherwise."""
|
"""Any message coming from the websocket should be one of these, and will be rejected otherwise."""
|
||||||
|
|
||||||
root: Union[PeerMessage, OperationMessage] = Field(discriminator="kind")
|
root: Union[PeerMessage, OperationMessage, JoinRequest] = Field(
|
||||||
|
discriminator="kind"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class JoinResponse(BaseModel):
|
class JoinResponse(BaseModel):
|
||||||
|
|
22
umap/tests/test_websocket_server.py
Normal file
22
umap/tests/test_websocket_server.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
from umap.websocket_server import OperationMessage, PeerMessage, Request, ServerRequest
|
||||||
|
|
||||||
|
|
||||||
|
def test_messages_are_parsed_correctly():
|
||||||
|
server = Request.model_validate(dict(kind="Server", action="list-peers")).root
|
||||||
|
assert type(server) is ServerRequest
|
||||||
|
|
||||||
|
operation = Request.model_validate(
|
||||||
|
dict(
|
||||||
|
kind="OperationMessage",
|
||||||
|
verb="upsert",
|
||||||
|
subject="map",
|
||||||
|
metadata={},
|
||||||
|
key="key",
|
||||||
|
)
|
||||||
|
).root
|
||||||
|
assert type(operation) is OperationMessage
|
||||||
|
|
||||||
|
peer_message = Request.model_validate(
|
||||||
|
dict(kind="PeerMessage", sender="Alice", recipient="Bob", message={})
|
||||||
|
).root
|
||||||
|
assert type(peer_message) is PeerMessage
|
Loading…
Reference in a new issue