Feat/campaign enhancements (#163)

* feat: add circuit breaker to safeguard

* feat: Add Circuit breaker in campaigns to safeguard against telephony failures

* feat: add schedules in campaigns
This commit is contained in:
Abhishek 2026-02-17 21:04:15 +05:30 committed by GitHub
parent 7552b6c819
commit fe4ea648e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 2037 additions and 149 deletions

View file

@ -104,6 +104,15 @@ DEFAULT_CAMPAIGN_RETRY_CONFIG = {
}
# Circuit breaker defaults for campaign call failure detection
DEFAULT_CIRCUIT_BREAKER_CONFIG = {
"enabled": True,
"failure_threshold": 0.5, # 50% failure rate trips the breaker
"window_seconds": 120, # 2-minute sliding window
"min_calls_in_window": 5, # Don't trip until at least 5 outcomes
}
TURN_SECRET = os.getenv("TURN_SECRET")
TURN_HOST = os.getenv("TURN_HOST", "localhost")
TURN_PORT = int(os.getenv("TURN_PORT", "3478"))

View file

@ -21,6 +21,7 @@ class CampaignClient(BaseDBClient):
organization_id: int,
retry_config: Optional[dict] = None,
max_concurrency: Optional[int] = None,
schedule_config: Optional[dict] = None,
) -> CampaignModel:
"""Create a new campaign"""
async with self.async_session() as session:
@ -28,6 +29,8 @@ class CampaignClient(BaseDBClient):
orchestrator_metadata = {}
if max_concurrency is not None:
orchestrator_metadata["max_concurrency"] = max_concurrency
if schedule_config is not None:
orchestrator_metadata["schedule_config"] = schedule_config
campaign = CampaignModel(
name=name,

View file

@ -1,9 +1,10 @@
import json
from datetime import datetime
from typing import List, Optional
from zoneinfo import ZoneInfo
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator, model_validator
from api.constants import DEFAULT_CAMPAIGN_RETRY_CONFIG, DEFAULT_ORG_CONCURRENCY_LIMIT
from api.db import db_client
@ -46,6 +47,28 @@ async def _get_from_numbers_count(organization_id: int) -> int:
return 0
async def _validate_max_concurrency(max_concurrency: int, organization_id: int) -> None:
"""Validate max_concurrency against org limit and configured phone numbers.
Raises HTTPException(400) if the value exceeds the effective limit.
"""
org_limit = await _get_org_concurrent_limit(organization_id)
from_numbers_count = await _get_from_numbers_count(organization_id)
effective_limit = (
min(org_limit, from_numbers_count) if from_numbers_count > 0 else org_limit
)
if max_concurrency > effective_limit:
if from_numbers_count > 0 and from_numbers_count < org_limit:
raise HTTPException(
status_code=400,
detail=f"max_concurrency ({max_concurrency}) cannot exceed {effective_limit}. You have {from_numbers_count} phone number(s) configured. Add more CLIs in telephony configuration to increase concurrency.",
)
raise HTTPException(
status_code=400,
detail=f"max_concurrency ({max_concurrency}) cannot exceed organization limit ({effective_limit})",
)
class RetryConfigRequest(BaseModel):
enabled: bool = True
max_retries: int = Field(default=2, ge=0, le=10)
@ -64,6 +87,45 @@ class RetryConfigResponse(BaseModel):
retry_on_voicemail: bool
class TimeSlotRequest(BaseModel):
day_of_week: int = Field(..., ge=0, le=6)
start_time: str = Field(..., pattern=r"^\d{2}:\d{2}$")
end_time: str = Field(..., pattern=r"^\d{2}:\d{2}$")
@model_validator(mode="after")
def validate_times(self):
if self.start_time >= self.end_time:
raise ValueError("start_time must be before end_time")
return self
class ScheduleConfigRequest(BaseModel):
enabled: bool = True
timezone: str = "UTC"
slots: List[TimeSlotRequest] = Field(..., min_length=1, max_length=50)
@field_validator("timezone")
@classmethod
def validate_timezone(cls, v: str) -> str:
try:
ZoneInfo(v)
except (KeyError, Exception):
raise ValueError(f"Invalid timezone: {v}")
return v
class TimeSlotResponse(BaseModel):
day_of_week: int
start_time: str
end_time: str
class ScheduleConfigResponse(BaseModel):
enabled: bool
timezone: str
slots: List[TimeSlotResponse]
class CreateCampaignRequest(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
workflow_id: int
@ -71,6 +133,14 @@ class CreateCampaignRequest(BaseModel):
source_id: str # Google Sheet URL or CSV file key
retry_config: Optional[RetryConfigRequest] = None
max_concurrency: Optional[int] = Field(default=None, ge=1, le=100)
schedule_config: Optional[ScheduleConfigRequest] = None
class UpdateCampaignRequest(BaseModel):
name: Optional[str] = Field(None, min_length=1, max_length=255)
retry_config: Optional[RetryConfigRequest] = None
max_concurrency: Optional[int] = Field(default=None, ge=1, le=100)
schedule_config: Optional[ScheduleConfigRequest] = None
class CampaignResponse(BaseModel):
@ -89,6 +159,7 @@ class CampaignResponse(BaseModel):
completed_at: Optional[datetime]
retry_config: RetryConfigResponse
max_concurrency: Optional[int] = None
schedule_config: Optional[ScheduleConfigResponse] = None
class CampaignsResponse(BaseModel):
@ -138,10 +209,18 @@ def _build_campaign_response(campaign, workflow_name: str) -> CampaignResponse:
else DEFAULT_CAMPAIGN_RETRY_CONFIG
)
# Get max_concurrency from orchestrator_metadata
# Get max_concurrency and schedule_config from orchestrator_metadata
max_concurrency = None
schedule_config = None
if campaign.orchestrator_metadata:
max_concurrency = campaign.orchestrator_metadata.get("max_concurrency")
sc = campaign.orchestrator_metadata.get("schedule_config")
if sc:
schedule_config = ScheduleConfigResponse(
enabled=sc.get("enabled", False),
timezone=sc.get("timezone", "UTC"),
slots=[TimeSlotResponse(**slot) for slot in sc.get("slots", [])],
)
return CampaignResponse(
id=campaign.id,
@ -159,6 +238,7 @@ def _build_campaign_response(campaign, workflow_name: str) -> CampaignResponse:
completed_at=campaign.completed_at,
retry_config=RetryConfigResponse(**retry_config),
max_concurrency=max_concurrency,
schedule_config=schedule_config,
)
@ -181,31 +261,21 @@ async def create_campaign(
if not validation_result.is_valid:
raise HTTPException(status_code=400, detail=validation_result.error.message)
# Validate max_concurrency against effective limit (min of org limit and from_numbers count)
if request.max_concurrency is not None:
org_limit = await _get_org_concurrent_limit(user.selected_organization_id)
from_numbers_count = await _get_from_numbers_count(
user.selected_organization_id
await _validate_max_concurrency(
request.max_concurrency, user.selected_organization_id
)
effective_limit = (
min(org_limit, from_numbers_count) if from_numbers_count > 0 else org_limit
)
if request.max_concurrency > effective_limit:
if from_numbers_count > 0 and from_numbers_count < org_limit:
raise HTTPException(
status_code=400,
detail=f"max_concurrency ({request.max_concurrency}) cannot exceed {effective_limit}. You have {from_numbers_count} phone number(s) configured. Add more CLIs in telephony configuration to increase concurrency.",
)
raise HTTPException(
status_code=400,
detail=f"max_concurrency ({request.max_concurrency}) cannot exceed organization limit ({effective_limit})",
)
# Build retry_config dict if provided
retry_config = None
if request.retry_config:
retry_config = request.retry_config.model_dump()
# Build schedule_config dict if provided
schedule_config = None
if request.schedule_config:
schedule_config = request.schedule_config.model_dump()
campaign = await db_client.create_campaign(
name=request.name,
workflow_id=request.workflow_id,
@ -215,6 +285,7 @@ async def create_campaign(
organization_id=user.selected_organization_id,
retry_config=retry_config,
max_concurrency=request.max_concurrency,
schedule_config=schedule_config,
)
return _build_campaign_response(campaign, workflow_name)
@ -322,6 +393,62 @@ async def pause_campaign(
return _build_campaign_response(campaign, workflow_name or "Unknown")
@router.patch("/{campaign_id}")
async def update_campaign(
campaign_id: int,
request: UpdateCampaignRequest,
user: UserModel = Depends(get_user),
) -> CampaignResponse:
"""Update campaign settings (name, retry config, max concurrency, schedule)"""
campaign = await db_client.get_campaign(campaign_id, user.selected_organization_id)
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
if campaign.state in ["completed", "failed"]:
raise HTTPException(
status_code=400,
detail=f"Cannot update a {campaign.state} campaign",
)
if request.max_concurrency is not None:
await _validate_max_concurrency(
request.max_concurrency, user.selected_organization_id
)
# Build update kwargs
update_kwargs = {}
if request.name is not None:
update_kwargs["name"] = request.name
if request.retry_config is not None:
update_kwargs["retry_config"] = request.retry_config.model_dump()
# Merge max_concurrency and schedule_config into orchestrator_metadata
metadata = campaign.orchestrator_metadata or {}
metadata_changed = False
if request.max_concurrency is not None:
metadata["max_concurrency"] = request.max_concurrency
metadata_changed = True
if request.schedule_config is not None:
metadata["schedule_config"] = request.schedule_config.model_dump()
metadata_changed = True
if metadata_changed:
update_kwargs["orchestrator_metadata"] = metadata
if update_kwargs:
await db_client.update_campaign(campaign_id=campaign_id, **update_kwargs)
# Re-fetch to return updated data
campaign = await db_client.get_campaign(campaign_id, user.selected_organization_id)
workflow_name = await db_client.get_workflow_name(campaign.workflow_id, user.id)
return _build_campaign_response(campaign, workflow_name or "Unknown")
@router.get("/{campaign_id}/runs")
async def get_campaign_runs(
campaign_id: int,

View file

@ -32,6 +32,7 @@ from api.errors.telephony_errors import TelephonyError
from api.services.auth.depends import get_user
from api.services.campaign.campaign_call_dispatcher import campaign_call_dispatcher
from api.services.campaign.campaign_event_publisher import get_campaign_event_publisher
from api.services.campaign.circuit_breaker import circuit_breaker
from api.services.quota_service import check_dograh_quota, check_dograh_quota_by_user_id
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
from api.services.telephony.factory import (
@ -760,6 +761,9 @@ async def _process_status_update(workflow_run_id: int, status: StatusCallbackReq
# Release concurrent slot if this was a campaign call
if workflow_run.campaign_id:
await campaign_call_dispatcher.release_call_slot(workflow_run_id)
await circuit_breaker.record_and_evaluate(
workflow_run.campaign_id, is_failure=False
)
# Mark workflow run as completed
await db_client.update_workflow_run(
@ -776,6 +780,9 @@ async def _process_status_update(workflow_run_id: int, status: StatusCallbackReq
# Release concurrent slot for terminal statuses if this was a campaign call
if workflow_run.campaign_id:
await campaign_call_dispatcher.release_call_slot(workflow_run_id)
await circuit_breaker.record_and_evaluate(
workflow_run.campaign_id, is_failure=True
)
# Check if retry is needed for campaign calls (busy/no-answer)
if status.status in ["busy", "no-answer"] and workflow_run.campaign_id:

View file

@ -33,6 +33,9 @@ class CampaignEventType(str, Enum):
RETRY_SCHEDULED = "retry_scheduled"
RETRY_FAILED = "retry_failed"
# Circuit breaker events
CIRCUIT_BREAKER_TRIPPED = "circuit_breaker_tripped"
class RetryReason(str, Enum):
"""Reasons for retry."""
@ -218,6 +221,18 @@ class RetryFailedEvent(BaseCampaignEvent):
last_reason: str = "" # RetryReason value
@dataclass
class CircuitBreakerTrippedEvent(BaseCampaignEvent):
"""Event sent when the circuit breaker trips and pauses a campaign."""
type: str = CampaignEventType.CIRCUIT_BREAKER_TRIPPED
failure_rate: float = 0.0
failure_count: int = 0
success_count: int = 0
threshold: float = 0.0
window_seconds: int = 0
def parse_campaign_event(data: str) -> Any:
"""Parse a campaign event message."""
try:
@ -239,6 +254,7 @@ def parse_campaign_event(data: str) -> Any:
CampaignEventType.RETRY_NEEDED: RetryNeededEvent,
CampaignEventType.RETRY_SCHEDULED: RetryScheduledEvent,
CampaignEventType.RETRY_FAILED: RetryFailedEvent,
CampaignEventType.CIRCUIT_BREAKER_TRIPPED: CircuitBreakerTrippedEvent,
}
event_class = event_class_map.get(event_type)

View file

@ -14,6 +14,7 @@ from api.services.campaign.campaign_event_protocol import (
BatchCompletedEvent,
BatchFailedEvent,
CampaignCompletedEvent,
CircuitBreakerTrippedEvent,
RetryNeededEvent,
SyncCompletedEvent,
)
@ -123,6 +124,32 @@ class CampaignEventPublisher:
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
async def publish_circuit_breaker_tripped(
self,
campaign_id: int,
failure_rate: float,
failure_count: int,
success_count: int,
threshold: float,
window_seconds: int,
):
"""Publish circuit breaker tripped event."""
event = CircuitBreakerTrippedEvent(
campaign_id=campaign_id,
failure_rate=failure_rate,
failure_count=failure_count,
success_count=success_count,
threshold=threshold,
window_seconds=window_seconds,
)
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
logger.warning(
f"Published circuit breaker tripped event for campaign {campaign_id}: "
f"failure_rate={failure_rate:.2%} ({failure_count} failures)"
)
# Global publisher instance with lazy Redis connection
async def get_campaign_event_publisher() -> CampaignEventPublisher:

View file

@ -14,6 +14,7 @@ import asyncio
import signal
from datetime import UTC, datetime, timedelta
from typing import Dict
from zoneinfo import ZoneInfo
import redis.asyncio as aioredis
from loguru import logger
@ -25,11 +26,13 @@ from api.enums import RedisChannel
from api.services.campaign.campaign_event_protocol import (
BatchCompletedEvent,
BatchFailedEvent,
CircuitBreakerTrippedEvent,
RetryNeededEvent,
SyncCompletedEvent,
parse_campaign_event,
)
from api.services.campaign.campaign_event_publisher import CampaignEventPublisher
from api.services.campaign.circuit_breaker import circuit_breaker
from api.tasks.arq import enqueue_job
from api.tasks.function_names import FunctionNames
@ -165,6 +168,14 @@ class CampaignOrchestrator:
await self._schedule_next_batch(campaign_id)
self._last_activity[campaign_id] = datetime.now(UTC)
elif isinstance(event, CircuitBreakerTrippedEvent):
# Circuit breaker tripped - clear state for this campaign
logger.warning(
f"campaign_id: {campaign_id} - Circuit breaker tripped event received: "
f"failure_rate={event.failure_rate:.2%}"
)
self._clear_campaign_state(campaign_id)
async def _handle_retry_event(self, event: RetryNeededEvent):
"""Process retry event and schedule if eligible (from campaign_retry_manager)."""
@ -274,6 +285,53 @@ class CampaignOrchestrator:
f"last reason: {reason}"
)
def _is_within_schedule(self, campaign: CampaignModel) -> bool:
"""Check if the current time falls within the campaign's schedule windows.
Returns True (allow scheduling) if:
- No schedule_config in metadata
- Schedule is disabled
- No slots configured
- Invalid timezone (fail open)
- Current time matches a slot
"""
if not campaign.orchestrator_metadata:
return True
schedule_config = campaign.orchestrator_metadata.get("schedule_config")
if not schedule_config:
return True
if not schedule_config.get("enabled", False):
return True
slots = schedule_config.get("slots")
if not slots:
return True
timezone_str = schedule_config.get("timezone", "UTC")
try:
tz = ZoneInfo(timezone_str)
except (KeyError, Exception):
logger.warning(
f"campaign_id: {campaign.id} - Invalid timezone '{timezone_str}' in schedule_config, "
f"failing open (allowing scheduling)"
)
return True
now = datetime.now(tz)
current_day = now.weekday() # 0=Monday through 6=Sunday
current_time = now.strftime("%H:%M")
for slot in slots:
if slot.get("day_of_week") == current_day:
start = slot.get("start_time", "")
end = slot.get("end_time", "")
if start <= current_time < end:
return True
return False
async def _schedule_next_batch(self, campaign_id: int):
"""Schedule next batch immediately if work available."""
@ -302,6 +360,40 @@ class CampaignOrchestrator:
)
return
# Check schedule window before scheduling
if not self._is_within_schedule(campaign):
logger.info(
f"campaign_id: {campaign_id} - Outside scheduled time window, skipping batch"
)
return
# Safety net: check circuit breaker before scheduling
cb_config = None
if campaign.orchestrator_metadata:
cb_config = campaign.orchestrator_metadata.get("circuit_breaker")
is_open, stats = await circuit_breaker.is_circuit_open(
campaign_id=campaign_id,
config=cb_config,
)
if is_open and stats:
logger.warning(
f"campaign_id: {campaign_id} - Circuit breaker is open, "
f"pausing campaign. Stats: {stats}"
)
await db_client.update_campaign(campaign_id=campaign_id, state="paused")
await self.publisher.publish_circuit_breaker_tripped(
campaign_id=campaign_id,
failure_rate=stats["failure_rate"],
failure_count=stats["failure_count"],
success_count=stats["success_count"],
threshold=stats["threshold"],
window_seconds=stats["window_seconds"],
)
self._clear_campaign_state(campaign_id)
return
# Check for available work (queued runs + due retries)
has_work = await self._has_pending_work(campaign_id)
@ -399,6 +491,12 @@ class CampaignOrchestrator:
if campaign_id not in self._batch_in_progress:
has_work = await self._has_pending_work(campaign_id)
if has_work:
if not self._is_within_schedule(campaign):
logger.info(
f"campaign_id: {campaign_id} - Found orphaned work but outside "
f"schedule window, skipping"
)
continue
logger.info(
f"campaign_id: {campaign_id} - Found orphaned work (likely new retries), "
f"scheduling batch to process"
@ -428,6 +526,12 @@ class CampaignOrchestrator:
# Check for any pending work
has_work = await self._has_pending_work(campaign_id)
if has_work:
# If outside schedule window, don't mark complete — work remains for next window
if not self._is_within_schedule(campaign):
logger.debug(
f"campaign_id: {campaign_id} - Outside schedule window with pending work, "
f"not marking complete"
)
return False
# Check in-memory last activity

View file

@ -0,0 +1,301 @@
"""Campaign circuit breaker for automatic pause on high failure rates.
Uses two Redis sorted sets (ZSETs) per campaign one for failures, one for
successes as sliding windows. ZCARD gives O(1) counts without iterating
members, keeping the Lua scripts simple.
"""
import time
from typing import Optional, Tuple
import redis.asyncio as aioredis
from loguru import logger
from api.constants import DEFAULT_CIRCUIT_BREAKER_CONFIG, REDIS_URL
from api.db import db_client
from api.services.campaign.campaign_event_publisher import get_campaign_event_publisher
class CircuitBreaker:
"""Sliding window circuit breaker for campaign call failures."""
def __init__(self):
self.redis_client: Optional[aioredis.Redis] = None
async def _get_redis(self) -> aioredis.Redis:
"""Get or create Redis connection."""
if self.redis_client is None:
self.redis_client = await aioredis.from_url(
REDIS_URL, decode_responses=True
)
return self.redis_client
@staticmethod
def _keys(campaign_id: int) -> Tuple[str, str]:
"""Return (failures_key, successes_key) for a campaign."""
return f"cb_failures:{campaign_id}", f"cb_successes:{campaign_id}"
async def record_call_outcome(
self,
campaign_id: int,
is_failure: bool,
config: Optional[dict] = None,
) -> Tuple[bool, Optional[dict]]:
"""Record a call outcome and check if the circuit breaker should trip.
Args:
campaign_id: The campaign ID.
is_failure: True if the call failed, False if succeeded.
config: Optional per-campaign circuit breaker config override.
Falls back to DEFAULT_CIRCUIT_BREAKER_CONFIG.
Returns:
Tuple of (tripped: bool, stats: dict or None).
If tripped is True, stats contains failure_rate, failure_count,
success_count, threshold, window_seconds.
"""
cb_config = {**DEFAULT_CIRCUIT_BREAKER_CONFIG, **(config or {})}
if not cb_config.get("enabled", True):
return False, None
redis_client = await self._get_redis()
window_seconds = cb_config["window_seconds"]
threshold = cb_config["failure_threshold"]
min_calls = cb_config["min_calls_in_window"]
now = time.time()
window_start = now - window_seconds
fail_key, succ_key = self._keys(campaign_id)
lua_script = """
local fail_key = KEYS[1]
local succ_key = KEYS[2]
local now = tonumber(ARGV[1])
local window_start = tonumber(ARGV[2])
local is_failure = tonumber(ARGV[3])
local threshold = tonumber(ARGV[4])
local min_calls = tonumber(ARGV[5])
local ttl = tonumber(ARGV[6])
-- Trim both sets to the sliding window
redis.call('ZREMRANGEBYSCORE', fail_key, 0, window_start)
redis.call('ZREMRANGEBYSCORE', succ_key, 0, window_start)
-- Add the new outcome to the appropriate set
if is_failure == 1 then
redis.call('ZADD', fail_key, now, now)
else
redis.call('ZADD', succ_key, now, now)
end
-- Refresh TTL on both keys
redis.call('EXPIRE', fail_key, ttl)
redis.call('EXPIRE', succ_key, ttl)
-- Count via ZCARD (O(1))
local failures = redis.call('ZCARD', fail_key)
local successes = redis.call('ZCARD', succ_key)
local total = failures + successes
-- Check trip condition
if total >= min_calls and (failures / total) >= threshold then
return {1, failures, successes, total}
end
return {0, failures, successes, total}
"""
try:
result = await redis_client.eval(
lua_script,
2,
fail_key,
succ_key,
now,
window_start,
1 if is_failure else 0,
threshold,
min_calls,
window_seconds + 60, # TTL with buffer
)
tripped = bool(result[0])
failure_count = int(result[1])
success_count = int(result[2])
total = int(result[3])
failure_rate = failure_count / total if total > 0 else 0.0
if tripped:
logger.warning(
f"Circuit breaker TRIPPED for campaign {campaign_id}: "
f"failure_rate={failure_rate:.2%} ({failure_count}/{total}) "
f"threshold={threshold:.2%} window={window_seconds}s"
)
stats = {
"failure_rate": failure_rate,
"failure_count": failure_count,
"success_count": success_count,
"threshold": threshold,
"window_seconds": window_seconds,
}
return tripped, stats
except Exception as e:
logger.error(f"Circuit breaker error for campaign {campaign_id}: {e}")
# Fail open - do NOT trip on errors
return False, None
async def is_circuit_open(
self,
campaign_id: int,
config: Optional[dict] = None,
) -> Tuple[bool, Optional[dict]]:
"""Check if the circuit breaker is in open (tripped) state without recording.
Used as a safety net check before scheduling batches.
"""
cb_config = {**DEFAULT_CIRCUIT_BREAKER_CONFIG, **(config or {})}
if not cb_config.get("enabled", True):
return False, None
redis_client = await self._get_redis()
window_seconds = cb_config["window_seconds"]
threshold = cb_config["failure_threshold"]
min_calls = cb_config["min_calls_in_window"]
now = time.time()
window_start = now - window_seconds
fail_key, succ_key = self._keys(campaign_id)
lua_script = """
local fail_key = KEYS[1]
local succ_key = KEYS[2]
local window_start = tonumber(ARGV[1])
local threshold = tonumber(ARGV[2])
local min_calls = tonumber(ARGV[3])
-- Trim both sets
redis.call('ZREMRANGEBYSCORE', fail_key, 0, window_start)
redis.call('ZREMRANGEBYSCORE', succ_key, 0, window_start)
-- Count via ZCARD
local failures = redis.call('ZCARD', fail_key)
local successes = redis.call('ZCARD', succ_key)
local total = failures + successes
if total >= min_calls and (failures / total) >= threshold then
return {1, failures, successes, total}
end
return {0, failures, successes, total}
"""
try:
result = await redis_client.eval(
lua_script,
2,
fail_key,
succ_key,
window_start,
threshold,
min_calls,
)
is_open = bool(result[0])
failure_count = int(result[1])
success_count = int(result[2])
total = int(result[3])
failure_rate = failure_count / total if total > 0 else 0.0
stats = {
"failure_rate": failure_rate,
"failure_count": failure_count,
"success_count": success_count,
"threshold": threshold,
"window_seconds": window_seconds,
}
return is_open, stats
except Exception as e:
logger.error(f"Circuit breaker check error for campaign {campaign_id}: {e}")
return False, None
async def record_and_evaluate(self, campaign_id: int, is_failure: bool) -> None:
"""Record a call outcome, and if the breaker trips, pause the campaign.
This is the main entry point called from telephony status callbacks.
It handles fetching campaign config, recording the outcome, and
pausing + publishing an event if the breaker trips.
Exceptions are caught internally so this never disrupts the caller.
"""
try:
campaign = await db_client.get_campaign_by_id(campaign_id)
if not campaign or campaign.state != "running":
return
cb_config = {}
if campaign.orchestrator_metadata:
cb_config = campaign.orchestrator_metadata.get("circuit_breaker", {})
tripped, stats = await self.record_call_outcome(
campaign_id=campaign_id,
is_failure=is_failure,
config=cb_config,
)
if tripped and stats:
logger.warning(
f"Circuit breaker tripped for campaign {campaign_id}, "
f"pausing campaign. Stats: {stats}"
)
await db_client.update_campaign(campaign_id=campaign_id, state="paused")
publisher = await get_campaign_event_publisher()
await publisher.publish_circuit_breaker_tripped(
campaign_id=campaign_id,
failure_rate=stats["failure_rate"],
failure_count=stats["failure_count"],
success_count=stats["success_count"],
threshold=stats["threshold"],
window_seconds=stats["window_seconds"],
)
except Exception as e:
logger.error(f"Error in circuit breaker for campaign {campaign_id}: {e}")
async def reset(self, campaign_id: int) -> bool:
"""Reset the circuit breaker state for a campaign.
Called when a campaign is resumed to give it a clean slate.
"""
redis_client = await self._get_redis()
fail_key, succ_key = self._keys(campaign_id)
try:
await redis_client.delete(fail_key, succ_key)
logger.info(f"Circuit breaker reset for campaign {campaign_id}")
return True
except Exception as e:
logger.error(
f"Error resetting circuit breaker for campaign {campaign_id}: {e}"
)
return False
async def close(self):
"""Close Redis connection."""
if self.redis_client:
await self.redis_client.close()
self.redis_client = None
# Global circuit breaker instance
circuit_breaker = CircuitBreaker()

View file

@ -4,6 +4,7 @@ from typing import Any, Dict
from loguru import logger
from api.db import db_client
from api.services.campaign.circuit_breaker import circuit_breaker
from api.tasks.arq import enqueue_job
from api.tasks.function_names import FunctionNames
@ -67,6 +68,9 @@ class CampaignRunnerService:
# stale campaign checker would do that if there are pending work.
await db_client.update_campaign(campaign_id=campaign_id, state="running")
# Reset circuit breaker so the resumed campaign starts with a clean slate
await circuit_breaker.reset(campaign_id)
logger.info(f"Campaign {campaign_id} resumed")
async def get_campaign_status(self, campaign_id: int) -> Dict[str, Any]:

View file

@ -0,0 +1,504 @@
"""
Tests for Campaign Circuit Breaker.
These tests verify:
1. Circuit breaker records call outcomes (success/failure)
2. Circuit breaker trips when failure rate exceeds threshold
3. Circuit breaker does NOT trip when below threshold or min_calls
4. Circuit breaker reset clears state
5. Integration: _process_status_update pauses campaign on circuit breaker trip
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# =============================================================================
# Unit tests for CircuitBreaker class
# =============================================================================
class TestCircuitBreakerRecordOutcome:
"""Tests for recording call outcomes and trip detection."""
@pytest.mark.asyncio
async def test_no_trip_below_min_calls(self):
"""Circuit breaker should NOT trip when total calls < min_calls_in_window."""
from api.services.campaign.circuit_breaker import CircuitBreaker
cb = CircuitBreaker()
# Mock Redis to simulate a window with 3 failures out of 3 total
# (100% failure rate, but below min_calls=5)
mock_redis = AsyncMock()
mock_redis.eval = AsyncMock(
return_value=[0, 3, 0, 3] # [not_tripped, failures, successes, total]
)
cb.redis_client = mock_redis
tripped, stats = await cb.record_call_outcome(campaign_id=1, is_failure=True)
assert tripped is False
assert stats is not None
assert stats["failure_count"] == 3
assert stats["success_count"] == 0
@pytest.mark.asyncio
async def test_trip_when_threshold_exceeded(self):
"""Circuit breaker should trip when failure rate >= threshold and total >= min_calls."""
from api.services.campaign.circuit_breaker import CircuitBreaker
cb = CircuitBreaker()
# Mock Redis to simulate: 4 failures out of 6 total = 66% > 50% threshold
mock_redis = AsyncMock()
mock_redis.eval = AsyncMock(
return_value=[1, 4, 2, 6] # [tripped, failures, successes, total]
)
cb.redis_client = mock_redis
tripped, stats = await cb.record_call_outcome(campaign_id=1, is_failure=True)
assert tripped is True
assert stats is not None
assert stats["failure_rate"] == pytest.approx(4 / 6)
assert stats["failure_count"] == 4
assert stats["success_count"] == 2
@pytest.mark.asyncio
async def test_no_trip_below_threshold(self):
"""Circuit breaker should NOT trip when failure rate < threshold."""
from api.services.campaign.circuit_breaker import CircuitBreaker
cb = CircuitBreaker()
# Mock Redis: 2 failures out of 8 total = 25% < 50% threshold
mock_redis = AsyncMock()
mock_redis.eval = AsyncMock(
return_value=[0, 2, 6, 8] # [not_tripped, failures, successes, total]
)
cb.redis_client = mock_redis
tripped, stats = await cb.record_call_outcome(campaign_id=1, is_failure=False)
assert tripped is False
assert stats["failure_rate"] == pytest.approx(2 / 8)
@pytest.mark.asyncio
async def test_disabled_circuit_breaker(self):
"""Circuit breaker should not record or trip when disabled."""
from api.services.campaign.circuit_breaker import CircuitBreaker
cb = CircuitBreaker()
mock_redis = AsyncMock()
cb.redis_client = mock_redis
tripped, stats = await cb.record_call_outcome(
campaign_id=1,
is_failure=True,
config={"enabled": False},
)
assert tripped is False
assert stats is None
# Redis should not have been called
mock_redis.eval.assert_not_called()
@pytest.mark.asyncio
async def test_custom_config_override(self):
"""Per-campaign config should override defaults."""
from api.services.campaign.circuit_breaker import CircuitBreaker
cb = CircuitBreaker()
# With custom threshold of 0.8, 4/6 = 66% should NOT trip
mock_redis = AsyncMock()
mock_redis.eval = AsyncMock(
return_value=[0, 4, 2, 6] # Lua script respects the threshold we pass
)
cb.redis_client = mock_redis
tripped, stats = await cb.record_call_outcome(
campaign_id=1,
is_failure=True,
config={"failure_threshold": 0.8, "min_calls_in_window": 3},
)
assert tripped is False
@pytest.mark.asyncio
async def test_redis_error_fails_open(self):
"""On Redis error, circuit breaker should fail open (not trip)."""
from api.services.campaign.circuit_breaker import CircuitBreaker
cb = CircuitBreaker()
mock_redis = AsyncMock()
mock_redis.eval = AsyncMock(side_effect=Exception("Redis connection lost"))
cb.redis_client = mock_redis
tripped, stats = await cb.record_call_outcome(campaign_id=1, is_failure=True)
assert tripped is False
assert stats is None
class TestCircuitBreakerIsOpen:
"""Tests for read-only circuit state check."""
@pytest.mark.asyncio
async def test_is_open_when_threshold_exceeded(self):
"""is_circuit_open should return True when failure rate exceeds threshold."""
from api.services.campaign.circuit_breaker import CircuitBreaker
cb = CircuitBreaker()
mock_redis = AsyncMock()
mock_redis.eval = AsyncMock(
return_value=[1, 5, 2, 7] # [is_open, failures, successes, total]
)
cb.redis_client = mock_redis
is_open, stats = await cb.is_circuit_open(campaign_id=1)
assert is_open is True
assert stats["failure_count"] == 5
@pytest.mark.asyncio
async def test_is_not_open_when_healthy(self):
"""is_circuit_open should return False when failure rate is low."""
from api.services.campaign.circuit_breaker import CircuitBreaker
cb = CircuitBreaker()
mock_redis = AsyncMock()
mock_redis.eval = AsyncMock(return_value=[0, 1, 9, 10])
cb.redis_client = mock_redis
is_open, stats = await cb.is_circuit_open(campaign_id=1)
assert is_open is False
assert stats["failure_rate"] == pytest.approx(0.1)
class TestCircuitBreakerReset:
"""Tests for circuit breaker reset."""
@pytest.mark.asyncio
async def test_reset_deletes_redis_keys(self):
"""Reset should delete both failure and success keys for the campaign."""
from api.services.campaign.circuit_breaker import CircuitBreaker
cb = CircuitBreaker()
mock_redis = AsyncMock()
mock_redis.delete = AsyncMock(return_value=2)
cb.redis_client = mock_redis
result = await cb.reset(campaign_id=42)
assert result is True
mock_redis.delete.assert_called_once_with("cb_failures:42", "cb_successes:42")
@pytest.mark.asyncio
async def test_reset_on_redis_error(self):
"""Reset should return False on Redis error."""
from api.services.campaign.circuit_breaker import CircuitBreaker
cb = CircuitBreaker()
mock_redis = AsyncMock()
mock_redis.delete = AsyncMock(side_effect=Exception("Redis down"))
cb.redis_client = mock_redis
result = await cb.reset(campaign_id=42)
assert result is False
# =============================================================================
# Tests for record_and_evaluate (the high-level method on CircuitBreaker)
# =============================================================================
class TestRecordAndEvaluate:
"""Test circuit_breaker.record_and_evaluate which handles the full
flow: record outcome, check trip, pause campaign, publish event."""
@pytest.mark.asyncio
async def test_trips_and_pauses_campaign(self):
"""When record_call_outcome returns tripped, campaign should be paused."""
from api.services.campaign.circuit_breaker import CircuitBreaker
cb = CircuitBreaker()
mock_campaign = MagicMock()
mock_campaign.id = 42
mock_campaign.state = "running"
mock_campaign.orchestrator_metadata = {}
stats = {
"failure_rate": 0.6,
"failure_count": 6,
"success_count": 4,
"threshold": 0.5,
"window_seconds": 120,
}
with (
patch("api.services.campaign.circuit_breaker.db_client") as mock_db,
patch(
"api.services.campaign.circuit_breaker.get_campaign_event_publisher"
) as mock_get_publisher,
):
mock_db.get_campaign_by_id = AsyncMock(return_value=mock_campaign)
mock_db.update_campaign = AsyncMock()
mock_publisher = AsyncMock()
mock_get_publisher.return_value = mock_publisher
# Mock the internal record_call_outcome to return tripped
cb.record_call_outcome = AsyncMock(return_value=(True, stats))
await cb.record_and_evaluate(campaign_id=42, is_failure=True)
# Verify campaign was paused
mock_db.update_campaign.assert_called_once_with(
campaign_id=42, state="paused"
)
# Verify event was published
mock_publisher.publish_circuit_breaker_tripped.assert_called_once()
@pytest.mark.asyncio
async def test_no_pause_when_not_tripped(self):
"""When record_call_outcome does NOT trip, campaign should not be paused."""
from api.services.campaign.circuit_breaker import CircuitBreaker
cb = CircuitBreaker()
mock_campaign = MagicMock()
mock_campaign.id = 42
mock_campaign.state = "running"
mock_campaign.orchestrator_metadata = {}
with patch("api.services.campaign.circuit_breaker.db_client") as mock_db:
mock_db.get_campaign_by_id = AsyncMock(return_value=mock_campaign)
cb.record_call_outcome = AsyncMock(return_value=(False, None))
await cb.record_and_evaluate(campaign_id=42, is_failure=False)
mock_db.update_campaign.assert_not_called()
@pytest.mark.asyncio
async def test_skips_when_campaign_not_running(self):
"""Should skip when campaign is not in running state."""
from api.services.campaign.circuit_breaker import CircuitBreaker
cb = CircuitBreaker()
mock_campaign = MagicMock()
mock_campaign.id = 42
mock_campaign.state = "paused"
with patch("api.services.campaign.circuit_breaker.db_client") as mock_db:
mock_db.get_campaign_by_id = AsyncMock(return_value=mock_campaign)
cb.record_call_outcome = AsyncMock()
await cb.record_and_evaluate(campaign_id=42, is_failure=True)
# Should not even attempt to record
cb.record_call_outcome.assert_not_called()
@pytest.mark.asyncio
async def test_reads_config_from_orchestrator_metadata(self):
"""Should pass circuit_breaker config from orchestrator_metadata."""
from api.services.campaign.circuit_breaker import CircuitBreaker
cb = CircuitBreaker()
custom_config = {"failure_threshold": 0.3, "min_calls_in_window": 10}
mock_campaign = MagicMock()
mock_campaign.id = 42
mock_campaign.state = "running"
mock_campaign.orchestrator_metadata = {"circuit_breaker": custom_config}
with patch("api.services.campaign.circuit_breaker.db_client") as mock_db:
mock_db.get_campaign_by_id = AsyncMock(return_value=mock_campaign)
cb.record_call_outcome = AsyncMock(return_value=(False, None))
await cb.record_and_evaluate(campaign_id=42, is_failure=True)
cb.record_call_outcome.assert_called_once_with(
campaign_id=42,
is_failure=True,
config=custom_config,
)
@pytest.mark.asyncio
async def test_error_is_swallowed(self):
"""Errors inside record_and_evaluate should be caught, not raised."""
from api.services.campaign.circuit_breaker import CircuitBreaker
cb = CircuitBreaker()
with patch("api.services.campaign.circuit_breaker.db_client") as mock_db:
mock_db.get_campaign_by_id = AsyncMock(side_effect=Exception("DB exploded"))
# Should NOT raise
await cb.record_and_evaluate(campaign_id=42, is_failure=True)
# =============================================================================
# Integration tests: _process_status_update calls circuit_breaker
# =============================================================================
class TestProcessStatusUpdateCircuitBreaker:
"""Test that _process_status_update calls circuit_breaker.record_and_evaluate
for campaign calls."""
@pytest.mark.asyncio
async def test_failure_status_calls_record_and_evaluate(self):
"""When a campaign call fails, record_and_evaluate should be called
with is_failure=True."""
from api.routes.telephony import StatusCallbackRequest, _process_status_update
mock_workflow_run = MagicMock()
mock_workflow_run.id = 100
mock_workflow_run.campaign_id = 42
mock_workflow_run.queued_run_id = 10
mock_workflow_run.state = "running"
mock_workflow_run.logs = {"telephony_status_callbacks": []}
mock_workflow_run.gathered_context = {}
status = StatusCallbackRequest(
call_id="call-123",
status="failed",
)
with (
patch("api.routes.telephony.db_client") as mock_db,
patch("api.routes.telephony.campaign_call_dispatcher") as mock_dispatcher,
patch("api.routes.telephony.circuit_breaker") as mock_cb,
patch(
"api.routes.telephony.get_campaign_event_publisher"
) as mock_get_publisher,
):
mock_db.get_workflow_run_by_id = AsyncMock(return_value=mock_workflow_run)
mock_db.update_workflow_run = AsyncMock()
mock_dispatcher.release_call_slot = AsyncMock(return_value=True)
mock_cb.record_and_evaluate = AsyncMock()
mock_publisher = AsyncMock()
mock_get_publisher.return_value = mock_publisher
await _process_status_update(100, status)
mock_cb.record_and_evaluate.assert_called_once_with(42, is_failure=True)
@pytest.mark.asyncio
async def test_success_status_calls_record_and_evaluate(self):
"""When a campaign call succeeds, record_and_evaluate should be called
with is_failure=False."""
from api.routes.telephony import StatusCallbackRequest, _process_status_update
mock_workflow_run = MagicMock()
mock_workflow_run.id = 100
mock_workflow_run.campaign_id = 42
mock_workflow_run.state = "running"
mock_workflow_run.logs = {"telephony_status_callbacks": []}
mock_workflow_run.gathered_context = {}
status = StatusCallbackRequest(
call_id="call-456",
status="completed",
)
with (
patch("api.routes.telephony.db_client") as mock_db,
patch("api.routes.telephony.campaign_call_dispatcher") as mock_dispatcher,
patch("api.routes.telephony.circuit_breaker") as mock_cb,
):
mock_db.get_workflow_run_by_id = AsyncMock(return_value=mock_workflow_run)
mock_db.update_workflow_run = AsyncMock()
mock_dispatcher.release_call_slot = AsyncMock(return_value=True)
mock_cb.record_and_evaluate = AsyncMock()
await _process_status_update(100, status)
mock_cb.record_and_evaluate.assert_called_once_with(42, is_failure=False)
@pytest.mark.asyncio
async def test_non_campaign_call_skips_circuit_breaker(self):
"""Calls without campaign_id should not interact with circuit breaker."""
from api.routes.telephony import StatusCallbackRequest, _process_status_update
mock_workflow_run = MagicMock()
mock_workflow_run.id = 100
mock_workflow_run.campaign_id = None # Not a campaign call
mock_workflow_run.state = "running"
mock_workflow_run.logs = {"telephony_status_callbacks": []}
mock_workflow_run.gathered_context = {}
status = StatusCallbackRequest(
call_id="call-789",
status="failed",
)
with (
patch("api.routes.telephony.db_client") as mock_db,
patch("api.routes.telephony.circuit_breaker") as mock_cb,
):
mock_db.get_workflow_run_by_id = AsyncMock(return_value=mock_workflow_run)
mock_db.update_workflow_run = AsyncMock()
await _process_status_update(100, status)
# Circuit breaker should NOT be called for non-campaign calls
mock_cb.record_and_evaluate.assert_not_called()
# =============================================================================
# Integration test: resume_campaign resets circuit breaker
# =============================================================================
class TestResumeCampaignResetsCircuitBreaker:
"""Test that resuming a campaign resets the circuit breaker."""
@pytest.mark.asyncio
async def test_resume_resets_circuit_breaker(self):
"""Resuming a paused campaign should reset the circuit breaker state."""
from api.services.campaign.runner import CampaignRunnerService
mock_campaign = MagicMock()
mock_campaign.id = 42
mock_campaign.state = "paused"
with (
patch("api.services.campaign.runner.db_client") as mock_db,
patch("api.services.campaign.runner.circuit_breaker") as mock_cb,
):
mock_db.get_campaign_by_id = AsyncMock(return_value=mock_campaign)
mock_db.update_campaign = AsyncMock()
mock_cb.reset = AsyncMock(return_value=True)
runner = CampaignRunnerService()
await runner.resume_campaign(42)
# Verify circuit breaker was reset
mock_cb.reset.assert_called_once_with(42)
# Verify campaign state was updated
mock_db.update_campaign.assert_called_once_with(
campaign_id=42, state="running"
)

View file

@ -48,7 +48,6 @@ from pipecat.turns.user_stop import ExternalUserTurnStopStrategy
from pipecat.turns.user_turn_strategies import UserTurnStrategies
from pipecat.utils.time import time_now_iso8601
# Short timeout for faster tests
STOP_STRATEGY_TIMEOUT = 0.15
# Delay to allow async processing
@ -115,7 +114,15 @@ def _build_components(llm_steps=None):
turn_controller = user_agg._user_turn_controller
return injector, user_agg, stop_strategy, turn_controller, mock_llm, context, pipeline
return (
injector,
user_agg,
stop_strategy,
turn_controller,
mock_llm,
context,
pipeline,
)
async def _run_scenario(pipeline, inject_fn):
@ -193,7 +200,9 @@ class TestUserTurnStopScenarios:
await asyncio.sleep(0)
await injector.inject(UserStartedSpeakingFrame())
await asyncio.sleep(0)
await injector.inject(TranscriptionFrame("hello", "user-1", time_now_iso8601()))
await injector.inject(
TranscriptionFrame("hello", "user-1", time_now_iso8601())
)
await asyncio.sleep(0)
await injector.inject(UserStoppedSpeakingFrame())
await asyncio.sleep(0)
@ -217,7 +226,9 @@ class TestUserTurnStopScenarios:
assert stop_strategy._text == "", (
f"Expected empty _text after clean turn, got '{stop_strategy._text}'"
)
assert not turn_ctrl._user_turn, "Expected _user_turn to be False after turn"
assert not turn_ctrl._user_turn, (
"Expected _user_turn to be False after turn"
)
assert mock_llm.get_current_step() == 1, (
f"Expected 1 LLM call (turn 2 only), got {mock_llm.get_current_step()}"
)
@ -257,7 +268,9 @@ class TestUserTurnStopScenarios:
await asyncio.sleep(0)
await injector.inject(UserStartedSpeakingFrame())
await asyncio.sleep(0)
await injector.inject(TranscriptionFrame("hello", "user-1", time_now_iso8601()))
await injector.inject(
TranscriptionFrame("hello", "user-1", time_now_iso8601())
)
await asyncio.sleep(ASYNC_DELAY)
# Bot stops -> unmuted
@ -329,7 +342,9 @@ class TestUserTurnStopScenarios:
await asyncio.sleep(ASYNC_DELAY)
# TranscriptionFrame arrives AFTER unmute -> reaches stop strategy
await injector.inject(TranscriptionFrame("hello", "user-1", time_now_iso8601()))
await injector.inject(
TranscriptionFrame("hello", "user-1", time_now_iso8601())
)
await asyncio.sleep(ASYNC_DELAY)
# Install spy on trigger_user_turn_stopped to track every call
@ -408,7 +423,9 @@ class TestUserTurnStopScenarios:
# _aggregation is a separate concern from the stop strategy's _text.
messages = context.messages
user_messages = [m for m in messages if m.get("role") == "user"]
assert len(user_messages) == 1, f"Expected 1 user message, got {len(user_messages)}"
assert len(user_messages) == 1, (
f"Expected 1 user message, got {len(user_messages)}"
)
user_text = user_messages[0]["content"]
assert "hello" in user_text, (
f"Expected 'hello' (from aggregator) in user message, got: '{user_text}'"
@ -519,7 +536,9 @@ class TestUserTurnStopScenarios:
await injector.inject(VADUserStoppedSpeakingFrame())
await asyncio.sleep(0)
# Late transcription - but still during bot speaking
await injector.inject(TranscriptionFrame("late hello", "user-1", time_now_iso8601()))
await injector.inject(
TranscriptionFrame("late hello", "user-1", time_now_iso8601())
)
await asyncio.sleep(ASYNC_DELAY)
await injector.inject(BotStoppedSpeakingFrame())
@ -651,7 +670,9 @@ class TestUserTurnStopScenarios:
# The LLM received both "late hello" (dangling in aggregator from turn 1)
# and "real speech" (from turn 2).
user_messages = [m for m in context.messages if m.get("role") == "user"]
assert len(user_messages) == 1, f"Expected 1 user message, got {len(user_messages)}"
assert len(user_messages) == 1, (
f"Expected 1 user message, got {len(user_messages)}"
)
user_text = user_messages[0]["content"]
assert "late hello" in user_text, (
f"Expected 'late hello' (from aggregator) in user message, got: '{user_text}'"
@ -867,7 +888,9 @@ class TestUserTurnStopScenarios:
await asyncio.sleep(ASYNC_DELAY)
# Late transcription after unmute
await injector.inject(TranscriptionFrame("first", "user-1", time_now_iso8601()))
await injector.inject(
TranscriptionFrame("first", "user-1", time_now_iso8601())
)
await asyncio.sleep(0)
await injector.inject(UserStoppedSpeakingFrame())
await asyncio.sleep(TIMEOUT_WAIT)
@ -890,7 +913,9 @@ class TestUserTurnStopScenarios:
await injector.inject(BotStoppedSpeakingFrame())
await asyncio.sleep(ASYNC_DELAY)
await injector.inject(TranscriptionFrame("second", "user-1", time_now_iso8601()))
await injector.inject(
TranscriptionFrame("second", "user-1", time_now_iso8601())
)
await asyncio.sleep(0)
await injector.inject(UserStoppedSpeakingFrame())
await asyncio.sleep(TIMEOUT_WAIT)
@ -926,7 +951,9 @@ class TestUserTurnStopScenarios:
user_text = user_messages[0]["content"]
assert "first" in user_text, f"Expected 'first' in '{user_text}'"
assert "second" in user_text, f"Expected 'second' in '{user_text}'"
assert "actual speech" in user_text, f"Expected 'actual speech' in '{user_text}'"
assert "actual speech" in user_text, (
f"Expected 'actual speech' in '{user_text}'"
)
await injector.inject(EndTaskFrame(), direction=FrameDirection.UPSTREAM)