mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
feat: agent versioning and model configurations override (#227)
* feat: add tests and migrations * feat: workflow versioning among published and draft * feat: add a new settings page to simplify workflow detail page * fix: fix tsclient generation
This commit is contained in:
parent
f5fa9ce717
commit
38d1d928b7
62 changed files with 10158 additions and 3131 deletions
|
|
@ -9,6 +9,7 @@ from api.db.base_client import BaseDBClient
|
|||
from api.db.models import (
|
||||
LoopTalkConversation,
|
||||
LoopTalkTestSession,
|
||||
WorkflowModel,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -50,8 +51,12 @@ class LoopTalkClient(BaseDBClient):
|
|||
result = await session.execute(
|
||||
select(LoopTalkTestSession)
|
||||
.options(
|
||||
selectinload(LoopTalkTestSession.actor_workflow),
|
||||
selectinload(LoopTalkTestSession.adversary_workflow),
|
||||
selectinload(LoopTalkTestSession.actor_workflow).selectinload(
|
||||
WorkflowModel.released_definition
|
||||
),
|
||||
selectinload(LoopTalkTestSession.adversary_workflow).selectinload(
|
||||
WorkflowModel.released_definition
|
||||
),
|
||||
selectinload(LoopTalkTestSession.conversations),
|
||||
)
|
||||
.where(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from loguru import logger
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
|
|
@ -199,7 +198,7 @@ class IntegrationModel(Base):
|
|||
class WorkflowDefinitionModel(Base):
|
||||
__tablename__ = "workflow_definitions"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
workflow_hash = Column(String, nullable=False)
|
||||
workflow_hash = Column(String, nullable=True) # Legacy, no longer used
|
||||
workflow_json = Column(JSON, nullable=False, default=dict)
|
||||
workflow_id = Column(Integer, ForeignKey("workflows.id"), nullable=True)
|
||||
is_current = Column(
|
||||
|
|
@ -207,12 +206,29 @@ class WorkflowDefinitionModel(Base):
|
|||
)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
# Table constraints and indexes
|
||||
# Versioning columns
|
||||
status = Column(
|
||||
String,
|
||||
nullable=False,
|
||||
default="published",
|
||||
server_default=text("'published'"),
|
||||
) # draft | published | archived
|
||||
version_number = Column(
|
||||
Integer, nullable=True
|
||||
) # Sequential per workflow, display only
|
||||
published_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Full behavioral snapshot (moved from WorkflowModel to enable versioning)
|
||||
workflow_configurations = Column(
|
||||
JSON, nullable=False, default=dict, server_default=text("'{}'::json")
|
||||
)
|
||||
template_context_variables = Column(
|
||||
JSON, nullable=False, default=dict, server_default=text("'{}'::json")
|
||||
)
|
||||
|
||||
# Table constraints and indexes — unique hash constraint removed (no more dedup)
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"workflow_hash", "workflow_id", name="uq_workflow_hash_workflow_id"
|
||||
),
|
||||
Index("ix_workflow_hash_workflow_id", "workflow_hash", "workflow_id"),
|
||||
Index("ix_workflow_definitions_workflow_status", "workflow_id", "status"),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
|
|
@ -247,6 +263,19 @@ class WorkflowModel(Base):
|
|||
runs = relationship("WorkflowRunModel", back_populates="workflow")
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
# Pointer to the currently-live (published) version
|
||||
released_definition_id = Column(
|
||||
Integer,
|
||||
ForeignKey("workflow_definitions.id", use_alter=True),
|
||||
nullable=True,
|
||||
)
|
||||
released_definition = relationship(
|
||||
"WorkflowDefinitionModel",
|
||||
foreign_keys=[released_definition_id],
|
||||
uselist=False,
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
# All versions / historical definitions of this workflow
|
||||
definitions = relationship(
|
||||
"WorkflowDefinitionModel",
|
||||
|
|
@ -255,6 +284,7 @@ class WorkflowModel(Base):
|
|||
)
|
||||
|
||||
# Relationship to fetch the current (is_current=True) definition
|
||||
# Kept for backward compatibility during transition
|
||||
current_definition = relationship(
|
||||
"WorkflowDefinitionModel",
|
||||
primaryjoin=lambda: and_(
|
||||
|
|
@ -277,36 +307,6 @@ class WorkflowModel(Base):
|
|||
# that scenario so callers can handle the absence explicitly.
|
||||
return None
|
||||
|
||||
@property
|
||||
def workflow_definition_with_fallback(self):
|
||||
"""
|
||||
Get workflow definition with fallback to legacy workflow_definition field.
|
||||
|
||||
Returns:
|
||||
dict: The workflow definition JSON
|
||||
"""
|
||||
# Access the relationship only if it has ALREADY been eagerly loaded on this
|
||||
# instance to avoid triggering an implicit lazy load once the SQLAlchemy
|
||||
# Session has been closed (which would raise a DetachedInstanceError).
|
||||
|
||||
# ``__dict__`` will contain "current_definition" **only** when the attribute
|
||||
# has been populated (e.g. via `selectinload` or an explicit access while
|
||||
# the session was still open). Using ``__dict__.get`` guarantees that we
|
||||
# do not accidentally issue a lazy load query on a detached instance.
|
||||
|
||||
current_definition = self.__dict__.get("current_definition")
|
||||
|
||||
if current_definition is not None:
|
||||
return current_definition.workflow_json
|
||||
|
||||
# Fallback for backwards-compatibility when the relationship is not (yet)
|
||||
# loaded. In this case we fall back to the legacy ``workflow_definition``
|
||||
# column that always contains the most recent definition JSON.
|
||||
logger.warning(
|
||||
f"Workflow {self.id} has no loaded current definition, using workflow_definition as fallback",
|
||||
)
|
||||
return self.workflow_definition
|
||||
|
||||
|
||||
class WorkflowTemplates(Base):
|
||||
__tablename__ = "workflow_templates"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import hashlib
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
|
@ -12,41 +11,15 @@ from api.db.models import WorkflowDefinitionModel, WorkflowModel, WorkflowRunMod
|
|||
|
||||
|
||||
class WorkflowClient(BaseDBClient):
|
||||
def _generate_workflow_hash(self, workflow_definition: dict) -> str:
|
||||
"""Generate a consistent hash for workflow definition."""
|
||||
# Convert to JSON with sorted keys for consistent hashing
|
||||
json_str = json.dumps(
|
||||
workflow_definition, sort_keys=True, separators=(",", ":")
|
||||
)
|
||||
return hashlib.sha256(json_str.encode()).hexdigest()
|
||||
|
||||
async def _get_or_create_workflow_definition(
|
||||
self, workflow_definition: dict, session, workflow_id: int = None
|
||||
) -> WorkflowDefinitionModel:
|
||||
"""Get existing workflow definition by hash or create a new one."""
|
||||
workflow_hash = self._generate_workflow_hash(workflow_definition)
|
||||
|
||||
# Try to find existing definition
|
||||
async def _next_version_number(self, session, workflow_id: int) -> int:
|
||||
"""Get the next version number for a workflow."""
|
||||
result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_hash == workflow_hash,
|
||||
select(func.max(WorkflowDefinitionModel.version_number)).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
)
|
||||
)
|
||||
existing_definition = result.scalars().first()
|
||||
|
||||
if existing_definition:
|
||||
return existing_definition
|
||||
|
||||
# Create new definition if it doesn't exist
|
||||
new_definition = WorkflowDefinitionModel(
|
||||
workflow_hash=workflow_hash,
|
||||
workflow_json=workflow_definition,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
session.add(new_definition)
|
||||
await session.flush() # Flush to get the ID without committing
|
||||
return new_definition
|
||||
current_max = result.scalar()
|
||||
return (current_max or 0) + 1
|
||||
|
||||
async def create_workflow(
|
||||
self,
|
||||
|
|
@ -66,21 +39,23 @@ class WorkflowClient(BaseDBClient):
|
|||
session.add(new_workflow)
|
||||
await session.flush() # Flush to get the workflow ID
|
||||
|
||||
# Now get or create workflow definition with the workflow_id
|
||||
definition = await self._get_or_create_workflow_definition(
|
||||
workflow_definition, session, new_workflow.id
|
||||
# Create the first definition as V1 published
|
||||
definition = WorkflowDefinitionModel(
|
||||
workflow_json=workflow_definition,
|
||||
workflow_id=new_workflow.id,
|
||||
is_current=True,
|
||||
status="published",
|
||||
version_number=1,
|
||||
published_at=datetime.now(UTC),
|
||||
workflow_configurations=new_workflow.workflow_configurations or {},
|
||||
template_context_variables=new_workflow.template_context_variables
|
||||
or {},
|
||||
)
|
||||
session.add(definition)
|
||||
await session.flush()
|
||||
|
||||
# Mark this definition as the current one and unset others
|
||||
definition.is_current = True
|
||||
await session.execute(
|
||||
update(WorkflowDefinitionModel)
|
||||
.where(
|
||||
WorkflowDefinitionModel.workflow_id == new_workflow.id,
|
||||
WorkflowDefinitionModel.id != definition.id,
|
||||
)
|
||||
.values(is_current=False)
|
||||
)
|
||||
# Set the released pointer
|
||||
new_workflow.released_definition_id = definition.id
|
||||
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
|
|
@ -89,6 +64,257 @@ class WorkflowClient(BaseDBClient):
|
|||
await session.refresh(new_workflow)
|
||||
return new_workflow
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Versioning methods
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def save_workflow_draft(
|
||||
self,
|
||||
workflow_id: int,
|
||||
workflow_definition: dict | None = None,
|
||||
workflow_configurations: dict | None = None,
|
||||
template_context_variables: dict | None = None,
|
||||
) -> WorkflowDefinitionModel:
|
||||
"""Create or update a draft version for this workflow.
|
||||
|
||||
If a draft already exists, it is updated in place.
|
||||
If no draft exists, a new one is created with the next version number.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Check for existing draft
|
||||
result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.status == "draft",
|
||||
)
|
||||
)
|
||||
draft = result.scalars().first()
|
||||
|
||||
if draft:
|
||||
# Update existing draft in place
|
||||
if workflow_definition is not None:
|
||||
draft.workflow_json = workflow_definition
|
||||
if workflow_configurations is not None:
|
||||
draft.workflow_configurations = workflow_configurations
|
||||
if template_context_variables is not None:
|
||||
draft.template_context_variables = template_context_variables
|
||||
else:
|
||||
# Get current published to use as base for unspecified fields
|
||||
pub_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.status == "published",
|
||||
)
|
||||
)
|
||||
published = pub_result.scalars().first()
|
||||
|
||||
next_version = await self._next_version_number(session, workflow_id)
|
||||
|
||||
draft = WorkflowDefinitionModel(
|
||||
workflow_id=workflow_id,
|
||||
workflow_json=workflow_definition
|
||||
if workflow_definition is not None
|
||||
else (published.workflow_json if published else {}),
|
||||
workflow_configurations=workflow_configurations
|
||||
if workflow_configurations is not None
|
||||
else (published.workflow_configurations if published else {}),
|
||||
template_context_variables=template_context_variables
|
||||
if template_context_variables is not None
|
||||
else (published.template_context_variables if published else {}),
|
||||
status="draft",
|
||||
version_number=next_version,
|
||||
is_current=False,
|
||||
)
|
||||
session.add(draft)
|
||||
|
||||
# Keep legacy columns on workflows table in sync with draft
|
||||
wf_result = await session.execute(
|
||||
select(WorkflowModel).where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
workflow = wf_result.scalars().first()
|
||||
if workflow:
|
||||
workflow.workflow_definition = draft.workflow_json
|
||||
workflow.workflow_configurations = draft.workflow_configurations
|
||||
workflow.template_context_variables = draft.template_context_variables
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(draft)
|
||||
return draft
|
||||
|
||||
async def publish_workflow_draft(
|
||||
self,
|
||||
workflow_id: int,
|
||||
) -> WorkflowDefinitionModel:
|
||||
"""Promote the current draft to published.
|
||||
|
||||
- Draft → published
|
||||
- Previous published → archived
|
||||
- Updates released_definition_id on the workflow
|
||||
- Sets is_current for backward compatibility
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Find the draft
|
||||
result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.status == "draft",
|
||||
)
|
||||
)
|
||||
draft = result.scalars().first()
|
||||
if not draft:
|
||||
raise ValueError(f"No draft exists for workflow {workflow_id}")
|
||||
|
||||
# Archive the current published version
|
||||
await session.execute(
|
||||
update(WorkflowDefinitionModel)
|
||||
.where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.status == "published",
|
||||
)
|
||||
.values(status="archived", is_current=False)
|
||||
)
|
||||
|
||||
# Promote draft → published
|
||||
draft.status = "published"
|
||||
draft.published_at = datetime.now(UTC)
|
||||
draft.is_current = True
|
||||
|
||||
# Update workflow's released pointer + legacy fields
|
||||
wf_result = await session.execute(
|
||||
select(WorkflowModel).where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
workflow = wf_result.scalars().first()
|
||||
workflow.released_definition_id = draft.id
|
||||
workflow.workflow_definition = draft.workflow_json
|
||||
workflow.workflow_configurations = draft.workflow_configurations
|
||||
workflow.template_context_variables = draft.template_context_variables
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(draft)
|
||||
return draft
|
||||
|
||||
async def discard_workflow_draft(
|
||||
self,
|
||||
workflow_id: int,
|
||||
) -> None:
|
||||
"""Delete the current draft version."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.status == "draft",
|
||||
)
|
||||
)
|
||||
draft = result.scalars().first()
|
||||
if not draft:
|
||||
raise ValueError(f"No draft exists for workflow {workflow_id}")
|
||||
|
||||
await session.delete(draft)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
async def revert_to_version(
|
||||
self,
|
||||
workflow_id: int,
|
||||
definition_id: int,
|
||||
) -> WorkflowDefinitionModel:
|
||||
"""Create a new draft from an archived version's snapshot.
|
||||
|
||||
Raises ValueError if a draft already exists (must discard first).
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Ensure no existing draft
|
||||
draft_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.status == "draft",
|
||||
)
|
||||
)
|
||||
if draft_result.scalars().first():
|
||||
raise ValueError(
|
||||
f"Draft already exists for workflow {workflow_id}. "
|
||||
"Discard it before reverting."
|
||||
)
|
||||
|
||||
# Fetch the source version
|
||||
source_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.id == definition_id,
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
)
|
||||
)
|
||||
source = source_result.scalars().first()
|
||||
if not source:
|
||||
raise ValueError(
|
||||
f"Version {definition_id} not found for workflow {workflow_id}"
|
||||
)
|
||||
|
||||
next_version = await self._next_version_number(session, workflow_id)
|
||||
|
||||
# Create new draft from the source snapshot
|
||||
draft = WorkflowDefinitionModel(
|
||||
workflow_id=workflow_id,
|
||||
workflow_json=source.workflow_json,
|
||||
workflow_configurations=source.workflow_configurations,
|
||||
template_context_variables=source.template_context_variables,
|
||||
status="draft",
|
||||
version_number=next_version,
|
||||
is_current=False,
|
||||
)
|
||||
session.add(draft)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(draft)
|
||||
return draft
|
||||
|
||||
async def get_draft_version(
|
||||
self,
|
||||
workflow_id: int,
|
||||
) -> WorkflowDefinitionModel | None:
|
||||
"""Get the draft version for a workflow, or None if no draft exists."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.status == "draft",
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_workflow_versions(
|
||||
self,
|
||||
workflow_id: int,
|
||||
) -> list[WorkflowDefinitionModel]:
|
||||
"""List all versions for a workflow, newest first."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowDefinitionModel)
|
||||
.where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.status.in_(
|
||||
["published", "draft", "archived"]
|
||||
),
|
||||
)
|
||||
.order_by(WorkflowDefinitionModel.version_number.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_all_workflows(
|
||||
self, user_id: int = None, organization_id: int = None, status: str = None
|
||||
) -> list[WorkflowModel]:
|
||||
|
|
@ -191,7 +417,10 @@ class WorkflowClient(BaseDBClient):
|
|||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.current_definition))
|
||||
.options(
|
||||
selectinload(WorkflowModel.current_definition),
|
||||
selectinload(WorkflowModel.released_definition),
|
||||
)
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
|
||||
|
|
@ -209,7 +438,10 @@ class WorkflowClient(BaseDBClient):
|
|||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.current_definition))
|
||||
.options(
|
||||
selectinload(WorkflowModel.current_definition),
|
||||
selectinload(WorkflowModel.released_definition),
|
||||
)
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
|
@ -227,11 +459,16 @@ class WorkflowClient(BaseDBClient):
|
|||
"""
|
||||
Update an existing workflow in the database.
|
||||
|
||||
Name changes are applied directly to the workflow.
|
||||
Definition/config/template_var changes are saved as a draft version
|
||||
via save_workflow_draft, keeping the published version unchanged.
|
||||
|
||||
Args:
|
||||
workflow_id: The ID of the workflow to update
|
||||
name: The new name for the workflow
|
||||
workflow_definition: The new workflow definition
|
||||
template_context_variables: The template context variables
|
||||
workflow_configurations: The workflow configurations
|
||||
user_id: The user ID (for backwards compatibility)
|
||||
organization_id: The organization ID
|
||||
|
||||
|
|
@ -249,10 +486,8 @@ class WorkflowClient(BaseDBClient):
|
|||
)
|
||||
|
||||
if organization_id:
|
||||
# Filter by organization_id when provided
|
||||
query = query.where(WorkflowModel.organization_id == organization_id)
|
||||
elif user_id:
|
||||
# Fallback to user_id for backwards compatibility
|
||||
query = query.where(WorkflowModel.user_id == user_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
|
|
@ -260,42 +495,38 @@ class WorkflowClient(BaseDBClient):
|
|||
if not workflow:
|
||||
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
||||
|
||||
# Name is a workflow-level field, not versioned
|
||||
if name is not None:
|
||||
workflow.name = name
|
||||
|
||||
if template_context_variables is not None:
|
||||
workflow.template_context_variables = template_context_variables
|
||||
|
||||
if workflow_configurations is not None:
|
||||
workflow.workflow_configurations = workflow_configurations
|
||||
|
||||
# In case of only name update, the workflow_definition can be None
|
||||
if workflow_definition:
|
||||
# Get or create new workflow definition
|
||||
definition = await self._get_or_create_workflow_definition(
|
||||
workflow_definition, session, workflow_id
|
||||
)
|
||||
|
||||
# Update legacy field for backwards compatibility
|
||||
workflow.workflow_definition = workflow_definition
|
||||
|
||||
# Mark new definition as current and reset others
|
||||
definition.is_current = True
|
||||
await session.execute(
|
||||
update(WorkflowDefinitionModel)
|
||||
.where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.id != definition.id,
|
||||
)
|
||||
.values(is_current=False)
|
||||
)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(workflow)
|
||||
|
||||
# Save versioned changes as a draft
|
||||
has_versioned_changes = any(
|
||||
v is not None
|
||||
for v in [
|
||||
workflow_definition,
|
||||
workflow_configurations,
|
||||
template_context_variables,
|
||||
]
|
||||
)
|
||||
if has_versioned_changes:
|
||||
await self.save_workflow_draft(
|
||||
workflow_id=workflow_id,
|
||||
workflow_definition=workflow_definition,
|
||||
workflow_configurations=workflow_configurations,
|
||||
template_context_variables=template_context_variables,
|
||||
)
|
||||
# Re-fetch with updated state
|
||||
workflow = await self.get_workflow(
|
||||
workflow_id, user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
async def get_workflows_by_ids(
|
||||
|
|
@ -353,7 +584,10 @@ class WorkflowClient(BaseDBClient):
|
|||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.current_definition))
|
||||
.options(
|
||||
selectinload(WorkflowModel.current_definition),
|
||||
selectinload(WorkflowModel.released_definition),
|
||||
)
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ class WorkflowRunClient(BaseDBClient):
|
|||
gathered_context: dict = None,
|
||||
campaign_id: int = None,
|
||||
queued_run_id: int = None,
|
||||
use_draft: bool = False,
|
||||
) -> WorkflowRunModel:
|
||||
async with self.async_session() as session:
|
||||
# Get workflow and user to check organization
|
||||
|
|
@ -44,41 +45,51 @@ class WorkflowRunClient(BaseDBClient):
|
|||
if not workflow:
|
||||
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
||||
|
||||
# # Check quota if user has an organization
|
||||
# if workflow.user and workflow.user.selected_organization_id:
|
||||
# # Import here to avoid circular dependency
|
||||
# from api.db.organization_usage_client import OrganizationUsageClient
|
||||
# Resolve which definition to bind to this run
|
||||
target_def = None
|
||||
|
||||
# usage_client = OrganizationUsageClient()
|
||||
|
||||
# # Check quota (no reservation for now, actual cost will be added after completion)
|
||||
# has_quota = await usage_client.check_and_reserve_quota(
|
||||
# workflow.user.selected_organization_id, estimated_tokens=0
|
||||
# )
|
||||
|
||||
# if not has_quota:
|
||||
# raise ValueError(
|
||||
# "Organization quota exceeded. Please contact your administrator."
|
||||
# )
|
||||
|
||||
# Fetch the current definition for this workflow
|
||||
current_def_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow.id,
|
||||
WorkflowDefinitionModel.is_current == True,
|
||||
if use_draft:
|
||||
# For test calls: prefer draft if it exists, fall back to published
|
||||
draft_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow.id,
|
||||
WorkflowDefinitionModel.status == "draft",
|
||||
)
|
||||
)
|
||||
)
|
||||
current_def = current_def_result.scalars().first()
|
||||
target_def = draft_result.scalars().first()
|
||||
|
||||
if target_def is None:
|
||||
# Use the published version via released_definition_id (preferred)
|
||||
# or fall back to is_current for backward compatibility
|
||||
if workflow.released_definition_id:
|
||||
target_def = await session.get(
|
||||
WorkflowDefinitionModel, workflow.released_definition_id
|
||||
)
|
||||
else:
|
||||
pub_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow.id,
|
||||
WorkflowDefinitionModel.is_current == True,
|
||||
)
|
||||
)
|
||||
target_def = pub_result.scalars().first()
|
||||
|
||||
# Get the current storage backend based on ENABLE_AWS_S3 flag
|
||||
current_backend = StorageBackend.get_current_backend()
|
||||
|
||||
# Use initial_context from the version if available, else from workflow
|
||||
default_context = (
|
||||
target_def.template_context_variables
|
||||
if target_def and target_def.template_context_variables
|
||||
else workflow.template_context_variables
|
||||
)
|
||||
|
||||
new_run = WorkflowRunModel(
|
||||
name=name,
|
||||
workflow=workflow,
|
||||
mode=mode,
|
||||
definition_id=current_def.id if current_def else None,
|
||||
initial_context=initial_context or workflow.template_context_variables,
|
||||
definition_id=target_def.id if target_def else None,
|
||||
initial_context=initial_context or default_context,
|
||||
gathered_context=gathered_context or {},
|
||||
campaign_id=campaign_id,
|
||||
queued_run_id=queued_run_id,
|
||||
|
|
@ -189,7 +200,11 @@ class WorkflowRunClient(BaseDBClient):
|
|||
self, run_id: int, user_id: int = None, organization_id: int = None
|
||||
) -> WorkflowRunModel | None:
|
||||
async with self.async_session() as session:
|
||||
query = select(WorkflowRunModel).join(WorkflowRunModel.workflow)
|
||||
query = (
|
||||
select(WorkflowRunModel)
|
||||
.options(selectinload(WorkflowRunModel.definition))
|
||||
.join(WorkflowRunModel.workflow)
|
||||
)
|
||||
|
||||
if organization_id:
|
||||
# Filter by organization_id when provided
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue