mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-07-01 08:59:46 +02:00
chore: drain active calls before rolling updates (#474)
* chore: drain active calls before rolling updates * fix: add a devops secret header * fix: implement PR review
This commit is contained in:
parent
327ec561d5
commit
b192d4ada7
12 changed files with 572 additions and 17 deletions
|
|
@ -12,6 +12,11 @@ UI_APP_URL="http://localhost:3000"
|
|||
DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"
|
||||
REDIS_URL="redis://:redissecret@localhost:6379"
|
||||
|
||||
# Internal devops secret for deployment scripts and lifecycle hooks.
|
||||
# scripts/rolling_update.sh sends this to protected operational endpoints via
|
||||
# X-Dograh-Devops-Secret. Use a unique random value in production.
|
||||
DOGRAH_DEVOPS_SECRET="change-me-dograh-devops-secret"
|
||||
|
||||
# AWS S3 Configuration
|
||||
ENABLE_AWS_S3="false"
|
||||
# AWS_ACCESS_KEY_ID=""
|
||||
|
|
|
|||
|
|
@ -14,4 +14,6 @@ UI_APP_URL=http://localhost:3000
|
|||
DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/test_db"
|
||||
REDIS_URL="redis://:redissecret@localhost:6379/0"
|
||||
|
||||
DOGRAH_DEVOPS_SECRET="test-dograh-devops-secret"
|
||||
|
||||
MINIO_PUBLIC_ENDPOINT=http://localhost:9000
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ STACK_AUTH_PROJECT_ID = os.getenv("STACK_AUTH_PROJECT_ID")
|
|||
STACK_PUBLISHABLE_CLIENT_KEY = os.getenv("STACK_PUBLISHABLE_CLIENT_KEY")
|
||||
DOGRAH_MPS_SECRET_KEY = os.getenv("DOGRAH_MPS_SECRET_KEY", None)
|
||||
MPS_API_URL = os.getenv("MPS_API_URL", "https://services.dograh.com")
|
||||
DOGRAH_DEVOPS_SECRET = os.getenv("DOGRAH_DEVOPS_SECRET") or None
|
||||
|
||||
# Storage Configuration
|
||||
ENABLE_AWS_S3 = os.getenv("ENABLE_AWS_S3", "false").lower() == "true"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
from fastapi import APIRouter
|
||||
import secrets
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Header, HTTPException, status
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -125,3 +128,51 @@ async def health() -> HealthResponse:
|
|||
STACK_PUBLISHABLE_CLIENT_KEY if is_stack else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ActiveCallsResponse(BaseModel):
|
||||
active_calls: int
|
||||
|
||||
|
||||
DOGRAH_DEVOPS_SECRET_HEADER = "X-Dograh-Devops-Secret"
|
||||
|
||||
|
||||
def _verify_devops_secret(
|
||||
configured_secret: str | None,
|
||||
provided_secret: str | None,
|
||||
) -> None:
|
||||
if not configured_secret:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Devops secret is not configured",
|
||||
)
|
||||
if not provided_secret or not secrets.compare_digest(
|
||||
provided_secret,
|
||||
configured_secret,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Forbidden",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health/active-calls", response_model=ActiveCallsResponse)
|
||||
async def active_calls(
|
||||
x_dograh_devops_secret: Annotated[
|
||||
str | None,
|
||||
Header(alias=DOGRAH_DEVOPS_SECRET_HEADER),
|
||||
] = None,
|
||||
) -> ActiveCallsResponse:
|
||||
"""In-flight call count for THIS worker — the drain signal for deploys.
|
||||
|
||||
A deploy orchestrator polls this per worker and waits for zero before
|
||||
sending SIGTERM, because uvicorn force-closes live call WebSockets (close
|
||||
code 1012) on SIGTERM and would cut calls mid-conversation otherwise. The
|
||||
count is per-process: one uvicorn per VM port (scripts/rolling_update.sh)
|
||||
or per Kubernetes pod (preStop hook). See api/services/pipecat/active_calls.py.
|
||||
"""
|
||||
from api.constants import DOGRAH_DEVOPS_SECRET
|
||||
from api.services.pipecat.active_calls import active_call_count
|
||||
|
||||
_verify_devops_secret(DOGRAH_DEVOPS_SECRET, x_dograh_devops_secret)
|
||||
return ActiveCallsResponse(active_calls=active_call_count())
|
||||
|
|
|
|||
35
api/services/pipecat/active_calls.py
Normal file
35
api/services/pipecat/active_calls.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""In-process registry of active pipeline runs (live voice calls).
|
||||
|
||||
Each uvicorn worker tracks the calls it is currently running so a deploy
|
||||
orchestrator can *drain* the worker before stopping it: poll the count, wait for
|
||||
zero, then send SIGTERM. Sending SIGTERM while calls are live makes uvicorn
|
||||
force-close their WebSockets (close code 1012), which cuts the calls instead of
|
||||
letting them finish — so the wait has to happen first.
|
||||
|
||||
The registry is deliberately per-process. That is exactly the unit that gets
|
||||
drained: one uvicorn process per VM port (see ``scripts/rolling_update.sh``) or
|
||||
one uvicorn process per Kubernetes pod (drained via a ``preStop`` hook). The
|
||||
count is exposed read-only at ``GET /api/v1/health/active-calls`` and is also a
|
||||
natural autoscaling signal (concurrent calls per worker).
|
||||
|
||||
Access is single-threaded (asyncio event loop), so no lock is needed. A set of
|
||||
run ids — rather than a bare counter — keeps register/unregister idempotent and
|
||||
makes the in-flight runs inspectable for debugging.
|
||||
"""
|
||||
|
||||
_active_run_ids: set[int] = set()
|
||||
|
||||
|
||||
def register_active_call(workflow_run_id: int) -> None:
|
||||
"""Mark a pipeline run as active in this worker."""
|
||||
_active_run_ids.add(workflow_run_id)
|
||||
|
||||
|
||||
def unregister_active_call(workflow_run_id: int) -> None:
|
||||
"""Mark a pipeline run as finished in this worker."""
|
||||
_active_run_ids.discard(workflow_run_id)
|
||||
|
||||
|
||||
def active_call_count() -> int:
|
||||
"""Number of pipeline runs currently active in this worker."""
|
||||
return len(_active_run_ids)
|
||||
|
|
@ -11,6 +11,10 @@ from api.services.integrations import (
|
|||
IntegrationRuntimeContext,
|
||||
create_runtime_sessions,
|
||||
)
|
||||
from api.services.pipecat.active_calls import (
|
||||
register_active_call,
|
||||
unregister_active_call,
|
||||
)
|
||||
from api.services.pipecat.audio_config import AudioConfig, create_audio_config
|
||||
from api.services.pipecat.event_handlers import (
|
||||
register_audio_data_handler,
|
||||
|
|
@ -163,6 +167,34 @@ async def run_pipeline_telephony(
|
|||
user_id: int,
|
||||
call_id: str,
|
||||
transport_kwargs: dict,
|
||||
) -> None:
|
||||
"""Run a pipeline for any telephony provider."""
|
||||
# Register before any async setup so deploy drains see calls that are still
|
||||
# resolving DB/config/transport state.
|
||||
register_active_call(workflow_run_id)
|
||||
try:
|
||||
await _run_pipeline_telephony_impl(
|
||||
websocket,
|
||||
provider_name=provider_name,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
user_id=user_id,
|
||||
call_id=call_id,
|
||||
transport_kwargs=transport_kwargs,
|
||||
)
|
||||
finally:
|
||||
unregister_active_call(workflow_run_id)
|
||||
|
||||
|
||||
async def _run_pipeline_telephony_impl(
|
||||
websocket,
|
||||
*,
|
||||
provider_name: str,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user_id: int,
|
||||
call_id: str,
|
||||
transport_kwargs: dict,
|
||||
) -> None:
|
||||
"""Run a pipeline for any telephony provider.
|
||||
|
||||
|
|
@ -236,7 +268,7 @@ async def run_pipeline_telephony(
|
|||
)
|
||||
|
||||
try:
|
||||
await _run_pipeline(
|
||||
await _run_pipeline_impl(
|
||||
transport,
|
||||
workflow_id,
|
||||
workflow_run_id,
|
||||
|
|
@ -260,6 +292,31 @@ async def run_pipeline_smallwebrtc(
|
|||
user_id: int,
|
||||
call_context_vars: dict = {},
|
||||
user_provider_id: str | None = None,
|
||||
) -> None:
|
||||
"""Run pipeline for WebRTC connections."""
|
||||
# Register before any async setup so deploy drains see calls that are still
|
||||
# resolving DB/config/transport state.
|
||||
register_active_call(workflow_run_id)
|
||||
try:
|
||||
await _run_pipeline_smallwebrtc_impl(
|
||||
webrtc_connection,
|
||||
workflow_id,
|
||||
workflow_run_id,
|
||||
user_id,
|
||||
call_context_vars=call_context_vars,
|
||||
user_provider_id=user_provider_id,
|
||||
)
|
||||
finally:
|
||||
unregister_active_call(workflow_run_id)
|
||||
|
||||
|
||||
async def _run_pipeline_smallwebrtc_impl(
|
||||
webrtc_connection: SmallWebRTCConnection,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user_id: int,
|
||||
call_context_vars: dict = {},
|
||||
user_provider_id: str | None = None,
|
||||
) -> None:
|
||||
"""Run pipeline for WebRTC connections"""
|
||||
logger.debug(
|
||||
|
|
@ -309,7 +366,7 @@ async def run_pipeline_smallwebrtc(
|
|||
ambient_noise_config,
|
||||
is_realtime=is_realtime,
|
||||
)
|
||||
await _run_pipeline(
|
||||
await _run_pipeline_impl(
|
||||
transport,
|
||||
workflow_id,
|
||||
workflow_run_id,
|
||||
|
|
@ -332,6 +389,35 @@ async def _run_pipeline(
|
|||
user_provider_id: str | None = None,
|
||||
workflow_run=None,
|
||||
resolved_user_config=None,
|
||||
) -> None:
|
||||
"""Run the pipeline with active-call drain accounting."""
|
||||
register_active_call(workflow_run_id)
|
||||
try:
|
||||
await _run_pipeline_impl(
|
||||
transport,
|
||||
workflow_id,
|
||||
workflow_run_id,
|
||||
user_id,
|
||||
call_context_vars=call_context_vars,
|
||||
audio_config=audio_config,
|
||||
user_provider_id=user_provider_id,
|
||||
workflow_run=workflow_run,
|
||||
resolved_user_config=resolved_user_config,
|
||||
)
|
||||
finally:
|
||||
unregister_active_call(workflow_run_id)
|
||||
|
||||
|
||||
async def _run_pipeline_impl(
|
||||
transport,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user_id: int,
|
||||
call_context_vars: dict = {},
|
||||
audio_config: AudioConfig = None,
|
||||
user_provider_id: str | None = None,
|
||||
workflow_run=None,
|
||||
resolved_user_config=None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the pipeline with the given transport and configuration
|
||||
|
|
|
|||
220
api/tests/test_active_calls.py
Normal file
220
api/tests/test_active_calls.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
"""Unit tests for the per-worker active-call registry (deploy draining).
|
||||
|
||||
The registry backs GET /api/v1/health/active-calls, which scripts/rolling_update.sh
|
||||
(and a k8s preStop hook) polls to wait for live calls to finish before stopping a
|
||||
worker. The guarantees that matter for draining: register/unregister are
|
||||
idempotent, and the count only reaches zero when every registered run is gone.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.routes import main as main_routes
|
||||
from api.services.pipecat import active_calls
|
||||
from api.services.pipecat import run_pipeline as run_pipeline_module
|
||||
|
||||
|
||||
def setup_function():
|
||||
# Module-level state — start each test from an empty registry.
|
||||
active_calls._active_run_ids.clear()
|
||||
|
||||
|
||||
def _make_active_calls_client(
|
||||
monkeypatch,
|
||||
configured_secret: str | None = "test-dograh-devops-secret",
|
||||
) -> TestClient:
|
||||
monkeypatch.setattr("api.constants.DOGRAH_DEVOPS_SECRET", configured_secret)
|
||||
app = FastAPI()
|
||||
app.add_api_route(
|
||||
"/api/v1/health/active-calls",
|
||||
main_routes.active_calls,
|
||||
methods=["GET"],
|
||||
response_model=main_routes.ActiveCallsResponse,
|
||||
)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_starts_empty():
|
||||
assert active_calls.active_call_count() == 0
|
||||
|
||||
|
||||
def test_register_counts_distinct_runs():
|
||||
active_calls.register_active_call(1)
|
||||
active_calls.register_active_call(2)
|
||||
assert active_calls.active_call_count() == 2
|
||||
|
||||
|
||||
def test_register_is_idempotent():
|
||||
# Registering the same run twice must not double-count, or the count could
|
||||
# never drain to zero.
|
||||
active_calls.register_active_call(1)
|
||||
active_calls.register_active_call(1)
|
||||
assert active_calls.active_call_count() == 1
|
||||
|
||||
|
||||
def test_unregister_removes_run():
|
||||
active_calls.register_active_call(1)
|
||||
active_calls.register_active_call(2)
|
||||
active_calls.unregister_active_call(1)
|
||||
assert active_calls.active_call_count() == 1
|
||||
|
||||
|
||||
def test_unregister_unknown_run_is_a_noop():
|
||||
# discard() semantics: unregistering a run that was never registered (or was
|
||||
# already removed) is safe and cannot push the count negative.
|
||||
active_calls.unregister_active_call(999)
|
||||
assert active_calls.active_call_count() == 0
|
||||
|
||||
|
||||
def test_full_lifecycle_drains_to_zero():
|
||||
active_calls.register_active_call(42)
|
||||
assert active_calls.active_call_count() == 1
|
||||
active_calls.unregister_active_call(42)
|
||||
assert active_calls.active_call_count() == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_pipeline_counts_call_during_setup(monkeypatch):
|
||||
entered_setup = asyncio.Event()
|
||||
release_setup = asyncio.Event()
|
||||
|
||||
async def fake_get_workflow_run(*args, **kwargs):
|
||||
entered_setup.set()
|
||||
await release_setup.wait()
|
||||
raise RuntimeError("setup failed")
|
||||
|
||||
monkeypatch.setattr(
|
||||
run_pipeline_module.db_client,
|
||||
"get_workflow_run",
|
||||
fake_get_workflow_run,
|
||||
)
|
||||
|
||||
task = asyncio.create_task(
|
||||
run_pipeline_module._run_pipeline(
|
||||
transport=object(),
|
||||
workflow_id=1,
|
||||
workflow_run_id=42,
|
||||
user_id=7,
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.wait_for(entered_setup.wait(), timeout=1.0)
|
||||
assert active_calls.active_call_count() == 1
|
||||
|
||||
release_setup.set()
|
||||
with pytest.raises(RuntimeError, match="setup failed"):
|
||||
await asyncio.wait_for(task, timeout=1.0)
|
||||
assert active_calls.active_call_count() == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webrtc_entrypoint_counts_call_during_setup(monkeypatch):
|
||||
entered_setup = asyncio.Event()
|
||||
release_setup = asyncio.Event()
|
||||
|
||||
async def fake_get_workflow(*args, **kwargs):
|
||||
entered_setup.set()
|
||||
await release_setup.wait()
|
||||
raise RuntimeError("setup failed")
|
||||
|
||||
monkeypatch.setattr(
|
||||
run_pipeline_module.db_client, "get_workflow", fake_get_workflow
|
||||
)
|
||||
|
||||
task = asyncio.create_task(
|
||||
run_pipeline_module.run_pipeline_smallwebrtc(
|
||||
webrtc_connection=object(),
|
||||
workflow_id=1,
|
||||
workflow_run_id=43,
|
||||
user_id=7,
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.wait_for(entered_setup.wait(), timeout=1.0)
|
||||
assert active_calls.active_call_count() == 1
|
||||
|
||||
release_setup.set()
|
||||
with pytest.raises(RuntimeError, match="setup failed"):
|
||||
await asyncio.wait_for(task, timeout=1.0)
|
||||
assert active_calls.active_call_count() == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_telephony_entrypoint_counts_call_during_setup(monkeypatch):
|
||||
entered_setup = asyncio.Event()
|
||||
release_setup = asyncio.Event()
|
||||
|
||||
async def fake_get_workflow(*args, **kwargs):
|
||||
entered_setup.set()
|
||||
await release_setup.wait()
|
||||
raise RuntimeError("setup failed")
|
||||
|
||||
monkeypatch.setattr(
|
||||
run_pipeline_module.db_client, "get_workflow", fake_get_workflow
|
||||
)
|
||||
|
||||
task = asyncio.create_task(
|
||||
run_pipeline_module.run_pipeline_telephony(
|
||||
websocket=object(),
|
||||
provider_name="twilio",
|
||||
workflow_id=1,
|
||||
workflow_run_id=44,
|
||||
user_id=7,
|
||||
call_id="call-1",
|
||||
transport_kwargs={},
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.wait_for(entered_setup.wait(), timeout=1.0)
|
||||
assert active_calls.active_call_count() == 1
|
||||
|
||||
release_setup.set()
|
||||
with pytest.raises(RuntimeError, match="setup failed"):
|
||||
await asyncio.wait_for(task, timeout=1.0)
|
||||
assert active_calls.active_call_count() == 0
|
||||
|
||||
|
||||
def test_active_calls_route_requires_configured_secret(monkeypatch):
|
||||
client = _make_active_calls_client(monkeypatch, configured_secret=None)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/health/active-calls",
|
||||
headers={"X-Dograh-Devops-Secret": "test-dograh-devops-secret"},
|
||||
)
|
||||
|
||||
assert response.status_code == 503
|
||||
|
||||
|
||||
def test_active_calls_route_rejects_missing_secret_header(monkeypatch):
|
||||
client = _make_active_calls_client(monkeypatch)
|
||||
|
||||
response = client.get("/api/v1/health/active-calls")
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_active_calls_route_rejects_wrong_secret(monkeypatch):
|
||||
client = _make_active_calls_client(monkeypatch)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/health/active-calls",
|
||||
headers={"X-Dograh-Devops-Secret": "wrong"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_active_calls_route_returns_count_with_secret(monkeypatch):
|
||||
active_calls.register_active_call(42)
|
||||
client = _make_active_calls_client(monkeypatch)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/health/active-calls",
|
||||
headers={"X-Dograh-Devops-Secret": "test-dograh-devops-secret"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"active_calls": 1}
|
||||
Loading…
Add table
Add a link
Reference in a new issue