dograh/api/routes/public_embed.py
nuthalapativarun dace4a7efc
fix: add CORS preflight handler and ACAO header for embed config endpoint (#403)
* fix: add CORS preflight handler and ACAO header for embed config endpoint

The GET /public/embed/config/{token} endpoint is fetched by external
websites (third-party embed sites). The global CORSMiddleware only covers
first-party origins, so external origins received no Access-Control-Allow-
Origin header, causing browser preflight failures.

Add an OPTIONS /config/{token} handler that validates the origin against the
token's allowed_domains list and returns the appropriate CORS headers.
Also inject Access-Control-Allow-Origin into the GET response via FastAPI's
response parameter so the actual request succeeds cross-origin.

Closes #383

* fix: complete public embed CORS handling

---------

Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
2026-06-03 21:27:44 +05:30

482 lines
16 KiB
Python

"""Public API endpoints for workflow embedding.
These endpoints are accessible without authentication but require valid embed tokens.
They handle CORS, domain validation, and session management for embedded workflows.
"""
import secrets
from datetime import UTC, datetime, timedelta
from typing import Optional
from urllib.parse import urlsplit
from fastapi import (
APIRouter,
HTTPException,
Request,
Response,
)
from loguru import logger
from pydantic import BaseModel
from starlette.datastructures import Headers
from starlette.types import ASGIApp, Receive, Scope, Send
from api.db import db_client
from api.enums import WorkflowRunMode
from api.routes.turn_credentials import (
TURN_SECRET,
TurnCredentialsResponse,
generate_turn_credentials,
)
router = APIRouter(prefix="/public/embed")
EMBED_CORS_ALLOW_HEADERS = "Content-Type, Origin"
EMBED_CORS_MAX_AGE = "86400"
class InitEmbedRequest(BaseModel):
"""Request model for initializing an embed session"""
token: str
context_variables: Optional[dict] = None
class InitEmbedResponse(BaseModel):
"""Response model for embed initialization"""
session_token: str
workflow_run_id: int
config: dict
class EmbedConfigResponse(BaseModel):
"""Response model for embed configuration"""
workflow_id: int
settings: dict
theme: str
position: str
button_text: str
button_color: str
size: str
auto_start: bool
def validate_origin(origin: str, allowed_domains: list) -> bool:
"""Validate if the origin is in the allowed domains list.
Args:
origin: The origin header from the request
allowed_domains: List of allowed domain patterns
Returns:
True if origin is allowed, False otherwise
"""
if not allowed_domains:
# If no domains specified, allow all origins
return True
domain, origin_port = _parse_origin_host_port(origin)
if not domain:
return False
# Normalize domain for www matching
def normalize_www(d: str) -> tuple[str, str]:
"""Return both www and non-www versions of a domain"""
if d.startswith("www."):
return (d, d[4:]) # (www.x.com, x.com)
else:
return (d, f"www.{d}") # (x.com, www.x.com)
domain_variants = normalize_www(domain)
for allowed in allowed_domains:
allowed = str(allowed).strip().lower()
if allowed == "*":
return True
allowed_domain, allowed_port = _parse_origin_host_port(allowed)
if not allowed_domain:
continue
if allowed_port is not None and allowed_port != origin_port:
continue
if allowed_domain.startswith("*."):
# Wildcard subdomain matching
base_domain = allowed_domain[2:]
if domain == base_domain or domain.endswith("." + base_domain):
return True
else:
# Check both www and non-www versions
allowed_variants = normalize_www(allowed_domain)
# If any variant of domain matches any variant of allowed, it's valid
if any(
dv in allowed_variants or av in domain_variants
for dv in domain_variants
for av in allowed_variants
):
return True
return False
def _parse_origin_host_port(value: str) -> tuple[str, str | None]:
candidate = value.strip().lower()
if not candidate:
return "", None
if "://" not in candidate and not candidate.startswith("//"):
candidate = f"//{candidate}"
parsed = urlsplit(candidate)
try:
parsed_port = parsed.port
except ValueError:
parsed_port = None
port = str(parsed_port) if parsed_port is not None else None
return (parsed.hostname or "").rstrip("."), port
def generate_session_token() -> str:
"""Generate a cryptographically secure session token"""
return f"emb_session_{secrets.token_urlsafe(32)}"
def get_request_origin(request: Request) -> str:
"""Extract origin from request headers, falling back to referer if not present."""
origin = request.headers.get("origin", "")
if not origin:
origin = request.headers.get("referer", "")
return origin
def _cors_response(origin: str, methods: str) -> Response:
return Response(
headers={
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Methods": methods,
"Access-Control-Allow-Headers": EMBED_CORS_ALLOW_HEADERS,
"Access-Control-Max-Age": EMBED_CORS_MAX_AGE,
"Vary": "Origin",
}
)
def _allow_embed_origin(response: Response, origin: str) -> None:
response.headers["Access-Control-Allow-Origin"] = origin
vary = response.headers.get("Vary")
if not vary:
response.headers["Vary"] = "Origin"
return
vary_values = {value.strip().lower() for value in vary.split(",")}
if "origin" not in vary_values:
response.headers["Vary"] = f"{vary}, Origin"
async def _config_preflight_response(token: str, origin: str) -> Response:
embed_token = await db_client.get_embed_token_by_token(token)
if not embed_token or not embed_token.is_active:
return Response(status_code=403)
if not validate_origin(origin, embed_token.allowed_domains or []):
return Response(status_code=403)
return _cors_response(origin, "GET, OPTIONS")
async def _turn_credentials_preflight_response(
session_token: str, origin: str
) -> Response:
embed_session = await db_client.get_embed_session_by_token(session_token)
if not embed_session:
return Response(status_code=403)
if embed_session.expires_at and embed_session.expires_at < datetime.now(UTC):
return Response(status_code=403)
embed_token = await db_client.get_embed_token_by_id(embed_session.embed_token_id)
if not embed_token:
return Response(status_code=403)
if not validate_origin(origin, embed_token.allowed_domains or []):
return Response(status_code=403)
return _cors_response(origin, "GET, OPTIONS")
async def build_public_embed_preflight_response(
path: str, origin: str, requested_method: str, api_prefix: str = "/api/v1"
) -> Response | None:
"""Handle embed preflights before global CORSMiddleware rejects external sites."""
public_embed_prefix = f"{api_prefix.rstrip('/')}/public/embed"
if path == f"{public_embed_prefix}/init":
if requested_method.upper() != "POST":
return Response(status_code=405)
return _cors_response(origin, "POST, OPTIONS")
config_prefix = f"{public_embed_prefix}/config/"
if path.startswith(config_prefix):
if requested_method.upper() != "GET":
return Response(status_code=405)
token = path[len(config_prefix) :].split("/", 1)[0]
return await _config_preflight_response(token, origin)
turn_credentials_prefix = f"{public_embed_prefix}/turn-credentials/"
if path.startswith(turn_credentials_prefix):
if requested_method.upper() != "GET":
return Response(status_code=405)
session_token = path[len(turn_credentials_prefix) :].split("/", 1)[0]
return await _turn_credentials_preflight_response(session_token, origin)
return None
class PublicEmbedCORSMiddleware:
"""Allow token-gated embed CORS before global SaaS CORS rejects preflights."""
def __init__(self, app: ASGIApp, api_prefix: str = "/api/v1"):
self.app = app
self.api_prefix = api_prefix
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http" or scope.get("method") != "OPTIONS":
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
origin = headers.get("origin")
requested_method = headers.get("access-control-request-method")
if origin and requested_method:
response = await build_public_embed_preflight_response(
scope.get("path", ""), origin, requested_method, self.api_prefix
)
if response is not None:
await response(scope, receive, send)
return
await self.app(scope, receive, send)
@router.post("/init", response_model=InitEmbedResponse)
async def initialize_embed_session(
request: Request, init_request: InitEmbedRequest, response: Response
):
"""Initialize an embed session with token validation and domain checking.
This endpoint:
1. Validates the embed token
2. Checks domain whitelist
3. Creates a workflow run
4. Generates a temporary session token
5. Returns configuration for the widget
"""
origin = get_request_origin(request)
# Validate embed token
embed_token = await db_client.get_embed_token_by_token(init_request.token)
if not embed_token:
raise HTTPException(status_code=404, detail="Invalid embed token")
# Check if token is active
if not embed_token.is_active:
raise HTTPException(status_code=403, detail="Embed token is inactive")
# Check expiration
if embed_token.expires_at and embed_token.expires_at < datetime.now(UTC):
raise HTTPException(status_code=403, detail="Embed token has expired")
# Check usage limit
if embed_token.usage_limit and embed_token.usage_count >= embed_token.usage_limit:
raise HTTPException(status_code=403, detail="Embed token usage limit exceeded")
# Validate domain
if not validate_origin(origin, embed_token.allowed_domains or []):
logger.warning(
f"Domain validation failed: {origin} not in {embed_token.allowed_domains}"
)
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
if origin:
_allow_embed_origin(response, origin)
# Create workflow run
try:
workflow_run = await db_client.create_workflow_run(
name=f"Embed Run - {datetime.now(UTC).isoformat()}",
workflow_id=embed_token.workflow_id,
mode=WorkflowRunMode.SMALLWEBRTC.value,
user_id=embed_token.created_by, # Use token creator as run owner
initial_context=init_request.context_variables,
)
except Exception as e:
logger.error(f"Failed to create workflow run: {e}")
raise HTTPException(status_code=500, detail="Failed to create workflow run")
# Generate session token
session_token = generate_session_token()
# Create embed session
try:
await db_client.create_embed_session(
session_token=session_token,
embed_token_id=embed_token.id,
workflow_run_id=workflow_run.id,
client_ip=request.client.host if request.client else None,
user_agent=request.headers.get("user-agent", "")[:500],
origin=origin[:255],
expires_at=datetime.now(UTC) + timedelta(hours=1), # 1 hour expiry
)
except Exception as e:
logger.error(f"Failed to create embed session: {e}")
raise HTTPException(status_code=500, detail="Failed to create session")
# Increment usage count
await db_client.increment_embed_token_usage(embed_token.id)
# Prepare configuration
config = {
"workflow_id": embed_token.workflow_id,
"workflow_run_id": workflow_run.id,
**(embed_token.settings or {}),
}
return InitEmbedResponse(
session_token=session_token, workflow_run_id=workflow_run.id, config=config
)
@router.options("/config/{token}")
async def options_embed_config(token: str, request: Request):
"""Fallback OPTIONS handler for the embed config endpoint.
Browser preflights include Access-Control-Request-Method and are handled by
PublicEmbedCORSMiddleware before global CORS. This keeps non-conformant
OPTIONS requests on the same validation path.
"""
return await _config_preflight_response(token, request.headers.get("origin", ""))
@router.get("/config/{token}", response_model=EmbedConfigResponse)
async def get_embed_config(token: str, request: Request, response: Response):
"""Get embed configuration without creating a session.
This endpoint is used to fetch widget configuration for display purposes
without actually starting a call session.
"""
origin = get_request_origin(request)
# Validate embed token
embed_token = await db_client.get_embed_token_by_token(token)
if not embed_token:
raise HTTPException(status_code=404, detail="Invalid embed token")
# Check if token is active
if not embed_token.is_active:
raise HTTPException(status_code=403, detail="Embed token is inactive")
# Validate domain
if not validate_origin(origin, embed_token.allowed_domains or []):
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
# Set CORS header explicitly; the global CORSMiddleware covers only
# first-party origins; this endpoint is fetched by external embed sites.
if origin:
_allow_embed_origin(response, origin)
# Extract settings with defaults
settings = embed_token.settings or {}
return EmbedConfigResponse(
workflow_id=embed_token.workflow_id,
settings=settings,
theme=settings.get("theme", "light"),
position=settings.get("position", "bottom-right"),
button_text=settings.get("buttonText", "Start Voice Call"),
button_color=settings.get("buttonColor", "#3B82F6"),
size=settings.get("size", "medium"),
auto_start=settings.get("autoStart", False),
)
@router.options("/init")
async def options_init(request: Request):
"""Fallback OPTIONS handler for init endpoint."""
# Browser preflights are handled by PublicEmbedCORSMiddleware before global CORS.
# For init endpoint, we need to check the token in the request body
# But OPTIONS requests don't have body, so we'll be permissive
# The actual validation happens in the POST request
origin = request.headers.get("origin", "*")
return _cors_response(origin, "POST, OPTIONS")
@router.get("/turn-credentials/{session_token}", response_model=TurnCredentialsResponse)
async def get_public_turn_credentials(
session_token: str, request: Request, response: Response
):
"""Get TURN credentials for an embed session.
This endpoint allows embedded widgets to obtain TURN server credentials
for WebRTC connections without requiring authentication.
Args:
session_token: The session token from embed initialization
Returns:
TurnCredentialsResponse with username, password, ttl, and TURN URIs
"""
origin = get_request_origin(request)
# Validate session token
embed_session = await db_client.get_embed_session_by_token(session_token)
if not embed_session:
raise HTTPException(status_code=404, detail="Invalid session token")
# Check if session is expired
if embed_session.expires_at and embed_session.expires_at < datetime.now(UTC):
raise HTTPException(status_code=403, detail="Session expired")
# Get the embed token to check allowed domains
embed_token = await db_client.get_embed_token_by_id(embed_session.embed_token_id)
if not embed_token:
raise HTTPException(status_code=404, detail="Invalid embed token")
# Validate domain (empty allowed_domains means allow all)
if not validate_origin(origin, embed_token.allowed_domains or []):
logger.warning(
f"Domain validation failed for TURN credentials: {origin} not in {embed_token.allowed_domains}"
)
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
if origin:
_allow_embed_origin(response, origin)
# Check if TURN is configured
if not TURN_SECRET:
raise HTTPException(
status_code=503,
detail="TURN server not configured",
)
try:
# Use session token as identifier for TURN credentials
credentials = generate_turn_credentials(f"embed:{session_token[:16]}")
return TurnCredentialsResponse(**credentials)
except Exception as e:
logger.error(f"Failed to generate TURN credentials for embed session: {e}")
raise HTTPException(
status_code=500,
detail="Failed to generate TURN credentials",
)
@router.options("/turn-credentials/{session_token}")
async def options_turn_credentials(request: Request, session_token: str):
"""Fallback OPTIONS handler for TURN credentials endpoint."""
# Browser preflights are handled by PublicEmbedCORSMiddleware before global CORS.
return await _turn_credentials_preflight_response(
session_token, request.headers.get("origin", "")
)