mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
380 lines
14 KiB
Python
380 lines
14 KiB
Python
|
|
from datetime import UTC, datetime
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
from sqlalchemy import func
|
||
|
|
from sqlalchemy.future import select
|
||
|
|
|
||
|
|
from api.db.base_client import BaseDBClient
|
||
|
|
from api.db.models import CampaignModel, QueuedRunModel, WorkflowRunModel
|
||
|
|
|
||
|
|
|
||
|
|
class CampaignClient(BaseDBClient):
|
||
|
|
async def create_campaign(
|
||
|
|
self,
|
||
|
|
name: str,
|
||
|
|
workflow_id: int,
|
||
|
|
source_type: str,
|
||
|
|
source_id: str,
|
||
|
|
user_id: int,
|
||
|
|
organization_id: int,
|
||
|
|
) -> CampaignModel:
|
||
|
|
"""Create a new campaign"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
campaign = CampaignModel(
|
||
|
|
name=name,
|
||
|
|
workflow_id=workflow_id,
|
||
|
|
source_type=source_type,
|
||
|
|
source_id=source_id,
|
||
|
|
created_by=user_id,
|
||
|
|
organization_id=organization_id,
|
||
|
|
)
|
||
|
|
session.add(campaign)
|
||
|
|
try:
|
||
|
|
await session.commit()
|
||
|
|
except Exception as e:
|
||
|
|
await session.rollback()
|
||
|
|
raise e
|
||
|
|
await session.refresh(campaign)
|
||
|
|
return campaign
|
||
|
|
|
||
|
|
async def get_campaigns(
|
||
|
|
self,
|
||
|
|
organization_id: int,
|
||
|
|
) -> list[CampaignModel]:
|
||
|
|
"""Get all campaigns for organization"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
query = (
|
||
|
|
select(CampaignModel)
|
||
|
|
.where(CampaignModel.organization_id == organization_id)
|
||
|
|
.order_by(CampaignModel.created_at.desc())
|
||
|
|
)
|
||
|
|
|
||
|
|
result = await session.execute(query)
|
||
|
|
return list(result.scalars().all())
|
||
|
|
|
||
|
|
async def get_campaign(
|
||
|
|
self,
|
||
|
|
campaign_id: int,
|
||
|
|
organization_id: int,
|
||
|
|
) -> Optional[CampaignModel]:
|
||
|
|
"""Get single campaign by ID, ensuring organization access"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
query = select(CampaignModel).where(
|
||
|
|
CampaignModel.id == campaign_id,
|
||
|
|
CampaignModel.organization_id == organization_id,
|
||
|
|
)
|
||
|
|
result = await session.execute(query)
|
||
|
|
return result.scalar_one_or_none()
|
||
|
|
|
||
|
|
async def update_campaign_state(
|
||
|
|
self,
|
||
|
|
campaign_id: int,
|
||
|
|
state: str,
|
||
|
|
organization_id: int,
|
||
|
|
) -> CampaignModel:
|
||
|
|
"""Update campaign state (start/pause/resume)"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
query = select(CampaignModel).where(
|
||
|
|
CampaignModel.id == campaign_id,
|
||
|
|
CampaignModel.organization_id == organization_id,
|
||
|
|
)
|
||
|
|
result = await session.execute(query)
|
||
|
|
campaign = result.scalar_one_or_none()
|
||
|
|
|
||
|
|
if not campaign:
|
||
|
|
raise ValueError(f"Campaign {campaign_id} not found")
|
||
|
|
|
||
|
|
campaign.state = state
|
||
|
|
if state == "running" and not campaign.started_at:
|
||
|
|
campaign.started_at = datetime.now(UTC)
|
||
|
|
elif state in ["completed", "failed"]:
|
||
|
|
campaign.completed_at = datetime.now(UTC)
|
||
|
|
|
||
|
|
try:
|
||
|
|
await session.commit()
|
||
|
|
except Exception as e:
|
||
|
|
await session.rollback()
|
||
|
|
raise e
|
||
|
|
await session.refresh(campaign)
|
||
|
|
return campaign
|
||
|
|
|
||
|
|
async def update_campaign_progress(
|
||
|
|
self,
|
||
|
|
campaign_id: int,
|
||
|
|
processed_rows: int,
|
||
|
|
failed_rows: int,
|
||
|
|
organization_id: int,
|
||
|
|
) -> None:
|
||
|
|
"""Update campaign progress counters"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
query = select(CampaignModel).where(
|
||
|
|
CampaignModel.id == campaign_id,
|
||
|
|
CampaignModel.organization_id == organization_id,
|
||
|
|
)
|
||
|
|
result = await session.execute(query)
|
||
|
|
campaign = result.scalar_one_or_none()
|
||
|
|
|
||
|
|
if not campaign:
|
||
|
|
raise ValueError(f"Campaign {campaign_id} not found")
|
||
|
|
|
||
|
|
campaign.processed_rows = processed_rows
|
||
|
|
campaign.failed_rows = failed_rows
|
||
|
|
campaign.updated_at = datetime.now(UTC)
|
||
|
|
|
||
|
|
try:
|
||
|
|
await session.commit()
|
||
|
|
except Exception as e:
|
||
|
|
await session.rollback()
|
||
|
|
raise e
|
||
|
|
|
||
|
|
async def get_campaign_runs(
|
||
|
|
self,
|
||
|
|
campaign_id: int,
|
||
|
|
organization_id: int,
|
||
|
|
) -> list[WorkflowRunModel]:
|
||
|
|
"""Get workflow runs for a campaign"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
# First verify campaign belongs to organization
|
||
|
|
campaign_query = select(CampaignModel).where(
|
||
|
|
CampaignModel.id == campaign_id,
|
||
|
|
CampaignModel.organization_id == organization_id,
|
||
|
|
)
|
||
|
|
campaign_result = await session.execute(campaign_query)
|
||
|
|
campaign = campaign_result.scalar_one_or_none()
|
||
|
|
|
||
|
|
if not campaign:
|
||
|
|
raise ValueError(f"Campaign {campaign_id} not found")
|
||
|
|
|
||
|
|
query = (
|
||
|
|
select(WorkflowRunModel)
|
||
|
|
.where(WorkflowRunModel.campaign_id == campaign_id)
|
||
|
|
.order_by(WorkflowRunModel.created_at.desc())
|
||
|
|
)
|
||
|
|
|
||
|
|
result = await session.execute(query)
|
||
|
|
return list(result.scalars().all())
|
||
|
|
|
||
|
|
async def get_campaign_by_id(self, campaign_id: int) -> Optional[CampaignModel]:
|
||
|
|
"""Get campaign by ID without organization check (for internal use)"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
query = select(CampaignModel).where(CampaignModel.id == campaign_id)
|
||
|
|
result = await session.execute(query)
|
||
|
|
return result.scalar_one_or_none()
|
||
|
|
|
||
|
|
async def update_campaign(self, campaign_id: int, **kwargs) -> CampaignModel:
|
||
|
|
"""Update campaign with arbitrary fields"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
query = select(CampaignModel).where(CampaignModel.id == campaign_id)
|
||
|
|
result = await session.execute(query)
|
||
|
|
campaign = result.scalar_one_or_none()
|
||
|
|
|
||
|
|
if not campaign:
|
||
|
|
raise ValueError(f"Campaign {campaign_id} not found")
|
||
|
|
|
||
|
|
# Update fields
|
||
|
|
for key, value in kwargs.items():
|
||
|
|
if hasattr(campaign, key):
|
||
|
|
setattr(campaign, key, value)
|
||
|
|
|
||
|
|
campaign.updated_at = datetime.now(UTC)
|
||
|
|
|
||
|
|
try:
|
||
|
|
await session.commit()
|
||
|
|
except Exception as e:
|
||
|
|
await session.rollback()
|
||
|
|
raise e
|
||
|
|
await session.refresh(campaign)
|
||
|
|
return campaign
|
||
|
|
|
||
|
|
# QueuedRun methods
|
||
|
|
async def bulk_create_queued_runs(self, queued_runs_data: list[dict]) -> None:
|
||
|
|
"""Bulk create queued runs"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
queued_runs = [QueuedRunModel(**data) for data in queued_runs_data]
|
||
|
|
session.add_all(queued_runs)
|
||
|
|
try:
|
||
|
|
await session.commit()
|
||
|
|
except Exception as e:
|
||
|
|
await session.rollback()
|
||
|
|
raise e
|
||
|
|
|
||
|
|
async def get_queued_runs(
|
||
|
|
self,
|
||
|
|
campaign_id: int,
|
||
|
|
state: str = "queued",
|
||
|
|
limit: int = 10,
|
||
|
|
scheduled_for: Optional[bool] = None,
|
||
|
|
) -> list[QueuedRunModel]:
|
||
|
|
"""Get queued runs for processing, optionally filtering by scheduled status"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
query = select(QueuedRunModel).where(
|
||
|
|
QueuedRunModel.campaign_id == campaign_id,
|
||
|
|
QueuedRunModel.state == state,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Filter by scheduled status if specified
|
||
|
|
if scheduled_for is True:
|
||
|
|
query = query.where(QueuedRunModel.scheduled_for.isnot(None))
|
||
|
|
elif scheduled_for is False:
|
||
|
|
query = query.where(QueuedRunModel.scheduled_for.is_(None))
|
||
|
|
|
||
|
|
query = query.order_by(QueuedRunModel.created_at).limit(limit)
|
||
|
|
|
||
|
|
result = await session.execute(query)
|
||
|
|
return list(result.scalars().all())
|
||
|
|
|
||
|
|
async def update_queued_run(self, queued_run_id: int, **kwargs) -> QueuedRunModel:
|
||
|
|
"""Update queued run"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
query = select(QueuedRunModel).where(QueuedRunModel.id == queued_run_id)
|
||
|
|
result = await session.execute(query)
|
||
|
|
queued_run = result.scalar_one_or_none()
|
||
|
|
|
||
|
|
if not queued_run:
|
||
|
|
raise ValueError(f"QueuedRun {queued_run_id} not found")
|
||
|
|
|
||
|
|
# Update fields
|
||
|
|
for key, value in kwargs.items():
|
||
|
|
if hasattr(queued_run, key):
|
||
|
|
setattr(queued_run, key, value)
|
||
|
|
|
||
|
|
try:
|
||
|
|
await session.commit()
|
||
|
|
except Exception as e:
|
||
|
|
await session.rollback()
|
||
|
|
raise e
|
||
|
|
await session.refresh(queued_run)
|
||
|
|
return queued_run
|
||
|
|
|
||
|
|
async def count_queued_runs(
|
||
|
|
self, campaign_id: int, state: Optional[str] = None
|
||
|
|
) -> int:
|
||
|
|
"""Count queued runs, optionally filtered by state"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
query = select(func.count(QueuedRunModel.id)).where(
|
||
|
|
QueuedRunModel.campaign_id == campaign_id
|
||
|
|
)
|
||
|
|
if state:
|
||
|
|
query = query.where(QueuedRunModel.state == state)
|
||
|
|
|
||
|
|
result = await session.execute(query)
|
||
|
|
return result.scalar() or 0
|
||
|
|
|
||
|
|
async def get_workflow_runs_by_campaign(
|
||
|
|
self, campaign_id: int
|
||
|
|
) -> list[WorkflowRunModel]:
|
||
|
|
"""Get all workflow runs for a campaign (internal use)"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
query = (
|
||
|
|
select(WorkflowRunModel)
|
||
|
|
.where(WorkflowRunModel.campaign_id == campaign_id)
|
||
|
|
.order_by(WorkflowRunModel.created_at)
|
||
|
|
)
|
||
|
|
result = await session.execute(query)
|
||
|
|
return list(result.scalars().all())
|
||
|
|
|
||
|
|
# New methods for retry support
|
||
|
|
async def get_scheduled_queued_runs(
|
||
|
|
self, campaign_id: int, scheduled_before: datetime, limit: int = 10
|
||
|
|
) -> list[QueuedRunModel]:
|
||
|
|
"""Get scheduled queued runs that are due for processing"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
query = (
|
||
|
|
select(QueuedRunModel)
|
||
|
|
.where(
|
||
|
|
QueuedRunModel.campaign_id == campaign_id,
|
||
|
|
QueuedRunModel.state == "queued",
|
||
|
|
QueuedRunModel.scheduled_for.isnot(None),
|
||
|
|
QueuedRunModel.scheduled_for <= scheduled_before,
|
||
|
|
)
|
||
|
|
.order_by(QueuedRunModel.scheduled_for)
|
||
|
|
.limit(limit)
|
||
|
|
)
|
||
|
|
result = await session.execute(query)
|
||
|
|
return list(result.scalars().all())
|
||
|
|
|
||
|
|
async def create_queued_run(
|
||
|
|
self,
|
||
|
|
campaign_id: int,
|
||
|
|
source_uuid: str,
|
||
|
|
context_variables: dict,
|
||
|
|
state: str = "queued",
|
||
|
|
retry_count: int = 0,
|
||
|
|
parent_queued_run_id: Optional[int] = None,
|
||
|
|
scheduled_for: Optional[datetime] = None,
|
||
|
|
retry_reason: Optional[str] = None,
|
||
|
|
) -> QueuedRunModel:
|
||
|
|
"""Create a single queued run with retry support"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
queued_run = QueuedRunModel(
|
||
|
|
campaign_id=campaign_id,
|
||
|
|
source_uuid=source_uuid,
|
||
|
|
context_variables=context_variables,
|
||
|
|
state=state,
|
||
|
|
retry_count=retry_count,
|
||
|
|
parent_queued_run_id=parent_queued_run_id,
|
||
|
|
scheduled_for=scheduled_for,
|
||
|
|
retry_reason=retry_reason,
|
||
|
|
)
|
||
|
|
session.add(queued_run)
|
||
|
|
try:
|
||
|
|
await session.commit()
|
||
|
|
except Exception as e:
|
||
|
|
await session.rollback()
|
||
|
|
raise e
|
||
|
|
await session.refresh(queued_run)
|
||
|
|
return queued_run
|
||
|
|
|
||
|
|
async def get_queued_run_by_id(
|
||
|
|
self, queued_run_id: int
|
||
|
|
) -> Optional[QueuedRunModel]:
|
||
|
|
"""Get a queued run by ID"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
query = select(QueuedRunModel).where(QueuedRunModel.id == queued_run_id)
|
||
|
|
result = await session.execute(query)
|
||
|
|
return result.scalar_one_or_none()
|
||
|
|
|
||
|
|
async def get_campaigns_by_status(self, statuses: list[str]) -> list[CampaignModel]:
|
||
|
|
"""Get campaigns by status"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
query = (
|
||
|
|
select(CampaignModel)
|
||
|
|
.where(CampaignModel.state.in_(statuses))
|
||
|
|
.order_by(CampaignModel.created_at.desc())
|
||
|
|
)
|
||
|
|
result = await session.execute(query)
|
||
|
|
return list(result.scalars().all())
|
||
|
|
|
||
|
|
async def get_queued_runs_count(self, campaign_id: int, states: list[str]) -> int:
|
||
|
|
"""Get count of queued runs for a campaign in specified states"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
query = select(func.count(QueuedRunModel.id)).where(
|
||
|
|
QueuedRunModel.campaign_id == campaign_id,
|
||
|
|
QueuedRunModel.state.in_(states),
|
||
|
|
)
|
||
|
|
result = await session.execute(query)
|
||
|
|
return result.scalar() or 0
|
||
|
|
|
||
|
|
async def get_scheduled_runs_count(
|
||
|
|
self,
|
||
|
|
campaign_id: int,
|
||
|
|
scheduled_before: Optional[datetime] = None,
|
||
|
|
scheduled_after: Optional[datetime] = None,
|
||
|
|
) -> int:
|
||
|
|
"""Get count of scheduled runs for a campaign"""
|
||
|
|
async with self.async_session() as session:
|
||
|
|
conditions = [
|
||
|
|
QueuedRunModel.campaign_id == campaign_id,
|
||
|
|
QueuedRunModel.scheduled_for.isnot(None),
|
||
|
|
QueuedRunModel.state == "queued",
|
||
|
|
]
|
||
|
|
|
||
|
|
if scheduled_before:
|
||
|
|
conditions.append(QueuedRunModel.scheduled_for <= scheduled_before)
|
||
|
|
if scheduled_after:
|
||
|
|
conditions.append(QueuedRunModel.scheduled_for > scheduled_after)
|
||
|
|
|
||
|
|
query = select(func.count(QueuedRunModel.id)).where(*conditions)
|
||
|
|
result = await session.execute(query)
|
||
|
|
return result.scalar() or 0
|