mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
Merge branch 'feat/headless-widget' of https://github.com/dograh-hq/dograh into feat/headless-widget
This commit is contained in:
commit
85dd2f915b
36 changed files with 1063 additions and 250 deletions
|
|
@ -0,0 +1,35 @@
|
|||
"""add campaign logs column
|
||||
|
||||
Revision ID: 6499c608d0f6
|
||||
Revises: a2355fc6bdc1
|
||||
Create Date: 2026-05-05 17:25:49.235730
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6499c608d0f6"
|
||||
down_revision: Union[str, None] = "a2355fc6bdc1"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"campaigns",
|
||||
sa.Column(
|
||||
"logs", sa.JSON(), server_default=sa.text("'[]'::json"), nullable=False
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("campaigns", "logs")
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
|
@ -134,35 +135,6 @@ class CampaignClient(BaseDBClient):
|
|||
await session.refresh(campaign)
|
||||
return campaign
|
||||
|
||||
async def update_campaign_progress(
|
||||
self,
|
||||
campaign_id: int,
|
||||
processed_rows: int,
|
||||
failed_rows: int,
|
||||
organization_id: int,
|
||||
) -> None:
|
||||
"""Update campaign progress counters"""
|
||||
async with self.async_session() as session:
|
||||
query = select(CampaignModel).where(
|
||||
CampaignModel.id == campaign_id,
|
||||
CampaignModel.organization_id == organization_id,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
campaign = result.scalar_one_or_none()
|
||||
|
||||
if not campaign:
|
||||
raise ValueError(f"Campaign {campaign_id} not found")
|
||||
|
||||
campaign.processed_rows = processed_rows
|
||||
campaign.failed_rows = failed_rows
|
||||
campaign.updated_at = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
async def get_campaign_runs(
|
||||
self,
|
||||
campaign_id: int,
|
||||
|
|
@ -452,6 +424,48 @@ class CampaignClient(BaseDBClient):
|
|||
await session.refresh(campaign)
|
||||
return campaign
|
||||
|
||||
async def append_campaign_log(
|
||||
self,
|
||||
campaign_id: int,
|
||||
level: str,
|
||||
event: str,
|
||||
message: str,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Append a timestamped entry to the campaign's logs JSON array.
|
||||
|
||||
Uses a SQL-side jsonb concat so concurrent writers do not clobber
|
||||
each other's entries.
|
||||
"""
|
||||
entry: Dict[str, Any] = {
|
||||
"ts": datetime.now(UTC).isoformat(),
|
||||
"level": level,
|
||||
"event": event,
|
||||
"message": message,
|
||||
}
|
||||
if details:
|
||||
entry["details"] = details
|
||||
|
||||
async with self.async_session() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE campaigns "
|
||||
"SET logs = (logs::jsonb || CAST(:entry AS jsonb))::json, "
|
||||
" updated_at = :now "
|
||||
"WHERE id = :campaign_id"
|
||||
),
|
||||
{
|
||||
"entry": json.dumps([entry]),
|
||||
"now": datetime.now(UTC),
|
||||
"campaign_id": campaign_id,
|
||||
},
|
||||
)
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
# QueuedRun methods
|
||||
async def bulk_create_queued_runs(self, queued_runs_data: list[dict]) -> None:
|
||||
"""Bulk create queued runs"""
|
||||
|
|
|
|||
|
|
@ -683,6 +683,16 @@ class CampaignModel(Base):
|
|||
JSON, nullable=False, default=dict, server_default=text("'{}'::json")
|
||||
)
|
||||
|
||||
# Append-only timestamped log entries for state transitions, failures,
|
||||
# and circuit-breaker events. Surfaced in the UI so operators can see
|
||||
# why a campaign moved to paused/failed without digging through logs.
|
||||
logs = Column(
|
||||
JSON,
|
||||
nullable=False,
|
||||
default=list,
|
||||
server_default=text("'[]'::json"),
|
||||
)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
|
@ -172,6 +172,20 @@ class UpdateCampaignRequest(BaseModel):
|
|||
circuit_breaker: Optional[CircuitBreakerConfigRequest] = None
|
||||
|
||||
|
||||
class CampaignLogEntryResponse(BaseModel):
|
||||
"""A single timestamped entry from the campaign's append-only log.
|
||||
|
||||
Surfaced in the UI so operators can see why a campaign moved to
|
||||
paused / failed without digging through server logs.
|
||||
"""
|
||||
|
||||
ts: str
|
||||
level: str
|
||||
event: str
|
||||
message: str
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class CampaignResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
|
@ -196,6 +210,7 @@ class CampaignResponse(BaseModel):
|
|||
redialed_campaign_id: Optional[int] = None
|
||||
telephony_configuration_id: Optional[int] = None
|
||||
telephony_configuration_name: Optional[str] = None
|
||||
logs: List[CampaignLogEntryResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CampaignsResponse(BaseModel):
|
||||
|
|
@ -298,6 +313,11 @@ def _build_campaign_response(
|
|||
redialed_campaign_id=redialed_campaign_id,
|
||||
telephony_configuration_id=campaign.telephony_configuration_id,
|
||||
telephony_configuration_name=telephony_configuration_name,
|
||||
logs=[
|
||||
CampaignLogEntryResponse(**entry)
|
||||
for entry in (campaign.logs or [])
|
||||
if isinstance(entry, dict)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from fastapi import (
|
|||
)
|
||||
from loguru import logger
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
from api.db import db_client
|
||||
|
|
@ -30,7 +30,6 @@ from api.services.telephony.call_transfer_manager import get_call_transfer_manag
|
|||
from api.services.telephony.factory import (
|
||||
get_all_telephony_providers,
|
||||
get_default_telephony_provider,
|
||||
get_telephony_provider,
|
||||
get_telephony_provider_by_id,
|
||||
get_telephony_provider_for_run,
|
||||
)
|
||||
|
|
@ -874,110 +873,6 @@ async def handle_inbound_telephony(
|
|||
return generic_hangup_response()
|
||||
|
||||
|
||||
class TransferCallRequest(BaseModel):
|
||||
"""Request model for initiating a call transfer."""
|
||||
|
||||
destination: str # E.164 format phone number (required)
|
||||
organization_id: int # Organization ID for provider configuration
|
||||
transfer_id: str # Unique identifier for tracking this transfer
|
||||
conference_name: str # Conference name for the transfer
|
||||
timeout: Optional[int] = 20 # seconds to wait for answer
|
||||
|
||||
@field_validator("destination")
|
||||
@classmethod
|
||||
def validate_destination(cls, destination: str) -> str:
|
||||
"""Validate destination is in E.164 format."""
|
||||
import re
|
||||
|
||||
if not destination or not destination.strip():
|
||||
raise ValueError("Destination phone number is required")
|
||||
|
||||
E164_PHONE_REGEX = r"^\+[1-9]\d{1,14}$"
|
||||
if not re.match(E164_PHONE_REGEX, destination.strip()):
|
||||
raise ValueError(
|
||||
f"Invalid phone number format: {destination}. Must be E.164 format (e.g., +1234567890)"
|
||||
)
|
||||
|
||||
return destination.strip()
|
||||
|
||||
|
||||
@router.post("/call-transfer")
|
||||
async def initiate_call_transfer(request: TransferCallRequest):
|
||||
"""Initiate a call transfer via the telephony provider.
|
||||
|
||||
This endpoint only initiates the outbound call. Transfer context
|
||||
(original_call_sid, etc.) is stored by the caller
|
||||
before invoking this endpoint.
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting call transfer to {request.destination} with transfer_id: {request.transfer_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
try:
|
||||
provider = await get_telephony_provider(request.organization_id)
|
||||
except ValueError as e:
|
||||
logger.error(f"Transfer provider validation failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Call transfer not supported: {str(e)}"
|
||||
)
|
||||
|
||||
if not provider.supports_transfers():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provider '{provider.PROVIDER_NAME}' does not support call transfers",
|
||||
)
|
||||
|
||||
if not provider.validate_config():
|
||||
logger.error(f"Provider {provider.PROVIDER_NAME} configuration is invalid")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Telephony provider '{provider.PROVIDER_NAME}' is not properly configured for transfers",
|
||||
)
|
||||
|
||||
logger.info(f"Initiating transfer call via {provider.PROVIDER_NAME} provider")
|
||||
try:
|
||||
transfer_result = await provider.transfer_call(
|
||||
destination=request.destination,
|
||||
transfer_id=request.transfer_id,
|
||||
conference_name=request.conference_name,
|
||||
timeout=request.timeout,
|
||||
)
|
||||
except NotImplementedError as e:
|
||||
logger.error(
|
||||
f"Provider {provider.PROVIDER_NAME} doesn't support transfers: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provider '{provider.PROVIDER_NAME}' does not support call transfers",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Provider transfer call failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Transfer call failed: {str(e)}"
|
||||
)
|
||||
|
||||
call_sid = transfer_result.get("call_sid")
|
||||
logger.info(f"Transfer call initiated successfully: {call_sid}")
|
||||
logger.debug(f"Transfer result: {transfer_result}")
|
||||
|
||||
return {
|
||||
"status": "transfer_initiated",
|
||||
"call_id": call_sid,
|
||||
"message": f"Calling {request.destination}...",
|
||||
"transfer_id": request.transfer_id,
|
||||
"provider": provider.PROVIDER_NAME,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during transfer call: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Internal error during transfer: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/transfer-result/{transfer_id}")
|
||||
async def complete_transfer_function_call(transfer_id: str, request: Request):
|
||||
"""Webhook endpoint to complete the function call with transfer result.
|
||||
|
|
|
|||
|
|
@ -345,7 +345,12 @@ class CampaignCallDispatcher:
|
|||
)
|
||||
|
||||
# Record call initiation failure in circuit breaker
|
||||
await circuit_breaker.record_and_evaluate(campaign.id, is_failure=True)
|
||||
await circuit_breaker.record_and_evaluate(
|
||||
campaign.id,
|
||||
is_failure=True,
|
||||
workflow_run_id=workflow_run.id,
|
||||
reason="call_initiation_failed",
|
||||
)
|
||||
|
||||
# Release concurrent slot on failure
|
||||
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run.id)
|
||||
|
|
@ -459,13 +464,18 @@ class CampaignCallDispatcher:
|
|||
await asyncio.sleep(1)
|
||||
|
||||
async def acquire_from_number(
|
||||
self, organization_id: int, timeout: float = 60
|
||||
self, organization_id: int, timeout: float = 600
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Acquire a from_number from the pool with retry.
|
||||
Waits up to timeout seconds, polling every 1s.
|
||||
|
||||
Returns the phone number or None if timeout is exceeded.
|
||||
Args:
|
||||
organization_id: ID of the organization for which to acquire the from_number.
|
||||
timeout: Maximum time in seconds to wait for a from_number before giving up.
|
||||
|
||||
Returns:
|
||||
The acquired phone number as a string, or None if timeout is exceeded.
|
||||
"""
|
||||
wait_start = time.time()
|
||||
|
||||
|
|
|
|||
|
|
@ -383,6 +383,20 @@ class CampaignOrchestrator:
|
|||
f"pausing campaign. Stats: {stats}"
|
||||
)
|
||||
await db_client.update_campaign(campaign_id=campaign_id, state="paused")
|
||||
await db_client.append_campaign_log(
|
||||
campaign_id=campaign_id,
|
||||
level="warning",
|
||||
event="circuit_breaker_tripped",
|
||||
message=(
|
||||
f"Paused at scheduling: failure rate "
|
||||
f"{stats['failure_rate']:.2%} "
|
||||
f"({stats['failure_count']}/"
|
||||
f"{stats['failure_count'] + stats['success_count']}) "
|
||||
f"exceeded threshold {stats['threshold']:.2%} "
|
||||
f"in {stats['window_seconds']}s window"
|
||||
),
|
||||
details=stats,
|
||||
)
|
||||
await self.publisher.publish_circuit_breaker_tripped(
|
||||
campaign_id=campaign_id,
|
||||
failure_rate=stats["failure_rate"],
|
||||
|
|
|
|||
|
|
@ -3,10 +3,15 @@
|
|||
Uses two Redis sorted sets (ZSETs) per campaign — one for failures, one for
|
||||
successes — as sliding windows. ZCARD gives O(1) counts without iterating
|
||||
members, keeping the Lua scripts simple.
|
||||
|
||||
A separate capped Redis list (``cb_recent_failures:{campaign_id}``) stores the
|
||||
last N failing ``{workflow_run_id, reason, ts}`` entries so the campaign log
|
||||
written when the breaker trips can show *which* calls pushed it over.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from loguru import logger
|
||||
|
|
@ -15,6 +20,11 @@ from api.constants import DEFAULT_CIRCUIT_BREAKER_CONFIG, REDIS_URL
|
|||
from api.db import db_client
|
||||
from api.services.campaign.campaign_event_publisher import get_campaign_event_publisher
|
||||
|
||||
# Cap on the number of recent failure entries kept per campaign — large enough
|
||||
# to be useful for debugging a trip, small enough that the JSON details stay
|
||||
# bounded.
|
||||
MAX_RECENT_FAILURES = 20
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""Sliding window circuit breaker for campaign call failures."""
|
||||
|
|
@ -35,6 +45,60 @@ class CircuitBreaker:
|
|||
"""Return (failures_key, successes_key) for a campaign."""
|
||||
return f"cb_failures:{campaign_id}", f"cb_successes:{campaign_id}"
|
||||
|
||||
@staticmethod
|
||||
def _recent_failures_key(campaign_id: int) -> str:
|
||||
"""Return the Redis key used for the capped recent-failures list."""
|
||||
return f"cb_recent_failures:{campaign_id}"
|
||||
|
||||
async def _push_recent_failure(
|
||||
self,
|
||||
campaign_id: int,
|
||||
workflow_run_id: int,
|
||||
reason: Optional[str],
|
||||
) -> None:
|
||||
"""Push a failure entry onto the capped recent-failures list."""
|
||||
redis_client = await self._get_redis()
|
||||
key = self._recent_failures_key(campaign_id)
|
||||
entry = json.dumps(
|
||||
{
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"reason": reason,
|
||||
"ts": time.time(),
|
||||
}
|
||||
)
|
||||
try:
|
||||
await redis_client.lpush(key, entry)
|
||||
await redis_client.ltrim(key, 0, MAX_RECENT_FAILURES - 1)
|
||||
# Keep this list around as long as the sliding window plus a buffer.
|
||||
await redis_client.expire(
|
||||
key,
|
||||
DEFAULT_CIRCUIT_BREAKER_CONFIG["window_seconds"] + 60,
|
||||
)
|
||||
except Exception as e:
|
||||
# Never let recent-failure bookkeeping disrupt the call path.
|
||||
logger.error(
|
||||
f"Failed to record recent failure for campaign {campaign_id}: {e}"
|
||||
)
|
||||
|
||||
async def _get_recent_failures(self, campaign_id: int) -> List[Dict[str, Any]]:
|
||||
"""Return the recent-failures list (most-recent first)."""
|
||||
redis_client = await self._get_redis()
|
||||
key = self._recent_failures_key(campaign_id)
|
||||
try:
|
||||
entries = await redis_client.lrange(key, 0, -1)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to read recent failures for campaign {campaign_id}: {e}"
|
||||
)
|
||||
return []
|
||||
decoded: List[Dict[str, Any]] = []
|
||||
for raw in entries:
|
||||
try:
|
||||
decoded.append(json.loads(raw))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
return decoded
|
||||
|
||||
async def record_call_outcome(
|
||||
self,
|
||||
campaign_id: int,
|
||||
|
|
@ -227,13 +291,25 @@ class CircuitBreaker:
|
|||
logger.error(f"Circuit breaker check error for campaign {campaign_id}: {e}")
|
||||
return False, None
|
||||
|
||||
async def record_and_evaluate(self, campaign_id: int, is_failure: bool) -> None:
|
||||
async def record_and_evaluate(
|
||||
self,
|
||||
campaign_id: int,
|
||||
is_failure: bool,
|
||||
*,
|
||||
workflow_run_id: Optional[int] = None,
|
||||
reason: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Record a call outcome, and if the breaker trips, pause the campaign.
|
||||
|
||||
This is the main entry point called from telephony status callbacks.
|
||||
It handles fetching campaign config, recording the outcome, and
|
||||
pausing + publishing an event if the breaker trips.
|
||||
|
||||
``workflow_run_id`` and ``reason`` are optional but should be supplied
|
||||
on failures: they are appended to a capped Redis list so the campaign
|
||||
log entry written on trip can name the calls that pushed the breaker
|
||||
over the threshold.
|
||||
|
||||
Exceptions are caught internally so this never disrupts the caller.
|
||||
"""
|
||||
try:
|
||||
|
|
@ -245,6 +321,13 @@ class CircuitBreaker:
|
|||
if campaign.orchestrator_metadata:
|
||||
cb_config = campaign.orchestrator_metadata.get("circuit_breaker", {})
|
||||
|
||||
if is_failure and workflow_run_id is not None:
|
||||
await self._push_recent_failure(
|
||||
campaign_id=campaign_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
tripped, stats = await self.record_call_outcome(
|
||||
campaign_id=campaign_id,
|
||||
is_failure=is_failure,
|
||||
|
|
@ -257,7 +340,22 @@ class CircuitBreaker:
|
|||
f"pausing campaign. Stats: {stats}"
|
||||
)
|
||||
|
||||
recent_failures = await self._get_recent_failures(campaign_id)
|
||||
|
||||
await db_client.update_campaign(campaign_id=campaign_id, state="paused")
|
||||
await db_client.append_campaign_log(
|
||||
campaign_id=campaign_id,
|
||||
level="warning",
|
||||
event="circuit_breaker_tripped",
|
||||
message=(
|
||||
f"Paused: failure rate {stats['failure_rate']:.2%} "
|
||||
f"({stats['failure_count']}/"
|
||||
f"{stats['failure_count'] + stats['success_count']}) "
|
||||
f"exceeded threshold {stats['threshold']:.2%} "
|
||||
f"in {stats['window_seconds']}s window"
|
||||
),
|
||||
details={**stats, "recent_failures": recent_failures},
|
||||
)
|
||||
|
||||
publisher = await get_campaign_event_publisher()
|
||||
await publisher.publish_circuit_breaker_tripped(
|
||||
|
|
@ -275,13 +373,16 @@ class CircuitBreaker:
|
|||
async def reset(self, campaign_id: int) -> bool:
|
||||
"""Reset the circuit breaker state for a campaign.
|
||||
|
||||
Called when a campaign is resumed to give it a clean slate.
|
||||
Called when a campaign is resumed to give it a clean slate. Also clears
|
||||
the recent-failures list so log entries from the next trip reference
|
||||
only post-resume failures.
|
||||
"""
|
||||
redis_client = await self._get_redis()
|
||||
fail_key, succ_key = self._keys(campaign_id)
|
||||
recent_key = self._recent_failures_key(campaign_id)
|
||||
|
||||
try:
|
||||
await redis_client.delete(fail_key, succ_key)
|
||||
await redis_client.delete(fail_key, succ_key, recent_key)
|
||||
logger.info(f"Circuit breaker reset for campaign {campaign_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -97,6 +97,12 @@ def register_event_handlers(
|
|||
"initial_response_triggered": False,
|
||||
}
|
||||
|
||||
async def queue_initial_llm_context():
|
||||
# Queue LLMContextFrame after the VoicemailDetector since the detector
|
||||
# gates LLMContextFrames until voicemail detection completes. We also
|
||||
# don't want to trigger the Voicemail LLM with this initial frame.
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
|
||||
async def maybe_trigger_initial_response():
|
||||
"""Start the conversation after both pipeline_started and client_connected events.
|
||||
|
||||
|
|
@ -185,7 +191,7 @@ def register_event_handlers(
|
|||
f"Failed to fetch audio greeting {greeting_value}, "
|
||||
"falling back to LLM generation"
|
||||
)
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
await queue_initial_llm_context()
|
||||
else:
|
||||
logger.debug("Playing text greeting via TTS")
|
||||
# append_to_context=True so the assistant aggregator commits
|
||||
|
|
@ -198,7 +204,7 @@ def register_event_handlers(
|
|||
logger.debug(
|
||||
"Both pipeline_started and client_connected received - triggering initial LLM generation"
|
||||
)
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
await queue_initial_llm_context()
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(_transport, _participant):
|
||||
|
|
@ -235,7 +241,10 @@ def register_event_handlers(
|
|||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if workflow_run and workflow_run.campaign_id:
|
||||
await circuit_breaker.record_and_evaluate(
|
||||
campaign_id=workflow_run.campaign_id, is_failure=True
|
||||
campaign_id=workflow_run.campaign_id,
|
||||
is_failure=True,
|
||||
workflow_run_id=workflow_run_id,
|
||||
reason="pipeline_error",
|
||||
)
|
||||
asyncio.create_task(
|
||||
_capture_call_event(
|
||||
|
|
|
|||
|
|
@ -105,23 +105,53 @@ def build_realtime_pipeline(
|
|||
assistant_context_aggregator,
|
||||
pipeline_engine_callback_processor,
|
||||
pipeline_metrics_aggregator,
|
||||
voicemail_detector=None,
|
||||
):
|
||||
"""Build a pipeline for realtime (speech-to-speech) LLM services.
|
||||
|
||||
Realtime services (e.g. OpenAI Realtime, Gemini Live) handle STT+LLM+TTS
|
||||
internally, so no separate STT or TTS processors are needed.
|
||||
|
||||
Args:
|
||||
voicemail_detector: Optional VoicemailDetector. Placed *below* the
|
||||
realtime LLM. This is asymmetric with the non-realtime layout
|
||||
(where the detector sits between STT and the main user aggregator)
|
||||
because the realtime LLM is both the source of TranscriptionFrame
|
||||
(broadcast downstream) and the sink of LLMContextFrame (consumed
|
||||
by _handle_context without forwarding). Placing the detector below
|
||||
the realtime LLM means: downstream TranscriptionFrames reach the
|
||||
classifier branch, UserStartedSpeakingFrame /
|
||||
UserStoppedSpeakingFrame are forwarded through by the LLM, and the
|
||||
main aggregator's LLMContextFrame is absorbed by the realtime LLM
|
||||
and never leaks into the classifier (which would otherwise run a
|
||||
voicemail completion on the workflow's main context).
|
||||
|
||||
The TTS gate and LLM gate are intentionally not used: the realtime
|
||||
LLM reacts to audio directly, not to LLMContextFrames. On voicemail
|
||||
detection we drop the call via end_call_with_reason; the detector's
|
||||
ConversationGate also blocks downstream audio output until the call
|
||||
ends.
|
||||
"""
|
||||
processors = [
|
||||
transport.input(),
|
||||
user_context_aggregator,
|
||||
realtime_llm,
|
||||
pipeline_engine_callback_processor,
|
||||
transport.output(),
|
||||
audio_buffer,
|
||||
assistant_context_aggregator,
|
||||
pipeline_metrics_aggregator,
|
||||
]
|
||||
|
||||
if voicemail_detector:
|
||||
logger.info("Adding native voicemail detector to realtime pipeline")
|
||||
processors.append(voicemail_detector.detector())
|
||||
|
||||
processors.extend(
|
||||
[
|
||||
pipeline_engine_callback_processor,
|
||||
transport.output(),
|
||||
audio_buffer,
|
||||
assistant_context_aggregator,
|
||||
pipeline_metrics_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
return Pipeline(processors)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -427,11 +427,13 @@ async def _run_pipeline(
|
|||
|
||||
# Configure turn strategies based on STT provider, model, and workflow configuration
|
||||
if is_realtime:
|
||||
# Realtime services have server-side VAD/turn detection.
|
||||
# For stop strategy, lets rely on SmartTurnAnalyzer which is
|
||||
# enabled by default
|
||||
# Realtime services do server-side turn detection for response generation,
|
||||
# but we still need a client-side stop strategy so the user aggregator emits
|
||||
# UserStoppedSpeakingFrame. Without it, downstream consumers (e.g. voicemail
|
||||
# detector) and Gemini Live's _finalize_pending flag never see a turn end.
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[VADUserTurnStartStrategy()], stop=[]
|
||||
start=[VADUserTurnStartStrategy()],
|
||||
stop=[SpeechTimeoutUserTurnStopStrategy()],
|
||||
)
|
||||
|
||||
# Lets not start the pipeline as muted for Realtime
|
||||
|
|
@ -521,7 +523,6 @@ async def _run_pipeline(
|
|||
async def on_user_turn_started(aggregator, strategy):
|
||||
user_idle_handler.reset()
|
||||
|
||||
# Voicemail detection and recording router are not supported in realtime mode
|
||||
voicemail_detector = None
|
||||
recording_router = None
|
||||
|
||||
|
|
@ -533,58 +534,61 @@ async def _run_pipeline(
|
|||
)
|
||||
engine.set_fetch_recording_audio(fetch_audio)
|
||||
|
||||
if not is_realtime:
|
||||
# Create voicemail detector if enabled in workflow configurations
|
||||
voicemail_config = (workflow.workflow_configurations or {}).get(
|
||||
"voicemail_detection", {}
|
||||
# Voicemail detection works in both modes. In realtime mode the detector sits
|
||||
# after the realtime LLM and consumes the TranscriptionFrames it broadcasts;
|
||||
# the LLM gate / TTS gate are not used (the realtime LLM responds to audio
|
||||
# directly, not LLMContextFrames), so on detection we rely on
|
||||
# end_call_with_reason to drop the call.
|
||||
voicemail_config = (workflow.workflow_configurations or {}).get(
|
||||
"voicemail_detection", {}
|
||||
)
|
||||
if voicemail_config.get("enabled", False):
|
||||
logger.info(f"Voicemail detection enabled for workflow run {workflow_run_id}")
|
||||
# Create a separate LLM instance for the voicemail sub-pipeline
|
||||
# (can't share with main pipeline as it would mess up frame linking)
|
||||
if voicemail_config.get("use_workflow_llm", True):
|
||||
voicemail_llm = create_llm_service(user_config)
|
||||
else:
|
||||
voicemail_llm = create_llm_service_from_provider(
|
||||
provider=voicemail_config.get("provider", "openai"),
|
||||
model=voicemail_config.get("model", "gpt-4.1"),
|
||||
api_key=voicemail_config.get("api_key", ""),
|
||||
)
|
||||
|
||||
long_speech_timeout = voicemail_config.get("long_speech_timeout", 8.0)
|
||||
custom_system_prompt = voicemail_config.get("system_prompt") or None
|
||||
|
||||
voicemail_detector = VoicemailDetector(
|
||||
llm=voicemail_llm,
|
||||
long_speech_timeout=long_speech_timeout,
|
||||
custom_system_prompt=custom_system_prompt,
|
||||
)
|
||||
if voicemail_config.get("enabled", False):
|
||||
logger.info(
|
||||
f"Voicemail detection enabled for workflow run {workflow_run_id}"
|
||||
)
|
||||
# Create a separate LLM instance for the voicemail sub-pipeline
|
||||
# (can't share with main pipeline as it would mess up frame linking)
|
||||
if voicemail_config.get("use_workflow_llm", True):
|
||||
voicemail_llm = create_llm_service(user_config)
|
||||
else:
|
||||
voicemail_llm = create_llm_service_from_provider(
|
||||
provider=voicemail_config.get("provider", "openai"),
|
||||
model=voicemail_config.get("model", "gpt-4.1"),
|
||||
api_key=voicemail_config.get("api_key", ""),
|
||||
)
|
||||
|
||||
long_speech_timeout = voicemail_config.get("long_speech_timeout", 8.0)
|
||||
custom_system_prompt = voicemail_config.get("system_prompt") or None
|
||||
|
||||
voicemail_detector = VoicemailDetector(
|
||||
llm=voicemail_llm,
|
||||
long_speech_timeout=long_speech_timeout,
|
||||
custom_system_prompt=custom_system_prompt,
|
||||
# Register event handler to end task when voicemail is detected
|
||||
@voicemail_detector.event_handler("on_voicemail_detected")
|
||||
async def _on_voicemail_detected(_processor):
|
||||
logger.info(f"Voicemail detected for workflow run {workflow_run_id}")
|
||||
await engine.end_call_with_reason(
|
||||
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
abort_immediately=True,
|
||||
)
|
||||
|
||||
# Register event handler to end task when voicemail is detected
|
||||
@voicemail_detector.event_handler("on_voicemail_detected")
|
||||
async def _on_voicemail_detected(_processor):
|
||||
logger.info(f"Voicemail detected for workflow run {workflow_run_id}")
|
||||
await engine.end_call_with_reason(
|
||||
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
abort_immediately=True,
|
||||
)
|
||||
|
||||
# Create recording router if workflow has active recordings
|
||||
if has_recordings:
|
||||
recording_router = RecordingRouterProcessor(
|
||||
audio_sample_rate=audio_config.pipeline_sample_rate,
|
||||
fetch_recording_audio=fetch_audio,
|
||||
)
|
||||
# Warm the recording cache in the background so audio is ready
|
||||
# before the first playback request.
|
||||
asyncio.create_task(
|
||||
warm_recording_cache(
|
||||
organization_id=workflow.organization_id,
|
||||
pipeline_sample_rate=audio_config.pipeline_sample_rate,
|
||||
)
|
||||
# Recording router is only meaningful in non-realtime mode (it routes between
|
||||
# pre-recorded audio playback and dynamic TTS; realtime LLMs produce audio
|
||||
# directly).
|
||||
if not is_realtime and has_recordings:
|
||||
recording_router = RecordingRouterProcessor(
|
||||
audio_sample_rate=audio_config.pipeline_sample_rate,
|
||||
fetch_recording_audio=fetch_audio,
|
||||
)
|
||||
# Warm the recording cache in the background so audio is ready
|
||||
# before the first playback request.
|
||||
asyncio.create_task(
|
||||
warm_recording_cache(
|
||||
organization_id=workflow.organization_id,
|
||||
pipeline_sample_rate=audio_config.pipeline_sample_rate,
|
||||
)
|
||||
)
|
||||
|
||||
# Build the pipeline
|
||||
if is_realtime:
|
||||
|
|
@ -596,6 +600,7 @@ async def _run_pipeline(
|
|||
assistant_context_aggregator,
|
||||
pipeline_engine_callback_processor,
|
||||
pipeline_metrics_aggregator,
|
||||
voicemail_detector=voicemail_detector,
|
||||
)
|
||||
else:
|
||||
pipeline = build_pipeline(
|
||||
|
|
|
|||
|
|
@ -139,6 +139,7 @@ class TelnyxProvider(TelephonyProvider):
|
|||
status="initiated",
|
||||
caller_number=from_number,
|
||||
provider_metadata={
|
||||
"call_id": call_control_id,
|
||||
"call_control_id": call_control_id,
|
||||
"call_leg_id": call_leg_id,
|
||||
"call_session_id": call_session_id,
|
||||
|
|
@ -321,6 +322,15 @@ class TelnyxProvider(TelephonyProvider):
|
|||
},
|
||||
)
|
||||
|
||||
except WebSocketDisconnect as e:
|
||||
# Telnyx opens the WebSocket during `bridging` (pre-answer) but only
|
||||
# sends the `start` event on `call.answered`. If the call ends before
|
||||
# answer (no-answer timeout, busy, declined), Telnyx closes the
|
||||
# socket abruptly — surface this as an expected end-of-call.
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Telnyx WebSocket closed before stream start "
|
||||
f"(call ended pre-answer): code={e.code}, reason={e.reason!r}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Telnyx WebSocket handler: {e}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -140,7 +140,8 @@ class VonageProvider(TelephonyProvider):
|
|||
status=response_data.get("status", "started"),
|
||||
caller_number=from_number,
|
||||
provider_metadata={
|
||||
"call_uuid": response_data["uuid"]
|
||||
"call_id": response_data["uuid"],
|
||||
"call_uuid": response_data["uuid"],
|
||||
}, # Vonage needs UUID persisted for WebSocket
|
||||
raw_response=response_data,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -179,9 +179,12 @@ async def _process_status_update(workflow_run_id: int, status: StatusCallbackReq
|
|||
|
||||
if workflow_run.campaign_id:
|
||||
await campaign_call_dispatcher.release_call_slot(workflow_run_id)
|
||||
is_failure = status.status in ("error", "failed")
|
||||
await circuit_breaker.record_and_evaluate(
|
||||
workflow_run.campaign_id,
|
||||
is_failure=status.status in ("error", "failed"),
|
||||
is_failure=is_failure,
|
||||
workflow_run_id=workflow_run_id if is_failure else None,
|
||||
reason=status.status if is_failure else None,
|
||||
)
|
||||
|
||||
if status.status in ["busy", "no-answer"] and workflow_run.campaign_id:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,10 @@ from api.services.campaign.campaign_call_dispatcher import campaign_call_dispatc
|
|||
from api.services.campaign.campaign_event_publisher import (
|
||||
get_campaign_event_publisher,
|
||||
)
|
||||
from api.services.campaign.errors import ConcurrentSlotAcquisitionError
|
||||
from api.services.campaign.errors import (
|
||||
ConcurrentSlotAcquisitionError,
|
||||
PhoneNumberPoolExhaustedError,
|
||||
)
|
||||
from api.services.campaign.source_sync_factory import get_sync_service
|
||||
|
||||
|
||||
|
|
@ -80,6 +83,13 @@ async def sync_campaign_source(ctx: Dict, campaign_id: int) -> None:
|
|||
source_sync_status="failed",
|
||||
source_sync_error=str(e),
|
||||
)
|
||||
await db_client.append_campaign_log(
|
||||
campaign_id=campaign_id,
|
||||
level="error",
|
||||
event="source_sync_failed",
|
||||
message=f"Source sync failed: {e}",
|
||||
details={"error": str(e)},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
|
|
@ -137,6 +147,39 @@ async def process_campaign_batch(
|
|||
|
||||
# Update campaign state to failed
|
||||
await db_client.update_campaign(campaign_id=campaign_id, state="failed")
|
||||
await db_client.append_campaign_log(
|
||||
campaign_id=campaign_id,
|
||||
level="error",
|
||||
event="batch_failed",
|
||||
message=f"Concurrent slot acquisition timeout: {e}",
|
||||
details={"error": str(e), "reason": "concurrent_slot_timeout"},
|
||||
)
|
||||
raise
|
||||
|
||||
except PhoneNumberPoolExhaustedError as e:
|
||||
logger.warning(f"Phone number pool exhausted for campaign {campaign_id}: {e}")
|
||||
|
||||
publisher = await get_campaign_event_publisher()
|
||||
await publisher.publish_batch_failed(
|
||||
campaign_id=campaign_id,
|
||||
error=f"Phone number pool exhausted: {e}",
|
||||
processed_count=0,
|
||||
)
|
||||
|
||||
await db_client.update_campaign(campaign_id=campaign_id, state="failed")
|
||||
await db_client.append_campaign_log(
|
||||
campaign_id=campaign_id,
|
||||
level="error",
|
||||
event="phone_number_pool_exhausted",
|
||||
message=(
|
||||
f"Phone number pool exhausted for org {e.organization_id}: "
|
||||
"no free from_number available to dispatch outbound calls"
|
||||
),
|
||||
details={
|
||||
"error": str(e),
|
||||
"organization_id": e.organization_id,
|
||||
},
|
||||
)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -152,4 +195,11 @@ async def process_campaign_batch(
|
|||
|
||||
# Update campaign state to failed
|
||||
await db_client.update_campaign(campaign_id=campaign_id, state="failed")
|
||||
await db_client.append_campaign_log(
|
||||
campaign_id=campaign_id,
|
||||
level="error",
|
||||
event="batch_failed",
|
||||
message=f"Batch processing failed: {e}",
|
||||
details={"error": str(e)},
|
||||
)
|
||||
raise
|
||||
|
|
|
|||
87
api/tests/test_campaign_tasks.py
Normal file
87
api/tests/test_campaign_tasks.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""
|
||||
Tests for api.tasks.campaign_tasks failure handling.
|
||||
|
||||
Specifically: each kind of failure that pauses or fails a campaign should
|
||||
write a specific, identifiable entry into the campaign log so operators
|
||||
can tell at a glance why a campaign stopped.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.campaign.errors import (
|
||||
ConcurrentSlotAcquisitionError,
|
||||
PhoneNumberPoolExhaustedError,
|
||||
)
|
||||
from api.tasks.campaign_tasks import process_campaign_batch
|
||||
|
||||
|
||||
class TestProcessCampaignBatchFailureLogs:
|
||||
"""``process_campaign_batch`` should log a *specific* event for each
|
||||
distinct failure mode, not collapse them all into a generic
|
||||
``batch_failed`` entry."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_phone_number_pool_exhausted_logs_specific_event(self):
|
||||
"""When PhoneNumberPoolExhaustedError propagates from process_batch,
|
||||
the campaign log entry should use event='phone_number_pool_exhausted'
|
||||
with a clear message — not the generic 'batch_failed' bucket."""
|
||||
with (
|
||||
patch("api.tasks.campaign_tasks.campaign_call_dispatcher") as mock_disp,
|
||||
patch("api.tasks.campaign_tasks.db_client") as mock_db,
|
||||
patch(
|
||||
"api.tasks.campaign_tasks.get_campaign_event_publisher"
|
||||
) as mock_get_pub,
|
||||
):
|
||||
mock_disp.process_batch = AsyncMock(
|
||||
side_effect=PhoneNumberPoolExhaustedError(organization_id=7)
|
||||
)
|
||||
mock_db.update_campaign = AsyncMock()
|
||||
mock_db.append_campaign_log = AsyncMock()
|
||||
mock_pub = AsyncMock()
|
||||
mock_get_pub.return_value = mock_pub
|
||||
|
||||
with pytest.raises(PhoneNumberPoolExhaustedError):
|
||||
await process_campaign_batch({}, campaign_id=42)
|
||||
|
||||
mock_db.update_campaign.assert_called_once_with(
|
||||
campaign_id=42, state="failed"
|
||||
)
|
||||
|
||||
mock_db.append_campaign_log.assert_called_once()
|
||||
kwargs = mock_db.append_campaign_log.call_args.kwargs
|
||||
assert kwargs["campaign_id"] == 42
|
||||
assert kwargs["event"] == "phone_number_pool_exhausted"
|
||||
assert kwargs["level"] == "error"
|
||||
assert "phone number" in kwargs["message"].lower()
|
||||
assert kwargs["details"]["organization_id"] == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_slot_timeout_still_logs_specific_event(self):
|
||||
"""Regression guard: the existing ConcurrentSlotAcquisitionError branch
|
||||
should keep logging its specific reason."""
|
||||
with (
|
||||
patch("api.tasks.campaign_tasks.campaign_call_dispatcher") as mock_disp,
|
||||
patch("api.tasks.campaign_tasks.db_client") as mock_db,
|
||||
patch(
|
||||
"api.tasks.campaign_tasks.get_campaign_event_publisher"
|
||||
) as mock_get_pub,
|
||||
):
|
||||
mock_disp.process_batch = AsyncMock(
|
||||
side_effect=ConcurrentSlotAcquisitionError(
|
||||
organization_id=7, campaign_id=42, wait_time=30.0
|
||||
)
|
||||
)
|
||||
mock_db.update_campaign = AsyncMock()
|
||||
mock_db.append_campaign_log = AsyncMock()
|
||||
mock_pub = AsyncMock()
|
||||
mock_get_pub.return_value = mock_pub
|
||||
|
||||
with pytest.raises(ConcurrentSlotAcquisitionError):
|
||||
await process_campaign_batch({}, campaign_id=42)
|
||||
|
||||
mock_db.append_campaign_log.assert_called_once()
|
||||
kwargs = mock_db.append_campaign_log.call_args.kwargs
|
||||
assert kwargs["event"] == "batch_failed"
|
||||
assert kwargs["details"]["reason"] == "concurrent_slot_timeout"
|
||||
|
|
@ -198,7 +198,9 @@ class TestCircuitBreakerReset:
|
|||
result = await cb.reset(campaign_id=42)
|
||||
|
||||
assert result is True
|
||||
mock_redis.delete.assert_called_once_with("cb_failures:42", "cb_successes:42")
|
||||
mock_redis.delete.assert_called_once_with(
|
||||
"cb_failures:42", "cb_successes:42", "cb_recent_failures:42"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_on_redis_error(self):
|
||||
|
|
@ -253,6 +255,7 @@ class TestRecordAndEvaluate:
|
|||
):
|
||||
mock_db.get_campaign_by_id = AsyncMock(return_value=mock_campaign)
|
||||
mock_db.update_campaign = AsyncMock()
|
||||
mock_db.append_campaign_log = AsyncMock()
|
||||
|
||||
mock_publisher = AsyncMock()
|
||||
mock_get_publisher.return_value = mock_publisher
|
||||
|
|
@ -352,6 +355,206 @@ class TestRecordAndEvaluate:
|
|||
await cb.record_and_evaluate(campaign_id=42, is_failure=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for recent-failures tracking (workflow_run_id + reason)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCircuitBreakerRecentFailures:
|
||||
"""When a call fails, the circuit breaker should remember the workflow_run_id
|
||||
and reason in a capped Redis list, and surface those entries in the campaign
|
||||
log entry written when the breaker trips."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_pushes_recent_failure_entry(self):
|
||||
"""is_failure=True with run id + reason should push to recent-failures list."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
mock_campaign = MagicMock()
|
||||
mock_campaign.id = 42
|
||||
mock_campaign.state = "running"
|
||||
mock_campaign.orchestrator_metadata = {}
|
||||
|
||||
with patch("api.services.campaign.circuit_breaker.db_client") as mock_db:
|
||||
mock_db.get_campaign_by_id = AsyncMock(return_value=mock_campaign)
|
||||
mock_db.append_campaign_log = AsyncMock()
|
||||
cb.record_call_outcome = AsyncMock(return_value=(False, None))
|
||||
cb._push_recent_failure = AsyncMock()
|
||||
cb._get_recent_failures = AsyncMock(return_value=[])
|
||||
|
||||
await cb.record_and_evaluate(
|
||||
campaign_id=42,
|
||||
is_failure=True,
|
||||
workflow_run_id=100,
|
||||
reason="failed",
|
||||
)
|
||||
|
||||
cb._push_recent_failure.assert_called_once_with(
|
||||
campaign_id=42, workflow_run_id=100, reason="failed"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_does_not_push_recent_failure(self):
|
||||
"""is_failure=False must not push to the recent-failures list."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
mock_campaign = MagicMock()
|
||||
mock_campaign.id = 42
|
||||
mock_campaign.state = "running"
|
||||
mock_campaign.orchestrator_metadata = {}
|
||||
|
||||
with patch("api.services.campaign.circuit_breaker.db_client") as mock_db:
|
||||
mock_db.get_campaign_by_id = AsyncMock(return_value=mock_campaign)
|
||||
cb.record_call_outcome = AsyncMock(return_value=(False, None))
|
||||
cb._push_recent_failure = AsyncMock()
|
||||
cb._get_recent_failures = AsyncMock(return_value=[])
|
||||
|
||||
await cb.record_and_evaluate(
|
||||
campaign_id=42,
|
||||
is_failure=False,
|
||||
workflow_run_id=100,
|
||||
reason=None,
|
||||
)
|
||||
|
||||
cb._push_recent_failure.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trip_log_includes_recent_failures_in_details(self):
|
||||
"""When the breaker trips, the campaign log entry's details should include
|
||||
recent_failures fetched from the Redis list."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
mock_campaign = MagicMock()
|
||||
mock_campaign.id = 42
|
||||
mock_campaign.state = "running"
|
||||
mock_campaign.orchestrator_metadata = {}
|
||||
|
||||
stats = {
|
||||
"failure_rate": 0.6,
|
||||
"failure_count": 6,
|
||||
"success_count": 4,
|
||||
"threshold": 0.5,
|
||||
"window_seconds": 120,
|
||||
}
|
||||
|
||||
recent = [
|
||||
{"workflow_run_id": 100, "reason": "failed", "ts": 1700000010.0},
|
||||
{"workflow_run_id": 99, "reason": "error", "ts": 1700000000.0},
|
||||
]
|
||||
|
||||
with (
|
||||
patch("api.services.campaign.circuit_breaker.db_client") as mock_db,
|
||||
patch(
|
||||
"api.services.campaign.circuit_breaker.get_campaign_event_publisher"
|
||||
) as mock_get_publisher,
|
||||
):
|
||||
mock_db.get_campaign_by_id = AsyncMock(return_value=mock_campaign)
|
||||
mock_db.update_campaign = AsyncMock()
|
||||
mock_db.append_campaign_log = AsyncMock()
|
||||
|
||||
mock_publisher = AsyncMock()
|
||||
mock_get_publisher.return_value = mock_publisher
|
||||
|
||||
cb.record_call_outcome = AsyncMock(return_value=(True, stats))
|
||||
cb._push_recent_failure = AsyncMock()
|
||||
cb._get_recent_failures = AsyncMock(return_value=recent)
|
||||
|
||||
await cb.record_and_evaluate(
|
||||
campaign_id=42,
|
||||
is_failure=True,
|
||||
workflow_run_id=100,
|
||||
reason="failed",
|
||||
)
|
||||
|
||||
mock_db.append_campaign_log.assert_called_once()
|
||||
kwargs = mock_db.append_campaign_log.call_args.kwargs
|
||||
assert kwargs["campaign_id"] == 42
|
||||
assert kwargs["event"] == "circuit_breaker_tripped"
|
||||
assert kwargs["details"]["recent_failures"] == recent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_recent_failure_uses_lpush_and_ltrim(self):
|
||||
"""_push_recent_failure should LPUSH a JSON entry and LTRIM the list
|
||||
to keep only the most recent N (default 20)."""
|
||||
import json
|
||||
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.lpush = AsyncMock(return_value=1)
|
||||
mock_redis.ltrim = AsyncMock(return_value=True)
|
||||
mock_redis.expire = AsyncMock(return_value=True)
|
||||
cb.redis_client = mock_redis
|
||||
|
||||
await cb._push_recent_failure(
|
||||
campaign_id=42, workflow_run_id=100, reason="failed"
|
||||
)
|
||||
|
||||
# Verify the key used
|
||||
mock_redis.lpush.assert_called_once()
|
||||
push_args = mock_redis.lpush.call_args.args
|
||||
assert push_args[0] == "cb_recent_failures:42"
|
||||
|
||||
# Verify the payload includes the run id + reason
|
||||
entry = json.loads(push_args[1])
|
||||
assert entry["workflow_run_id"] == 100
|
||||
assert entry["reason"] == "failed"
|
||||
assert "ts" in entry
|
||||
|
||||
# Verify the cap (LTRIM 0 19 keeps 20 entries)
|
||||
mock_redis.ltrim.assert_called_once_with("cb_recent_failures:42", 0, 19)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_recent_failures_decodes_lrange(self):
|
||||
"""_get_recent_failures should LRANGE the list and JSON-decode entries."""
|
||||
import json
|
||||
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
entries = [
|
||||
json.dumps({"workflow_run_id": 100, "reason": "failed", "ts": 1.0}),
|
||||
json.dumps({"workflow_run_id": 99, "reason": "error", "ts": 0.5}),
|
||||
]
|
||||
mock_redis.lrange = AsyncMock(return_value=entries)
|
||||
cb.redis_client = mock_redis
|
||||
|
||||
result = await cb._get_recent_failures(campaign_id=42)
|
||||
|
||||
mock_redis.lrange.assert_called_once_with("cb_recent_failures:42", 0, -1)
|
||||
assert result == [
|
||||
{"workflow_run_id": 100, "reason": "failed", "ts": 1.0},
|
||||
{"workflow_run_id": 99, "reason": "error", "ts": 0.5},
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_clears_recent_failures_key(self):
|
||||
"""reset() must also delete cb_recent_failures:{campaign_id}."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.delete = AsyncMock(return_value=3)
|
||||
cb.redis_client = mock_redis
|
||||
|
||||
await cb.reset(campaign_id=42)
|
||||
|
||||
mock_redis.delete.assert_called_once_with(
|
||||
"cb_failures:42", "cb_successes:42", "cb_recent_failures:42"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration tests: _process_status_update calls circuit_breaker
|
||||
# =============================================================================
|
||||
|
|
@ -405,7 +608,12 @@ class TestProcessStatusUpdateCircuitBreaker:
|
|||
|
||||
await _process_status_update(100, status)
|
||||
|
||||
mock_cb.record_and_evaluate.assert_called_once_with(42, is_failure=True)
|
||||
mock_cb.record_and_evaluate.assert_called_once_with(
|
||||
42,
|
||||
is_failure=True,
|
||||
workflow_run_id=100,
|
||||
reason="failed",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_status_calls_record_and_evaluate(self):
|
||||
|
|
|
|||
|
|
@ -720,9 +720,9 @@ class TestCustomToolManagerUnit:
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_tool_schemas_returns_correct_format(self):
|
||||
"""Test that get_tool_schemas returns FunctionSchema objects."""
|
||||
# Create a mock engine
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
|
||||
# Create a mock engine
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue