mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-25 08:48:13 +02:00
fix: make campaign process batch thread safe (#141)
* fix: dont schedule new batch on resume * fix: make process_batch thread safe
This commit is contained in:
parent
e9c5da16c5
commit
6827744327
17 changed files with 1012 additions and 230 deletions
|
|
@ -3,7 +3,9 @@ from typing import Any, Dict, Optional
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from api.constants import DATABASE_URL
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
ExtractionVariableDTO,
|
||||
|
|
@ -549,3 +551,22 @@ def three_node_workflow_no_variable_extraction() -> WorkflowGraph:
|
|||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Database fixtures for integration tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def db_engine():
|
||||
"""Create database engine for tests."""
|
||||
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def db_session_factory(db_engine):
|
||||
"""Create session factory for tests."""
|
||||
return async_sessionmaker(bind=db_engine, expire_on_commit=False)
|
||||
|
|
|
|||
603
api/tests/test_campaign_call_dispatcher.py
Normal file
603
api/tests/test_campaign_call_dispatcher.py
Normal file
|
|
@ -0,0 +1,603 @@
|
|||
"""
|
||||
Tests for CampaignCallDispatcher.process_batch method.
|
||||
|
||||
These tests verify:
|
||||
1. Basic batch processing functionality
|
||||
2. Thread-safety via SELECT FOR UPDATE SKIP LOCKED
|
||||
3. Race condition handling when multiple workers process concurrently
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete, text
|
||||
|
||||
from api.db.models import (
|
||||
CampaignModel,
|
||||
OrganizationModel,
|
||||
QueuedRunModel,
|
||||
UserModel,
|
||||
WorkflowModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from api.services.campaign.campaign_call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
# =============================================================================
|
||||
# Test-specific fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignTestData:
|
||||
"""Container for campaign test data IDs"""
|
||||
|
||||
organization_id: int
|
||||
user_id: int
|
||||
workflow_id: int
|
||||
campaign_id: int
|
||||
queued_run_ids: List[int]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def campaign_test_data(db_session_factory) -> CampaignTestData:
|
||||
"""
|
||||
Create test data for campaign processing tests.
|
||||
|
||||
Creates:
|
||||
- Organization
|
||||
- User
|
||||
- Workflow
|
||||
- Campaign (in 'running' state)
|
||||
- 10 QueuedRuns (in 'queued' state)
|
||||
"""
|
||||
async with db_session_factory() as session:
|
||||
# Create organization
|
||||
org = OrganizationModel(
|
||||
provider_id=f"test-org-{uuid.uuid4().hex[:8]}",
|
||||
)
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
# Create user
|
||||
user = UserModel(
|
||||
provider_id=f"test-user-{uuid.uuid4().hex[:8]}",
|
||||
selected_organization_id=org.id,
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
# Create workflow
|
||||
workflow = WorkflowModel(
|
||||
name=f"test-workflow-{uuid.uuid4().hex[:8]}",
|
||||
user_id=user.id,
|
||||
organization_id=org.id,
|
||||
workflow_definition={
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {"name": "Start", "prompt": "Hello"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
},
|
||||
template_context_variables={},
|
||||
)
|
||||
session.add(workflow)
|
||||
await session.flush()
|
||||
|
||||
# Create campaign
|
||||
campaign = CampaignModel(
|
||||
name=f"test-campaign-{uuid.uuid4().hex[:8]}",
|
||||
organization_id=org.id,
|
||||
workflow_id=workflow.id,
|
||||
created_by=user.id,
|
||||
source_type="test",
|
||||
source_id="test-source",
|
||||
state="running",
|
||||
rate_limit_per_second=100, # High limit to avoid rate limiting in tests
|
||||
)
|
||||
session.add(campaign)
|
||||
await session.flush()
|
||||
|
||||
# Create queued runs
|
||||
queued_run_ids = []
|
||||
for i in range(10):
|
||||
queued_run = QueuedRunModel(
|
||||
campaign_id=campaign.id,
|
||||
source_uuid=f"test-uuid-{i}",
|
||||
context_variables={"phone_number": f"+1555000{i:04d}"},
|
||||
state="queued",
|
||||
)
|
||||
session.add(queued_run)
|
||||
await session.flush()
|
||||
queued_run_ids.append(queued_run.id)
|
||||
|
||||
await session.commit()
|
||||
|
||||
test_data = CampaignTestData(
|
||||
organization_id=org.id,
|
||||
user_id=user.id,
|
||||
workflow_id=workflow.id,
|
||||
campaign_id=campaign.id,
|
||||
queued_run_ids=queued_run_ids,
|
||||
)
|
||||
|
||||
yield test_data
|
||||
|
||||
# Cleanup
|
||||
async with db_session_factory() as cleanup_session:
|
||||
# Delete in reverse order of dependencies
|
||||
await cleanup_session.execute(
|
||||
delete(QueuedRunModel).where(QueuedRunModel.campaign_id == campaign.id)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(WorkflowRunModel).where(
|
||||
WorkflowRunModel.campaign_id == campaign.id
|
||||
)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(CampaignModel).where(CampaignModel.id == campaign.id)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(WorkflowModel).where(WorkflowModel.id == workflow.id)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(UserModel).where(UserModel.id == user.id)
|
||||
)
|
||||
await cleanup_session.execute(
|
||||
delete(OrganizationModel).where(OrganizationModel.id == org.id)
|
||||
)
|
||||
await cleanup_session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dispatch_call():
|
||||
"""Mock dispatch_call to track which runs were processed."""
|
||||
processed_runs = []
|
||||
|
||||
async def mock_dispatch(queued_run, campaign, slot_id):
|
||||
# Simulate some processing time
|
||||
await asyncio.sleep(0.01)
|
||||
processed_runs.append(queued_run.id)
|
||||
# Return a mock workflow run
|
||||
mock_run = MagicMock()
|
||||
mock_run.id = len(processed_runs)
|
||||
return mock_run
|
||||
|
||||
return mock_dispatch, processed_runs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rate_limiter():
|
||||
"""Mock rate limiter to always allow calls."""
|
||||
|
||||
async def mock_acquire_token(*args, **kwargs):
|
||||
return True
|
||||
|
||||
async def mock_try_acquire_slot(*args, **kwargs):
|
||||
return f"slot-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async def mock_release_slot(*args, **kwargs):
|
||||
return True
|
||||
|
||||
async def mock_store_mapping(*args, **kwargs):
|
||||
pass
|
||||
|
||||
async def mock_get_mapping(*args, **kwargs):
|
||||
return None
|
||||
|
||||
async def mock_delete_mapping(*args, **kwargs):
|
||||
pass
|
||||
|
||||
return {
|
||||
"acquire_token": mock_acquire_token,
|
||||
"try_acquire_concurrent_slot": mock_try_acquire_slot,
|
||||
"release_concurrent_slot": mock_release_slot,
|
||||
"store_workflow_slot_mapping": mock_store_mapping,
|
||||
"get_workflow_slot_mapping": mock_get_mapping,
|
||||
"delete_workflow_slot_mapping": mock_delete_mapping,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestProcessBatchBasic:
|
||||
"""Basic tests for process_batch functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_batch_processes_queued_runs(
|
||||
self, campaign_test_data, mock_dispatch_call, mock_rate_limiter
|
||||
):
|
||||
"""Test that process_batch processes queued runs and marks them as processed."""
|
||||
mock_dispatch, processed_runs = mock_dispatch_call
|
||||
|
||||
with patch(
|
||||
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
|
||||
) as mock_rl:
|
||||
# Setup rate limiter mocks
|
||||
mock_rl.acquire_token = AsyncMock(
|
||||
side_effect=mock_rate_limiter["acquire_token"]
|
||||
)
|
||||
mock_rl.try_acquire_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
|
||||
)
|
||||
mock_rl.release_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["release_concurrent_slot"]
|
||||
)
|
||||
mock_rl.store_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["store_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.get_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["get_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.delete_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["delete_workflow_slot_mapping"]
|
||||
)
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock dispatch_call
|
||||
with patch.object(dispatcher, "dispatch_call", side_effect=mock_dispatch):
|
||||
# Process batch of 5
|
||||
processed_count = await dispatcher.process_batch(
|
||||
campaign_id=campaign_test_data.campaign_id, batch_size=5
|
||||
)
|
||||
|
||||
assert processed_count == 5
|
||||
assert len(processed_runs) == 5
|
||||
|
||||
|
||||
class TestProcessBatchConcurrency:
|
||||
"""Tests for concurrent batch processing and database locking."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_process_batch_no_duplicate_processing(
|
||||
self,
|
||||
campaign_test_data,
|
||||
mock_dispatch_call,
|
||||
mock_rate_limiter,
|
||||
db_session_factory,
|
||||
):
|
||||
"""
|
||||
Test that two concurrent process_batch calls don't process the same runs.
|
||||
|
||||
This verifies the SELECT FOR UPDATE SKIP LOCKED mechanism works correctly.
|
||||
"""
|
||||
mock_dispatch, processed_runs = mock_dispatch_call
|
||||
|
||||
# Reset queued runs to 'queued' state for this test
|
||||
async with db_session_factory() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE queued_runs SET state = 'queued' WHERE campaign_id = :campaign_id"
|
||||
),
|
||||
{"campaign_id": campaign_test_data.campaign_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def run_process_batch():
|
||||
"""Helper to run process_batch with mocked dependencies."""
|
||||
with patch(
|
||||
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
|
||||
) as mock_rl:
|
||||
mock_rl.acquire_token = AsyncMock(
|
||||
side_effect=mock_rate_limiter["acquire_token"]
|
||||
)
|
||||
mock_rl.try_acquire_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
|
||||
)
|
||||
mock_rl.release_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["release_concurrent_slot"]
|
||||
)
|
||||
mock_rl.store_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["store_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.get_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["get_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.delete_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["delete_workflow_slot_mapping"]
|
||||
)
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
with patch.object(
|
||||
dispatcher, "dispatch_call", side_effect=mock_dispatch
|
||||
):
|
||||
return await dispatcher.process_batch(
|
||||
campaign_id=campaign_test_data.campaign_id, batch_size=5
|
||||
)
|
||||
|
||||
# Run two process_batch calls concurrently
|
||||
results = await asyncio.gather(
|
||||
run_process_batch(),
|
||||
run_process_batch(),
|
||||
)
|
||||
|
||||
# Total processed should be 10 (all queued runs)
|
||||
total_processed = sum(results)
|
||||
assert total_processed == 10, f"Expected 10 total, got {total_processed}"
|
||||
|
||||
# Each run should be processed exactly once (no duplicates)
|
||||
assert len(processed_runs) == 10, f"Expected 10 runs, got {len(processed_runs)}"
|
||||
assert len(set(processed_runs)) == 10, "Duplicate runs were processed!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_process_batch_with_different_batch_sizes(
|
||||
self,
|
||||
campaign_test_data,
|
||||
mock_dispatch_call,
|
||||
mock_rate_limiter,
|
||||
db_session_factory,
|
||||
):
|
||||
"""
|
||||
Test concurrent processing with different batch sizes.
|
||||
|
||||
Worker 1 requests 3 runs, Worker 2 requests 7 runs.
|
||||
Total should still be 10 with no duplicates.
|
||||
"""
|
||||
mock_dispatch, processed_runs = mock_dispatch_call
|
||||
|
||||
# Reset queued runs to 'queued' state
|
||||
async with db_session_factory() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE queued_runs SET state = 'queued' WHERE campaign_id = :campaign_id"
|
||||
),
|
||||
{"campaign_id": campaign_test_data.campaign_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def run_process_batch(batch_size: int):
|
||||
with patch(
|
||||
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
|
||||
) as mock_rl:
|
||||
mock_rl.acquire_token = AsyncMock(
|
||||
side_effect=mock_rate_limiter["acquire_token"]
|
||||
)
|
||||
mock_rl.try_acquire_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
|
||||
)
|
||||
mock_rl.release_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["release_concurrent_slot"]
|
||||
)
|
||||
mock_rl.store_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["store_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.get_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["get_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.delete_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["delete_workflow_slot_mapping"]
|
||||
)
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
with patch.object(
|
||||
dispatcher, "dispatch_call", side_effect=mock_dispatch
|
||||
):
|
||||
return await dispatcher.process_batch(
|
||||
campaign_id=campaign_test_data.campaign_id,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
# Run with different batch sizes concurrently
|
||||
results = await asyncio.gather(
|
||||
run_process_batch(3),
|
||||
run_process_batch(7),
|
||||
)
|
||||
|
||||
total_processed = sum(results)
|
||||
assert total_processed == 10
|
||||
|
||||
# Verify no duplicates
|
||||
assert len(set(processed_runs)) == len(processed_runs)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_concurrent_workers(
|
||||
self,
|
||||
campaign_test_data,
|
||||
mock_dispatch_call,
|
||||
mock_rate_limiter,
|
||||
db_session_factory,
|
||||
):
|
||||
"""
|
||||
Test with many concurrent workers (simulating production scenario).
|
||||
|
||||
5 workers each requesting 4 runs from a pool of 10.
|
||||
Should process all 10 exactly once.
|
||||
"""
|
||||
mock_dispatch, processed_runs = mock_dispatch_call
|
||||
|
||||
# Reset queued runs
|
||||
async with db_session_factory() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE queued_runs SET state = 'queued' WHERE campaign_id = :campaign_id"
|
||||
),
|
||||
{"campaign_id": campaign_test_data.campaign_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def run_process_batch():
|
||||
with patch(
|
||||
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
|
||||
) as mock_rl:
|
||||
mock_rl.acquire_token = AsyncMock(
|
||||
side_effect=mock_rate_limiter["acquire_token"]
|
||||
)
|
||||
mock_rl.try_acquire_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
|
||||
)
|
||||
mock_rl.release_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["release_concurrent_slot"]
|
||||
)
|
||||
mock_rl.store_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["store_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.get_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["get_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.delete_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["delete_workflow_slot_mapping"]
|
||||
)
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
with patch.object(
|
||||
dispatcher, "dispatch_call", side_effect=mock_dispatch
|
||||
):
|
||||
return await dispatcher.process_batch(
|
||||
campaign_id=campaign_test_data.campaign_id, batch_size=4
|
||||
)
|
||||
|
||||
# Run 5 workers concurrently
|
||||
results = await asyncio.gather(*[run_process_batch() for _ in range(5)])
|
||||
|
||||
total_processed = sum(results)
|
||||
assert total_processed == 10
|
||||
|
||||
# Verify no duplicates
|
||||
assert len(set(processed_runs)) == 10, "Duplicate runs were processed!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processing_state_transition(
|
||||
self,
|
||||
campaign_test_data,
|
||||
mock_dispatch_call,
|
||||
mock_rate_limiter,
|
||||
db_session_factory,
|
||||
):
|
||||
"""
|
||||
Test that runs transition through processing -> processed states correctly.
|
||||
"""
|
||||
mock_dispatch, processed_runs = mock_dispatch_call
|
||||
|
||||
# Reset queued runs
|
||||
async with db_session_factory() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE queued_runs SET state = 'queued' WHERE campaign_id = :campaign_id"
|
||||
),
|
||||
{"campaign_id": campaign_test_data.campaign_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
with patch(
|
||||
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
|
||||
) as mock_rl:
|
||||
mock_rl.acquire_token = AsyncMock(
|
||||
side_effect=mock_rate_limiter["acquire_token"]
|
||||
)
|
||||
mock_rl.try_acquire_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
|
||||
)
|
||||
mock_rl.release_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["release_concurrent_slot"]
|
||||
)
|
||||
mock_rl.store_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["store_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.get_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["get_workflow_slot_mapping"]
|
||||
)
|
||||
mock_rl.delete_workflow_slot_mapping = AsyncMock(
|
||||
side_effect=mock_rate_limiter["delete_workflow_slot_mapping"]
|
||||
)
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
with patch.object(dispatcher, "dispatch_call", side_effect=mock_dispatch):
|
||||
await dispatcher.process_batch(
|
||||
campaign_id=campaign_test_data.campaign_id, batch_size=10
|
||||
)
|
||||
|
||||
# Verify all runs are in 'processed' state
|
||||
async with db_session_factory() as session:
|
||||
result = await session.execute(
|
||||
text(
|
||||
"SELECT state, COUNT(*) as count FROM queued_runs "
|
||||
"WHERE campaign_id = :campaign_id GROUP BY state"
|
||||
),
|
||||
{"campaign_id": campaign_test_data.campaign_id},
|
||||
)
|
||||
states = {row[0]: row[1] for row in result.fetchall()}
|
||||
|
||||
assert states.get("processed", 0) == 10
|
||||
assert states.get("queued", 0) == 0
|
||||
assert states.get("processing", 0) == 0
|
||||
|
||||
|
||||
class TestProcessBatchEdgeCases:
|
||||
"""Edge case tests for process_batch."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_queue(
|
||||
self, campaign_test_data, mock_rate_limiter, db_session_factory
|
||||
):
|
||||
"""Test process_batch with no queued runs returns 0."""
|
||||
# Set all runs to processed
|
||||
async with db_session_factory() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE queued_runs SET state = 'processed' WHERE campaign_id = :campaign_id"
|
||||
),
|
||||
{"campaign_id": campaign_test_data.campaign_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
with patch(
|
||||
"api.services.campaign.campaign_call_dispatcher.rate_limiter"
|
||||
) as mock_rl:
|
||||
mock_rl.acquire_token = AsyncMock(
|
||||
side_effect=mock_rate_limiter["acquire_token"]
|
||||
)
|
||||
mock_rl.try_acquire_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_rate_limiter["try_acquire_concurrent_slot"]
|
||||
)
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
result = await dispatcher.process_batch(
|
||||
campaign_id=campaign_test_data.campaign_id, batch_size=5
|
||||
)
|
||||
|
||||
assert result == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_campaign_not_running(
|
||||
self, campaign_test_data, mock_rate_limiter, db_session_factory
|
||||
):
|
||||
"""Test process_batch returns 0 if campaign is not in running state."""
|
||||
# Set campaign to paused
|
||||
async with db_session_factory() as session:
|
||||
await session.execute(
|
||||
text("UPDATE campaigns SET state = 'paused' WHERE id = :campaign_id"),
|
||||
{"campaign_id": campaign_test_data.campaign_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
result = await dispatcher.process_batch(
|
||||
campaign_id=campaign_test_data.campaign_id, batch_size=5
|
||||
)
|
||||
assert result == 0
|
||||
finally:
|
||||
# Restore campaign state
|
||||
async with db_session_factory() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE campaigns SET state = 'running' WHERE id = :campaign_id"
|
||||
),
|
||||
{"campaign_id": campaign_test_data.campaign_id},
|
||||
)
|
||||
await session.commit()
|
||||
Loading…
Add table
Add a link
Reference in a new issue