Skip to content

Commit 2825ebb

Browse files
committed
tests(sync): Change the way the websocket server is run in the tests
Using [pytest-xprocess](https://pytest-xprocess.readthedocs.io/) proved not being as useful as I thought at first, because it was causing intermitent failures when starting the process. The code now directly uses `subprocess.popen` calls to start the server. The tests are grouped together using the following decorator: `@pytest.mark.xdist_group(name="websockets")` Tests now need to be run with the `pytest --dist loadgroup` so that all tests of the same group happen on the same process. More details on this blogpost: https://blog.notmyidea.org/start-a-process-when-using-pytest-xdist.html
1 parent 3300319 commit 2825ebb

File tree

6 files changed

+122
-21
lines changed

6 files changed

+122
-21
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ publish: ## Publish the Python package to Pypi
5858
test: testpy testjs
5959

6060
testpy:
61-
pytest -vv umap/tests/
61+
pytest -vv umap/tests/ --dist=loadgroup
6262

6363
test-integration:
64-
pytest -xv umap/tests/integration/
64+
pytest -xv umap/tests/integration/ --dist=loadgroup
6565

6666
clean:
6767
rm -f dist/*

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ test = [
6262
"pytest-django==4.8.0",
6363
"pytest-playwright==0.5.0",
6464
"pytest-xdist>=3.5.0,<4",
65-
"pytest-xprocess>=1.0.1",
6665
]
6766
docker = [
6867
"uwsgi==2.0.26",

umap/tests/integration/conftest.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
2+
import subprocess
3+
import time
24
from pathlib import Path
35

46
import pytest
57
from playwright.sync_api import expect
6-
from xprocess import ProcessStarter
78

89

910
@pytest.fixture(autouse=True)
@@ -37,19 +38,23 @@ def do_login(user):
3738
return do_login
3839

3940

40-
@pytest.fixture()
41-
def websocket_server(xprocess):
42-
class Starter(ProcessStarter):
43-
settings_path = (
44-
(Path(__file__).parent.parent / "settings.py").absolute().as_posix()
45-
)
46-
os.environ["UMAP_SETTINGS"] = settings_path
47-
# env = {"UMAP_SETTINGS": settings_path}
48-
pattern = "Waiting for connections*"
49-
args = ["python", "-m", "umap.ws"]
50-
timeout = 1
51-
terminate_on_interrupt = True
52-
53-
xprocess.ensure("websocket_server", Starter)
54-
yield
55-
xprocess.getinfo("websocket_server").terminate()
41+
@pytest.fixture
42+
def websocket_server():
43+
# Find the test-settings, and put them in the current environment
44+
settings_path = (Path(__file__).parent.parent / "settings.py").absolute().as_posix()
45+
os.environ["UMAP_SETTINGS"] = settings_path
46+
47+
ds_proc = subprocess.Popen(
48+
[
49+
"umap",
50+
"run_websocket_server",
51+
],
52+
stdout=subprocess.PIPE,
53+
stderr=subprocess.STDOUT,
54+
)
55+
time.sleep(2)
56+
# Ensure it started properly before yielding
57+
assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8")
58+
yield ds_proc
59+
# Shut it down at the end of the pytest session
60+
ds_proc.terminate()

umap/tests/integration/test_websocket_sync.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22

3+
import pytest
34
from playwright.sync_api import expect
45

56
from umap.models import Map
@@ -9,6 +10,7 @@
910
DATALAYER_UPDATE = re.compile(r".*/datalayer/update/.*")
1011

1112

13+
@pytest.mark.xdist_group(name="websockets")
1214
def test_websocket_connection_can_sync_markers(
1315
context, live_server, websocket_server, tilelayer
1416
):
@@ -73,6 +75,7 @@ def test_websocket_connection_can_sync_markers(
7375
expect(b_marker_pane).to_have_count(1)
7476

7577

78+
@pytest.mark.xdist_group(name="websockets")
7679
def test_websocket_connection_can_sync_polygons(
7780
context, live_server, websocket_server, tilelayer
7881
):
@@ -156,6 +159,7 @@ def test_websocket_connection_can_sync_polygons(
156159
expect(b_polygons).to_have_count(0)
157160

158161

162+
@pytest.mark.xdist_group(name="websockets")
159163
def test_websocket_connection_can_sync_map_properties(
160164
context, live_server, websocket_server, tilelayer
161165
):
@@ -187,6 +191,7 @@ def test_websocket_connection_can_sync_map_properties(
187191
expect(peerA.locator(".leaflet-control-zoom")).to_be_hidden()
188192

189193

194+
@pytest.mark.xdist_group(name="websockets")
190195
def test_websocket_connection_can_sync_datalayer_properties(
191196
context, live_server, websocket_server, tilelayer
192197
):
@@ -215,6 +220,7 @@ def test_websocket_connection_can_sync_datalayer_properties(
215220
expect(peerB.get_by_role("combobox")).to_have_value("Choropleth")
216221

217222

223+
@pytest.mark.xdist_group(name="websockets")
218224
def test_websocket_connection_can_sync_cloned_polygons(
219225
context, live_server, websocket_server, tilelayer
220226
):

umap/tests/test_datalayer_views.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import json
2-
import time
32
from copy import deepcopy
43
from pathlib import Path
54

umap/websocket_server.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#!/usr/bin/env python
2+
3+
import asyncio
4+
from collections import defaultdict
5+
from typing import Literal, Optional
6+
7+
import websockets
8+
from django.conf import settings
9+
from django.core.signing import TimestampSigner
10+
from pydantic import BaseModel, ValidationError
11+
from websockets import WebSocketClientProtocol
12+
from websockets.server import serve
13+
14+
from umap.models import Map, User # NOQA
15+
16+
# Contains the list of websocket connections handled by this process.
17+
# It's a mapping of map_id to a set of the active websocket connections
18+
CONNECTIONS = defaultdict(set)
19+
20+
21+
class JoinMessage(BaseModel):
22+
kind: str = "join"
23+
token: str
24+
25+
26+
class OperationMessage(BaseModel):
27+
kind: str = "operation"
28+
verb: str = Literal["upsert", "update", "delete"]
29+
subject: str = Literal["map", "layer", "feature"]
30+
metadata: Optional[dict] = None
31+
key: Optional[str] = None
32+
33+
34+
async def join_and_listen(
35+
map_id: int, permissions: list, user: str | int, websocket: WebSocketClientProtocol
36+
):
37+
"""Join a "room" whith other connected peers.
38+
39+
New messages will be broadcasted to other connected peers.
40+
"""
41+
print(f"{user} joined room #{map_id}")
42+
CONNECTIONS[map_id].add(websocket)
43+
try:
44+
async for raw_message in websocket:
45+
# recompute the peers-list at the time of message-sending.
46+
# as doing so beforehand would miss new connections
47+
peers = CONNECTIONS[map_id] - {websocket}
48+
# Only relay valid "operation" messages
49+
try:
50+
OperationMessage.model_validate_json(raw_message)
51+
websockets.broadcast(peers, raw_message)
52+
except ValidationError as e:
53+
error = f"An error occurred when receiving this message: {raw_message}"
54+
print(error, e)
55+
finally:
56+
CONNECTIONS[map_id].remove(websocket)
57+
58+
59+
async def handler(websocket):
60+
"""Main WebSocket handler.
61+
62+
If permissions are granted, let the peer enter a room.
63+
"""
64+
raw_message = await websocket.recv()
65+
66+
# The first event should always be 'join'
67+
message: JoinMessage = JoinMessage.model_validate_json(raw_message)
68+
signed = TimestampSigner().unsign_object(message.token, max_age=30)
69+
user, map_id, permissions = signed.values()
70+
71+
# Check if permissions for this map have been granted by the server
72+
if "edit" in signed["permissions"]:
73+
await join_and_listen(map_id, permissions, user, websocket)
74+
75+
76+
def run(host, port):
77+
if not settings.WEBSOCKET_ENABLED:
78+
msg = (
79+
"WEBSOCKET_ENABLED should be set to True to run the WebSocket Server. "
80+
"See the documentation at "
81+
"https://docs.umap-project.org/en/stable/config/settings/#websocket_enabled "
82+
"for more information."
83+
)
84+
print(msg)
85+
exit(1)
86+
87+
async def _serve():
88+
async with serve(handler, host, port):
89+
print(f"Waiting for connections on {host}:{port}")
90+
await asyncio.Future() # run forever
91+
92+
asyncio.run(_serve())

0 commit comments

Comments
 (0)