mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
Feat/campaign enhancements (#163)
* feat: add circuit breaker to safeguard * feat: Add Circuit breaker in campaigns to safeguard against telephony failures * feat: add schedules in campaigns
This commit is contained in:
parent
7552b6c819
commit
fe4ea648e4
17 changed files with 2037 additions and 149 deletions
|
|
@ -104,6 +104,15 @@ DEFAULT_CAMPAIGN_RETRY_CONFIG = {
|
|||
}
|
||||
|
||||
|
||||
# Circuit breaker defaults for campaign call failure detection
|
||||
DEFAULT_CIRCUIT_BREAKER_CONFIG = {
|
||||
"enabled": True,
|
||||
"failure_threshold": 0.5, # 50% failure rate trips the breaker
|
||||
"window_seconds": 120, # 2-minute sliding window
|
||||
"min_calls_in_window": 5, # Don't trip until at least 5 outcomes
|
||||
}
|
||||
|
||||
|
||||
TURN_SECRET = os.getenv("TURN_SECRET")
|
||||
TURN_HOST = os.getenv("TURN_HOST", "localhost")
|
||||
TURN_PORT = int(os.getenv("TURN_PORT", "3478"))
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ class CampaignClient(BaseDBClient):
|
|||
organization_id: int,
|
||||
retry_config: Optional[dict] = None,
|
||||
max_concurrency: Optional[int] = None,
|
||||
schedule_config: Optional[dict] = None,
|
||||
) -> CampaignModel:
|
||||
"""Create a new campaign"""
|
||||
async with self.async_session() as session:
|
||||
|
|
@ -28,6 +29,8 @@ class CampaignClient(BaseDBClient):
|
|||
orchestrator_metadata = {}
|
||||
if max_concurrency is not None:
|
||||
orchestrator_metadata["max_concurrency"] = max_concurrency
|
||||
if schedule_config is not None:
|
||||
orchestrator_metadata["schedule_config"] = schedule_config
|
||||
|
||||
campaign = CampaignModel(
|
||||
name=name,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from api.constants import DEFAULT_CAMPAIGN_RETRY_CONFIG, DEFAULT_ORG_CONCURRENCY_LIMIT
|
||||
from api.db import db_client
|
||||
|
|
@ -46,6 +47,28 @@ async def _get_from_numbers_count(organization_id: int) -> int:
|
|||
return 0
|
||||
|
||||
|
||||
async def _validate_max_concurrency(max_concurrency: int, organization_id: int) -> None:
|
||||
"""Validate max_concurrency against org limit and configured phone numbers.
|
||||
|
||||
Raises HTTPException(400) if the value exceeds the effective limit.
|
||||
"""
|
||||
org_limit = await _get_org_concurrent_limit(organization_id)
|
||||
from_numbers_count = await _get_from_numbers_count(organization_id)
|
||||
effective_limit = (
|
||||
min(org_limit, from_numbers_count) if from_numbers_count > 0 else org_limit
|
||||
)
|
||||
if max_concurrency > effective_limit:
|
||||
if from_numbers_count > 0 and from_numbers_count < org_limit:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"max_concurrency ({max_concurrency}) cannot exceed {effective_limit}. You have {from_numbers_count} phone number(s) configured. Add more CLIs in telephony configuration to increase concurrency.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"max_concurrency ({max_concurrency}) cannot exceed organization limit ({effective_limit})",
|
||||
)
|
||||
|
||||
|
||||
class RetryConfigRequest(BaseModel):
|
||||
enabled: bool = True
|
||||
max_retries: int = Field(default=2, ge=0, le=10)
|
||||
|
|
@ -64,6 +87,45 @@ class RetryConfigResponse(BaseModel):
|
|||
retry_on_voicemail: bool
|
||||
|
||||
|
||||
class TimeSlotRequest(BaseModel):
|
||||
day_of_week: int = Field(..., ge=0, le=6)
|
||||
start_time: str = Field(..., pattern=r"^\d{2}:\d{2}$")
|
||||
end_time: str = Field(..., pattern=r"^\d{2}:\d{2}$")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_times(self):
|
||||
if self.start_time >= self.end_time:
|
||||
raise ValueError("start_time must be before end_time")
|
||||
return self
|
||||
|
||||
|
||||
class ScheduleConfigRequest(BaseModel):
|
||||
enabled: bool = True
|
||||
timezone: str = "UTC"
|
||||
slots: List[TimeSlotRequest] = Field(..., min_length=1, max_length=50)
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_timezone(cls, v: str) -> str:
|
||||
try:
|
||||
ZoneInfo(v)
|
||||
except (KeyError, Exception):
|
||||
raise ValueError(f"Invalid timezone: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class TimeSlotResponse(BaseModel):
|
||||
day_of_week: int
|
||||
start_time: str
|
||||
end_time: str
|
||||
|
||||
|
||||
class ScheduleConfigResponse(BaseModel):
|
||||
enabled: bool
|
||||
timezone: str
|
||||
slots: List[TimeSlotResponse]
|
||||
|
||||
|
||||
class CreateCampaignRequest(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
workflow_id: int
|
||||
|
|
@ -71,6 +133,14 @@ class CreateCampaignRequest(BaseModel):
|
|||
source_id: str # Google Sheet URL or CSV file key
|
||||
retry_config: Optional[RetryConfigRequest] = None
|
||||
max_concurrency: Optional[int] = Field(default=None, ge=1, le=100)
|
||||
schedule_config: Optional[ScheduleConfigRequest] = None
|
||||
|
||||
|
||||
class UpdateCampaignRequest(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
retry_config: Optional[RetryConfigRequest] = None
|
||||
max_concurrency: Optional[int] = Field(default=None, ge=1, le=100)
|
||||
schedule_config: Optional[ScheduleConfigRequest] = None
|
||||
|
||||
|
||||
class CampaignResponse(BaseModel):
|
||||
|
|
@ -89,6 +159,7 @@ class CampaignResponse(BaseModel):
|
|||
completed_at: Optional[datetime]
|
||||
retry_config: RetryConfigResponse
|
||||
max_concurrency: Optional[int] = None
|
||||
schedule_config: Optional[ScheduleConfigResponse] = None
|
||||
|
||||
|
||||
class CampaignsResponse(BaseModel):
|
||||
|
|
@ -138,10 +209,18 @@ def _build_campaign_response(campaign, workflow_name: str) -> CampaignResponse:
|
|||
else DEFAULT_CAMPAIGN_RETRY_CONFIG
|
||||
)
|
||||
|
||||
# Get max_concurrency from orchestrator_metadata
|
||||
# Get max_concurrency and schedule_config from orchestrator_metadata
|
||||
max_concurrency = None
|
||||
schedule_config = None
|
||||
if campaign.orchestrator_metadata:
|
||||
max_concurrency = campaign.orchestrator_metadata.get("max_concurrency")
|
||||
sc = campaign.orchestrator_metadata.get("schedule_config")
|
||||
if sc:
|
||||
schedule_config = ScheduleConfigResponse(
|
||||
enabled=sc.get("enabled", False),
|
||||
timezone=sc.get("timezone", "UTC"),
|
||||
slots=[TimeSlotResponse(**slot) for slot in sc.get("slots", [])],
|
||||
)
|
||||
|
||||
return CampaignResponse(
|
||||
id=campaign.id,
|
||||
|
|
@ -159,6 +238,7 @@ def _build_campaign_response(campaign, workflow_name: str) -> CampaignResponse:
|
|||
completed_at=campaign.completed_at,
|
||||
retry_config=RetryConfigResponse(**retry_config),
|
||||
max_concurrency=max_concurrency,
|
||||
schedule_config=schedule_config,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -181,31 +261,21 @@ async def create_campaign(
|
|||
if not validation_result.is_valid:
|
||||
raise HTTPException(status_code=400, detail=validation_result.error.message)
|
||||
|
||||
# Validate max_concurrency against effective limit (min of org limit and from_numbers count)
|
||||
if request.max_concurrency is not None:
|
||||
org_limit = await _get_org_concurrent_limit(user.selected_organization_id)
|
||||
from_numbers_count = await _get_from_numbers_count(
|
||||
user.selected_organization_id
|
||||
await _validate_max_concurrency(
|
||||
request.max_concurrency, user.selected_organization_id
|
||||
)
|
||||
effective_limit = (
|
||||
min(org_limit, from_numbers_count) if from_numbers_count > 0 else org_limit
|
||||
)
|
||||
if request.max_concurrency > effective_limit:
|
||||
if from_numbers_count > 0 and from_numbers_count < org_limit:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"max_concurrency ({request.max_concurrency}) cannot exceed {effective_limit}. You have {from_numbers_count} phone number(s) configured. Add more CLIs in telephony configuration to increase concurrency.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"max_concurrency ({request.max_concurrency}) cannot exceed organization limit ({effective_limit})",
|
||||
)
|
||||
|
||||
# Build retry_config dict if provided
|
||||
retry_config = None
|
||||
if request.retry_config:
|
||||
retry_config = request.retry_config.model_dump()
|
||||
|
||||
# Build schedule_config dict if provided
|
||||
schedule_config = None
|
||||
if request.schedule_config:
|
||||
schedule_config = request.schedule_config.model_dump()
|
||||
|
||||
campaign = await db_client.create_campaign(
|
||||
name=request.name,
|
||||
workflow_id=request.workflow_id,
|
||||
|
|
@ -215,6 +285,7 @@ async def create_campaign(
|
|||
organization_id=user.selected_organization_id,
|
||||
retry_config=retry_config,
|
||||
max_concurrency=request.max_concurrency,
|
||||
schedule_config=schedule_config,
|
||||
)
|
||||
|
||||
return _build_campaign_response(campaign, workflow_name)
|
||||
|
|
@ -322,6 +393,62 @@ async def pause_campaign(
|
|||
return _build_campaign_response(campaign, workflow_name or "Unknown")
|
||||
|
||||
|
||||
@router.patch("/{campaign_id}")
|
||||
async def update_campaign(
|
||||
campaign_id: int,
|
||||
request: UpdateCampaignRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> CampaignResponse:
|
||||
"""Update campaign settings (name, retry config, max concurrency, schedule)"""
|
||||
campaign = await db_client.get_campaign(campaign_id, user.selected_organization_id)
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
if campaign.state in ["completed", "failed"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot update a {campaign.state} campaign",
|
||||
)
|
||||
|
||||
if request.max_concurrency is not None:
|
||||
await _validate_max_concurrency(
|
||||
request.max_concurrency, user.selected_organization_id
|
||||
)
|
||||
|
||||
# Build update kwargs
|
||||
update_kwargs = {}
|
||||
|
||||
if request.name is not None:
|
||||
update_kwargs["name"] = request.name
|
||||
|
||||
if request.retry_config is not None:
|
||||
update_kwargs["retry_config"] = request.retry_config.model_dump()
|
||||
|
||||
# Merge max_concurrency and schedule_config into orchestrator_metadata
|
||||
metadata = campaign.orchestrator_metadata or {}
|
||||
metadata_changed = False
|
||||
|
||||
if request.max_concurrency is not None:
|
||||
metadata["max_concurrency"] = request.max_concurrency
|
||||
metadata_changed = True
|
||||
|
||||
if request.schedule_config is not None:
|
||||
metadata["schedule_config"] = request.schedule_config.model_dump()
|
||||
metadata_changed = True
|
||||
|
||||
if metadata_changed:
|
||||
update_kwargs["orchestrator_metadata"] = metadata
|
||||
|
||||
if update_kwargs:
|
||||
await db_client.update_campaign(campaign_id=campaign_id, **update_kwargs)
|
||||
|
||||
# Re-fetch to return updated data
|
||||
campaign = await db_client.get_campaign(campaign_id, user.selected_organization_id)
|
||||
workflow_name = await db_client.get_workflow_name(campaign.workflow_id, user.id)
|
||||
|
||||
return _build_campaign_response(campaign, workflow_name or "Unknown")
|
||||
|
||||
|
||||
@router.get("/{campaign_id}/runs")
|
||||
async def get_campaign_runs(
|
||||
campaign_id: int,
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from api.errors.telephony_errors import TelephonyError
|
|||
from api.services.auth.depends import get_user
|
||||
from api.services.campaign.campaign_call_dispatcher import campaign_call_dispatcher
|
||||
from api.services.campaign.campaign_event_publisher import get_campaign_event_publisher
|
||||
from api.services.campaign.circuit_breaker import circuit_breaker
|
||||
from api.services.quota_service import check_dograh_quota, check_dograh_quota_by_user_id
|
||||
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
|
||||
from api.services.telephony.factory import (
|
||||
|
|
@ -760,6 +761,9 @@ async def _process_status_update(workflow_run_id: int, status: StatusCallbackReq
|
|||
# Release concurrent slot if this was a campaign call
|
||||
if workflow_run.campaign_id:
|
||||
await campaign_call_dispatcher.release_call_slot(workflow_run_id)
|
||||
await circuit_breaker.record_and_evaluate(
|
||||
workflow_run.campaign_id, is_failure=False
|
||||
)
|
||||
|
||||
# Mark workflow run as completed
|
||||
await db_client.update_workflow_run(
|
||||
|
|
@ -776,6 +780,9 @@ async def _process_status_update(workflow_run_id: int, status: StatusCallbackReq
|
|||
# Release concurrent slot for terminal statuses if this was a campaign call
|
||||
if workflow_run.campaign_id:
|
||||
await campaign_call_dispatcher.release_call_slot(workflow_run_id)
|
||||
await circuit_breaker.record_and_evaluate(
|
||||
workflow_run.campaign_id, is_failure=True
|
||||
)
|
||||
|
||||
# Check if retry is needed for campaign calls (busy/no-answer)
|
||||
if status.status in ["busy", "no-answer"] and workflow_run.campaign_id:
|
||||
|
|
|
|||
|
|
@ -33,6 +33,9 @@ class CampaignEventType(str, Enum):
|
|||
RETRY_SCHEDULED = "retry_scheduled"
|
||||
RETRY_FAILED = "retry_failed"
|
||||
|
||||
# Circuit breaker events
|
||||
CIRCUIT_BREAKER_TRIPPED = "circuit_breaker_tripped"
|
||||
|
||||
|
||||
class RetryReason(str, Enum):
|
||||
"""Reasons for retry."""
|
||||
|
|
@ -218,6 +221,18 @@ class RetryFailedEvent(BaseCampaignEvent):
|
|||
last_reason: str = "" # RetryReason value
|
||||
|
||||
|
||||
@dataclass
|
||||
class CircuitBreakerTrippedEvent(BaseCampaignEvent):
|
||||
"""Event sent when the circuit breaker trips and pauses a campaign."""
|
||||
|
||||
type: str = CampaignEventType.CIRCUIT_BREAKER_TRIPPED
|
||||
failure_rate: float = 0.0
|
||||
failure_count: int = 0
|
||||
success_count: int = 0
|
||||
threshold: float = 0.0
|
||||
window_seconds: int = 0
|
||||
|
||||
|
||||
def parse_campaign_event(data: str) -> Any:
|
||||
"""Parse a campaign event message."""
|
||||
try:
|
||||
|
|
@ -239,6 +254,7 @@ def parse_campaign_event(data: str) -> Any:
|
|||
CampaignEventType.RETRY_NEEDED: RetryNeededEvent,
|
||||
CampaignEventType.RETRY_SCHEDULED: RetryScheduledEvent,
|
||||
CampaignEventType.RETRY_FAILED: RetryFailedEvent,
|
||||
CampaignEventType.CIRCUIT_BREAKER_TRIPPED: CircuitBreakerTrippedEvent,
|
||||
}
|
||||
|
||||
event_class = event_class_map.get(event_type)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from api.services.campaign.campaign_event_protocol import (
|
|||
BatchCompletedEvent,
|
||||
BatchFailedEvent,
|
||||
CampaignCompletedEvent,
|
||||
CircuitBreakerTrippedEvent,
|
||||
RetryNeededEvent,
|
||||
SyncCompletedEvent,
|
||||
)
|
||||
|
|
@ -123,6 +124,32 @@ class CampaignEventPublisher:
|
|||
|
||||
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
|
||||
|
||||
async def publish_circuit_breaker_tripped(
|
||||
self,
|
||||
campaign_id: int,
|
||||
failure_rate: float,
|
||||
failure_count: int,
|
||||
success_count: int,
|
||||
threshold: float,
|
||||
window_seconds: int,
|
||||
):
|
||||
"""Publish circuit breaker tripped event."""
|
||||
event = CircuitBreakerTrippedEvent(
|
||||
campaign_id=campaign_id,
|
||||
failure_rate=failure_rate,
|
||||
failure_count=failure_count,
|
||||
success_count=success_count,
|
||||
threshold=threshold,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
|
||||
|
||||
logger.warning(
|
||||
f"Published circuit breaker tripped event for campaign {campaign_id}: "
|
||||
f"failure_rate={failure_rate:.2%} ({failure_count} failures)"
|
||||
)
|
||||
|
||||
|
||||
# Global publisher instance with lazy Redis connection
|
||||
async def get_campaign_event_publisher() -> CampaignEventPublisher:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import asyncio
|
|||
import signal
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Dict
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from loguru import logger
|
||||
|
|
@ -25,11 +26,13 @@ from api.enums import RedisChannel
|
|||
from api.services.campaign.campaign_event_protocol import (
|
||||
BatchCompletedEvent,
|
||||
BatchFailedEvent,
|
||||
CircuitBreakerTrippedEvent,
|
||||
RetryNeededEvent,
|
||||
SyncCompletedEvent,
|
||||
parse_campaign_event,
|
||||
)
|
||||
from api.services.campaign.campaign_event_publisher import CampaignEventPublisher
|
||||
from api.services.campaign.circuit_breaker import circuit_breaker
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
|
||||
|
|
@ -165,6 +168,14 @@ class CampaignOrchestrator:
|
|||
await self._schedule_next_batch(campaign_id)
|
||||
self._last_activity[campaign_id] = datetime.now(UTC)
|
||||
|
||||
elif isinstance(event, CircuitBreakerTrippedEvent):
|
||||
# Circuit breaker tripped - clear state for this campaign
|
||||
logger.warning(
|
||||
f"campaign_id: {campaign_id} - Circuit breaker tripped event received: "
|
||||
f"failure_rate={event.failure_rate:.2%}"
|
||||
)
|
||||
self._clear_campaign_state(campaign_id)
|
||||
|
||||
async def _handle_retry_event(self, event: RetryNeededEvent):
|
||||
"""Process retry event and schedule if eligible (from campaign_retry_manager)."""
|
||||
|
||||
|
|
@ -274,6 +285,53 @@ class CampaignOrchestrator:
|
|||
f"last reason: {reason}"
|
||||
)
|
||||
|
||||
def _is_within_schedule(self, campaign: CampaignModel) -> bool:
|
||||
"""Check if the current time falls within the campaign's schedule windows.
|
||||
|
||||
Returns True (allow scheduling) if:
|
||||
- No schedule_config in metadata
|
||||
- Schedule is disabled
|
||||
- No slots configured
|
||||
- Invalid timezone (fail open)
|
||||
- Current time matches a slot
|
||||
"""
|
||||
if not campaign.orchestrator_metadata:
|
||||
return True
|
||||
|
||||
schedule_config = campaign.orchestrator_metadata.get("schedule_config")
|
||||
if not schedule_config:
|
||||
return True
|
||||
|
||||
if not schedule_config.get("enabled", False):
|
||||
return True
|
||||
|
||||
slots = schedule_config.get("slots")
|
||||
if not slots:
|
||||
return True
|
||||
|
||||
timezone_str = schedule_config.get("timezone", "UTC")
|
||||
try:
|
||||
tz = ZoneInfo(timezone_str)
|
||||
except (KeyError, Exception):
|
||||
logger.warning(
|
||||
f"campaign_id: {campaign.id} - Invalid timezone '{timezone_str}' in schedule_config, "
|
||||
f"failing open (allowing scheduling)"
|
||||
)
|
||||
return True
|
||||
|
||||
now = datetime.now(tz)
|
||||
current_day = now.weekday() # 0=Monday through 6=Sunday
|
||||
current_time = now.strftime("%H:%M")
|
||||
|
||||
for slot in slots:
|
||||
if slot.get("day_of_week") == current_day:
|
||||
start = slot.get("start_time", "")
|
||||
end = slot.get("end_time", "")
|
||||
if start <= current_time < end:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _schedule_next_batch(self, campaign_id: int):
|
||||
"""Schedule next batch immediately if work available."""
|
||||
|
||||
|
|
@ -302,6 +360,40 @@ class CampaignOrchestrator:
|
|||
)
|
||||
return
|
||||
|
||||
# Check schedule window before scheduling
|
||||
if not self._is_within_schedule(campaign):
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - Outside scheduled time window, skipping batch"
|
||||
)
|
||||
return
|
||||
|
||||
# Safety net: check circuit breaker before scheduling
|
||||
cb_config = None
|
||||
if campaign.orchestrator_metadata:
|
||||
cb_config = campaign.orchestrator_metadata.get("circuit_breaker")
|
||||
|
||||
is_open, stats = await circuit_breaker.is_circuit_open(
|
||||
campaign_id=campaign_id,
|
||||
config=cb_config,
|
||||
)
|
||||
|
||||
if is_open and stats:
|
||||
logger.warning(
|
||||
f"campaign_id: {campaign_id} - Circuit breaker is open, "
|
||||
f"pausing campaign. Stats: {stats}"
|
||||
)
|
||||
await db_client.update_campaign(campaign_id=campaign_id, state="paused")
|
||||
await self.publisher.publish_circuit_breaker_tripped(
|
||||
campaign_id=campaign_id,
|
||||
failure_rate=stats["failure_rate"],
|
||||
failure_count=stats["failure_count"],
|
||||
success_count=stats["success_count"],
|
||||
threshold=stats["threshold"],
|
||||
window_seconds=stats["window_seconds"],
|
||||
)
|
||||
self._clear_campaign_state(campaign_id)
|
||||
return
|
||||
|
||||
# Check for available work (queued runs + due retries)
|
||||
has_work = await self._has_pending_work(campaign_id)
|
||||
|
||||
|
|
@ -399,6 +491,12 @@ class CampaignOrchestrator:
|
|||
if campaign_id not in self._batch_in_progress:
|
||||
has_work = await self._has_pending_work(campaign_id)
|
||||
if has_work:
|
||||
if not self._is_within_schedule(campaign):
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - Found orphaned work but outside "
|
||||
f"schedule window, skipping"
|
||||
)
|
||||
continue
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - Found orphaned work (likely new retries), "
|
||||
f"scheduling batch to process"
|
||||
|
|
@ -428,6 +526,12 @@ class CampaignOrchestrator:
|
|||
# Check for any pending work
|
||||
has_work = await self._has_pending_work(campaign_id)
|
||||
if has_work:
|
||||
# If outside schedule window, don't mark complete — work remains for next window
|
||||
if not self._is_within_schedule(campaign):
|
||||
logger.debug(
|
||||
f"campaign_id: {campaign_id} - Outside schedule window with pending work, "
|
||||
f"not marking complete"
|
||||
)
|
||||
return False
|
||||
|
||||
# Check in-memory last activity
|
||||
|
|
|
|||
301
api/services/campaign/circuit_breaker.py
Normal file
301
api/services/campaign/circuit_breaker.py
Normal file
|
|
@ -0,0 +1,301 @@
|
|||
"""Campaign circuit breaker for automatic pause on high failure rates.
|
||||
|
||||
Uses two Redis sorted sets (ZSETs) per campaign — one for failures, one for
|
||||
successes — as sliding windows. ZCARD gives O(1) counts without iterating
|
||||
members, keeping the Lua scripts simple.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import DEFAULT_CIRCUIT_BREAKER_CONFIG, REDIS_URL
|
||||
from api.db import db_client
|
||||
from api.services.campaign.campaign_event_publisher import get_campaign_event_publisher
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""Sliding window circuit breaker for campaign call failures."""
|
||||
|
||||
def __init__(self):
|
||||
self.redis_client: Optional[aioredis.Redis] = None
|
||||
|
||||
async def _get_redis(self) -> aioredis.Redis:
|
||||
"""Get or create Redis connection."""
|
||||
if self.redis_client is None:
|
||||
self.redis_client = await aioredis.from_url(
|
||||
REDIS_URL, decode_responses=True
|
||||
)
|
||||
return self.redis_client
|
||||
|
||||
@staticmethod
|
||||
def _keys(campaign_id: int) -> Tuple[str, str]:
|
||||
"""Return (failures_key, successes_key) for a campaign."""
|
||||
return f"cb_failures:{campaign_id}", f"cb_successes:{campaign_id}"
|
||||
|
||||
async def record_call_outcome(
|
||||
self,
|
||||
campaign_id: int,
|
||||
is_failure: bool,
|
||||
config: Optional[dict] = None,
|
||||
) -> Tuple[bool, Optional[dict]]:
|
||||
"""Record a call outcome and check if the circuit breaker should trip.
|
||||
|
||||
Args:
|
||||
campaign_id: The campaign ID.
|
||||
is_failure: True if the call failed, False if succeeded.
|
||||
config: Optional per-campaign circuit breaker config override.
|
||||
Falls back to DEFAULT_CIRCUIT_BREAKER_CONFIG.
|
||||
|
||||
Returns:
|
||||
Tuple of (tripped: bool, stats: dict or None).
|
||||
If tripped is True, stats contains failure_rate, failure_count,
|
||||
success_count, threshold, window_seconds.
|
||||
"""
|
||||
cb_config = {**DEFAULT_CIRCUIT_BREAKER_CONFIG, **(config or {})}
|
||||
|
||||
if not cb_config.get("enabled", True):
|
||||
return False, None
|
||||
|
||||
redis_client = await self._get_redis()
|
||||
|
||||
window_seconds = cb_config["window_seconds"]
|
||||
threshold = cb_config["failure_threshold"]
|
||||
min_calls = cb_config["min_calls_in_window"]
|
||||
|
||||
now = time.time()
|
||||
window_start = now - window_seconds
|
||||
|
||||
fail_key, succ_key = self._keys(campaign_id)
|
||||
|
||||
lua_script = """
|
||||
local fail_key = KEYS[1]
|
||||
local succ_key = KEYS[2]
|
||||
local now = tonumber(ARGV[1])
|
||||
local window_start = tonumber(ARGV[2])
|
||||
local is_failure = tonumber(ARGV[3])
|
||||
local threshold = tonumber(ARGV[4])
|
||||
local min_calls = tonumber(ARGV[5])
|
||||
local ttl = tonumber(ARGV[6])
|
||||
|
||||
-- Trim both sets to the sliding window
|
||||
redis.call('ZREMRANGEBYSCORE', fail_key, 0, window_start)
|
||||
redis.call('ZREMRANGEBYSCORE', succ_key, 0, window_start)
|
||||
|
||||
-- Add the new outcome to the appropriate set
|
||||
if is_failure == 1 then
|
||||
redis.call('ZADD', fail_key, now, now)
|
||||
else
|
||||
redis.call('ZADD', succ_key, now, now)
|
||||
end
|
||||
|
||||
-- Refresh TTL on both keys
|
||||
redis.call('EXPIRE', fail_key, ttl)
|
||||
redis.call('EXPIRE', succ_key, ttl)
|
||||
|
||||
-- Count via ZCARD (O(1))
|
||||
local failures = redis.call('ZCARD', fail_key)
|
||||
local successes = redis.call('ZCARD', succ_key)
|
||||
local total = failures + successes
|
||||
|
||||
-- Check trip condition
|
||||
if total >= min_calls and (failures / total) >= threshold then
|
||||
return {1, failures, successes, total}
|
||||
end
|
||||
|
||||
return {0, failures, successes, total}
|
||||
"""
|
||||
|
||||
try:
|
||||
result = await redis_client.eval(
|
||||
lua_script,
|
||||
2,
|
||||
fail_key,
|
||||
succ_key,
|
||||
now,
|
||||
window_start,
|
||||
1 if is_failure else 0,
|
||||
threshold,
|
||||
min_calls,
|
||||
window_seconds + 60, # TTL with buffer
|
||||
)
|
||||
|
||||
tripped = bool(result[0])
|
||||
failure_count = int(result[1])
|
||||
success_count = int(result[2])
|
||||
total = int(result[3])
|
||||
failure_rate = failure_count / total if total > 0 else 0.0
|
||||
|
||||
if tripped:
|
||||
logger.warning(
|
||||
f"Circuit breaker TRIPPED for campaign {campaign_id}: "
|
||||
f"failure_rate={failure_rate:.2%} ({failure_count}/{total}) "
|
||||
f"threshold={threshold:.2%} window={window_seconds}s"
|
||||
)
|
||||
|
||||
stats = {
|
||||
"failure_rate": failure_rate,
|
||||
"failure_count": failure_count,
|
||||
"success_count": success_count,
|
||||
"threshold": threshold,
|
||||
"window_seconds": window_seconds,
|
||||
}
|
||||
return tripped, stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Circuit breaker error for campaign {campaign_id}: {e}")
|
||||
# Fail open - do NOT trip on errors
|
||||
return False, None
|
||||
|
||||
async def is_circuit_open(
|
||||
self,
|
||||
campaign_id: int,
|
||||
config: Optional[dict] = None,
|
||||
) -> Tuple[bool, Optional[dict]]:
|
||||
"""Check if the circuit breaker is in open (tripped) state without recording.
|
||||
|
||||
Used as a safety net check before scheduling batches.
|
||||
"""
|
||||
cb_config = {**DEFAULT_CIRCUIT_BREAKER_CONFIG, **(config or {})}
|
||||
|
||||
if not cb_config.get("enabled", True):
|
||||
return False, None
|
||||
|
||||
redis_client = await self._get_redis()
|
||||
|
||||
window_seconds = cb_config["window_seconds"]
|
||||
threshold = cb_config["failure_threshold"]
|
||||
min_calls = cb_config["min_calls_in_window"]
|
||||
|
||||
now = time.time()
|
||||
window_start = now - window_seconds
|
||||
|
||||
fail_key, succ_key = self._keys(campaign_id)
|
||||
|
||||
lua_script = """
|
||||
local fail_key = KEYS[1]
|
||||
local succ_key = KEYS[2]
|
||||
local window_start = tonumber(ARGV[1])
|
||||
local threshold = tonumber(ARGV[2])
|
||||
local min_calls = tonumber(ARGV[3])
|
||||
|
||||
-- Trim both sets
|
||||
redis.call('ZREMRANGEBYSCORE', fail_key, 0, window_start)
|
||||
redis.call('ZREMRANGEBYSCORE', succ_key, 0, window_start)
|
||||
|
||||
-- Count via ZCARD
|
||||
local failures = redis.call('ZCARD', fail_key)
|
||||
local successes = redis.call('ZCARD', succ_key)
|
||||
local total = failures + successes
|
||||
|
||||
if total >= min_calls and (failures / total) >= threshold then
|
||||
return {1, failures, successes, total}
|
||||
end
|
||||
|
||||
return {0, failures, successes, total}
|
||||
"""
|
||||
|
||||
try:
|
||||
result = await redis_client.eval(
|
||||
lua_script,
|
||||
2,
|
||||
fail_key,
|
||||
succ_key,
|
||||
window_start,
|
||||
threshold,
|
||||
min_calls,
|
||||
)
|
||||
|
||||
is_open = bool(result[0])
|
||||
failure_count = int(result[1])
|
||||
success_count = int(result[2])
|
||||
total = int(result[3])
|
||||
failure_rate = failure_count / total if total > 0 else 0.0
|
||||
|
||||
stats = {
|
||||
"failure_rate": failure_rate,
|
||||
"failure_count": failure_count,
|
||||
"success_count": success_count,
|
||||
"threshold": threshold,
|
||||
"window_seconds": window_seconds,
|
||||
}
|
||||
return is_open, stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Circuit breaker check error for campaign {campaign_id}: {e}")
|
||||
return False, None
|
||||
|
||||
async def record_and_evaluate(self, campaign_id: int, is_failure: bool) -> None:
|
||||
"""Record a call outcome, and if the breaker trips, pause the campaign.
|
||||
|
||||
This is the main entry point called from telephony status callbacks.
|
||||
It handles fetching campaign config, recording the outcome, and
|
||||
pausing + publishing an event if the breaker trips.
|
||||
|
||||
Exceptions are caught internally so this never disrupts the caller.
|
||||
"""
|
||||
try:
|
||||
campaign = await db_client.get_campaign_by_id(campaign_id)
|
||||
if not campaign or campaign.state != "running":
|
||||
return
|
||||
|
||||
cb_config = {}
|
||||
if campaign.orchestrator_metadata:
|
||||
cb_config = campaign.orchestrator_metadata.get("circuit_breaker", {})
|
||||
|
||||
tripped, stats = await self.record_call_outcome(
|
||||
campaign_id=campaign_id,
|
||||
is_failure=is_failure,
|
||||
config=cb_config,
|
||||
)
|
||||
|
||||
if tripped and stats:
|
||||
logger.warning(
|
||||
f"Circuit breaker tripped for campaign {campaign_id}, "
|
||||
f"pausing campaign. Stats: {stats}"
|
||||
)
|
||||
|
||||
await db_client.update_campaign(campaign_id=campaign_id, state="paused")
|
||||
|
||||
publisher = await get_campaign_event_publisher()
|
||||
await publisher.publish_circuit_breaker_tripped(
|
||||
campaign_id=campaign_id,
|
||||
failure_rate=stats["failure_rate"],
|
||||
failure_count=stats["failure_count"],
|
||||
success_count=stats["success_count"],
|
||||
threshold=stats["threshold"],
|
||||
window_seconds=stats["window_seconds"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in circuit breaker for campaign {campaign_id}: {e}")
|
||||
|
||||
async def reset(self, campaign_id: int) -> bool:
|
||||
"""Reset the circuit breaker state for a campaign.
|
||||
|
||||
Called when a campaign is resumed to give it a clean slate.
|
||||
"""
|
||||
redis_client = await self._get_redis()
|
||||
fail_key, succ_key = self._keys(campaign_id)
|
||||
|
||||
try:
|
||||
await redis_client.delete(fail_key, succ_key)
|
||||
logger.info(f"Circuit breaker reset for campaign {campaign_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error resetting circuit breaker for campaign {campaign_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def close(self):
|
||||
"""Close Redis connection."""
|
||||
if self.redis_client:
|
||||
await self.redis_client.close()
|
||||
self.redis_client = None
|
||||
|
||||
|
||||
# Global circuit breaker instance
|
||||
circuit_breaker = CircuitBreaker()
|
||||
|
|
@ -4,6 +4,7 @@ from typing import Any, Dict
|
|||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.campaign.circuit_breaker import circuit_breaker
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
|
||||
|
|
@ -67,6 +68,9 @@ class CampaignRunnerService:
|
|||
# stale campaign checker would do that if there are pending work.
|
||||
await db_client.update_campaign(campaign_id=campaign_id, state="running")
|
||||
|
||||
# Reset circuit breaker so the resumed campaign starts with a clean slate
|
||||
await circuit_breaker.reset(campaign_id)
|
||||
|
||||
logger.info(f"Campaign {campaign_id} resumed")
|
||||
|
||||
async def get_campaign_status(self, campaign_id: int) -> Dict[str, Any]:
|
||||
|
|
|
|||
504
api/tests/test_circuit_breaker.py
Normal file
504
api/tests/test_circuit_breaker.py
Normal file
|
|
@ -0,0 +1,504 @@
|
|||
"""
|
||||
Tests for Campaign Circuit Breaker.
|
||||
|
||||
These tests verify:
|
||||
1. Circuit breaker records call outcomes (success/failure)
|
||||
2. Circuit breaker trips when failure rate exceeds threshold
|
||||
3. Circuit breaker does NOT trip when below threshold or min_calls
|
||||
4. Circuit breaker reset clears state
|
||||
5. Integration: _process_status_update pauses campaign on circuit breaker trip
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# =============================================================================
|
||||
# Unit tests for CircuitBreaker class
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCircuitBreakerRecordOutcome:
|
||||
"""Tests for recording call outcomes and trip detection."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_trip_below_min_calls(self):
|
||||
"""Circuit breaker should NOT trip when total calls < min_calls_in_window."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
# Mock Redis to simulate a window with 3 failures out of 3 total
|
||||
# (100% failure rate, but below min_calls=5)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.eval = AsyncMock(
|
||||
return_value=[0, 3, 0, 3] # [not_tripped, failures, successes, total]
|
||||
)
|
||||
cb.redis_client = mock_redis
|
||||
|
||||
tripped, stats = await cb.record_call_outcome(campaign_id=1, is_failure=True)
|
||||
|
||||
assert tripped is False
|
||||
assert stats is not None
|
||||
assert stats["failure_count"] == 3
|
||||
assert stats["success_count"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trip_when_threshold_exceeded(self):
|
||||
"""Circuit breaker should trip when failure rate >= threshold and total >= min_calls."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
# Mock Redis to simulate: 4 failures out of 6 total = 66% > 50% threshold
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.eval = AsyncMock(
|
||||
return_value=[1, 4, 2, 6] # [tripped, failures, successes, total]
|
||||
)
|
||||
cb.redis_client = mock_redis
|
||||
|
||||
tripped, stats = await cb.record_call_outcome(campaign_id=1, is_failure=True)
|
||||
|
||||
assert tripped is True
|
||||
assert stats is not None
|
||||
assert stats["failure_rate"] == pytest.approx(4 / 6)
|
||||
assert stats["failure_count"] == 4
|
||||
assert stats["success_count"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_trip_below_threshold(self):
|
||||
"""Circuit breaker should NOT trip when failure rate < threshold."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
# Mock Redis: 2 failures out of 8 total = 25% < 50% threshold
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.eval = AsyncMock(
|
||||
return_value=[0, 2, 6, 8] # [not_tripped, failures, successes, total]
|
||||
)
|
||||
cb.redis_client = mock_redis
|
||||
|
||||
tripped, stats = await cb.record_call_outcome(campaign_id=1, is_failure=False)
|
||||
|
||||
assert tripped is False
|
||||
assert stats["failure_rate"] == pytest.approx(2 / 8)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disabled_circuit_breaker(self):
|
||||
"""Circuit breaker should not record or trip when disabled."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
mock_redis = AsyncMock()
|
||||
cb.redis_client = mock_redis
|
||||
|
||||
tripped, stats = await cb.record_call_outcome(
|
||||
campaign_id=1,
|
||||
is_failure=True,
|
||||
config={"enabled": False},
|
||||
)
|
||||
|
||||
assert tripped is False
|
||||
assert stats is None
|
||||
# Redis should not have been called
|
||||
mock_redis.eval.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_config_override(self):
|
||||
"""Per-campaign config should override defaults."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
# With custom threshold of 0.8, 4/6 = 66% should NOT trip
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.eval = AsyncMock(
|
||||
return_value=[0, 4, 2, 6] # Lua script respects the threshold we pass
|
||||
)
|
||||
cb.redis_client = mock_redis
|
||||
|
||||
tripped, stats = await cb.record_call_outcome(
|
||||
campaign_id=1,
|
||||
is_failure=True,
|
||||
config={"failure_threshold": 0.8, "min_calls_in_window": 3},
|
||||
)
|
||||
|
||||
assert tripped is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_error_fails_open(self):
|
||||
"""On Redis error, circuit breaker should fail open (not trip)."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.eval = AsyncMock(side_effect=Exception("Redis connection lost"))
|
||||
cb.redis_client = mock_redis
|
||||
|
||||
tripped, stats = await cb.record_call_outcome(campaign_id=1, is_failure=True)
|
||||
|
||||
assert tripped is False
|
||||
assert stats is None
|
||||
|
||||
|
||||
class TestCircuitBreakerIsOpen:
|
||||
"""Tests for read-only circuit state check."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_open_when_threshold_exceeded(self):
|
||||
"""is_circuit_open should return True when failure rate exceeds threshold."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.eval = AsyncMock(
|
||||
return_value=[1, 5, 2, 7] # [is_open, failures, successes, total]
|
||||
)
|
||||
cb.redis_client = mock_redis
|
||||
|
||||
is_open, stats = await cb.is_circuit_open(campaign_id=1)
|
||||
|
||||
assert is_open is True
|
||||
assert stats["failure_count"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_not_open_when_healthy(self):
|
||||
"""is_circuit_open should return False when failure rate is low."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.eval = AsyncMock(return_value=[0, 1, 9, 10])
|
||||
cb.redis_client = mock_redis
|
||||
|
||||
is_open, stats = await cb.is_circuit_open(campaign_id=1)
|
||||
|
||||
assert is_open is False
|
||||
assert stats["failure_rate"] == pytest.approx(0.1)
|
||||
|
||||
|
||||
class TestCircuitBreakerReset:
|
||||
"""Tests for circuit breaker reset."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_deletes_redis_keys(self):
|
||||
"""Reset should delete both failure and success keys for the campaign."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.delete = AsyncMock(return_value=2)
|
||||
cb.redis_client = mock_redis
|
||||
|
||||
result = await cb.reset(campaign_id=42)
|
||||
|
||||
assert result is True
|
||||
mock_redis.delete.assert_called_once_with("cb_failures:42", "cb_successes:42")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_on_redis_error(self):
|
||||
"""Reset should return False on Redis error."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.delete = AsyncMock(side_effect=Exception("Redis down"))
|
||||
cb.redis_client = mock_redis
|
||||
|
||||
result = await cb.reset(campaign_id=42)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for record_and_evaluate (the high-level method on CircuitBreaker)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRecordAndEvaluate:
|
||||
"""Test circuit_breaker.record_and_evaluate which handles the full
|
||||
flow: record outcome, check trip, pause campaign, publish event."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trips_and_pauses_campaign(self):
|
||||
"""When record_call_outcome returns tripped, campaign should be paused."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
mock_campaign = MagicMock()
|
||||
mock_campaign.id = 42
|
||||
mock_campaign.state = "running"
|
||||
mock_campaign.orchestrator_metadata = {}
|
||||
|
||||
stats = {
|
||||
"failure_rate": 0.6,
|
||||
"failure_count": 6,
|
||||
"success_count": 4,
|
||||
"threshold": 0.5,
|
||||
"window_seconds": 120,
|
||||
}
|
||||
|
||||
with (
|
||||
patch("api.services.campaign.circuit_breaker.db_client") as mock_db,
|
||||
patch(
|
||||
"api.services.campaign.circuit_breaker.get_campaign_event_publisher"
|
||||
) as mock_get_publisher,
|
||||
):
|
||||
mock_db.get_campaign_by_id = AsyncMock(return_value=mock_campaign)
|
||||
mock_db.update_campaign = AsyncMock()
|
||||
|
||||
mock_publisher = AsyncMock()
|
||||
mock_get_publisher.return_value = mock_publisher
|
||||
|
||||
# Mock the internal record_call_outcome to return tripped
|
||||
cb.record_call_outcome = AsyncMock(return_value=(True, stats))
|
||||
|
||||
await cb.record_and_evaluate(campaign_id=42, is_failure=True)
|
||||
|
||||
# Verify campaign was paused
|
||||
mock_db.update_campaign.assert_called_once_with(
|
||||
campaign_id=42, state="paused"
|
||||
)
|
||||
|
||||
# Verify event was published
|
||||
mock_publisher.publish_circuit_breaker_tripped.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_pause_when_not_tripped(self):
|
||||
"""When record_call_outcome does NOT trip, campaign should not be paused."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
mock_campaign = MagicMock()
|
||||
mock_campaign.id = 42
|
||||
mock_campaign.state = "running"
|
||||
mock_campaign.orchestrator_metadata = {}
|
||||
|
||||
with patch("api.services.campaign.circuit_breaker.db_client") as mock_db:
|
||||
mock_db.get_campaign_by_id = AsyncMock(return_value=mock_campaign)
|
||||
|
||||
cb.record_call_outcome = AsyncMock(return_value=(False, None))
|
||||
|
||||
await cb.record_and_evaluate(campaign_id=42, is_failure=False)
|
||||
|
||||
mock_db.update_campaign.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_campaign_not_running(self):
|
||||
"""Should skip when campaign is not in running state."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
mock_campaign = MagicMock()
|
||||
mock_campaign.id = 42
|
||||
mock_campaign.state = "paused"
|
||||
|
||||
with patch("api.services.campaign.circuit_breaker.db_client") as mock_db:
|
||||
mock_db.get_campaign_by_id = AsyncMock(return_value=mock_campaign)
|
||||
|
||||
cb.record_call_outcome = AsyncMock()
|
||||
|
||||
await cb.record_and_evaluate(campaign_id=42, is_failure=True)
|
||||
|
||||
# Should not even attempt to record
|
||||
cb.record_call_outcome.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reads_config_from_orchestrator_metadata(self):
|
||||
"""Should pass circuit_breaker config from orchestrator_metadata."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
custom_config = {"failure_threshold": 0.3, "min_calls_in_window": 10}
|
||||
mock_campaign = MagicMock()
|
||||
mock_campaign.id = 42
|
||||
mock_campaign.state = "running"
|
||||
mock_campaign.orchestrator_metadata = {"circuit_breaker": custom_config}
|
||||
|
||||
with patch("api.services.campaign.circuit_breaker.db_client") as mock_db:
|
||||
mock_db.get_campaign_by_id = AsyncMock(return_value=mock_campaign)
|
||||
|
||||
cb.record_call_outcome = AsyncMock(return_value=(False, None))
|
||||
|
||||
await cb.record_and_evaluate(campaign_id=42, is_failure=True)
|
||||
|
||||
cb.record_call_outcome.assert_called_once_with(
|
||||
campaign_id=42,
|
||||
is_failure=True,
|
||||
config=custom_config,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_is_swallowed(self):
|
||||
"""Errors inside record_and_evaluate should be caught, not raised."""
|
||||
from api.services.campaign.circuit_breaker import CircuitBreaker
|
||||
|
||||
cb = CircuitBreaker()
|
||||
|
||||
with patch("api.services.campaign.circuit_breaker.db_client") as mock_db:
|
||||
mock_db.get_campaign_by_id = AsyncMock(side_effect=Exception("DB exploded"))
|
||||
|
||||
# Should NOT raise
|
||||
await cb.record_and_evaluate(campaign_id=42, is_failure=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration tests: _process_status_update calls circuit_breaker
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestProcessStatusUpdateCircuitBreaker:
|
||||
"""Test that _process_status_update calls circuit_breaker.record_and_evaluate
|
||||
for campaign calls."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_status_calls_record_and_evaluate(self):
|
||||
"""When a campaign call fails, record_and_evaluate should be called
|
||||
with is_failure=True."""
|
||||
|
||||
from api.routes.telephony import StatusCallbackRequest, _process_status_update
|
||||
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 100
|
||||
mock_workflow_run.campaign_id = 42
|
||||
mock_workflow_run.queued_run_id = 10
|
||||
mock_workflow_run.state = "running"
|
||||
mock_workflow_run.logs = {"telephony_status_callbacks": []}
|
||||
mock_workflow_run.gathered_context = {}
|
||||
|
||||
status = StatusCallbackRequest(
|
||||
call_id="call-123",
|
||||
status="failed",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.routes.telephony.db_client") as mock_db,
|
||||
patch("api.routes.telephony.campaign_call_dispatcher") as mock_dispatcher,
|
||||
patch("api.routes.telephony.circuit_breaker") as mock_cb,
|
||||
patch(
|
||||
"api.routes.telephony.get_campaign_event_publisher"
|
||||
) as mock_get_publisher,
|
||||
):
|
||||
mock_db.get_workflow_run_by_id = AsyncMock(return_value=mock_workflow_run)
|
||||
mock_db.update_workflow_run = AsyncMock()
|
||||
|
||||
mock_dispatcher.release_call_slot = AsyncMock(return_value=True)
|
||||
mock_cb.record_and_evaluate = AsyncMock()
|
||||
|
||||
mock_publisher = AsyncMock()
|
||||
mock_get_publisher.return_value = mock_publisher
|
||||
|
||||
await _process_status_update(100, status)
|
||||
|
||||
mock_cb.record_and_evaluate.assert_called_once_with(42, is_failure=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_status_calls_record_and_evaluate(self):
|
||||
"""When a campaign call succeeds, record_and_evaluate should be called
|
||||
with is_failure=False."""
|
||||
|
||||
from api.routes.telephony import StatusCallbackRequest, _process_status_update
|
||||
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 100
|
||||
mock_workflow_run.campaign_id = 42
|
||||
mock_workflow_run.state = "running"
|
||||
mock_workflow_run.logs = {"telephony_status_callbacks": []}
|
||||
mock_workflow_run.gathered_context = {}
|
||||
|
||||
status = StatusCallbackRequest(
|
||||
call_id="call-456",
|
||||
status="completed",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.routes.telephony.db_client") as mock_db,
|
||||
patch("api.routes.telephony.campaign_call_dispatcher") as mock_dispatcher,
|
||||
patch("api.routes.telephony.circuit_breaker") as mock_cb,
|
||||
):
|
||||
mock_db.get_workflow_run_by_id = AsyncMock(return_value=mock_workflow_run)
|
||||
mock_db.update_workflow_run = AsyncMock()
|
||||
|
||||
mock_dispatcher.release_call_slot = AsyncMock(return_value=True)
|
||||
mock_cb.record_and_evaluate = AsyncMock()
|
||||
|
||||
await _process_status_update(100, status)
|
||||
|
||||
mock_cb.record_and_evaluate.assert_called_once_with(42, is_failure=False)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_campaign_call_skips_circuit_breaker(self):
|
||||
"""Calls without campaign_id should not interact with circuit breaker."""
|
||||
|
||||
from api.routes.telephony import StatusCallbackRequest, _process_status_update
|
||||
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 100
|
||||
mock_workflow_run.campaign_id = None # Not a campaign call
|
||||
mock_workflow_run.state = "running"
|
||||
mock_workflow_run.logs = {"telephony_status_callbacks": []}
|
||||
mock_workflow_run.gathered_context = {}
|
||||
|
||||
status = StatusCallbackRequest(
|
||||
call_id="call-789",
|
||||
status="failed",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.routes.telephony.db_client") as mock_db,
|
||||
patch("api.routes.telephony.circuit_breaker") as mock_cb,
|
||||
):
|
||||
mock_db.get_workflow_run_by_id = AsyncMock(return_value=mock_workflow_run)
|
||||
mock_db.update_workflow_run = AsyncMock()
|
||||
|
||||
await _process_status_update(100, status)
|
||||
|
||||
# Circuit breaker should NOT be called for non-campaign calls
|
||||
mock_cb.record_and_evaluate.assert_not_called()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration test: resume_campaign resets circuit breaker
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestResumeCampaignResetsCircuitBreaker:
|
||||
"""Test that resuming a campaign resets the circuit breaker."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_resets_circuit_breaker(self):
|
||||
"""Resuming a paused campaign should reset the circuit breaker state."""
|
||||
from api.services.campaign.runner import CampaignRunnerService
|
||||
|
||||
mock_campaign = MagicMock()
|
||||
mock_campaign.id = 42
|
||||
mock_campaign.state = "paused"
|
||||
|
||||
with (
|
||||
patch("api.services.campaign.runner.db_client") as mock_db,
|
||||
patch("api.services.campaign.runner.circuit_breaker") as mock_cb,
|
||||
):
|
||||
mock_db.get_campaign_by_id = AsyncMock(return_value=mock_campaign)
|
||||
mock_db.update_campaign = AsyncMock()
|
||||
mock_cb.reset = AsyncMock(return_value=True)
|
||||
|
||||
runner = CampaignRunnerService()
|
||||
await runner.resume_campaign(42)
|
||||
|
||||
# Verify circuit breaker was reset
|
||||
mock_cb.reset.assert_called_once_with(42)
|
||||
|
||||
# Verify campaign state was updated
|
||||
mock_db.update_campaign.assert_called_once_with(
|
||||
campaign_id=42, state="running"
|
||||
)
|
||||
|
|
@ -48,7 +48,6 @@ from pipecat.turns.user_stop import ExternalUserTurnStopStrategy
|
|||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
# Short timeout for faster tests
|
||||
STOP_STRATEGY_TIMEOUT = 0.15
|
||||
# Delay to allow async processing
|
||||
|
|
@ -115,7 +114,15 @@ def _build_components(llm_steps=None):
|
|||
|
||||
turn_controller = user_agg._user_turn_controller
|
||||
|
||||
return injector, user_agg, stop_strategy, turn_controller, mock_llm, context, pipeline
|
||||
return (
|
||||
injector,
|
||||
user_agg,
|
||||
stop_strategy,
|
||||
turn_controller,
|
||||
mock_llm,
|
||||
context,
|
||||
pipeline,
|
||||
)
|
||||
|
||||
|
||||
async def _run_scenario(pipeline, inject_fn):
|
||||
|
|
@ -193,7 +200,9 @@ class TestUserTurnStopScenarios:
|
|||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(TranscriptionFrame("hello", "user-1", time_now_iso8601()))
|
||||
await injector.inject(
|
||||
TranscriptionFrame("hello", "user-1", time_now_iso8601())
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
|
|
@ -217,7 +226,9 @@ class TestUserTurnStopScenarios:
|
|||
assert stop_strategy._text == "", (
|
||||
f"Expected empty _text after clean turn, got '{stop_strategy._text}'"
|
||||
)
|
||||
assert not turn_ctrl._user_turn, "Expected _user_turn to be False after turn"
|
||||
assert not turn_ctrl._user_turn, (
|
||||
"Expected _user_turn to be False after turn"
|
||||
)
|
||||
assert mock_llm.get_current_step() == 1, (
|
||||
f"Expected 1 LLM call (turn 2 only), got {mock_llm.get_current_step()}"
|
||||
)
|
||||
|
|
@ -257,7 +268,9 @@ class TestUserTurnStopScenarios:
|
|||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(TranscriptionFrame("hello", "user-1", time_now_iso8601()))
|
||||
await injector.inject(
|
||||
TranscriptionFrame("hello", "user-1", time_now_iso8601())
|
||||
)
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Bot stops -> unmuted
|
||||
|
|
@ -329,7 +342,9 @@ class TestUserTurnStopScenarios:
|
|||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# TranscriptionFrame arrives AFTER unmute -> reaches stop strategy
|
||||
await injector.inject(TranscriptionFrame("hello", "user-1", time_now_iso8601()))
|
||||
await injector.inject(
|
||||
TranscriptionFrame("hello", "user-1", time_now_iso8601())
|
||||
)
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Install spy on trigger_user_turn_stopped to track every call
|
||||
|
|
@ -408,7 +423,9 @@ class TestUserTurnStopScenarios:
|
|||
# _aggregation is a separate concern from the stop strategy's _text.
|
||||
messages = context.messages
|
||||
user_messages = [m for m in messages if m.get("role") == "user"]
|
||||
assert len(user_messages) == 1, f"Expected 1 user message, got {len(user_messages)}"
|
||||
assert len(user_messages) == 1, (
|
||||
f"Expected 1 user message, got {len(user_messages)}"
|
||||
)
|
||||
user_text = user_messages[0]["content"]
|
||||
assert "hello" in user_text, (
|
||||
f"Expected 'hello' (from aggregator) in user message, got: '{user_text}'"
|
||||
|
|
@ -519,7 +536,9 @@ class TestUserTurnStopScenarios:
|
|||
await injector.inject(VADUserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
# Late transcription - but still during bot speaking
|
||||
await injector.inject(TranscriptionFrame("late hello", "user-1", time_now_iso8601()))
|
||||
await injector.inject(
|
||||
TranscriptionFrame("late hello", "user-1", time_now_iso8601())
|
||||
)
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
await injector.inject(BotStoppedSpeakingFrame())
|
||||
|
|
@ -651,7 +670,9 @@ class TestUserTurnStopScenarios:
|
|||
# The LLM received both "late hello" (dangling in aggregator from turn 1)
|
||||
# and "real speech" (from turn 2).
|
||||
user_messages = [m for m in context.messages if m.get("role") == "user"]
|
||||
assert len(user_messages) == 1, f"Expected 1 user message, got {len(user_messages)}"
|
||||
assert len(user_messages) == 1, (
|
||||
f"Expected 1 user message, got {len(user_messages)}"
|
||||
)
|
||||
user_text = user_messages[0]["content"]
|
||||
assert "late hello" in user_text, (
|
||||
f"Expected 'late hello' (from aggregator) in user message, got: '{user_text}'"
|
||||
|
|
@ -867,7 +888,9 @@ class TestUserTurnStopScenarios:
|
|||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Late transcription after unmute
|
||||
await injector.inject(TranscriptionFrame("first", "user-1", time_now_iso8601()))
|
||||
await injector.inject(
|
||||
TranscriptionFrame("first", "user-1", time_now_iso8601())
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
|
@ -890,7 +913,9 @@ class TestUserTurnStopScenarios:
|
|||
await injector.inject(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
await injector.inject(TranscriptionFrame("second", "user-1", time_now_iso8601()))
|
||||
await injector.inject(
|
||||
TranscriptionFrame("second", "user-1", time_now_iso8601())
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
|
@ -926,7 +951,9 @@ class TestUserTurnStopScenarios:
|
|||
user_text = user_messages[0]["content"]
|
||||
assert "first" in user_text, f"Expected 'first' in '{user_text}'"
|
||||
assert "second" in user_text, f"Expected 'second' in '{user_text}'"
|
||||
assert "actual speech" in user_text, f"Expected 'actual speech' in '{user_text}'"
|
||||
assert "actual speech" in user_text, (
|
||||
f"Expected 'actual speech' in '{user_text}'"
|
||||
)
|
||||
|
||||
await injector.inject(EndTaskFrame(), direction=FrameDirection.UPSTREAM)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue