From 51ab9303ec880c0890566a0055777f5419a442b1 Mon Sep 17 00:00:00 2001 From: shiminshen Date: Tue, 2 Jun 2026 15:40:30 +0800 Subject: [PATCH] fix(webrtc): enforce embed allowed-domain policy on public signaling websocket (#388) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The public WebRTC signaling WebSocket (`/public/signaling/{session_token}`) validated only the session token and its expiry, not the embed token's allowed-domain policy that the HTTP embed endpoints already enforce. A leaked or replayed session token could therefore attach to the signaling path from an arbitrary origin. Validate the request origin against `embed_token.allowed_domains` (reusing the existing `validate_origin` helper) before the signaling handoff, rejecting disallowed origins with a 1008 close — mirroring the HTTP embed endpoints. Closes #330 Co-authored-by: shiminshen <16914659+shiminshen@users.noreply.github.com> --- api/routes/webrtc_signaling.py | 14 +++++ api/tests/test_public_signaling_origin.py | 77 +++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 api/tests/test_public_signaling_origin.py diff --git a/api/routes/webrtc_signaling.py b/api/routes/webrtc_signaling.py index f4be425f..f7b4eeb3 100644 --- a/api/routes/webrtc_signaling.py +++ b/api/routes/webrtc_signaling.py @@ -545,6 +545,20 @@ async def public_signaling_websocket( await websocket.close(code=1008, reason="Invalid embed token") return + # Enforce the embed token's allowed-domain policy on the public signaling + # path, mirroring the HTTP embed endpoints (issue #330). Without this a + # leaked or replayed session token could attach from an arbitrary origin. + from api.routes.public_embed import validate_origin + + origin = websocket.headers.get("origin") or websocket.headers.get("referer", "") + if not validate_origin(origin, embed_token.allowed_domains or []): + logger.warning( + f"Domain validation failed for public signaling: {origin} " + f"not in {embed_token.allowed_domains}" + ) + await websocket.close(code=1008, reason="Domain not allowed") + return + # Create a minimal user object for compatibility with signaling manager # Use the embed token creator as the user user = await db_client.get_user_by_id(embed_token.created_by) diff --git a/api/tests/test_public_signaling_origin.py b/api/tests/test_public_signaling_origin.py new file mode 100644 index 00000000..ad893c60 --- /dev/null +++ b/api/tests/test_public_signaling_origin.py @@ -0,0 +1,77 @@ +"""Tests for public WebRTC signaling allowed-domain enforcement. + +Regression for issue #330: the public signaling WebSocket +(`/public/signaling/{session_token}`) must enforce the embed token's +allowed-domain policy, mirroring the HTTP embed endpoints. Before the fix it +validated only the session token and expiry, so a leaked or replayed session +token could attach to the signaling path from an arbitrary origin. +""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + + +class _FakeWebSocket: + """Minimal WebSocket double exposing handshake headers and close().""" + + def __init__(self, origin: str): + self.headers = {"origin": origin} + self.close = AsyncMock() + + +def _embed_session(): + return SimpleNamespace(expires_at=None, embed_token_id=1, workflow_run_id=42) + + +def _embed_token(allowed_domains): + return SimpleNamespace( + allowed_domains=allowed_domains, created_by=7, workflow_id=3 + ) + + +def _patch_deps(): + """Patch db_client + signaling_manager for a valid, non-expired session.""" + db = patch("api.routes.webrtc_signaling.db_client").start() + mgr = patch("api.routes.webrtc_signaling.signaling_manager").start() + db.get_embed_session_by_token = AsyncMock(return_value=_embed_session()) + db.get_embed_token_by_id = AsyncMock( + return_value=_embed_token(["example.com"]) + ) + db.get_user_by_id = AsyncMock(return_value=SimpleNamespace(id=7)) + mgr.handle_websocket = AsyncMock() + return db, mgr + + +@pytest.mark.asyncio +async def test_public_signaling_rejects_disallowed_origin(): + from api.routes.webrtc_signaling import public_signaling_websocket + + ws = _FakeWebSocket("https://evil.example") + _db, mgr = _patch_deps() + try: + await public_signaling_websocket(ws, "emb_session_tok") + finally: + patch.stopall() + + # Regression (issue #330): a valid session token presented from an origin + # outside the embed allowlist must be rejected before the signaling handoff. + ws.close.assert_awaited_once() + assert ws.close.await_args.kwargs.get("code") == 1008 + mgr.handle_websocket.assert_not_called() + + +@pytest.mark.asyncio +async def test_public_signaling_accepts_allowed_origin(): + from api.routes.webrtc_signaling import public_signaling_websocket + + ws = _FakeWebSocket("https://example.com") + _db, mgr = _patch_deps() + try: + await public_signaling_websocket(ws, "emb_session_tok") + finally: + patch.stopall() + + # An origin within the allowlist proceeds to the signaling handoff. + mgr.handle_websocket.assert_awaited_once()