fix: add CORS preflight handler and ACAO header for embed config endpoint (#403)

* fix: add CORS preflight handler and ACAO header for embed config endpoint

The GET /public/embed/config/{token} endpoint is fetched by external
websites (third-party embed sites). The global CORSMiddleware only covers
first-party origins, so external origins received no Access-Control-Allow-
Origin header, causing browser preflight failures.

Add an OPTIONS /config/{token} handler that validates the origin against the
token's allowed_domains list and returns the appropriate CORS headers.
Also inject Access-Control-Allow-Origin into the GET response via FastAPI's
response parameter so the actual request succeeds cross-origin.

Closes #383

* fix: complete public embed CORS handling

---------

Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
This commit is contained in:
nuthalapativarun 2026-06-03 08:57:44 -07:00 committed by GitHub
parent cdb27c1d4f
commit dace4a7efc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 466 additions and 79 deletions

View file

@ -117,6 +117,15 @@ app.add_middleware(
allow_headers=["*"],
)
def _add_public_embed_cors_middleware() -> None:
from api.routes.public_embed import PublicEmbedCORSMiddleware
app.add_middleware(PublicEmbedCORSMiddleware, api_prefix=API_PREFIX)
_add_public_embed_cors_middleware()
api_router = APIRouter()
# include subrouters here

View file

@ -7,6 +7,7 @@ They handle CORS, domain validation, and session management for embedded workflo
import secrets
from datetime import UTC, datetime, timedelta
from typing import Optional
from urllib.parse import urlsplit
from fastapi import (
APIRouter,
@ -16,6 +17,8 @@ from fastapi import (
)
from loguru import logger
from pydantic import BaseModel
from starlette.datastructures import Headers
from starlette.types import ASGIApp, Receive, Scope, Send
from api.db import db_client
from api.enums import WorkflowRunMode
@ -27,6 +30,9 @@ from api.routes.turn_credentials import (
router = APIRouter(prefix="/public/embed")
EMBED_CORS_ALLOW_HEADERS = "Content-Type, Origin"
EMBED_CORS_MAX_AGE = "86400"
class InitEmbedRequest(BaseModel):
"""Request model for initializing an embed session"""
@ -70,11 +76,9 @@ def validate_origin(origin: str, allowed_domains: list) -> bool:
# If no domains specified, allow all origins
return True
# Extract domain from origin (remove protocol)
if "://" in origin:
domain = origin.split("://")[1].split("/")[0].split(":")[0]
else:
domain = origin
domain, origin_port = _parse_origin_host_port(origin)
if not domain:
return False
# Normalize domain for www matching
def normalize_www(d: str) -> tuple[str, str]:
@ -87,16 +91,23 @@ def validate_origin(origin: str, allowed_domains: list) -> bool:
domain_variants = normalize_www(domain)
for allowed in allowed_domains:
allowed = str(allowed).strip().lower()
if allowed == "*":
return True
elif allowed.startswith("*."):
allowed_domain, allowed_port = _parse_origin_host_port(allowed)
if not allowed_domain:
continue
if allowed_port is not None and allowed_port != origin_port:
continue
if allowed_domain.startswith("*."):
# Wildcard subdomain matching
base_domain = allowed[2:]
base_domain = allowed_domain[2:]
if domain == base_domain or domain.endswith("." + base_domain):
return True
else:
# Check both www and non-www versions
allowed_variants = normalize_www(allowed)
allowed_variants = normalize_www(allowed_domain)
# If any variant of domain matches any variant of allowed, it's valid
if any(
dv in allowed_variants or av in domain_variants
@ -108,6 +119,24 @@ def validate_origin(origin: str, allowed_domains: list) -> bool:
return False
def _parse_origin_host_port(value: str) -> tuple[str, str | None]:
candidate = value.strip().lower()
if not candidate:
return "", None
if "://" not in candidate and not candidate.startswith("//"):
candidate = f"//{candidate}"
parsed = urlsplit(candidate)
try:
parsed_port = parsed.port
except ValueError:
parsed_port = None
port = str(parsed_port) if parsed_port is not None else None
return (parsed.hostname or "").rstrip("."), port
def generate_session_token() -> str:
"""Generate a cryptographically secure session token"""
return f"emb_session_{secrets.token_urlsafe(32)}"
@ -121,8 +150,120 @@ def get_request_origin(request: Request) -> str:
return origin
def _cors_response(origin: str, methods: str) -> Response:
return Response(
headers={
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Methods": methods,
"Access-Control-Allow-Headers": EMBED_CORS_ALLOW_HEADERS,
"Access-Control-Max-Age": EMBED_CORS_MAX_AGE,
"Vary": "Origin",
}
)
def _allow_embed_origin(response: Response, origin: str) -> None:
response.headers["Access-Control-Allow-Origin"] = origin
vary = response.headers.get("Vary")
if not vary:
response.headers["Vary"] = "Origin"
return
vary_values = {value.strip().lower() for value in vary.split(",")}
if "origin" not in vary_values:
response.headers["Vary"] = f"{vary}, Origin"
async def _config_preflight_response(token: str, origin: str) -> Response:
embed_token = await db_client.get_embed_token_by_token(token)
if not embed_token or not embed_token.is_active:
return Response(status_code=403)
if not validate_origin(origin, embed_token.allowed_domains or []):
return Response(status_code=403)
return _cors_response(origin, "GET, OPTIONS")
async def _turn_credentials_preflight_response(
session_token: str, origin: str
) -> Response:
embed_session = await db_client.get_embed_session_by_token(session_token)
if not embed_session:
return Response(status_code=403)
if embed_session.expires_at and embed_session.expires_at < datetime.now(UTC):
return Response(status_code=403)
embed_token = await db_client.get_embed_token_by_id(embed_session.embed_token_id)
if not embed_token:
return Response(status_code=403)
if not validate_origin(origin, embed_token.allowed_domains or []):
return Response(status_code=403)
return _cors_response(origin, "GET, OPTIONS")
async def build_public_embed_preflight_response(
path: str, origin: str, requested_method: str, api_prefix: str = "/api/v1"
) -> Response | None:
"""Handle embed preflights before global CORSMiddleware rejects external sites."""
public_embed_prefix = f"{api_prefix.rstrip('/')}/public/embed"
if path == f"{public_embed_prefix}/init":
if requested_method.upper() != "POST":
return Response(status_code=405)
return _cors_response(origin, "POST, OPTIONS")
config_prefix = f"{public_embed_prefix}/config/"
if path.startswith(config_prefix):
if requested_method.upper() != "GET":
return Response(status_code=405)
token = path[len(config_prefix) :].split("/", 1)[0]
return await _config_preflight_response(token, origin)
turn_credentials_prefix = f"{public_embed_prefix}/turn-credentials/"
if path.startswith(turn_credentials_prefix):
if requested_method.upper() != "GET":
return Response(status_code=405)
session_token = path[len(turn_credentials_prefix) :].split("/", 1)[0]
return await _turn_credentials_preflight_response(session_token, origin)
return None
class PublicEmbedCORSMiddleware:
"""Allow token-gated embed CORS before global SaaS CORS rejects preflights."""
def __init__(self, app: ASGIApp, api_prefix: str = "/api/v1"):
self.app = app
self.api_prefix = api_prefix
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http" or scope.get("method") != "OPTIONS":
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
origin = headers.get("origin")
requested_method = headers.get("access-control-request-method")
if origin and requested_method:
response = await build_public_embed_preflight_response(
scope.get("path", ""), origin, requested_method, self.api_prefix
)
if response is not None:
await response(scope, receive, send)
return
await self.app(scope, receive, send)
@router.post("/init", response_model=InitEmbedResponse)
async def initialize_embed_session(request: Request, init_request: InitEmbedRequest):
async def initialize_embed_session(
request: Request, init_request: InitEmbedRequest, response: Response
):
"""Initialize an embed session with token validation and domain checking.
This endpoint:
@ -158,6 +299,9 @@ async def initialize_embed_session(request: Request, init_request: InitEmbedRequ
)
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
if origin:
_allow_embed_origin(response, origin)
# Create workflow run
try:
workflow_run = await db_client.create_workflow_run(
@ -204,8 +348,19 @@ async def initialize_embed_session(request: Request, init_request: InitEmbedRequ
)
@router.options("/config/{token}")
async def options_embed_config(token: str, request: Request):
"""Fallback OPTIONS handler for the embed config endpoint.
Browser preflights include Access-Control-Request-Method and are handled by
PublicEmbedCORSMiddleware before global CORS. This keeps non-conformant
OPTIONS requests on the same validation path.
"""
return await _config_preflight_response(token, request.headers.get("origin", ""))
@router.get("/config/{token}", response_model=EmbedConfigResponse)
async def get_embed_config(token: str, request: Request):
async def get_embed_config(token: str, request: Request, response: Response):
"""Get embed configuration without creating a session.
This endpoint is used to fetch widget configuration for display purposes
@ -226,6 +381,11 @@ async def get_embed_config(token: str, request: Request):
if not validate_origin(origin, embed_token.allowed_domains or []):
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
# Set CORS header explicitly; the global CORSMiddleware covers only
# first-party origins; this endpoint is fetched by external embed sites.
if origin:
_allow_embed_origin(response, origin)
# Extract settings with defaults
settings = embed_token.settings or {}
@ -243,24 +403,20 @@ async def get_embed_config(token: str, request: Request):
@router.options("/init")
async def options_init(request: Request):
"""Handle CORS preflight for init endpoint"""
"""Fallback OPTIONS handler for init endpoint."""
# Browser preflights are handled by PublicEmbedCORSMiddleware before global CORS.
# For init endpoint, we need to check the token in the request body
# But OPTIONS requests don't have body, so we'll be permissive
# The actual validation happens in the POST request
origin = request.headers.get("origin", "*")
return Response(
headers={
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Origin",
"Access-Control-Max-Age": "86400",
}
)
return _cors_response(origin, "POST, OPTIONS")
@router.get("/turn-credentials/{session_token}", response_model=TurnCredentialsResponse)
async def get_public_turn_credentials(session_token: str, request: Request):
async def get_public_turn_credentials(
session_token: str, request: Request, response: Response
):
"""Get TURN credentials for an embed session.
This endpoint allows embedded widgets to obtain TURN server credentials
@ -295,6 +451,9 @@ async def get_public_turn_credentials(session_token: str, request: Request):
)
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
if origin:
_allow_embed_origin(response, origin)
# Check if TURN is configured
if not TURN_SECRET:
raise HTTPException(
@ -316,63 +475,8 @@ async def get_public_turn_credentials(session_token: str, request: Request):
@router.options("/turn-credentials/{session_token}")
async def options_turn_credentials(request: Request, session_token: str):
"""Handle CORS preflight for TURN credentials endpoint"""
origin = request.headers.get("origin", "*")
# Try to validate the session token and get allowed domains
allowed_origin = origin
try:
embed_session = await db_client.get_embed_session_by_token(session_token)
if embed_session:
embed_token = await db_client.get_embed_token_by_id(
embed_session.embed_token_id
)
if embed_token:
# Check if origin is in allowed domains (empty means allow all)
if validate_origin(origin, embed_token.allowed_domains or []):
allowed_origin = origin
else:
allowed_origin = ""
except Exception:
# On error, be permissive for OPTIONS
pass
return Response(
headers={
"Access-Control-Allow-Origin": allowed_origin,
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "86400",
}
)
@router.options("/config/{token}")
async def options_config(request: Request, token: str):
"""Handle CORS preflight for config endpoint"""
# Get origin header
origin = request.headers.get("origin", "*")
# Try to validate the token and get allowed domains
allowed_origin = origin
try:
embed_token = await db_client.get_embed_token_by_token(token)
if embed_token and embed_token.is_active:
# Check if origin is in allowed domains
if validate_origin(origin, embed_token.allowed_domains or []):
allowed_origin = origin
else:
# If not allowed, don't include the origin
allowed_origin = ""
except Exception:
# On error, be permissive for OPTIONS
pass
return Response(
headers={
"Access-Control-Allow-Origin": allowed_origin,
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "86400",
}
"""Fallback OPTIONS handler for TURN credentials endpoint."""
# Browser preflights are handled by PublicEmbedCORSMiddleware before global CORS.
return await _turn_credentials_preflight_response(
session_token, request.headers.get("origin", "")
)

View file

@ -0,0 +1,274 @@
from types import SimpleNamespace
import pytest
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.testclient import TestClient
from api.routes.public_embed import PublicEmbedCORSMiddleware, router
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://app.dograh.com"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(PublicEmbedCORSMiddleware, api_prefix="/api/v1")
app.include_router(router, prefix="/api/v1")
client = TestClient(app, raise_server_exceptions=False)
_ACTIVE_TOKEN = SimpleNamespace(
id=10,
is_active=True,
expires_at=None,
allowed_domains=[],
workflow_id=1,
created_by=7,
usage_limit=None,
usage_count=0,
settings={},
)
_RESTRICTED_TOKEN = SimpleNamespace(
id=20,
is_active=True,
expires_at=None,
allowed_domains=["allowed.example.com"],
workflow_id=2,
created_by=7,
usage_limit=None,
usage_count=0,
settings={},
)
_LOCALHOST_TOKEN = SimpleNamespace(
id=30,
is_active=True,
expires_at=None,
allowed_domains=["localhost:3000", "localhost:3020"],
workflow_id=3,
created_by=7,
usage_limit=None,
usage_count=0,
settings={},
)
@pytest.fixture(autouse=True)
def _patch_db(monkeypatch):
async def _get_token(token):
if token == "valid":
return _ACTIVE_TOKEN
if token == "restricted":
return _RESTRICTED_TOKEN
if token == "localhost":
return _LOCALHOST_TOKEN
return None
async def _get_token_by_id(token_id):
if token_id == _ACTIVE_TOKEN.id:
return _ACTIVE_TOKEN
if token_id == _RESTRICTED_TOKEN.id:
return _RESTRICTED_TOKEN
if token_id == _LOCALHOST_TOKEN.id:
return _LOCALHOST_TOKEN
return None
async def _get_session(session_token):
if session_token == "session-valid":
return SimpleNamespace(embed_token_id=_ACTIVE_TOKEN.id, expires_at=None)
if session_token == "session-restricted":
return SimpleNamespace(embed_token_id=_RESTRICTED_TOKEN.id, expires_at=None)
return None
async def _create_workflow_run(**_kwargs):
return SimpleNamespace(id=123)
async def _noop(*_args, **_kwargs):
return None
monkeypatch.setattr(
"api.routes.public_embed.db_client.get_embed_token_by_token",
_get_token,
)
monkeypatch.setattr(
"api.routes.public_embed.db_client.get_embed_token_by_id",
_get_token_by_id,
)
monkeypatch.setattr(
"api.routes.public_embed.db_client.get_embed_session_by_token",
_get_session,
)
monkeypatch.setattr(
"api.routes.public_embed.db_client.create_workflow_run",
_create_workflow_run,
)
monkeypatch.setattr(
"api.routes.public_embed.db_client.create_embed_session",
_noop,
)
monkeypatch.setattr(
"api.routes.public_embed.db_client.increment_embed_token_usage",
_noop,
)
monkeypatch.setattr("api.routes.public_embed.TURN_SECRET", "test-secret")
monkeypatch.setattr(
"api.routes.public_embed.generate_turn_credentials",
lambda _user_id: {
"username": "turn-user",
"password": "turn-password",
"ttl": 3600,
"uris": ["turn:example.com:3478"],
},
)
def _assert_embed_cors(resp, origin: str):
assert resp.headers.get("access-control-allow-origin") == origin
assert "origin" in {
value.strip().lower() for value in resp.headers.get("vary", "").split(",")
}
def test_options_config_returns_acao_for_allowed_origin():
origin = "https://mysite.vercel.app"
resp = client.options(
"/api/v1/public/embed/config/valid",
headers={
"Origin": origin,
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_options_config_accepts_allowed_localhost_port():
origin = "http://localhost:3020"
resp = client.options(
"/api/v1/public/embed/config/localhost",
headers={
"Origin": origin,
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_options_config_rejects_unknown_token():
resp = client.options(
"/api/v1/public/embed/config/unknown",
headers={
"Origin": "https://mysite.vercel.app",
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 403
def test_options_config_rejects_disallowed_origin():
resp = client.options(
"/api/v1/public/embed/config/restricted",
headers={
"Origin": "https://notallowed.example.com",
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 403
def test_get_config_includes_acao_header():
origin = "https://mysite.vercel.app"
resp = client.get(
"/api/v1/public/embed/config/valid",
headers={"Origin": origin},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_get_config_accepts_allowed_localhost_port():
origin = "http://localhost:3020"
resp = client.get(
"/api/v1/public/embed/config/localhost",
headers={"Origin": origin},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_get_config_rejects_unlisted_localhost_port():
resp = client.get(
"/api/v1/public/embed/config/localhost",
headers={"Origin": "http://localhost:3021"},
)
assert resp.status_code == 403
def test_get_config_rejects_disallowed_origin():
resp = client.get(
"/api/v1/public/embed/config/restricted",
headers={"Origin": "https://notallowed.example.com"},
)
assert resp.status_code == 403
def test_init_includes_acao_header():
origin = "https://mysite.vercel.app"
resp = client.post(
"/api/v1/public/embed/init",
headers={"Origin": origin},
json={"token": "valid"},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_turn_credentials_includes_acao_header():
origin = "https://mysite.vercel.app"
resp = client.get(
"/api/v1/public/embed/turn-credentials/session-valid",
headers={"Origin": origin},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_options_init_returns_acao_for_allowed_origin():
origin = "https://mysite.vercel.app"
resp = client.options(
"/api/v1/public/embed/init",
headers={
"Origin": origin,
"Access-Control-Request-Method": "POST",
},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_options_turn_credentials_returns_acao_for_allowed_origin():
origin = "https://mysite.vercel.app"
resp = client.options(
"/api/v1/public/embed/turn-credentials/session-valid",
headers={
"Origin": origin,
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_options_turn_credentials_rejects_disallowed_origin():
resp = client.options(
"/api/v1/public/embed/turn-credentials/session-restricted",
headers={
"Origin": "https://notallowed.example.com",
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 403