fix: add CORS preflight handler and ACAO header for embed config endpoint (#403)

* fix: add CORS preflight handler and ACAO header for embed config endpoint

The GET /public/embed/config/{token} endpoint is fetched by external
websites (third-party embed sites). The global CORSMiddleware only covers
first-party origins, so external origins received no Access-Control-Allow-
Origin header, causing browser preflight failures.

Add an OPTIONS /config/{token} handler that validates the origin against the
token's allowed_domains list and returns the appropriate CORS headers.
Also inject Access-Control-Allow-Origin into the GET response via FastAPI's
response parameter so the actual request succeeds cross-origin.

Closes #383

* fix: complete public embed CORS handling

---------

Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
This commit is contained in:
nuthalapativarun 2026-06-03 08:57:44 -07:00 committed by GitHub
parent cdb27c1d4f
commit dace4a7efc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 466 additions and 79 deletions

View file

@ -0,0 +1,274 @@
from types import SimpleNamespace
import pytest
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.testclient import TestClient
from api.routes.public_embed import PublicEmbedCORSMiddleware, router
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://app.dograh.com"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(PublicEmbedCORSMiddleware, api_prefix="/api/v1")
app.include_router(router, prefix="/api/v1")
client = TestClient(app, raise_server_exceptions=False)
_ACTIVE_TOKEN = SimpleNamespace(
id=10,
is_active=True,
expires_at=None,
allowed_domains=[],
workflow_id=1,
created_by=7,
usage_limit=None,
usage_count=0,
settings={},
)
_RESTRICTED_TOKEN = SimpleNamespace(
id=20,
is_active=True,
expires_at=None,
allowed_domains=["allowed.example.com"],
workflow_id=2,
created_by=7,
usage_limit=None,
usage_count=0,
settings={},
)
_LOCALHOST_TOKEN = SimpleNamespace(
id=30,
is_active=True,
expires_at=None,
allowed_domains=["localhost:3000", "localhost:3020"],
workflow_id=3,
created_by=7,
usage_limit=None,
usage_count=0,
settings={},
)
@pytest.fixture(autouse=True)
def _patch_db(monkeypatch):
async def _get_token(token):
if token == "valid":
return _ACTIVE_TOKEN
if token == "restricted":
return _RESTRICTED_TOKEN
if token == "localhost":
return _LOCALHOST_TOKEN
return None
async def _get_token_by_id(token_id):
if token_id == _ACTIVE_TOKEN.id:
return _ACTIVE_TOKEN
if token_id == _RESTRICTED_TOKEN.id:
return _RESTRICTED_TOKEN
if token_id == _LOCALHOST_TOKEN.id:
return _LOCALHOST_TOKEN
return None
async def _get_session(session_token):
if session_token == "session-valid":
return SimpleNamespace(embed_token_id=_ACTIVE_TOKEN.id, expires_at=None)
if session_token == "session-restricted":
return SimpleNamespace(embed_token_id=_RESTRICTED_TOKEN.id, expires_at=None)
return None
async def _create_workflow_run(**_kwargs):
return SimpleNamespace(id=123)
async def _noop(*_args, **_kwargs):
return None
monkeypatch.setattr(
"api.routes.public_embed.db_client.get_embed_token_by_token",
_get_token,
)
monkeypatch.setattr(
"api.routes.public_embed.db_client.get_embed_token_by_id",
_get_token_by_id,
)
monkeypatch.setattr(
"api.routes.public_embed.db_client.get_embed_session_by_token",
_get_session,
)
monkeypatch.setattr(
"api.routes.public_embed.db_client.create_workflow_run",
_create_workflow_run,
)
monkeypatch.setattr(
"api.routes.public_embed.db_client.create_embed_session",
_noop,
)
monkeypatch.setattr(
"api.routes.public_embed.db_client.increment_embed_token_usage",
_noop,
)
monkeypatch.setattr("api.routes.public_embed.TURN_SECRET", "test-secret")
monkeypatch.setattr(
"api.routes.public_embed.generate_turn_credentials",
lambda _user_id: {
"username": "turn-user",
"password": "turn-password",
"ttl": 3600,
"uris": ["turn:example.com:3478"],
},
)
def _assert_embed_cors(resp, origin: str):
assert resp.headers.get("access-control-allow-origin") == origin
assert "origin" in {
value.strip().lower() for value in resp.headers.get("vary", "").split(",")
}
def test_options_config_returns_acao_for_allowed_origin():
origin = "https://mysite.vercel.app"
resp = client.options(
"/api/v1/public/embed/config/valid",
headers={
"Origin": origin,
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_options_config_accepts_allowed_localhost_port():
origin = "http://localhost:3020"
resp = client.options(
"/api/v1/public/embed/config/localhost",
headers={
"Origin": origin,
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_options_config_rejects_unknown_token():
resp = client.options(
"/api/v1/public/embed/config/unknown",
headers={
"Origin": "https://mysite.vercel.app",
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 403
def test_options_config_rejects_disallowed_origin():
resp = client.options(
"/api/v1/public/embed/config/restricted",
headers={
"Origin": "https://notallowed.example.com",
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 403
def test_get_config_includes_acao_header():
origin = "https://mysite.vercel.app"
resp = client.get(
"/api/v1/public/embed/config/valid",
headers={"Origin": origin},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_get_config_accepts_allowed_localhost_port():
origin = "http://localhost:3020"
resp = client.get(
"/api/v1/public/embed/config/localhost",
headers={"Origin": origin},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_get_config_rejects_unlisted_localhost_port():
resp = client.get(
"/api/v1/public/embed/config/localhost",
headers={"Origin": "http://localhost:3021"},
)
assert resp.status_code == 403
def test_get_config_rejects_disallowed_origin():
resp = client.get(
"/api/v1/public/embed/config/restricted",
headers={"Origin": "https://notallowed.example.com"},
)
assert resp.status_code == 403
def test_init_includes_acao_header():
origin = "https://mysite.vercel.app"
resp = client.post(
"/api/v1/public/embed/init",
headers={"Origin": origin},
json={"token": "valid"},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_turn_credentials_includes_acao_header():
origin = "https://mysite.vercel.app"
resp = client.get(
"/api/v1/public/embed/turn-credentials/session-valid",
headers={"Origin": origin},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_options_init_returns_acao_for_allowed_origin():
origin = "https://mysite.vercel.app"
resp = client.options(
"/api/v1/public/embed/init",
headers={
"Origin": origin,
"Access-Control-Request-Method": "POST",
},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_options_turn_credentials_returns_acao_for_allowed_origin():
origin = "https://mysite.vercel.app"
resp = client.options(
"/api/v1/public/embed/turn-credentials/session-valid",
headers={
"Origin": origin,
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_options_turn_credentials_rejects_disallowed_origin():
resp = client.options(
"/api/v1/public/embed/turn-credentials/session-restricted",
headers={
"Origin": "https://notallowed.example.com",
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 403