mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
* fix: add CORS preflight handler and ACAO header for embed config endpoint
The GET /public/embed/config/{token} endpoint is fetched by external
websites (third-party embed sites). The global CORSMiddleware only covers
first-party origins, so external origins received no Access-Control-Allow-
Origin header, causing browser preflight failures.
Add an OPTIONS /config/{token} handler that validates the origin against the
token's allowed_domains list and returns the appropriate CORS headers.
Also inject Access-Control-Allow-Origin into the GET response via FastAPI's
response parameter so the actual request succeeds cross-origin.
Closes #383
* fix: complete public embed CORS handling
---------
Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
482 lines
16 KiB
Python
482 lines
16 KiB
Python
"""Public API endpoints for workflow embedding.
|
|
|
|
These endpoints are accessible without authentication but require valid embed tokens.
|
|
They handle CORS, domain validation, and session management for embedded workflows.
|
|
"""
|
|
|
|
import secrets
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import Optional
|
|
from urllib.parse import urlsplit
|
|
|
|
from fastapi import (
|
|
APIRouter,
|
|
HTTPException,
|
|
Request,
|
|
Response,
|
|
)
|
|
from loguru import logger
|
|
from pydantic import BaseModel
|
|
from starlette.datastructures import Headers
|
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
|
|
from api.db import db_client
|
|
from api.enums import WorkflowRunMode
|
|
from api.routes.turn_credentials import (
|
|
TURN_SECRET,
|
|
TurnCredentialsResponse,
|
|
generate_turn_credentials,
|
|
)
|
|
|
|
router = APIRouter(prefix="/public/embed")
|
|
|
|
EMBED_CORS_ALLOW_HEADERS = "Content-Type, Origin"
|
|
EMBED_CORS_MAX_AGE = "86400"
|
|
|
|
|
|
class InitEmbedRequest(BaseModel):
|
|
"""Request model for initializing an embed session"""
|
|
|
|
token: str
|
|
context_variables: Optional[dict] = None
|
|
|
|
|
|
class InitEmbedResponse(BaseModel):
|
|
"""Response model for embed initialization"""
|
|
|
|
session_token: str
|
|
workflow_run_id: int
|
|
config: dict
|
|
|
|
|
|
class EmbedConfigResponse(BaseModel):
|
|
"""Response model for embed configuration"""
|
|
|
|
workflow_id: int
|
|
settings: dict
|
|
theme: str
|
|
position: str
|
|
button_text: str
|
|
button_color: str
|
|
size: str
|
|
auto_start: bool
|
|
|
|
|
|
def validate_origin(origin: str, allowed_domains: list) -> bool:
|
|
"""Validate if the origin is in the allowed domains list.
|
|
|
|
Args:
|
|
origin: The origin header from the request
|
|
allowed_domains: List of allowed domain patterns
|
|
|
|
Returns:
|
|
True if origin is allowed, False otherwise
|
|
"""
|
|
if not allowed_domains:
|
|
# If no domains specified, allow all origins
|
|
return True
|
|
|
|
domain, origin_port = _parse_origin_host_port(origin)
|
|
if not domain:
|
|
return False
|
|
|
|
# Normalize domain for www matching
|
|
def normalize_www(d: str) -> tuple[str, str]:
|
|
"""Return both www and non-www versions of a domain"""
|
|
if d.startswith("www."):
|
|
return (d, d[4:]) # (www.x.com, x.com)
|
|
else:
|
|
return (d, f"www.{d}") # (x.com, www.x.com)
|
|
|
|
domain_variants = normalize_www(domain)
|
|
|
|
for allowed in allowed_domains:
|
|
allowed = str(allowed).strip().lower()
|
|
if allowed == "*":
|
|
return True
|
|
allowed_domain, allowed_port = _parse_origin_host_port(allowed)
|
|
if not allowed_domain:
|
|
continue
|
|
if allowed_port is not None and allowed_port != origin_port:
|
|
continue
|
|
|
|
if allowed_domain.startswith("*."):
|
|
# Wildcard subdomain matching
|
|
base_domain = allowed_domain[2:]
|
|
if domain == base_domain or domain.endswith("." + base_domain):
|
|
return True
|
|
else:
|
|
# Check both www and non-www versions
|
|
allowed_variants = normalize_www(allowed_domain)
|
|
# If any variant of domain matches any variant of allowed, it's valid
|
|
if any(
|
|
dv in allowed_variants or av in domain_variants
|
|
for dv in domain_variants
|
|
for av in allowed_variants
|
|
):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def _parse_origin_host_port(value: str) -> tuple[str, str | None]:
|
|
candidate = value.strip().lower()
|
|
if not candidate:
|
|
return "", None
|
|
|
|
if "://" not in candidate and not candidate.startswith("//"):
|
|
candidate = f"//{candidate}"
|
|
|
|
parsed = urlsplit(candidate)
|
|
try:
|
|
parsed_port = parsed.port
|
|
except ValueError:
|
|
parsed_port = None
|
|
|
|
port = str(parsed_port) if parsed_port is not None else None
|
|
return (parsed.hostname or "").rstrip("."), port
|
|
|
|
|
|
def generate_session_token() -> str:
|
|
"""Generate a cryptographically secure session token"""
|
|
return f"emb_session_{secrets.token_urlsafe(32)}"
|
|
|
|
|
|
def get_request_origin(request: Request) -> str:
|
|
"""Extract origin from request headers, falling back to referer if not present."""
|
|
origin = request.headers.get("origin", "")
|
|
if not origin:
|
|
origin = request.headers.get("referer", "")
|
|
return origin
|
|
|
|
|
|
def _cors_response(origin: str, methods: str) -> Response:
|
|
return Response(
|
|
headers={
|
|
"Access-Control-Allow-Origin": origin,
|
|
"Access-Control-Allow-Methods": methods,
|
|
"Access-Control-Allow-Headers": EMBED_CORS_ALLOW_HEADERS,
|
|
"Access-Control-Max-Age": EMBED_CORS_MAX_AGE,
|
|
"Vary": "Origin",
|
|
}
|
|
)
|
|
|
|
|
|
def _allow_embed_origin(response: Response, origin: str) -> None:
|
|
response.headers["Access-Control-Allow-Origin"] = origin
|
|
vary = response.headers.get("Vary")
|
|
if not vary:
|
|
response.headers["Vary"] = "Origin"
|
|
return
|
|
|
|
vary_values = {value.strip().lower() for value in vary.split(",")}
|
|
if "origin" not in vary_values:
|
|
response.headers["Vary"] = f"{vary}, Origin"
|
|
|
|
|
|
async def _config_preflight_response(token: str, origin: str) -> Response:
|
|
embed_token = await db_client.get_embed_token_by_token(token)
|
|
if not embed_token or not embed_token.is_active:
|
|
return Response(status_code=403)
|
|
|
|
if not validate_origin(origin, embed_token.allowed_domains or []):
|
|
return Response(status_code=403)
|
|
|
|
return _cors_response(origin, "GET, OPTIONS")
|
|
|
|
|
|
async def _turn_credentials_preflight_response(
|
|
session_token: str, origin: str
|
|
) -> Response:
|
|
embed_session = await db_client.get_embed_session_by_token(session_token)
|
|
if not embed_session:
|
|
return Response(status_code=403)
|
|
|
|
if embed_session.expires_at and embed_session.expires_at < datetime.now(UTC):
|
|
return Response(status_code=403)
|
|
|
|
embed_token = await db_client.get_embed_token_by_id(embed_session.embed_token_id)
|
|
if not embed_token:
|
|
return Response(status_code=403)
|
|
|
|
if not validate_origin(origin, embed_token.allowed_domains or []):
|
|
return Response(status_code=403)
|
|
|
|
return _cors_response(origin, "GET, OPTIONS")
|
|
|
|
|
|
async def build_public_embed_preflight_response(
|
|
path: str, origin: str, requested_method: str, api_prefix: str = "/api/v1"
|
|
) -> Response | None:
|
|
"""Handle embed preflights before global CORSMiddleware rejects external sites."""
|
|
public_embed_prefix = f"{api_prefix.rstrip('/')}/public/embed"
|
|
|
|
if path == f"{public_embed_prefix}/init":
|
|
if requested_method.upper() != "POST":
|
|
return Response(status_code=405)
|
|
return _cors_response(origin, "POST, OPTIONS")
|
|
|
|
config_prefix = f"{public_embed_prefix}/config/"
|
|
if path.startswith(config_prefix):
|
|
if requested_method.upper() != "GET":
|
|
return Response(status_code=405)
|
|
token = path[len(config_prefix) :].split("/", 1)[0]
|
|
return await _config_preflight_response(token, origin)
|
|
|
|
turn_credentials_prefix = f"{public_embed_prefix}/turn-credentials/"
|
|
if path.startswith(turn_credentials_prefix):
|
|
if requested_method.upper() != "GET":
|
|
return Response(status_code=405)
|
|
session_token = path[len(turn_credentials_prefix) :].split("/", 1)[0]
|
|
return await _turn_credentials_preflight_response(session_token, origin)
|
|
|
|
return None
|
|
|
|
|
|
class PublicEmbedCORSMiddleware:
|
|
"""Allow token-gated embed CORS before global SaaS CORS rejects preflights."""
|
|
|
|
def __init__(self, app: ASGIApp, api_prefix: str = "/api/v1"):
|
|
self.app = app
|
|
self.api_prefix = api_prefix
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
if scope["type"] != "http" or scope.get("method") != "OPTIONS":
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
headers = Headers(scope=scope)
|
|
origin = headers.get("origin")
|
|
requested_method = headers.get("access-control-request-method")
|
|
|
|
if origin and requested_method:
|
|
response = await build_public_embed_preflight_response(
|
|
scope.get("path", ""), origin, requested_method, self.api_prefix
|
|
)
|
|
if response is not None:
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
await self.app(scope, receive, send)
|
|
|
|
|
|
@router.post("/init", response_model=InitEmbedResponse)
|
|
async def initialize_embed_session(
|
|
request: Request, init_request: InitEmbedRequest, response: Response
|
|
):
|
|
"""Initialize an embed session with token validation and domain checking.
|
|
|
|
This endpoint:
|
|
1. Validates the embed token
|
|
2. Checks domain whitelist
|
|
3. Creates a workflow run
|
|
4. Generates a temporary session token
|
|
5. Returns configuration for the widget
|
|
"""
|
|
origin = get_request_origin(request)
|
|
|
|
# Validate embed token
|
|
embed_token = await db_client.get_embed_token_by_token(init_request.token)
|
|
if not embed_token:
|
|
raise HTTPException(status_code=404, detail="Invalid embed token")
|
|
|
|
# Check if token is active
|
|
if not embed_token.is_active:
|
|
raise HTTPException(status_code=403, detail="Embed token is inactive")
|
|
|
|
# Check expiration
|
|
if embed_token.expires_at and embed_token.expires_at < datetime.now(UTC):
|
|
raise HTTPException(status_code=403, detail="Embed token has expired")
|
|
|
|
# Check usage limit
|
|
if embed_token.usage_limit and embed_token.usage_count >= embed_token.usage_limit:
|
|
raise HTTPException(status_code=403, detail="Embed token usage limit exceeded")
|
|
|
|
# Validate domain
|
|
if not validate_origin(origin, embed_token.allowed_domains or []):
|
|
logger.warning(
|
|
f"Domain validation failed: {origin} not in {embed_token.allowed_domains}"
|
|
)
|
|
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
|
|
|
|
if origin:
|
|
_allow_embed_origin(response, origin)
|
|
|
|
# Create workflow run
|
|
try:
|
|
workflow_run = await db_client.create_workflow_run(
|
|
name=f"Embed Run - {datetime.now(UTC).isoformat()}",
|
|
workflow_id=embed_token.workflow_id,
|
|
mode=WorkflowRunMode.SMALLWEBRTC.value,
|
|
user_id=embed_token.created_by, # Use token creator as run owner
|
|
initial_context=init_request.context_variables,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to create workflow run: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to create workflow run")
|
|
|
|
# Generate session token
|
|
session_token = generate_session_token()
|
|
|
|
# Create embed session
|
|
try:
|
|
await db_client.create_embed_session(
|
|
session_token=session_token,
|
|
embed_token_id=embed_token.id,
|
|
workflow_run_id=workflow_run.id,
|
|
client_ip=request.client.host if request.client else None,
|
|
user_agent=request.headers.get("user-agent", "")[:500],
|
|
origin=origin[:255],
|
|
expires_at=datetime.now(UTC) + timedelta(hours=1), # 1 hour expiry
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to create embed session: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to create session")
|
|
|
|
# Increment usage count
|
|
await db_client.increment_embed_token_usage(embed_token.id)
|
|
|
|
# Prepare configuration
|
|
config = {
|
|
"workflow_id": embed_token.workflow_id,
|
|
"workflow_run_id": workflow_run.id,
|
|
**(embed_token.settings or {}),
|
|
}
|
|
|
|
return InitEmbedResponse(
|
|
session_token=session_token, workflow_run_id=workflow_run.id, config=config
|
|
)
|
|
|
|
|
|
@router.options("/config/{token}")
|
|
async def options_embed_config(token: str, request: Request):
|
|
"""Fallback OPTIONS handler for the embed config endpoint.
|
|
|
|
Browser preflights include Access-Control-Request-Method and are handled by
|
|
PublicEmbedCORSMiddleware before global CORS. This keeps non-conformant
|
|
OPTIONS requests on the same validation path.
|
|
"""
|
|
return await _config_preflight_response(token, request.headers.get("origin", ""))
|
|
|
|
|
|
@router.get("/config/{token}", response_model=EmbedConfigResponse)
|
|
async def get_embed_config(token: str, request: Request, response: Response):
|
|
"""Get embed configuration without creating a session.
|
|
|
|
This endpoint is used to fetch widget configuration for display purposes
|
|
without actually starting a call session.
|
|
"""
|
|
origin = get_request_origin(request)
|
|
|
|
# Validate embed token
|
|
embed_token = await db_client.get_embed_token_by_token(token)
|
|
if not embed_token:
|
|
raise HTTPException(status_code=404, detail="Invalid embed token")
|
|
|
|
# Check if token is active
|
|
if not embed_token.is_active:
|
|
raise HTTPException(status_code=403, detail="Embed token is inactive")
|
|
|
|
# Validate domain
|
|
if not validate_origin(origin, embed_token.allowed_domains or []):
|
|
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
|
|
|
|
# Set CORS header explicitly; the global CORSMiddleware covers only
|
|
# first-party origins; this endpoint is fetched by external embed sites.
|
|
if origin:
|
|
_allow_embed_origin(response, origin)
|
|
|
|
# Extract settings with defaults
|
|
settings = embed_token.settings or {}
|
|
|
|
return EmbedConfigResponse(
|
|
workflow_id=embed_token.workflow_id,
|
|
settings=settings,
|
|
theme=settings.get("theme", "light"),
|
|
position=settings.get("position", "bottom-right"),
|
|
button_text=settings.get("buttonText", "Start Voice Call"),
|
|
button_color=settings.get("buttonColor", "#3B82F6"),
|
|
size=settings.get("size", "medium"),
|
|
auto_start=settings.get("autoStart", False),
|
|
)
|
|
|
|
|
|
@router.options("/init")
|
|
async def options_init(request: Request):
|
|
"""Fallback OPTIONS handler for init endpoint."""
|
|
# Browser preflights are handled by PublicEmbedCORSMiddleware before global CORS.
|
|
# For init endpoint, we need to check the token in the request body
|
|
# But OPTIONS requests don't have body, so we'll be permissive
|
|
# The actual validation happens in the POST request
|
|
origin = request.headers.get("origin", "*")
|
|
|
|
return _cors_response(origin, "POST, OPTIONS")
|
|
|
|
|
|
@router.get("/turn-credentials/{session_token}", response_model=TurnCredentialsResponse)
|
|
async def get_public_turn_credentials(
|
|
session_token: str, request: Request, response: Response
|
|
):
|
|
"""Get TURN credentials for an embed session.
|
|
|
|
This endpoint allows embedded widgets to obtain TURN server credentials
|
|
for WebRTC connections without requiring authentication.
|
|
|
|
Args:
|
|
session_token: The session token from embed initialization
|
|
|
|
Returns:
|
|
TurnCredentialsResponse with username, password, ttl, and TURN URIs
|
|
"""
|
|
origin = get_request_origin(request)
|
|
|
|
# Validate session token
|
|
embed_session = await db_client.get_embed_session_by_token(session_token)
|
|
if not embed_session:
|
|
raise HTTPException(status_code=404, detail="Invalid session token")
|
|
|
|
# Check if session is expired
|
|
if embed_session.expires_at and embed_session.expires_at < datetime.now(UTC):
|
|
raise HTTPException(status_code=403, detail="Session expired")
|
|
|
|
# Get the embed token to check allowed domains
|
|
embed_token = await db_client.get_embed_token_by_id(embed_session.embed_token_id)
|
|
if not embed_token:
|
|
raise HTTPException(status_code=404, detail="Invalid embed token")
|
|
|
|
# Validate domain (empty allowed_domains means allow all)
|
|
if not validate_origin(origin, embed_token.allowed_domains or []):
|
|
logger.warning(
|
|
f"Domain validation failed for TURN credentials: {origin} not in {embed_token.allowed_domains}"
|
|
)
|
|
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
|
|
|
|
if origin:
|
|
_allow_embed_origin(response, origin)
|
|
|
|
# Check if TURN is configured
|
|
if not TURN_SECRET:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="TURN server not configured",
|
|
)
|
|
|
|
try:
|
|
# Use session token as identifier for TURN credentials
|
|
credentials = generate_turn_credentials(f"embed:{session_token[:16]}")
|
|
return TurnCredentialsResponse(**credentials)
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate TURN credentials for embed session: {e}")
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="Failed to generate TURN credentials",
|
|
)
|
|
|
|
|
|
@router.options("/turn-credentials/{session_token}")
|
|
async def options_turn_credentials(request: Request, session_token: str):
|
|
"""Fallback OPTIONS handler for TURN credentials endpoint."""
|
|
# Browser preflights are handled by PublicEmbedCORSMiddleware before global CORS.
|
|
return await _turn_credentials_preflight_response(
|
|
session_token, request.headers.get("origin", "")
|
|
)
|