mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-25 08:48:13 +02:00
fix(webrtc): enforce embed allowed-domain policy on public signaling websocket (#388)
The public WebRTC signaling WebSocket (`/public/signaling/{session_token}`)
validated only the session token and its expiry, not the embed token's
allowed-domain policy that the HTTP embed endpoints already enforce. A leaked
or replayed session token could therefore attach to the signaling path from
an arbitrary origin.
Validate the request origin against `embed_token.allowed_domains` (reusing the
existing `validate_origin` helper) before the signaling handoff, rejecting
disallowed origins with a 1008 close — mirroring the HTTP embed endpoints.
Closes #330
Co-authored-by: shiminshen <16914659+shiminshen@users.noreply.github.com>
This commit is contained in:
parent
37e7f4d2e6
commit
51ab9303ec
2 changed files with 91 additions and 0 deletions
|
|
@ -545,6 +545,20 @@ async def public_signaling_websocket(
|
|||
await websocket.close(code=1008, reason="Invalid embed token")
|
||||
return
|
||||
|
||||
# Enforce the embed token's allowed-domain policy on the public signaling
|
||||
# path, mirroring the HTTP embed endpoints (issue #330). Without this a
|
||||
# leaked or replayed session token could attach from an arbitrary origin.
|
||||
from api.routes.public_embed import validate_origin
|
||||
|
||||
origin = websocket.headers.get("origin") or websocket.headers.get("referer", "")
|
||||
if not validate_origin(origin, embed_token.allowed_domains or []):
|
||||
logger.warning(
|
||||
f"Domain validation failed for public signaling: {origin} "
|
||||
f"not in {embed_token.allowed_domains}"
|
||||
)
|
||||
await websocket.close(code=1008, reason="Domain not allowed")
|
||||
return
|
||||
|
||||
# Create a minimal user object for compatibility with signaling manager
|
||||
# Use the embed token creator as the user
|
||||
user = await db_client.get_user_by_id(embed_token.created_by)
|
||||
|
|
|
|||
77
api/tests/test_public_signaling_origin.py
Normal file
77
api/tests/test_public_signaling_origin.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
"""Tests for public WebRTC signaling allowed-domain enforcement.
|
||||
|
||||
Regression for issue #330: the public signaling WebSocket
|
||||
(`/public/signaling/{session_token}`) must enforce the embed token's
|
||||
allowed-domain policy, mirroring the HTTP embed endpoints. Before the fix it
|
||||
validated only the session token and expiry, so a leaked or replayed session
|
||||
token could attach to the signaling path from an arbitrary origin.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _FakeWebSocket:
|
||||
"""Minimal WebSocket double exposing handshake headers and close()."""
|
||||
|
||||
def __init__(self, origin: str):
|
||||
self.headers = {"origin": origin}
|
||||
self.close = AsyncMock()
|
||||
|
||||
|
||||
def _embed_session():
|
||||
return SimpleNamespace(expires_at=None, embed_token_id=1, workflow_run_id=42)
|
||||
|
||||
|
||||
def _embed_token(allowed_domains):
|
||||
return SimpleNamespace(
|
||||
allowed_domains=allowed_domains, created_by=7, workflow_id=3
|
||||
)
|
||||
|
||||
|
||||
def _patch_deps():
|
||||
"""Patch db_client + signaling_manager for a valid, non-expired session."""
|
||||
db = patch("api.routes.webrtc_signaling.db_client").start()
|
||||
mgr = patch("api.routes.webrtc_signaling.signaling_manager").start()
|
||||
db.get_embed_session_by_token = AsyncMock(return_value=_embed_session())
|
||||
db.get_embed_token_by_id = AsyncMock(
|
||||
return_value=_embed_token(["example.com"])
|
||||
)
|
||||
db.get_user_by_id = AsyncMock(return_value=SimpleNamespace(id=7))
|
||||
mgr.handle_websocket = AsyncMock()
|
||||
return db, mgr
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_signaling_rejects_disallowed_origin():
|
||||
from api.routes.webrtc_signaling import public_signaling_websocket
|
||||
|
||||
ws = _FakeWebSocket("https://evil.example")
|
||||
_db, mgr = _patch_deps()
|
||||
try:
|
||||
await public_signaling_websocket(ws, "emb_session_tok")
|
||||
finally:
|
||||
patch.stopall()
|
||||
|
||||
# Regression (issue #330): a valid session token presented from an origin
|
||||
# outside the embed allowlist must be rejected before the signaling handoff.
|
||||
ws.close.assert_awaited_once()
|
||||
assert ws.close.await_args.kwargs.get("code") == 1008
|
||||
mgr.handle_websocket.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_signaling_accepts_allowed_origin():
|
||||
from api.routes.webrtc_signaling import public_signaling_websocket
|
||||
|
||||
ws = _FakeWebSocket("https://example.com")
|
||||
_db, mgr = _patch_deps()
|
||||
try:
|
||||
await public_signaling_websocket(ws, "emb_session_tok")
|
||||
finally:
|
||||
patch.stopall()
|
||||
|
||||
# An origin within the allowlist proceeds to the signaling handoff.
|
||||
mgr.handle_websocket.assert_awaited_once()
|
||||
Loading…
Add table
Add a link
Reference in a new issue