mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
fix: make campaign process batch thread safe (#141)
* fix: dont schedule new batch on resume * fix: make process_batch thread safe
This commit is contained in:
parent
e9c5da16c5
commit
6827744327
17 changed files with 1012 additions and 230 deletions
|
|
@ -5,28 +5,31 @@ Revises: d1dac4c93e61
|
|||
Create Date: 2026-01-29 20:36:57.924887
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '02ffd7f23d1d'
|
||||
down_revision: Union[str, None] = 'd1dac4c93e61'
|
||||
revision: str = "02ffd7f23d1d"
|
||||
down_revision: Union[str, None] = "d1dac4c93e61"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_index('idx_workflow_runs_campaign_id', 'workflow_runs', ['campaign_id'], unique=False)
|
||||
op.create_index('idx_workflow_runs_workflow_id', 'workflow_runs', ['workflow_id'], unique=False)
|
||||
op.create_index(
|
||||
"idx_workflow_runs_campaign_id", "workflow_runs", ["campaign_id"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"idx_workflow_runs_workflow_id", "workflow_runs", ["workflow_id"], unique=False
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index('idx_workflow_runs_workflow_id', table_name='workflow_runs')
|
||||
op.drop_index('idx_workflow_runs_campaign_id', table_name='workflow_runs')
|
||||
op.drop_index("idx_workflow_runs_workflow_id", table_name="workflow_runs")
|
||||
op.drop_index("idx_workflow_runs_campaign_id", table_name="workflow_runs")
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -0,0 +1,54 @@
|
|||
"""add processing state in queued runs
|
||||
|
||||
Revision ID: 34c8537dfde5
|
||||
Revises: 02ffd7f23d1d
|
||||
Create Date: 2026-01-30 14:40:29.905325
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from alembic_postgresql_enum import TableReference
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "34c8537dfde5"
|
||||
down_revision: Union[str, None] = "02ffd7f23d1d"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_api_keys_key_hash"), table_name="api_keys")
|
||||
op.create_index("ix_api_keys_key_hash", "api_keys", ["key_hash"], unique=False)
|
||||
op.sync_enum_values(
|
||||
enum_schema="public",
|
||||
enum_name="queued_run_state",
|
||||
new_values=["queued", "processed", "processing", "failed"],
|
||||
affected_columns=[
|
||||
TableReference(
|
||||
table_schema="public", table_name="queued_runs", column_name="state"
|
||||
)
|
||||
],
|
||||
enum_values_to_rename=[],
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.sync_enum_values(
|
||||
enum_schema="public",
|
||||
enum_name="queued_run_state",
|
||||
new_values=["queued", "processed", "failed"],
|
||||
affected_columns=[
|
||||
TableReference(
|
||||
table_schema="public", table_name="queued_runs", column_name="state"
|
||||
)
|
||||
],
|
||||
enum_values_to_rename=[],
|
||||
)
|
||||
op.drop_index("ix_api_keys_key_hash", table_name="api_keys")
|
||||
op.create_index(op.f("ix_api_keys_key_hash"), "api_keys", ["key_hash"], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -209,31 +209,6 @@ class CampaignClient(BaseDBClient):
|
|||
await session.rollback()
|
||||
raise e
|
||||
|
||||
async def get_queued_runs(
|
||||
self,
|
||||
campaign_id: int,
|
||||
state: str = "queued",
|
||||
limit: int = 10,
|
||||
scheduled_for: Optional[bool] = None,
|
||||
) -> list[QueuedRunModel]:
|
||||
"""Get queued runs for processing, optionally filtering by scheduled status"""
|
||||
async with self.async_session() as session:
|
||||
query = select(QueuedRunModel).where(
|
||||
QueuedRunModel.campaign_id == campaign_id,
|
||||
QueuedRunModel.state == state,
|
||||
)
|
||||
|
||||
# Filter by scheduled status if specified
|
||||
if scheduled_for is True:
|
||||
query = query.where(QueuedRunModel.scheduled_for.isnot(None))
|
||||
elif scheduled_for is False:
|
||||
query = query.where(QueuedRunModel.scheduled_for.is_(None))
|
||||
|
||||
query = query.order_by(QueuedRunModel.created_at).limit(limit)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_queued_run(self, queued_run_id: int, **kwargs) -> QueuedRunModel:
|
||||
"""Update queued run"""
|
||||
async with self.async_session() as session:
|
||||
|
|
@ -284,26 +259,6 @@ class CampaignClient(BaseDBClient):
|
|||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
# New methods for retry support
|
||||
async def get_scheduled_queued_runs(
|
||||
self, campaign_id: int, scheduled_before: datetime, limit: int = 10
|
||||
) -> list[QueuedRunModel]:
|
||||
"""Get scheduled queued runs that are due for processing"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(QueuedRunModel)
|
||||
.where(
|
||||
QueuedRunModel.campaign_id == campaign_id,
|
||||
QueuedRunModel.state == "queued",
|
||||
QueuedRunModel.scheduled_for.isnot(None),
|
||||
QueuedRunModel.scheduled_for <= scheduled_before,
|
||||
)
|
||||
.order_by(QueuedRunModel.scheduled_for)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def create_queued_run(
|
||||
self,
|
||||
campaign_id: int,
|
||||
|
|
@ -388,3 +343,82 @@ class CampaignClient(BaseDBClient):
|
|||
query = select(func.count(QueuedRunModel.id)).where(*conditions)
|
||||
result = await session.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def claim_queued_runs_for_processing(
|
||||
self,
|
||||
campaign_id: int,
|
||||
scheduled_before: datetime,
|
||||
limit: int = 10,
|
||||
) -> list[QueuedRunModel]:
|
||||
"""
|
||||
Atomically claim queued runs for processing using SELECT FOR UPDATE SKIP LOCKED.
|
||||
|
||||
This method is thread-safe - multiple workers can call it concurrently without
|
||||
processing the same runs. It:
|
||||
1. Prioritizes scheduled retries that are due
|
||||
2. Falls back to regular queued runs if more slots available
|
||||
3. Locks selected rows and marks them as 'processing' atomically
|
||||
|
||||
Returns: List of claimed QueuedRunModel objects
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
claimed_runs = []
|
||||
|
||||
# First, get scheduled retries that are due (with lock)
|
||||
scheduled_query = (
|
||||
select(QueuedRunModel)
|
||||
.where(
|
||||
QueuedRunModel.campaign_id == campaign_id,
|
||||
QueuedRunModel.state == "queued",
|
||||
QueuedRunModel.scheduled_for.isnot(None),
|
||||
QueuedRunModel.scheduled_for <= scheduled_before,
|
||||
)
|
||||
.order_by(QueuedRunModel.scheduled_for)
|
||||
.limit(limit)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
|
||||
scheduled_result = await session.execute(scheduled_query)
|
||||
scheduled_runs = list(scheduled_result.scalars().all())
|
||||
|
||||
# Mark scheduled runs as processing
|
||||
for run in scheduled_runs:
|
||||
run.state = "processing"
|
||||
claimed_runs.append(run)
|
||||
|
||||
remaining_slots = limit - len(scheduled_runs)
|
||||
|
||||
# Then get regular queued runs if we have remaining slots
|
||||
if remaining_slots > 0:
|
||||
regular_query = (
|
||||
select(QueuedRunModel)
|
||||
.where(
|
||||
QueuedRunModel.campaign_id == campaign_id,
|
||||
QueuedRunModel.state == "queued",
|
||||
QueuedRunModel.scheduled_for.is_(None),
|
||||
)
|
||||
.order_by(QueuedRunModel.created_at)
|
||||
.limit(remaining_slots)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
|
||||
regular_result = await session.execute(regular_query)
|
||||
regular_runs = list(regular_result.scalars().all())
|
||||
|
||||
# Mark regular runs as processing
|
||||
for run in regular_runs:
|
||||
run.state = "processing"
|
||||
claimed_runs.append(run)
|
||||
|
||||
# Commit the state changes
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
# Refresh to get updated state
|
||||
for run in claimed_runs:
|
||||
await session.refresh(run)
|
||||
|
||||
return claimed_runs
|
||||
|
|
|
|||
|
|
@ -598,7 +598,7 @@ class QueuedRunModel(Base):
|
|||
source_uuid = Column(String, nullable=False)
|
||||
context_variables = Column(JSON, nullable=False, default=dict)
|
||||
state = Column(
|
||||
Enum("queued", "processed", "failed", name="queued_run_state"),
|
||||
Enum("queued", "processed", "processing", "failed", name="queued_run_state"),
|
||||
nullable=False,
|
||||
default="queued",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from api.db.workflow_run_client import WorkflowRunClient
|
|||
from api.enums import CallType, OrganizationConfigurationKey, WorkflowRunState
|
||||
from api.errors.telephony_errors import TelephonyError
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.campaign.call_dispatcher import campaign_call_dispatcher
|
||||
from api.services.campaign.campaign_call_dispatcher import campaign_call_dispatcher
|
||||
from api.services.campaign.campaign_event_publisher import get_campaign_event_publisher
|
||||
from api.services.quota_service import check_dograh_quota, check_dograh_quota_by_user_id
|
||||
from api.services.telephony.factory import (
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from api.constants import DEFAULT_ORG_CONCURRENCY_LIMIT
|
|||
from api.db import db_client
|
||||
from api.db.models import QueuedRunModel, WorkflowRunModel
|
||||
from api.enums import OrganizationConfigurationKey, WorkflowRunState
|
||||
from api.services.campaign.errors import ConcurrentSlotAcquisitionError
|
||||
from api.services.campaign.rate_limiter import rate_limiter
|
||||
from api.services.telephony.base import TelephonyProvider
|
||||
from api.services.telephony.factory import get_telephony_provider
|
||||
|
|
@ -42,7 +43,8 @@ class CampaignCallDispatcher:
|
|||
|
||||
async def process_batch(self, campaign_id: int, batch_size: int = 10) -> int:
|
||||
"""
|
||||
Processes a batch of queued runs with priority for scheduled retries
|
||||
Processes a batch of queued runs with priority for scheduled retries.
|
||||
Thread-safe: uses SELECT FOR UPDATE SKIP LOCKED to prevent concurrent processing.
|
||||
Returns: number of processed runs
|
||||
"""
|
||||
# Get campaign details
|
||||
|
|
@ -57,41 +59,34 @@ class CampaignCallDispatcher:
|
|||
)
|
||||
return 0
|
||||
|
||||
# First, get any scheduled retries that are due
|
||||
scheduled_runs = await db_client.get_scheduled_queued_runs(
|
||||
# Atomically claim queued runs for processing (thread-safe)
|
||||
# This uses SELECT FOR UPDATE SKIP LOCKED to prevent race conditions
|
||||
queued_runs = await db_client.claim_queued_runs_for_processing(
|
||||
campaign_id=campaign_id,
|
||||
scheduled_before=datetime.now(UTC),
|
||||
limit=batch_size,
|
||||
)
|
||||
|
||||
remaining_slots = batch_size - len(scheduled_runs)
|
||||
|
||||
# Then get regular queued runs
|
||||
regular_runs = []
|
||||
if remaining_slots > 0:
|
||||
regular_runs = await db_client.get_queued_runs(
|
||||
campaign_id=campaign_id,
|
||||
state="queued",
|
||||
scheduled_for=False, # Exclude scheduled runs
|
||||
limit=remaining_slots,
|
||||
)
|
||||
|
||||
queued_runs = scheduled_runs + regular_runs
|
||||
|
||||
if not queued_runs:
|
||||
logger.info(f"No more queued runs for campaign {campaign_id}")
|
||||
return 0
|
||||
|
||||
processed_count = 0
|
||||
for queued_run in queued_runs:
|
||||
for i, queued_run in enumerate(queued_runs):
|
||||
try:
|
||||
# Apply rate limiting
|
||||
# Apply rate limiting, i.e lets not initiate more than rate_limit_per_second
|
||||
# calls per second. It is different than concurrency limit.
|
||||
await self.apply_rate_limit(
|
||||
campaign.organization_id, campaign.rate_limit_per_second
|
||||
)
|
||||
|
||||
# Acquire concurrent slot - waits until a slot is available
|
||||
slot_id = await self.acquire_concurrent_slot(
|
||||
campaign.organization_id, campaign
|
||||
)
|
||||
|
||||
# Dispatch the call
|
||||
workflow_run = await self.dispatch_call(queued_run, campaign)
|
||||
workflow_run = await self.dispatch_call(queued_run, campaign, slot_id)
|
||||
|
||||
# Update queued run as processed
|
||||
await db_client.update_queued_run(
|
||||
|
|
@ -108,6 +103,25 @@ class CampaignCallDispatcher:
|
|||
campaign_id=campaign_id, processed_rows=campaign.processed_rows + 1
|
||||
)
|
||||
|
||||
except ConcurrentSlotAcquisitionError:
|
||||
# Revert all unprocessed runs (current and remaining) back to queued
|
||||
# so they can be picked up again when campaign is resumed
|
||||
for unprocessed_run in queued_runs[i:]:
|
||||
try:
|
||||
await db_client.update_queued_run(
|
||||
queued_run_id=unprocessed_run.id,
|
||||
state="queued",
|
||||
)
|
||||
logger.info(
|
||||
f"Reverted queued run {unprocessed_run.id} back to queued state"
|
||||
)
|
||||
except Exception as revert_error:
|
||||
logger.error(
|
||||
f"Failed to revert queued run {unprocessed_run.id}: {revert_error}"
|
||||
)
|
||||
# Re-raise to propagate to process_campaign_batch
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing queued run {queued_run.id}: {e}")
|
||||
|
||||
|
|
@ -129,54 +143,9 @@ class CampaignCallDispatcher:
|
|||
return processed_count
|
||||
|
||||
async def dispatch_call(
|
||||
self, queued_run: QueuedRunModel, campaign: any
|
||||
self, queued_run: QueuedRunModel, campaign: any, slot_id: str
|
||||
) -> Optional[WorkflowRunModel]:
|
||||
"""Creates workflow run and initiates call with concurrent limiting"""
|
||||
# Get concurrent limit for organization
|
||||
org_concurrent_limit = await self.get_org_concurrent_limit(
|
||||
campaign.organization_id
|
||||
)
|
||||
|
||||
# Check for campaign-level max_concurrency in orchestrator_metadata
|
||||
campaign_max_concurrency = None
|
||||
if campaign.orchestrator_metadata:
|
||||
campaign_max_concurrency = campaign.orchestrator_metadata.get(
|
||||
"max_concurrency"
|
||||
)
|
||||
|
||||
# Use the lower of campaign limit and org limit
|
||||
if campaign_max_concurrency is not None:
|
||||
max_concurrent = min(campaign_max_concurrency, org_concurrent_limit)
|
||||
else:
|
||||
max_concurrent = org_concurrent_limit
|
||||
|
||||
# Track wait time for alerting
|
||||
wait_start = time.time()
|
||||
slot_id = None
|
||||
|
||||
# Wait until we can acquire a concurrent slot
|
||||
while True:
|
||||
slot_id = await rate_limiter.try_acquire_concurrent_slot(
|
||||
campaign.organization_id, max_concurrent
|
||||
)
|
||||
if slot_id:
|
||||
break
|
||||
|
||||
# Check if we've been waiting too long
|
||||
wait_time = time.time() - wait_start
|
||||
if wait_time > 600: # 10 minutes
|
||||
logger.error(
|
||||
f"Waiting for concurrent slot for {wait_time:.1f}s, "
|
||||
f"org: {campaign.organization_id}, campaign: {campaign.id}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Attempting to get a slot for {campaign.organization_id} {campaign.id}"
|
||||
)
|
||||
|
||||
# Wait before retrying
|
||||
await asyncio.sleep(1)
|
||||
|
||||
"""Creates workflow run and initiates call. Requires a pre-acquired slot_id."""
|
||||
# Get workflow details
|
||||
workflow = await db_client.get_workflow_by_id(campaign.workflow_id)
|
||||
if not workflow:
|
||||
|
|
@ -351,6 +320,66 @@ class CampaignCallDispatcher:
|
|||
# Wait for next available slot
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
async def acquire_concurrent_slot(
|
||||
self, organization_id: int, campaign: any, timeout: float = 600
|
||||
) -> str:
|
||||
"""
|
||||
Acquires a concurrent call slot - waits if necessary until a slot is available.
|
||||
|
||||
Args:
|
||||
organization_id: The organization ID
|
||||
campaign: The campaign object
|
||||
timeout: Maximum time to wait for a slot (default 10 minutes)
|
||||
|
||||
Returns the slot_id which must be released when the call completes.
|
||||
|
||||
Raises:
|
||||
ConcurrentSlotAcquisitionError: If slot cannot be acquired within timeout
|
||||
"""
|
||||
# Get concurrent limit for organization
|
||||
org_concurrent_limit = await self.get_org_concurrent_limit(organization_id)
|
||||
|
||||
# Check for campaign-level max_concurrency in orchestrator_metadata
|
||||
campaign_max_concurrency = None
|
||||
if campaign.orchestrator_metadata:
|
||||
campaign_max_concurrency = campaign.orchestrator_metadata.get(
|
||||
"max_concurrency"
|
||||
)
|
||||
|
||||
# Use the lower of campaign limit and org limit
|
||||
if campaign_max_concurrency is not None:
|
||||
max_concurrent = min(campaign_max_concurrency, org_concurrent_limit)
|
||||
else:
|
||||
max_concurrent = org_concurrent_limit
|
||||
|
||||
# Track wait time for alerting
|
||||
wait_start = time.time()
|
||||
|
||||
# Wait until we can acquire a concurrent slot
|
||||
while True:
|
||||
slot_id = await rate_limiter.try_acquire_concurrent_slot(
|
||||
organization_id, max_concurrent
|
||||
)
|
||||
if slot_id:
|
||||
return slot_id
|
||||
|
||||
# Check if we've been waiting too long
|
||||
wait_time = time.time() - wait_start
|
||||
if wait_time > timeout:
|
||||
raise ConcurrentSlotAcquisitionError(
|
||||
organization_id=organization_id,
|
||||
campaign_id=campaign.id,
|
||||
wait_time=wait_time,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Attempting to get a slot for {organization_id} {campaign.id}, "
|
||||
f"waited {wait_time:.1f}s"
|
||||
)
|
||||
|
||||
# Wait before retrying
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def release_call_slot(self, workflow_run_id: int) -> bool:
|
||||
"""
|
||||
Release concurrent slot when a call completes.
|
||||
|
|
@ -12,6 +12,7 @@ from api.constants import REDIS_URL
|
|||
from api.enums import RedisChannel
|
||||
from api.services.campaign.campaign_event_protocol import (
|
||||
BatchCompletedEvent,
|
||||
BatchFailedEvent,
|
||||
CampaignCompletedEvent,
|
||||
RetryNeededEvent,
|
||||
SyncCompletedEvent,
|
||||
|
|
@ -43,6 +44,23 @@ class CampaignEventPublisher:
|
|||
|
||||
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
|
||||
|
||||
async def publish_batch_failed(
|
||||
self,
|
||||
campaign_id: int,
|
||||
error: str,
|
||||
processed_count: int = 0,
|
||||
metadata: Optional[Dict] = None,
|
||||
):
|
||||
"""Publish batch failed event."""
|
||||
event = BatchFailedEvent(
|
||||
campaign_id=campaign_id,
|
||||
error=error,
|
||||
processed_count=processed_count,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
await self.redis.publish(RedisChannel.CAMPAIGN_EVENTS.value, event.to_json())
|
||||
|
||||
async def publish_sync_completed(
|
||||
self,
|
||||
campaign_id: int,
|
||||
|
|
|
|||
|
|
@ -23,11 +23,13 @@ from api.db import db_client
|
|||
from api.db.models import CampaignModel, QueuedRunModel
|
||||
from api.enums import RedisChannel
|
||||
from api.services.campaign.campaign_event_protocol import (
|
||||
CampaignCompletedEvent,
|
||||
CampaignEventType,
|
||||
BatchCompletedEvent,
|
||||
BatchFailedEvent,
|
||||
RetryNeededEvent,
|
||||
SyncCompletedEvent,
|
||||
parse_campaign_event,
|
||||
)
|
||||
from api.services.campaign.campaign_event_publisher import CampaignEventPublisher
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
|
||||
|
|
@ -37,6 +39,7 @@ class CampaignOrchestrator:
|
|||
|
||||
def __init__(self, redis_client: aioredis.Redis):
|
||||
self.redis = redis_client
|
||||
self.publisher = CampaignEventPublisher(redis_client)
|
||||
self.completion_check_interval = 60 # 1 minute
|
||||
self.completion_timeout = 3600 # 1 hour
|
||||
self._processing_locks: Dict[int, datetime] = {} # prevent duplicate scheduling
|
||||
|
|
@ -97,22 +100,21 @@ class CampaignOrchestrator:
|
|||
|
||||
async def _handle_event(self, event):
|
||||
"""Handle campaign events including retry events."""
|
||||
# Handle RetryNeededEvent
|
||||
if isinstance(event, RetryNeededEvent):
|
||||
await self._handle_retry_event(event)
|
||||
return
|
||||
|
||||
# All events should have campaign_id
|
||||
if not hasattr(event, "campaign_id") or not event.campaign_id:
|
||||
logger.warning(f"Event missing campaign_id: {type(event).__name__}")
|
||||
return
|
||||
|
||||
campaign_id = event.campaign_id
|
||||
event_type = event.type
|
||||
|
||||
logger.debug(f"campaign_id: {campaign_id} - Received event: {event_type}")
|
||||
logger.debug(
|
||||
f"campaign_id: {campaign_id} - Received event: {type(event).__name__}"
|
||||
)
|
||||
|
||||
if event_type == CampaignEventType.BATCH_COMPLETED:
|
||||
if isinstance(event, RetryNeededEvent):
|
||||
await self._handle_retry_event(event)
|
||||
|
||||
elif isinstance(event, BatchCompletedEvent):
|
||||
# Clear the batch in progress flag
|
||||
if campaign_id in self._batch_in_progress:
|
||||
del self._batch_in_progress[campaign_id]
|
||||
|
|
@ -120,11 +122,42 @@ class CampaignOrchestrator:
|
|||
f"campaign_id: {campaign_id} - Batch completed, cleared in-progress flag"
|
||||
)
|
||||
|
||||
# Check campaign state before scheduling next batch
|
||||
campaign = await db_client.get_campaign_by_id(campaign_id)
|
||||
if not campaign:
|
||||
logger.error(f"campaign_id: {campaign_id} - Campaign not found")
|
||||
self._clear_campaign_state(campaign_id)
|
||||
return
|
||||
|
||||
if campaign.state != "running":
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - Campaign not in running state ({campaign.state}), "
|
||||
f"not scheduling next batch"
|
||||
)
|
||||
self._clear_campaign_state(campaign_id)
|
||||
return
|
||||
|
||||
# Immediately schedule next batch
|
||||
await self._schedule_next_batch(campaign_id)
|
||||
self._last_activity[campaign_id] = datetime.now(UTC)
|
||||
|
||||
elif event_type == CampaignEventType.SYNC_COMPLETED:
|
||||
elif isinstance(event, BatchFailedEvent):
|
||||
# Clear the batch in progress flag
|
||||
if campaign_id in self._batch_in_progress:
|
||||
del self._batch_in_progress[campaign_id]
|
||||
|
||||
logger.warning(
|
||||
f"campaign_id: {campaign_id} - Batch failed: {event.error}, "
|
||||
f"scheduling next batch to continue processing"
|
||||
)
|
||||
|
||||
# Lets not schedule another batch, since we mark the campaign
|
||||
# as failed just to be on the safe side from process_campaign_batch
|
||||
# if a batch fails
|
||||
|
||||
self._last_activity[campaign_id] = datetime.now(UTC)
|
||||
|
||||
elif isinstance(event, SyncCompletedEvent):
|
||||
# Start processing after sync
|
||||
logger.info(
|
||||
f"campaign_id: {campaign_id} - Sync completed, starting processing"
|
||||
|
|
@ -309,6 +342,16 @@ class CampaignOrchestrator:
|
|||
del self._processing_locks[campaign_id]
|
||||
logger.debug(f"campaign_id: {campaign_id} - Released processing lock")
|
||||
|
||||
def _clear_campaign_state(self, campaign_id: int):
|
||||
"""Clear all in-memory state for a campaign."""
|
||||
if campaign_id in self._last_activity:
|
||||
del self._last_activity[campaign_id]
|
||||
if campaign_id in self._processing_locks:
|
||||
del self._processing_locks[campaign_id]
|
||||
if campaign_id in self._batch_in_progress:
|
||||
del self._batch_in_progress[campaign_id]
|
||||
logger.debug(f"campaign_id: {campaign_id} - Cleared all in-memory state")
|
||||
|
||||
async def _monitor_completion(self):
|
||||
"""Periodically check for campaigns that should be marked complete."""
|
||||
while self._running:
|
||||
|
|
@ -457,30 +500,22 @@ class CampaignOrchestrator:
|
|||
|
||||
logger.info(f"campaign_id: {campaign_id} - Campaign marked as completed")
|
||||
|
||||
# Publish completion event using typed event
|
||||
completion_event = CampaignCompletedEvent(
|
||||
# Calculate duration if started_at is available
|
||||
duration = None
|
||||
if campaign.started_at:
|
||||
duration = (datetime.now(UTC) - campaign.started_at).total_seconds()
|
||||
|
||||
# Publish completion event
|
||||
await self.publisher.publish_campaign_completed(
|
||||
campaign_id=campaign_id,
|
||||
total_rows=campaign.total_rows or 0,
|
||||
processed_rows=campaign.processed_rows,
|
||||
failed_rows=campaign.failed_rows,
|
||||
)
|
||||
|
||||
# Calculate duration if started_at is available
|
||||
if campaign.started_at:
|
||||
duration = (datetime.now(UTC) - campaign.started_at).total_seconds()
|
||||
completion_event.duration_seconds = duration
|
||||
|
||||
await self.redis.publish(
|
||||
RedisChannel.CAMPAIGN_EVENTS.value, completion_event.to_json()
|
||||
duration_seconds=duration,
|
||||
)
|
||||
|
||||
# Clean up in-memory state
|
||||
if campaign_id in self._last_activity:
|
||||
del self._last_activity[campaign_id]
|
||||
if campaign_id in self._processing_locks:
|
||||
del self._processing_locks[campaign_id]
|
||||
if campaign_id in self._batch_in_progress:
|
||||
del self._batch_in_progress[campaign_id]
|
||||
self._clear_campaign_state(campaign_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
|
|
|||
16
api/services/campaign/errors.py
Normal file
16
api/services/campaign/errors.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"""
|
||||
Campaign service exceptions.
|
||||
"""
|
||||
|
||||
|
||||
class ConcurrentSlotAcquisitionError(Exception):
|
||||
"""Raised when a concurrent call slot cannot be acquired within the timeout period."""
|
||||
|
||||
def __init__(self, organization_id: int, campaign_id: int, wait_time: float):
|
||||
self.organization_id = organization_id
|
||||
self.campaign_id = campaign_id
|
||||
self.wait_time = wait_time
|
||||
super().__init__(
|
||||
f"Failed to acquire concurrent slot for org {organization_id}, "
|
||||
f"campaign {campaign_id} after waiting {wait_time:.1f}s"
|
||||
)
|
||||
18
api/services/campaign/readme.md
Normal file
18
api/services/campaign/readme.md
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
### campaign_orchestrator.py (CampaignOrchestrator)
|
||||
|
||||
- Listens to retry events, batch completed event, sync completed events from redis pubsub, and schedules batches
|
||||
- Monitors stale campaigns and schedules batches if one is not already scheduled
|
||||
- Marks campaign as completed if no more tasks pending
|
||||
|
||||
### runner.py (CampaignRunnerService)
|
||||
|
||||
- Service layer to handle router requests, like run campaign, pause campaign, resume campaign, get campaign status etc.
|
||||
|
||||
### call_dispatcher.py (CampaignCallDispatcher)
|
||||
|
||||
- Ensures rate limit and concurrency limits and dispatches call using telephony provider
|
||||
|
||||
### campaign_tasks.py
|
||||
|
||||
- sync campaign from source
|
||||
- process campaign batch
|
||||
|
|
@ -63,12 +63,10 @@ class CampaignRunnerService:
|
|||
f"Campaign must be in 'paused' state to resume, current state: {campaign.state}"
|
||||
)
|
||||
|
||||
# Update state to running
|
||||
# Update state to running. Do not queue batch since campaign orchestrator's
|
||||
# stale campaign checker would do that if there are pending work.
|
||||
await db_client.update_campaign(campaign_id=campaign_id, state="running")
|
||||
|
||||
# Enqueue process batch task to continue processing
|
||||
await enqueue_job(FunctionNames.PROCESS_CAMPAIGN_BATCH, campaign_id)
|
||||
|
||||
logger.info(f"Campaign {campaign_id} resumed")
|
||||
|
||||
async def get_campaign_status(self, campaign_id: int) -> Dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from loguru import logger
|
|||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunState
|
||||
from api.services.campaign.call_dispatcher import campaign_call_dispatcher
|
||||
from api.services.campaign.campaign_call_dispatcher import campaign_call_dispatcher
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.pipecat.in_memory_buffers import (
|
||||
InMemoryAudioBuffer,
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ REDIS_SETTINGS = RedisSettings(
|
|||
)
|
||||
|
||||
from api.tasks.campaign_tasks import (
|
||||
monitor_campaign_progress,
|
||||
process_campaign_batch,
|
||||
sync_campaign_source,
|
||||
)
|
||||
|
|
@ -62,7 +61,6 @@ class WorkerSettings:
|
|||
process_workflow_completion,
|
||||
sync_campaign_source,
|
||||
process_campaign_batch,
|
||||
monitor_campaign_progress,
|
||||
process_knowledge_base_document,
|
||||
]
|
||||
cron_jobs = []
|
||||
|
|
|
|||
|
|
@ -4,12 +4,11 @@ from typing import Dict
|
|||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import RedisChannel
|
||||
from api.services.campaign.call_dispatcher import campaign_call_dispatcher
|
||||
from api.services.campaign.campaign_event_protocol import BatchFailedEvent
|
||||
from api.services.campaign.campaign_call_dispatcher import campaign_call_dispatcher
|
||||
from api.services.campaign.campaign_event_publisher import (
|
||||
get_campaign_event_publisher,
|
||||
)
|
||||
from api.services.campaign.errors import ConcurrentSlotAcquisitionError
|
||||
from api.services.campaign.source_sync_factory import get_sync_service
|
||||
|
||||
|
||||
|
|
@ -95,6 +94,10 @@ async def process_campaign_batch(
|
|||
- Updates queued_run state to 'processed'
|
||||
- Updates campaign.processed_rows counter
|
||||
- Publishes batch_completed event for orchestrator
|
||||
|
||||
# TODO: May be not fail the campaign immediately on a single batch failure
|
||||
# and propagate the error to campaign orchestrator which can fail the campaign
|
||||
# on some consecutive batch failures.
|
||||
"""
|
||||
logger.info(f"Processing batch for campaign {campaign_id}, batch_size={batch_size}")
|
||||
|
||||
|
|
@ -119,81 +122,34 @@ async def process_campaign_batch(
|
|||
f"failed={failed_count}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing batch for campaign {campaign_id}: {e}")
|
||||
|
||||
# Publish batch failed event
|
||||
publisher = await get_campaign_event_publisher()
|
||||
event = BatchFailedEvent(
|
||||
campaign_id=campaign_id,
|
||||
error=str(e),
|
||||
processed_count=0,
|
||||
except ConcurrentSlotAcquisitionError as e:
|
||||
logger.warning(
|
||||
f"Failed to acquire concurrent slot for campaign {campaign_id}: {e}"
|
||||
)
|
||||
await publisher.redis.publish(
|
||||
RedisChannel.CAMPAIGN_EVENTS.value, event.to_json()
|
||||
|
||||
# Publish batch failed event with specific error
|
||||
publisher = await get_campaign_event_publisher()
|
||||
await publisher.publish_batch_failed(
|
||||
campaign_id=campaign_id,
|
||||
error=f"Concurrent slot acquisition timeout: {e}",
|
||||
processed_count=0,
|
||||
)
|
||||
|
||||
# Update campaign state to failed
|
||||
await db_client.update_campaign(campaign_id=campaign_id, state="failed")
|
||||
raise
|
||||
|
||||
|
||||
async def monitor_campaign_progress(ctx: Dict, campaign_id: int) -> None:
|
||||
"""
|
||||
Phase 3: Monitors campaign completion
|
||||
- Checks if all queued runs are in 'processed' state
|
||||
- Queries workflow_runs for final call statistics
|
||||
- Updates campaign state to 'completed'
|
||||
- Calculates total calls made, successful, failed
|
||||
- Triggers post-campaign integrations
|
||||
"""
|
||||
logger.info(f"Monitoring progress for campaign {campaign_id}")
|
||||
|
||||
try:
|
||||
# Get campaign
|
||||
campaign = await db_client.get_campaign_by_id(campaign_id)
|
||||
if not campaign:
|
||||
raise ValueError(f"Campaign {campaign_id} not found")
|
||||
|
||||
# Check if all runs are processed
|
||||
pending_runs = await db_client.count_queued_runs(
|
||||
campaign_id=campaign_id, state="queued"
|
||||
)
|
||||
|
||||
if pending_runs > 0:
|
||||
logger.info(f"Campaign {campaign_id} still has {pending_runs} pending runs")
|
||||
return
|
||||
|
||||
# All runs processed, mark campaign as completed
|
||||
await db_client.update_campaign(
|
||||
campaign_id=campaign_id, state="completed", completed_at=datetime.now(UTC)
|
||||
)
|
||||
|
||||
# Calculate statistics
|
||||
workflow_runs = await db_client.get_workflow_runs_by_campaign(campaign_id)
|
||||
|
||||
total_calls = len(workflow_runs)
|
||||
successful_calls = 0
|
||||
failed_calls = 0
|
||||
|
||||
for run in workflow_runs:
|
||||
callbacks = run.logs.get("telephony_status_callbacks", [])
|
||||
if callbacks:
|
||||
final_status = callbacks[-1].get("status", "").lower()
|
||||
if final_status == "completed":
|
||||
successful_calls += 1
|
||||
elif final_status in ["failed", "busy", "no-answer"]:
|
||||
failed_calls += 1
|
||||
|
||||
logger.info(
|
||||
f"Campaign {campaign_id} completed: "
|
||||
f"Total calls: {total_calls}, "
|
||||
f"Successful: {successful_calls}, "
|
||||
f"Failed: {failed_calls}"
|
||||
)
|
||||
|
||||
# TODO: Trigger post-campaign integrations if configured
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error monitoring campaign {campaign_id}: {e}")
|
||||
logger.error(f"Error processing batch for campaign {campaign_id}: {e}")
|
||||
|
||||
# Publish batch failed event
|
||||
publisher = await get_campaign_event_publisher()
|
||||
await publisher.publish_batch_failed(
|
||||
campaign_id=campaign_id,
|
||||
error=str(e),
|
||||
processed_count=0,
|
||||
)
|
||||
|
||||
# Update campaign state to failed
|
||||
await db_client.update_campaign(campaign_id=campaign_id, state="failed")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -5,5 +5,4 @@ class FunctionNames:
|
|||
UPLOAD_VOICEMAIL_AUDIO_TO_S3 = "upload_voicemail_audio_to_s3"
|
||||
SYNC_CAMPAIGN_SOURCE = "sync_campaign_source"
|
||||
PROCESS_CAMPAIGN_BATCH = "process_campaign_batch"
|
||||
MONITOR_CAMPAIGN_PROGRESS = "monitor_campaign_progress"
|
||||
PROCESS_KNOWLEDGE_BASE_DOCUMENT = "process_knowledge_base_document"
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ from typing import Any, Dict, Optional
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from api.constants import DATABASE_URL
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
ExtractionVariableDTO,
|
||||
|
|
@ -549,3 +551,22 @@ def three_node_workflow_no_variable_extraction() -> WorkflowGraph:
|
|||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Database fixtures for integration tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def db_engine():
|
||||
"""Create database engine for tests."""
|
||||
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def db_session_factory(db_engine):
|
||||
"""Create session factory for tests."""
|
||||
return async_sessionmaker(bind=db_engine, expire_on_commit=False)
|
||||
|
|
|
|||
603
api/tests/test_campaign_call_dispatcher.py
Normal file
603
api/tests/test_campaign_call_dispatcher.py
Normal file
|
|
@ -0,0 +1,603 @@
|
|||
"""
|
||||
Tests for CampaignCallDispatcher.process_batch method.
|
||||
|
||||
These tests verify:
|
||||
1. Basic batch processing functionality
|
||||
2. Thread-safety via SELECT FOR UPDATE SKIP LOCKED
|
||||
3. Race condition handling when multiple workers process concurrently
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete, text
|
||||
|
||||
from api.db.models import (
|
||||
CampaignModel,
|
||||
OrganizationModel,
|
||||
QueuedRunModel,
|
||||
UserModel,
|
||||
WorkflowModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from api.services.campaign.campaign_call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
# =============================================================================
|
||||
# Test-specific fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignTestData:
|
||||
"""Container for campaign test data IDs"""
|
||||
|
||||
organization_id: int
|
||||
user_id: int
|
||||
workflow_id: int
|
||||
campaign_id: int
|
||||
queued_run_ids: List[int]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def campaign_test_data(db_session_factory) -> CampaignTestData:
|
||||
"""
|
||||
Create test data for campaign processing tests.
|
||||
|
||||
Creates:
|
||||
- Organization
|
||||
- User
|
||||
- Workflow
|
||||
- Campaign (in 'running' state)
|
||||
- 10 QueuedRuns (in 'queued' state)
|
||||
"""
|
||||
async with db_session_factory() as session:
|
||||
# Create organization
|
||||
org = OrganizationModel(
|
||||
provider_id=f"test-org-{uuid.uuid4().hex[:8]}",
|
||||
)
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
# Create user
|
||||
user = UserModel(
|
||||
provider_id=f"test-user-{uuid.uuid4().hex[:8]}",
|
||||
selected_organization_id=org.id,
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# Create workflow
|
||||
workflow = WorkflowModel(
|
||||
name=f"test-workflow-{uuid.uuid4().hex[:8]}",
|
||||
user_id=user.id,
|
||||
organization_id=org.id,
|
||||
workflow_definition={
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {"name": "Start", "prompt": "Hello"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
},
|
||||
template_context_variables={},
|
||||
)
|
||||
session.add(workflow)
|
||||
await session.flush()
|
||||
|
||||
# Create campaign
|
||||
campaign = CampaignModel(
|
||||
name=f"test-campaign-{uuid.uuid4().hex[:8]}",
|
||||
organization_id=org.id,
|
||||
workflow_id=workflow.id,
|
||||
created_by=user.id,
|
||||
source_type="test",
|
||||
source_id="test-source",
|
||||
state="running",
|
||||
rate_limit_per_second=100, # High limit to avoid rate limiting in tests
|
||||
)
|
||||
session.add(campaign)
|
||||
await session.flush()
|
||||
|
||||
# Create queued runs
|
||||
queued_run_ids = []
|
||||
for i in range(10):
|
||||
queued_run = QueuedRunModel(
|
||||
campaign_id=campaign.id,
|
||||
source_uuid=f"test-uuid-{i}",
|
||||
context_variables={"phone_number": f"+1555000{i:04d}"},
|
||||
state="queued",
|
||||
)
|
||||
session.add(queued_run)
|
||||
await session.flush()
|
||||
queued_run_ids.append(queued_run.id)
|
||||
|
||||
await session.commit()
|
||||
|
||||
test_data = CampaignTestData(
|
||||
organization_id=org.id,
|
||||
user_id=user.id,
|
||||
workflow_id=workflow.id,
|
||||
campaign_id=campaign.id,
|
||||
queued_run_ids=queued_run_ids,
|
||||
)
|
||||
|
||||
yield test_data
|
||||
|
||||
# Cleanup
|
||||
async with db_session_factory() as cleanup_session:
|
||||
# Delete in reverse order of dependencies
|
||||
await cleanup_session.execute(
|
||||
delete(QueuedRunModel).where(QueuedRunModel.campaign_id == campaign.id)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(WorkflowRunModel).where(
|
||||
WorkflowRunModel.campaign_id == campaign.id
|
||||
)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(CampaignModel).where(CampaignModel.id == campaign.id)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(WorkflowModel).where(WorkflowModel.id == workflow.id)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(UserModel).where(UserModel.id == user.id)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(OrganizationModel).where(OrganizationModel.id == org.id)
|
||||
)
|
||||
await cleanup_session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dispatch_call():
|
||||
"""Mock dispatch_call to track which runs were processed."""
|
||||
processed_runs = []
|
||||
|
||||
async def mock_dispatch(queued_run, campaign, slot_id):
|
||||
# Simulate some processing time
|
||||
await asyncio.sleep(0.01)
|
||||
processed_runs.append(queued_run.id)
|
||||
# Return a mock workflow run
|
||||
mock_run = MagicMock()
|
||||
mock_run.id = len(processed_runs)
|
||||
return mock_run
|
||||
|
||||
return mock_dispatch, processed_runs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rate_limiter():
|
||||
"""Mock rate limiter to always allow calls."""
|
||||
|
||||
async def mock_acquire_token(*args, **kwargs):
|
||||
return True
|
||||
|
||||
async def mock_try_acquire_slot(*args, **kwargs):
|
||||
return f"slot-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async def mock_release_slot(*args, **kwargs):
|
||||
return True
|
||||
|
||||
async def mock_store_mapping(*args, **kwargs):
|
||||
pass
|
||||
|
||||
async def mock_get_mapping(*args, **kwargs):
|
||||
return None
|
||||
|
||||
async def mock_delete_mapping(*args, **kwargs):
|
||||
pass
|
||||
|
||||
return {
|
||||
"acquire_token": mock_acquire_token,
|
||||
"try_acquire_concurrent_slot": mock_try_acquire_slot,
|
||||
"release_concurrent_slot": mock_release_slot,
|
||||
"store_workflow_slot_mapping": mock_store_mapping,
|
||||
"get_workflow_slot_mapping": mock_get_mapping,
|
||||
"delete_workflow_slot_mapping": mock_delete_mapping,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestProcessBatchBasic:
|
||||
"""Basic tests for process_batch functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_batch_processes_queued_runs(
|
||||
self, campaign_test_data, mock_dispatch_call, mock_rate_limiter
|
||||
):
|
||||
"""Test that process_batch processes queued runs and marks them as processed."""
|
||||
mock_dispatch, processed_runs = mock_dispatch_call
|
||||
|
||||
with patch(
|
||||
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
|
||||
) as mock_rl:
|
||||
# Setup rate limiter mocks
|
||||
mock_rl.acquire_token = AsyncMock(
|
||||
side_effect=mock_rate_limiter["acquire_token"]
|
||||
)
|
||||
mock_rl.try_acquire_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
|
||||
)
|
||||
mock_rl.release_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["release_concurrent_slot"]
|
||||
)
|
||||
mock_rl.store_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["store_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.get_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["get_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.delete_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["delete_workflow_slot_mapping"]
|
||||
)
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock dispatch_call
|
||||
with patch.object(dispatcher, "dispatch_call", side_effect=mock_dispatch):
|
||||
# Process batch of 5
|
||||
processed_count = await dispatcher.process_batch(
|
||||
campaign_id=campaign_test_data.campaign_id, batch_size=5
|
||||
)
|
||||
|
||||
assert processed_count == 5
|
||||
assert len(processed_runs) == 5
|
||||
|
||||
|
||||
class TestProcessBatchConcurrency:
|
||||
"""Tests for concurrent batch processing and database locking."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_process_batch_no_duplicate_processing(
|
||||
self,
|
||||
campaign_test_data,
|
||||
mock_dispatch_call,
|
||||
mock_rate_limiter,
|
||||
db_session_factory,
|
||||
):
|
||||
"""
|
||||
Test that two concurrent process_batch calls don't process the same runs.
|
||||
|
||||
This verifies the SELECT FOR UPDATE SKIP LOCKED mechanism works correctly.
|
||||
"""
|
||||
mock_dispatch, processed_runs = mock_dispatch_call
|
||||
|
||||
# Reset queued runs to 'queued' state for this test
|
||||
async with db_session_factory() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE queued_runs SET state = 'queued' WHERE campaign_id = :campaign_id"
|
||||
),
|
||||
{"campaign_id": campaign_test_data.campaign_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def run_process_batch():
|
||||
"""Helper to run process_batch with mocked dependencies."""
|
||||
with patch(
|
||||
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
|
||||
) as mock_rl:
|
||||
mock_rl.acquire_token = AsyncMock(
|
||||
side_effect=mock_rate_limiter["acquire_token"]
|
||||
)
|
||||
mock_rl.try_acquire_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
|
||||
)
|
||||
mock_rl.release_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["release_concurrent_slot"]
|
||||
)
|
||||
mock_rl.store_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["store_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.get_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["get_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.delete_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["delete_workflow_slot_mapping"]
|
||||
)
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
with patch.object(
|
||||
dispatcher, "dispatch_call", side_effect=mock_dispatch
|
||||
):
|
||||
return await dispatcher.process_batch(
|
||||
campaign_id=campaign_test_data.campaign_id, batch_size=5
|
||||
)
|
||||
|
||||
# Run two process_batch calls concurrently
|
||||
results = await asyncio.gather(
|
||||
run_process_batch(),
|
||||
run_process_batch(),
|
||||
)
|
||||
|
||||
# Total processed should be 10 (all queued runs)
|
||||
total_processed = sum(results)
|
||||
assert total_processed == 10, f"Expected 10 total, got {total_processed}"
|
||||
|
||||
# Each run should be processed exactly once (no duplicates)
|
||||
assert len(processed_runs) == 10, f"Expected 10 runs, got {len(processed_runs)}"
|
||||
assert len(set(processed_runs)) == 10, "Duplicate runs were processed!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_process_batch_with_different_batch_sizes(
|
||||
self,
|
||||
campaign_test_data,
|
||||
mock_dispatch_call,
|
||||
mock_rate_limiter,
|
||||
db_session_factory,
|
||||
):
|
||||
"""
|
||||
Test concurrent processing with different batch sizes.
|
||||
|
||||
Worker 1 requests 3 runs, Worker 2 requests 7 runs.
|
||||
Total should still be 10 with no duplicates.
|
||||
"""
|
||||
mock_dispatch, processed_runs = mock_dispatch_call
|
||||
|
||||
# Reset queued runs to 'queued' state
|
||||
async with db_session_factory() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE queued_runs SET state = 'queued' WHERE campaign_id = :campaign_id"
|
||||
),
|
||||
{"campaign_id": campaign_test_data.campaign_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def run_process_batch(batch_size: int):
|
||||
with patch(
|
||||
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
|
||||
) as mock_rl:
|
||||
mock_rl.acquire_token = AsyncMock(
|
||||
side_effect=mock_rate_limiter["acquire_token"]
|
||||
)
|
||||
mock_rl.try_acquire_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
|
||||
)
|
||||
mock_rl.release_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["release_concurrent_slot"]
|
||||
)
|
||||
mock_rl.store_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["store_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.get_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["get_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.delete_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["delete_workflow_slot_mapping"]
|
||||
)
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
with patch.object(
|
||||
dispatcher, "dispatch_call", side_effect=mock_dispatch
|
||||
):
|
||||
return await dispatcher.process_batch(
|
||||
campaign_id=campaign_test_data.campaign_id,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
# Run with different batch sizes concurrently
|
||||
results = await asyncio.gather(
|
||||
run_process_batch(3),
|
||||
run_process_batch(7),
|
||||
)
|
||||
|
||||
total_processed = sum(results)
|
||||
assert total_processed == 10
|
||||
|
||||
# Verify no duplicates
|
||||
assert len(set(processed_runs)) == len(processed_runs)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_concurrent_workers(
|
||||
self,
|
||||
campaign_test_data,
|
||||
mock_dispatch_call,
|
||||
mock_rate_limiter,
|
||||
db_session_factory,
|
||||
):
|
||||
"""
|
||||
Test with many concurrent workers (simulating production scenario).
|
||||
|
||||
5 workers each requesting 4 runs from a pool of 10.
|
||||
Should process all 10 exactly once.
|
||||
"""
|
||||
mock_dispatch, processed_runs = mock_dispatch_call
|
||||
|
||||
# Reset queued runs
|
||||
async with db_session_factory() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE queued_runs SET state = 'queued' WHERE campaign_id = :campaign_id"
|
||||
),
|
||||
{"campaign_id": campaign_test_data.campaign_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def run_process_batch():
|
||||
with patch(
|
||||
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
|
||||
) as mock_rl:
|
||||
mock_rl.acquire_token = AsyncMock(
|
||||
side_effect=mock_rate_limiter["acquire_token"]
|
||||
)
|
||||
mock_rl.try_acquire_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
|
||||
)
|
||||
mock_rl.release_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["release_concurrent_slot"]
|
||||
)
|
||||
mock_rl.store_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["store_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.get_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["get_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.delete_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["delete_workflow_slot_mapping"]
|
||||
)
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
with patch.object(
|
||||
dispatcher, "dispatch_call", side_effect=mock_dispatch
|
||||
):
|
||||
return await dispatcher.process_batch(
|
||||
campaign_id=campaign_test_data.campaign_id, batch_size=4
|
||||
)
|
||||
|
||||
# Run 5 workers concurrently
|
||||
results = await asyncio.gather(*[run_process_batch() for _ in range(5)])
|
||||
|
||||
total_processed = sum(results)
|
||||
assert total_processed == 10
|
||||
|
||||
# Verify no duplicates
|
||||
assert len(set(processed_runs)) == 10, "Duplicate runs were processed!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processing_state_transition(
|
||||
self,
|
||||
campaign_test_data,
|
||||
mock_dispatch_call,
|
||||
mock_rate_limiter,
|
||||
db_session_factory,
|
||||
):
|
||||
"""
|
||||
Test that runs transition through processing -> processed states correctly.
|
||||
"""
|
||||
mock_dispatch, processed_runs = mock_dispatch_call
|
||||
|
||||
# Reset queued runs
|
||||
async with db_session_factory() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE queued_runs SET state = 'queued' WHERE campaign_id = :campaign_id"
|
||||
),
|
||||
{"campaign_id": campaign_test_data.campaign_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
with patch(
|
||||
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
|
||||
) as mock_rl:
|
||||
mock_rl.acquire_token = AsyncMock(
|
||||
side_effect=mock_rate_limiter["acquire_token"]
|
||||
)
|
||||
mock_rl.try_acquire_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
|
||||
)
|
||||
mock_rl.release_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["release_concurrent_slot"]
|
||||
)
|
||||
mock_rl.store_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["store_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.get_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["get_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.delete_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["delete_workflow_slot_mapping"]
|
||||
)
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
with patch.object(dispatcher, "dispatch_call", side_effect=mock_dispatch):
|
||||
await dispatcher.process_batch(
|
||||
campaign_id=campaign_test_data.campaign_id, batch_size=10
|
||||
)
|
||||
|
||||
# Verify all runs are in 'processed' state
|
||||
async with db_session_factory() as session:
|
||||
result = await session.execute(
|
||||
text(
|
||||
"SELECT state, COUNT(*) as count FROM queued_runs "
|
||||
"WHERE campaign_id = :campaign_id GROUP BY state"
|
||||
),
|
||||
{"campaign_id": campaign_test_data.campaign_id},
|
||||
)
|
||||
states = {row[0]: row[1] for row in result.fetchall()}
|
||||
|
||||
assert states.get("processed", 0) == 10
|
||||
assert states.get("queued", 0) == 0
|
||||
assert states.get("processing", 0) == 0
|
||||
|
||||
|
||||
class TestProcessBatchEdgeCases:
|
||||
"""Edge case tests for process_batch."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_queue(
|
||||
self, campaign_test_data, mock_rate_limiter, db_session_factory
|
||||
):
|
||||
"""Test process_batch with no queued runs returns 0."""
|
||||
# Set all runs to processed
|
||||
async with db_session_factory() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE queued_runs SET state = 'processed' WHERE campaign_id = :campaign_id"
|
||||
),
|
||||
{"campaign_id": campaign_test_data.campaign_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
with patch(
|
||||
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
|
||||
) as mock_rl:
|
||||
mock_rl.acquire_token = AsyncMock(
|
||||
side_effect=mock_rate_limiter["acquire_token"]
|
||||
)
|
||||
mock_rl.try_acquire_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
|
||||
)
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
result = await dispatcher.process_batch(
|
||||
campaign_id=campaign_test_data.campaign_id, batch_size=5
|
||||
)
|
||||
|
||||
assert result == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_campaign_not_running(
|
||||
self, campaign_test_data, mock_rate_limiter, db_session_factory
|
||||
):
|
||||
"""Test process_batch returns 0 if campaign is not in running state."""
|
||||
# Set campaign to paused
|
||||
async with db_session_factory() as session:
|
||||
await session.execute(
|
||||
text("UPDATE campaigns SET state = 'paused' WHERE id = :campaign_id"),
|
||||
{"campaign_id": campaign_test_data.campaign_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
result = await dispatcher.process_batch(
|
||||
campaign_id=campaign_test_data.campaign_id, batch_size=5
|
||||
)
|
||||
assert result == 0
|
||||
finally:
|
||||
# Restore campaign state
|
||||
async with db_session_factory() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE campaigns SET state = 'running' WHERE id = :campaign_id"
|
||||
),
|
||||
{"campaign_id": campaign_test_data.campaign_id},
|
||||
)
|
||||
await session.commit()
|
||||
Loading…
Add table
Add a link
Reference in a new issue