fix: make campaign process batch thread safe (#141)

* fix: dont schedule new batch on resume

* fix: make process_batch thread safe
This commit is contained in:
Abhishek 2026-01-30 14:48:00 +05:30 committed by GitHub
parent e9c5da16c5
commit 6827744327
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1012 additions and 230 deletions

View file

@ -9,6 +9,7 @@ from api.constants import DEFAULT_ORG_CONCURRENCY_LIMIT
from api.db import db_client
from api.db.models import QueuedRunModel, WorkflowRunModel
from api.enums import OrganizationConfigurationKey, WorkflowRunState
from api.services.campaign.errors import ConcurrentSlotAcquisitionError
from api.services.campaign.rate_limiter import rate_limiter
from api.services.telephony.base import TelephonyProvider
from api.services.telephony.factory import get_telephony_provider
@ -42,7 +43,8 @@ class CampaignCallDispatcher:
async def process_batch(self, campaign_id: int, batch_size: int = 10) -> int:
"""
Processes a batch of queued runs with priority for scheduled retries
Processes a batch of queued runs with priority for scheduled retries.
Thread-safe: uses SELECT FOR UPDATE SKIP LOCKED to prevent concurrent processing.
Returns: number of processed runs
"""
# Get campaign details
@ -57,41 +59,34 @@ class CampaignCallDispatcher:
)
return 0
# First, get any scheduled retries that are due
scheduled_runs = await db_client.get_scheduled_queued_runs(
# Atomically claim queued runs for processing (thread-safe)
# This uses SELECT FOR UPDATE SKIP LOCKED to prevent race conditions
queued_runs = await db_client.claim_queued_runs_for_processing(
campaign_id=campaign_id,
scheduled_before=datetime.now(UTC),
limit=batch_size,
)
remaining_slots = batch_size - len(scheduled_runs)
# Then get regular queued runs
regular_runs = []
if remaining_slots > 0:
regular_runs = await db_client.get_queued_runs(
campaign_id=campaign_id,
state="queued",
scheduled_for=False, # Exclude scheduled runs
limit=remaining_slots,
)
queued_runs = scheduled_runs + regular_runs
if not queued_runs:
logger.info(f"No more queued runs for campaign {campaign_id}")
return 0
processed_count = 0
for queued_run in queued_runs:
for i, queued_run in enumerate(queued_runs):
try:
# Apply rate limiting
# Apply rate limiting, i.e lets not initiate more than rate_limit_per_second
# calls per second. It is different than concurrency limit.
await self.apply_rate_limit(
campaign.organization_id, campaign.rate_limit_per_second
)
# Acquire concurrent slot - waits until a slot is available
slot_id = await self.acquire_concurrent_slot(
campaign.organization_id, campaign
)
# Dispatch the call
workflow_run = await self.dispatch_call(queued_run, campaign)
workflow_run = await self.dispatch_call(queued_run, campaign, slot_id)
# Update queued run as processed
await db_client.update_queued_run(
@ -108,6 +103,25 @@ class CampaignCallDispatcher:
campaign_id=campaign_id, processed_rows=campaign.processed_rows + 1
)
except ConcurrentSlotAcquisitionError:
# Revert all unprocessed runs (current and remaining) back to queued
# so they can be picked up again when campaign is resumed
for unprocessed_run in queued_runs[i:]:
try:
await db_client.update_queued_run(
queued_run_id=unprocessed_run.id,
state="queued",
)
logger.info(
f"Reverted queued run {unprocessed_run.id} back to queued state"
)
except Exception as revert_error:
logger.error(
f"Failed to revert queued run {unprocessed_run.id}: {revert_error}"
)
# Re-raise to propagate to process_campaign_batch
raise
except Exception as e:
logger.warning(f"Error processing queued run {queued_run.id}: {e}")
@ -129,54 +143,9 @@ class CampaignCallDispatcher:
return processed_count
async def dispatch_call(
self, queued_run: QueuedRunModel, campaign: any
self, queued_run: QueuedRunModel, campaign: any, slot_id: str
) -> Optional[WorkflowRunModel]:
"""Creates workflow run and initiates call with concurrent limiting"""
# Get concurrent limit for organization
org_concurrent_limit = await self.get_org_concurrent_limit(
campaign.organization_id
)
# Check for campaign-level max_concurrency in orchestrator_metadata
campaign_max_concurrency = None
if campaign.orchestrator_metadata:
campaign_max_concurrency = campaign.orchestrator_metadata.get(
"max_concurrency"
)
# Use the lower of campaign limit and org limit
if campaign_max_concurrency is not None:
max_concurrent = min(campaign_max_concurrency, org_concurrent_limit)
else:
max_concurrent = org_concurrent_limit
# Track wait time for alerting
wait_start = time.time()
slot_id = None
# Wait until we can acquire a concurrent slot
while True:
slot_id = await rate_limiter.try_acquire_concurrent_slot(
campaign.organization_id, max_concurrent
)
if slot_id:
break
# Check if we've been waiting too long
wait_time = time.time() - wait_start
if wait_time > 600: # 10 minutes
logger.error(
f"Waiting for concurrent slot for {wait_time:.1f}s, "
f"org: {campaign.organization_id}, campaign: {campaign.id}"
)
logger.debug(
f"Attempting to get a slot for {campaign.organization_id} {campaign.id}"
)
# Wait before retrying
await asyncio.sleep(1)
"""Creates workflow run and initiates call. Requires a pre-acquired slot_id."""
# Get workflow details
workflow = await db_client.get_workflow_by_id(campaign.workflow_id)
if not workflow:
@ -351,6 +320,66 @@ class CampaignCallDispatcher:
# Wait for next available slot
await asyncio.sleep(wait_time)
async def acquire_concurrent_slot(
self, organization_id: int, campaign: any, timeout: float = 600
) -> str:
"""
Acquires a concurrent call slot - waits if necessary until a slot is available.
Args:
organization_id: The organization ID
campaign: The campaign object
timeout: Maximum time to wait for a slot (default 10 minutes)
Returns the slot_id which must be released when the call completes.
Raises:
ConcurrentSlotAcquisitionError: If slot cannot be acquired within timeout
"""
# Get concurrent limit for organization
org_concurrent_limit = await self.get_org_concurrent_limit(organization_id)
# Check for campaign-level max_concurrency in orchestrator_metadata
campaign_max_concurrency = None
if campaign.orchestrator_metadata:
campaign_max_concurrency = campaign.orchestrator_metadata.get(
"max_concurrency"
)
# Use the lower of campaign limit and org limit
if campaign_max_concurrency is not None:
max_concurrent = min(campaign_max_concurrency, org_concurrent_limit)
else:
max_concurrent = org_concurrent_limit
# Track wait time for alerting
wait_start = time.time()
# Wait until we can acquire a concurrent slot
while True:
slot_id = await rate_limiter.try_acquire_concurrent_slot(
organization_id, max_concurrent
)
if slot_id:
return slot_id
# Check if we've been waiting too long
wait_time = time.time() - wait_start
if wait_time > timeout:
raise ConcurrentSlotAcquisitionError(
organization_id=organization_id,
campaign_id=campaign.id,
wait_time=wait_time,
)
logger.debug(
f"Attempting to get a slot for {organization_id} {campaign.id}, "
f"waited {wait_time:.1f}s"
)
# Wait before retrying
await asyncio.sleep(1)
async def release_call_slot(self, workflow_run_id: int) -> bool:
"""
Release concurrent slot when a call completes.

View file

@ -12,6 +12,7 @@ from api.constants import REDIS_URL
from api.enums import RedisChannel
from api.services.campaign.campaign_event_protocol import (
BatchCompletedEvent,
BatchFailedEvent,
CampaignCompletedEvent,
RetryNeededEvent,
SyncCompletedEvent,
@ -43,6 +44,23 @@ class CampaignEventPublisher:
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
async def publish_batch_failed(
self,
campaign_id: int,
error: str,
processed_count: int = 0,
metadata: Optional[Dict] = None,
):
"""Publish batch failed event."""
event = BatchFailedEvent(
campaign_id=campaign_id,
error=error,
processed_count=processed_count,
metadata=metadata,
)
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
async def publish_sync_completed(
self,
campaign_id: int,

View file

@ -23,11 +23,13 @@ from api.db import db_client
from api.db.models import CampaignModel, QueuedRunModel
from api.enums import RedisChannel
from api.services.campaign.campaign_event_protocol import (
CampaignCompletedEvent,
CampaignEventType,
BatchCompletedEvent,
BatchFailedEvent,
RetryNeededEvent,
SyncCompletedEvent,
parse_campaign_event,
)
from api.services.campaign.campaign_event_publisher import CampaignEventPublisher
from api.tasks.arq import enqueue_job
from api.tasks.function_names import FunctionNames
@ -37,6 +39,7 @@ class CampaignOrchestrator:
def __init__(self, redis_client: aioredis.Redis):
self.redis = redis_client
self.publisher = CampaignEventPublisher(redis_client)
self.completion_check_interval = 60 # 1 minute
self.completion_timeout = 3600 # 1 hour
self._processing_locks: Dict[int, datetime] = {} # prevent duplicate scheduling
@ -97,22 +100,21 @@ class CampaignOrchestrator:
async def _handle_event(self, event):
"""Handle campaign events including retry events."""
# Handle RetryNeededEvent
if isinstance(event, RetryNeededEvent):
await self._handle_retry_event(event)
return
# All events should have campaign_id
if not hasattr(event, "campaign_id") or not event.campaign_id:
logger.warning(f"Event missing campaign_id: {type(event).__name__}")
return
campaign_id = event.campaign_id
event_type = event.type
logger.debug(f"campaign_id: {campaign_id} - Received event: {event_type}")
logger.debug(
f"campaign_id: {campaign_id} - Received event: {type(event).__name__}"
)
if event_type == CampaignEventType.BATCH_COMPLETED:
if isinstance(event, RetryNeededEvent):
await self._handle_retry_event(event)
elif isinstance(event, BatchCompletedEvent):
# Clear the batch in progress flag
if campaign_id in self._batch_in_progress:
del self._batch_in_progress[campaign_id]
@ -120,11 +122,42 @@ class CampaignOrchestrator:
f"campaign_id: {campaign_id} - Batch completed, cleared in-progress flag"
)
# Check campaign state before scheduling next batch
campaign = await db_client.get_campaign_by_id(campaign_id)
if not campaign:
logger.error(f"campaign_id: {campaign_id} - Campaign not found")
self._clear_campaign_state(campaign_id)
return
if campaign.state != "running":
logger.info(
f"campaign_id: {campaign_id} - Campaign not in running state ({campaign.state}), "
f"not scheduling next batch"
)
self._clear_campaign_state(campaign_id)
return
# Immediately schedule next batch
await self._schedule_next_batch(campaign_id)
self._last_activity[campaign_id] = datetime.now(UTC)
elif event_type == CampaignEventType.SYNC_COMPLETED:
elif isinstance(event, BatchFailedEvent):
# Clear the batch in progress flag
if campaign_id in self._batch_in_progress:
del self._batch_in_progress[campaign_id]
logger.warning(
f"campaign_id: {campaign_id} - Batch failed: {event.error}, "
f"scheduling next batch to continue processing"
)
# Lets not schedule another batch, since we mark the campaign
# as failed just to be on the safe side from process_campaign_batch
# if a batch fails
self._last_activity[campaign_id] = datetime.now(UTC)
elif isinstance(event, SyncCompletedEvent):
# Start processing after sync
logger.info(
f"campaign_id: {campaign_id} - Sync completed, starting processing"
@ -309,6 +342,16 @@ class CampaignOrchestrator:
del self._processing_locks[campaign_id]
logger.debug(f"campaign_id: {campaign_id} - Released processing lock")
def _clear_campaign_state(self, campaign_id: int):
"""Clear all in-memory state for a campaign."""
if campaign_id in self._last_activity:
del self._last_activity[campaign_id]
if campaign_id in self._processing_locks:
del self._processing_locks[campaign_id]
if campaign_id in self._batch_in_progress:
del self._batch_in_progress[campaign_id]
logger.debug(f"campaign_id: {campaign_id} - Cleared all in-memory state")
async def _monitor_completion(self):
"""Periodically check for campaigns that should be marked complete."""
while self._running:
@ -457,30 +500,22 @@ class CampaignOrchestrator:
logger.info(f"campaign_id: {campaign_id} - Campaign marked as completed")
# Publish completion event using typed event
completion_event = CampaignCompletedEvent(
# Calculate duration if started_at is available
duration = None
if campaign.started_at:
duration = (datetime.now(UTC) - campaign.started_at).total_seconds()
# Publish completion event
await self.publisher.publish_campaign_completed(
campaign_id=campaign_id,
total_rows=campaign.total_rows or 0,
processed_rows=campaign.processed_rows,
failed_rows=campaign.failed_rows,
)
# Calculate duration if started_at is available
if campaign.started_at:
duration = (datetime.now(UTC) - campaign.started_at).total_seconds()
completion_event.duration_seconds = duration
await self.redis.publish(
RedisChannel.CAMPAIGN_EVENTS.value, completion_event.to_json()
duration_seconds=duration,
)
# Clean up in-memory state
if campaign_id in self._last_activity:
del self._last_activity[campaign_id]
if campaign_id in self._processing_locks:
del self._processing_locks[campaign_id]
if campaign_id in self._batch_in_progress:
del self._batch_in_progress[campaign_id]
self._clear_campaign_state(campaign_id)
except Exception as e:
logger.error(

View file

@ -0,0 +1,16 @@
"""
Campaign service exceptions.
"""
class ConcurrentSlotAcquisitionError(Exception):
"""Raised when a concurrent call slot cannot be acquired within the timeout period."""
def __init__(self, organization_id: int, campaign_id: int, wait_time: float):
self.organization_id = organization_id
self.campaign_id = campaign_id
self.wait_time = wait_time
super().__init__(
f"Failed to acquire concurrent slot for org {organization_id}, "
f"campaign {campaign_id} after waiting {wait_time:.1f}s"
)

View file

@ -0,0 +1,18 @@
### campaign_orchestrator.py (CampaignOrchestrator)
- Listens to retry events, batch completed event, sync completed events from redis pubsub, and schedules batches
- Monitors stale campaigns and schedules batches if one is not already scheduled
- Marks campaign as completed if no more tasks pending
### runner.py (CampaignRunnerService)
- Service layer to handle router requests, like run campaign, pause campaign, resume campaign, get campaign status etc.
### call_dispatcher.py (CampaignCallDispatcher)
- Ensures rate limit and concurrency limits and dispatches call using telephony provider
### campaign_tasks.py
- sync campaign from source
- process campaign batch

View file

@ -63,12 +63,10 @@ class CampaignRunnerService:
f"Campaign must be in 'paused' state to resume, current state: {campaign.state}"
)
# Update state to running
# Update state to running. Do not queue batch since campaign orchestrator's
# stale campaign checker would do that if there are pending work.
await db_client.update_campaign(campaign_id=campaign_id, state="running")
# Enqueue process batch task to continue processing
await enqueue_job(FunctionNames.PROCESS_CAMPAIGN_BATCH, campaign_id)
logger.info(f"Campaign {campaign_id} resumed")
async def get_campaign_status(self, campaign_id: int) -> Dict[str, Any]:

View file

@ -2,7 +2,7 @@ from loguru import logger
from api.db import db_client
from api.enums import WorkflowRunState
from api.services.campaign.call_dispatcher import campaign_call_dispatcher
from api.services.campaign.campaign_call_dispatcher import campaign_call_dispatcher
from api.services.pipecat.audio_config import AudioConfig
from api.services.pipecat.in_memory_buffers import (
InMemoryAudioBuffer,