From 682774432778ecd72601d6993fb26cc3b179e1b3 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Fri, 30 Jan 2026 14:48:00 +0530 Subject: [PATCH] fix: make campaign process batch thread safe (#141) * fix: dont schedule new batch on resume * fix: make process_batch thread safe --- .../02ffd7f23d1d_add_index_in_workflow_run.py | 19 +- ...de5_add_processing_state_in_queued_runs.py | 54 ++ api/db/campaign_client.py | 124 ++-- api/db/models.py | 2 +- api/routes/telephony.py | 2 +- ...patcher.py => campaign_call_dispatcher.py} | 163 +++-- .../campaign/campaign_event_publisher.py | 18 + .../campaign/campaign_orchestrator.py | 91 ++- api/services/campaign/errors.py | 16 + api/services/campaign/readme.md | 18 + api/services/campaign/runner.py | 6 +- api/services/pipecat/event_handlers.py | 2 +- api/tasks/arq.py | 2 - api/tasks/campaign_tasks.py | 100 +-- api/tasks/function_names.py | 1 - api/tests/conftest.py | 21 + api/tests/test_campaign_call_dispatcher.py | 603 ++++++++++++++++++ 17 files changed, 1012 insertions(+), 230 deletions(-) create mode 100644 api/alembic/versions/34c8537dfde5_add_processing_state_in_queued_runs.py rename api/services/campaign/{call_dispatcher.py => campaign_call_dispatcher.py} (81%) create mode 100644 api/services/campaign/errors.py create mode 100644 api/services/campaign/readme.md create mode 100644 api/tests/test_campaign_call_dispatcher.py diff --git a/api/alembic/versions/02ffd7f23d1d_add_index_in_workflow_run.py b/api/alembic/versions/02ffd7f23d1d_add_index_in_workflow_run.py index eca83e7..8005751 100644 --- a/api/alembic/versions/02ffd7f23d1d_add_index_in_workflow_run.py +++ b/api/alembic/versions/02ffd7f23d1d_add_index_in_workflow_run.py @@ -5,28 +5,31 @@ Revises: d1dac4c93e61 Create Date: 2026-01-29 20:36:57.924887 """ + from typing import Sequence, Union from alembic import op -import sqlalchemy as sa - # revision identifiers, used by Alembic. -revision: str = '02ffd7f23d1d' -down_revision: Union[str, None] = 'd1dac4c93e61' +revision: str = "02ffd7f23d1d" +down_revision: Union[str, None] = "d1dac4c93e61" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_index('idx_workflow_runs_campaign_id', 'workflow_runs', ['campaign_id'], unique=False) - op.create_index('idx_workflow_runs_workflow_id', 'workflow_runs', ['workflow_id'], unique=False) + op.create_index( + "idx_workflow_runs_campaign_id", "workflow_runs", ["campaign_id"], unique=False + ) + op.create_index( + "idx_workflow_runs_workflow_id", "workflow_runs", ["workflow_id"], unique=False + ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_index('idx_workflow_runs_workflow_id', table_name='workflow_runs') - op.drop_index('idx_workflow_runs_campaign_id', table_name='workflow_runs') + op.drop_index("idx_workflow_runs_workflow_id", table_name="workflow_runs") + op.drop_index("idx_workflow_runs_campaign_id", table_name="workflow_runs") # ### end Alembic commands ### diff --git a/api/alembic/versions/34c8537dfde5_add_processing_state_in_queued_runs.py b/api/alembic/versions/34c8537dfde5_add_processing_state_in_queued_runs.py new file mode 100644 index 0000000..cb59ff0 --- /dev/null +++ b/api/alembic/versions/34c8537dfde5_add_processing_state_in_queued_runs.py @@ -0,0 +1,54 @@ +"""add processing state in queued runs + +Revision ID: 34c8537dfde5 +Revises: 02ffd7f23d1d +Create Date: 2026-01-30 14:40:29.905325 + +""" + +from typing import Sequence, Union + +from alembic import op +from alembic_postgresql_enum import TableReference + +# revision identifiers, used by Alembic. +revision: str = "34c8537dfde5" +down_revision: Union[str, None] = "02ffd7f23d1d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_api_keys_key_hash"), table_name="api_keys") + op.create_index("ix_api_keys_key_hash", "api_keys", ["key_hash"], unique=False) + op.sync_enum_values( + enum_schema="public", + enum_name="queued_run_state", + new_values=["queued", "processed", "processing", "failed"], + affected_columns=[ + TableReference( + table_schema="public", table_name="queued_runs", column_name="state" + ) + ], + enum_values_to_rename=[], + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.sync_enum_values( + enum_schema="public", + enum_name="queued_run_state", + new_values=["queued", "processed", "failed"], + affected_columns=[ + TableReference( + table_schema="public", table_name="queued_runs", column_name="state" + ) + ], + enum_values_to_rename=[], + ) + op.drop_index("ix_api_keys_key_hash", table_name="api_keys") + op.create_index(op.f("ix_api_keys_key_hash"), "api_keys", ["key_hash"], unique=True) + # ### end Alembic commands ### diff --git a/api/db/campaign_client.py b/api/db/campaign_client.py index 398c1ef..d58d0a7 100644 --- a/api/db/campaign_client.py +++ b/api/db/campaign_client.py @@ -209,31 +209,6 @@ class CampaignClient(BaseDBClient): await session.rollback() raise e - async def get_queued_runs( - self, - campaign_id: int, - state: str = "queued", - limit: int = 10, - scheduled_for: Optional[bool] = None, - ) -> list[QueuedRunModel]: - """Get queued runs for processing, optionally filtering by scheduled status""" - async with self.async_session() as session: - query = select(QueuedRunModel).where( - QueuedRunModel.campaign_id == campaign_id, - QueuedRunModel.state == state, - ) - - # Filter by scheduled status if specified - if scheduled_for is True: - query = query.where(QueuedRunModel.scheduled_for.isnot(None)) - elif scheduled_for is False: - query = query.where(QueuedRunModel.scheduled_for.is_(None)) - - query = query.order_by(QueuedRunModel.created_at).limit(limit) - - result = await session.execute(query) - return list(result.scalars().all()) - async def update_queued_run(self, queued_run_id: int, **kwargs) -> QueuedRunModel: """Update queued run""" async with self.async_session() as session: @@ -284,26 +259,6 @@ class CampaignClient(BaseDBClient): result = await session.execute(query) return list(result.scalars().all()) - # New methods for retry support - async def get_scheduled_queued_runs( - self, campaign_id: int, scheduled_before: datetime, limit: int = 10 - ) -> list[QueuedRunModel]: - """Get scheduled queued runs that are due for processing""" - async with self.async_session() as session: - query = ( - select(QueuedRunModel) - .where( - QueuedRunModel.campaign_id == campaign_id, - QueuedRunModel.state == "queued", - QueuedRunModel.scheduled_for.isnot(None), - QueuedRunModel.scheduled_for <= scheduled_before, - ) - .order_by(QueuedRunModel.scheduled_for) - .limit(limit) - ) - result = await session.execute(query) - return list(result.scalars().all()) - async def create_queued_run( self, campaign_id: int, @@ -388,3 +343,82 @@ class CampaignClient(BaseDBClient): query = select(func.count(QueuedRunModel.id)).where(*conditions) result = await session.execute(query) return result.scalar() or 0 + + async def claim_queued_runs_for_processing( + self, + campaign_id: int, + scheduled_before: datetime, + limit: int = 10, + ) -> list[QueuedRunModel]: + """ + Atomically claim queued runs for processing using SELECT FOR UPDATE SKIP LOCKED. + + This method is thread-safe - multiple workers can call it concurrently without + processing the same runs. It: + 1. Prioritizes scheduled retries that are due + 2. Falls back to regular queued runs if more slots available + 3. Locks selected rows and marks them as 'processing' atomically + + Returns: List of claimed QueuedRunModel objects + """ + async with self.async_session() as session: + claimed_runs = [] + + # First, get scheduled retries that are due (with lock) + scheduled_query = ( + select(QueuedRunModel) + .where( + QueuedRunModel.campaign_id == campaign_id, + QueuedRunModel.state == "queued", + QueuedRunModel.scheduled_for.isnot(None), + QueuedRunModel.scheduled_for <= scheduled_before, + ) + .order_by(QueuedRunModel.scheduled_for) + .limit(limit) + .with_for_update(skip_locked=True) + ) + + scheduled_result = await session.execute(scheduled_query) + scheduled_runs = list(scheduled_result.scalars().all()) + + # Mark scheduled runs as processing + for run in scheduled_runs: + run.state = "processing" + claimed_runs.append(run) + + remaining_slots = limit - len(scheduled_runs) + + # Then get regular queued runs if we have remaining slots + if remaining_slots > 0: + regular_query = ( + select(QueuedRunModel) + .where( + QueuedRunModel.campaign_id == campaign_id, + QueuedRunModel.state == "queued", + QueuedRunModel.scheduled_for.is_(None), + ) + .order_by(QueuedRunModel.created_at) + .limit(remaining_slots) + .with_for_update(skip_locked=True) + ) + + regular_result = await session.execute(regular_query) + regular_runs = list(regular_result.scalars().all()) + + # Mark regular runs as processing + for run in regular_runs: + run.state = "processing" + claimed_runs.append(run) + + # Commit the state changes + try: + await session.commit() + except Exception as e: + await session.rollback() + raise e + + # Refresh to get updated state + for run in claimed_runs: + await session.refresh(run) + + return claimed_runs diff --git a/api/db/models.py b/api/db/models.py index ff544cb..409db0c 100644 --- a/api/db/models.py +++ b/api/db/models.py @@ -598,7 +598,7 @@ class QueuedRunModel(Base): source_uuid = Column(String, nullable=False) context_variables = Column(JSON, nullable=False, default=dict) state = Column( - Enum("queued", "processed", "failed", name="queued_run_state"), + Enum("queued", "processed", "processing", "failed", name="queued_run_state"), nullable=False, default="queued", ) diff --git a/api/routes/telephony.py b/api/routes/telephony.py index b031ec8..0ed6c87 100644 --- a/api/routes/telephony.py +++ b/api/routes/telephony.py @@ -23,7 +23,7 @@ from api.db.workflow_run_client import WorkflowRunClient from api.enums import CallType, OrganizationConfigurationKey, WorkflowRunState from api.errors.telephony_errors import TelephonyError from api.services.auth.depends import get_user -from api.services.campaign.call_dispatcher import campaign_call_dispatcher +from api.services.campaign.campaign_call_dispatcher import campaign_call_dispatcher from api.services.campaign.campaign_event_publisher import get_campaign_event_publisher from api.services.quota_service import check_dograh_quota, check_dograh_quota_by_user_id from api.services.telephony.factory import ( diff --git a/api/services/campaign/call_dispatcher.py b/api/services/campaign/campaign_call_dispatcher.py similarity index 81% rename from api/services/campaign/call_dispatcher.py rename to api/services/campaign/campaign_call_dispatcher.py index 523df77..fea1045 100644 --- a/api/services/campaign/call_dispatcher.py +++ b/api/services/campaign/campaign_call_dispatcher.py @@ -9,6 +9,7 @@ from api.constants import DEFAULT_ORG_CONCURRENCY_LIMIT from api.db import db_client from api.db.models import QueuedRunModel, WorkflowRunModel from api.enums import OrganizationConfigurationKey, WorkflowRunState +from api.services.campaign.errors import ConcurrentSlotAcquisitionError from api.services.campaign.rate_limiter import rate_limiter from api.services.telephony.base import TelephonyProvider from api.services.telephony.factory import get_telephony_provider @@ -42,7 +43,8 @@ class CampaignCallDispatcher: async def process_batch(self, campaign_id: int, batch_size: int = 10) -> int: """ - Processes a batch of queued runs with priority for scheduled retries + Processes a batch of queued runs with priority for scheduled retries. + Thread-safe: uses SELECT FOR UPDATE SKIP LOCKED to prevent concurrent processing. Returns: number of processed runs """ # Get campaign details @@ -57,41 +59,34 @@ class CampaignCallDispatcher: ) return 0 - # First, get any scheduled retries that are due - scheduled_runs = await db_client.get_scheduled_queued_runs( + # Atomically claim queued runs for processing (thread-safe) + # This uses SELECT FOR UPDATE SKIP LOCKED to prevent race conditions + queued_runs = await db_client.claim_queued_runs_for_processing( campaign_id=campaign_id, scheduled_before=datetime.now(UTC), limit=batch_size, ) - remaining_slots = batch_size - len(scheduled_runs) - - # Then get regular queued runs - regular_runs = [] - if remaining_slots > 0: - regular_runs = await db_client.get_queued_runs( - campaign_id=campaign_id, - state="queued", - scheduled_for=False, # Exclude scheduled runs - limit=remaining_slots, - ) - - queued_runs = scheduled_runs + regular_runs - if not queued_runs: logger.info(f"No more queued runs for campaign {campaign_id}") return 0 processed_count = 0 - for queued_run in queued_runs: + for i, queued_run in enumerate(queued_runs): try: - # Apply rate limiting + # Apply rate limiting, i.e lets not initiate more than rate_limit_per_second + # calls per second. It is different than concurrency limit. await self.apply_rate_limit( campaign.organization_id, campaign.rate_limit_per_second ) + # Acquire concurrent slot - waits until a slot is available + slot_id = await self.acquire_concurrent_slot( + campaign.organization_id, campaign + ) + # Dispatch the call - workflow_run = await self.dispatch_call(queued_run, campaign) + workflow_run = await self.dispatch_call(queued_run, campaign, slot_id) # Update queued run as processed await db_client.update_queued_run( @@ -108,6 +103,25 @@ class CampaignCallDispatcher: campaign_id=campaign_id, processed_rows=campaign.processed_rows + 1 ) + except ConcurrentSlotAcquisitionError: + # Revert all unprocessed runs (current and remaining) back to queued + # so they can be picked up again when campaign is resumed + for unprocessed_run in queued_runs[i:]: + try: + await db_client.update_queued_run( + queued_run_id=unprocessed_run.id, + state="queued", + ) + logger.info( + f"Reverted queued run {unprocessed_run.id} back to queued state" + ) + except Exception as revert_error: + logger.error( + f"Failed to revert queued run {unprocessed_run.id}: {revert_error}" + ) + # Re-raise to propagate to process_campaign_batch + raise + except Exception as e: logger.warning(f"Error processing queued run {queued_run.id}: {e}") @@ -129,54 +143,9 @@ class CampaignCallDispatcher: return processed_count async def dispatch_call( - self, queued_run: QueuedRunModel, campaign: any + self, queued_run: QueuedRunModel, campaign: any, slot_id: str ) -> Optional[WorkflowRunModel]: - """Creates workflow run and initiates call with concurrent limiting""" - # Get concurrent limit for organization - org_concurrent_limit = await self.get_org_concurrent_limit( - campaign.organization_id - ) - - # Check for campaign-level max_concurrency in orchestrator_metadata - campaign_max_concurrency = None - if campaign.orchestrator_metadata: - campaign_max_concurrency = campaign.orchestrator_metadata.get( - "max_concurrency" - ) - - # Use the lower of campaign limit and org limit - if campaign_max_concurrency is not None: - max_concurrent = min(campaign_max_concurrency, org_concurrent_limit) - else: - max_concurrent = org_concurrent_limit - - # Track wait time for alerting - wait_start = time.time() - slot_id = None - - # Wait until we can acquire a concurrent slot - while True: - slot_id = await rate_limiter.try_acquire_concurrent_slot( - campaign.organization_id, max_concurrent - ) - if slot_id: - break - - # Check if we've been waiting too long - wait_time = time.time() - wait_start - if wait_time > 600: # 10 minutes - logger.error( - f"Waiting for concurrent slot for {wait_time:.1f}s, " - f"org: {campaign.organization_id}, campaign: {campaign.id}" - ) - - logger.debug( - f"Attempting to get a slot for {campaign.organization_id} {campaign.id}" - ) - - # Wait before retrying - await asyncio.sleep(1) - + """Creates workflow run and initiates call. Requires a pre-acquired slot_id.""" # Get workflow details workflow = await db_client.get_workflow_by_id(campaign.workflow_id) if not workflow: @@ -351,6 +320,66 @@ class CampaignCallDispatcher: # Wait for next available slot await asyncio.sleep(wait_time) + async def acquire_concurrent_slot( + self, organization_id: int, campaign: any, timeout: float = 600 + ) -> str: + """ + Acquires a concurrent call slot - waits if necessary until a slot is available. + + Args: + organization_id: The organization ID + campaign: The campaign object + timeout: Maximum time to wait for a slot (default 10 minutes) + + Returns the slot_id which must be released when the call completes. + + Raises: + ConcurrentSlotAcquisitionError: If slot cannot be acquired within timeout + """ + # Get concurrent limit for organization + org_concurrent_limit = await self.get_org_concurrent_limit(organization_id) + + # Check for campaign-level max_concurrency in orchestrator_metadata + campaign_max_concurrency = None + if campaign.orchestrator_metadata: + campaign_max_concurrency = campaign.orchestrator_metadata.get( + "max_concurrency" + ) + + # Use the lower of campaign limit and org limit + if campaign_max_concurrency is not None: + max_concurrent = min(campaign_max_concurrency, org_concurrent_limit) + else: + max_concurrent = org_concurrent_limit + + # Track wait time for alerting + wait_start = time.time() + + # Wait until we can acquire a concurrent slot + while True: + slot_id = await rate_limiter.try_acquire_concurrent_slot( + organization_id, max_concurrent + ) + if slot_id: + return slot_id + + # Check if we've been waiting too long + wait_time = time.time() - wait_start + if wait_time > timeout: + raise ConcurrentSlotAcquisitionError( + organization_id=organization_id, + campaign_id=campaign.id, + wait_time=wait_time, + ) + + logger.debug( + f"Attempting to get a slot for {organization_id} {campaign.id}, " + f"waited {wait_time:.1f}s" + ) + + # Wait before retrying + await asyncio.sleep(1) + async def release_call_slot(self, workflow_run_id: int) -> bool: """ Release concurrent slot when a call completes. diff --git a/api/services/campaign/campaign_event_publisher.py b/api/services/campaign/campaign_event_publisher.py index 95b319a..4f6438d 100644 --- a/api/services/campaign/campaign_event_publisher.py +++ b/api/services/campaign/campaign_event_publisher.py @@ -12,6 +12,7 @@ from api.constants import REDIS_URL from api.enums import RedisChannel from api.services.campaign.campaign_event_protocol import ( BatchCompletedEvent, + BatchFailedEvent, CampaignCompletedEvent, RetryNeededEvent, SyncCompletedEvent, @@ -43,6 +44,23 @@ class CampaignEventPublisher: await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json()) + async def publish_batch_failed( + self, + campaign_id: int, + error: str, + processed_count: int = 0, + metadata: Optional[Dict] = None, + ): + """Publish batch failed event.""" + event = BatchFailedEvent( + campaign_id=campaign_id, + error=error, + processed_count=processed_count, + metadata=metadata, + ) + + await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json()) + async def publish_sync_completed( self, campaign_id: int, diff --git a/api/services/campaign/campaign_orchestrator.py b/api/services/campaign/campaign_orchestrator.py index 26346c5..d701e51 100644 --- a/api/services/campaign/campaign_orchestrator.py +++ b/api/services/campaign/campaign_orchestrator.py @@ -23,11 +23,13 @@ from api.db import db_client from api.db.models import CampaignModel, QueuedRunModel from api.enums import RedisChannel from api.services.campaign.campaign_event_protocol import ( - CampaignCompletedEvent, - CampaignEventType, + BatchCompletedEvent, + BatchFailedEvent, RetryNeededEvent, + SyncCompletedEvent, parse_campaign_event, ) +from api.services.campaign.campaign_event_publisher import CampaignEventPublisher from api.tasks.arq import enqueue_job from api.tasks.function_names import FunctionNames @@ -37,6 +39,7 @@ class CampaignOrchestrator: def __init__(self, redis_client: aioredis.Redis): self.redis = redis_client + self.publisher = CampaignEventPublisher(redis_client) self.completion_check_interval = 60 # 1 minute self.completion_timeout = 3600 # 1 hour self._processing_locks: Dict[int, datetime] = {} # prevent duplicate scheduling @@ -97,22 +100,21 @@ class CampaignOrchestrator: async def _handle_event(self, event): """Handle campaign events including retry events.""" - # Handle RetryNeededEvent - if isinstance(event, RetryNeededEvent): - await self._handle_retry_event(event) - return - # All events should have campaign_id if not hasattr(event, "campaign_id") or not event.campaign_id: logger.warning(f"Event missing campaign_id: {type(event).__name__}") return campaign_id = event.campaign_id - event_type = event.type - logger.debug(f"campaign_id: {campaign_id} - Received event: {event_type}") + logger.debug( + f"campaign_id: {campaign_id} - Received event: {type(event).__name__}" + ) - if event_type == CampaignEventType.BATCH_COMPLETED: + if isinstance(event, RetryNeededEvent): + await self._handle_retry_event(event) + + elif isinstance(event, BatchCompletedEvent): # Clear the batch in progress flag if campaign_id in self._batch_in_progress: del self._batch_in_progress[campaign_id] @@ -120,11 +122,42 @@ class CampaignOrchestrator: f"campaign_id: {campaign_id} - Batch completed, cleared in-progress flag" ) + # Check campaign state before scheduling next batch + campaign = await db_client.get_campaign_by_id(campaign_id) + if not campaign: + logger.error(f"campaign_id: {campaign_id} - Campaign not found") + self._clear_campaign_state(campaign_id) + return + + if campaign.state != "running": + logger.info( + f"campaign_id: {campaign_id} - Campaign not in running state ({campaign.state}), " + f"not scheduling next batch" + ) + self._clear_campaign_state(campaign_id) + return + # Immediately schedule next batch await self._schedule_next_batch(campaign_id) self._last_activity[campaign_id] = datetime.now(UTC) - elif event_type == CampaignEventType.SYNC_COMPLETED: + elif isinstance(event, BatchFailedEvent): + # Clear the batch in progress flag + if campaign_id in self._batch_in_progress: + del self._batch_in_progress[campaign_id] + + logger.warning( + f"campaign_id: {campaign_id} - Batch failed: {event.error}, " + f"scheduling next batch to continue processing" + ) + + # Lets not schedule another batch, since we mark the campaign + # as failed just to be on the safe side from process_campaign_batch + # if a batch fails + + self._last_activity[campaign_id] = datetime.now(UTC) + + elif isinstance(event, SyncCompletedEvent): # Start processing after sync logger.info( f"campaign_id: {campaign_id} - Sync completed, starting processing" @@ -309,6 +342,16 @@ class CampaignOrchestrator: del self._processing_locks[campaign_id] logger.debug(f"campaign_id: {campaign_id} - Released processing lock") + def _clear_campaign_state(self, campaign_id: int): + """Clear all in-memory state for a campaign.""" + if campaign_id in self._last_activity: + del self._last_activity[campaign_id] + if campaign_id in self._processing_locks: + del self._processing_locks[campaign_id] + if campaign_id in self._batch_in_progress: + del self._batch_in_progress[campaign_id] + logger.debug(f"campaign_id: {campaign_id} - Cleared all in-memory state") + async def _monitor_completion(self): """Periodically check for campaigns that should be marked complete.""" while self._running: @@ -457,30 +500,22 @@ class CampaignOrchestrator: logger.info(f"campaign_id: {campaign_id} - Campaign marked as completed") - # Publish completion event using typed event - completion_event = CampaignCompletedEvent( + # Calculate duration if started_at is available + duration = None + if campaign.started_at: + duration = (datetime.now(UTC) - campaign.started_at).total_seconds() + + # Publish completion event + await self.publisher.publish_campaign_completed( campaign_id=campaign_id, total_rows=campaign.total_rows or 0, processed_rows=campaign.processed_rows, failed_rows=campaign.failed_rows, - ) - - # Calculate duration if started_at is available - if campaign.started_at: - duration = (datetime.now(UTC) - campaign.started_at).total_seconds() - completion_event.duration_seconds = duration - - await self.redis.publish( - RedisChannel.CAMPAIGN_EVENTS.value, completion_event.to_json() + duration_seconds=duration, ) # Clean up in-memory state - if campaign_id in self._last_activity: - del self._last_activity[campaign_id] - if campaign_id in self._processing_locks: - del self._processing_locks[campaign_id] - if campaign_id in self._batch_in_progress: - del self._batch_in_progress[campaign_id] + self._clear_campaign_state(campaign_id) except Exception as e: logger.error( diff --git a/api/services/campaign/errors.py b/api/services/campaign/errors.py new file mode 100644 index 0000000..b8bd990 --- /dev/null +++ b/api/services/campaign/errors.py @@ -0,0 +1,16 @@ +""" +Campaign service exceptions. +""" + + +class ConcurrentSlotAcquisitionError(Exception): + """Raised when a concurrent call slot cannot be acquired within the timeout period.""" + + def __init__(self, organization_id: int, campaign_id: int, wait_time: float): + self.organization_id = organization_id + self.campaign_id = campaign_id + self.wait_time = wait_time + super().__init__( + f"Failed to acquire concurrent slot for org {organization_id}, " + f"campaign {campaign_id} after waiting {wait_time:.1f}s" + ) diff --git a/api/services/campaign/readme.md b/api/services/campaign/readme.md new file mode 100644 index 0000000..9b108aa --- /dev/null +++ b/api/services/campaign/readme.md @@ -0,0 +1,18 @@ +### campaign_orchestrator.py (CampaignOrchestrator) + +- Listens to retry events, batch completed event, sync completed events from redis pubsub, and schedules batches +- Monitors stale campaigns and schedules batches if one is not already scheduled +- Marks campaign as completed if no more tasks pending + +### runner.py (CampaignRunnerService) + +- Service layer to handle router requests, like run campaign, pause campaign, resume campaign, get campaign status etc. + +### call_dispatcher.py (CampaignCallDispatcher) + +- Ensures rate limit and concurrency limits and dispatches call using telephony provider + +### campaign_tasks.py + +- sync campaign from source +- process campaign batch diff --git a/api/services/campaign/runner.py b/api/services/campaign/runner.py index 008ee80..f1b397c 100644 --- a/api/services/campaign/runner.py +++ b/api/services/campaign/runner.py @@ -63,12 +63,10 @@ class CampaignRunnerService: f"Campaign must be in 'paused' state to resume, current state: {campaign.state}" ) - # Update state to running + # Update state to running. Do not queue batch since campaign orchestrator's + # stale campaign checker would do that if there are pending work. await db_client.update_campaign(campaign_id=campaign_id, state="running") - # Enqueue process batch task to continue processing - await enqueue_job(FunctionNames.PROCESS_CAMPAIGN_BATCH, campaign_id) - logger.info(f"Campaign {campaign_id} resumed") async def get_campaign_status(self, campaign_id: int) -> Dict[str, Any]: diff --git a/api/services/pipecat/event_handlers.py b/api/services/pipecat/event_handlers.py index a898639..fea4a6f 100644 --- a/api/services/pipecat/event_handlers.py +++ b/api/services/pipecat/event_handlers.py @@ -2,7 +2,7 @@ from loguru import logger from api.db import db_client from api.enums import WorkflowRunState -from api.services.campaign.call_dispatcher import campaign_call_dispatcher +from api.services.campaign.campaign_call_dispatcher import campaign_call_dispatcher from api.services.pipecat.audio_config import AudioConfig from api.services.pipecat.in_memory_buffers import ( InMemoryAudioBuffer, diff --git a/api/tasks/arq.py b/api/tasks/arq.py index 102e94e..c796a6d 100644 --- a/api/tasks/arq.py +++ b/api/tasks/arq.py @@ -42,7 +42,6 @@ REDIS_SETTINGS = RedisSettings( ) from api.tasks.campaign_tasks import ( - monitor_campaign_progress, process_campaign_batch, sync_campaign_source, ) @@ -62,7 +61,6 @@ class WorkerSettings: process_workflow_completion, sync_campaign_source, process_campaign_batch, - monitor_campaign_progress, process_knowledge_base_document, ] cron_jobs = [] diff --git a/api/tasks/campaign_tasks.py b/api/tasks/campaign_tasks.py index 964933f..db2fc24 100644 --- a/api/tasks/campaign_tasks.py +++ b/api/tasks/campaign_tasks.py @@ -4,12 +4,11 @@ from typing import Dict from loguru import logger from api.db import db_client -from api.enums import RedisChannel -from api.services.campaign.call_dispatcher import campaign_call_dispatcher -from api.services.campaign.campaign_event_protocol import BatchFailedEvent +from api.services.campaign.campaign_call_dispatcher import campaign_call_dispatcher from api.services.campaign.campaign_event_publisher import ( get_campaign_event_publisher, ) +from api.services.campaign.errors import ConcurrentSlotAcquisitionError from api.services.campaign.source_sync_factory import get_sync_service @@ -95,6 +94,10 @@ async def process_campaign_batch( - Updates queued_run state to 'processed' - Updates campaign.processed_rows counter - Publishes batch_completed event for orchestrator + + # TODO: May be not fail the campaign immediately on a single batch failure + # and propagate the error to campaign orchestrator which can fail the campaign + # on some consecutive batch failures. """ logger.info(f"Processing batch for campaign {campaign_id}, batch_size={batch_size}") @@ -119,81 +122,34 @@ async def process_campaign_batch( f"failed={failed_count}" ) - except Exception as e: - logger.error(f"Error processing batch for campaign {campaign_id}: {e}") - - # Publish batch failed event - publisher = await get_campaign_event_publisher() - event = BatchFailedEvent( - campaign_id=campaign_id, - error=str(e), - processed_count=0, + except ConcurrentSlotAcquisitionError as e: + logger.warning( + f"Failed to acquire concurrent slot for campaign {campaign_id}: {e}" ) - await publisher.redis.publish( - RedisChannel.CAMPAIGN_EVENTS.value, event.to_json() + + # Publish batch failed event with specific error + publisher = await get_campaign_event_publisher() + await publisher.publish_batch_failed( + campaign_id=campaign_id, + error=f"Concurrent slot acquisition timeout: {e}", + processed_count=0, ) # Update campaign state to failed await db_client.update_campaign(campaign_id=campaign_id, state="failed") raise - -async def monitor_campaign_progress(ctx: Dict, campaign_id: int) -> None: - """ - Phase 3: Monitors campaign completion - - Checks if all queued runs are in 'processed' state - - Queries workflow_runs for final call statistics - - Updates campaign state to 'completed' - - Calculates total calls made, successful, failed - - Triggers post-campaign integrations - """ - logger.info(f"Monitoring progress for campaign {campaign_id}") - - try: - # Get campaign - campaign = await db_client.get_campaign_by_id(campaign_id) - if not campaign: - raise ValueError(f"Campaign {campaign_id} not found") - - # Check if all runs are processed - pending_runs = await db_client.count_queued_runs( - campaign_id=campaign_id, state="queued" - ) - - if pending_runs > 0: - logger.info(f"Campaign {campaign_id} still has {pending_runs} pending runs") - return - - # All runs processed, mark campaign as completed - await db_client.update_campaign( - campaign_id=campaign_id, state="completed", completed_at=datetime.now(UTC) - ) - - # Calculate statistics - workflow_runs = await db_client.get_workflow_runs_by_campaign(campaign_id) - - total_calls = len(workflow_runs) - successful_calls = 0 - failed_calls = 0 - - for run in workflow_runs: - callbacks = run.logs.get("telephony_status_callbacks", []) - if callbacks: - final_status = callbacks[-1].get("status", "").lower() - if final_status == "completed": - successful_calls += 1 - elif final_status in ["failed", "busy", "no-answer"]: - failed_calls += 1 - - logger.info( - f"Campaign {campaign_id} completed: " - f"Total calls: {total_calls}, " - f"Successful: {successful_calls}, " - f"Failed: {failed_calls}" - ) - - # TODO: Trigger post-campaign integrations if configured - except Exception as e: - logger.error(f"Error monitoring campaign {campaign_id}: {e}") + logger.error(f"Error processing batch for campaign {campaign_id}: {e}") + + # Publish batch failed event + publisher = await get_campaign_event_publisher() + await publisher.publish_batch_failed( + campaign_id=campaign_id, + error=str(e), + processed_count=0, + ) + + # Update campaign state to failed + await db_client.update_campaign(campaign_id=campaign_id, state="failed") raise diff --git a/api/tasks/function_names.py b/api/tasks/function_names.py index e3aeb54..6d5e73a 100644 --- a/api/tasks/function_names.py +++ b/api/tasks/function_names.py @@ -5,5 +5,4 @@ class FunctionNames: UPLOAD_VOICEMAIL_AUDIO_TO_S3 = "upload_voicemail_audio_to_s3" SYNC_CAMPAIGN_SOURCE = "sync_campaign_source" PROCESS_CAMPAIGN_BATCH = "process_campaign_batch" - MONITOR_CAMPAIGN_PROGRESS = "monitor_campaign_progress" PROCESS_KNOWLEDGE_BASE_DOCUMENT = "process_knowledge_base_document" diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 9b0484b..15e6b54 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -3,7 +3,9 @@ from typing import Any, Dict, Optional from unittest.mock import Mock import pytest +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from api.constants import DATABASE_URL from api.services.workflow.dto import ( EdgeDataDTO, ExtractionVariableDTO, @@ -549,3 +551,22 @@ def three_node_workflow_no_variable_extraction() -> WorkflowGraph: ], ) return WorkflowGraph(dto) + + +# ============================================================================= +# Database fixtures for integration tests +# ============================================================================= + + +@pytest.fixture(scope="session") +async def db_engine(): + """Create database engine for tests.""" + engine = create_async_engine(DATABASE_URL, echo=False) + yield engine + await engine.dispose() + + +@pytest.fixture(scope="session") +async def db_session_factory(db_engine): + """Create session factory for tests.""" + return async_sessionmaker(bind=db_engine, expire_on_commit=False) diff --git a/api/tests/test_campaign_call_dispatcher.py b/api/tests/test_campaign_call_dispatcher.py new file mode 100644 index 0000000..3b0c878 --- /dev/null +++ b/api/tests/test_campaign_call_dispatcher.py @@ -0,0 +1,603 @@ +""" +Tests for CampaignCallDispatcher.process_batch method. + +These tests verify: +1. Basic batch processing functionality +2. Thread-safety via SELECT FOR UPDATE SKIP LOCKED +3. Race condition handling when multiple workers process concurrently +""" + +import asyncio +import uuid +from dataclasses import dataclass +from typing import List +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from sqlalchemy import delete, text + +from api.db.models import ( + CampaignModel, + OrganizationModel, + QueuedRunModel, + UserModel, + WorkflowModel, + WorkflowRunModel, +) +from api.services.campaign.campaign_call_dispatcher import CampaignCallDispatcher + +# ============================================================================= +# Test-specific fixtures +# ============================================================================= + + +@dataclass +class CampaignTestData: + """Container for campaign test data IDs""" + + organization_id: int + user_id: int + workflow_id: int + campaign_id: int + queued_run_ids: List[int] + + +@pytest.fixture +async def campaign_test_data(db_session_factory) -> CampaignTestData: + """ + Create test data for campaign processing tests. + + Creates: + - Organization + - User + - Workflow + - Campaign (in 'running' state) + - 10 QueuedRuns (in 'queued' state) + """ + async with db_session_factory() as session: + # Create organization + org = OrganizationModel( + provider_id=f"test-org-{uuid.uuid4().hex[:8]}", + ) + session.add(org) + await session.flush() + + # Create user + user = UserModel( + provider_id=f"test-user-{uuid.uuid4().hex[:8]}", + selected_organization_id=org.id, + ) + session.add(user) + await session.flush() + + # Create workflow + workflow = WorkflowModel( + name=f"test-workflow-{uuid.uuid4().hex[:8]}", + user_id=user.id, + organization_id=org.id, + workflow_definition={ + "nodes": [ + { + "id": "1", + "type": "startCall", + "position": {"x": 0, "y": 0}, + "data": {"name": "Start", "prompt": "Hello"}, + } + ], + "edges": [], + }, + template_context_variables={}, + ) + session.add(workflow) + await session.flush() + + # Create campaign + campaign = CampaignModel( + name=f"test-campaign-{uuid.uuid4().hex[:8]}", + organization_id=org.id, + workflow_id=workflow.id, + created_by=user.id, + source_type="test", + source_id="test-source", + state="running", + rate_limit_per_second=100, # High limit to avoid rate limiting in tests + ) + session.add(campaign) + await session.flush() + + # Create queued runs + queued_run_ids = [] + for i in range(10): + queued_run = QueuedRunModel( + campaign_id=campaign.id, + source_uuid=f"test-uuid-{i}", + context_variables={"phone_number": f"+1555000{i:04d}"}, + state="queued", + ) + session.add(queued_run) + await session.flush() + queued_run_ids.append(queued_run.id) + + await session.commit() + + test_data = CampaignTestData( + organization_id=org.id, + user_id=user.id, + workflow_id=workflow.id, + campaign_id=campaign.id, + queued_run_ids=queued_run_ids, + ) + + yield test_data + + # Cleanup + async with db_session_factory() as cleanup_session: + # Delete in reverse order of dependencies + await cleanup_session.execute( + delete(QueuedRunModel).where(QueuedRunModel.campaign_id == campaign.id) + ) + await cleanup_session.execute( + delete(WorkflowRunModel).where( + WorkflowRunModel.campaign_id == campaign.id + ) + ) + await cleanup_session.execute( + delete(CampaignModel).where(CampaignModel.id == campaign.id) + ) + await cleanup_session.execute( + delete(WorkflowModel).where(WorkflowModel.id == workflow.id) + ) + await cleanup_session.execute( + delete(UserModel).where(UserModel.id == user.id) + ) + await cleanup_session.execute( + delete(OrganizationModel).where(OrganizationModel.id == org.id) + ) + await cleanup_session.commit() + + +@pytest.fixture +def mock_dispatch_call(): + """Mock dispatch_call to track which runs were processed.""" + processed_runs = [] + + async def mock_dispatch(queued_run, campaign, slot_id): + # Simulate some processing time + await asyncio.sleep(0.01) + processed_runs.append(queued_run.id) + # Return a mock workflow run + mock_run = MagicMock() + mock_run.id = len(processed_runs) + return mock_run + + return mock_dispatch, processed_runs + + +@pytest.fixture +def mock_rate_limiter(): + """Mock rate limiter to always allow calls.""" + + async def mock_acquire_token(*args, **kwargs): + return True + + async def mock_try_acquire_slot(*args, **kwargs): + return f"slot-{uuid.uuid4().hex[:8]}" + + async def mock_release_slot(*args, **kwargs): + return True + + async def mock_store_mapping(*args, **kwargs): + pass + + async def mock_get_mapping(*args, **kwargs): + return None + + async def mock_delete_mapping(*args, **kwargs): + pass + + return { + "acquire_token": mock_acquire_token, + "try_acquire_concurrent_slot": mock_try_acquire_slot, + "release_concurrent_slot": mock_release_slot, + "store_workflow_slot_mapping": mock_store_mapping, + "get_workflow_slot_mapping": mock_get_mapping, + "delete_workflow_slot_mapping": mock_delete_mapping, + } + + +# ============================================================================= +# Tests +# ============================================================================= + + +class TestProcessBatchBasic: + """Basic tests for process_batch functionality.""" + + @pytest.mark.asyncio + async def test_process_batch_processes_queued_runs( + self, campaign_test_data, mock_dispatch_call, mock_rate_limiter + ): + """Test that process_batch processes queued runs and marks them as processed.""" + mock_dispatch, processed_runs = mock_dispatch_call + + with patch( + "api.services.campaign.campaign_call_dispatcher.rate_limiter" + ) as mock_rl: + # Setup rate limiter mocks + mock_rl.acquire_token = AsyncMock( + side_effect=mock_rate_limiter["acquire_token"] + ) + mock_rl.try_acquire_concurrent_slot = AsyncMock( + side_effect=mock_rate_limiter["try_acquire_concurrent_slot"] + ) + mock_rl.release_concurrent_slot = AsyncMock( + side_effect=mock_rate_limiter["release_concurrent_slot"] + ) + mock_rl.store_workflow_slot_mapping = AsyncMock( + side_effect=mock_rate_limiter["store_workflow_slot_mapping"] + ) + mock_rl.get_workflow_slot_mapping = AsyncMock( + side_effect=mock_rate_limiter["get_workflow_slot_mapping"] + ) + mock_rl.delete_workflow_slot_mapping = AsyncMock( + side_effect=mock_rate_limiter["delete_workflow_slot_mapping"] + ) + + dispatcher = CampaignCallDispatcher() + + # Mock dispatch_call + with patch.object(dispatcher, "dispatch_call", side_effect=mock_dispatch): + # Process batch of 5 + processed_count = await dispatcher.process_batch( + campaign_id=campaign_test_data.campaign_id, batch_size=5 + ) + + assert processed_count == 5 + assert len(processed_runs) == 5 + + +class TestProcessBatchConcurrency: + """Tests for concurrent batch processing and database locking.""" + + @pytest.mark.asyncio + async def test_concurrent_process_batch_no_duplicate_processing( + self, + campaign_test_data, + mock_dispatch_call, + mock_rate_limiter, + db_session_factory, + ): + """ + Test that two concurrent process_batch calls don't process the same runs. + + This verifies the SELECT FOR UPDATE SKIP LOCKED mechanism works correctly. + """ + mock_dispatch, processed_runs = mock_dispatch_call + + # Reset queued runs to 'queued' state for this test + async with db_session_factory() as session: + await session.execute( + text( + "UPDATE queued_runs SET state = 'queued' WHERE campaign_id = :campaign_id" + ), + {"campaign_id": campaign_test_data.campaign_id}, + ) + await session.commit() + + async def run_process_batch(): + """Helper to run process_batch with mocked dependencies.""" + with patch( + "api.services.campaign.campaign_call_dispatcher.rate_limiter" + ) as mock_rl: + mock_rl.acquire_token = AsyncMock( + side_effect=mock_rate_limiter["acquire_token"] + ) + mock_rl.try_acquire_concurrent_slot = AsyncMock( + side_effect=mock_rate_limiter["try_acquire_concurrent_slot"] + ) + mock_rl.release_concurrent_slot = AsyncMock( + side_effect=mock_rate_limiter["release_concurrent_slot"] + ) + mock_rl.store_workflow_slot_mapping = AsyncMock( + side_effect=mock_rate_limiter["store_workflow_slot_mapping"] + ) + mock_rl.get_workflow_slot_mapping = AsyncMock( + side_effect=mock_rate_limiter["get_workflow_slot_mapping"] + ) + mock_rl.delete_workflow_slot_mapping = AsyncMock( + side_effect=mock_rate_limiter["delete_workflow_slot_mapping"] + ) + + dispatcher = CampaignCallDispatcher() + + with patch.object( + dispatcher, "dispatch_call", side_effect=mock_dispatch + ): + return await dispatcher.process_batch( + campaign_id=campaign_test_data.campaign_id, batch_size=5 + ) + + # Run two process_batch calls concurrently + results = await asyncio.gather( + run_process_batch(), + run_process_batch(), + ) + + # Total processed should be 10 (all queued runs) + total_processed = sum(results) + assert total_processed == 10, f"Expected 10 total, got {total_processed}" + + # Each run should be processed exactly once (no duplicates) + assert len(processed_runs) == 10, f"Expected 10 runs, got {len(processed_runs)}" + assert len(set(processed_runs)) == 10, "Duplicate runs were processed!" + + @pytest.mark.asyncio + async def test_concurrent_process_batch_with_different_batch_sizes( + self, + campaign_test_data, + mock_dispatch_call, + mock_rate_limiter, + db_session_factory, + ): + """ + Test concurrent processing with different batch sizes. + + Worker 1 requests 3 runs, Worker 2 requests 7 runs. + Total should still be 10 with no duplicates. + """ + mock_dispatch, processed_runs = mock_dispatch_call + + # Reset queued runs to 'queued' state + async with db_session_factory() as session: + await session.execute( + text( + "UPDATE queued_runs SET state = 'queued' WHERE campaign_id = :campaign_id" + ), + {"campaign_id": campaign_test_data.campaign_id}, + ) + await session.commit() + + async def run_process_batch(batch_size: int): + with patch( + "api.services.campaign.campaign_call_dispatcher.rate_limiter" + ) as mock_rl: + mock_rl.acquire_token = AsyncMock( + side_effect=mock_rate_limiter["acquire_token"] + ) + mock_rl.try_acquire_concurrent_slot = AsyncMock( + side_effect=mock_rate_limiter["try_acquire_concurrent_slot"] + ) + mock_rl.release_concurrent_slot = AsyncMock( + side_effect=mock_rate_limiter["release_concurrent_slot"] + ) + mock_rl.store_workflow_slot_mapping = AsyncMock( + side_effect=mock_rate_limiter["store_workflow_slot_mapping"] + ) + mock_rl.get_workflow_slot_mapping = AsyncMock( + side_effect=mock_rate_limiter["get_workflow_slot_mapping"] + ) + mock_rl.delete_workflow_slot_mapping = AsyncMock( + side_effect=mock_rate_limiter["delete_workflow_slot_mapping"] + ) + + dispatcher = CampaignCallDispatcher() + + with patch.object( + dispatcher, "dispatch_call", side_effect=mock_dispatch + ): + return await dispatcher.process_batch( + campaign_id=campaign_test_data.campaign_id, + batch_size=batch_size, + ) + + # Run with different batch sizes concurrently + results = await asyncio.gather( + run_process_batch(3), + run_process_batch(7), + ) + + total_processed = sum(results) + assert total_processed == 10 + + # Verify no duplicates + assert len(set(processed_runs)) == len(processed_runs) + + @pytest.mark.asyncio + async def test_multiple_concurrent_workers( + self, + campaign_test_data, + mock_dispatch_call, + mock_rate_limiter, + db_session_factory, + ): + """ + Test with many concurrent workers (simulating production scenario). + + 5 workers each requesting 4 runs from a pool of 10. + Should process all 10 exactly once. + """ + mock_dispatch, processed_runs = mock_dispatch_call + + # Reset queued runs + async with db_session_factory() as session: + await session.execute( + text( + "UPDATE queued_runs SET state = 'queued' WHERE campaign_id = :campaign_id" + ), + {"campaign_id": campaign_test_data.campaign_id}, + ) + await session.commit() + + async def run_process_batch(): + with patch( + "api.services.campaign.campaign_call_dispatcher.rate_limiter" + ) as mock_rl: + mock_rl.acquire_token = AsyncMock( + side_effect=mock_rate_limiter["acquire_token"] + ) + mock_rl.try_acquire_concurrent_slot = AsyncMock( + side_effect=mock_rate_limiter["try_acquire_concurrent_slot"] + ) + mock_rl.release_concurrent_slot = AsyncMock( + side_effect=mock_rate_limiter["release_concurrent_slot"] + ) + mock_rl.store_workflow_slot_mapping = AsyncMock( + side_effect=mock_rate_limiter["store_workflow_slot_mapping"] + ) + mock_rl.get_workflow_slot_mapping = AsyncMock( + side_effect=mock_rate_limiter["get_workflow_slot_mapping"] + ) + mock_rl.delete_workflow_slot_mapping = AsyncMock( + side_effect=mock_rate_limiter["delete_workflow_slot_mapping"] + ) + + dispatcher = CampaignCallDispatcher() + + with patch.object( + dispatcher, "dispatch_call", side_effect=mock_dispatch + ): + return await dispatcher.process_batch( + campaign_id=campaign_test_data.campaign_id, batch_size=4 + ) + + # Run 5 workers concurrently + results = await asyncio.gather(*[run_process_batch() for _ in range(5)]) + + total_processed = sum(results) + assert total_processed == 10 + + # Verify no duplicates + assert len(set(processed_runs)) == 10, "Duplicate runs were processed!" + + @pytest.mark.asyncio + async def test_processing_state_transition( + self, + campaign_test_data, + mock_dispatch_call, + mock_rate_limiter, + db_session_factory, + ): + """ + Test that runs transition through processing -> processed states correctly. + """ + mock_dispatch, processed_runs = mock_dispatch_call + + # Reset queued runs + async with db_session_factory() as session: + await session.execute( + text( + "UPDATE queued_runs SET state = 'queued' WHERE campaign_id = :campaign_id" + ), + {"campaign_id": campaign_test_data.campaign_id}, + ) + await session.commit() + + with patch( + "api.services.campaign.campaign_call_dispatcher.rate_limiter" + ) as mock_rl: + mock_rl.acquire_token = AsyncMock( + side_effect=mock_rate_limiter["acquire_token"] + ) + mock_rl.try_acquire_concurrent_slot = AsyncMock( + side_effect=mock_rate_limiter["try_acquire_concurrent_slot"] + ) + mock_rl.release_concurrent_slot = AsyncMock( + side_effect=mock_rate_limiter["release_concurrent_slot"] + ) + mock_rl.store_workflow_slot_mapping = AsyncMock( + side_effect=mock_rate_limiter["store_workflow_slot_mapping"] + ) + mock_rl.get_workflow_slot_mapping = AsyncMock( + side_effect=mock_rate_limiter["get_workflow_slot_mapping"] + ) + mock_rl.delete_workflow_slot_mapping = AsyncMock( + side_effect=mock_rate_limiter["delete_workflow_slot_mapping"] + ) + + dispatcher = CampaignCallDispatcher() + + with patch.object(dispatcher, "dispatch_call", side_effect=mock_dispatch): + await dispatcher.process_batch( + campaign_id=campaign_test_data.campaign_id, batch_size=10 + ) + + # Verify all runs are in 'processed' state + async with db_session_factory() as session: + result = await session.execute( + text( + "SELECT state, COUNT(*) as count FROM queued_runs " + "WHERE campaign_id = :campaign_id GROUP BY state" + ), + {"campaign_id": campaign_test_data.campaign_id}, + ) + states = {row[0]: row[1] for row in result.fetchall()} + + assert states.get("processed", 0) == 10 + assert states.get("queued", 0) == 0 + assert states.get("processing", 0) == 0 + + +class TestProcessBatchEdgeCases: + """Edge case tests for process_batch.""" + + @pytest.mark.asyncio + async def test_empty_queue( + self, campaign_test_data, mock_rate_limiter, db_session_factory + ): + """Test process_batch with no queued runs returns 0.""" + # Set all runs to processed + async with db_session_factory() as session: + await session.execute( + text( + "UPDATE queued_runs SET state = 'processed' WHERE campaign_id = :campaign_id" + ), + {"campaign_id": campaign_test_data.campaign_id}, + ) + await session.commit() + + with patch( + "api.services.campaign.campaign_call_dispatcher.rate_limiter" + ) as mock_rl: + mock_rl.acquire_token = AsyncMock( + side_effect=mock_rate_limiter["acquire_token"] + ) + mock_rl.try_acquire_concurrent_slot = AsyncMock( + side_effect=mock_rate_limiter["try_acquire_concurrent_slot"] + ) + + dispatcher = CampaignCallDispatcher() + result = await dispatcher.process_batch( + campaign_id=campaign_test_data.campaign_id, batch_size=5 + ) + + assert result == 0 + + @pytest.mark.asyncio + async def test_campaign_not_running( + self, campaign_test_data, mock_rate_limiter, db_session_factory + ): + """Test process_batch returns 0 if campaign is not in running state.""" + # Set campaign to paused + async with db_session_factory() as session: + await session.execute( + text("UPDATE campaigns SET state = 'paused' WHERE id = :campaign_id"), + {"campaign_id": campaign_test_data.campaign_id}, + ) + await session.commit() + + try: + dispatcher = CampaignCallDispatcher() + result = await dispatcher.process_batch( + campaign_id=campaign_test_data.campaign_id, batch_size=5 + ) + assert result == 0 + finally: + # Restore campaign state + async with db_session_factory() as session: + await session.execute( + text( + "UPDATE campaigns SET state = 'running' WHERE id = :campaign_id" + ), + {"campaign_id": campaign_test_data.campaign_id}, + ) + await session.commit()