diff --git a/api/routes/public_embed.py b/api/routes/public_embed.py index 058def54..91247351 100644 --- a/api/routes/public_embed.py +++ b/api/routes/public_embed.py @@ -204,8 +204,35 @@ async def initialize_embed_session(request: Request, init_request: InitEmbedRequ ) +@router.options("/config/{token}") +async def options_embed_config(token: str, request: Request): + """Handle CORS preflight for the embed config endpoint. + + External sites fetch /config/{token} before calling Start Voice Call. + The global CORSMiddleware only covers first-party origins, so we handle + CORS explicitly here, gating on the token's allowed_domains list. + """ + origin = request.headers.get("origin", "") + + embed_token = await db_client.get_embed_token_by_token(token) + if not embed_token or not embed_token.is_active: + return Response(status_code=403) + + if not validate_origin(origin, embed_token.allowed_domains or []): + return Response(status_code=403) + + return Response( + headers={ + "Access-Control-Allow-Origin": origin, + "Access-Control-Allow-Methods": "GET, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type, Origin", + "Access-Control-Max-Age": "86400", + } + ) + + @router.get("/config/{token}", response_model=EmbedConfigResponse) -async def get_embed_config(token: str, request: Request): +async def get_embed_config(token: str, request: Request, response: Response): """Get embed configuration without creating a session. This endpoint is used to fetch widget configuration for display purposes @@ -226,6 +253,11 @@ async def get_embed_config(token: str, request: Request): if not validate_origin(origin, embed_token.allowed_domains or []): raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}") + # Set CORS header explicitly — the global CORSMiddleware covers only + # first-party origins; this endpoint is fetched by external embed sites. + if origin: + response.headers["Access-Control-Allow-Origin"] = origin + # Extract settings with defaults settings = embed_token.settings or {} diff --git a/api/tests/test_public_embed_cors.py b/api/tests/test_public_embed_cors.py new file mode 100644 index 00000000..605ef59f --- /dev/null +++ b/api/tests/test_public_embed_cors.py @@ -0,0 +1,85 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from api.routes.public_embed import router + +app = FastAPI() +app.include_router(router, prefix="/api/v1") +client = TestClient(app, raise_server_exceptions=False) + +_ACTIVE_TOKEN = SimpleNamespace( + is_active=True, + expires_at=None, + allowed_domains=[], + workflow_id=1, + settings={}, +) + +_RESTRICTED_TOKEN = SimpleNamespace( + is_active=True, + expires_at=None, + allowed_domains=["allowed.example.com"], + workflow_id=2, + settings={}, +) + + +@pytest.fixture(autouse=True) +def _patch_db(monkeypatch): + async def _get_token(token): + if token == "valid": + return _ACTIVE_TOKEN + if token == "restricted": + return _RESTRICTED_TOKEN + return None + + monkeypatch.setattr( + "api.routes.public_embed.db_client.get_embed_token_by_token", + _get_token, + ) + + +def test_options_config_returns_acao_for_allowed_origin(): + resp = client.options( + "/api/v1/public/embed/config/valid", + headers={"Origin": "https://mysite.vercel.app"}, + ) + assert resp.status_code == 200 + assert resp.headers.get("access-control-allow-origin") == "https://mysite.vercel.app" + + +def test_options_config_rejects_unknown_token(): + resp = client.options( + "/api/v1/public/embed/config/unknown", + headers={"Origin": "https://mysite.vercel.app"}, + ) + assert resp.status_code == 403 + + +def test_options_config_rejects_disallowed_origin(): + resp = client.options( + "/api/v1/public/embed/config/restricted", + headers={"Origin": "https://notallowed.example.com"}, + ) + assert resp.status_code == 403 + + +def test_get_config_includes_acao_header(): + resp = client.get( + "/api/v1/public/embed/config/valid", + headers={"Origin": "https://mysite.vercel.app"}, + ) + assert resp.status_code == 200 + assert resp.headers.get("access-control-allow-origin") == "https://mysite.vercel.app" + + +def test_get_config_rejects_disallowed_origin(): + resp = client.get( + "/api/v1/public/embed/config/restricted", + headers={"Origin": "https://notallowed.example.com"}, + ) + assert resp.status_code == 403