mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-28 08:49:42 +02:00
fix: make campaign process batch thread safe (#141)
* fix: dont schedule new batch on resume * fix: make process_batch thread safe
This commit is contained in:
parent
e9c5da16c5
commit
6827744327
17 changed files with 1012 additions and 230 deletions
|
|
@ -9,6 +9,7 @@ from api.constants import DEFAULT_ORG_CONCURRENCY_LIMIT
|
|||
from api.db import db_client
|
||||
from api.db.models import QueuedRunModel, WorkflowRunModel
|
||||
from api.enums import OrganizationConfigurationKey, WorkflowRunState
|
||||
from api.services.campaign.errors import ConcurrentSlotAcquisitionError
|
||||
from api.services.campaign.rate_limiter import rate_limiter
|
||||
from api.services.telephony.base import TelephonyProvider
|
||||
from api.services.telephony.factory import get_telephony_provider
|
||||
|
|
@ -42,7 +43,8 @@ class CampaignCallDispatcher:
|
|||
|
||||
async def process_batch(self, campaign_id: int, batch_size: int = 10) -> int:
|
||||
"""
|
||||
Processes a batch of queued runs with priority for scheduled retries
|
||||
Processes a batch of queued runs with priority for scheduled retries.
|
||||
Thread-safe: uses SELECT FOR UPDATE SKIP LOCKED to prevent concurrent processing.
|
||||
Returns: number of processed runs
|
||||
"""
|
||||
# Get campaign details
|
||||
|
|
@ -57,41 +59,34 @@ class CampaignCallDispatcher:
|
|||
)
|
||||
return 0
|
||||
|
||||
# First, get any scheduled retries that are due
|
||||
scheduled_runs = await db_client.get_scheduled_queued_runs(
|
||||
# Atomically claim queued runs for processing (thread-safe)
|
||||
# This uses SELECT FOR UPDATE SKIP LOCKED to prevent race conditions
|
||||
queued_runs = await db_client.claim_queued_runs_for_processing(
|
||||
campaign_id=campaign_id,
|
||||
scheduled_before=datetime.now(UTC),
|
||||
limit=batch_size,
|
||||
)
|
||||
|
||||
remaining_slots = batch_size - len(scheduled_runs)
|
||||
|
||||
# Then get regular queued runs
|
||||
regular_runs = []
|
||||
if remaining_slots > 0:
|
||||
regular_runs = await db_client.get_queued_runs(
|
||||
campaign_id=campaign_id,
|
||||
state="queued",
|
||||
scheduled_for=False, # Exclude scheduled runs
|
||||
limit=remaining_slots,
|
||||
)
|
||||
|
||||
queued_runs = scheduled_runs + regular_runs
|
||||
|
||||
if not queued_runs:
|
||||
logger.info(f"No more queued runs for campaign {campaign_id}")
|
||||
return 0
|
||||
|
||||
processed_count = 0
|
||||
for queued_run in queued_runs:
|
||||
for i, queued_run in enumerate(queued_runs):
|
||||
try:
|
||||
# Apply rate limiting
|
||||
# Apply rate limiting, i.e lets not initiate more than rate_limit_per_second
|
||||
# calls per second. It is different than concurrency limit.
|
||||
await self.apply_rate_limit(
|
||||
campaign.organization_id, campaign.rate_limit_per_second
|
||||
)
|
||||
|
||||
# Acquire concurrent slot - waits until a slot is available
|
||||
slot_id = await self.acquire_concurrent_slot(
|
||||
campaign.organization_id, campaign
|
||||
)
|
||||
|
||||
# Dispatch the call
|
||||
workflow_run = await self.dispatch_call(queued_run, campaign)
|
||||
workflow_run = await self.dispatch_call(queued_run, campaign, slot_id)
|
||||
|
||||
# Update queued run as processed
|
||||
await db_client.update_queued_run(
|
||||
|
|
@ -108,6 +103,25 @@ class CampaignCallDispatcher:
|
|||
campaign_id=campaign_id, processed_rows=campaign.processed_rows + 1
|
||||
)
|
||||
|
||||
except ConcurrentSlotAcquisitionError:
|
||||
# Revert all unprocessed runs (current and remaining) back to queued
|
||||
# so they can be picked up again when campaign is resumed
|
||||
for unprocessed_run in queued_runs[i:]:
|
||||
try:
|
||||
await db_client.update_queued_run(
|
||||
queued_run_id=unprocessed_run.id,
|
||||
state="queued",
|
||||
)
|
||||
logger.info(
|
||||
f"Reverted queued run {unprocessed_run.id} back to queued state"
|
||||
)
|
||||
except Exception as revert_error:
|
||||
logger.error(
|
||||
f"Failed to revert queued run {unprocessed_run.id}: {revert_error}"
|
||||
)
|
||||
# Re-raise to propagate to process_campaign_batch
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing queued run {queued_run.id}: {e}")
|
||||
|
||||
|
|
@ -129,54 +143,9 @@ class CampaignCallDispatcher:
|
|||
return processed_count
|
||||
|
||||
async def dispatch_call(
|
||||
self, queued_run: QueuedRunModel, campaign: any
|
||||
self, queued_run: QueuedRunModel, campaign: any, slot_id: str
|
||||
) -> Optional[WorkflowRunModel]:
|
||||
"""Creates workflow run and initiates call with concurrent limiting"""
|
||||
# Get concurrent limit for organization
|
||||
org_concurrent_limit = await self.get_org_concurrent_limit(
|
||||
campaign.organization_id
|
||||
)
|
||||
|
||||
# Check for campaign-level max_concurrency in orchestrator_metadata
|
||||
campaign_max_concurrency = None
|
||||
if campaign.orchestrator_metadata:
|
||||
campaign_max_concurrency = campaign.orchestrator_metadata.get(
|
||||
"max_concurrency"
|
||||
)
|
||||
|
||||
# Use the lower of campaign limit and org limit
|
||||
if campaign_max_concurrency is not None:
|
||||
max_concurrent = min(campaign_max_concurrency, org_concurrent_limit)
|
||||
else:
|
||||
max_concurrent = org_concurrent_limit
|
||||
|
||||
# Track wait time for alerting
|
||||
wait_start = time.time()
|
||||
slot_id = None
|
||||
|
||||
# Wait until we can acquire a concurrent slot
|
||||
while True:
|
||||
slot_id = await rate_limiter.try_acquire_concurrent_slot(
|
||||
campaign.organization_id, max_concurrent
|
||||
)
|
||||
if slot_id:
|
||||
break
|
||||
|
||||
# Check if we've been waiting too long
|
||||
wait_time = time.time() - wait_start
|
||||
if wait_time > 600: # 10 minutes
|
||||
logger.error(
|
||||
f"Waiting for concurrent slot for {wait_time:.1f}s, "
|
||||
f"org: {campaign.organization_id}, campaign: {campaign.id}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Attempting to get a slot for {campaign.organization_id} {campaign.id}"
|
||||
)
|
||||
|
||||
# Wait before retrying
|
||||
await asyncio.sleep(1)
|
||||
|
||||
"""Creates workflow run and initiates call. Requires a pre-acquired slot_id."""
|
||||
# Get workflow details
|
||||
workflow = await db_client.get_workflow_by_id(campaign.workflow_id)
|
||||
if not workflow:
|
||||
|
|
@ -351,6 +320,66 @@ class CampaignCallDispatcher:
|
|||
# Wait for next available slot
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
async def acquire_concurrent_slot(
|
||||
self, organization_id: int, campaign: any, timeout: float = 600
|
||||
) -> str:
|
||||
"""
|
||||
Acquires a concurrent call slot - waits if necessary until a slot is available.
|
||||
|
||||
Args:
|
||||
organization_id: The organization ID
|
||||
campaign: The campaign object
|
||||
timeout: Maximum time to wait for a slot (default 10 minutes)
|
||||
|
||||
Returns the slot_id which must be released when the call completes.
|
||||
|
||||
Raises:
|
||||
ConcurrentSlotAcquisitionError: If slot cannot be acquired within timeout
|
||||
"""
|
||||
# Get concurrent limit for organization
|
||||
org_concurrent_limit = await self.get_org_concurrent_limit(organization_id)
|
||||
|
||||
# Check for campaign-level max_concurrency in orchestrator_metadata
|
||||
campaign_max_concurrency = None
|
||||
if campaign.orchestrator_metadata:
|
||||
campaign_max_concurrency = campaign.orchestrator_metadata.get(
|
||||
"max_concurrency"
|
||||
)
|
||||
|
||||
# Use the lower of campaign limit and org limit
|
||||
if campaign_max_concurrency is not None:
|
||||
max_concurrent = min(campaign_max_concurrency, org_concurrent_limit)
|
||||
else:
|
||||
max_concurrent = org_concurrent_limit
|
||||
|
||||
# Track wait time for alerting
|
||||
wait_start = time.time()
|
||||
|
||||
# Wait until we can acquire a concurrent slot
|
||||
while True:
|
||||
slot_id = await rate_limiter.try_acquire_concurrent_slot(
|
||||
organization_id, max_concurrent
|
||||
)
|
||||
if slot_id:
|
||||
return slot_id
|
||||
|
||||
# Check if we've been waiting too long
|
||||
wait_time = time.time() - wait_start
|
||||
if wait_time > timeout:
|
||||
raise ConcurrentSlotAcquisitionError(
|
||||
organization_id=organization_id,
|
||||
campaign_id=campaign.id,
|
||||
wait_time=wait_time,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Attempting to get a slot for {organization_id} {campaign.id}, "
|
||||
f"waited {wait_time:.1f}s"
|
||||
)
|
||||
|
||||
# Wait before retrying
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def release_call_slot(self, workflow_run_id: int) -> bool:
|
||||
"""
|
||||
Release concurrent slot when a call completes.
|
||||
|
|
@ -12,6 +12,7 @@ from api.constants import REDIS_URL
|
|||
from api.enums import RedisChannel
|
||||
from api.services.campaign.campaign_event_protocol import (
|
||||
BatchCompletedEvent,
|
||||
BatchFailedEvent,
|
||||
CampaignCompletedEvent,
|
||||
RetryNeededEvent,
|
||||
SyncCompletedEvent,
|
||||
|
|
@ -43,6 +44,23 @@ class CampaignEventPublisher:
|
|||
|
||||
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
|
||||
|
||||
async def publish_batch_failed(
|
||||
self,
|
||||
campaign_id: int,
|
||||
error: str,
|
||||
processed_count: int = 0,
|
||||
metadata: Optional[Dict] = None,
|
||||
):
|
||||
"""Publish batch failed event."""
|
||||
event = BatchFailedEvent(
|
||||
campaign_id=campaign_id,
|
||||
error=error,
|
||||
processed_count=processed_count,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
|
||||
|
||||
async def publish_sync_completed(
|
||||
self,
|
||||
campaign_id: int,
|
||||
|
|
|
|||
|
|
@ -23,11 +23,13 @@ from api.db import db_client
|
|||
from api.db.models import CampaignModel, QueuedRunModel
|
||||
from api.enums import RedisChannel
|
||||
from api.services.campaign.campaign_event_protocol import (
|
||||
CampaignCompletedEvent,
|
||||
CampaignEventType,
|
||||
BatchCompletedEvent,
|
||||
BatchFailedEvent,
|
||||
RetryNeededEvent,
|
||||
SyncCompletedEvent,
|
||||
parse_campaign_event,
|
||||
)
|
||||
from api.services.campaign.campaign_event_publisher import CampaignEventPublisher
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
|
||||
|
|
@ -37,6 +39,7 @@ class CampaignOrchestrator:
|
|||
|
||||
def __init__(self, redis_client: aioredis.Redis):
|
||||
self.redis = redis_client
|
||||
self.publisher = CampaignEventPublisher(redis_client)
|
||||
self.completion_check_interval = 60 # 1 minute
|
||||
self.completion_timeout = 3600 # 1 hour
|
||||
self._processing_locks: Dict[int, datetime] = {} # prevent duplicate scheduling
|
||||
|
|
@ -97,22 +100,21 @@ class CampaignOrchestrator:
|
|||
|
||||
async def _handle_event(self, event):
|
||||
"""Handle campaign events including retry events."""
|
||||
# Handle RetryNeededEvent
|
||||
if isinstance(event, RetryNeededEvent):
|
||||
await self._handle_retry_event(event)
|
||||
return
|
||||
|
||||
# All events should have campaign_id
|
||||
if not hasattr(event, "campaign_id") or not event.campaign_id:
|
||||
logger.warning(f"Event missing campaign_id: {type(event).__name__}")
|
||||
return
|
||||
|
||||
campaign_id = event.campaign_id
|
||||
event_type = event.type
|
||||
|
||||
logger.debug(f"campaign_id: {campaign_id} - Received event: {event_type}")
|
||||
logger.debug(
|
||||
f"campaign_id: {campaign_id} - Received event: {type(event).__name__}"
|
||||
)
|
||||
|
||||
if event_type == CampaignEventType.BATCH_COMPLETED:
|
||||
if isinstance(event, RetryNeededEvent):
|
||||
await self._handle_retry_event(event)
|
||||
|
||||
elif isinstance(event, BatchCompletedEvent):
|
||||
# Clear the batch in progress flag
|
||||
if campaign_id in self._batch_in_progress:
|
||||
del self._batch_in_progress[campaign_id]
|
||||
|
|
@ -120,11 +122,42 @@ class CampaignOrchestrator:
|
|||
f"campaign_id: {campaign_id} - Batch completed, cleared in-progress flag"
|
||||
)
|
||||
|
||||
# Check campaign state before scheduling next batch
|
||||
campaign = await db_client.get_campaign_by_id(campaign_id)
|
||||
if not campaign:
|
||||
logger.error(f"campaign_id: {campaign_id} - Campaign not found")
|
||||
self._clear_campaign_state(campaign_id)
|
||||
return
|
||||
|
||||
if campaign.state != "running":
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - Campaign not in running state ({campaign.state}), "
|
||||
f"not scheduling next batch"
|
||||
)
|
||||
self._clear_campaign_state(campaign_id)
|
||||
return
|
||||
|
||||
# Immediately schedule next batch
|
||||
await self._schedule_next_batch(campaign_id)
|
||||
self._last_activity[campaign_id] = datetime.now(UTC)
|
||||
|
||||
elif event_type == CampaignEventType.SYNC_COMPLETED:
|
||||
elif isinstance(event, BatchFailedEvent):
|
||||
# Clear the batch in progress flag
|
||||
if campaign_id in self._batch_in_progress:
|
||||
del self._batch_in_progress[campaign_id]
|
||||
|
||||
logger.warning(
|
||||
f"campaign_id: {campaign_id} - Batch failed: {event.error}, "
|
||||
f"scheduling next batch to continue processing"
|
||||
)
|
||||
|
||||
# Lets not schedule another batch, since we mark the campaign
|
||||
# as failed just to be on the safe side from process_campaign_batch
|
||||
# if a batch fails
|
||||
|
||||
self._last_activity[campaign_id] = datetime.now(UTC)
|
||||
|
||||
elif isinstance(event, SyncCompletedEvent):
|
||||
# Start processing after sync
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - Sync completed, starting processing"
|
||||
|
|
@ -309,6 +342,16 @@ class CampaignOrchestrator:
|
|||
del self._processing_locks[campaign_id]
|
||||
logger.debug(f"campaign_id: {campaign_id} - Released processing lock")
|
||||
|
||||
def _clear_campaign_state(self, campaign_id: int):
|
||||
"""Clear all in-memory state for a campaign."""
|
||||
if campaign_id in self._last_activity:
|
||||
del self._last_activity[campaign_id]
|
||||
if campaign_id in self._processing_locks:
|
||||
del self._processing_locks[campaign_id]
|
||||
if campaign_id in self._batch_in_progress:
|
||||
del self._batch_in_progress[campaign_id]
|
||||
logger.debug(f"campaign_id: {campaign_id} - Cleared all in-memory state")
|
||||
|
||||
async def _monitor_completion(self):
|
||||
"""Periodically check for campaigns that should be marked complete."""
|
||||
while self._running:
|
||||
|
|
@ -457,30 +500,22 @@ class CampaignOrchestrator:
|
|||
|
||||
logger.info(f"campaign_id: {campaign_id} - Campaign marked as completed")
|
||||
|
||||
# Publish completion event using typed event
|
||||
completion_event = CampaignCompletedEvent(
|
||||
# Calculate duration if started_at is available
|
||||
duration = None
|
||||
if campaign.started_at:
|
||||
duration = (datetime.now(UTC) - campaign.started_at).total_seconds()
|
||||
|
||||
# Publish completion event
|
||||
await self.publisher.publish_campaign_completed(
|
||||
campaign_id=campaign_id,
|
||||
total_rows=campaign.total_rows or 0,
|
||||
processed_rows=campaign.processed_rows,
|
||||
failed_rows=campaign.failed_rows,
|
||||
)
|
||||
|
||||
# Calculate duration if started_at is available
|
||||
if campaign.started_at:
|
||||
duration = (datetime.now(UTC) - campaign.started_at).total_seconds()
|
||||
completion_event.duration_seconds = duration
|
||||
|
||||
await self.redis.publish(
|
||||
RedisChannel.CAMPAIGN_EVENTS.value, completion_event.to_json()
|
||||
duration_seconds=duration,
|
||||
)
|
||||
|
||||
# Clean up in-memory state
|
||||
if campaign_id in self._last_activity:
|
||||
del self._last_activity[campaign_id]
|
||||
if campaign_id in self._processing_locks:
|
||||
del self._processing_locks[campaign_id]
|
||||
if campaign_id in self._batch_in_progress:
|
||||
del self._batch_in_progress[campaign_id]
|
||||
self._clear_campaign_state(campaign_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
|
|
|||
16
api/services/campaign/errors.py
Normal file
16
api/services/campaign/errors.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"""
|
||||
Campaign service exceptions.
|
||||
"""
|
||||
|
||||
|
||||
class ConcurrentSlotAcquisitionError(Exception):
|
||||
"""Raised when a concurrent call slot cannot be acquired within the timeout period."""
|
||||
|
||||
def __init__(self, organization_id: int, campaign_id: int, wait_time: float):
|
||||
self.organization_id = organization_id
|
||||
self.campaign_id = campaign_id
|
||||
self.wait_time = wait_time
|
||||
super().__init__(
|
||||
f"Failed to acquire concurrent slot for org {organization_id}, "
|
||||
f"campaign {campaign_id} after waiting {wait_time:.1f}s"
|
||||
)
|
||||
18
api/services/campaign/readme.md
Normal file
18
api/services/campaign/readme.md
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
### campaign_orchestrator.py (CampaignOrchestrator)
|
||||
|
||||
- Listens to retry events, batch completed event, sync completed events from redis pubsub, and schedules batches
|
||||
- Monitors stale campaigns and schedules batches if one is not already scheduled
|
||||
- Marks campaign as completed if no more tasks pending
|
||||
|
||||
### runner.py (CampaignRunnerService)
|
||||
|
||||
- Service layer to handle router requests, like run campaign, pause campaign, resume campaign, get campaign status etc.
|
||||
|
||||
### call_dispatcher.py (CampaignCallDispatcher)
|
||||
|
||||
- Ensures rate limit and concurrency limits and dispatches call using telephony provider
|
||||
|
||||
### campaign_tasks.py
|
||||
|
||||
- sync campaign from source
|
||||
- process campaign batch
|
||||
|
|
@ -63,12 +63,10 @@ class CampaignRunnerService:
|
|||
f"Campaign must be in 'paused' state to resume, current state: {campaign.state}"
|
||||
)
|
||||
|
||||
# Update state to running
|
||||
# Update state to running. Do not queue batch since campaign orchestrator's
|
||||
# stale campaign checker would do that if there are pending work.
|
||||
await db_client.update_campaign(campaign_id=campaign_id, state="running")
|
||||
|
||||
# Enqueue process batch task to continue processing
|
||||
await enqueue_job(FunctionNames.PROCESS_CAMPAIGN_BATCH, campaign_id)
|
||||
|
||||
logger.info(f"Campaign {campaign_id} resumed")
|
||||
|
||||
async def get_campaign_status(self, campaign_id: int) -> Dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from loguru import logger
|
|||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunState
|
||||
from api.services.campaign.call_dispatcher import campaign_call_dispatcher
|
||||
from api.services.campaign.campaign_call_dispatcher import campaign_call_dispatcher
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.pipecat.in_memory_buffers import (
|
||||
InMemoryAudioBuffer,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue