fix: make campaign process batch thread safe (#141)

* fix: dont schedule new batch on resume

* fix: make process_batch thread safe
This commit is contained in:
Abhishek 2026-01-30 14:48:00 +05:30 committed by GitHub
parent e9c5da16c5
commit 6827744327
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1012 additions and 230 deletions

View file

@ -209,31 +209,6 @@ class CampaignClient(BaseDBClient):
await session.rollback()
raise e
async def get_queued_runs(
self,
campaign_id: int,
state: str = "queued",
limit: int = 10,
scheduled_for: Optional[bool] = None,
) -> list[QueuedRunModel]:
"""Get queued runs for processing, optionally filtering by scheduled status"""
async with self.async_session() as session:
query = select(QueuedRunModel).where(
QueuedRunModel.campaign_id == campaign_id,
QueuedRunModel.state == state,
)
# Filter by scheduled status if specified
if scheduled_for is True:
query = query.where(QueuedRunModel.scheduled_for.isnot(None))
elif scheduled_for is False:
query = query.where(QueuedRunModel.scheduled_for.is_(None))
query = query.order_by(QueuedRunModel.created_at).limit(limit)
result = await session.execute(query)
return list(result.scalars().all())
async def update_queued_run(self, queued_run_id: int, **kwargs) -> QueuedRunModel:
"""Update queued run"""
async with self.async_session() as session:
@ -284,26 +259,6 @@ class CampaignClient(BaseDBClient):
result = await session.execute(query)
return list(result.scalars().all())
# New methods for retry support
async def get_scheduled_queued_runs(
self, campaign_id: int, scheduled_before: datetime, limit: int = 10
) -> list[QueuedRunModel]:
"""Get scheduled queued runs that are due for processing"""
async with self.async_session() as session:
query = (
select(QueuedRunModel)
.where(
QueuedRunModel.campaign_id == campaign_id,
QueuedRunModel.state == "queued",
QueuedRunModel.scheduled_for.isnot(None),
QueuedRunModel.scheduled_for <= scheduled_before,
)
.order_by(QueuedRunModel.scheduled_for)
.limit(limit)
)
result = await session.execute(query)
return list(result.scalars().all())
async def create_queued_run(
self,
campaign_id: int,
@ -388,3 +343,82 @@ class CampaignClient(BaseDBClient):
query = select(func.count(QueuedRunModel.id)).where(*conditions)
result = await session.execute(query)
return result.scalar() or 0
async def claim_queued_runs_for_processing(
self,
campaign_id: int,
scheduled_before: datetime,
limit: int = 10,
) -> list[QueuedRunModel]:
"""
Atomically claim queued runs for processing using SELECT FOR UPDATE SKIP LOCKED.
This method is thread-safe - multiple workers can call it concurrently without
processing the same runs. It:
1. Prioritizes scheduled retries that are due
2. Falls back to regular queued runs if more slots available
3. Locks selected rows and marks them as 'processing' atomically
Returns: List of claimed QueuedRunModel objects
"""
async with self.async_session() as session:
claimed_runs = []
# First, get scheduled retries that are due (with lock)
scheduled_query = (
select(QueuedRunModel)
.where(
QueuedRunModel.campaign_id == campaign_id,
QueuedRunModel.state == "queued",
QueuedRunModel.scheduled_for.isnot(None),
QueuedRunModel.scheduled_for <= scheduled_before,
)
.order_by(QueuedRunModel.scheduled_for)
.limit(limit)
.with_for_update(skip_locked=True)
)
scheduled_result = await session.execute(scheduled_query)
scheduled_runs = list(scheduled_result.scalars().all())
# Mark scheduled runs as processing
for run in scheduled_runs:
run.state = "processing"
claimed_runs.append(run)
remaining_slots = limit - len(scheduled_runs)
# Then get regular queued runs if we have remaining slots
if remaining_slots > 0:
regular_query = (
select(QueuedRunModel)
.where(
QueuedRunModel.campaign_id == campaign_id,
QueuedRunModel.state == "queued",
QueuedRunModel.scheduled_for.is_(None),
)
.order_by(QueuedRunModel.created_at)
.limit(remaining_slots)
.with_for_update(skip_locked=True)
)
regular_result = await session.execute(regular_query)
regular_runs = list(regular_result.scalars().all())
# Mark regular runs as processing
for run in regular_runs:
run.state = "processing"
claimed_runs.append(run)
# Commit the state changes
try:
await session.commit()
except Exception as e:
await session.rollback()
raise e
# Refresh to get updated state
for run in claimed_runs:
await session.refresh(run)
return claimed_runs

View file

@ -598,7 +598,7 @@ class QueuedRunModel(Base):
source_uuid = Column(String, nullable=False)
context_variables = Column(JSON, nullable=False, default=dict)
state = Column(
Enum("queued", "processed", "failed", name="queued_run_state"),
Enum("queued", "processed", "processing", "failed", name="queued_run_state"),
nullable=False,
default="queued",
)