mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-19 08:28:10 +02:00
Initial Commit 🚀 🚀
This commit is contained in:
commit
4f2a629340
444 changed files with 76863 additions and 0 deletions
312
api/db/workflow_client.py
Normal file
312
api/db/workflow_client.py
Normal file
|
|
@ -0,0 +1,312 @@
|
|||
import hashlib
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import WorkflowDefinitionModel, WorkflowModel, WorkflowRunModel
|
||||
|
||||
|
||||
class WorkflowClient(BaseDBClient):
|
||||
def _generate_workflow_hash(self, workflow_definition: dict) -> str:
|
||||
"""Generate a consistent hash for workflow definition."""
|
||||
# Convert to JSON with sorted keys for consistent hashing
|
||||
json_str = json.dumps(
|
||||
workflow_definition, sort_keys=True, separators=(",", ":")
|
||||
)
|
||||
return hashlib.sha256(json_str.encode()).hexdigest()
|
||||
|
||||
async def _get_or_create_workflow_definition(
|
||||
self, workflow_definition: dict, session, workflow_id: int = None
|
||||
) -> WorkflowDefinitionModel:
|
||||
"""Get existing workflow definition by hash or create a new one."""
|
||||
workflow_hash = self._generate_workflow_hash(workflow_definition)
|
||||
|
||||
# Try to find existing definition
|
||||
result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_hash == workflow_hash,
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
)
|
||||
)
|
||||
existing_definition = result.scalars().first()
|
||||
|
||||
if existing_definition:
|
||||
return existing_definition
|
||||
|
||||
# Create new definition if it doesn't exist
|
||||
new_definition = WorkflowDefinitionModel(
|
||||
workflow_hash=workflow_hash,
|
||||
workflow_json=workflow_definition,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
session.add(new_definition)
|
||||
await session.flush() # Flush to get the ID without committing
|
||||
return new_definition
|
||||
|
||||
async def create_workflow(
|
||||
self,
|
||||
name: str,
|
||||
workflow_definition: dict,
|
||||
user_id: int,
|
||||
organization_id: int = None,
|
||||
) -> WorkflowModel:
|
||||
async with self.async_session() as session:
|
||||
try:
|
||||
new_workflow = WorkflowModel(
|
||||
name=name,
|
||||
workflow_definition=workflow_definition, # Keep for backwards compatibility
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(new_workflow)
|
||||
await session.flush() # Flush to get the workflow ID
|
||||
|
||||
# Now get or create workflow definition with the workflow_id
|
||||
definition = await self._get_or_create_workflow_definition(
|
||||
workflow_definition, session, new_workflow.id
|
||||
)
|
||||
|
||||
# Mark this definition as the current one and unset others
|
||||
definition.is_current = True
|
||||
# Set any other definitions for this workflow to not current
|
||||
other_defs_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == new_workflow.id,
|
||||
WorkflowDefinitionModel.id != definition.id,
|
||||
)
|
||||
)
|
||||
for other_def in other_defs_result.scalars().all():
|
||||
other_def.is_current = False
|
||||
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(new_workflow)
|
||||
return new_workflow
|
||||
|
||||
async def get_all_workflows(
|
||||
self, user_id: int = None, organization_id: int = None, status: str = None
|
||||
) -> list[WorkflowModel]:
|
||||
async with self.async_session() as session:
|
||||
query = select(WorkflowModel).options(
|
||||
selectinload(WorkflowModel.current_definition)
|
||||
)
|
||||
|
||||
if organization_id:
|
||||
# Filter by organization_id when provided
|
||||
query = query.where(WorkflowModel.organization_id == organization_id)
|
||||
elif user_id:
|
||||
# Fallback to user_id for backwards compatibility
|
||||
query = query.where(WorkflowModel.user_id == user_id)
|
||||
|
||||
# Filter by status if provided
|
||||
if status:
|
||||
query = query.where(WorkflowModel.status == status)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_workflow(
|
||||
self, workflow_id: int, user_id: int = None, organization_id: int = None
|
||||
) -> WorkflowModel | None:
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.current_definition))
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
|
||||
if organization_id:
|
||||
# Filter by organization_id when provided
|
||||
query = query.where(WorkflowModel.organization_id == organization_id)
|
||||
elif user_id:
|
||||
# Fallback to user_id for backwards compatibility
|
||||
query = query.where(WorkflowModel.user_id == user_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_workflow_by_id(self, workflow_id: int) -> WorkflowModel | None:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.current_definition))
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def update_workflow(
|
||||
self,
|
||||
workflow_id: int,
|
||||
name: str,
|
||||
workflow_definition: dict | None,
|
||||
template_context_variables: dict | None,
|
||||
workflow_configurations: dict | None,
|
||||
user_id: int = None,
|
||||
organization_id: int = None,
|
||||
) -> WorkflowModel:
|
||||
"""
|
||||
Update an existing workflow in the database.
|
||||
|
||||
Args:
|
||||
workflow_id: The ID of the workflow to update
|
||||
name: The new name for the workflow
|
||||
workflow_definition: The new workflow definition
|
||||
template_context_variables: The template context variables
|
||||
user_id: The user ID (for backwards compatibility)
|
||||
organization_id: The organization ID
|
||||
|
||||
Returns:
|
||||
The updated WorkflowModel
|
||||
|
||||
Raises:
|
||||
ValueError: If the workflow with the given ID is not found
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.current_definition))
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
|
||||
if organization_id:
|
||||
# Filter by organization_id when provided
|
||||
query = query.where(WorkflowModel.organization_id == organization_id)
|
||||
elif user_id:
|
||||
# Fallback to user_id for backwards compatibility
|
||||
query = query.where(WorkflowModel.user_id == user_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
workflow = result.scalars().first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
||||
|
||||
workflow.name = name
|
||||
|
||||
if template_context_variables is not None:
|
||||
workflow.template_context_variables = template_context_variables
|
||||
|
||||
if workflow_configurations is not None:
|
||||
workflow.workflow_configurations = workflow_configurations
|
||||
|
||||
# In case of only name update, the workflow_definition can be None
|
||||
if workflow_definition:
|
||||
# Get or create new workflow definition
|
||||
definition = await self._get_or_create_workflow_definition(
|
||||
workflow_definition, session, workflow_id
|
||||
)
|
||||
|
||||
# Update legacy field for backwards compatibility
|
||||
workflow.workflow_definition = workflow_definition
|
||||
|
||||
# Mark new definition as current and reset others
|
||||
definition.is_current = True
|
||||
other_defs_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.id != definition.id,
|
||||
)
|
||||
)
|
||||
for other_def in other_defs_result.scalars().all():
|
||||
other_def.is_current = False
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(workflow)
|
||||
return workflow
|
||||
|
||||
async def get_workflows_by_ids(
|
||||
self, workflow_ids: list[int], organization_id: int
|
||||
) -> list[WorkflowModel]:
|
||||
"""Get workflows by IDs for a specific organization"""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowModel)
|
||||
.join(WorkflowModel.user)
|
||||
.where(
|
||||
WorkflowModel.id.in_(workflow_ids),
|
||||
WorkflowModel.user.has(selected_organization_id=organization_id),
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_workflow_name(
|
||||
self, workflow_id: int, user_id: int = None, organization_id: int = None
|
||||
) -> Optional[str]:
|
||||
"""Get just the workflow name by ID"""
|
||||
async with self.async_session() as session:
|
||||
query = select(WorkflowModel.name).where(WorkflowModel.id == workflow_id)
|
||||
|
||||
if organization_id:
|
||||
# Filter by organization_id when provided
|
||||
query = query.where(WorkflowModel.organization_id == organization_id)
|
||||
elif user_id:
|
||||
# Fallback to user_id for backwards compatibility
|
||||
query = query.where(WorkflowModel.user_id == user_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_workflow_status(
|
||||
self,
|
||||
workflow_id: int,
|
||||
status: str,
|
||||
organization_id: int = None,
|
||||
) -> WorkflowModel:
|
||||
"""
|
||||
Update the status of a workflow.
|
||||
|
||||
Args:
|
||||
workflow_id: The ID of the workflow to update
|
||||
status: The new status (active/archived)
|
||||
organization_id: The organization ID
|
||||
|
||||
Returns:
|
||||
The updated WorkflowModel
|
||||
|
||||
Raises:
|
||||
ValueError: If the workflow is not found
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.current_definition))
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
|
||||
if organization_id:
|
||||
query = query.where(WorkflowModel.organization_id == organization_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
workflow = result.scalars().first()
|
||||
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
||||
|
||||
workflow.status = status
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(workflow)
|
||||
return workflow
|
||||
|
||||
async def get_workflow_run_count(self, workflow_id: int) -> int:
|
||||
"""Get the count of runs for a workflow."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(func.count(WorkflowRunModel.id)).where(
|
||||
WorkflowRunModel.workflow_id == workflow_id
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
Loading…
Add table
Add a link
Reference in a new issue