fix: make campaign process batch thread safe (#141)

* fix: dont schedule new batch on resume

* fix: make process_batch thread safe
This commit is contained in:
Abhishek 2026-01-30 14:48:00 +05:30 committed by GitHub
parent e9c5da16c5
commit 6827744327
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1012 additions and 230 deletions

View file

@ -5,28 +5,31 @@ Revises: d1dac4c93e61
Create Date: 2026-01-29 20:36:57.924887
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '02ffd7f23d1d'
down_revision: Union[str, None] = 'd1dac4c93e61'
revision: str = "02ffd7f23d1d"
down_revision: Union[str, None] = "d1dac4c93e61"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index('idx_workflow_runs_campaign_id', 'workflow_runs', ['campaign_id'], unique=False)
op.create_index('idx_workflow_runs_workflow_id', 'workflow_runs', ['workflow_id'], unique=False)
op.create_index(
"idx_workflow_runs_campaign_id", "workflow_runs", ["campaign_id"], unique=False
)
op.create_index(
"idx_workflow_runs_workflow_id", "workflow_runs", ["workflow_id"], unique=False
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('idx_workflow_runs_workflow_id', table_name='workflow_runs')
op.drop_index('idx_workflow_runs_campaign_id', table_name='workflow_runs')
op.drop_index("idx_workflow_runs_workflow_id", table_name="workflow_runs")
op.drop_index("idx_workflow_runs_campaign_id", table_name="workflow_runs")
# ### end Alembic commands ###

View file

@ -0,0 +1,54 @@
"""add processing state in queued runs
Revision ID: 34c8537dfde5
Revises: 02ffd7f23d1d
Create Date: 2026-01-30 14:40:29.905325
"""
from typing import Sequence, Union
from alembic import op
from alembic_postgresql_enum import TableReference
# revision identifiers, used by Alembic.
revision: str = "34c8537dfde5"
down_revision: Union[str, None] = "02ffd7f23d1d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_api_keys_key_hash"), table_name="api_keys")
op.create_index("ix_api_keys_key_hash", "api_keys", ["key_hash"], unique=False)
op.sync_enum_values(
enum_schema="public",
enum_name="queued_run_state",
new_values=["queued", "processed", "processing", "failed"],
affected_columns=[
TableReference(
table_schema="public", table_name="queued_runs", column_name="state"
)
],
enum_values_to_rename=[],
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values(
enum_schema="public",
enum_name="queued_run_state",
new_values=["queued", "processed", "failed"],
affected_columns=[
TableReference(
table_schema="public", table_name="queued_runs", column_name="state"
)
],
enum_values_to_rename=[],
)
op.drop_index("ix_api_keys_key_hash", table_name="api_keys")
op.create_index(op.f("ix_api_keys_key_hash"), "api_keys", ["key_hash"], unique=True)
# ### end Alembic commands ###

View file

@ -209,31 +209,6 @@ class CampaignClient(BaseDBClient):
await session.rollback()
raise e
async def get_queued_runs(
self,
campaign_id: int,
state: str = "queued",
limit: int = 10,
scheduled_for: Optional[bool] = None,
) -> list[QueuedRunModel]:
"""Get queued runs for processing, optionally filtering by scheduled status"""
async with self.async_session() as session:
query = select(QueuedRunModel).where(
QueuedRunModel.campaign_id == campaign_id,
QueuedRunModel.state == state,
)
# Filter by scheduled status if specified
if scheduled_for is True:
query = query.where(QueuedRunModel.scheduled_for.isnot(None))
elif scheduled_for is False:
query = query.where(QueuedRunModel.scheduled_for.is_(None))
query = query.order_by(QueuedRunModel.created_at).limit(limit)
result = await session.execute(query)
return list(result.scalars().all())
async def update_queued_run(self, queued_run_id: int, **kwargs) -> QueuedRunModel:
"""Update queued run"""
async with self.async_session() as session:
@ -284,26 +259,6 @@ class CampaignClient(BaseDBClient):
result = await session.execute(query)
return list(result.scalars().all())
# New methods for retry support
async def get_scheduled_queued_runs(
self, campaign_id: int, scheduled_before: datetime, limit: int = 10
) -> list[QueuedRunModel]:
"""Get scheduled queued runs that are due for processing"""
async with self.async_session() as session:
query = (
select(QueuedRunModel)
.where(
QueuedRunModel.campaign_id == campaign_id,
QueuedRunModel.state == "queued",
QueuedRunModel.scheduled_for.isnot(None),
QueuedRunModel.scheduled_for <= scheduled_before,
)
.order_by(QueuedRunModel.scheduled_for)
.limit(limit)
)
result = await session.execute(query)
return list(result.scalars().all())
async def create_queued_run(
self,
campaign_id: int,
@ -388,3 +343,82 @@ class CampaignClient(BaseDBClient):
query = select(func.count(QueuedRunModel.id)).where(*conditions)
result = await session.execute(query)
return result.scalar() or 0
async def claim_queued_runs_for_processing(
self,
campaign_id: int,
scheduled_before: datetime,
limit: int = 10,
) -> list[QueuedRunModel]:
"""
Atomically claim queued runs for processing using SELECT FOR UPDATE SKIP LOCKED.
This method is thread-safe - multiple workers can call it concurrently without
processing the same runs. It:
1. Prioritizes scheduled retries that are due
2. Falls back to regular queued runs if more slots available
3. Locks selected rows and marks them as 'processing' atomically
Returns: List of claimed QueuedRunModel objects
"""
async with self.async_session() as session:
claimed_runs = []
# First, get scheduled retries that are due (with lock)
scheduled_query = (
select(QueuedRunModel)
.where(
QueuedRunModel.campaign_id == campaign_id,
QueuedRunModel.state == "queued",
QueuedRunModel.scheduled_for.isnot(None),
QueuedRunModel.scheduled_for <= scheduled_before,
)
.order_by(QueuedRunModel.scheduled_for)
.limit(limit)
.with_for_update(skip_locked=True)
)
scheduled_result = await session.execute(scheduled_query)
scheduled_runs = list(scheduled_result.scalars().all())
# Mark scheduled runs as processing
for run in scheduled_runs:
run.state = "processing"
claimed_runs.append(run)
remaining_slots = limit - len(scheduled_runs)
# Then get regular queued runs if we have remaining slots
if remaining_slots > 0:
regular_query = (
select(QueuedRunModel)
.where(
QueuedRunModel.campaign_id == campaign_id,
QueuedRunModel.state == "queued",
QueuedRunModel.scheduled_for.is_(None),
)
.order_by(QueuedRunModel.created_at)
.limit(remaining_slots)
.with_for_update(skip_locked=True)
)
regular_result = await session.execute(regular_query)
regular_runs = list(regular_result.scalars().all())
# Mark regular runs as processing
for run in regular_runs:
run.state = "processing"
claimed_runs.append(run)
# Commit the state changes
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
# Refresh to get updated state
for run in claimed_runs:
await session.refresh(run)
return claimed_runs

View file

@ -598,7 +598,7 @@ class QueuedRunModel(Base):
source_uuid = Column(String, nullable=False)
context_variables = Column(JSON, nullable=False, default=dict)
state = Column(
Enum("queued", "processed", "failed", name="queued_run_state"),
Enum("queued", "processed", "processing", "failed", name="queued_run_state"),
nullable=False,
default="queued",
)

View file

@ -23,7 +23,7 @@ from api.db.workflow_run_client import WorkflowRunClient
from api.enums import CallType, OrganizationConfigurationKey, WorkflowRunState
from api.errors.telephony_errors import TelephonyError
from api.services.auth.depends import get_user
from api.services.campaign.call_dispatcher import campaign_call_dispatcher
from api.services.campaign.campaign_call_dispatcher import campaign_call_dispatcher
from api.services.campaign.campaign_event_publisher import get_campaign_event_publisher
from api.services.quota_service import check_dograh_quota, check_dograh_quota_by_user_id
from api.services.telephony.factory import (

View file

@ -9,6 +9,7 @@ from api.constants import DEFAULT_ORG_CONCURRENCY_LIMIT
from api.db import db_client
from api.db.models import QueuedRunModel, WorkflowRunModel
from api.enums import OrganizationConfigurationKey, WorkflowRunState
from api.services.campaign.errors import ConcurrentSlotAcquisitionError
from api.services.campaign.rate_limiter import rate_limiter
from api.services.telephony.base import TelephonyProvider
from api.services.telephony.factory import get_telephony_provider
@ -42,7 +43,8 @@ class CampaignCallDispatcher:
async def process_batch(self, campaign_id: int, batch_size: int = 10) -> int:
"""
Processes a batch of queued runs with priority for scheduled retries
Processes a batch of queued runs with priority for scheduled retries.
Thread-safe: uses SELECT FOR UPDATE SKIP LOCKED to prevent concurrent processing.
Returns: number of processed runs
"""
# Get campaign details
@ -57,41 +59,34 @@ class CampaignCallDispatcher:
)
return 0
# First, get any scheduled retries that are due
scheduled_runs = await db_client.get_scheduled_queued_runs(
# Atomically claim queued runs for processing (thread-safe)
# This uses SELECT FOR UPDATE SKIP LOCKED to prevent race conditions
queued_runs = await db_client.claim_queued_runs_for_processing(
campaign_id=campaign_id,
scheduled_before=datetime.now(UTC),
limit=batch_size,
)
remaining_slots = batch_size - len(scheduled_runs)
# Then get regular queued runs
regular_runs = []
if remaining_slots > 0:
regular_runs = await db_client.get_queued_runs(
campaign_id=campaign_id,
state="queued",
scheduled_for=False, # Exclude scheduled runs
limit=remaining_slots,
)
queued_runs = scheduled_runs + regular_runs
if not queued_runs:
logger.info(f"No more queued runs for campaign {campaign_id}")
return 0
processed_count = 0
for queued_run in queued_runs:
for i, queued_run in enumerate(queued_runs):
try:
# Apply rate limiting
# Apply rate limiting, i.e lets not initiate more than rate_limit_per_second
# calls per second. It is different than concurrency limit.
await self.apply_rate_limit(
campaign.organization_id, campaign.rate_limit_per_second
)
# Acquire concurrent slot - waits until a slot is available
slot_id = await self.acquire_concurrent_slot(
campaign.organization_id, campaign
)
# Dispatch the call
workflow_run = await self.dispatch_call(queued_run, campaign)
workflow_run = await self.dispatch_call(queued_run, campaign, slot_id)
# Update queued run as processed
await db_client.update_queued_run(
@ -108,6 +103,25 @@ class CampaignCallDispatcher:
campaign_id=campaign_id, processed_rows=campaign.processed_rows + 1
)
except ConcurrentSlotAcquisitionError:
# Revert all unprocessed runs (current and remaining) back to queued
# so they can be picked up again when campaign is resumed
for unprocessed_run in queued_runs[i:]:
try:
await db_client.update_queued_run(
queued_run_id=unprocessed_run.id,
state="queued",
)
logger.info(
f"Reverted queued run {unprocessed_run.id} back to queued state"
)
except Exception as revert_error:
logger.error(
f"Failed to revert queued run {unprocessed_run.id}: {revert_error}"
)
# Re-raise to propagate to process_campaign_batch
raise
except Exception as e:
logger.warning(f"Error processing queued run {queued_run.id}: {e}")
@ -129,54 +143,9 @@ class CampaignCallDispatcher:
return processed_count
async def dispatch_call(
self, queued_run: QueuedRunModel, campaign: any
self, queued_run: QueuedRunModel, campaign: any, slot_id: str
) -> Optional[WorkflowRunModel]:
"""Creates workflow run and initiates call with concurrent limiting"""
# Get concurrent limit for organization
org_concurrent_limit = await self.get_org_concurrent_limit(
campaign.organization_id
)
# Check for campaign-level max_concurrency in orchestrator_metadata
campaign_max_concurrency = None
if campaign.orchestrator_metadata:
campaign_max_concurrency = campaign.orchestrator_metadata.get(
"max_concurrency"
)
# Use the lower of campaign limit and org limit
if campaign_max_concurrency is not None:
max_concurrent = min(campaign_max_concurrency, org_concurrent_limit)
else:
max_concurrent = org_concurrent_limit
# Track wait time for alerting
wait_start = time.time()
slot_id = None
# Wait until we can acquire a concurrent slot
while True:
slot_id = await rate_limiter.try_acquire_concurrent_slot(
campaign.organization_id, max_concurrent
)
if slot_id:
break
# Check if we've been waiting too long
wait_time = time.time() - wait_start
if wait_time > 600: # 10 minutes
logger.error(
f"Waiting for concurrent slot for {wait_time:.1f}s, "
f"org: {campaign.organization_id}, campaign: {campaign.id}"
)
logger.debug(
f"Attempting to get a slot for {campaign.organization_id} {campaign.id}"
)
# Wait before retrying
await asyncio.sleep(1)
"""Creates workflow run and initiates call. Requires a pre-acquired slot_id."""
# Get workflow details
workflow = await db_client.get_workflow_by_id(campaign.workflow_id)
if not workflow:
@ -351,6 +320,66 @@ class CampaignCallDispatcher:
# Wait for next available slot
await asyncio.sleep(wait_time)
async def acquire_concurrent_slot(
self, organization_id: int, campaign: any, timeout: float = 600
) -> str:
"""
Acquires a concurrent call slot - waits if necessary until a slot is available.
Args:
organization_id: The organization ID
campaign: The campaign object
timeout: Maximum time to wait for a slot (default 10 minutes)
Returns the slot_id which must be released when the call completes.
Raises:
ConcurrentSlotAcquisitionError: If slot cannot be acquired within timeout
"""
# Get concurrent limit for organization
org_concurrent_limit = await self.get_org_concurrent_limit(organization_id)
# Check for campaign-level max_concurrency in orchestrator_metadata
campaign_max_concurrency = None
if campaign.orchestrator_metadata:
campaign_max_concurrency = campaign.orchestrator_metadata.get(
"max_concurrency"
)
# Use the lower of campaign limit and org limit
if campaign_max_concurrency is not None:
max_concurrent = min(campaign_max_concurrency, org_concurrent_limit)
else:
max_concurrent = org_concurrent_limit
# Track wait time for alerting
wait_start = time.time()
# Wait until we can acquire a concurrent slot
while True:
slot_id = await rate_limiter.try_acquire_concurrent_slot(
organization_id, max_concurrent
)
if slot_id:
return slot_id
# Check if we've been waiting too long
wait_time = time.time() - wait_start
if wait_time > timeout:
raise ConcurrentSlotAcquisitionError(
organization_id=organization_id,
campaign_id=campaign.id,
wait_time=wait_time,
)
logger.debug(
f"Attempting to get a slot for {organization_id} {campaign.id}, "
f"waited {wait_time:.1f}s"
)
# Wait before retrying
await asyncio.sleep(1)
async def release_call_slot(self, workflow_run_id: int) -> bool:
"""
Release concurrent slot when a call completes.

View file

@ -12,6 +12,7 @@ from api.constants import REDIS_URL
from api.enums import RedisChannel
from api.services.campaign.campaign_event_protocol import (
BatchCompletedEvent,
BatchFailedEvent,
CampaignCompletedEvent,
RetryNeededEvent,
SyncCompletedEvent,
@ -43,6 +44,23 @@ class CampaignEventPublisher:
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
async def publish_batch_failed(
self,
campaign_id: int,
error: str,
processed_count: int = 0,
metadata: Optional[Dict] = None,
):
"""Publish batch failed event."""
event = BatchFailedEvent(
campaign_id=campaign_id,
error=error,
processed_count=processed_count,
metadata=metadata,
)
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
async def publish_sync_completed(
self,
campaign_id: int,

View file

@ -23,11 +23,13 @@ from api.db import db_client
from api.db.models import CampaignModel, QueuedRunModel
from api.enums import RedisChannel
from api.services.campaign.campaign_event_protocol import (
CampaignCompletedEvent,
CampaignEventType,
BatchCompletedEvent,
BatchFailedEvent,
RetryNeededEvent,
SyncCompletedEvent,
parse_campaign_event,
)
from api.services.campaign.campaign_event_publisher import CampaignEventPublisher
from api.tasks.arq import enqueue_job
from api.tasks.function_names import FunctionNames
@ -37,6 +39,7 @@ class CampaignOrchestrator:
def __init__(self, redis_client: aioredis.Redis):
self.redis = redis_client
self.publisher = CampaignEventPublisher(redis_client)
self.completion_check_interval = 60 # 1 minute
self.completion_timeout = 3600 # 1 hour
self._processing_locks: Dict[int, datetime] = {} # prevent duplicate scheduling
@ -97,22 +100,21 @@ class CampaignOrchestrator:
async def _handle_event(self, event):
"""Handle campaign events including retry events."""
# Handle RetryNeededEvent
if isinstance(event, RetryNeededEvent):
await self._handle_retry_event(event)
return
# All events should have campaign_id
if not hasattr(event, "campaign_id") or not event.campaign_id:
logger.warning(f"Event missing campaign_id: {type(event).__name__}")
return
campaign_id = event.campaign_id
event_type = event.type
logger.debug(f"campaign_id: {campaign_id} - Received event: {event_type}")
logger.debug(
f"campaign_id: {campaign_id} - Received event: {type(event).__name__}"
)
if event_type == CampaignEventType.BATCH_COMPLETED:
if isinstance(event, RetryNeededEvent):
await self._handle_retry_event(event)
elif isinstance(event, BatchCompletedEvent):
# Clear the batch in progress flag
if campaign_id in self._batch_in_progress:
del self._batch_in_progress[campaign_id]
@ -120,11 +122,42 @@ class CampaignOrchestrator:
f"campaign_id: {campaign_id} - Batch completed, cleared in-progress flag"
)
# Check campaign state before scheduling next batch
campaign = await db_client.get_campaign_by_id(campaign_id)
if not campaign:
logger.error(f"campaign_id: {campaign_id} - Campaign not found")
self._clear_campaign_state(campaign_id)
return
if campaign.state != "running":
logger.info(
f"campaign_id: {campaign_id} - Campaign not in running state ({campaign.state}), "
f"not scheduling next batch"
)
self._clear_campaign_state(campaign_id)
return
# Immediately schedule next batch
await self._schedule_next_batch(campaign_id)
self._last_activity[campaign_id] = datetime.now(UTC)
elif event_type == CampaignEventType.SYNC_COMPLETED:
elif isinstance(event, BatchFailedEvent):
# Clear the batch in progress flag
if campaign_id in self._batch_in_progress:
del self._batch_in_progress[campaign_id]
logger.warning(
f"campaign_id: {campaign_id} - Batch failed: {event.error}, "
f"scheduling next batch to continue processing"
)
# Lets not schedule another batch, since we mark the campaign
# as failed just to be on the safe side from process_campaign_batch
# if a batch fails
self._last_activity[campaign_id] = datetime.now(UTC)
elif isinstance(event, SyncCompletedEvent):
# Start processing after sync
logger.info(
f"campaign_id: {campaign_id} - Sync completed, starting processing"
@ -309,6 +342,16 @@ class CampaignOrchestrator:
del self._processing_locks[campaign_id]
logger.debug(f"campaign_id: {campaign_id} - Released processing lock")
def _clear_campaign_state(self, campaign_id: int):
"""Clear all in-memory state for a campaign."""
if campaign_id in self._last_activity:
del self._last_activity[campaign_id]
if campaign_id in self._processing_locks:
del self._processing_locks[campaign_id]
if campaign_id in self._batch_in_progress:
del self._batch_in_progress[campaign_id]
logger.debug(f"campaign_id: {campaign_id} - Cleared all in-memory state")
async def _monitor_completion(self):
"""Periodically check for campaigns that should be marked complete."""
while self._running:
@ -457,30 +500,22 @@ class CampaignOrchestrator:
logger.info(f"campaign_id: {campaign_id} - Campaign marked as completed")
# Publish completion event using typed event
completion_event = CampaignCompletedEvent(
# Calculate duration if started_at is available
duration = None
if campaign.started_at:
duration = (datetime.now(UTC) - campaign.started_at).total_seconds()
# Publish completion event
await self.publisher.publish_campaign_completed(
campaign_id=campaign_id,
total_rows=campaign.total_rows or 0,
processed_rows=campaign.processed_rows,
failed_rows=campaign.failed_rows,
)
# Calculate duration if started_at is available
if campaign.started_at:
duration = (datetime.now(UTC) - campaign.started_at).total_seconds()
completion_event.duration_seconds = duration
await self.redis.publish(
RedisChannel.CAMPAIGN_EVENTS.value, completion_event.to_json()
duration_seconds=duration,
)
# Clean up in-memory state
if campaign_id in self._last_activity:
del self._last_activity[campaign_id]
if campaign_id in self._processing_locks:
del self._processing_locks[campaign_id]
if campaign_id in self._batch_in_progress:
del self._batch_in_progress[campaign_id]
self._clear_campaign_state(campaign_id)
except Exception as e:
logger.error(

View file

@ -0,0 +1,16 @@
"""
Campaign service exceptions.
"""
class ConcurrentSlotAcquisitionError(Exception):
"""Raised when a concurrent call slot cannot be acquired within the timeout period."""
def __init__(self, organization_id: int, campaign_id: int, wait_time: float):
self.organization_id = organization_id
self.campaign_id = campaign_id
self.wait_time = wait_time
super().__init__(
f"Failed to acquire concurrent slot for org {organization_id}, "
f"campaign {campaign_id} after waiting {wait_time:.1f}s"
)

View file

@ -0,0 +1,18 @@
### campaign_orchestrator.py (CampaignOrchestrator)
- Listens to retry events, batch completed event, sync completed events from redis pubsub, and schedules batches
- Monitors stale campaigns and schedules batches if one is not already scheduled
- Marks campaign as completed if no more tasks pending
### runner.py (CampaignRunnerService)
- Service layer to handle router requests, like run campaign, pause campaign, resume campaign, get campaign status etc.
### call_dispatcher.py (CampaignCallDispatcher)
- Ensures rate limit and concurrency limits and dispatches call using telephony provider
### campaign_tasks.py
- sync campaign from source
- process campaign batch

View file

@ -63,12 +63,10 @@ class CampaignRunnerService:
f"Campaign must be in 'paused' state to resume, current state: {campaign.state}"
)
# Update state to running
# Update state to running. Do not queue batch since campaign orchestrator's
# stale campaign checker would do that if there are pending work.
await db_client.update_campaign(campaign_id=campaign_id, state="running")
# Enqueue process batch task to continue processing
await enqueue_job(FunctionNames.PROCESS_CAMPAIGN_BATCH, campaign_id)
logger.info(f"Campaign {campaign_id} resumed")
async def get_campaign_status(self, campaign_id: int) -> Dict[str, Any]:

View file

@ -2,7 +2,7 @@ from loguru import logger
from api.db import db_client
from api.enums import WorkflowRunState
from api.services.campaign.call_dispatcher import campaign_call_dispatcher
from api.services.campaign.campaign_call_dispatcher import campaign_call_dispatcher
from api.services.pipecat.audio_config import AudioConfig
from api.services.pipecat.in_memory_buffers import (
InMemoryAudioBuffer,

View file

@ -42,7 +42,6 @@ REDIS_SETTINGS = RedisSettings(
)
from api.tasks.campaign_tasks import (
monitor_campaign_progress,
process_campaign_batch,
sync_campaign_source,
)
@ -62,7 +61,6 @@ class WorkerSettings:
process_workflow_completion,
sync_campaign_source,
process_campaign_batch,
monitor_campaign_progress,
process_knowledge_base_document,
]
cron_jobs = []

View file

@ -4,12 +4,11 @@ from typing import Dict
from loguru import logger
from api.db import db_client
from api.enums import RedisChannel
from api.services.campaign.call_dispatcher import campaign_call_dispatcher
from api.services.campaign.campaign_event_protocol import BatchFailedEvent
from api.services.campaign.campaign_call_dispatcher import campaign_call_dispatcher
from api.services.campaign.campaign_event_publisher import (
get_campaign_event_publisher,
)
from api.services.campaign.errors import ConcurrentSlotAcquisitionError
from api.services.campaign.source_sync_factory import get_sync_service
@ -95,6 +94,10 @@ async def process_campaign_batch(
- Updates queued_run state to 'processed'
- Updates campaign.processed_rows counter
- Publishes batch_completed event for orchestrator
# TODO: May be not fail the campaign immediately on a single batch failure
# and propagate the error to campaign orchestrator which can fail the campaign
# on some consecutive batch failures.
"""
logger.info(f"Processing batch for campaign {campaign_id}, batch_size={batch_size}")
@ -119,81 +122,34 @@ async def process_campaign_batch(
f"failed={failed_count}"
)
except Exception as e:
logger.error(f"Error processing batch for campaign {campaign_id}: {e}")
# Publish batch failed event
publisher = await get_campaign_event_publisher()
event = BatchFailedEvent(
campaign_id=campaign_id,
error=str(e),
processed_count=0,
except ConcurrentSlotAcquisitionError as e:
logger.warning(
f"Failed to acquire concurrent slot for campaign {campaign_id}: {e}"
)
await publisher.redis.publish(
RedisChannel.CAMPAIGN_EVENTS.value, event.to_json()
# Publish batch failed event with specific error
publisher = await get_campaign_event_publisher()
await publisher.publish_batch_failed(
campaign_id=campaign_id,
error=f"Concurrent slot acquisition timeout: {e}",
processed_count=0,
)
# Update campaign state to failed
await db_client.update_campaign(campaign_id=campaign_id, state="failed")
raise
async def monitor_campaign_progress(ctx: Dict, campaign_id: int) -> None:
"""
Phase 3: Monitors campaign completion
- Checks if all queued runs are in 'processed' state
- Queries workflow_runs for final call statistics
- Updates campaign state to 'completed'
- Calculates total calls made, successful, failed
- Triggers post-campaign integrations
"""
logger.info(f"Monitoring progress for campaign {campaign_id}")
try:
# Get campaign
campaign = await db_client.get_campaign_by_id(campaign_id)
if not campaign:
raise ValueError(f"Campaign {campaign_id} not found")
# Check if all runs are processed
pending_runs = await db_client.count_queued_runs(
campaign_id=campaign_id, state="queued"
)
if pending_runs > 0:
logger.info(f"Campaign {campaign_id} still has {pending_runs} pending runs")
return
# All runs processed, mark campaign as completed
await db_client.update_campaign(
campaign_id=campaign_id, state="completed", completed_at=datetime.now(UTC)
)
# Calculate statistics
workflow_runs = await db_client.get_workflow_runs_by_campaign(campaign_id)
total_calls = len(workflow_runs)
successful_calls = 0
failed_calls = 0
for run in workflow_runs:
callbacks = run.logs.get("telephony_status_callbacks", [])
if callbacks:
final_status = callbacks[-1].get("status", "").lower()
if final_status == "completed":
successful_calls += 1
elif final_status in ["failed", "busy", "no-answer"]:
failed_calls += 1
logger.info(
f"Campaign {campaign_id} completed: "
f"Total calls: {total_calls}, "
f"Successful: {successful_calls}, "
f"Failed: {failed_calls}"
)
# TODO: Trigger post-campaign integrations if configured
except Exception as e:
logger.error(f"Error monitoring campaign {campaign_id}: {e}")
logger.error(f"Error processing batch for campaign {campaign_id}: {e}")
# Publish batch failed event
publisher = await get_campaign_event_publisher()
await publisher.publish_batch_failed(
campaign_id=campaign_id,
error=str(e),
processed_count=0,
)
# Update campaign state to failed
await db_client.update_campaign(campaign_id=campaign_id, state="failed")
raise

View file

@ -5,5 +5,4 @@ class FunctionNames:
UPLOAD_VOICEMAIL_AUDIO_TO_S3 = "upload_voicemail_audio_to_s3"
SYNC_CAMPAIGN_SOURCE = "sync_campaign_source"
PROCESS_CAMPAIGN_BATCH = "process_campaign_batch"
MONITOR_CAMPAIGN_PROGRESS = "monitor_campaign_progress"
PROCESS_KNOWLEDGE_BASE_DOCUMENT = "process_knowledge_base_document"

View file

@ -3,7 +3,9 @@ from typing import Any, Dict, Optional
from unittest.mock import Mock
import pytest
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from api.constants import DATABASE_URL
from api.services.workflow.dto import (
EdgeDataDTO,
ExtractionVariableDTO,
@ -549,3 +551,22 @@ def three_node_workflow_no_variable_extraction() -> WorkflowGraph:
],
)
return WorkflowGraph(dto)
# =============================================================================
# Database fixtures for integration tests
# =============================================================================
@pytest.fixture(scope="session")
async def db_engine():
"""Create database engine for tests."""
engine = create_async_engine(DATABASE_URL, echo=False)
yield engine
await engine.dispose()
@pytest.fixture(scope="session")
async def db_session_factory(db_engine):
"""Create session factory for tests."""
return async_sessionmaker(bind=db_engine, expire_on_commit=False)

View file

@ -0,0 +1,603 @@
"""
Tests for CampaignCallDispatcher.process_batch method.
These tests verify:
1. Basic batch processing functionality
2. Thread-safety via SELECT FOR UPDATE SKIP LOCKED
3. Race condition handling when multiple workers process concurrently
"""
import asyncio
import uuid
from dataclasses import dataclass
from typing import List
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlalchemy import delete, text
from api.db.models import (
CampaignModel,
OrganizationModel,
QueuedRunModel,
UserModel,
WorkflowModel,
WorkflowRunModel,
)
from api.services.campaign.campaign_call_dispatcher import CampaignCallDispatcher
# =============================================================================
# Test-specific fixtures
# =============================================================================
@dataclass
class CampaignTestData:
"""Container for campaign test data IDs"""
organization_id: int
user_id: int
workflow_id: int
campaign_id: int
queued_run_ids: List[int]
@pytest.fixture
async def campaign_test_data(db_session_factory) -> CampaignTestData:
"""
Create test data for campaign processing tests.
Creates:
- Organization
- User
- Workflow
- Campaign (in 'running' state)
- 10 QueuedRuns (in 'queued' state)
"""
async with db_session_factory() as session:
# Create organization
org = OrganizationModel(
provider_id=f"test-org-{uuid.uuid4().hex[:8]}",
)
session.add(org)
await session.flush()
# Create user
user = UserModel(
provider_id=f"test-user-{uuid.uuid4().hex[:8]}",
selected_organization_id=org.id,
)
session.add(user)
await session.flush()
# Create workflow
workflow = WorkflowModel(
name=f"test-workflow-{uuid.uuid4().hex[:8]}",
user_id=user.id,
organization_id=org.id,
workflow_definition={
"nodes": [
{
"id": "1",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {"name": "Start", "prompt": "Hello"},
}
],
"edges": [],
},
template_context_variables={},
)
session.add(workflow)
await session.flush()
# Create campaign
campaign = CampaignModel(
name=f"test-campaign-{uuid.uuid4().hex[:8]}",
organization_id=org.id,
workflow_id=workflow.id,
created_by=user.id,
source_type="test",
source_id="test-source",
state="running",
rate_limit_per_second=100, # High limit to avoid rate limiting in tests
)
session.add(campaign)
await session.flush()
# Create queued runs
queued_run_ids = []
for i in range(10):
queued_run = QueuedRunModel(
campaign_id=campaign.id,
source_uuid=f"test-uuid-{i}",
context_variables={"phone_number": f"+1555000{i:04d}"},
state="queued",
)
session.add(queued_run)
await session.flush()
queued_run_ids.append(queued_run.id)
await session.commit()
test_data = CampaignTestData(
organization_id=org.id,
user_id=user.id,
workflow_id=workflow.id,
campaign_id=campaign.id,
queued_run_ids=queued_run_ids,
)
yield test_data
# Cleanup
async with db_session_factory() as cleanup_session:
# Delete in reverse order of dependencies
await cleanup_session.execute(
delete(QueuedRunModel).where(QueuedRunModel.campaign_id == campaign.id)
)
await cleanup_session.execute(
delete(WorkflowRunModel).where(
WorkflowRunModel.campaign_id == campaign.id
)
)
await cleanup_session.execute(
delete(CampaignModel).where(CampaignModel.id == campaign.id)
)
await cleanup_session.execute(
delete(WorkflowModel).where(WorkflowModel.id == workflow.id)
)
await cleanup_session.execute(
delete(UserModel).where(UserModel.id == user.id)
)
await cleanup_session.execute(
delete(OrganizationModel).where(OrganizationModel.id == org.id)
)
await cleanup_session.commit()
@pytest.fixture
def mock_dispatch_call():
"""Mock dispatch_call to track which runs were processed."""
processed_runs = []
async def mock_dispatch(queued_run, campaign, slot_id):
# Simulate some processing time
await asyncio.sleep(0.01)
processed_runs.append(queued_run.id)
# Return a mock workflow run
mock_run = MagicMock()
mock_run.id = len(processed_runs)
return mock_run
return mock_dispatch, processed_runs
@pytest.fixture
def mock_rate_limiter():
"""Mock rate limiter to always allow calls."""
async def mock_acquire_token(*args, **kwargs):
return True
async def mock_try_acquire_slot(*args, **kwargs):
return f"slot-{uuid.uuid4().hex[:8]}"
async def mock_release_slot(*args, **kwargs):
return True
async def mock_store_mapping(*args, **kwargs):
pass
async def mock_get_mapping(*args, **kwargs):
return None
async def mock_delete_mapping(*args, **kwargs):
pass
return {
"acquire_token": mock_acquire_token,
"try_acquire_concurrent_slot": mock_try_acquire_slot,
"release_concurrent_slot": mock_release_slot,
"store_workflow_slot_mapping": mock_store_mapping,
"get_workflow_slot_mapping": mock_get_mapping,
"delete_workflow_slot_mapping": mock_delete_mapping,
}
# =============================================================================
# Tests
# =============================================================================
class TestProcessBatchBasic:
"""Basic tests for process_batch functionality."""
@pytest.mark.asyncio
async def test_process_batch_processes_queued_runs(
self, campaign_test_data, mock_dispatch_call, mock_rate_limiter
):
"""Test that process_batch processes queued runs and marks them as processed."""
mock_dispatch, processed_runs = mock_dispatch_call
with patch(
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
) as mock_rl:
# Setup rate limiter mocks
mock_rl.acquire_token = AsyncMock(
side_effect=mock_rate_limiter["acquire_token"]
)
mock_rl.try_acquire_concurrent_slot = AsyncMock(
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
)
mock_rl.release_concurrent_slot = AsyncMock(
side_effect=mock_rate_limiter["release_concurrent_slot"]
)
mock_rl.store_workflow_slot_mapping = AsyncMock(
side_effect=mock_rate_limiter["store_workflow_slot_mapping"]
)
mock_rl.get_workflow_slot_mapping = AsyncMock(
side_effect=mock_rate_limiter["get_workflow_slot_mapping"]
)
mock_rl.delete_workflow_slot_mapping = AsyncMock(
side_effect=mock_rate_limiter["delete_workflow_slot_mapping"]
)
dispatcher = CampaignCallDispatcher()
# Mock dispatch_call
with patch.object(dispatcher, "dispatch_call", side_effect=mock_dispatch):
# Process batch of 5
processed_count = await dispatcher.process_batch(
campaign_id=campaign_test_data.campaign_id, batch_size=5
)
assert processed_count == 5
assert len(processed_runs) == 5
class TestProcessBatchConcurrency:
"""Tests for concurrent batch processing and database locking."""
@pytest.mark.asyncio
async def test_concurrent_process_batch_no_duplicate_processing(
self,
campaign_test_data,
mock_dispatch_call,
mock_rate_limiter,
db_session_factory,
):
"""
Test that two concurrent process_batch calls don't process the same runs.
This verifies the SELECT FOR UPDATE SKIP LOCKED mechanism works correctly.
"""
mock_dispatch, processed_runs = mock_dispatch_call
# Reset queued runs to 'queued' state for this test
async with db_session_factory() as session:
await session.execute(
text(
"UPDATE queued_runs SET state = 'queued' WHERE campaign_id = :campaign_id"
),
{"campaign_id": campaign_test_data.campaign_id},
)
await session.commit()
async def run_process_batch():
"""Helper to run process_batch with mocked dependencies."""
with patch(
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
) as mock_rl:
mock_rl.acquire_token = AsyncMock(
side_effect=mock_rate_limiter["acquire_token"]
)
mock_rl.try_acquire_concurrent_slot = AsyncMock(
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
)
mock_rl.release_concurrent_slot = AsyncMock(
side_effect=mock_rate_limiter["release_concurrent_slot"]
)
mock_rl.store_workflow_slot_mapping = AsyncMock(
side_effect=mock_rate_limiter["store_workflow_slot_mapping"]
)
mock_rl.get_workflow_slot_mapping = AsyncMock(
side_effect=mock_rate_limiter["get_workflow_slot_mapping"]
)
mock_rl.delete_workflow_slot_mapping = AsyncMock(
side_effect=mock_rate_limiter["delete_workflow_slot_mapping"]
)
dispatcher = CampaignCallDispatcher()
with patch.object(
dispatcher, "dispatch_call", side_effect=mock_dispatch
):
return await dispatcher.process_batch(
campaign_id=campaign_test_data.campaign_id, batch_size=5
)
# Run two process_batch calls concurrently
results = await asyncio.gather(
run_process_batch(),
run_process_batch(),
)
# Total processed should be 10 (all queued runs)
total_processed = sum(results)
assert total_processed == 10, f"Expected 10 total, got {total_processed}"
# Each run should be processed exactly once (no duplicates)
assert len(processed_runs) == 10, f"Expected 10 runs, got {len(processed_runs)}"
assert len(set(processed_runs)) == 10, "Duplicate runs were processed!"
@pytest.mark.asyncio
async def test_concurrent_process_batch_with_different_batch_sizes(
self,
campaign_test_data,
mock_dispatch_call,
mock_rate_limiter,
db_session_factory,
):
"""
Test concurrent processing with different batch sizes.
Worker 1 requests 3 runs, Worker 2 requests 7 runs.
Total should still be 10 with no duplicates.
"""
mock_dispatch, processed_runs = mock_dispatch_call
# Reset queued runs to 'queued' state
async with db_session_factory() as session:
await session.execute(
text(
"UPDATE queued_runs SET state = 'queued' WHERE campaign_id = :campaign_id"
),
{"campaign_id": campaign_test_data.campaign_id},
)
await session.commit()
async def run_process_batch(batch_size: int):
with patch(
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
) as mock_rl:
mock_rl.acquire_token = AsyncMock(
side_effect=mock_rate_limiter["acquire_token"]
)
mock_rl.try_acquire_concurrent_slot = AsyncMock(
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
)
mock_rl.release_concurrent_slot = AsyncMock(
side_effect=mock_rate_limiter["release_concurrent_slot"]
)
mock_rl.store_workflow_slot_mapping = AsyncMock(
side_effect=mock_rate_limiter["store_workflow_slot_mapping"]
)
mock_rl.get_workflow_slot_mapping = AsyncMock(
side_effect=mock_rate_limiter["get_workflow_slot_mapping"]
)
mock_rl.delete_workflow_slot_mapping = AsyncMock(
side_effect=mock_rate_limiter["delete_workflow_slot_mapping"]
)
dispatcher = CampaignCallDispatcher()
with patch.object(
dispatcher, "dispatch_call", side_effect=mock_dispatch
):
return await dispatcher.process_batch(
campaign_id=campaign_test_data.campaign_id,
batch_size=batch_size,
)
# Run with different batch sizes concurrently
results = await asyncio.gather(
run_process_batch(3),
run_process_batch(7),
)
total_processed = sum(results)
assert total_processed == 10
# Verify no duplicates
assert len(set(processed_runs)) == len(processed_runs)
@pytest.mark.asyncio
async def test_multiple_concurrent_workers(
self,
campaign_test_data,
mock_dispatch_call,
mock_rate_limiter,
db_session_factory,
):
"""
Test with many concurrent workers (simulating production scenario).
5 workers each requesting 4 runs from a pool of 10.
Should process all 10 exactly once.
"""
mock_dispatch, processed_runs = mock_dispatch_call
# Reset queued runs
async with db_session_factory() as session:
await session.execute(
text(
"UPDATE queued_runs SET state = 'queued' WHERE campaign_id = :campaign_id"
),
{"campaign_id": campaign_test_data.campaign_id},
)
await session.commit()
async def run_process_batch():
with patch(
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
) as mock_rl:
mock_rl.acquire_token = AsyncMock(
side_effect=mock_rate_limiter["acquire_token"]
)
mock_rl.try_acquire_concurrent_slot = AsyncMock(
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
)
mock_rl.release_concurrent_slot = AsyncMock(
side_effect=mock_rate_limiter["release_concurrent_slot"]
)
mock_rl.store_workflow_slot_mapping = AsyncMock(
side_effect=mock_rate_limiter["store_workflow_slot_mapping"]
)
mock_rl.get_workflow_slot_mapping = AsyncMock(
side_effect=mock_rate_limiter["get_workflow_slot_mapping"]
)
mock_rl.delete_workflow_slot_mapping = AsyncMock(
side_effect=mock_rate_limiter["delete_workflow_slot_mapping"]
)
dispatcher = CampaignCallDispatcher()
with patch.object(
dispatcher, "dispatch_call", side_effect=mock_dispatch
):
return await dispatcher.process_batch(
campaign_id=campaign_test_data.campaign_id, batch_size=4
)
# Run 5 workers concurrently
results = await asyncio.gather(*[run_process_batch() for _ in range(5)])
total_processed = sum(results)
assert total_processed == 10
# Verify no duplicates
assert len(set(processed_runs)) == 10, "Duplicate runs were processed!"
@pytest.mark.asyncio
async def test_processing_state_transition(
self,
campaign_test_data,
mock_dispatch_call,
mock_rate_limiter,
db_session_factory,
):
"""
Test that runs transition through processing -> processed states correctly.
"""
mock_dispatch, processed_runs = mock_dispatch_call
# Reset queued runs
async with db_session_factory() as session:
await session.execute(
text(
"UPDATE queued_runs SET state = 'queued' WHERE campaign_id = :campaign_id"
),
{"campaign_id": campaign_test_data.campaign_id},
)
await session.commit()
with patch(
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
) as mock_rl:
mock_rl.acquire_token = AsyncMock(
side_effect=mock_rate_limiter["acquire_token"]
)
mock_rl.try_acquire_concurrent_slot = AsyncMock(
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
)
mock_rl.release_concurrent_slot = AsyncMock(
side_effect=mock_rate_limiter["release_concurrent_slot"]
)
mock_rl.store_workflow_slot_mapping = AsyncMock(
side_effect=mock_rate_limiter["store_workflow_slot_mapping"]
)
mock_rl.get_workflow_slot_mapping = AsyncMock(
side_effect=mock_rate_limiter["get_workflow_slot_mapping"]
)
mock_rl.delete_workflow_slot_mapping = AsyncMock(
side_effect=mock_rate_limiter["delete_workflow_slot_mapping"]
)
dispatcher = CampaignCallDispatcher()
with patch.object(dispatcher, "dispatch_call", side_effect=mock_dispatch):
await dispatcher.process_batch(
campaign_id=campaign_test_data.campaign_id, batch_size=10
)
# Verify all runs are in 'processed' state
async with db_session_factory() as session:
result = await session.execute(
text(
"SELECT state, COUNT(*) as count FROM queued_runs "
"WHERE campaign_id = :campaign_id GROUP BY state"
),
{"campaign_id": campaign_test_data.campaign_id},
)
states = {row[0]: row[1] for row in result.fetchall()}
assert states.get("processed", 0) == 10
assert states.get("queued", 0) == 0
assert states.get("processing", 0) == 0
class TestProcessBatchEdgeCases:
"""Edge case tests for process_batch."""
@pytest.mark.asyncio
async def test_empty_queue(
self, campaign_test_data, mock_rate_limiter, db_session_factory
):
"""Test process_batch with no queued runs returns 0."""
# Set all runs to processed
async with db_session_factory() as session:
await session.execute(
text(
"UPDATE queued_runs SET state = 'processed' WHERE campaign_id = :campaign_id"
),
{"campaign_id": campaign_test_data.campaign_id},
)
await session.commit()
with patch(
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
) as mock_rl:
mock_rl.acquire_token = AsyncMock(
side_effect=mock_rate_limiter["acquire_token"]
)
mock_rl.try_acquire_concurrent_slot = AsyncMock(
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
)
dispatcher = CampaignCallDispatcher()
result = await dispatcher.process_batch(
campaign_id=campaign_test_data.campaign_id, batch_size=5
)
assert result == 0
@pytest.mark.asyncio
async def test_campaign_not_running(
self, campaign_test_data, mock_rate_limiter, db_session_factory
):
"""Test process_batch returns 0 if campaign is not in running state."""
# Set campaign to paused
async with db_session_factory() as session:
await session.execute(
text("UPDATE campaigns SET state = 'paused' WHERE id = :campaign_id"),
{"campaign_id": campaign_test_data.campaign_id},
)
await session.commit()
try:
dispatcher = CampaignCallDispatcher()
result = await dispatcher.process_batch(
campaign_id=campaign_test_data.campaign_id, batch_size=5
)
assert result == 0
finally:
# Restore campaign state
async with db_session_factory() as session:
await session.execute(
text(
"UPDATE campaigns SET state = 'running' WHERE id = :campaign_id"
),
{"campaign_id": campaign_test_data.campaign_id},
)
await session.commit()