mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-28 08:49:42 +02:00
fix: make campaign process batch thread safe (#141)
* fix: dont schedule new batch on resume * fix: make process_batch thread safe
This commit is contained in:
parent
e9c5da16c5
commit
6827744327
17 changed files with 1012 additions and 230 deletions
|
|
@ -209,31 +209,6 @@ class CampaignClient(BaseDBClient):
|
|||
await session.rollback()
|
||||
raise e
|
||||
|
||||
async def get_queued_runs(
|
||||
self,
|
||||
campaign_id: int,
|
||||
state: str = "queued",
|
||||
limit: int = 10,
|
||||
scheduled_for: Optional[bool] = None,
|
||||
) -> list[QueuedRunModel]:
|
||||
"""Get queued runs for processing, optionally filtering by scheduled status"""
|
||||
async with self.async_session() as session:
|
||||
query = select(QueuedRunModel).where(
|
||||
QueuedRunModel.campaign_id == campaign_id,
|
||||
QueuedRunModel.state == state,
|
||||
)
|
||||
|
||||
# Filter by scheduled status if specified
|
||||
if scheduled_for is True:
|
||||
query = query.where(QueuedRunModel.scheduled_for.isnot(None))
|
||||
elif scheduled_for is False:
|
||||
query = query.where(QueuedRunModel.scheduled_for.is_(None))
|
||||
|
||||
query = query.order_by(QueuedRunModel.created_at).limit(limit)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_queued_run(self, queued_run_id: int, **kwargs) -> QueuedRunModel:
|
||||
"""Update queued run"""
|
||||
async with self.async_session() as session:
|
||||
|
|
@ -284,26 +259,6 @@ class CampaignClient(BaseDBClient):
|
|||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
# New methods for retry support
|
||||
async def get_scheduled_queued_runs(
|
||||
self, campaign_id: int, scheduled_before: datetime, limit: int = 10
|
||||
) -> list[QueuedRunModel]:
|
||||
"""Get scheduled queued runs that are due for processing"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(QueuedRunModel)
|
||||
.where(
|
||||
QueuedRunModel.campaign_id == campaign_id,
|
||||
QueuedRunModel.state == "queued",
|
||||
QueuedRunModel.scheduled_for.isnot(None),
|
||||
QueuedRunModel.scheduled_for <= scheduled_before,
|
||||
)
|
||||
.order_by(QueuedRunModel.scheduled_for)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def create_queued_run(
|
||||
self,
|
||||
campaign_id: int,
|
||||
|
|
@ -388,3 +343,82 @@ class CampaignClient(BaseDBClient):
|
|||
query = select(func.count(QueuedRunModel.id)).where(*conditions)
|
||||
result = await session.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def claim_queued_runs_for_processing(
|
||||
self,
|
||||
campaign_id: int,
|
||||
scheduled_before: datetime,
|
||||
limit: int = 10,
|
||||
) -> list[QueuedRunModel]:
|
||||
"""
|
||||
Atomically claim queued runs for processing using SELECT FOR UPDATE SKIP LOCKED.
|
||||
|
||||
This method is thread-safe - multiple workers can call it concurrently without
|
||||
processing the same runs. It:
|
||||
1. Prioritizes scheduled retries that are due
|
||||
2. Falls back to regular queued runs if more slots available
|
||||
3. Locks selected rows and marks them as 'processing' atomically
|
||||
|
||||
Returns: List of claimed QueuedRunModel objects
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
claimed_runs = []
|
||||
|
||||
# First, get scheduled retries that are due (with lock)
|
||||
scheduled_query = (
|
||||
select(QueuedRunModel)
|
||||
.where(
|
||||
QueuedRunModel.campaign_id == campaign_id,
|
||||
QueuedRunModel.state == "queued",
|
||||
QueuedRunModel.scheduled_for.isnot(None),
|
||||
QueuedRunModel.scheduled_for <= scheduled_before,
|
||||
)
|
||||
.order_by(QueuedRunModel.scheduled_for)
|
||||
.limit(limit)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
|
||||
scheduled_result = await session.execute(scheduled_query)
|
||||
scheduled_runs = list(scheduled_result.scalars().all())
|
||||
|
||||
# Mark scheduled runs as processing
|
||||
for run in scheduled_runs:
|
||||
run.state = "processing"
|
||||
claimed_runs.append(run)
|
||||
|
||||
remaining_slots = limit - len(scheduled_runs)
|
||||
|
||||
# Then get regular queued runs if we have remaining slots
|
||||
if remaining_slots > 0:
|
||||
regular_query = (
|
||||
select(QueuedRunModel)
|
||||
.where(
|
||||
QueuedRunModel.campaign_id == campaign_id,
|
||||
QueuedRunModel.state == "queued",
|
||||
QueuedRunModel.scheduled_for.is_(None),
|
||||
)
|
||||
.order_by(QueuedRunModel.created_at)
|
||||
.limit(remaining_slots)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
|
||||
regular_result = await session.execute(regular_query)
|
||||
regular_runs = list(regular_result.scalars().all())
|
||||
|
||||
# Mark regular runs as processing
|
||||
for run in regular_runs:
|
||||
run.state = "processing"
|
||||
claimed_runs.append(run)
|
||||
|
||||
# Commit the state changes
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
# Refresh to get updated state
|
||||
for run in claimed_runs:
|
||||
await session.refresh(run)
|
||||
|
||||
return claimed_runs
|
||||
|
|
|
|||
|
|
@ -598,7 +598,7 @@ class QueuedRunModel(Base):
|
|||
source_uuid = Column(String, nullable=False)
|
||||
context_variables = Column(JSON, nullable=False, default=dict)
|
||||
state = Column(
|
||||
Enum("queued", "processed", "failed", name="queued_run_state"),
|
||||
Enum("queued", "processed", "processing", "failed", name="queued_run_state"),
|
||||
nullable=False,
|
||||
default="queued",
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue