chore: drain active calls before rolling updates (#474)

* chore: drain active calls before rolling updates

* fix: add a devops secret header

* fix: implement PR review
This commit is contained in:
Abhishek 2026-06-29 06:00:31 +05:30 committed by GitHub
parent 327ec561d5
commit b192d4ada7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 572 additions and 17 deletions

View file

@ -12,6 +12,11 @@ UI_APP_URL="http://localhost:3000"
DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"
REDIS_URL="redis://:redissecret@localhost:6379"
# Internal devops secret for deployment scripts and lifecycle hooks.
# scripts/rolling_update.sh sends this to protected operational endpoints via
# X-Dograh-Devops-Secret. Use a unique random value in production.
DOGRAH_DEVOPS_SECRET="change-me-dograh-devops-secret"
# AWS S3 Configuration
ENABLE_AWS_S3="false"
# AWS_ACCESS_KEY_ID=""

View file

@ -14,4 +14,6 @@ UI_APP_URL=http://localhost:3000
DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/test_db"
REDIS_URL="redis://:redissecret@localhost:6379/0"
DOGRAH_DEVOPS_SECRET="test-dograh-devops-secret"
MINIO_PUBLIC_ENDPOINT=http://localhost:9000

View file

@ -54,6 +54,7 @@ STACK_AUTH_PROJECT_ID = os.getenv("STACK_AUTH_PROJECT_ID")
STACK_PUBLISHABLE_CLIENT_KEY = os.getenv("STACK_PUBLISHABLE_CLIENT_KEY")
DOGRAH_MPS_SECRET_KEY = os.getenv("DOGRAH_MPS_SECRET_KEY", None)
MPS_API_URL = os.getenv("MPS_API_URL", "https://services.dograh.com")
DOGRAH_DEVOPS_SECRET = os.getenv("DOGRAH_DEVOPS_SECRET") or None
# Storage Configuration
ENABLE_AWS_S3 = os.getenv("ENABLE_AWS_S3", "false").lower() == "true"

View file

@ -1,4 +1,7 @@
from fastapi import APIRouter
import secrets
from typing import Annotated
from fastapi import APIRouter, Header, HTTPException, status
from loguru import logger
from pydantic import BaseModel
@ -125,3 +128,51 @@ async def health() -> HealthResponse:
STACK_PUBLISHABLE_CLIENT_KEY if is_stack else None
),
)
class ActiveCallsResponse(BaseModel):
active_calls: int
DOGRAH_DEVOPS_SECRET_HEADER = "X-Dograh-Devops-Secret"
def _verify_devops_secret(
configured_secret: str | None,
provided_secret: str | None,
) -> None:
if not configured_secret:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Devops secret is not configured",
)
if not provided_secret or not secrets.compare_digest(
provided_secret,
configured_secret,
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Forbidden",
)
@router.get("/health/active-calls", response_model=ActiveCallsResponse)
async def active_calls(
x_dograh_devops_secret: Annotated[
str | None,
Header(alias=DOGRAH_DEVOPS_SECRET_HEADER),
] = None,
) -> ActiveCallsResponse:
"""In-flight call count for THIS worker — the drain signal for deploys.
A deploy orchestrator polls this per worker and waits for zero before
sending SIGTERM, because uvicorn force-closes live call WebSockets (close
code 1012) on SIGTERM and would cut calls mid-conversation otherwise. The
count is per-process: one uvicorn per VM port (scripts/rolling_update.sh)
or per Kubernetes pod (preStop hook). See api/services/pipecat/active_calls.py.
"""
from api.constants import DOGRAH_DEVOPS_SECRET
from api.services.pipecat.active_calls import active_call_count
_verify_devops_secret(DOGRAH_DEVOPS_SECRET, x_dograh_devops_secret)
return ActiveCallsResponse(active_calls=active_call_count())

View file

@ -0,0 +1,35 @@
"""In-process registry of active pipeline runs (live voice calls).
Each uvicorn worker tracks the calls it is currently running so a deploy
orchestrator can *drain* the worker before stopping it: poll the count, wait for
zero, then send SIGTERM. Sending SIGTERM while calls are live makes uvicorn
force-close their WebSockets (close code 1012), which cuts the calls instead of
letting them finish so the wait has to happen first.
The registry is deliberately per-process. That is exactly the unit that gets
drained: one uvicorn process per VM port (see ``scripts/rolling_update.sh``) or
one uvicorn process per Kubernetes pod (drained via a ``preStop`` hook). The
count is exposed read-only at ``GET /api/v1/health/active-calls`` and is also a
natural autoscaling signal (concurrent calls per worker).
Access is single-threaded (asyncio event loop), so no lock is needed. A set of
run ids rather than a bare counter keeps register/unregister idempotent and
makes the in-flight runs inspectable for debugging.
"""
_active_run_ids: set[int] = set()
def register_active_call(workflow_run_id: int) -> None:
"""Mark a pipeline run as active in this worker."""
_active_run_ids.add(workflow_run_id)
def unregister_active_call(workflow_run_id: int) -> None:
"""Mark a pipeline run as finished in this worker."""
_active_run_ids.discard(workflow_run_id)
def active_call_count() -> int:
"""Number of pipeline runs currently active in this worker."""
return len(_active_run_ids)

View file

@ -11,6 +11,10 @@ from api.services.integrations import (
IntegrationRuntimeContext,
create_runtime_sessions,
)
from api.services.pipecat.active_calls import (
register_active_call,
unregister_active_call,
)
from api.services.pipecat.audio_config import AudioConfig, create_audio_config
from api.services.pipecat.event_handlers import (
register_audio_data_handler,
@ -163,6 +167,34 @@ async def run_pipeline_telephony(
user_id: int,
call_id: str,
transport_kwargs: dict,
) -> None:
"""Run a pipeline for any telephony provider."""
# Register before any async setup so deploy drains see calls that are still
# resolving DB/config/transport state.
register_active_call(workflow_run_id)
try:
await _run_pipeline_telephony_impl(
websocket,
provider_name=provider_name,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
user_id=user_id,
call_id=call_id,
transport_kwargs=transport_kwargs,
)
finally:
unregister_active_call(workflow_run_id)
async def _run_pipeline_telephony_impl(
websocket,
*,
provider_name: str,
workflow_id: int,
workflow_run_id: int,
user_id: int,
call_id: str,
transport_kwargs: dict,
) -> None:
"""Run a pipeline for any telephony provider.
@ -236,7 +268,7 @@ async def run_pipeline_telephony(
)
try:
await _run_pipeline(
await _run_pipeline_impl(
transport,
workflow_id,
workflow_run_id,
@ -260,6 +292,31 @@ async def run_pipeline_smallwebrtc(
user_id: int,
call_context_vars: dict = {},
user_provider_id: str | None = None,
) -> None:
"""Run pipeline for WebRTC connections."""
# Register before any async setup so deploy drains see calls that are still
# resolving DB/config/transport state.
register_active_call(workflow_run_id)
try:
await _run_pipeline_smallwebrtc_impl(
webrtc_connection,
workflow_id,
workflow_run_id,
user_id,
call_context_vars=call_context_vars,
user_provider_id=user_provider_id,
)
finally:
unregister_active_call(workflow_run_id)
async def _run_pipeline_smallwebrtc_impl(
webrtc_connection: SmallWebRTCConnection,
workflow_id: int,
workflow_run_id: int,
user_id: int,
call_context_vars: dict = {},
user_provider_id: str | None = None,
) -> None:
"""Run pipeline for WebRTC connections"""
logger.debug(
@ -309,7 +366,7 @@ async def run_pipeline_smallwebrtc(
ambient_noise_config,
is_realtime=is_realtime,
)
await _run_pipeline(
await _run_pipeline_impl(
transport,
workflow_id,
workflow_run_id,
@ -332,6 +389,35 @@ async def _run_pipeline(
user_provider_id: str | None = None,
workflow_run=None,
resolved_user_config=None,
) -> None:
"""Run the pipeline with active-call drain accounting."""
register_active_call(workflow_run_id)
try:
await _run_pipeline_impl(
transport,
workflow_id,
workflow_run_id,
user_id,
call_context_vars=call_context_vars,
audio_config=audio_config,
user_provider_id=user_provider_id,
workflow_run=workflow_run,
resolved_user_config=resolved_user_config,
)
finally:
unregister_active_call(workflow_run_id)
async def _run_pipeline_impl(
transport,
workflow_id: int,
workflow_run_id: int,
user_id: int,
call_context_vars: dict = {},
audio_config: AudioConfig = None,
user_provider_id: str | None = None,
workflow_run=None,
resolved_user_config=None,
) -> None:
"""
Run the pipeline with the given transport and configuration

View file

@ -0,0 +1,220 @@
"""Unit tests for the per-worker active-call registry (deploy draining).
The registry backs GET /api/v1/health/active-calls, which scripts/rolling_update.sh
(and a k8s preStop hook) polls to wait for live calls to finish before stopping a
worker. The guarantees that matter for draining: register/unregister are
idempotent, and the count only reaches zero when every registered run is gone.
"""
import asyncio
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from api.routes import main as main_routes
from api.services.pipecat import active_calls
from api.services.pipecat import run_pipeline as run_pipeline_module
def setup_function():
# Module-level state — start each test from an empty registry.
active_calls._active_run_ids.clear()
def _make_active_calls_client(
monkeypatch,
configured_secret: str | None = "test-dograh-devops-secret",
) -> TestClient:
monkeypatch.setattr("api.constants.DOGRAH_DEVOPS_SECRET", configured_secret)
app = FastAPI()
app.add_api_route(
"/api/v1/health/active-calls",
main_routes.active_calls,
methods=["GET"],
response_model=main_routes.ActiveCallsResponse,
)
return TestClient(app)
def test_starts_empty():
assert active_calls.active_call_count() == 0
def test_register_counts_distinct_runs():
active_calls.register_active_call(1)
active_calls.register_active_call(2)
assert active_calls.active_call_count() == 2
def test_register_is_idempotent():
# Registering the same run twice must not double-count, or the count could
# never drain to zero.
active_calls.register_active_call(1)
active_calls.register_active_call(1)
assert active_calls.active_call_count() == 1
def test_unregister_removes_run():
active_calls.register_active_call(1)
active_calls.register_active_call(2)
active_calls.unregister_active_call(1)
assert active_calls.active_call_count() == 1
def test_unregister_unknown_run_is_a_noop():
# discard() semantics: unregistering a run that was never registered (or was
# already removed) is safe and cannot push the count negative.
active_calls.unregister_active_call(999)
assert active_calls.active_call_count() == 0
def test_full_lifecycle_drains_to_zero():
active_calls.register_active_call(42)
assert active_calls.active_call_count() == 1
active_calls.unregister_active_call(42)
assert active_calls.active_call_count() == 0
@pytest.mark.asyncio
async def test_run_pipeline_counts_call_during_setup(monkeypatch):
entered_setup = asyncio.Event()
release_setup = asyncio.Event()
async def fake_get_workflow_run(*args, **kwargs):
entered_setup.set()
await release_setup.wait()
raise RuntimeError("setup failed")
monkeypatch.setattr(
run_pipeline_module.db_client,
"get_workflow_run",
fake_get_workflow_run,
)
task = asyncio.create_task(
run_pipeline_module._run_pipeline(
transport=object(),
workflow_id=1,
workflow_run_id=42,
user_id=7,
)
)
await asyncio.wait_for(entered_setup.wait(), timeout=1.0)
assert active_calls.active_call_count() == 1
release_setup.set()
with pytest.raises(RuntimeError, match="setup failed"):
await asyncio.wait_for(task, timeout=1.0)
assert active_calls.active_call_count() == 0
@pytest.mark.asyncio
async def test_webrtc_entrypoint_counts_call_during_setup(monkeypatch):
entered_setup = asyncio.Event()
release_setup = asyncio.Event()
async def fake_get_workflow(*args, **kwargs):
entered_setup.set()
await release_setup.wait()
raise RuntimeError("setup failed")
monkeypatch.setattr(
run_pipeline_module.db_client, "get_workflow", fake_get_workflow
)
task = asyncio.create_task(
run_pipeline_module.run_pipeline_smallwebrtc(
webrtc_connection=object(),
workflow_id=1,
workflow_run_id=43,
user_id=7,
)
)
await asyncio.wait_for(entered_setup.wait(), timeout=1.0)
assert active_calls.active_call_count() == 1
release_setup.set()
with pytest.raises(RuntimeError, match="setup failed"):
await asyncio.wait_for(task, timeout=1.0)
assert active_calls.active_call_count() == 0
@pytest.mark.asyncio
async def test_telephony_entrypoint_counts_call_during_setup(monkeypatch):
entered_setup = asyncio.Event()
release_setup = asyncio.Event()
async def fake_get_workflow(*args, **kwargs):
entered_setup.set()
await release_setup.wait()
raise RuntimeError("setup failed")
monkeypatch.setattr(
run_pipeline_module.db_client, "get_workflow", fake_get_workflow
)
task = asyncio.create_task(
run_pipeline_module.run_pipeline_telephony(
websocket=object(),
provider_name="twilio",
workflow_id=1,
workflow_run_id=44,
user_id=7,
call_id="call-1",
transport_kwargs={},
)
)
await asyncio.wait_for(entered_setup.wait(), timeout=1.0)
assert active_calls.active_call_count() == 1
release_setup.set()
with pytest.raises(RuntimeError, match="setup failed"):
await asyncio.wait_for(task, timeout=1.0)
assert active_calls.active_call_count() == 0
def test_active_calls_route_requires_configured_secret(monkeypatch):
client = _make_active_calls_client(monkeypatch, configured_secret=None)
response = client.get(
"/api/v1/health/active-calls",
headers={"X-Dograh-Devops-Secret": "test-dograh-devops-secret"},
)
assert response.status_code == 503
def test_active_calls_route_rejects_missing_secret_header(monkeypatch):
client = _make_active_calls_client(monkeypatch)
response = client.get("/api/v1/health/active-calls")
assert response.status_code == 403
def test_active_calls_route_rejects_wrong_secret(monkeypatch):
client = _make_active_calls_client(monkeypatch)
response = client.get(
"/api/v1/health/active-calls",
headers={"X-Dograh-Devops-Secret": "wrong"},
)
assert response.status_code == 403
def test_active_calls_route_returns_count_with_secret(monkeypatch):
active_calls.register_active_call(42)
client = _make_active_calls_client(monkeypatch)
response = client.get(
"/api/v1/health/active-calls",
headers={"X-Dograh-Devops-Secret": "test-dograh-devops-secret"},
)
assert response.status_code == 200
assert response.json() == {"active_calls": 1}