fix(webrtc): enforce embed allowed-domain policy on public signaling websocket (#388)

The public WebRTC signaling WebSocket (`/public/signaling/{session_token}`)
validated only the session token and its expiry, not the embed token's
allowed-domain policy that the HTTP embed endpoints already enforce. A leaked
or replayed session token could therefore attach to the signaling path from
an arbitrary origin.

Validate the request origin against `embed_token.allowed_domains` (reusing the
existing `validate_origin` helper) before the signaling handoff, rejecting
disallowed origins with a 1008 close — mirroring the HTTP embed endpoints.

Closes #330

Co-authored-by: shiminshen <16914659+shiminshen@users.noreply.github.com>
This commit is contained in:
shiminshen 2026-06-02 15:40:30 +08:00 committed by GitHub
parent 37e7f4d2e6
commit 51ab9303ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 91 additions and 0 deletions

View file

@ -545,6 +545,20 @@ async def public_signaling_websocket(
await websocket.close(code=1008, reason="Invalid embed token")
return
# Enforce the embed token's allowed-domain policy on the public signaling
# path, mirroring the HTTP embed endpoints (issue #330). Without this a
# leaked or replayed session token could attach from an arbitrary origin.
from api.routes.public_embed import validate_origin
origin = websocket.headers.get("origin") or websocket.headers.get("referer", "")
if not validate_origin(origin, embed_token.allowed_domains or []):
logger.warning(
f"Domain validation failed for public signaling: {origin} "
f"not in {embed_token.allowed_domains}"
)
await websocket.close(code=1008, reason="Domain not allowed")
return
# Create a minimal user object for compatibility with signaling manager
# Use the embed token creator as the user
user = await db_client.get_user_by_id(embed_token.created_by)

View file

@ -0,0 +1,77 @@
"""Tests for public WebRTC signaling allowed-domain enforcement.
Regression for issue #330: the public signaling WebSocket
(`/public/signaling/{session_token}`) must enforce the embed token's
allowed-domain policy, mirroring the HTTP embed endpoints. Before the fix it
validated only the session token and expiry, so a leaked or replayed session
token could attach to the signaling path from an arbitrary origin.
"""
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
class _FakeWebSocket:
"""Minimal WebSocket double exposing handshake headers and close()."""
def __init__(self, origin: str):
self.headers = {"origin": origin}
self.close = AsyncMock()
def _embed_session():
return SimpleNamespace(expires_at=None, embed_token_id=1, workflow_run_id=42)
def _embed_token(allowed_domains):
return SimpleNamespace(
allowed_domains=allowed_domains, created_by=7, workflow_id=3
)
def _patch_deps():
"""Patch db_client + signaling_manager for a valid, non-expired session."""
db = patch("api.routes.webrtc_signaling.db_client").start()
mgr = patch("api.routes.webrtc_signaling.signaling_manager").start()
db.get_embed_session_by_token = AsyncMock(return_value=_embed_session())
db.get_embed_token_by_id = AsyncMock(
return_value=_embed_token(["example.com"])
)
db.get_user_by_id = AsyncMock(return_value=SimpleNamespace(id=7))
mgr.handle_websocket = AsyncMock()
return db, mgr
@pytest.mark.asyncio
async def test_public_signaling_rejects_disallowed_origin():
from api.routes.webrtc_signaling import public_signaling_websocket
ws = _FakeWebSocket("https://evil.example")
_db, mgr = _patch_deps()
try:
await public_signaling_websocket(ws, "emb_session_tok")
finally:
patch.stopall()
# Regression (issue #330): a valid session token presented from an origin
# outside the embed allowlist must be rejected before the signaling handoff.
ws.close.assert_awaited_once()
assert ws.close.await_args.kwargs.get("code") == 1008
mgr.handle_websocket.assert_not_called()
@pytest.mark.asyncio
async def test_public_signaling_accepts_allowed_origin():
from api.routes.webrtc_signaling import public_signaling_websocket
ws = _FakeWebSocket("https://example.com")
_db, mgr = _patch_deps()
try:
await public_signaling_websocket(ws, "emb_session_tok")
finally:
patch.stopall()
# An origin within the allowlist proceeds to the signaling handoff.
mgr.handle_websocket.assert_awaited_once()