mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
feat: agent versioning and model configurations override (#227)
* feat: add tests and migrations * feat: workflow versioning among published and draft * feat: add a new settings page to simplify workflow detail page * fix: fix tsclient generation
This commit is contained in:
parent
f5fa9ce717
commit
38d1d928b7
62 changed files with 10158 additions and 3131 deletions
|
|
@ -0,0 +1,132 @@
|
|||
"""add versioning in workflow definitions
|
||||
|
||||
Revision ID: a399b39479fe
|
||||
Revises: c71db647d354
|
||||
Create Date: 2026-04-07 14:43:50.042973
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "a399b39479fe"
|
||||
down_revision: Union[str, None] = "c71db647d354"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_api_keys_key_hash"), table_name="api_keys")
|
||||
op.create_index(op.f("ix_api_keys_key_hash"), "api_keys", ["key_hash"], unique=True)
|
||||
op.add_column(
|
||||
"workflow_definitions",
|
||||
sa.Column(
|
||||
"status", sa.String(), server_default=sa.text("'published'"), nullable=False
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"workflow_definitions", sa.Column("version_number", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"workflow_definitions",
|
||||
sa.Column("published_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"workflow_definitions",
|
||||
sa.Column(
|
||||
"workflow_configurations",
|
||||
sa.JSON(),
|
||||
server_default=sa.text("'{}'::json"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"workflow_definitions",
|
||||
sa.Column(
|
||||
"template_context_variables",
|
||||
sa.JSON(),
|
||||
server_default=sa.text("'{}'::json"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"workflow_definitions",
|
||||
sa.Column(
|
||||
"call_disposition_codes",
|
||||
sa.JSON(),
|
||||
server_default=sa.text("'{}'::json"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.alter_column(
|
||||
"workflow_definitions",
|
||||
"workflow_hash",
|
||||
existing_type=sa.VARCHAR(),
|
||||
nullable=True,
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_workflow_hash_workflow_id"), table_name="workflow_definitions"
|
||||
)
|
||||
op.drop_constraint(
|
||||
op.f("uq_workflow_hash_workflow_id"), "workflow_definitions", type_="unique"
|
||||
)
|
||||
op.create_index(
|
||||
"ix_workflow_definitions_workflow_status",
|
||||
"workflow_definitions",
|
||||
["workflow_id", "status"],
|
||||
unique=False,
|
||||
)
|
||||
op.add_column(
|
||||
"workflows", sa.Column("released_definition_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
None,
|
||||
"workflows",
|
||||
"workflow_definitions",
|
||||
["released_definition_id"],
|
||||
["id"],
|
||||
use_alter=True,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint(None, "workflows", type_="foreignkey")
|
||||
op.drop_column("workflows", "released_definition_id")
|
||||
op.drop_index(
|
||||
"ix_workflow_definitions_workflow_status", table_name="workflow_definitions"
|
||||
)
|
||||
op.create_unique_constraint(
|
||||
op.f("uq_workflow_hash_workflow_id"),
|
||||
"workflow_definitions",
|
||||
["workflow_hash", "workflow_id"],
|
||||
postgresql_nulls_not_distinct=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_workflow_hash_workflow_id"),
|
||||
"workflow_definitions",
|
||||
["workflow_hash", "workflow_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"workflow_definitions",
|
||||
"workflow_hash",
|
||||
existing_type=sa.VARCHAR(),
|
||||
nullable=False,
|
||||
)
|
||||
op.drop_column("workflow_definitions", "call_disposition_codes")
|
||||
op.drop_column("workflow_definitions", "template_context_variables")
|
||||
op.drop_column("workflow_definitions", "workflow_configurations")
|
||||
op.drop_column("workflow_definitions", "published_at")
|
||||
op.drop_column("workflow_definitions", "version_number")
|
||||
op.drop_column("workflow_definitions", "status")
|
||||
op.drop_index(op.f("ix_api_keys_key_hash"), table_name="api_keys")
|
||||
op.create_index(
|
||||
op.f("ix_api_keys_key_hash"), "api_keys", ["key_hash"], unique=False
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
"""backfill workflow definition versioning
|
||||
|
||||
Copy workflow_configurations, template_context_variables, call_disposition_codes
|
||||
from the workflows table into the is_current=True definition for each workflow.
|
||||
Set that definition as status='published', version_number=1.
|
||||
Set all other definitions to status='archived'.
|
||||
Point workflows.released_definition_id to the published definition.
|
||||
|
||||
Revision ID: d688d0da1123
|
||||
Revises: a399b39479fe
|
||||
Create Date: 2026-04-07 15:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d688d0da1123"
|
||||
down_revision: Union[str, None] = "a399b39479fe"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Step 1: For each workflow's is_current=True definition, copy configs from
|
||||
# the workflow table and mark as published with version_number=1.
|
||||
conn.execute(
|
||||
sa.text("""
|
||||
UPDATE workflow_definitions wd
|
||||
SET
|
||||
workflow_configurations = w.workflow_configurations,
|
||||
template_context_variables = w.template_context_variables,
|
||||
status = 'published',
|
||||
version_number = 1,
|
||||
published_at = wd.created_at
|
||||
FROM workflows w
|
||||
WHERE wd.workflow_id = w.id
|
||||
AND wd.is_current = true
|
||||
""")
|
||||
)
|
||||
|
||||
# Step 2: Mark all pre-versioning non-current definitions as legacy.
|
||||
conn.execute(
|
||||
sa.text("""
|
||||
UPDATE workflow_definitions
|
||||
SET status = 'legacy'
|
||||
WHERE is_current = false
|
||||
""")
|
||||
)
|
||||
|
||||
# Step 3: Set released_definition_id on workflows to their published definition.
|
||||
conn.execute(
|
||||
sa.text("""
|
||||
UPDATE workflows w
|
||||
SET released_definition_id = wd.id
|
||||
FROM workflow_definitions wd
|
||||
WHERE wd.workflow_id = w.id
|
||||
AND wd.is_current = true
|
||||
""")
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Clear the released pointer
|
||||
conn.execute(
|
||||
sa.text("""
|
||||
UPDATE workflows SET released_definition_id = NULL
|
||||
""")
|
||||
)
|
||||
|
||||
# Reset all definitions back to server defaults
|
||||
conn.execute(
|
||||
sa.text("""
|
||||
UPDATE workflow_definitions
|
||||
SET
|
||||
status = 'published',
|
||||
version_number = NULL,
|
||||
published_at = NULL,
|
||||
workflow_configurations = '{}',
|
||||
template_context_variables = '{}'
|
||||
""")
|
||||
)
|
||||
|
|
@ -9,6 +9,7 @@ from api.db.base_client import BaseDBClient
|
|||
from api.db.models import (
|
||||
LoopTalkConversation,
|
||||
LoopTalkTestSession,
|
||||
WorkflowModel,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -50,8 +51,12 @@ class LoopTalkClient(BaseDBClient):
|
|||
result = await session.execute(
|
||||
select(LoopTalkTestSession)
|
||||
.options(
|
||||
selectinload(LoopTalkTestSession.actor_workflow),
|
||||
selectinload(LoopTalkTestSession.adversary_workflow),
|
||||
selectinload(LoopTalkTestSession.actor_workflow).selectinload(
|
||||
WorkflowModel.released_definition
|
||||
),
|
||||
selectinload(LoopTalkTestSession.adversary_workflow).selectinload(
|
||||
WorkflowModel.released_definition
|
||||
),
|
||||
selectinload(LoopTalkTestSession.conversations),
|
||||
)
|
||||
.where(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from loguru import logger
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
|
|
@ -199,7 +198,7 @@ class IntegrationModel(Base):
|
|||
class WorkflowDefinitionModel(Base):
|
||||
__tablename__ = "workflow_definitions"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
workflow_hash = Column(String, nullable=False)
|
||||
workflow_hash = Column(String, nullable=True) # Legacy, no longer used
|
||||
workflow_json = Column(JSON, nullable=False, default=dict)
|
||||
workflow_id = Column(Integer, ForeignKey("workflows.id"), nullable=True)
|
||||
is_current = Column(
|
||||
|
|
@ -207,12 +206,29 @@ class WorkflowDefinitionModel(Base):
|
|||
)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
# Table constraints and indexes
|
||||
# Versioning columns
|
||||
status = Column(
|
||||
String,
|
||||
nullable=False,
|
||||
default="published",
|
||||
server_default=text("'published'"),
|
||||
) # draft | published | archived
|
||||
version_number = Column(
|
||||
Integer, nullable=True
|
||||
) # Sequential per workflow, display only
|
||||
published_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Full behavioral snapshot (moved from WorkflowModel to enable versioning)
|
||||
workflow_configurations = Column(
|
||||
JSON, nullable=False, default=dict, server_default=text("'{}'::json")
|
||||
)
|
||||
template_context_variables = Column(
|
||||
JSON, nullable=False, default=dict, server_default=text("'{}'::json")
|
||||
)
|
||||
|
||||
# Table constraints and indexes — unique hash constraint removed (no more dedup)
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"workflow_hash", "workflow_id", name="uq_workflow_hash_workflow_id"
|
||||
),
|
||||
Index("ix_workflow_hash_workflow_id", "workflow_hash", "workflow_id"),
|
||||
Index("ix_workflow_definitions_workflow_status", "workflow_id", "status"),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
|
|
@ -247,6 +263,19 @@ class WorkflowModel(Base):
|
|||
runs = relationship("WorkflowRunModel", back_populates="workflow")
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
# Pointer to the currently-live (published) version
|
||||
released_definition_id = Column(
|
||||
Integer,
|
||||
ForeignKey("workflow_definitions.id", use_alter=True),
|
||||
nullable=True,
|
||||
)
|
||||
released_definition = relationship(
|
||||
"WorkflowDefinitionModel",
|
||||
foreign_keys=[released_definition_id],
|
||||
uselist=False,
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
# All versions / historical definitions of this workflow
|
||||
definitions = relationship(
|
||||
"WorkflowDefinitionModel",
|
||||
|
|
@ -255,6 +284,7 @@ class WorkflowModel(Base):
|
|||
)
|
||||
|
||||
# Relationship to fetch the current (is_current=True) definition
|
||||
# Kept for backward compatibility during transition
|
||||
current_definition = relationship(
|
||||
"WorkflowDefinitionModel",
|
||||
primaryjoin=lambda: and_(
|
||||
|
|
@ -277,36 +307,6 @@ class WorkflowModel(Base):
|
|||
# that scenario so callers can handle the absence explicitly.
|
||||
return None
|
||||
|
||||
@property
|
||||
def workflow_definition_with_fallback(self):
|
||||
"""
|
||||
Get workflow definition with fallback to legacy workflow_definition field.
|
||||
|
||||
Returns:
|
||||
dict: The workflow definition JSON
|
||||
"""
|
||||
# Access the relationship only if it has ALREADY been eagerly loaded on this
|
||||
# instance to avoid triggering an implicit lazy load once the SQLAlchemy
|
||||
# Session has been closed (which would raise a DetachedInstanceError).
|
||||
|
||||
# ``__dict__`` will contain "current_definition" **only** when the attribute
|
||||
# has been populated (e.g. via `selectinload` or an explicit access while
|
||||
# the session was still open). Using ``__dict__.get`` guarantees that we
|
||||
# do not accidentally issue a lazy load query on a detached instance.
|
||||
|
||||
current_definition = self.__dict__.get("current_definition")
|
||||
|
||||
if current_definition is not None:
|
||||
return current_definition.workflow_json
|
||||
|
||||
# Fallback for backwards-compatibility when the relationship is not (yet)
|
||||
# loaded. In this case we fall back to the legacy ``workflow_definition``
|
||||
# column that always contains the most recent definition JSON.
|
||||
logger.warning(
|
||||
f"Workflow {self.id} has no loaded current definition, using workflow_definition as fallback",
|
||||
)
|
||||
return self.workflow_definition
|
||||
|
||||
|
||||
class WorkflowTemplates(Base):
|
||||
__tablename__ = "workflow_templates"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import hashlib
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
|
@ -12,41 +11,15 @@ from api.db.models import WorkflowDefinitionModel, WorkflowModel, WorkflowRunMod
|
|||
|
||||
|
||||
class WorkflowClient(BaseDBClient):
|
||||
def _generate_workflow_hash(self, workflow_definition: dict) -> str:
|
||||
"""Generate a consistent hash for workflow definition."""
|
||||
# Convert to JSON with sorted keys for consistent hashing
|
||||
json_str = json.dumps(
|
||||
workflow_definition, sort_keys=True, separators=(",", ":")
|
||||
)
|
||||
return hashlib.sha256(json_str.encode()).hexdigest()
|
||||
|
||||
async def _get_or_create_workflow_definition(
|
||||
self, workflow_definition: dict, session, workflow_id: int = None
|
||||
) -> WorkflowDefinitionModel:
|
||||
"""Get existing workflow definition by hash or create a new one."""
|
||||
workflow_hash = self._generate_workflow_hash(workflow_definition)
|
||||
|
||||
# Try to find existing definition
|
||||
async def _next_version_number(self, session, workflow_id: int) -> int:
|
||||
"""Get the next version number for a workflow."""
|
||||
result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_hash == workflow_hash,
|
||||
select(func.max(WorkflowDefinitionModel.version_number)).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
)
|
||||
)
|
||||
existing_definition = result.scalars().first()
|
||||
|
||||
if existing_definition:
|
||||
return existing_definition
|
||||
|
||||
# Create new definition if it doesn't exist
|
||||
new_definition = WorkflowDefinitionModel(
|
||||
workflow_hash=workflow_hash,
|
||||
workflow_json=workflow_definition,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
session.add(new_definition)
|
||||
await session.flush() # Flush to get the ID without committing
|
||||
return new_definition
|
||||
current_max = result.scalar()
|
||||
return (current_max or 0) + 1
|
||||
|
||||
async def create_workflow(
|
||||
self,
|
||||
|
|
@ -66,21 +39,23 @@ class WorkflowClient(BaseDBClient):
|
|||
session.add(new_workflow)
|
||||
await session.flush() # Flush to get the workflow ID
|
||||
|
||||
# Now get or create workflow definition with the workflow_id
|
||||
definition = await self._get_or_create_workflow_definition(
|
||||
workflow_definition, session, new_workflow.id
|
||||
# Create the first definition as V1 published
|
||||
definition = WorkflowDefinitionModel(
|
||||
workflow_json=workflow_definition,
|
||||
workflow_id=new_workflow.id,
|
||||
is_current=True,
|
||||
status="published",
|
||||
version_number=1,
|
||||
published_at=datetime.now(UTC),
|
||||
workflow_configurations=new_workflow.workflow_configurations or {},
|
||||
template_context_variables=new_workflow.template_context_variables
|
||||
or {},
|
||||
)
|
||||
session.add(definition)
|
||||
await session.flush()
|
||||
|
||||
# Mark this definition as the current one and unset others
|
||||
definition.is_current = True
|
||||
await session.execute(
|
||||
update(WorkflowDefinitionModel)
|
||||
.where(
|
||||
WorkflowDefinitionModel.workflow_id == new_workflow.id,
|
||||
WorkflowDefinitionModel.id != definition.id,
|
||||
)
|
||||
.values(is_current=False)
|
||||
)
|
||||
# Set the released pointer
|
||||
new_workflow.released_definition_id = definition.id
|
||||
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
|
|
@ -89,6 +64,257 @@ class WorkflowClient(BaseDBClient):
|
|||
await session.refresh(new_workflow)
|
||||
return new_workflow
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Versioning methods
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def save_workflow_draft(
|
||||
self,
|
||||
workflow_id: int,
|
||||
workflow_definition: dict | None = None,
|
||||
workflow_configurations: dict | None = None,
|
||||
template_context_variables: dict | None = None,
|
||||
) -> WorkflowDefinitionModel:
|
||||
"""Create or update a draft version for this workflow.
|
||||
|
||||
If a draft already exists, it is updated in place.
|
||||
If no draft exists, a new one is created with the next version number.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Check for existing draft
|
||||
result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.status == "draft",
|
||||
)
|
||||
)
|
||||
draft = result.scalars().first()
|
||||
|
||||
if draft:
|
||||
# Update existing draft in place
|
||||
if workflow_definition is not None:
|
||||
draft.workflow_json = workflow_definition
|
||||
if workflow_configurations is not None:
|
||||
draft.workflow_configurations = workflow_configurations
|
||||
if template_context_variables is not None:
|
||||
draft.template_context_variables = template_context_variables
|
||||
else:
|
||||
# Get current published to use as base for unspecified fields
|
||||
pub_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.status == "published",
|
||||
)
|
||||
)
|
||||
published = pub_result.scalars().first()
|
||||
|
||||
next_version = await self._next_version_number(session, workflow_id)
|
||||
|
||||
draft = WorkflowDefinitionModel(
|
||||
workflow_id=workflow_id,
|
||||
workflow_json=workflow_definition
|
||||
if workflow_definition is not None
|
||||
else (published.workflow_json if published else {}),
|
||||
workflow_configurations=workflow_configurations
|
||||
if workflow_configurations is not None
|
||||
else (published.workflow_configurations if published else {}),
|
||||
template_context_variables=template_context_variables
|
||||
if template_context_variables is not None
|
||||
else (published.template_context_variables if published else {}),
|
||||
status="draft",
|
||||
version_number=next_version,
|
||||
is_current=False,
|
||||
)
|
||||
session.add(draft)
|
||||
|
||||
# Keep legacy columns on workflows table in sync with draft
|
||||
wf_result = await session.execute(
|
||||
select(WorkflowModel).where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
workflow = wf_result.scalars().first()
|
||||
if workflow:
|
||||
workflow.workflow_definition = draft.workflow_json
|
||||
workflow.workflow_configurations = draft.workflow_configurations
|
||||
workflow.template_context_variables = draft.template_context_variables
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(draft)
|
||||
return draft
|
||||
|
||||
async def publish_workflow_draft(
|
||||
self,
|
||||
workflow_id: int,
|
||||
) -> WorkflowDefinitionModel:
|
||||
"""Promote the current draft to published.
|
||||
|
||||
- Draft → published
|
||||
- Previous published → archived
|
||||
- Updates released_definition_id on the workflow
|
||||
- Sets is_current for backward compatibility
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Find the draft
|
||||
result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.status == "draft",
|
||||
)
|
||||
)
|
||||
draft = result.scalars().first()
|
||||
if not draft:
|
||||
raise ValueError(f"No draft exists for workflow {workflow_id}")
|
||||
|
||||
# Archive the current published version
|
||||
await session.execute(
|
||||
update(WorkflowDefinitionModel)
|
||||
.where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.status == "published",
|
||||
)
|
||||
.values(status="archived", is_current=False)
|
||||
)
|
||||
|
||||
# Promote draft → published
|
||||
draft.status = "published"
|
||||
draft.published_at = datetime.now(UTC)
|
||||
draft.is_current = True
|
||||
|
||||
# Update workflow's released pointer + legacy fields
|
||||
wf_result = await session.execute(
|
||||
select(WorkflowModel).where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
workflow = wf_result.scalars().first()
|
||||
workflow.released_definition_id = draft.id
|
||||
workflow.workflow_definition = draft.workflow_json
|
||||
workflow.workflow_configurations = draft.workflow_configurations
|
||||
workflow.template_context_variables = draft.template_context_variables
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(draft)
|
||||
return draft
|
||||
|
||||
async def discard_workflow_draft(
|
||||
self,
|
||||
workflow_id: int,
|
||||
) -> None:
|
||||
"""Delete the current draft version."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.status == "draft",
|
||||
)
|
||||
)
|
||||
draft = result.scalars().first()
|
||||
if not draft:
|
||||
raise ValueError(f"No draft exists for workflow {workflow_id}")
|
||||
|
||||
await session.delete(draft)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
async def revert_to_version(
|
||||
self,
|
||||
workflow_id: int,
|
||||
definition_id: int,
|
||||
) -> WorkflowDefinitionModel:
|
||||
"""Create a new draft from an archived version's snapshot.
|
||||
|
||||
Raises ValueError if a draft already exists (must discard first).
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Ensure no existing draft
|
||||
draft_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.status == "draft",
|
||||
)
|
||||
)
|
||||
if draft_result.scalars().first():
|
||||
raise ValueError(
|
||||
f"Draft already exists for workflow {workflow_id}. "
|
||||
"Discard it before reverting."
|
||||
)
|
||||
|
||||
# Fetch the source version
|
||||
source_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.id == definition_id,
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
)
|
||||
)
|
||||
source = source_result.scalars().first()
|
||||
if not source:
|
||||
raise ValueError(
|
||||
f"Version {definition_id} not found for workflow {workflow_id}"
|
||||
)
|
||||
|
||||
next_version = await self._next_version_number(session, workflow_id)
|
||||
|
||||
# Create new draft from the source snapshot
|
||||
draft = WorkflowDefinitionModel(
|
||||
workflow_id=workflow_id,
|
||||
workflow_json=source.workflow_json,
|
||||
workflow_configurations=source.workflow_configurations,
|
||||
template_context_variables=source.template_context_variables,
|
||||
status="draft",
|
||||
version_number=next_version,
|
||||
is_current=False,
|
||||
)
|
||||
session.add(draft)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(draft)
|
||||
return draft
|
||||
|
||||
async def get_draft_version(
|
||||
self,
|
||||
workflow_id: int,
|
||||
) -> WorkflowDefinitionModel | None:
|
||||
"""Get the draft version for a workflow, or None if no draft exists."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.status == "draft",
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_workflow_versions(
|
||||
self,
|
||||
workflow_id: int,
|
||||
) -> list[WorkflowDefinitionModel]:
|
||||
"""List all versions for a workflow, newest first."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowDefinitionModel)
|
||||
.where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.status.in_(
|
||||
["published", "draft", "archived"]
|
||||
),
|
||||
)
|
||||
.order_by(WorkflowDefinitionModel.version_number.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_all_workflows(
|
||||
self, user_id: int = None, organization_id: int = None, status: str = None
|
||||
) -> list[WorkflowModel]:
|
||||
|
|
@ -191,7 +417,10 @@ class WorkflowClient(BaseDBClient):
|
|||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.current_definition))
|
||||
.options(
|
||||
selectinload(WorkflowModel.current_definition),
|
||||
selectinload(WorkflowModel.released_definition),
|
||||
)
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
|
||||
|
|
@ -209,7 +438,10 @@ class WorkflowClient(BaseDBClient):
|
|||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.current_definition))
|
||||
.options(
|
||||
selectinload(WorkflowModel.current_definition),
|
||||
selectinload(WorkflowModel.released_definition),
|
||||
)
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
|
@ -227,11 +459,16 @@ class WorkflowClient(BaseDBClient):
|
|||
"""
|
||||
Update an existing workflow in the database.
|
||||
|
||||
Name changes are applied directly to the workflow.
|
||||
Definition/config/template_var changes are saved as a draft version
|
||||
via save_workflow_draft, keeping the published version unchanged.
|
||||
|
||||
Args:
|
||||
workflow_id: The ID of the workflow to update
|
||||
name: The new name for the workflow
|
||||
workflow_definition: The new workflow definition
|
||||
template_context_variables: The template context variables
|
||||
workflow_configurations: The workflow configurations
|
||||
user_id: The user ID (for backwards compatibility)
|
||||
organization_id: The organization ID
|
||||
|
||||
|
|
@ -249,10 +486,8 @@ class WorkflowClient(BaseDBClient):
|
|||
)
|
||||
|
||||
if organization_id:
|
||||
# Filter by organization_id when provided
|
||||
query = query.where(WorkflowModel.organization_id == organization_id)
|
||||
elif user_id:
|
||||
# Fallback to user_id for backwards compatibility
|
||||
query = query.where(WorkflowModel.user_id == user_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
|
|
@ -260,42 +495,38 @@ class WorkflowClient(BaseDBClient):
|
|||
if not workflow:
|
||||
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
||||
|
||||
# Name is a workflow-level field, not versioned
|
||||
if name is not None:
|
||||
workflow.name = name
|
||||
|
||||
if template_context_variables is not None:
|
||||
workflow.template_context_variables = template_context_variables
|
||||
|
||||
if workflow_configurations is not None:
|
||||
workflow.workflow_configurations = workflow_configurations
|
||||
|
||||
# In case of only name update, the workflow_definition can be None
|
||||
if workflow_definition:
|
||||
# Get or create new workflow definition
|
||||
definition = await self._get_or_create_workflow_definition(
|
||||
workflow_definition, session, workflow_id
|
||||
)
|
||||
|
||||
# Update legacy field for backwards compatibility
|
||||
workflow.workflow_definition = workflow_definition
|
||||
|
||||
# Mark new definition as current and reset others
|
||||
definition.is_current = True
|
||||
await session.execute(
|
||||
update(WorkflowDefinitionModel)
|
||||
.where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow_id,
|
||||
WorkflowDefinitionModel.id != definition.id,
|
||||
)
|
||||
.values(is_current=False)
|
||||
)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(workflow)
|
||||
|
||||
# Save versioned changes as a draft
|
||||
has_versioned_changes = any(
|
||||
v is not None
|
||||
for v in [
|
||||
workflow_definition,
|
||||
workflow_configurations,
|
||||
template_context_variables,
|
||||
]
|
||||
)
|
||||
if has_versioned_changes:
|
||||
await self.save_workflow_draft(
|
||||
workflow_id=workflow_id,
|
||||
workflow_definition=workflow_definition,
|
||||
workflow_configurations=workflow_configurations,
|
||||
template_context_variables=template_context_variables,
|
||||
)
|
||||
# Re-fetch with updated state
|
||||
workflow = await self.get_workflow(
|
||||
workflow_id, user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
async def get_workflows_by_ids(
|
||||
|
|
@ -353,7 +584,10 @@ class WorkflowClient(BaseDBClient):
|
|||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.current_definition))
|
||||
.options(
|
||||
selectinload(WorkflowModel.current_definition),
|
||||
selectinload(WorkflowModel.released_definition),
|
||||
)
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ class WorkflowRunClient(BaseDBClient):
|
|||
gathered_context: dict = None,
|
||||
campaign_id: int = None,
|
||||
queued_run_id: int = None,
|
||||
use_draft: bool = False,
|
||||
) -> WorkflowRunModel:
|
||||
async with self.async_session() as session:
|
||||
# Get workflow and user to check organization
|
||||
|
|
@ -44,41 +45,51 @@ class WorkflowRunClient(BaseDBClient):
|
|||
if not workflow:
|
||||
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
||||
|
||||
# # Check quota if user has an organization
|
||||
# if workflow.user and workflow.user.selected_organization_id:
|
||||
# # Import here to avoid circular dependency
|
||||
# from api.db.organization_usage_client import OrganizationUsageClient
|
||||
# Resolve which definition to bind to this run
|
||||
target_def = None
|
||||
|
||||
# usage_client = OrganizationUsageClient()
|
||||
|
||||
# # Check quota (no reservation for now, actual cost will be added after completion)
|
||||
# has_quota = await usage_client.check_and_reserve_quota(
|
||||
# workflow.user.selected_organization_id, estimated_tokens=0
|
||||
# )
|
||||
|
||||
# if not has_quota:
|
||||
# raise ValueError(
|
||||
# "Organization quota exceeded. Please contact your administrator."
|
||||
# )
|
||||
|
||||
# Fetch the current definition for this workflow
|
||||
current_def_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow.id,
|
||||
WorkflowDefinitionModel.is_current == True,
|
||||
if use_draft:
|
||||
# For test calls: prefer draft if it exists, fall back to published
|
||||
draft_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow.id,
|
||||
WorkflowDefinitionModel.status == "draft",
|
||||
)
|
||||
)
|
||||
)
|
||||
current_def = current_def_result.scalars().first()
|
||||
target_def = draft_result.scalars().first()
|
||||
|
||||
if target_def is None:
|
||||
# Use the published version via released_definition_id (preferred)
|
||||
# or fall back to is_current for backward compatibility
|
||||
if workflow.released_definition_id:
|
||||
target_def = await session.get(
|
||||
WorkflowDefinitionModel, workflow.released_definition_id
|
||||
)
|
||||
else:
|
||||
pub_result = await session.execute(
|
||||
select(WorkflowDefinitionModel).where(
|
||||
WorkflowDefinitionModel.workflow_id == workflow.id,
|
||||
WorkflowDefinitionModel.is_current == True,
|
||||
)
|
||||
)
|
||||
target_def = pub_result.scalars().first()
|
||||
|
||||
# Get the current storage backend based on ENABLE_AWS_S3 flag
|
||||
current_backend = StorageBackend.get_current_backend()
|
||||
|
||||
# Use initial_context from the version if available, else from workflow
|
||||
default_context = (
|
||||
target_def.template_context_variables
|
||||
if target_def and target_def.template_context_variables
|
||||
else workflow.template_context_variables
|
||||
)
|
||||
|
||||
new_run = WorkflowRunModel(
|
||||
name=name,
|
||||
workflow=workflow,
|
||||
mode=mode,
|
||||
definition_id=current_def.id if current_def else None,
|
||||
initial_context=initial_context or workflow.template_context_variables,
|
||||
definition_id=target_def.id if target_def else None,
|
||||
initial_context=initial_context or default_context,
|
||||
gathered_context=gathered_context or {},
|
||||
campaign_id=campaign_id,
|
||||
queued_run_id=queued_run_id,
|
||||
|
|
@ -189,7 +200,11 @@ class WorkflowRunClient(BaseDBClient):
|
|||
self, run_id: int, user_id: int = None, organization_id: int = None
|
||||
) -> WorkflowRunModel | None:
|
||||
async with self.async_session() as session:
|
||||
query = select(WorkflowRunModel).join(WorkflowRunModel.workflow)
|
||||
query = (
|
||||
select(WorkflowRunModel)
|
||||
.options(selectinload(WorkflowRunModel.definition))
|
||||
.join(WorkflowRunModel.workflow)
|
||||
)
|
||||
|
||||
if organization_id:
|
||||
# Filter by organization_id when provided
|
||||
|
|
|
|||
|
|
@ -295,7 +295,7 @@ async def create_campaign(
|
|||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
|
||||
workflow_def = workflow.workflow_definition_with_fallback
|
||||
workflow_def = workflow.released_definition.workflow_json
|
||||
if workflow_def:
|
||||
try:
|
||||
dto = ReactFlowDTO(**workflow_def)
|
||||
|
|
|
|||
|
|
@ -103,14 +103,13 @@ async def initiate_call(
|
|||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
# Get workflow definition (with fallback to legacy field)
|
||||
workflow_definition = workflow.workflow_definition_with_fallback
|
||||
workflow_definition = workflow.released_definition.workflow_json
|
||||
|
||||
# Validate trigger node still exists in the workflow definition
|
||||
if not trigger_exists_in_workflow(workflow_definition, uuid):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Trigger not found or has been removed from workflow",
|
||||
detail="Trigger not found in the published Agent",
|
||||
)
|
||||
|
||||
# 6. Get telephony provider for the organization
|
||||
|
|
|
|||
|
|
@ -143,7 +143,8 @@ class StatusCallbackRequest(BaseModel):
|
|||
async def initiate_call(
|
||||
request: InitiateCallRequest, user: UserModel = Depends(get_user)
|
||||
):
|
||||
"""Initiate a call using the configured telephony provider."""
|
||||
"""Initiate a call using the configured telephony provider from web browser. This is
|
||||
supposed to be a test call method for the draft version of the agent."""
|
||||
|
||||
# Get the telephony provider for the organization
|
||||
provider = await get_telephony_provider(user.selected_organization_id)
|
||||
|
|
@ -190,6 +191,7 @@ async def initiate_call(
|
|||
"called_number": phone_number,
|
||||
"provider": provider.PROVIDER_NAME,
|
||||
},
|
||||
use_draft=True,
|
||||
)
|
||||
workflow_run_id = workflow_run.id
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -15,10 +15,12 @@ from api.db.workflow_template_client import WorkflowTemplateClient
|
|||
from api.enums import CallType
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.masking import (
|
||||
mask_workflow_definition,
|
||||
merge_workflow_api_keys,
|
||||
)
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.duplicate import duplicate_workflow
|
||||
|
|
@ -104,6 +106,8 @@ class WorkflowResponse(BaseModel):
|
|||
call_disposition_codes: CallDispositionCodes | None = None
|
||||
total_runs: int | None = None
|
||||
workflow_configurations: dict | None = None
|
||||
version_number: int | None = None
|
||||
version_status: str | None = None
|
||||
|
||||
|
||||
class WorkflowListResponse(BaseModel):
|
||||
|
|
@ -149,6 +153,17 @@ class UpdateWorkflowRequest(BaseModel):
|
|||
workflow_configurations: dict | None = None
|
||||
|
||||
|
||||
class WorkflowVersionResponse(BaseModel):
|
||||
id: int
|
||||
version_number: int
|
||||
status: str
|
||||
created_at: datetime
|
||||
published_at: datetime | None = None
|
||||
workflow_json: dict
|
||||
workflow_configurations: dict | None = None
|
||||
template_context_variables: dict | None = None
|
||||
|
||||
|
||||
class UpdateWorkflowStatusRequest(BaseModel):
|
||||
status: str # "active" or "archived"
|
||||
|
||||
|
|
@ -200,8 +215,11 @@ async def validate_workflow(
|
|||
|
||||
errors: list[WorkflowError] = []
|
||||
|
||||
# Get workflow definition from WorkflowDefinition table, fallback to workflow_definition field
|
||||
workflow_definition = workflow.workflow_definition_with_fallback
|
||||
# Validate draft if it exists (user is editing), else validate published
|
||||
draft = await db_client.get_draft_version(workflow_id)
|
||||
workflow_definition = (
|
||||
draft.workflow_json if draft else workflow.released_definition.workflow_json
|
||||
)
|
||||
|
||||
# ----------- DTO Validation ------------
|
||||
dto: Optional[ReactFlowDTO] = None
|
||||
|
|
@ -282,9 +300,7 @@ async def create_workflow(
|
|||
"name": workflow.name,
|
||||
"status": workflow.status,
|
||||
"created_at": workflow.created_at,
|
||||
"workflow_definition": mask_workflow_definition(
|
||||
workflow.workflow_definition_with_fallback
|
||||
),
|
||||
"workflow_definition": mask_workflow_definition(request.workflow_definition),
|
||||
"current_definition_id": workflow.current_definition_id,
|
||||
"template_context_variables": workflow.template_context_variables,
|
||||
"call_disposition_codes": workflow.call_disposition_codes,
|
||||
|
|
@ -362,9 +378,7 @@ async def create_workflow_from_template(
|
|||
"name": workflow.name,
|
||||
"status": workflow.status,
|
||||
"created_at": workflow.created_at,
|
||||
"workflow_definition": mask_workflow_definition(
|
||||
workflow.workflow_definition_with_fallback
|
||||
),
|
||||
"workflow_definition": mask_workflow_definition(workflow_def),
|
||||
"current_definition_id": workflow.current_definition_id,
|
||||
"template_context_variables": workflow.template_context_variables,
|
||||
"call_disposition_codes": workflow.call_disposition_codes,
|
||||
|
|
@ -461,7 +475,11 @@ async def get_workflow(
|
|||
workflow_id: int,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> WorkflowResponse:
|
||||
"""Get a single workflow by ID"""
|
||||
"""Get a single workflow by ID.
|
||||
|
||||
If a draft version exists, returns the draft content for editing.
|
||||
Otherwise returns the published version's content.
|
||||
"""
|
||||
workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
|
|
@ -470,21 +488,123 @@ async def get_workflow(
|
|||
status_code=404, detail=f"Workflow with id {workflow_id} not found"
|
||||
)
|
||||
|
||||
# Check for draft — editor should show draft content if it exists
|
||||
draft = await db_client.get_draft_version(workflow_id)
|
||||
|
||||
if draft:
|
||||
workflow_def = draft.workflow_json
|
||||
workflow_configs = draft.workflow_configurations
|
||||
template_vars = draft.template_context_variables
|
||||
else:
|
||||
published = workflow.released_definition
|
||||
workflow_def = published.workflow_json
|
||||
workflow_configs = published.workflow_configurations
|
||||
template_vars = published.template_context_variables
|
||||
|
||||
active_def = draft or workflow.released_definition
|
||||
return {
|
||||
"id": workflow.id,
|
||||
"name": workflow.name,
|
||||
"status": workflow.status,
|
||||
"created_at": workflow.created_at,
|
||||
"workflow_definition": mask_workflow_definition(
|
||||
workflow.workflow_definition_with_fallback
|
||||
),
|
||||
"workflow_definition": mask_workflow_definition(workflow_def),
|
||||
"current_definition_id": workflow.current_definition_id,
|
||||
"template_context_variables": workflow.template_context_variables,
|
||||
"template_context_variables": template_vars,
|
||||
"call_disposition_codes": workflow.call_disposition_codes,
|
||||
"workflow_configurations": workflow.workflow_configurations,
|
||||
"workflow_configurations": workflow_configs,
|
||||
"version_number": active_def.version_number if active_def else None,
|
||||
"version_status": active_def.status if active_def else None,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{workflow_id}/versions")
|
||||
async def get_workflow_versions(
|
||||
workflow_id: int,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> list[WorkflowVersionResponse]:
|
||||
"""List all versions for a workflow, newest first."""
|
||||
workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if workflow is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Workflow with id {workflow_id} not found"
|
||||
)
|
||||
|
||||
versions = await db_client.get_workflow_versions(workflow_id)
|
||||
return [
|
||||
WorkflowVersionResponse(
|
||||
id=v.id,
|
||||
version_number=v.version_number,
|
||||
status=v.status,
|
||||
created_at=v.created_at,
|
||||
published_at=v.published_at,
|
||||
workflow_json=mask_workflow_definition(v.workflow_json),
|
||||
workflow_configurations=v.workflow_configurations,
|
||||
template_context_variables=v.template_context_variables,
|
||||
)
|
||||
for v in versions
|
||||
if v.version_number is not None
|
||||
]
|
||||
|
||||
|
||||
@router.post("/{workflow_id}/publish")
|
||||
async def publish_workflow(
|
||||
workflow_id: int,
|
||||
user: UserModel = Depends(get_user),
|
||||
):
|
||||
"""Publish the current draft version of a workflow."""
|
||||
workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if workflow is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Workflow with id {workflow_id} not found"
|
||||
)
|
||||
|
||||
try:
|
||||
published = await db_client.publish_workflow_draft(workflow_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return {
|
||||
"id": published.id,
|
||||
"version_number": published.version_number,
|
||||
"status": published.status,
|
||||
"published_at": published.published_at,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{workflow_id}/create-draft")
|
||||
async def create_workflow_draft(
|
||||
workflow_id: int,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> WorkflowVersionResponse:
|
||||
"""Create a draft version from the current published version.
|
||||
|
||||
If a draft already exists, returns the existing draft.
|
||||
"""
|
||||
workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if workflow is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Workflow with id {workflow_id} not found"
|
||||
)
|
||||
|
||||
draft = await db_client.save_workflow_draft(workflow_id)
|
||||
return WorkflowVersionResponse(
|
||||
id=draft.id,
|
||||
version_number=draft.version_number,
|
||||
status=draft.status,
|
||||
created_at=draft.created_at,
|
||||
published_at=draft.published_at,
|
||||
workflow_json=mask_workflow_definition(draft.workflow_json),
|
||||
workflow_configurations=draft.workflow_configurations,
|
||||
template_context_variables=draft.template_context_variables,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_workflows_summary(
|
||||
user: UserModel = Depends(get_user),
|
||||
|
|
@ -528,7 +648,7 @@ async def update_workflow_status(
|
|||
"status": workflow.status,
|
||||
"created_at": workflow.created_at,
|
||||
"workflow_definition": mask_workflow_definition(
|
||||
workflow.workflow_definition_with_fallback
|
||||
workflow.released_definition.workflow_json
|
||||
),
|
||||
"current_definition_id": workflow.current_definition_id,
|
||||
"template_context_variables": workflow.template_context_variables,
|
||||
|
|
@ -569,11 +689,37 @@ async def update_workflow(
|
|||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if existing_workflow:
|
||||
# Merge against what the user was editing (draft or published)
|
||||
existing_draft = await db_client.get_draft_version(workflow_id)
|
||||
existing_def = (
|
||||
existing_draft.workflow_json
|
||||
if existing_draft
|
||||
else existing_workflow.released_definition.workflow_json
|
||||
)
|
||||
workflow_definition = merge_workflow_api_keys(
|
||||
workflow_definition,
|
||||
existing_workflow.workflow_definition_with_fallback,
|
||||
existing_def,
|
||||
)
|
||||
|
||||
# Validate model_overrides: resolve onto global config, then
|
||||
# run the same validator used by the user-configurations endpoint.
|
||||
if request.workflow_configurations and request.workflow_configurations.get(
|
||||
"model_overrides"
|
||||
):
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
try:
|
||||
effective = resolve_effective_config(
|
||||
user_config,
|
||||
request.workflow_configurations["model_overrides"],
|
||||
)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
|
||||
workflow = await db_client.update_workflow(
|
||||
workflow_id=workflow_id,
|
||||
name=request.name,
|
||||
|
|
@ -592,19 +738,35 @@ async def update_workflow(
|
|||
trigger_paths=trigger_paths,
|
||||
)
|
||||
|
||||
# Return draft content if one exists (save creates a draft)
|
||||
draft = await db_client.get_draft_version(workflow_id)
|
||||
if draft:
|
||||
workflow_def = draft.workflow_json
|
||||
workflow_configs = draft.workflow_configurations
|
||||
template_vars = draft.template_context_variables
|
||||
else:
|
||||
published = workflow.released_definition
|
||||
workflow_def = published.workflow_json
|
||||
workflow_configs = published.workflow_configurations
|
||||
template_vars = published.template_context_variables
|
||||
|
||||
# Include version info from the active definition (draft or published)
|
||||
active_def = draft or workflow.released_definition
|
||||
return {
|
||||
"id": workflow.id,
|
||||
"name": workflow.name,
|
||||
"status": workflow.status,
|
||||
"created_at": workflow.created_at,
|
||||
"workflow_definition": mask_workflow_definition(
|
||||
workflow.workflow_definition_with_fallback
|
||||
),
|
||||
"workflow_definition": mask_workflow_definition(workflow_def),
|
||||
"current_definition_id": workflow.current_definition_id,
|
||||
"template_context_variables": workflow.template_context_variables,
|
||||
"template_context_variables": template_vars,
|
||||
"call_disposition_codes": workflow.call_disposition_codes,
|
||||
"workflow_configurations": workflow.workflow_configurations,
|
||||
"workflow_configurations": workflow_configs,
|
||||
"version_number": active_def.version_number if active_def else None,
|
||||
"version_status": active_def.status if active_def else None,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
|
|
@ -629,7 +791,7 @@ async def duplicate_workflow_endpoint(
|
|||
"status": workflow.status,
|
||||
"created_at": workflow.created_at,
|
||||
"workflow_definition": mask_workflow_definition(
|
||||
workflow.workflow_definition_with_fallback
|
||||
workflow.released_definition.workflow_json
|
||||
),
|
||||
"current_definition_id": workflow.current_definition_id,
|
||||
"template_context_variables": workflow.template_context_variables,
|
||||
|
|
@ -658,7 +820,7 @@ async def create_workflow_run(
|
|||
user: The user to create the workflow run for
|
||||
"""
|
||||
run = await db_client.create_workflow_run(
|
||||
request.name, workflow_id, request.mode, user.id
|
||||
request.name, workflow_id, request.mode, user.id, use_draft=True
|
||||
)
|
||||
return {
|
||||
"id": run.id,
|
||||
|
|
@ -862,9 +1024,7 @@ async def duplicate_workflow_template(
|
|||
"name": workflow.name,
|
||||
"status": workflow.status,
|
||||
"created_at": workflow.created_at,
|
||||
"workflow_definition": mask_workflow_definition(
|
||||
workflow.workflow_definition_with_fallback
|
||||
),
|
||||
"workflow_definition": mask_workflow_definition(workflow_def),
|
||||
"current_definition_id": workflow.current_definition_id,
|
||||
"template_context_variables": workflow.template_context_variables,
|
||||
"call_disposition_codes": workflow.call_disposition_codes,
|
||||
|
|
|
|||
83
api/services/configuration/resolve.py
Normal file
83
api/services/configuration/resolve.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
"""Resolve effective config by merging per-workflow model overrides onto global config."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.registry import (
|
||||
REGISTRY,
|
||||
ServiceType,
|
||||
)
|
||||
|
||||
# Maps override key → (UserConfiguration field, ServiceType for registry lookup)
|
||||
_SECTION_MAP: dict[str, ServiceType] = {
|
||||
"llm": ServiceType.LLM,
|
||||
"tts": ServiceType.TTS,
|
||||
"stt": ServiceType.STT,
|
||||
"realtime": ServiceType.REALTIME,
|
||||
}
|
||||
|
||||
|
||||
def _build_section_from_override(service_type: ServiceType, override: dict):
|
||||
"""Construct a typed config object from a raw override dict using the registry."""
|
||||
provider = override.get("provider")
|
||||
if not provider:
|
||||
return None
|
||||
registry = REGISTRY.get(service_type, {})
|
||||
config_cls = registry.get(provider)
|
||||
if config_cls is None:
|
||||
return None
|
||||
return config_cls(**override)
|
||||
|
||||
|
||||
def resolve_effective_config(
|
||||
user_config: UserConfiguration,
|
||||
model_overrides: dict | None,
|
||||
) -> UserConfiguration:
|
||||
"""Deep-merge workflow model_overrides onto global user config.
|
||||
|
||||
- If model_overrides is None or empty, returns a copy of user_config unchanged.
|
||||
- For each section (llm, tts, stt, realtime), if the override contains that key:
|
||||
- If the global section is None, construct a new config from the override.
|
||||
- If the provider changes, construct a new config from the override.
|
||||
- Otherwise, merge override fields onto the existing config (model_copy).
|
||||
- is_realtime is a simple boolean override.
|
||||
- Sections not in the override are inherited from global unchanged.
|
||||
- The original user_config is never mutated.
|
||||
"""
|
||||
if not model_overrides:
|
||||
return user_config.model_copy(deep=True)
|
||||
|
||||
effective = user_config.model_copy(deep=True)
|
||||
|
||||
# Handle is_realtime boolean
|
||||
if "is_realtime" in model_overrides:
|
||||
effective.is_realtime = model_overrides["is_realtime"]
|
||||
|
||||
# Handle service sections
|
||||
for section_key, service_type in _SECTION_MAP.items():
|
||||
if section_key not in model_overrides:
|
||||
continue
|
||||
|
||||
override = model_overrides[section_key]
|
||||
base = getattr(effective, section_key)
|
||||
|
||||
if base is None:
|
||||
# No global config for this section — build from override
|
||||
setattr(
|
||||
effective,
|
||||
section_key,
|
||||
_build_section_from_override(service_type, override),
|
||||
)
|
||||
elif "provider" in override and override["provider"] != base.provider:
|
||||
# Provider changed — must construct new typed object
|
||||
setattr(
|
||||
effective,
|
||||
section_key,
|
||||
_build_section_from_override(service_type, override),
|
||||
)
|
||||
else:
|
||||
# Same provider — merge fields onto existing config
|
||||
merged = base.model_copy(update=override)
|
||||
setattr(effective, section_key, merged)
|
||||
|
||||
return effective
|
||||
|
|
@ -76,13 +76,15 @@ class LoopTalkPipelineBuilder:
|
|||
pipeline_sample_rate=16000,
|
||||
)
|
||||
|
||||
# Use published definition for graph + configs
|
||||
released_def = workflow.released_definition
|
||||
wf_json = released_def.workflow_json
|
||||
wf_configs = released_def.workflow_configurations or {}
|
||||
|
||||
# Extract keyterms from workflow configurations
|
||||
keyterms = None
|
||||
if (
|
||||
workflow.workflow_configurations
|
||||
and "dictionary" in workflow.workflow_configurations
|
||||
):
|
||||
dictionary = workflow.workflow_configurations["dictionary"]
|
||||
if wf_configs and "dictionary" in wf_configs:
|
||||
dictionary = wf_configs["dictionary"]
|
||||
if dictionary and isinstance(dictionary, str):
|
||||
keyterms = [
|
||||
term.strip() for term in dictionary.split(",") if term.strip()
|
||||
|
|
@ -90,6 +92,12 @@ class LoopTalkPipelineBuilder:
|
|||
if keyterms:
|
||||
logger.info(f"Using {len(keyterms)} keyterms for STT: {keyterms}")
|
||||
|
||||
# Resolve model overrides from the version onto global user config
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
|
||||
model_overrides = wf_configs.get("model_overrides")
|
||||
user_config = resolve_effective_config(user_config, model_overrides)
|
||||
|
||||
# Create services
|
||||
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
|
||||
llm = create_llm_service(user_config)
|
||||
|
|
@ -98,9 +106,7 @@ class LoopTalkPipelineBuilder:
|
|||
logger.debug(f"Created services for {role}: STT={stt}, LLM={llm}, TTS={tts}")
|
||||
|
||||
# Get workflow graph
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(workflow.workflow_definition_with_fallback)
|
||||
)
|
||||
workflow_graph = WorkflowGraph(ReactFlowDTO.model_validate(wf_json))
|
||||
|
||||
# Create engine first (needed for create_pipeline_components)
|
||||
engine = PipecatEngine(
|
||||
|
|
|
|||
|
|
@ -562,50 +562,49 @@ async def _run_pipeline(
|
|||
# Get user configuration
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
|
||||
# Get workflow first so we can extract configurations before creating services
|
||||
# Get workflow for metadata (name, organization_id, call_disposition_codes)
|
||||
workflow = await db_client.get_workflow(workflow_id, user_id)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
# Extract configurations from workflow configurations
|
||||
# Use the run's pinned definition for graph + configs (not the workflow's current)
|
||||
run_definition = workflow_run.definition
|
||||
run_workflow_json = run_definition.workflow_json
|
||||
run_configs = run_definition.workflow_configurations or {}
|
||||
|
||||
# Extract configurations from the version's workflow_configurations
|
||||
max_call_duration_seconds = 300 # Default 5 minutes
|
||||
max_user_idle_timeout = 10.0 # Default 10 seconds
|
||||
smart_turn_stop_secs = 2.0 # Default 2 seconds for incomplete turn timeout
|
||||
turn_stop_strategy = "transcription" # Default to transcription-based detection
|
||||
keyterms = None # Dictionary words for STT boosting
|
||||
|
||||
if workflow.workflow_configurations:
|
||||
# Use workflow-specific max call duration if provided
|
||||
if "max_call_duration" in workflow.workflow_configurations:
|
||||
max_call_duration_seconds = workflow.workflow_configurations[
|
||||
"max_call_duration"
|
||||
]
|
||||
if run_configs:
|
||||
if "max_call_duration" in run_configs:
|
||||
max_call_duration_seconds = run_configs["max_call_duration"]
|
||||
|
||||
# Use workflow-specific max user idle timeout if provided
|
||||
if "max_user_idle_timeout" in workflow.workflow_configurations:
|
||||
max_user_idle_timeout = workflow.workflow_configurations[
|
||||
"max_user_idle_timeout"
|
||||
]
|
||||
if "max_user_idle_timeout" in run_configs:
|
||||
max_user_idle_timeout = run_configs["max_user_idle_timeout"]
|
||||
|
||||
# Use workflow-specific smart turn stop timeout if provided
|
||||
if "smart_turn_stop_secs" in workflow.workflow_configurations:
|
||||
smart_turn_stop_secs = workflow.workflow_configurations[
|
||||
"smart_turn_stop_secs"
|
||||
]
|
||||
if "smart_turn_stop_secs" in run_configs:
|
||||
smart_turn_stop_secs = run_configs["smart_turn_stop_secs"]
|
||||
|
||||
# Use workflow-specific turn stop strategy if provided
|
||||
if "turn_stop_strategy" in workflow.workflow_configurations:
|
||||
turn_stop_strategy = workflow.workflow_configurations["turn_stop_strategy"]
|
||||
if "turn_stop_strategy" in run_configs:
|
||||
turn_stop_strategy = run_configs["turn_stop_strategy"]
|
||||
|
||||
# Extract dictionary words and convert to keyterms list
|
||||
if "dictionary" in workflow.workflow_configurations:
|
||||
dictionary = workflow.workflow_configurations["dictionary"]
|
||||
if "dictionary" in run_configs:
|
||||
dictionary = run_configs["dictionary"]
|
||||
if dictionary and isinstance(dictionary, str):
|
||||
# Split by comma and strip whitespace from each term
|
||||
keyterms = [
|
||||
term.strip() for term in dictionary.split(",") if term.strip()
|
||||
]
|
||||
|
||||
# Resolve model overrides from the version onto global user config
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
|
||||
model_overrides = run_configs.get("model_overrides")
|
||||
user_config = resolve_effective_config(user_config, model_overrides)
|
||||
|
||||
# Detect realtime mode (speech-to-speech services like OpenAI Realtime, Gemini Live)
|
||||
is_realtime = user_config.is_realtime and user_config.realtime is not None
|
||||
|
||||
|
|
@ -619,9 +618,7 @@ async def _run_pipeline(
|
|||
tts = create_tts_service(user_config, audio_config)
|
||||
llm = create_llm_service(user_config)
|
||||
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(workflow.workflow_definition_with_fallback)
|
||||
)
|
||||
workflow_graph = WorkflowGraph(ReactFlowDTO.model_validate(run_workflow_json))
|
||||
|
||||
# Pre-call fetch: fire early so it runs concurrently with remaining setup
|
||||
pre_call_fetch_task = None
|
||||
|
|
|
|||
|
|
@ -325,8 +325,6 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.RIME.value:
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
speed = getattr(user_config.tts, "speed", None)
|
||||
language_code = getattr(user_config.tts, "language", None) or "en"
|
||||
rime_language_mapping = {
|
||||
|
|
|
|||
|
|
@ -74,13 +74,17 @@ async def duplicate_workflow(
|
|||
if source is None:
|
||||
raise ValueError(f"Workflow with id {workflow_id} not found")
|
||||
|
||||
workflow_definition = copy.deepcopy(source.workflow_definition_with_fallback)
|
||||
# 2. Prefer draft over released definition (duplicate latest state)
|
||||
draft = await db_client.get_draft_version(workflow_id)
|
||||
source_def = draft if draft else source.released_definition
|
||||
|
||||
# 2. Regenerate trigger UUIDs to avoid conflicts
|
||||
workflow_definition = copy.deepcopy(source_def.workflow_json)
|
||||
|
||||
# 3. Regenerate trigger UUIDs to avoid conflicts
|
||||
if workflow_definition:
|
||||
workflow_definition = _regenerate_trigger_uuids(workflow_definition)
|
||||
|
||||
# 3. Create the new workflow
|
||||
# 4. Create the new workflow
|
||||
new_name = f"{source.name} - Duplicate"
|
||||
new_workflow = await db_client.create_workflow(
|
||||
name=new_name,
|
||||
|
|
@ -89,21 +93,20 @@ async def duplicate_workflow(
|
|||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
# 4. Copy template_context_variables and workflow_configurations
|
||||
has_extra_fields = (
|
||||
source.template_context_variables or source.workflow_configurations
|
||||
)
|
||||
if has_extra_fields:
|
||||
# 5. Copy template_context_variables and workflow_configurations from source definition
|
||||
source_tcv = source_def.template_context_variables
|
||||
source_wc = source_def.workflow_configurations
|
||||
if source_tcv or source_wc:
|
||||
new_workflow = await db_client.update_workflow(
|
||||
workflow_id=new_workflow.id,
|
||||
name=None,
|
||||
workflow_definition=None,
|
||||
template_context_variables=copy.deepcopy(source.template_context_variables),
|
||||
workflow_configurations=copy.deepcopy(source.workflow_configurations),
|
||||
template_context_variables=copy.deepcopy(source_tcv),
|
||||
workflow_configurations=copy.deepcopy(source_wc),
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
# 5. Copy recordings with new IDs and storage paths scoped to new workflow
|
||||
# 6. Copy recordings with new IDs and storage paths scoped to new workflow
|
||||
recording_id_map = await _duplicate_recordings(
|
||||
source_workflow_id=workflow_id,
|
||||
new_workflow_id=new_workflow.id,
|
||||
|
|
@ -111,7 +114,7 @@ async def duplicate_workflow(
|
|||
user_id=user_id,
|
||||
)
|
||||
|
||||
# 6. Replace old recording IDs with new ones in the workflow definition
|
||||
# 7. Replace old recording IDs with new ones in the workflow definition
|
||||
if recording_id_map:
|
||||
workflow_definition = _replace_recording_ids(
|
||||
workflow_definition, recording_id_map
|
||||
|
|
@ -125,7 +128,7 @@ async def duplicate_workflow(
|
|||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
# 7. Sync triggers for the new workflow
|
||||
# 8. Sync triggers for the new workflow
|
||||
if workflow_definition:
|
||||
trigger_paths = _extract_trigger_paths(workflow_definition)
|
||||
if trigger_paths:
|
||||
|
|
|
|||
|
|
@ -187,15 +187,9 @@ async def run_integrations_post_workflow_run(_ctx, workflow_run_id: int):
|
|||
secret_key=langfuse_config.get("secret_key"),
|
||||
)
|
||||
|
||||
# Step 2: Get workflow definition (prefer the run-specific definition)
|
||||
if workflow_run.definition:
|
||||
workflow_definition = workflow_run.definition.workflow_json
|
||||
definition_id = workflow_run.definition.id
|
||||
else:
|
||||
workflow_definition = (
|
||||
workflow_run.workflow.workflow_definition_with_fallback
|
||||
)
|
||||
definition_id = workflow_run.workflow.current_definition_id
|
||||
# Step 2: Get workflow definition from the run's pinned version
|
||||
workflow_definition = workflow_run.definition.workflow_json
|
||||
definition_id = workflow_run.definition.id
|
||||
|
||||
if not workflow_definition:
|
||||
logger.debug("No workflow definition, skipping integrations")
|
||||
|
|
|
|||
|
|
@ -80,11 +80,6 @@ class MockWorkflowModel:
|
|||
workflow_id: int = 1
|
||||
organization_id: int = 1
|
||||
workflow_configurations: Dict[str, Any] = field(default_factory=dict)
|
||||
workflow_definition_with_fallback: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.workflow_definition_with_fallback:
|
||||
self.workflow_definition_with_fallback = DEFAULT_WORKFLOW_DEFINITION.copy()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -120,6 +115,7 @@ class MockToolModel:
|
|||
name: str
|
||||
description: str
|
||||
definition: Dict[str, Any]
|
||||
category: str = "http_api"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
353
api/tests/test_resolve_effective_config.py
Normal file
353
api/tests/test_resolve_effective_config.py
Normal file
|
|
@ -0,0 +1,353 @@
|
|||
"""
|
||||
TDD tests for resolve_effective_config().
|
||||
|
||||
This function deep-merges workflow-level model_overrides onto the global
|
||||
UserConfiguration. Fields not overridden inherit from global.
|
||||
|
||||
Module under test: api.services.configuration.resolve
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.registry import (
|
||||
DeepgramSTTConfiguration,
|
||||
ElevenlabsTTSConfiguration,
|
||||
GoogleRealtimeLLMConfiguration,
|
||||
OpenAILLMService,
|
||||
)
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def global_config() -> UserConfiguration:
|
||||
"""A realistic global user configuration."""
|
||||
return UserConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai", api_key="sk-global-llm", model="gpt-4.1"
|
||||
),
|
||||
tts=ElevenlabsTTSConfiguration(
|
||||
provider="elevenlabs",
|
||||
api_key="el-global-tts",
|
||||
voice="Rachel",
|
||||
model="eleven_flash_v2_5",
|
||||
),
|
||||
stt=DeepgramSTTConfiguration(
|
||||
provider="deepgram",
|
||||
api_key="dg-global-stt",
|
||||
model="nova-3-general",
|
||||
language="multi",
|
||||
),
|
||||
is_realtime=False,
|
||||
realtime=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def global_config_realtime() -> UserConfiguration:
|
||||
"""Global config with realtime enabled."""
|
||||
return UserConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai", api_key="sk-global-llm", model="gpt-4.1"
|
||||
),
|
||||
tts=ElevenlabsTTSConfiguration(
|
||||
provider="elevenlabs",
|
||||
api_key="el-global-tts",
|
||||
voice="Rachel",
|
||||
model="eleven_flash_v2_5",
|
||||
),
|
||||
stt=DeepgramSTTConfiguration(
|
||||
provider="deepgram",
|
||||
api_key="dg-global-stt",
|
||||
model="nova-3-general",
|
||||
language="multi",
|
||||
),
|
||||
is_realtime=True,
|
||||
realtime=GoogleRealtimeLLMConfiguration(
|
||||
provider="google_realtime",
|
||||
api_key="goog-global-rt",
|
||||
model="gemini-3.1-flash-live-preview",
|
||||
voice="Puck",
|
||||
language="en",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# No overrides → global returned unchanged
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNoOverrides:
|
||||
def test_none_overrides_returns_global(self, global_config):
|
||||
result = resolve_effective_config(global_config, None)
|
||||
assert result.llm.model == "gpt-4.1"
|
||||
assert result.tts.voice == "Rachel"
|
||||
assert result.stt.model == "nova-3-general"
|
||||
assert result.is_realtime is False
|
||||
|
||||
def test_empty_dict_overrides_returns_global(self, global_config):
|
||||
result = resolve_effective_config(global_config, {})
|
||||
assert result.llm.model == "gpt-4.1"
|
||||
assert result.tts.voice == "Rachel"
|
||||
|
||||
def test_does_not_mutate_original(self, global_config):
|
||||
"""The original config object must not be modified."""
|
||||
resolve_effective_config(global_config, {"llm": {"model": "gpt-4.1-mini"}})
|
||||
assert global_config.llm.model == "gpt-4.1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single-field overrides within a section (same provider)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSingleFieldOverride:
|
||||
def test_override_llm_model_only(self, global_config):
|
||||
result = resolve_effective_config(
|
||||
global_config, {"llm": {"model": "gpt-4.1-mini"}}
|
||||
)
|
||||
assert result.llm.model == "gpt-4.1-mini"
|
||||
assert result.llm.provider == "openai" # inherited
|
||||
assert result.llm.api_key == "sk-global-llm" # inherited
|
||||
|
||||
def test_override_tts_voice_only(self, global_config):
|
||||
result = resolve_effective_config(global_config, {"tts": {"voice": "shimmer"}})
|
||||
assert result.tts.voice == "shimmer"
|
||||
assert result.tts.provider == "elevenlabs" # inherited
|
||||
assert result.tts.api_key == "el-global-tts" # inherited
|
||||
|
||||
def test_override_stt_language_only(self, global_config):
|
||||
result = resolve_effective_config(global_config, {"stt": {"language": "en"}})
|
||||
assert result.stt.language == "en"
|
||||
assert result.stt.model == "nova-3-general" # inherited
|
||||
assert result.stt.provider == "deepgram" # inherited
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider change (requires full section replacement)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProviderChange:
|
||||
def test_override_llm_to_different_provider(self, global_config):
|
||||
result = resolve_effective_config(
|
||||
global_config,
|
||||
{
|
||||
"llm": {
|
||||
"provider": "groq",
|
||||
"api_key": "groq-key",
|
||||
"model": "llama-3.3-70b-versatile",
|
||||
}
|
||||
},
|
||||
)
|
||||
assert result.llm.provider == "groq"
|
||||
assert result.llm.model == "llama-3.3-70b-versatile"
|
||||
assert result.llm.api_key == "groq-key"
|
||||
|
||||
def test_provider_change_does_not_affect_other_sections(self, global_config):
|
||||
result = resolve_effective_config(
|
||||
global_config,
|
||||
{
|
||||
"llm": {
|
||||
"provider": "groq",
|
||||
"api_key": "groq-key",
|
||||
"model": "llama-3.3-70b-versatile",
|
||||
}
|
||||
},
|
||||
)
|
||||
# TTS and STT unchanged
|
||||
assert result.tts.provider == "elevenlabs"
|
||||
assert result.stt.provider == "deepgram"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API key inheritance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAPIKeyInheritance:
|
||||
def test_no_api_key_in_override_inherits_global(self, global_config):
|
||||
"""When override omits api_key, global key is used."""
|
||||
result = resolve_effective_config(
|
||||
global_config, {"llm": {"model": "gpt-4.1-mini"}}
|
||||
)
|
||||
assert result.llm.api_key == "sk-global-llm"
|
||||
|
||||
def test_explicit_api_key_in_override_wins(self, global_config):
|
||||
"""When override includes api_key, it takes precedence."""
|
||||
result = resolve_effective_config(
|
||||
global_config,
|
||||
{"llm": {"model": "gpt-4.1-mini", "api_key": "sk-override-key"}},
|
||||
)
|
||||
assert result.llm.api_key == "sk-override-key"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_realtime override
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRealtimeOverride:
|
||||
def test_enable_realtime_on_non_realtime_global(self, global_config):
|
||||
result = resolve_effective_config(
|
||||
global_config,
|
||||
{
|
||||
"is_realtime": True,
|
||||
"realtime": {
|
||||
"provider": "google_realtime",
|
||||
"api_key": "goog-override",
|
||||
"model": "gemini-3.1-flash-live-preview",
|
||||
"voice": "Charon",
|
||||
"language": "en",
|
||||
},
|
||||
},
|
||||
)
|
||||
assert result.is_realtime is True
|
||||
assert result.realtime.provider == "google_realtime"
|
||||
assert result.realtime.voice == "Charon"
|
||||
|
||||
def test_disable_realtime_on_realtime_global(self, global_config_realtime):
|
||||
result = resolve_effective_config(
|
||||
global_config_realtime, {"is_realtime": False}
|
||||
)
|
||||
assert result.is_realtime is False
|
||||
# Realtime config may still be present but is_realtime flag controls usage
|
||||
|
||||
def test_override_realtime_voice_only(self, global_config_realtime):
|
||||
result = resolve_effective_config(
|
||||
global_config_realtime, {"realtime": {"voice": "Kore"}}
|
||||
)
|
||||
assert result.realtime.voice == "Kore"
|
||||
assert result.realtime.provider == "google_realtime" # inherited
|
||||
assert result.realtime.api_key == "goog-global-rt" # inherited
|
||||
|
||||
def test_override_is_realtime_only_without_realtime_section(self, global_config):
|
||||
"""Override is_realtime=True but provide no realtime config.
|
||||
Should set the flag; realtime section stays None from global."""
|
||||
result = resolve_effective_config(global_config, {"is_realtime": True})
|
||||
assert result.is_realtime is True
|
||||
assert result.realtime is None # no config provided
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Section override when global has None for that section
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOverrideOnNullGlobal:
|
||||
def test_override_stt_when_global_is_none(self):
|
||||
"""When global has no STT config, override creates one from scratch."""
|
||||
config = UserConfiguration(
|
||||
llm=OpenAILLMService(provider="openai", api_key="sk-key", model="gpt-4.1"),
|
||||
stt=None,
|
||||
tts=None,
|
||||
is_realtime=False,
|
||||
)
|
||||
result = resolve_effective_config(
|
||||
config,
|
||||
{
|
||||
"stt": {
|
||||
"provider": "deepgram",
|
||||
"api_key": "dg-new",
|
||||
"model": "nova-3-general",
|
||||
"language": "en",
|
||||
}
|
||||
},
|
||||
)
|
||||
assert result.stt is not None
|
||||
assert result.stt.provider == "deepgram"
|
||||
assert result.stt.model == "nova-3-general"
|
||||
|
||||
def test_override_realtime_when_global_is_none(self):
|
||||
"""Realtime section can be created from override even if global has none."""
|
||||
config = UserConfiguration(
|
||||
llm=OpenAILLMService(provider="openai", api_key="sk-key", model="gpt-4.1"),
|
||||
is_realtime=False,
|
||||
realtime=None,
|
||||
)
|
||||
result = resolve_effective_config(
|
||||
config,
|
||||
{
|
||||
"is_realtime": True,
|
||||
"realtime": {
|
||||
"provider": "google_realtime",
|
||||
"api_key": "goog-new",
|
||||
"model": "gemini-3.1-flash-live-preview",
|
||||
"voice": "Puck",
|
||||
"language": "en",
|
||||
},
|
||||
},
|
||||
)
|
||||
assert result.is_realtime is True
|
||||
assert result.realtime.provider == "google_realtime"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-section overrides
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMultiSectionOverride:
|
||||
def test_override_llm_and_tts_not_stt(self, global_config):
|
||||
result = resolve_effective_config(
|
||||
global_config,
|
||||
{
|
||||
"llm": {"model": "gpt-4.1-mini"},
|
||||
"tts": {"voice": "shimmer"},
|
||||
},
|
||||
)
|
||||
assert result.llm.model == "gpt-4.1-mini"
|
||||
assert result.tts.voice == "shimmer"
|
||||
# STT untouched
|
||||
assert result.stt.model == "nova-3-general"
|
||||
assert result.stt.language == "multi"
|
||||
|
||||
def test_override_all_sections(self, global_config):
|
||||
result = resolve_effective_config(
|
||||
global_config,
|
||||
{
|
||||
"llm": {"model": "gpt-4.1-mini"},
|
||||
"tts": {"voice": "shimmer"},
|
||||
"stt": {"language": "en"},
|
||||
"is_realtime": True,
|
||||
"realtime": {
|
||||
"provider": "google_realtime",
|
||||
"api_key": "goog-key",
|
||||
"model": "gemini-3.1-flash-live-preview",
|
||||
"voice": "Fenrir",
|
||||
"language": "en",
|
||||
},
|
||||
},
|
||||
)
|
||||
assert result.llm.model == "gpt-4.1-mini"
|
||||
assert result.tts.voice == "shimmer"
|
||||
assert result.stt.language == "en"
|
||||
assert result.is_realtime is True
|
||||
assert result.realtime.voice == "Fenrir"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ignored / unknown keys
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUnknownKeys:
|
||||
def test_unknown_section_in_overrides_is_ignored(self, global_config):
|
||||
"""Override with a key that doesn't map to any section should not crash."""
|
||||
result = resolve_effective_config(
|
||||
global_config, {"unknown_section": {"foo": "bar"}}
|
||||
)
|
||||
assert result.llm.model == "gpt-4.1"
|
||||
|
||||
def test_embeddings_not_overridable(self, global_config):
|
||||
"""Embeddings stay global — overrides for embeddings should be ignored."""
|
||||
result = resolve_effective_config(
|
||||
global_config,
|
||||
{"embeddings": {"provider": "openai", "model": "text-embedding-3-small"}},
|
||||
)
|
||||
assert result.embeddings is None # was None in global, stays None
|
||||
|
|
@ -1,960 +0,0 @@
|
|||
"""Tests validating user turn stop strategy behavior during bot speaking scenarios.
|
||||
|
||||
These tests validate the scenarios described in scenarios.md. They demonstrate
|
||||
how the ExternalUserTurnStopStrategy and UserTurnController interact when frames
|
||||
are suppressed (muted) during bot speaking.
|
||||
|
||||
Key concepts:
|
||||
- When the bot is speaking, AlwaysUserMuteStrategy causes the LLMUserAggregator
|
||||
to suppress user frames (UserStartedSpeaking, UserStoppedSpeaking, Transcription, VAD).
|
||||
- The ExternalUserTurnStopStrategy accumulates _text from TranscriptionFrames and
|
||||
triggers a stop when _user_speaking is False and _text is truthy.
|
||||
- The UserTurnController only allows a stop if _user_turn is True (a start must
|
||||
have occurred first). When a stop is rejected, the controller unconditionally
|
||||
resets all stop strategies, clearing any dangling state (e.g. _text).
|
||||
- This unconditional reset prevents stale _text from causing premature stops
|
||||
or contaminating subsequent turns.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
EndTaskFrame,
|
||||
Frame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests import MockLLMService
|
||||
from pipecat.turns.user_mute import AlwaysUserMuteStrategy
|
||||
from pipecat.turns.user_start import VADUserTurnStartStrategy
|
||||
from pipecat.turns.user_stop import ExternalUserTurnStopStrategy
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
# Short timeout for faster tests
|
||||
STOP_STRATEGY_TIMEOUT = 0.15
|
||||
# Delay to allow async processing
|
||||
ASYNC_DELAY = 0.05
|
||||
# Delay to wait for stop strategy timeout to fire
|
||||
TIMEOUT_WAIT = STOP_STRATEGY_TIMEOUT + 0.1
|
||||
|
||||
|
||||
class FrameInjector(FrameProcessor):
|
||||
"""Simple processor that can inject frames into the pipeline."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def inject(
|
||||
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
"""Inject a frame into the pipeline."""
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
def _build_components(llm_steps=None):
|
||||
"""Build pipeline components for testing.
|
||||
|
||||
Uses:
|
||||
- VADUserTurnStartStrategy: turn starts only when VADUserStartedSpeakingFrame arrives
|
||||
- ExternalUserTurnStopStrategy: turn stops based on UserStoppedSpeakingFrame + _text
|
||||
- AlwaysUserMuteStrategy: suppresses user frames while bot is speaking
|
||||
|
||||
Returns a tuple of (injector, user_aggregator, stop_strategy, turn_controller, mock_llm, pipeline).
|
||||
"""
|
||||
context = LLMContext()
|
||||
|
||||
stop_strategy = ExternalUserTurnStopStrategy(timeout=STOP_STRATEGY_TIMEOUT)
|
||||
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[VADUserTurnStartStrategy()],
|
||||
stop=[stop_strategy],
|
||||
)
|
||||
|
||||
user_params = LLMUserAggregatorParams(
|
||||
user_turn_strategies=user_turn_strategies,
|
||||
user_mute_strategies=[AlwaysUserMuteStrategy()],
|
||||
)
|
||||
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
|
||||
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context, assistant_params=assistant_params, user_params=user_params
|
||||
)
|
||||
user_agg = context_aggregator.user()
|
||||
assistant_agg = context_aggregator.assistant()
|
||||
|
||||
if llm_steps is None:
|
||||
llm_steps = [
|
||||
MockLLMService.create_text_chunks(text="Response 1"),
|
||||
MockLLMService.create_text_chunks(text="Response 2"),
|
||||
MockLLMService.create_text_chunks(text="Response 3"),
|
||||
]
|
||||
mock_llm = MockLLMService(mock_steps=llm_steps, chunk_delay=0.001)
|
||||
|
||||
injector = FrameInjector()
|
||||
pipeline = Pipeline([injector, user_agg, mock_llm, assistant_agg])
|
||||
|
||||
turn_controller = user_agg._user_turn_controller
|
||||
|
||||
return (
|
||||
injector,
|
||||
user_agg,
|
||||
stop_strategy,
|
||||
turn_controller,
|
||||
mock_llm,
|
||||
context,
|
||||
pipeline,
|
||||
)
|
||||
|
||||
|
||||
async def _run_scenario(pipeline, inject_fn):
|
||||
"""Run a pipeline with a frame injection coroutine."""
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run():
|
||||
await runner.run(task)
|
||||
|
||||
async def inject():
|
||||
# Wait for pipeline to start (StartFrame to propagate)
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
await inject_fn()
|
||||
|
||||
await asyncio.gather(run(), inject())
|
||||
|
||||
|
||||
async def _inject_user_turn(injector, text, delay=ASYNC_DELAY):
|
||||
"""Inject a complete user turn: VAD start + external start + transcription + external stop.
|
||||
|
||||
This simulates what happens in a real pipeline when the user speaks:
|
||||
1. VAD detects speech -> VADUserStartedSpeakingFrame (triggers turn start)
|
||||
2. External processor sends UserStartedSpeakingFrame (stop strategy tracks _user_speaking)
|
||||
3. STT produces TranscriptionFrame (stop strategy accumulates _text)
|
||||
4. External processor sends UserStoppedSpeakingFrame (stop strategy triggers stop)
|
||||
"""
|
||||
await injector.inject(VADUserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(delay)
|
||||
await injector.inject(TranscriptionFrame(text, "user-1", time_now_iso8601()))
|
||||
|
||||
|
||||
class TestUserTurnStopScenarios:
|
||||
"""Test scenarios from scenarios.md.
|
||||
|
||||
Each test simulates a specific frame ordering to validate the interaction
|
||||
between ExternalUserTurnStopStrategy and UserTurnController, particularly
|
||||
around frame suppression during bot speaking.
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Scenario 1 (✅): All frames suppressed during bot speaking
|
||||
#
|
||||
# BotStartedSpeaking (muted)
|
||||
# UserStartedSpeaking (suppressed)
|
||||
# TranscriptionFrame (suppressed)
|
||||
# UserStoppedSpeaking (suppressed)
|
||||
# BotStoppedSpeaking (unmuted)
|
||||
#
|
||||
# Stop strategy _text is empty because TranscriptionFrame was suppressed.
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_1_all_suppressed_then_bot_stops(self):
|
||||
"""All user frames suppressed during bot speaking, then bot stops.
|
||||
|
||||
Expected: _text is empty, no turn triggered, clean state.
|
||||
Second turn works correctly.
|
||||
"""
|
||||
injector, user_agg, stop_strategy, turn_ctrl, mock_llm, context, pipeline = (
|
||||
_build_components()
|
||||
)
|
||||
|
||||
async def inject():
|
||||
# === Turn 1: Bot speaking, all user frames suppressed ===
|
||||
await injector.inject(BotStartedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# These are all suppressed by mute
|
||||
await injector.inject(VADUserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(
|
||||
TranscriptionFrame("hello", "user-1", time_now_iso8601())
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(VADUserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
await injector.inject(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
# Assert: _text should be empty (all frames suppressed)
|
||||
assert stop_strategy._text == "", (
|
||||
f"Expected empty _text after all frames suppressed, got '{stop_strategy._text}'"
|
||||
)
|
||||
assert not turn_ctrl._user_turn, "Expected _user_turn to be False"
|
||||
|
||||
# === Turn 2: Normal turn should work correctly ===
|
||||
await _inject_user_turn(injector, "second turn text")
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
# Assert: turn completed, _text cleared by reset
|
||||
assert stop_strategy._text == "", (
|
||||
f"Expected empty _text after clean turn, got '{stop_strategy._text}'"
|
||||
)
|
||||
assert not turn_ctrl._user_turn, (
|
||||
"Expected _user_turn to be False after turn"
|
||||
)
|
||||
assert mock_llm.get_current_step() == 1, (
|
||||
f"Expected 1 LLM call (turn 2 only), got {mock_llm.get_current_step()}"
|
||||
)
|
||||
|
||||
await injector.inject(EndTaskFrame(), direction=FrameDirection.UPSTREAM)
|
||||
|
||||
await _run_scenario(pipeline, inject)
|
||||
|
||||
# =========================================================================
|
||||
# Scenario 2 (✅): User frames suppressed, user stops after bot stops
|
||||
#
|
||||
# BotStartedSpeaking (muted)
|
||||
# UserStartedSpeaking (suppressed)
|
||||
# TranscriptionFrame (suppressed)
|
||||
# BotStoppedSpeaking (unmuted)
|
||||
# UserStoppedSpeaking (stop strategy has no _text -> no trigger)
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_2_user_stops_after_bot_stops_no_text(self):
|
||||
"""User stops speaking after bot stops, but transcription was suppressed.
|
||||
|
||||
Expected: _text is empty because transcription was suppressed.
|
||||
UserStoppedSpeaking doesn't trigger stop (no _text).
|
||||
"""
|
||||
injector, user_agg, stop_strategy, turn_ctrl, mock_llm, context, pipeline = (
|
||||
_build_components()
|
||||
)
|
||||
|
||||
async def inject():
|
||||
# === Turn 1: Bot speaking, user frames partially suppressed ===
|
||||
await injector.inject(BotStartedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Suppressed during bot speaking
|
||||
await injector.inject(VADUserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(
|
||||
TranscriptionFrame("hello", "user-1", time_now_iso8601())
|
||||
)
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Bot stops -> unmuted
|
||||
await injector.inject(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# UserStoppedSpeaking arrives after unmute, but _text is empty
|
||||
await injector.inject(UserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
# Assert: _text empty (TranscriptionFrame was suppressed)
|
||||
assert stop_strategy._text == "", (
|
||||
f"Expected empty _text, got '{stop_strategy._text}'"
|
||||
)
|
||||
assert not turn_ctrl._user_turn, "Expected _user_turn to be False"
|
||||
|
||||
# === Turn 2: Normal turn should work ===
|
||||
await _inject_user_turn(injector, "second turn")
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
assert stop_strategy._text == "", "Expected clean _text after turn 2"
|
||||
assert mock_llm.get_current_step() == 1, (
|
||||
f"Expected 1 LLM call, got {mock_llm.get_current_step()}"
|
||||
)
|
||||
|
||||
await injector.inject(EndTaskFrame(), direction=FrameDirection.UPSTREAM)
|
||||
|
||||
await _run_scenario(pipeline, inject)
|
||||
|
||||
# =========================================================================
|
||||
# Scenario 3 (✅ after fix): Transcription arrives after unmute
|
||||
#
|
||||
# BotStartedSpeaking (muted)
|
||||
# UserStartedSpeaking (suppressed)
|
||||
# BotStoppedSpeaking (unmuted)
|
||||
# TranscriptionFrame -> stop strategy _text = "hello"
|
||||
# UserStoppedSpeaking -> stop strategy triggers (text truthy, not speaking)
|
||||
# Turn controller ignores (user_turn is False), BUT unconditionally
|
||||
# resets stop strategies -> _text cleared. No dangling state.
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_3_transcription_after_unmute_text_cleared(self):
|
||||
"""Transcription arrives after bot stops but turn was never started.
|
||||
|
||||
The VADUserStartedSpeakingFrame was suppressed, so no turn started.
|
||||
But TranscriptionFrame arrives after unmute and accumulates _text.
|
||||
The stop strategy triggers, but the turn controller rejects it
|
||||
(no active turn). The unconditional reset clears _text, preventing
|
||||
any dangling state from contaminating subsequent turns.
|
||||
"""
|
||||
injector, user_agg, stop_strategy, turn_ctrl, mock_llm, context, pipeline = (
|
||||
_build_components()
|
||||
)
|
||||
|
||||
async def inject():
|
||||
# === Turn 1: Rejected stop with unconditional reset ===
|
||||
await injector.inject(BotStartedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Suppressed: VAD and UserStartedSpeaking
|
||||
await injector.inject(VADUserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStartedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Bot stops -> unmuted
|
||||
await injector.inject(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Install spy on trigger_user_turn_stopped to track every call
|
||||
# and the _user_turn state at the time of each call.
|
||||
trigger_stop_calls = []
|
||||
original_trigger_stop = stop_strategy.trigger_user_turn_stopped
|
||||
|
||||
async def spy_trigger_stop():
|
||||
trigger_stop_calls.append(turn_ctrl._user_turn)
|
||||
await original_trigger_stop()
|
||||
|
||||
stop_strategy.trigger_user_turn_stopped = spy_trigger_stop
|
||||
|
||||
# TranscriptionFrame arrives AFTER unmute -> reaches stop strategy
|
||||
await injector.inject(
|
||||
TranscriptionFrame("hello", "user-1", time_now_iso8601())
|
||||
)
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# UserStoppedSpeaking arrives AFTER unmute
|
||||
# Stop strategy: _user_speaking is False (UserStartedSpeaking was suppressed),
|
||||
# _text is "hello" -> triggers stop via _handle_user_stopped_speaking
|
||||
# Turn controller: _user_turn is False -> rejects, but resets -> _text cleared
|
||||
await injector.inject(UserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Call #1: _handle_user_stopped_speaking -> _maybe_trigger_user_turn_stopped
|
||||
assert len(trigger_stop_calls) == 1, (
|
||||
f"Expected exactly 1 trigger_user_turn_stopped call from "
|
||||
f"_handle_user_stopped_speaking, got {len(trigger_stop_calls)}"
|
||||
)
|
||||
assert trigger_stop_calls[0] is False, (
|
||||
"Expected _user_turn=False when _handle_user_stopped_speaking triggered stop"
|
||||
)
|
||||
|
||||
# Wait for _task_handler timeout period
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
# The unconditional reset cleared _text after the rejected stop,
|
||||
# so the timeout's _maybe_trigger_user_turn_stopped sees _text="" and
|
||||
# does NOT call trigger_user_turn_stopped again.
|
||||
assert len(trigger_stop_calls) == 1, (
|
||||
f"Expected no additional trigger_user_turn_stopped calls after "
|
||||
f"reset cleared _text, but got {len(trigger_stop_calls)} total call(s)"
|
||||
)
|
||||
|
||||
# Restore original method
|
||||
stop_strategy.trigger_user_turn_stopped = original_trigger_stop
|
||||
|
||||
# Transcript is not suppressed, so we should have hello in user aggregator
|
||||
assert user_agg._aggregation[0].text == "hello"
|
||||
|
||||
# Assert: _text is cleared by the unconditional reset (no dangling state)
|
||||
assert stop_strategy._text == "", (
|
||||
f"Expected empty _text after unconditional reset, got '{stop_strategy._text}'"
|
||||
)
|
||||
assert not turn_ctrl._user_turn, (
|
||||
"Expected _user_turn to be False (turn was never started)"
|
||||
)
|
||||
# No LLM call should have happened
|
||||
assert mock_llm.get_current_step() == 0, (
|
||||
f"Expected 0 LLM calls, got {mock_llm.get_current_step()}"
|
||||
)
|
||||
|
||||
# === Turn 2: No premature stop, normal flow ===
|
||||
# _text is clean, so UserStoppedSpeaking won't trigger a premature stop.
|
||||
# The turn completes normally when the timeout fires after TranscriptionFrame.
|
||||
# The aggregator still has dangling "hello" from turn 1, which gets
|
||||
# combined with turn 2's "world" — this is acceptable behavior.
|
||||
await _inject_user_turn(injector, "world")
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
assert stop_strategy._text == "", (
|
||||
f"Expected clean _text after normal turn, got '{stop_strategy._text}'"
|
||||
)
|
||||
assert mock_llm.get_current_step() == 1, (
|
||||
f"Expected 1 LLM call (normal turn), got {mock_llm.get_current_step()}"
|
||||
)
|
||||
|
||||
# The LLM received both "hello" (dangling in aggregator from turn 1)
|
||||
# and "world" (from turn 2). This is acceptable — the aggregator's
|
||||
# _aggregation is a separate concern from the stop strategy's _text.
|
||||
messages = context.messages
|
||||
user_messages = [m for m in messages if m.get("role") == "user"]
|
||||
assert len(user_messages) == 1, (
|
||||
f"Expected 1 user message, got {len(user_messages)}"
|
||||
)
|
||||
user_text = user_messages[0]["content"]
|
||||
assert "hello" in user_text, (
|
||||
f"Expected 'hello' (from aggregator) in user message, got: '{user_text}'"
|
||||
)
|
||||
assert "world" in user_text, (
|
||||
f"Expected 'world' (from turn 2) in user message, got: '{user_text}'"
|
||||
)
|
||||
|
||||
await injector.inject(EndTaskFrame(), direction=FrameDirection.UPSTREAM)
|
||||
|
||||
await _run_scenario(pipeline, inject)
|
||||
|
||||
# =========================================================================
|
||||
# Scenario 4 (✅): User speaks after bot stops -> normal flow
|
||||
#
|
||||
# BotStartedSpeaking (muted)
|
||||
# BotStoppedSpeaking (unmuted)
|
||||
# UserStartedSpeaking (triggers interruption/turn start)
|
||||
# TranscriptionFrame
|
||||
# UserStoppedSpeaking
|
||||
#
|
||||
# Turn starts because VAD frame is not suppressed. Everything works.
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_4_user_speaks_after_bot_stops(self):
|
||||
"""User speaks after bot stops speaking. Normal flow, everything works.
|
||||
|
||||
All frames arrive after unmute, so VAD triggers turn start normally.
|
||||
"""
|
||||
injector, user_agg, stop_strategy, turn_ctrl, mock_llm, context, pipeline = (
|
||||
_build_components()
|
||||
)
|
||||
|
||||
async def inject():
|
||||
# === Turn 1: Bot speaks, then user speaks after bot stops ===
|
||||
await injector.inject(BotStartedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
await injector.inject(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Normal user turn after bot stopped
|
||||
await _inject_user_turn(injector, "hello after bot")
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
# Assert: clean state
|
||||
assert stop_strategy._text == "", (
|
||||
f"Expected empty _text after clean turn, got '{stop_strategy._text}'"
|
||||
)
|
||||
assert not turn_ctrl._user_turn, "Expected _user_turn False after turn"
|
||||
assert mock_llm.get_current_step() == 1, (
|
||||
f"Expected 1 LLM call, got {mock_llm.get_current_step()}"
|
||||
)
|
||||
|
||||
# === Turn 2: Another normal turn ===
|
||||
await _inject_user_turn(injector, "second turn")
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
assert stop_strategy._text == "", "Expected clean _text after turn 2"
|
||||
assert mock_llm.get_current_step() == 2, (
|
||||
f"Expected 2 LLM calls, got {mock_llm.get_current_step()}"
|
||||
)
|
||||
|
||||
# Verify clean context - each turn should be separate
|
||||
user_messages = [m for m in context.messages if m.get("role") == "user"]
|
||||
assert len(user_messages) == 2, (
|
||||
f"Expected 2 user messages (one per turn), got {len(user_messages)}"
|
||||
)
|
||||
|
||||
await injector.inject(EndTaskFrame(), direction=FrameDirection.UPSTREAM)
|
||||
|
||||
await _run_scenario(pipeline, inject)
|
||||
|
||||
# =========================================================================
|
||||
# Scenario 5 (✅): Late transcription - all suppressed
|
||||
#
|
||||
# BotStartedSpeaking (muted)
|
||||
# UserStartedSpeaking (suppressed)
|
||||
# UserStoppedSpeaking (suppressed)
|
||||
# TranscriptionFrame (suppressed) <- late, but still during bot speaking
|
||||
# BotStoppedSpeaking (unmuted)
|
||||
#
|
||||
# Everything suppressed, _text empty. Clean state.
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_5_late_transcription_all_suppressed(self):
|
||||
"""Late transcription arrives during bot speaking. All suppressed.
|
||||
|
||||
Even though transcription is late, it still arrives before BotStoppedSpeaking
|
||||
so it's still muted. Clean state.
|
||||
"""
|
||||
injector, user_agg, stop_strategy, turn_ctrl, mock_llm, context, pipeline = (
|
||||
_build_components()
|
||||
)
|
||||
|
||||
async def inject():
|
||||
# === Turn 1: Late transcription, but all still suppressed ===
|
||||
await injector.inject(BotStartedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
await injector.inject(VADUserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(VADUserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
# Late transcription - but still during bot speaking
|
||||
await injector.inject(
|
||||
TranscriptionFrame("late hello", "user-1", time_now_iso8601())
|
||||
)
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
await injector.inject(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
# Assert: all suppressed, clean state
|
||||
assert stop_strategy._text == "", (
|
||||
f"Expected empty _text, got '{stop_strategy._text}'"
|
||||
)
|
||||
assert not turn_ctrl._user_turn
|
||||
|
||||
# === Turn 2: Normal turn works ===
|
||||
await _inject_user_turn(injector, "clean turn")
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
assert stop_strategy._text == ""
|
||||
assert mock_llm.get_current_step() == 1
|
||||
|
||||
await injector.inject(EndTaskFrame(), direction=FrameDirection.UPSTREAM)
|
||||
|
||||
await _run_scenario(pipeline, inject)
|
||||
|
||||
# =========================================================================
|
||||
# Scenario 6 (✅ after fix): Late transcription arrives after bot stops
|
||||
#
|
||||
# BotStartedSpeaking (muted)
|
||||
# UserStartedSpeaking (suppressed)
|
||||
# UserStoppedSpeaking (suppressed)
|
||||
# BotStoppedSpeaking (unmuted)
|
||||
# TranscriptionFrame -> reaches stop strategy, _text = "late hello"
|
||||
#
|
||||
# Stop strategy timeout fires: _user_speaking is False (from initial state,
|
||||
# UserStartedSpeaking was suppressed), _text truthy -> triggers stop.
|
||||
# Turn controller: _user_turn False -> rejects, but unconditionally resets
|
||||
# -> _text cleared. No dangling state.
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_6_late_transcription_after_unmute_text_cleared(self):
|
||||
"""Late transcription arrives after bot stops. No turn was started.
|
||||
|
||||
UserStartedSpeaking was suppressed so _user_turn never started.
|
||||
The late TranscriptionFrame accumulates _text after unmute.
|
||||
The stop strategy timeout triggers, but controller rejects it.
|
||||
The unconditional reset clears _text, preventing dangling state.
|
||||
"""
|
||||
injector, user_agg, stop_strategy, turn_ctrl, mock_llm, context, pipeline = (
|
||||
_build_components()
|
||||
)
|
||||
|
||||
async def inject():
|
||||
# === Turn 1: Late transcription scenario ===
|
||||
await injector.inject(BotStartedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Suppressed
|
||||
await injector.inject(VADUserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(VADUserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Bot stops -> unmuted
|
||||
await injector.inject(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Install spy on trigger_user_turn_stopped to track calls
|
||||
trigger_stop_calls = []
|
||||
original_trigger_stop = stop_strategy.trigger_user_turn_stopped
|
||||
|
||||
async def spy_trigger_stop():
|
||||
trigger_stop_calls.append(turn_ctrl._user_turn)
|
||||
await original_trigger_stop()
|
||||
|
||||
stop_strategy.trigger_user_turn_stopped = spy_trigger_stop
|
||||
|
||||
# Late transcription arrives after unmute
|
||||
await injector.inject(
|
||||
TranscriptionFrame("late hello", "user-1", time_now_iso8601())
|
||||
)
|
||||
|
||||
# No UserStoppedSpeakingFrame in this scenario — the stop is
|
||||
# triggered ONLY by the _task_handler timeout path.
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
# The _task_handler timeout fired _maybe_trigger_user_turn_stopped:
|
||||
# _user_speaking=False (UserStartedSpeaking was suppressed),
|
||||
# _text="late hello" -> trigger_user_turn_stopped called
|
||||
# Turn controller: _user_turn=False -> rejects, but resets -> _text cleared
|
||||
assert len(trigger_stop_calls) == 1, (
|
||||
f"Expected exactly 1 trigger_user_turn_stopped call from "
|
||||
f"_task_handler timeout, got {len(trigger_stop_calls)}"
|
||||
)
|
||||
assert trigger_stop_calls[0] is False, (
|
||||
"Expected _user_turn=False when timeout triggered stop"
|
||||
)
|
||||
|
||||
# Restore original method
|
||||
stop_strategy.trigger_user_turn_stopped = original_trigger_stop
|
||||
|
||||
# Transcript is not suppressed, so we should have late hello in user aggregator
|
||||
assert user_agg._aggregation[0].text == "late hello"
|
||||
|
||||
# Assert: _text is cleared by the unconditional reset (no dangling state)
|
||||
assert stop_strategy._text == "", (
|
||||
f"Expected empty _text after unconditional reset, got '{stop_strategy._text}'"
|
||||
)
|
||||
assert not turn_ctrl._user_turn, "Turn should not have started"
|
||||
assert mock_llm.get_current_step() == 0, "No LLM call expected"
|
||||
|
||||
# === Turn 2: No premature stop, normal flow ===
|
||||
# _text is clean, so no premature stop occurs.
|
||||
# The turn completes normally when the timeout fires after TranscriptionFrame.
|
||||
# The aggregator still has dangling "late hello" from turn 1, which gets
|
||||
# combined with turn 2's "real speech" — this is acceptable behavior.
|
||||
await _inject_user_turn(injector, "real speech")
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
assert stop_strategy._text == "", (
|
||||
f"Expected clean _text after normal turn, got '{stop_strategy._text}'"
|
||||
)
|
||||
assert mock_llm.get_current_step() == 1, (
|
||||
f"Expected 1 LLM call (normal turn), got {mock_llm.get_current_step()}"
|
||||
)
|
||||
|
||||
# The LLM received both "late hello" (dangling in aggregator from turn 1)
|
||||
# and "real speech" (from turn 2).
|
||||
user_messages = [m for m in context.messages if m.get("role") == "user"]
|
||||
assert len(user_messages) == 1, (
|
||||
f"Expected 1 user message, got {len(user_messages)}"
|
||||
)
|
||||
user_text = user_messages[0]["content"]
|
||||
assert "late hello" in user_text, (
|
||||
f"Expected 'late hello' (from aggregator) in user message, got: '{user_text}'"
|
||||
)
|
||||
assert "real speech" in user_text, (
|
||||
f"Expected 'real speech' (from turn 2) in user message, got: '{user_text}'"
|
||||
)
|
||||
|
||||
await injector.inject(EndTaskFrame(), direction=FrameDirection.UPSTREAM)
|
||||
|
||||
await _run_scenario(pipeline, inject)
|
||||
|
||||
# =========================================================================
|
||||
# Scenario 7 (✅ after fix): Late transcription - user stops before transcription
|
||||
#
|
||||
# BotStartedSpeaking (muted)
|
||||
# UserStartedSpeaking (suppressed)
|
||||
# BotStoppedSpeaking (unmuted)
|
||||
# UserStoppedSpeaking (no _text yet -> no trigger from _handle_user_stopped)
|
||||
# TranscriptionFrame -> _text = "late", timeout triggers stop
|
||||
#
|
||||
# Turn controller: _user_turn False -> rejects, but unconditionally resets
|
||||
# -> _text cleared. No dangling state.
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_7_late_transcription_after_user_stops_text_cleared(self):
|
||||
"""User stops speaking, then late transcription arrives. No turn started.
|
||||
|
||||
UserStoppedSpeaking arrives first (no _text yet, so no trigger).
|
||||
Then TranscriptionFrame arrives (sets _text). The timeout fires and
|
||||
triggers stop, but controller rejects it. The unconditional reset
|
||||
clears _text, preventing dangling state.
|
||||
"""
|
||||
injector, user_agg, stop_strategy, turn_ctrl, mock_llm, context, pipeline = (
|
||||
_build_components()
|
||||
)
|
||||
|
||||
async def inject():
|
||||
# === Turn 1: Late transcription after user stops ===
|
||||
await injector.inject(BotStartedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Suppressed
|
||||
await injector.inject(VADUserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStartedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Bot stops -> unmuted
|
||||
await injector.inject(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# UserStoppedSpeaking arrives after unmute, but _text is still empty
|
||||
# -> _maybe_trigger_user_turn_stopped: _text is "" -> no trigger
|
||||
await injector.inject(UserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Late transcription arrives AFTER user stopped
|
||||
await injector.inject(
|
||||
TranscriptionFrame("late text", "user-1", time_now_iso8601())
|
||||
)
|
||||
# Wait for timeout to fire
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
# Transcript is not suppressed, so we should have late text in user aggregator
|
||||
assert user_agg._aggregation[0].text == "late text"
|
||||
|
||||
# Assert: _text is cleared by the unconditional reset
|
||||
# The timeout fired _maybe_trigger_user_turn_stopped:
|
||||
# _user_speaking=False (was never set, UserStartedSpeaking suppressed),
|
||||
# _text="late text" -> triggers stop
|
||||
# Turn controller: _user_turn=False -> rejects, but resets -> _text cleared
|
||||
assert stop_strategy._text == "", (
|
||||
f"Expected empty _text after unconditional reset, got '{stop_strategy._text}'"
|
||||
)
|
||||
assert not turn_ctrl._user_turn
|
||||
assert mock_llm.get_current_step() == 0
|
||||
|
||||
# === Turn 2: No premature stop, normal flow ===
|
||||
# _text is clean, so no premature stop occurs.
|
||||
# The turn completes normally when the timeout fires after TranscriptionFrame.
|
||||
# The aggregator still has dangling "late text" from turn 1, which gets
|
||||
# combined with turn 2's "next speech" — this is acceptable behavior.
|
||||
await _inject_user_turn(injector, "next speech")
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
assert stop_strategy._text == "", (
|
||||
f"Expected clean _text after normal turn, got '{stop_strategy._text}'"
|
||||
)
|
||||
assert mock_llm.get_current_step() == 1
|
||||
|
||||
# The LLM received both "late text" (dangling in aggregator from turn 1)
|
||||
# and "next speech" (from turn 2).
|
||||
user_messages = [m for m in context.messages if m.get("role") == "user"]
|
||||
assert len(user_messages) == 1
|
||||
user_text = user_messages[0]["content"]
|
||||
assert "late text" in user_text, (
|
||||
f"Expected 'late text' (from aggregator) in context, got: '{user_text}'"
|
||||
)
|
||||
assert "next speech" in user_text, (
|
||||
f"Expected 'next speech' (from turn 2) in context, got: '{user_text}'"
|
||||
)
|
||||
|
||||
await injector.inject(EndTaskFrame(), direction=FrameDirection.UPSTREAM)
|
||||
|
||||
await _run_scenario(pipeline, inject)
|
||||
|
||||
# =========================================================================
|
||||
# Scenario 8 (✅): Late transcription - user speaks after bot stops
|
||||
#
|
||||
# BotStartedSpeaking (muted)
|
||||
# BotStoppedSpeaking (unmuted)
|
||||
# UserStartedSpeaking (not suppressed -> turn starts, start strategies reset)
|
||||
# UserStoppedSpeaking (no _text -> no trigger)
|
||||
# TranscriptionFrame (timeout triggers stop)
|
||||
#
|
||||
# Turn controller: _user_turn IS True -> allows stop -> resets strategies
|
||||
# Clean state!
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_8_late_transcription_user_speaks_after_bot_stops(self):
|
||||
"""User speaks after bot stops, then late transcription arrives.
|
||||
|
||||
Because user spoke after unmute, VAD triggers turn start -> _user_turn=True.
|
||||
When the late transcription triggers the stop, controller allows it and
|
||||
resets strategies. Clean state.
|
||||
"""
|
||||
injector, user_agg, stop_strategy, turn_ctrl, mock_llm, context, pipeline = (
|
||||
_build_components()
|
||||
)
|
||||
|
||||
async def inject():
|
||||
# === Turn 1: Late transcription but user spoke after unmute ===
|
||||
await injector.inject(BotStartedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
await injector.inject(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# User speaks AFTER bot stops -> not suppressed
|
||||
await injector.inject(VADUserStartedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStartedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# User stops speaking (no _text yet, so stop strategy doesn't trigger)
|
||||
await injector.inject(UserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(VADUserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Late transcription arrives
|
||||
await injector.inject(
|
||||
TranscriptionFrame("late but ok", "user-1", time_now_iso8601())
|
||||
)
|
||||
# Wait for timeout to trigger stop
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
# Assert: turn controller allowed the stop, strategies were reset
|
||||
assert stop_strategy._text == "", (
|
||||
f"Expected clean _text after allowed stop, got '{stop_strategy._text}'"
|
||||
)
|
||||
assert not turn_ctrl._user_turn, "Turn should have stopped"
|
||||
assert mock_llm.get_current_step() == 1, (
|
||||
f"Expected 1 LLM call, got {mock_llm.get_current_step()}"
|
||||
)
|
||||
|
||||
# === Turn 2: Clean subsequent turn ===
|
||||
await _inject_user_turn(injector, "clean turn")
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
assert stop_strategy._text == ""
|
||||
assert mock_llm.get_current_step() == 2
|
||||
|
||||
# Verify each turn is separate in context
|
||||
user_messages = [m for m in context.messages if m.get("role") == "user"]
|
||||
assert len(user_messages) == 2, (
|
||||
f"Expected 2 separate user messages, got {len(user_messages)}"
|
||||
)
|
||||
|
||||
await injector.inject(EndTaskFrame(), direction=FrameDirection.UPSTREAM)
|
||||
|
||||
await _run_scenario(pipeline, inject)
|
||||
|
||||
# =========================================================================
|
||||
# Combined test: validates _text is cleared independently after each
|
||||
# rejected stop, preventing accumulation across muted periods.
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_cleared_independently_across_failed_stops(self):
|
||||
"""Validates _text does not accumulate across multiple failed stop attempts.
|
||||
|
||||
Two consecutive muted periods with late transcriptions each trigger
|
||||
a rejected stop. The unconditional reset clears _text after each
|
||||
rejection, so no accumulation occurs. The subsequent normal turn
|
||||
completes correctly.
|
||||
"""
|
||||
injector, user_agg, stop_strategy, turn_ctrl, mock_llm, context, pipeline = (
|
||||
_build_components()
|
||||
)
|
||||
|
||||
async def inject():
|
||||
# === Muted period 1: _text cleared after rejected stop ===
|
||||
await injector.inject(BotStartedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
await injector.inject(VADUserStartedSpeakingFrame()) # suppressed
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStartedSpeakingFrame()) # suppressed
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
await injector.inject(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
# Late transcription after unmute
|
||||
await injector.inject(
|
||||
TranscriptionFrame("first", "user-1", time_now_iso8601())
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
# Transcript is not suppressed, so we should have first in user aggregator
|
||||
assert user_agg._aggregation[0].text == "first"
|
||||
|
||||
# _text is cleared by the unconditional reset after rejected stop
|
||||
assert stop_strategy._text == "", (
|
||||
f"Expected empty _text after unconditional reset, got '{stop_strategy._text}'"
|
||||
)
|
||||
|
||||
# === Muted period 2: _text cleared independently, no accumulation ===
|
||||
await injector.inject(BotStartedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
await injector.inject(VADUserStartedSpeakingFrame()) # suppressed
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStartedSpeakingFrame()) # suppressed
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
await injector.inject(BotStoppedSpeakingFrame())
|
||||
await asyncio.sleep(ASYNC_DELAY)
|
||||
|
||||
await injector.inject(
|
||||
TranscriptionFrame("second", "user-1", time_now_iso8601())
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
await injector.inject(UserStoppedSpeakingFrame())
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
# _text is cleared again — no accumulation of "first" + "second"
|
||||
assert stop_strategy._text == "", (
|
||||
f"Expected empty _text after second unconditional reset, got '{stop_strategy._text}'"
|
||||
)
|
||||
# Aggregator accumulated both (separate concern, acceptable)
|
||||
assert len(user_agg._aggregation) == 2
|
||||
assert user_agg._aggregation[0].text == "first"
|
||||
assert user_agg._aggregation[1].text == "second"
|
||||
|
||||
# === Turn 3: No premature stop, normal flow ===
|
||||
# _text is clean, so no premature stop occurs.
|
||||
# The turn completes normally when the timeout fires after TranscriptionFrame.
|
||||
# The aggregator has dangling "first" + "second" from muted periods,
|
||||
# which get combined with turn 3's "actual speech".
|
||||
await _inject_user_turn(injector, "actual speech")
|
||||
await asyncio.sleep(TIMEOUT_WAIT)
|
||||
|
||||
assert stop_strategy._text == "", (
|
||||
f"Expected clean _text after normal turn, got '{stop_strategy._text}'"
|
||||
)
|
||||
assert mock_llm.get_current_step() == 1
|
||||
|
||||
# The LLM received all three: "first" + "second" (from aggregator)
|
||||
# and "actual speech" (from turn 3).
|
||||
user_messages = [m for m in context.messages if m.get("role") == "user"]
|
||||
assert len(user_messages) == 1, (
|
||||
f"Expected 1 user message, got {len(user_messages)}"
|
||||
)
|
||||
user_text = user_messages[0]["content"]
|
||||
assert "first" in user_text, f"Expected 'first' in '{user_text}'"
|
||||
assert "second" in user_text, f"Expected 'second' in '{user_text}'"
|
||||
assert "actual speech" in user_text, (
|
||||
f"Expected 'actual speech' in '{user_text}'"
|
||||
)
|
||||
|
||||
await injector.inject(EndTaskFrame(), direction=FrameDirection.UPSTREAM)
|
||||
|
||||
await _run_scenario(pipeline, inject)
|
||||
608
api/tests/test_workflow_versioning.py
Normal file
608
api/tests/test_workflow_versioning.py
Normal file
|
|
@ -0,0 +1,608 @@
|
|||
"""
|
||||
TDD tests for workflow versioning lifecycle.
|
||||
|
||||
Tests the version lifecycle on WorkflowDefinitionModel:
|
||||
- status: draft / published / archived
|
||||
- version_number: sequential per workflow
|
||||
- released_definition_id on WorkflowModel
|
||||
|
||||
Modules under test:
|
||||
- api.db.workflow_client (new versioning methods)
|
||||
- api.db.models (new columns on WorkflowDefinitionModel, WorkflowModel)
|
||||
|
||||
These are DB integration tests using the transactional test session.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from api.db.models import (
|
||||
OrganizationModel,
|
||||
UserModel,
|
||||
)
|
||||
|
||||
# Sample workflow definitions (graph JSON)
|
||||
GRAPH_V1 = {
|
||||
"nodes": [
|
||||
{"id": "1", "type": "startCall", "data": {"name": "Start", "prompt": "Hello"}},
|
||||
{"id": "2", "type": "endCall", "data": {"name": "End", "prompt": "Bye"}},
|
||||
],
|
||||
"edges": [{"id": "e1", "source": "1", "target": "2", "data": {"label": "End"}}],
|
||||
}
|
||||
|
||||
GRAPH_V2 = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "startCall",
|
||||
"data": {"name": "Start", "prompt": "Hello v2"},
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"type": "agentNode",
|
||||
"data": {"name": "Agent", "prompt": "Collect info"},
|
||||
},
|
||||
{"id": "3", "type": "endCall", "data": {"name": "End", "prompt": "Bye"}},
|
||||
],
|
||||
"edges": [
|
||||
{"id": "e1", "source": "1", "target": "2", "data": {"label": "Collect"}},
|
||||
{"id": "e2", "source": "2", "target": "3", "data": {"label": "End"}},
|
||||
],
|
||||
}
|
||||
|
||||
GRAPH_V3 = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "startCall",
|
||||
"data": {"name": "Start", "prompt": "Hello v3"},
|
||||
},
|
||||
{"id": "2", "type": "endCall", "data": {"name": "End", "prompt": "Goodbye"}},
|
||||
],
|
||||
"edges": [{"id": "e1", "source": "1", "target": "2", "data": {"label": "End"}}],
|
||||
}
|
||||
|
||||
CONFIG_V1 = {"max_call_duration": 300}
|
||||
CONFIG_V2 = {
|
||||
"max_call_duration": 600,
|
||||
"model_overrides": {"llm": {"model": "gpt-4.1-mini"}},
|
||||
}
|
||||
TEMPLATE_VARS_V1 = {"company_name": "Acme"}
|
||||
TEMPLATE_VARS_V2 = {"company_name": "Acme Inc"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def org_and_user(async_session):
|
||||
"""Create an organization and user for workflow tests."""
|
||||
org = OrganizationModel(provider_id="test-org-versioning")
|
||||
async_session.add(org)
|
||||
await async_session.flush()
|
||||
|
||||
user = UserModel(
|
||||
provider_id="test-user-versioning", selected_organization_id=org.id
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.flush()
|
||||
|
||||
return org, user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def workflow_with_v1(db_session, org_and_user):
|
||||
"""Create a workflow — should produce V1 as published."""
|
||||
org, user = org_and_user
|
||||
workflow = await db_session.create_workflow(
|
||||
name="Test Workflow",
|
||||
workflow_definition=GRAPH_V1,
|
||||
user_id=user.id,
|
||||
organization_id=org.id,
|
||||
)
|
||||
return workflow, user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow creation → V1 published
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWorkflowCreation:
|
||||
async def test_create_workflow_produces_published_v1(
|
||||
self, db_session, org_and_user
|
||||
):
|
||||
"""Creating a new workflow should produce exactly one definition
|
||||
with status='published' and version_number=1."""
|
||||
org, user = org_and_user
|
||||
workflow = await db_session.create_workflow(
|
||||
name="New Workflow",
|
||||
workflow_definition=GRAPH_V1,
|
||||
user_id=user.id,
|
||||
organization_id=org.id,
|
||||
)
|
||||
|
||||
versions = await db_session.get_workflow_versions(workflow.id)
|
||||
assert len(versions) == 1
|
||||
|
||||
v1 = versions[0]
|
||||
assert v1.status == "published"
|
||||
assert v1.version_number == 1
|
||||
assert v1.workflow_json == GRAPH_V1
|
||||
|
||||
async def test_create_workflow_sets_released_pointer(
|
||||
self, db_session, org_and_user
|
||||
):
|
||||
"""The workflow's released_definition_id should point to V1."""
|
||||
org, user = org_and_user
|
||||
workflow = await db_session.create_workflow(
|
||||
name="Pointer Test",
|
||||
workflow_definition=GRAPH_V1,
|
||||
user_id=user.id,
|
||||
organization_id=org.id,
|
||||
)
|
||||
|
||||
versions = await db_session.get_workflow_versions(workflow.id)
|
||||
assert workflow.released_definition_id == versions[0].id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Saving a draft
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSaveDraft:
|
||||
async def test_save_draft_creates_draft_version(self, db_session, workflow_with_v1):
|
||||
"""Saving changes to a published workflow creates a draft version."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
draft = await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
workflow_configurations=CONFIG_V2,
|
||||
template_context_variables=TEMPLATE_VARS_V2,
|
||||
)
|
||||
|
||||
assert draft.status == "draft"
|
||||
assert draft.version_number == 2
|
||||
assert draft.workflow_json == GRAPH_V2
|
||||
assert draft.workflow_configurations == CONFIG_V2
|
||||
assert draft.template_context_variables == TEMPLATE_VARS_V2
|
||||
|
||||
async def test_save_draft_does_not_change_released_pointer(
|
||||
self, db_session, workflow_with_v1
|
||||
):
|
||||
"""Creating a draft must not move the released pointer."""
|
||||
workflow, user = workflow_with_v1
|
||||
original_released_id = workflow.released_definition_id
|
||||
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
)
|
||||
|
||||
refreshed = await db_session.get_workflow(workflow.id)
|
||||
assert refreshed.released_definition_id == original_released_id
|
||||
|
||||
async def test_save_draft_twice_updates_in_place(
|
||||
self, db_session, workflow_with_v1
|
||||
):
|
||||
"""Saving a second draft should update the existing draft, not create a new row."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
draft1 = await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
)
|
||||
|
||||
draft2 = await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V3,
|
||||
)
|
||||
|
||||
assert draft1.id == draft2.id # same row
|
||||
assert draft2.workflow_json == GRAPH_V3
|
||||
assert draft2.version_number == 2 # unchanged
|
||||
|
||||
versions = await db_session.get_workflow_versions(workflow.id)
|
||||
assert len(versions) == 2 # V1 published + V2 draft, no extras
|
||||
|
||||
async def test_save_draft_with_only_config_change(
|
||||
self, db_session, workflow_with_v1
|
||||
):
|
||||
"""A draft can change only configs, keeping the same graph."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
draft = await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V1, # same graph
|
||||
workflow_configurations=CONFIG_V2, # different config
|
||||
)
|
||||
|
||||
assert draft.status == "draft"
|
||||
assert draft.workflow_json == GRAPH_V1
|
||||
assert draft.workflow_configurations == CONFIG_V2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Publishing a draft
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPublishDraft:
|
||||
async def test_publish_promotes_draft_to_published(
|
||||
self, db_session, workflow_with_v1
|
||||
):
|
||||
"""Publishing moves draft → published and old published → archived."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
workflow_configurations=CONFIG_V2,
|
||||
)
|
||||
|
||||
published = await db_session.publish_workflow_draft(workflow.id)
|
||||
|
||||
assert published.status == "published"
|
||||
assert published.workflow_json == GRAPH_V2
|
||||
|
||||
versions = await db_session.get_workflow_versions(workflow.id)
|
||||
statuses = {v.version_number: v.status for v in versions}
|
||||
assert statuses[1] == "archived"
|
||||
assert statuses[2] == "published"
|
||||
|
||||
async def test_publish_updates_released_pointer(self, db_session, workflow_with_v1):
|
||||
"""After publishing, released_definition_id should point to the new version."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
draft = await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
)
|
||||
|
||||
await db_session.publish_workflow_draft(workflow.id)
|
||||
|
||||
refreshed = await db_session.get_workflow(workflow.id)
|
||||
assert refreshed.released_definition_id == draft.id
|
||||
|
||||
async def test_publish_sets_published_at(self, db_session, workflow_with_v1):
|
||||
"""Published version should have a published_at timestamp."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
)
|
||||
|
||||
published = await db_session.publish_workflow_draft(workflow.id)
|
||||
assert published.published_at is not None
|
||||
|
||||
async def test_publish_with_no_draft_raises(self, db_session, workflow_with_v1):
|
||||
"""Publishing when no draft exists should raise an error."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
with pytest.raises(ValueError, match="[Nn]o draft"):
|
||||
await db_session.publish_workflow_draft(workflow.id)
|
||||
|
||||
async def test_exactly_one_published_after_multiple_cycles(
|
||||
self, db_session, workflow_with_v1
|
||||
):
|
||||
"""After several draft/publish cycles, exactly one version is published."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
# Cycle 1
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
)
|
||||
await db_session.publish_workflow_draft(workflow.id)
|
||||
|
||||
# Cycle 2
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V3,
|
||||
)
|
||||
await db_session.publish_workflow_draft(workflow.id)
|
||||
|
||||
versions = await db_session.get_workflow_versions(workflow.id)
|
||||
published = [v for v in versions if v.status == "published"]
|
||||
assert len(published) == 1
|
||||
assert published[0].version_number == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discarding a draft
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDiscardDraft:
|
||||
async def test_discard_removes_draft(self, db_session, workflow_with_v1):
|
||||
"""Discarding a draft should delete the draft row."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
)
|
||||
|
||||
await db_session.discard_workflow_draft(workflow.id)
|
||||
|
||||
versions = await db_session.get_workflow_versions(workflow.id)
|
||||
assert len(versions) == 1
|
||||
assert versions[0].status == "published"
|
||||
|
||||
async def test_discard_does_not_affect_published(
|
||||
self, db_session, workflow_with_v1
|
||||
):
|
||||
"""Published version and released pointer are unchanged after discard."""
|
||||
workflow, user = workflow_with_v1
|
||||
original_released_id = workflow.released_definition_id
|
||||
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
)
|
||||
await db_session.discard_workflow_draft(workflow.id)
|
||||
|
||||
refreshed = await db_session.get_workflow(workflow.id)
|
||||
assert refreshed.released_definition_id == original_released_id
|
||||
|
||||
async def test_discard_when_no_draft_raises(self, db_session, workflow_with_v1):
|
||||
"""Discarding when no draft exists should raise an error."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
with pytest.raises(ValueError, match="[Nn]o draft"):
|
||||
await db_session.discard_workflow_draft(workflow.id)
|
||||
|
||||
async def test_new_draft_after_discard_gets_next_version_number(
|
||||
self, db_session, workflow_with_v1
|
||||
):
|
||||
"""After discarding V2 draft, the next draft should still be V2
|
||||
(since V2 was deleted and never published)."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
)
|
||||
await db_session.discard_workflow_draft(workflow.id)
|
||||
|
||||
new_draft = await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V3,
|
||||
)
|
||||
# Version number reuse is acceptable since V2 was never published
|
||||
assert new_draft.version_number == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reverting to an archived version
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRevert:
|
||||
async def _publish_v2(self, db_session, workflow):
|
||||
"""Helper: create and publish V2, making V1 archived."""
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
workflow_configurations=CONFIG_V2,
|
||||
template_context_variables=TEMPLATE_VARS_V2,
|
||||
)
|
||||
return await db_session.publish_workflow_draft(workflow.id)
|
||||
|
||||
async def test_revert_creates_draft_from_archived(
|
||||
self, db_session, workflow_with_v1
|
||||
):
|
||||
"""Reverting copies the archived version's full snapshot into a new draft."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
# Get V1's definition ID before it gets archived
|
||||
versions_before = await db_session.get_workflow_versions(workflow.id)
|
||||
v1_id = versions_before[0].id
|
||||
|
||||
# Publish V2, archiving V1
|
||||
await self._publish_v2(db_session, workflow)
|
||||
|
||||
# Revert to V1
|
||||
draft = await db_session.revert_to_version(workflow.id, v1_id)
|
||||
|
||||
assert draft.status == "draft"
|
||||
assert draft.workflow_json == GRAPH_V1
|
||||
|
||||
async def test_revert_preserves_all_snapshot_fields(
|
||||
self, db_session, workflow_with_v1
|
||||
):
|
||||
"""Revert should copy graph, configs, and template vars."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
# Publish V2 with full config
|
||||
v2 = await self._publish_v2(db_session, workflow)
|
||||
|
||||
# Publish V3, archiving V2
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V3,
|
||||
)
|
||||
await db_session.publish_workflow_draft(workflow.id)
|
||||
|
||||
# Revert to V2
|
||||
draft = await db_session.revert_to_version(workflow.id, v2.id)
|
||||
|
||||
assert draft.workflow_json == GRAPH_V2
|
||||
assert draft.workflow_configurations == CONFIG_V2
|
||||
assert draft.template_context_variables == TEMPLATE_VARS_V2
|
||||
|
||||
async def test_revert_when_draft_exists_raises(self, db_session, workflow_with_v1):
|
||||
"""Cannot revert when a draft already exists — must discard first."""
|
||||
workflow, user = workflow_with_v1
|
||||
versions = await db_session.get_workflow_versions(workflow.id)
|
||||
v1_id = versions[0].id
|
||||
|
||||
await self._publish_v2(db_session, workflow)
|
||||
|
||||
# Create a draft
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V3,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="[Dd]raft.*exists"):
|
||||
await db_session.revert_to_version(workflow.id, v1_id)
|
||||
|
||||
async def test_revert_does_not_change_released_pointer(
|
||||
self, db_session, workflow_with_v1
|
||||
):
|
||||
"""Revert creates a draft — the released pointer stays on the published version."""
|
||||
workflow, user = workflow_with_v1
|
||||
versions = await db_session.get_workflow_versions(workflow.id)
|
||||
v1_id = versions[0].id
|
||||
|
||||
v2 = await self._publish_v2(db_session, workflow)
|
||||
|
||||
await db_session.revert_to_version(workflow.id, v1_id)
|
||||
|
||||
refreshed = await db_session.get_workflow(workflow.id)
|
||||
assert refreshed.released_definition_id == v2.id # still V2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Version listing & ordering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVersionListing:
|
||||
async def test_versions_ordered_by_version_number_desc(
|
||||
self, db_session, workflow_with_v1
|
||||
):
|
||||
"""Versions should be returned newest first."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
)
|
||||
await db_session.publish_workflow_draft(workflow.id)
|
||||
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V3,
|
||||
)
|
||||
|
||||
versions = await db_session.get_workflow_versions(workflow.id)
|
||||
version_numbers = [v.version_number for v in versions]
|
||||
assert version_numbers == sorted(version_numbers, reverse=True)
|
||||
|
||||
async def test_versions_include_status(self, db_session, workflow_with_v1):
|
||||
"""Each version should have an explicit status."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
)
|
||||
await db_session.publish_workflow_draft(workflow.id)
|
||||
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V3,
|
||||
)
|
||||
|
||||
versions = await db_session.get_workflow_versions(workflow.id)
|
||||
statuses = {v.version_number: v.status for v in versions}
|
||||
assert statuses == {1: "archived", 2: "published", 3: "draft"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Version data stored on definition, not workflow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVersionDataOnDefinition:
|
||||
async def test_configs_stored_on_definition(self, db_session, workflow_with_v1):
|
||||
"""workflow_configurations should be on the definition, not just the workflow."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
draft = await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
workflow_configurations=CONFIG_V2,
|
||||
template_context_variables=TEMPLATE_VARS_V2,
|
||||
)
|
||||
|
||||
assert draft.workflow_configurations == CONFIG_V2
|
||||
assert draft.template_context_variables == TEMPLATE_VARS_V2
|
||||
|
||||
async def test_different_versions_have_different_configs(
|
||||
self, db_session, workflow_with_v1
|
||||
):
|
||||
"""V1 and V2 can have different configs stored independently."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
workflow_configurations=CONFIG_V2,
|
||||
)
|
||||
await db_session.publish_workflow_draft(workflow.id)
|
||||
|
||||
versions = await db_session.get_workflow_versions(workflow.id)
|
||||
configs_by_version = {
|
||||
v.version_number: v.workflow_configurations for v in versions
|
||||
}
|
||||
|
||||
assert configs_by_version[1] != configs_by_version[2]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Run creation uses published (or draft for testing)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunDefinitionBinding:
|
||||
async def test_campaign_run_uses_published_version(
|
||||
self, db_session, workflow_with_v1
|
||||
):
|
||||
"""A campaign-initiated run should use the published version, not draft."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
# Create a draft (unpublished)
|
||||
await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
)
|
||||
|
||||
# Create a run (simulating campaign dispatch)
|
||||
run = await db_session.create_workflow_run(
|
||||
name="Campaign Run",
|
||||
workflow_id=workflow.id,
|
||||
mode="webrtc",
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
# Run should be bound to the published V1, not the draft V2
|
||||
versions = await db_session.get_workflow_versions(workflow.id)
|
||||
published = next(v for v in versions if v.status == "published")
|
||||
assert run.definition_id == published.id
|
||||
|
||||
async def test_test_run_uses_draft_if_exists(self, db_session, workflow_with_v1):
|
||||
"""A test/phone call should use the draft version for pre-publish testing."""
|
||||
workflow, user = workflow_with_v1
|
||||
|
||||
draft = await db_session.save_workflow_draft(
|
||||
workflow_id=workflow.id,
|
||||
workflow_definition=GRAPH_V2,
|
||||
)
|
||||
|
||||
# Create a test run
|
||||
run = await db_session.create_workflow_run(
|
||||
name="Test Run",
|
||||
workflow_id=workflow.id,
|
||||
mode="webrtc", # test mode
|
||||
user_id=user.id,
|
||||
use_draft=True,
|
||||
)
|
||||
|
||||
assert run.definition_id == draft.id
|
||||
Loading…
Add table
Add a link
Reference in a new issue