mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: add google stt and tts. add folders to organize agents
This commit is contained in:
parent
21951eca18
commit
ad2fa07058
52 changed files with 3412 additions and 621 deletions
|
|
@ -0,0 +1,61 @@
|
|||
"""add folders and workflow folder_id
|
||||
|
||||
Revision ID: 6bd9f67ec994
|
||||
Revises: 2f638891cbb6
|
||||
Create Date: 2026-05-22 12:52:30.737380
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6bd9f67ec994"
|
||||
down_revision: Union[str, None] = "2f638891cbb6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"folders",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("organization_id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("organization_id", "name", name="uq_folder_org_name"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_folders_organization_id"), "folders", ["organization_id"], unique=False
|
||||
)
|
||||
op.add_column("workflows", sa.Column("folder_id", sa.Integer(), nullable=True))
|
||||
op.create_index(
|
||||
op.f("ix_workflows_folder_id"), "workflows", ["folder_id"], unique=False
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_workflows_folder_id",
|
||||
"workflows",
|
||||
"folders",
|
||||
["folder_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("fk_workflows_folder_id", "workflows", type_="foreignkey")
|
||||
op.drop_index(op.f("ix_workflows_folder_id"), table_name="workflows")
|
||||
op.drop_column("workflows", "folder_id")
|
||||
op.drop_index(op.f("ix_folders_organization_id"), table_name="folders")
|
||||
op.drop_table("folders")
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -2,6 +2,7 @@ from api.db.agent_trigger_client import AgentTriggerClient
|
|||
from api.db.api_key_client import APIKeyClient
|
||||
from api.db.campaign_client import CampaignClient
|
||||
from api.db.embed_token_client import EmbedTokenClient
|
||||
from api.db.folder_client import FolderClient
|
||||
from api.db.integration_client import IntegrationClient
|
||||
from api.db.knowledge_base_client import KnowledgeBaseClient
|
||||
from api.db.organization_client import OrganizationClient
|
||||
|
|
@ -41,6 +42,7 @@ class DBClient(
|
|||
WorkflowRecordingClient,
|
||||
TelephonyConfigurationClient,
|
||||
TelephonyPhoneNumberClient,
|
||||
FolderClient,
|
||||
):
|
||||
"""
|
||||
Unified database client that combines all specialized database operations.
|
||||
|
|
@ -62,6 +64,7 @@ class DBClient(
|
|||
- WebhookCredentialClient: handles webhook credential operations
|
||||
- ToolClient: handles tool operations for reusable HTTP API tools
|
||||
- KnowledgeBaseClient: handles knowledge base document and vector search operations
|
||||
- FolderClient: handles folder operations for grouping workflows (agents)
|
||||
"""
|
||||
|
||||
pass
|
||||
|
|
|
|||
115
api/db/folder_client.py
Normal file
115
api/db/folder_client.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
from sqlalchemy import func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import FolderModel, WorkflowModel
|
||||
from api.enums import WorkflowStatus
|
||||
|
||||
|
||||
class FolderNameConflictError(Exception):
|
||||
"""Raised when a folder name already exists within the organization."""
|
||||
|
||||
|
||||
class FolderClient(BaseDBClient):
|
||||
async def create_folder(self, name: str, organization_id: int) -> FolderModel:
|
||||
async with self.async_session() as session:
|
||||
folder = FolderModel(name=name, organization_id=organization_id)
|
||||
session.add(folder)
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise FolderNameConflictError(
|
||||
f"A folder named '{name}' already exists."
|
||||
)
|
||||
await session.refresh(folder)
|
||||
return folder
|
||||
|
||||
async def get_folder(
|
||||
self, folder_id: int, organization_id: int
|
||||
) -> FolderModel | None:
|
||||
"""Fetch a single folder scoped to the organization (tenant isolation)."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(FolderModel).where(
|
||||
FolderModel.id == folder_id,
|
||||
FolderModel.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_folders(self, organization_id: int) -> list[FolderModel]:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(FolderModel)
|
||||
.where(FolderModel.organization_id == organization_id)
|
||||
.order_by(FolderModel.name.asc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def rename_folder(
|
||||
self, folder_id: int, name: str, organization_id: int
|
||||
) -> FolderModel:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(FolderModel).where(
|
||||
FolderModel.id == folder_id,
|
||||
FolderModel.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
folder = result.scalar_one_or_none()
|
||||
if folder is None:
|
||||
raise ValueError(f"Folder with id {folder_id} not found")
|
||||
|
||||
folder.name = name
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise FolderNameConflictError(
|
||||
f"A folder named '{name}' already exists."
|
||||
)
|
||||
await session.refresh(folder)
|
||||
return folder
|
||||
|
||||
async def delete_folder(self, folder_id: int, organization_id: int) -> bool:
|
||||
"""Delete a folder. Member workflows are un-filed (folder_id -> NULL)
|
||||
via the ON DELETE SET NULL foreign key, never deleted.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(FolderModel).where(
|
||||
FolderModel.id == folder_id,
|
||||
FolderModel.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
folder = result.scalar_one_or_none()
|
||||
if folder is None:
|
||||
return False
|
||||
|
||||
await session.delete(folder)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def get_active_workflow_counts_by_folder(
|
||||
self, organization_id: int
|
||||
) -> dict[int, int]:
|
||||
"""Return {folder_id: active_workflow_count} for the organization.
|
||||
|
||||
Only counts active (non-archived) workflows with a non-NULL folder_id.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(
|
||||
WorkflowModel.folder_id,
|
||||
func.count(WorkflowModel.id).label("count"),
|
||||
)
|
||||
.where(
|
||||
WorkflowModel.organization_id == organization_id,
|
||||
WorkflowModel.folder_id.is_not(None),
|
||||
WorkflowModel.status == WorkflowStatus.ACTIVE.value,
|
||||
)
|
||||
.group_by(WorkflowModel.folder_id)
|
||||
)
|
||||
return {folder_id: count for folder_id, count in result.all()}
|
||||
|
|
@ -352,6 +352,32 @@ class WorkflowDefinitionModel(Base):
|
|||
workflow_runs = relationship("WorkflowRunModel", back_populates="definition")
|
||||
|
||||
|
||||
class FolderModel(Base):
|
||||
"""A folder for grouping workflows (agents) within an organization.
|
||||
|
||||
Folders are flat (no nesting) and org-scoped. A workflow belongs to at
|
||||
most one folder via ``WorkflowModel.folder_id``; a NULL folder_id means
|
||||
the workflow is "Uncategorized".
|
||||
"""
|
||||
|
||||
__tablename__ = "folders"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
organization_id = Column(
|
||||
Integer, ForeignKey("organizations.id"), nullable=False, index=True
|
||||
)
|
||||
organization = relationship("OrganizationModel")
|
||||
name = Column(String, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
workflows = relationship("WorkflowModel", back_populates="folder")
|
||||
|
||||
# Folder names must be unique within an organization.
|
||||
__table_args__ = (
|
||||
UniqueConstraint("organization_id", "name", name="uq_folder_org_name"),
|
||||
)
|
||||
|
||||
|
||||
class WorkflowModel(Base):
|
||||
__tablename__ = "workflows"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
|
@ -366,6 +392,15 @@ class WorkflowModel(Base):
|
|||
user = relationship("UserModel", back_populates="workflows")
|
||||
organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=True)
|
||||
organization = relationship("OrganizationModel")
|
||||
# Optional folder for grouping in the agents list. NULL = "Uncategorized".
|
||||
# ON DELETE SET NULL: deleting a folder un-files its agents, never deletes them.
|
||||
folder_id = Column(
|
||||
Integer,
|
||||
ForeignKey("folders.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
folder = relationship("FolderModel", back_populates="workflows")
|
||||
name = Column(String, index=True, nullable=False)
|
||||
status = Column(
|
||||
Enum(*[status.value for status in WorkflowStatus], name="workflow_status"),
|
||||
|
|
|
|||
|
|
@ -372,6 +372,8 @@ class WorkflowClient(BaseDBClient):
|
|||
WorkflowModel.name,
|
||||
WorkflowModel.status,
|
||||
WorkflowModel.created_at,
|
||||
WorkflowModel.folder_id,
|
||||
WorkflowModel.workflow_uuid,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -425,8 +427,26 @@ class WorkflowClient(BaseDBClient):
|
|||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_workflow(
|
||||
self, workflow_id: int, user_id: int = None, organization_id: int = None
|
||||
self,
|
||||
workflow_id: int,
|
||||
user_id: int | None = None,
|
||||
organization_id: int | None = None,
|
||||
) -> WorkflowModel | None:
|
||||
"""Fetch a workflow by id, scoped to a tenant.
|
||||
|
||||
Scoping is mandatory: pass ``organization_id`` (preferred) or
|
||||
``user_id``. A fully unscoped lookup would let a request-supplied id
|
||||
reach another tenant's workflow. System/runtime paths that only have a
|
||||
``workflow_id`` and derive the org from the workflow itself (e.g.
|
||||
inbound telephony routing) must call ``get_workflow_by_id`` instead —
|
||||
the explicit unscoped variant.
|
||||
"""
|
||||
if user_id is None and organization_id is None:
|
||||
raise ValueError(
|
||||
"get_workflow requires organization_id (preferred) or user_id "
|
||||
"for tenant scoping; use get_workflow_by_id for unscoped "
|
||||
"system lookups."
|
||||
)
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowModel)
|
||||
|
|
@ -448,6 +468,13 @@ class WorkflowClient(BaseDBClient):
|
|||
return result.scalars().first()
|
||||
|
||||
async def get_workflow_by_id(self, workflow_id: int) -> WorkflowModel | None:
|
||||
"""Fetch a workflow by id WITHOUT tenant scoping.
|
||||
|
||||
Explicit unscoped variant of ``get_workflow``. Only for system/runtime
|
||||
contexts that legitimately have just a workflow_id and derive the org
|
||||
from the workflow itself (e.g. inbound telephony). Never call this with
|
||||
a request-supplied id on a user-facing path.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowModel)
|
||||
|
|
@ -609,7 +636,7 @@ class WorkflowClient(BaseDBClient):
|
|||
self,
|
||||
workflow_id: int,
|
||||
status: str,
|
||||
organization_id: int = None,
|
||||
organization_id: int,
|
||||
) -> WorkflowModel:
|
||||
"""
|
||||
Update the status of a workflow.
|
||||
|
|
@ -617,7 +644,9 @@ class WorkflowClient(BaseDBClient):
|
|||
Args:
|
||||
workflow_id: The ID of the workflow to update
|
||||
status: The new status (active/archived)
|
||||
organization_id: The organization ID
|
||||
organization_id: The organization ID. Required and always filtered
|
||||
on: this is a mutation, so an unscoped query would let a caller
|
||||
archive another org's workflow (tenant-isolation bypass).
|
||||
|
||||
Returns:
|
||||
The updated WorkflowModel
|
||||
|
|
@ -632,12 +661,12 @@ class WorkflowClient(BaseDBClient):
|
|||
selectinload(WorkflowModel.current_definition),
|
||||
selectinload(WorkflowModel.released_definition),
|
||||
)
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
.where(
|
||||
WorkflowModel.id == workflow_id,
|
||||
WorkflowModel.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
|
||||
if organization_id:
|
||||
query = query.where(WorkflowModel.organization_id == organization_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
workflow = result.scalars().first()
|
||||
|
||||
|
|
@ -654,6 +683,47 @@ class WorkflowClient(BaseDBClient):
|
|||
await session.refresh(workflow)
|
||||
return workflow
|
||||
|
||||
async def move_workflow_to_folder(
|
||||
self,
|
||||
workflow_id: int,
|
||||
folder_id: int | None,
|
||||
organization_id: int,
|
||||
) -> WorkflowModel:
|
||||
"""Set (or clear) a workflow's folder.
|
||||
|
||||
Pass ``folder_id=None`` to move the workflow to "Uncategorized". The
|
||||
caller must validate that ``folder_id`` belongs to ``organization_id``
|
||||
before calling (the FK only proves the folder exists, not ownership).
|
||||
|
||||
``organization_id`` is required and always filtered on: this is a
|
||||
mutation, so an unscoped query would let a caller move another org's
|
||||
workflow (tenant-isolation bypass).
|
||||
|
||||
Raises:
|
||||
ValueError: If the workflow is not found within the organization.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(WorkflowModel).where(
|
||||
WorkflowModel.id == workflow_id,
|
||||
WorkflowModel.organization_id == organization_id,
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
workflow = result.scalars().first()
|
||||
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
||||
|
||||
workflow.folder_id = folder_id
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(workflow)
|
||||
return workflow
|
||||
|
||||
async def get_workflow_run_count(self, workflow_id: int) -> int:
|
||||
"""Get the count of runs for a workflow."""
|
||||
async with self.async_session() as session:
|
||||
|
|
|
|||
|
|
@ -34,6 +34,10 @@ from api.mcp_server.ts_bridge import TsBridgeError, parse_code
|
|||
from api.services.posthog_client import capture_event
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.layout import reconcile_positions
|
||||
from api.services.workflow.trigger_paths import (
|
||||
extract_trigger_paths,
|
||||
validate_trigger_paths,
|
||||
)
|
||||
from api.services.workflow.workflow_graph import WorkflowGraph
|
||||
|
||||
|
||||
|
|
@ -53,20 +57,6 @@ def _format_errors(errors: list[dict[str, Any]]) -> str:
|
|||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _extract_trigger_paths(workflow_definition: dict) -> list[str]:
|
||||
"""Mirror of `routes.workflow.extract_trigger_paths` — kept local so the
|
||||
MCP layer doesn't depend on the route module."""
|
||||
if not workflow_definition:
|
||||
return []
|
||||
paths: list[str] = []
|
||||
for node in workflow_definition.get("nodes") or []:
|
||||
if node.get("type") == "trigger":
|
||||
trigger_path = (node.get("data") or {}).get("trigger_path")
|
||||
if trigger_path:
|
||||
paths.append(trigger_path)
|
||||
return paths
|
||||
|
||||
|
||||
@traced_tool
|
||||
async def create_workflow(code: str) -> dict[str, Any]:
|
||||
"""Parse SDK TypeScript and create a new published workflow.
|
||||
|
|
@ -129,6 +119,12 @@ async def create_workflow(code: str) -> dict[str, Any]:
|
|||
# 1b. New workflow — no prior version to reconcile against; layout
|
||||
# places new nodes adjacent to their first incoming neighbor.
|
||||
payload = reconcile_positions(payload, None)
|
||||
trigger_path_issues = validate_trigger_paths(payload)
|
||||
if trigger_path_issues:
|
||||
return _error_result(
|
||||
"validation_error",
|
||||
"\n".join(issue.message for issue in trigger_path_issues),
|
||||
)
|
||||
|
||||
# 2. Pydantic shape check (defence in depth — parser is spec-driven).
|
||||
try:
|
||||
|
|
@ -144,7 +140,7 @@ async def create_workflow(code: str) -> dict[str, Any]:
|
|||
|
||||
# 4. Reject upfront if any trigger path collides with another workflow's
|
||||
# trigger in this org so we don't leave an orphan workflow record.
|
||||
trigger_paths = _extract_trigger_paths(payload)
|
||||
trigger_paths = extract_trigger_paths(payload)
|
||||
if trigger_paths:
|
||||
try:
|
||||
await db_client.assert_trigger_paths_available(
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from api.mcp_server.tracing import traced_tool
|
|||
from api.mcp_server.ts_bridge import TsBridgeError, parse_code
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.layout import reconcile_positions
|
||||
from api.services.workflow.trigger_paths import validate_trigger_paths
|
||||
from api.services.workflow.workflow_graph import WorkflowGraph
|
||||
|
||||
|
||||
|
|
@ -129,6 +130,12 @@ async def save_workflow(workflow_id: int, code: str) -> dict[str, Any]:
|
|||
# here we fill them back in from what was there before, and pick
|
||||
# approximate placements for newly-introduced nodes.
|
||||
payload = reconcile_positions(payload, await _previous_workflow_json(workflow))
|
||||
trigger_path_issues = validate_trigger_paths(payload)
|
||||
if trigger_path_issues:
|
||||
return _error_result(
|
||||
"validation_error",
|
||||
"\n".join(issue.message for issue in trigger_path_issues),
|
||||
)
|
||||
|
||||
# 2. Pydantic shape check (defence in depth — parser is spec-driven).
|
||||
try:
|
||||
|
|
|
|||
99
api/routes/folder.py
Normal file
99
api/routes/folder.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.folder_client import FolderNameConflictError
|
||||
from api.db.models import UserModel
|
||||
from api.services.auth.depends import get_user
|
||||
|
||||
router = APIRouter(prefix="/folder")
|
||||
|
||||
|
||||
class FolderResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class CreateFolderRequest(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def strip_name(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("Folder name cannot be empty")
|
||||
return v
|
||||
|
||||
|
||||
class UpdateFolderRequest(CreateFolderRequest):
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_folders(
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> list[FolderResponse]:
|
||||
"""List all folders in the authenticated user's organization."""
|
||||
folders = await db_client.list_folders(
|
||||
organization_id=user.selected_organization_id
|
||||
)
|
||||
return [
|
||||
FolderResponse(id=f.id, name=f.name, created_at=f.created_at) for f in folders
|
||||
]
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_folder(
|
||||
request: CreateFolderRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> FolderResponse:
|
||||
"""Create a new folder in the authenticated user's organization."""
|
||||
try:
|
||||
folder = await db_client.create_folder(
|
||||
name=request.name,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
except FolderNameConflictError as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return FolderResponse(id=folder.id, name=folder.name, created_at=folder.created_at)
|
||||
|
||||
|
||||
@router.put("/{folder_id}")
|
||||
async def rename_folder(
|
||||
folder_id: int,
|
||||
request: UpdateFolderRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> FolderResponse:
|
||||
"""Rename a folder owned by the authenticated user's organization."""
|
||||
try:
|
||||
folder = await db_client.rename_folder(
|
||||
folder_id=folder_id,
|
||||
name=request.name,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except FolderNameConflictError as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return FolderResponse(id=folder.id, name=folder.name, created_at=folder.created_at)
|
||||
|
||||
|
||||
@router.delete("/{folder_id}")
|
||||
async def delete_folder(
|
||||
folder_id: int,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> dict[str, bool]:
|
||||
"""Delete a folder. Member agents are moved to "Uncategorized", not deleted."""
|
||||
deleted = await db_client.delete_folder(
|
||||
folder_id=folder_id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Folder with id {folder_id} not found"
|
||||
)
|
||||
return {"success": True}
|
||||
|
|
@ -6,6 +6,7 @@ from api.routes.agent_stream import router as agent_stream_router
|
|||
from api.routes.auth import router as auth_router
|
||||
from api.routes.campaign import router as campaign_router
|
||||
from api.routes.credentials import router as credentials_router
|
||||
from api.routes.folder import router as folder_router
|
||||
from api.routes.knowledge_base import router as knowledge_base_router
|
||||
from api.routes.node_types import router as node_types_router
|
||||
from api.routes.organization import router as organization_router
|
||||
|
|
@ -54,6 +55,7 @@ router.include_router(public_download_router)
|
|||
router.include_router(workflow_embed_router)
|
||||
router.include_router(knowledge_base_router)
|
||||
router.include_router(workflow_recording_router)
|
||||
router.include_router(folder_router)
|
||||
router.include_router(auth_router)
|
||||
router.include_router(node_types_router)
|
||||
router.include_router(agent_stream_router)
|
||||
|
|
|
|||
|
|
@ -1,18 +1,19 @@
|
|||
"""Public API endpoints for agent triggers.
|
||||
"""Public API endpoints for public agent execution.
|
||||
|
||||
These endpoints are accessible with API key authentication and allow
|
||||
external systems to programmatically trigger phone calls.
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from fastapi import APIRouter, Header, HTTPException
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import TriggerState
|
||||
from api.enums import TriggerState, WorkflowStatus
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.telephony.factory import (
|
||||
get_default_telephony_provider,
|
||||
|
|
@ -39,6 +40,14 @@ class TriggerCallResponse(BaseModel):
|
|||
workflow_run_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedAgentTarget:
|
||||
workflow: object
|
||||
organization_id: int
|
||||
identifier_type: str
|
||||
identifier_value: str
|
||||
|
||||
|
||||
def trigger_exists_in_workflow(workflow_definition: dict, trigger_path: str) -> bool:
|
||||
"""Check if trigger node exists in workflow definition.
|
||||
|
||||
|
|
@ -57,72 +66,133 @@ def trigger_exists_in_workflow(workflow_definition: dict, trigger_path: str) ->
|
|||
return False
|
||||
|
||||
|
||||
async def _initiate_call(
|
||||
uuid: str,
|
||||
request: TriggerCallRequest,
|
||||
x_api_key: str,
|
||||
*,
|
||||
use_draft: bool,
|
||||
) -> TriggerCallResponse:
|
||||
"""Shared core for production and test trigger endpoints.
|
||||
|
||||
When ``use_draft`` is True the latest draft definition is executed;
|
||||
otherwise the published (released) definition is used.
|
||||
"""
|
||||
# 1. Validate API key
|
||||
async def _validate_api_key(x_api_key: str):
|
||||
"""Validate the org API key used to invoke a public agent endpoint."""
|
||||
api_key = await db_client.validate_api_key(x_api_key)
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
return api_key
|
||||
|
||||
# 2. Lookup agent trigger by UUID
|
||||
trigger = await db_client.get_agent_trigger_by_path(uuid)
|
||||
|
||||
def _ensure_workflow_is_active(workflow) -> None:
|
||||
if workflow.status != WorkflowStatus.ACTIVE.value:
|
||||
raise HTTPException(status_code=404, detail="Workflow is not active")
|
||||
|
||||
|
||||
def _get_execution_user_id(workflow) -> int:
|
||||
if workflow.user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Workflow has no execution owner",
|
||||
)
|
||||
return workflow.user_id
|
||||
|
||||
|
||||
async def _get_workflow_definition_for_execution(workflow, *, use_draft: bool) -> dict:
|
||||
"""Return the definition that would execute for this public agent request."""
|
||||
if use_draft:
|
||||
draft = await db_client.get_draft_version(workflow.id)
|
||||
if draft:
|
||||
return draft.workflow_json
|
||||
|
||||
if workflow.released_definition is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Workflow has no published definition"
|
||||
)
|
||||
|
||||
return workflow.released_definition.workflow_json
|
||||
|
||||
|
||||
async def _resolve_trigger_target(
|
||||
trigger_path: str,
|
||||
organization_id: int,
|
||||
*,
|
||||
use_draft: bool,
|
||||
) -> ResolvedAgentTarget:
|
||||
"""Resolve a trigger UUID to a workflow, scoped to the API key's org."""
|
||||
trigger = await db_client.get_agent_trigger_by_path(trigger_path)
|
||||
if not trigger:
|
||||
raise HTTPException(status_code=404, detail="Agent trigger not found")
|
||||
|
||||
# 3. Validate organization match (API key org must match trigger org)
|
||||
if api_key.organization_id != trigger.organization_id:
|
||||
if organization_id != trigger.organization_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# 4. Validate trigger is active
|
||||
if trigger.state != TriggerState.ACTIVE.value:
|
||||
raise HTTPException(status_code=404, detail="Agent trigger is not active")
|
||||
|
||||
# 4.5 Check Dograh quota before initiating the call (apply the trigger's
|
||||
# workflow's model_overrides so we evaluate the keys this run will use).
|
||||
workflow = await db_client.get_workflow(
|
||||
trigger.workflow_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
_ensure_workflow_is_active(workflow)
|
||||
workflow_definition = await _get_workflow_definition_for_execution(
|
||||
workflow,
|
||||
use_draft=use_draft,
|
||||
)
|
||||
if not trigger_exists_in_workflow(workflow_definition, trigger_path):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Trigger not found in the selected Agent",
|
||||
)
|
||||
|
||||
return ResolvedAgentTarget(
|
||||
workflow=workflow,
|
||||
organization_id=organization_id,
|
||||
identifier_type="trigger_path",
|
||||
identifier_value=trigger_path,
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_workflow_uuid_target(
|
||||
workflow_uuid: str,
|
||||
organization_id: int,
|
||||
*,
|
||||
use_draft: bool,
|
||||
) -> ResolvedAgentTarget:
|
||||
"""Resolve a workflow UUID directly, scoped to the API key's org."""
|
||||
workflow = await db_client.get_workflow_by_uuid(workflow_uuid, organization_id)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
_ensure_workflow_is_active(workflow)
|
||||
await _get_workflow_definition_for_execution(workflow, use_draft=use_draft)
|
||||
|
||||
return ResolvedAgentTarget(
|
||||
workflow=workflow,
|
||||
organization_id=organization_id,
|
||||
identifier_type="workflow_uuid",
|
||||
identifier_value=workflow_uuid,
|
||||
)
|
||||
|
||||
|
||||
async def _execute_resolved_target(
|
||||
target: ResolvedAgentTarget,
|
||||
request: TriggerCallRequest,
|
||||
*,
|
||||
use_draft: bool,
|
||||
api_key_id: int | None,
|
||||
api_key_created_by: int | None,
|
||||
) -> TriggerCallResponse:
|
||||
"""Shared execution path once the target workflow has been resolved."""
|
||||
execution_user_id = _get_execution_user_id(target.workflow)
|
||||
|
||||
# Check Dograh quota using the workflow owner's config and model overrides.
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
api_key.created_by, workflow_id=trigger.workflow_id
|
||||
execution_user_id,
|
||||
workflow_id=target.workflow.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# 5. Get workflow and resolve the definition (published vs draft)
|
||||
workflow = await db_client.get_workflow_by_id(trigger.workflow_id)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
if use_draft:
|
||||
draft = await db_client.get_draft_version(trigger.workflow_id)
|
||||
# Fall back to the published definition when no draft exists, so the
|
||||
# test URL always runs *something* — typically the same agent the
|
||||
# production URL would run.
|
||||
workflow_definition = (
|
||||
draft.workflow_json if draft else workflow.released_definition.workflow_json
|
||||
)
|
||||
else:
|
||||
workflow_definition = workflow.released_definition.workflow_json
|
||||
|
||||
# Validate trigger node still exists in the resolved definition
|
||||
if not trigger_exists_in_workflow(workflow_definition, uuid):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Trigger not found in the published Agent",
|
||||
)
|
||||
|
||||
# 6. Get telephony provider — either the caller-specified config (validated
|
||||
# against the trigger's org) or the org's default config.
|
||||
# Get telephony provider — either the caller-specified config (validated
|
||||
# against the workflow's org) or the org's default config.
|
||||
if request.telephony_configuration_id is not None:
|
||||
cfg = await db_client.get_telephony_configuration_for_org(
|
||||
request.telephony_configuration_id, trigger.organization_id
|
||||
request.telephony_configuration_id,
|
||||
target.organization_id,
|
||||
)
|
||||
if not cfg:
|
||||
raise HTTPException(
|
||||
|
|
@ -130,7 +200,7 @@ async def _initiate_call(
|
|||
)
|
||||
try:
|
||||
provider = await get_telephony_provider_by_id(
|
||||
cfg.id, trigger.organization_id
|
||||
cfg.id, target.organization_id
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
|
|
@ -140,14 +210,14 @@ async def _initiate_call(
|
|||
resolved_cfg_id = cfg.id
|
||||
else:
|
||||
try:
|
||||
provider = await get_default_telephony_provider(trigger.organization_id)
|
||||
provider = await get_default_telephony_provider(target.organization_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Telephony provider not configured for this organization",
|
||||
)
|
||||
default_cfg = await db_client.get_default_telephony_configuration(
|
||||
trigger.organization_id
|
||||
target.organization_id
|
||||
)
|
||||
resolved_cfg_id = default_cfg.id if default_cfg else None
|
||||
|
||||
|
|
@ -164,24 +234,36 @@ async def _initiate_call(
|
|||
# 8. Create workflow run
|
||||
mode_label = "TEST" if use_draft else "API"
|
||||
workflow_run_name = f"WR-{mode_label}-{random.randint(1000, 9999)}"
|
||||
initial_context = {
|
||||
"provider": provider.PROVIDER_NAME,
|
||||
"phone_number": request.phone_number,
|
||||
"trigger_mode": "test" if use_draft else "production",
|
||||
"telephony_configuration_id": resolved_cfg_id,
|
||||
"agent_identifier": target.identifier_value,
|
||||
"agent_identifier_type": target.identifier_type,
|
||||
"workflow_uuid": target.workflow.workflow_uuid,
|
||||
}
|
||||
if target.identifier_type == "trigger_path":
|
||||
initial_context["agent_uuid"] = target.identifier_value
|
||||
if api_key_id is not None:
|
||||
initial_context["api_key_id"] = api_key_id
|
||||
if api_key_created_by is not None:
|
||||
initial_context["api_key_created_by"] = api_key_created_by
|
||||
initial_context.update(request.initial_context or {})
|
||||
|
||||
workflow_run = await db_client.create_workflow_run(
|
||||
name=workflow_run_name,
|
||||
workflow_id=trigger.workflow_id,
|
||||
workflow_id=target.workflow.id,
|
||||
mode=workflow_run_mode,
|
||||
initial_context={
|
||||
"provider": provider.PROVIDER_NAME,
|
||||
"phone_number": request.phone_number,
|
||||
"agent_uuid": uuid,
|
||||
"trigger_mode": "test" if use_draft else "production",
|
||||
"telephony_configuration_id": resolved_cfg_id,
|
||||
**(request.initial_context or {}),
|
||||
},
|
||||
user_id=api_key.created_by,
|
||||
initial_context=initial_context,
|
||||
user_id=execution_user_id,
|
||||
use_draft=use_draft,
|
||||
organization_id=target.organization_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created workflow run {workflow_run.id} for API trigger {uuid} "
|
||||
f"Created workflow run {workflow_run.id} for public agent "
|
||||
f"{target.identifier_type}={target.identifier_value} "
|
||||
f"(mode={'test' if use_draft else 'production'}) "
|
||||
f"to phone number {request.phone_number}"
|
||||
)
|
||||
|
|
@ -192,10 +274,10 @@ async def _initiate_call(
|
|||
|
||||
webhook_url = (
|
||||
f"{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
|
||||
f"?workflow_id={trigger.workflow_id}"
|
||||
f"&user_id={api_key.created_by}"
|
||||
f"?workflow_id={target.workflow.id}"
|
||||
f"&user_id={execution_user_id}"
|
||||
f"&workflow_run_id={workflow_run.id}"
|
||||
f"&organization_id={trigger.organization_id}"
|
||||
f"&organization_id={target.organization_id}"
|
||||
)
|
||||
|
||||
# 10. Initiate call via telephony provider. workflow_id and user_id are
|
||||
|
|
@ -207,8 +289,8 @@ async def _initiate_call(
|
|||
to_number=request.phone_number,
|
||||
webhook_url=webhook_url,
|
||||
workflow_run_id=workflow_run.id,
|
||||
workflow_id=trigger.workflow_id,
|
||||
user_id=api_key.created_by,
|
||||
workflow_id=target.workflow.id,
|
||||
user_id=execution_user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
|
|
@ -221,7 +303,7 @@ async def _initiate_call(
|
|||
|
||||
logger.info(
|
||||
f"Call initiated successfully for workflow run {workflow_run.id} "
|
||||
f"via trigger {uuid}"
|
||||
f"via {target.identifier_type}={target.identifier_value}"
|
||||
)
|
||||
|
||||
return TriggerCallResponse(
|
||||
|
|
@ -231,6 +313,30 @@ async def _initiate_call(
|
|||
)
|
||||
|
||||
|
||||
async def _initiate_call(
|
||||
identifier: str,
|
||||
request: TriggerCallRequest,
|
||||
x_api_key: str,
|
||||
*,
|
||||
use_draft: bool,
|
||||
target_resolver: Callable[..., Awaitable[ResolvedAgentTarget]],
|
||||
) -> TriggerCallResponse:
|
||||
"""Resolve the requested public target, then execute the common call flow."""
|
||||
api_key = await _validate_api_key(x_api_key)
|
||||
target = await target_resolver(
|
||||
identifier,
|
||||
api_key.organization_id,
|
||||
use_draft=use_draft,
|
||||
)
|
||||
return await _execute_resolved_target(
|
||||
target,
|
||||
request,
|
||||
use_draft=use_draft,
|
||||
api_key_id=api_key.id,
|
||||
api_key_created_by=api_key.created_by,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{uuid}", response_model=TriggerCallResponse)
|
||||
async def initiate_call(
|
||||
uuid: str,
|
||||
|
|
@ -241,7 +347,13 @@ async def initiate_call(
|
|||
|
||||
Executes the workflow's currently released definition.
|
||||
"""
|
||||
return await _initiate_call(uuid, request, x_api_key, use_draft=False)
|
||||
return await _initiate_call(
|
||||
uuid,
|
||||
request,
|
||||
x_api_key,
|
||||
use_draft=False,
|
||||
target_resolver=_resolve_trigger_target,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/test/{uuid}", response_model=TriggerCallResponse)
|
||||
|
|
@ -255,4 +367,42 @@ async def initiate_call_test(
|
|||
Useful for verifying changes before publishing. Falls back to the
|
||||
published definition when no draft exists.
|
||||
"""
|
||||
return await _initiate_call(uuid, request, x_api_key, use_draft=True)
|
||||
return await _initiate_call(
|
||||
uuid,
|
||||
request,
|
||||
x_api_key,
|
||||
use_draft=True,
|
||||
target_resolver=_resolve_trigger_target,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/workflow/{workflow_uuid}", response_model=TriggerCallResponse)
|
||||
async def initiate_call_by_workflow_uuid(
|
||||
workflow_uuid: str,
|
||||
request: TriggerCallRequest,
|
||||
x_api_key: str = Header(..., alias="X-API-Key"),
|
||||
):
|
||||
"""Initiate a phone call against the published workflow identified by UUID."""
|
||||
return await _initiate_call(
|
||||
workflow_uuid,
|
||||
request,
|
||||
x_api_key,
|
||||
use_draft=False,
|
||||
target_resolver=_resolve_workflow_uuid_target,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/test/workflow/{workflow_uuid}", response_model=TriggerCallResponse)
|
||||
async def initiate_call_test_by_workflow_uuid(
|
||||
workflow_uuid: str,
|
||||
request: TriggerCallRequest,
|
||||
x_api_key: str = Header(..., alias="X-API-Key"),
|
||||
):
|
||||
"""Initiate a phone call against the latest draft of the workflow by UUID."""
|
||||
return await _initiate_call(
|
||||
workflow_uuid,
|
||||
request,
|
||||
x_api_key,
|
||||
use_draft=True,
|
||||
target_resolver=_resolve_workflow_uuid_target,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -305,7 +305,9 @@ async def _validate_inbound_request(
|
|||
"""
|
||||
from api.services.telephony import registry as telephony_registry
|
||||
|
||||
workflow = await db_client.get_workflow(workflow_id)
|
||||
# System lookup: inbound routing only has the workflow_id and derives the
|
||||
# org/user from the workflow itself, so use the explicit unscoped variant.
|
||||
workflow = await db_client.get_workflow_by_id(workflow_id)
|
||||
if not workflow:
|
||||
return False, TelephonyError.WORKFLOW_NOT_FOUND, {}, None
|
||||
|
||||
|
|
@ -528,8 +530,9 @@ async def _handle_telephony_websocket(
|
|||
await websocket.close(code=4404, reason="Workflow run not found")
|
||||
return
|
||||
|
||||
# Get workflow for organization info
|
||||
workflow = await db_client.get_workflow(workflow_id)
|
||||
# Get workflow for organization info. System lookup keyed only on the
|
||||
# workflow_id (org is derived below) — use the explicit unscoped variant.
|
||||
workflow = await db_client.get_workflow_by_id(workflow_id)
|
||||
if not workflow:
|
||||
logger.error(f"Workflow {workflow_id} not found")
|
||||
await websocket.close(code=4404, reason="Workflow not found")
|
||||
|
|
|
|||
|
|
@ -32,99 +32,16 @@ from api.services.storage import storage_fs
|
|||
from api.services.workflow.dto import ReactFlowDTO, sanitize_workflow_definition
|
||||
from api.services.workflow.duplicate import duplicate_workflow
|
||||
from api.services.workflow.errors import ItemKind, WorkflowError
|
||||
from api.services.workflow.trigger_paths import (
|
||||
TriggerPathIssue,
|
||||
ensure_trigger_paths,
|
||||
extract_trigger_paths,
|
||||
regenerate_trigger_uuids,
|
||||
trigger_path_to_node_id,
|
||||
validate_trigger_paths,
|
||||
)
|
||||
from api.services.workflow.workflow_graph import WorkflowGraph
|
||||
|
||||
|
||||
def extract_trigger_paths(workflow_definition: dict) -> List[str]:
|
||||
"""Extract trigger UUIDs from workflow definition.
|
||||
|
||||
Args:
|
||||
workflow_definition: The workflow definition JSON
|
||||
|
||||
Returns:
|
||||
List of trigger UUIDs found in the workflow
|
||||
"""
|
||||
if not workflow_definition:
|
||||
return []
|
||||
|
||||
nodes = workflow_definition.get("nodes", [])
|
||||
trigger_paths = []
|
||||
|
||||
for node in nodes:
|
||||
if node.get("type") == "trigger":
|
||||
trigger_path = node.get("data", {}).get("trigger_path")
|
||||
if trigger_path:
|
||||
trigger_paths.append(trigger_path)
|
||||
|
||||
return trigger_paths
|
||||
|
||||
|
||||
def _trigger_path_to_node_id(workflow_definition: dict) -> dict[str, str]:
|
||||
"""Map each trigger node's trigger_path to its node id."""
|
||||
if not workflow_definition:
|
||||
return {}
|
||||
out: dict[str, str] = {}
|
||||
for node in workflow_definition.get("nodes", []):
|
||||
if node.get("type") == "trigger":
|
||||
tp = node.get("data", {}).get("trigger_path")
|
||||
if tp:
|
||||
out[tp] = node.get("id")
|
||||
return out
|
||||
|
||||
|
||||
def regenerate_trigger_uuids(workflow_definition: dict) -> dict:
|
||||
"""Regenerate UUIDs for all trigger nodes in a workflow definition.
|
||||
|
||||
This should be called when creating a new workflow from a template or
|
||||
duplicating a workflow to avoid trigger UUID conflicts.
|
||||
|
||||
Args:
|
||||
workflow_definition: The workflow definition JSON
|
||||
|
||||
Returns:
|
||||
Updated workflow definition with new trigger UUIDs
|
||||
"""
|
||||
if not workflow_definition:
|
||||
return workflow_definition
|
||||
|
||||
# Deep copy to avoid modifying the original
|
||||
import copy
|
||||
|
||||
updated_definition = copy.deepcopy(workflow_definition)
|
||||
|
||||
nodes = updated_definition.get("nodes", [])
|
||||
for node in nodes:
|
||||
if node.get("type") == "trigger":
|
||||
# Generate a new UUID for this trigger
|
||||
if "data" not in node:
|
||||
node["data"] = {}
|
||||
node["data"]["trigger_path"] = str(uuid.uuid4())
|
||||
|
||||
return updated_definition
|
||||
|
||||
|
||||
def ensure_trigger_paths(workflow_definition: Optional[dict]) -> Optional[dict]:
|
||||
"""Mint a UUID for any trigger node that's missing ``data.trigger_path``.
|
||||
|
||||
Trigger nodes that already carry a non-empty trigger_path are left
|
||||
untouched so stable IDs survive edits. The input is not mutated; the
|
||||
returned dict is what should be persisted and echoed in the response.
|
||||
"""
|
||||
if not workflow_definition:
|
||||
return workflow_definition
|
||||
|
||||
import copy
|
||||
|
||||
out = copy.deepcopy(workflow_definition)
|
||||
for node in out.get("nodes") or []:
|
||||
if node.get("type") != "trigger":
|
||||
continue
|
||||
data = node.setdefault("data", {})
|
||||
if not data.get("trigger_path"):
|
||||
data["trigger_path"] = str(uuid.uuid4())
|
||||
return out
|
||||
|
||||
|
||||
router = APIRouter(prefix="/workflow")
|
||||
|
||||
|
||||
|
|
@ -139,7 +56,7 @@ def _trigger_conflict_http_exception(
|
|||
"""Build a 409 with the same detail shape as validate's 422 so the editor
|
||||
can highlight the offending trigger node(s) using the same code path."""
|
||||
path_to_node = (
|
||||
_trigger_path_to_node_id(workflow_definition) if workflow_definition else {}
|
||||
trigger_path_to_node_id(workflow_definition) if workflow_definition else {}
|
||||
)
|
||||
errors: list[WorkflowError] = [
|
||||
WorkflowError(
|
||||
|
|
@ -159,6 +76,24 @@ def _trigger_conflict_http_exception(
|
|||
)
|
||||
|
||||
|
||||
def _trigger_path_validation_http_exception(
|
||||
issues: list[TriggerPathIssue],
|
||||
) -> HTTPException:
|
||||
errors = [
|
||||
WorkflowError(
|
||||
kind=ItemKind.node,
|
||||
id=issue.node_id,
|
||||
field="data.trigger_path",
|
||||
message=issue.message,
|
||||
)
|
||||
for issue in issues
|
||||
]
|
||||
return HTTPException(
|
||||
status_code=422,
|
||||
detail=ValidateWorkflowResponse(is_valid=False, errors=errors).model_dump(),
|
||||
)
|
||||
|
||||
|
||||
async def _validate_workflow_definition(
|
||||
workflow_definition: Optional[dict],
|
||||
exclude_workflow_id: Optional[int] = None,
|
||||
|
|
@ -187,6 +122,17 @@ async def _validate_workflow_definition(
|
|||
except ValueError as e:
|
||||
errors.extend(e.args[0])
|
||||
|
||||
# ----------- Trigger Path Format Check ------------
|
||||
for issue in validate_trigger_paths(workflow_definition):
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
kind=ItemKind.node,
|
||||
id=issue.node_id,
|
||||
field="data.trigger_path",
|
||||
message=issue.message,
|
||||
)
|
||||
)
|
||||
|
||||
# ----------- Trigger Path Conflict Check ------------
|
||||
trigger_paths = extract_trigger_paths(workflow_definition)
|
||||
if trigger_paths:
|
||||
|
|
@ -195,7 +141,7 @@ async def _validate_workflow_definition(
|
|||
exclude_workflow_id=exclude_workflow_id,
|
||||
)
|
||||
if conflicts:
|
||||
path_to_node = _trigger_path_to_node_id(workflow_definition)
|
||||
path_to_node = trigger_path_to_node_id(workflow_definition)
|
||||
for conflicting_path in conflicts:
|
||||
errors.append(
|
||||
WorkflowError(
|
||||
|
|
@ -251,6 +197,14 @@ class WorkflowListResponse(BaseModel):
|
|||
status: str
|
||||
created_at: datetime
|
||||
total_runs: int
|
||||
folder_id: int | None = None
|
||||
workflow_uuid: str | None = None
|
||||
|
||||
|
||||
class MoveWorkflowToFolderRequest(BaseModel):
|
||||
"""Move a workflow into a folder, or to "Uncategorized" when null."""
|
||||
|
||||
folder_id: int | None = None
|
||||
|
||||
|
||||
class WorkflowCountResponse(BaseModel):
|
||||
|
|
@ -404,6 +358,9 @@ async def create_workflow(
|
|||
# Auto-mint trigger_path for any trigger node that didn't ship one so
|
||||
# clients don't need to generate UUIDs themselves.
|
||||
workflow_definition = ensure_trigger_paths(request.workflow_definition)
|
||||
trigger_path_issues = validate_trigger_paths(workflow_definition)
|
||||
if trigger_path_issues:
|
||||
raise _trigger_path_validation_http_exception(trigger_path_issues)
|
||||
|
||||
# Validate trigger path uniqueness BEFORE creating the workflow so we
|
||||
# don't leave an orphaned workflow record when the trigger conflicts.
|
||||
|
|
@ -641,6 +598,8 @@ async def get_workflows(
|
|||
status=workflow.status,
|
||||
created_at=workflow.created_at,
|
||||
total_runs=run_counts.get(workflow.id, 0),
|
||||
folder_id=workflow.folder_id,
|
||||
workflow_uuid=workflow.workflow_uuid,
|
||||
)
|
||||
for workflow in workflows
|
||||
]
|
||||
|
|
@ -883,6 +842,48 @@ async def update_workflow_status(
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/{workflow_id}/folder")
|
||||
async def move_workflow_to_folder(
|
||||
workflow_id: int,
|
||||
request: MoveWorkflowToFolderRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> WorkflowListResponse:
|
||||
"""Move a workflow into a folder, or to "Uncategorized" (folder_id=null).
|
||||
|
||||
Validates that the target folder belongs to the caller's organization —
|
||||
the FK alone proves the folder exists, not that the caller may use it.
|
||||
"""
|
||||
# Validate target folder ownership (tenant isolation) unless un-filing.
|
||||
if request.folder_id is not None:
|
||||
folder = await db_client.get_folder(
|
||||
request.folder_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if folder is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Folder with id {request.folder_id} not found",
|
||||
)
|
||||
|
||||
try:
|
||||
workflow = await db_client.move_workflow_to_folder(
|
||||
workflow_id=workflow_id,
|
||||
folder_id=request.folder_id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
run_count = await db_client.get_workflow_run_count(workflow.id)
|
||||
return WorkflowListResponse(
|
||||
id=workflow.id,
|
||||
name=workflow.name,
|
||||
status=workflow.status,
|
||||
created_at=workflow.created_at,
|
||||
total_runs=run_count,
|
||||
folder_id=workflow.folder_id,
|
||||
)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/{workflow_id}",
|
||||
**sdk_expose(
|
||||
|
|
@ -917,6 +918,9 @@ async def update_workflow(
|
|||
# response echoes workflow_definition so the client picks up the new
|
||||
# UUID without a refetch.
|
||||
workflow_definition = ensure_trigger_paths(workflow_definition)
|
||||
trigger_path_issues = validate_trigger_paths(workflow_definition)
|
||||
if trigger_path_issues:
|
||||
raise _trigger_path_validation_http_exception(trigger_path_issues)
|
||||
if workflow_definition:
|
||||
existing_workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
|
|
|
|||
49
api/services/configuration/options/__init__.py
Normal file
49
api/services/configuration/options/__init__.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
from .deepgram import DEEPGRAM_LANGUAGES, DEEPGRAM_STT_MODELS
|
||||
from .gladia import GLADIA_STT_LANGUAGES, GLADIA_STT_MODELS
|
||||
from .google import (
|
||||
GOOGLE_MODELS,
|
||||
GOOGLE_REALTIME_LANGUAGES,
|
||||
GOOGLE_REALTIME_MODELS,
|
||||
GOOGLE_REALTIME_VOICES,
|
||||
GOOGLE_STT_LANGUAGES,
|
||||
GOOGLE_STT_MODELS,
|
||||
GOOGLE_TTS_LANGUAGES,
|
||||
GOOGLE_TTS_MODELS,
|
||||
GOOGLE_TTS_VOICES,
|
||||
GOOGLE_VERTEX_REALTIME_LANGUAGES,
|
||||
GOOGLE_VERTEX_REALTIME_MODELS,
|
||||
GOOGLE_VERTEX_REALTIME_VOICES,
|
||||
)
|
||||
from .sarvam import (
|
||||
SARVAM_LANGUAGES,
|
||||
SARVAM_STT_MODELS,
|
||||
SARVAM_TTS_MODELS,
|
||||
SARVAM_V2_VOICES,
|
||||
SARVAM_V3_VOICES,
|
||||
)
|
||||
from .speechmatics import SPEECHMATICS_STT_LANGUAGES
|
||||
|
||||
__all__ = [
|
||||
"DEEPGRAM_LANGUAGES",
|
||||
"DEEPGRAM_STT_MODELS",
|
||||
"GLADIA_STT_LANGUAGES",
|
||||
"GLADIA_STT_MODELS",
|
||||
"GOOGLE_MODELS",
|
||||
"GOOGLE_REALTIME_LANGUAGES",
|
||||
"GOOGLE_REALTIME_MODELS",
|
||||
"GOOGLE_REALTIME_VOICES",
|
||||
"GOOGLE_STT_LANGUAGES",
|
||||
"GOOGLE_STT_MODELS",
|
||||
"GOOGLE_TTS_LANGUAGES",
|
||||
"GOOGLE_TTS_MODELS",
|
||||
"GOOGLE_TTS_VOICES",
|
||||
"GOOGLE_VERTEX_REALTIME_LANGUAGES",
|
||||
"GOOGLE_VERTEX_REALTIME_MODELS",
|
||||
"GOOGLE_VERTEX_REALTIME_VOICES",
|
||||
"SARVAM_LANGUAGES",
|
||||
"SARVAM_STT_MODELS",
|
||||
"SARVAM_TTS_MODELS",
|
||||
"SARVAM_V2_VOICES",
|
||||
"SARVAM_V3_VOICES",
|
||||
"SPEECHMATICS_STT_LANGUAGES",
|
||||
]
|
||||
84
api/services/configuration/options/deepgram.py
Normal file
84
api/services/configuration/options/deepgram.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
DEEPGRAM_STT_MODELS = ("nova-3-general", "flux-general-en", "flux-general-multi")
|
||||
DEEPGRAM_LANGUAGES = (
|
||||
"multi",
|
||||
"ar",
|
||||
"ar-AE",
|
||||
"ar-SA",
|
||||
"ar-QA",
|
||||
"ar-KW",
|
||||
"ar-SY",
|
||||
"ar-LB",
|
||||
"ar-PS",
|
||||
"ar-JO",
|
||||
"ar-EG",
|
||||
"ar-SD",
|
||||
"ar-TD",
|
||||
"ar-MA",
|
||||
"ar-DZ",
|
||||
"ar-TN",
|
||||
"ar-IQ",
|
||||
"ar-IR",
|
||||
"be",
|
||||
"bn",
|
||||
"bs",
|
||||
"bg",
|
||||
"ca",
|
||||
"cs",
|
||||
"da",
|
||||
"da-DK",
|
||||
"de",
|
||||
"de-CH",
|
||||
"el",
|
||||
"en",
|
||||
"en-US",
|
||||
"en-AU",
|
||||
"en-GB",
|
||||
"en-IN",
|
||||
"en-NZ",
|
||||
"es",
|
||||
"es-419",
|
||||
"et",
|
||||
"fa",
|
||||
"fi",
|
||||
"fr",
|
||||
"fr-CA",
|
||||
"he",
|
||||
"hi",
|
||||
"hr",
|
||||
"hu",
|
||||
"id",
|
||||
"it",
|
||||
"ja",
|
||||
"kn",
|
||||
"ko",
|
||||
"ko-KR",
|
||||
"lt",
|
||||
"lv",
|
||||
"mk",
|
||||
"mr",
|
||||
"ms",
|
||||
"nl",
|
||||
"nl-BE",
|
||||
"no",
|
||||
"pl",
|
||||
"pt",
|
||||
"pt-BR",
|
||||
"pt-PT",
|
||||
"ro",
|
||||
"ru",
|
||||
"sk",
|
||||
"sl",
|
||||
"sr",
|
||||
"sv",
|
||||
"sv-SE",
|
||||
"ta",
|
||||
"te",
|
||||
"th",
|
||||
"tl",
|
||||
"tr",
|
||||
"uk",
|
||||
"ur",
|
||||
"vi",
|
||||
"zh-CN",
|
||||
"zh-TW",
|
||||
)
|
||||
103
api/services/configuration/options/gladia.py
Normal file
103
api/services/configuration/options/gladia.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
GLADIA_STT_MODELS = ("solaria-1",)
|
||||
GLADIA_STT_LANGUAGES = (
|
||||
"af",
|
||||
"am",
|
||||
"ar",
|
||||
"as",
|
||||
"az",
|
||||
"ba",
|
||||
"be",
|
||||
"bg",
|
||||
"bn",
|
||||
"bo",
|
||||
"br",
|
||||
"bs",
|
||||
"ca",
|
||||
"cs",
|
||||
"cy",
|
||||
"da",
|
||||
"de",
|
||||
"el",
|
||||
"en",
|
||||
"es",
|
||||
"et",
|
||||
"eu",
|
||||
"fa",
|
||||
"fi",
|
||||
"fo",
|
||||
"fr",
|
||||
"gl",
|
||||
"gu",
|
||||
"ha",
|
||||
"haw",
|
||||
"he",
|
||||
"hi",
|
||||
"hr",
|
||||
"ht",
|
||||
"hu",
|
||||
"hy",
|
||||
"id",
|
||||
"is",
|
||||
"it",
|
||||
"ja",
|
||||
"jw",
|
||||
"ka",
|
||||
"kk",
|
||||
"km",
|
||||
"kn",
|
||||
"ko",
|
||||
"la",
|
||||
"lb",
|
||||
"ln",
|
||||
"lo",
|
||||
"lt",
|
||||
"lv",
|
||||
"mg",
|
||||
"mi",
|
||||
"mk",
|
||||
"ml",
|
||||
"mn",
|
||||
"mr",
|
||||
"ms",
|
||||
"mt",
|
||||
"my",
|
||||
"ne",
|
||||
"nl",
|
||||
"nn",
|
||||
"no",
|
||||
"oc",
|
||||
"pa",
|
||||
"pl",
|
||||
"ps",
|
||||
"pt",
|
||||
"ro",
|
||||
"ru",
|
||||
"sa",
|
||||
"sd",
|
||||
"si",
|
||||
"sk",
|
||||
"sl",
|
||||
"sn",
|
||||
"so",
|
||||
"sq",
|
||||
"sr",
|
||||
"su",
|
||||
"sv",
|
||||
"sw",
|
||||
"ta",
|
||||
"te",
|
||||
"tg",
|
||||
"th",
|
||||
"tk",
|
||||
"tl",
|
||||
"tr",
|
||||
"tt",
|
||||
"uk",
|
||||
"ur",
|
||||
"uz",
|
||||
"vi",
|
||||
"wo",
|
||||
"yi",
|
||||
"yo",
|
||||
"zh",
|
||||
)
|
||||
273
api/services/configuration/options/google.py
Normal file
273
api/services/configuration/options/google.py
Normal file
|
|
@ -0,0 +1,273 @@
|
|||
GOOGLE_MODELS = (
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-lite",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
)
|
||||
|
||||
GOOGLE_REALTIME_MODELS = ("gemini-3.1-flash-live-preview",)
|
||||
GOOGLE_REALTIME_VOICES = ("Puck", "Charon", "Kore", "Fenrir", "Aoede")
|
||||
GOOGLE_REALTIME_LANGUAGES = (
|
||||
"ar",
|
||||
"bn",
|
||||
"de",
|
||||
"en",
|
||||
"es",
|
||||
"fr",
|
||||
"gu",
|
||||
"hi",
|
||||
"id",
|
||||
"it",
|
||||
"ja",
|
||||
"kn",
|
||||
"ko",
|
||||
"ml",
|
||||
"mr",
|
||||
"nl",
|
||||
"pl",
|
||||
"pt",
|
||||
"ru",
|
||||
"ta",
|
||||
"te",
|
||||
"th",
|
||||
"tr",
|
||||
"vi",
|
||||
"zh",
|
||||
)
|
||||
|
||||
GOOGLE_VERTEX_REALTIME_MODELS = ("google/gemini-live-2.5-flash-native-audio",)
|
||||
GOOGLE_VERTEX_REALTIME_VOICES = GOOGLE_REALTIME_VOICES
|
||||
GOOGLE_VERTEX_REALTIME_LANGUAGES = GOOGLE_REALTIME_LANGUAGES
|
||||
|
||||
GOOGLE_STT_MODELS = ("latest_long", "latest_short", "chirp_3")
|
||||
# Docs-derived from Google Cloud Speech-to-Text V2 supported languages.
|
||||
GOOGLE_STT_LANGUAGES = (
|
||||
"af-ZA",
|
||||
"am-ET",
|
||||
"ar-AE",
|
||||
"ar-BH",
|
||||
"ar-DZ",
|
||||
"ar-EG",
|
||||
"ar-IL",
|
||||
"ar-IQ",
|
||||
"ar-JO",
|
||||
"ar-KW",
|
||||
"ar-LB",
|
||||
"ar-MA",
|
||||
"ar-MR",
|
||||
"ar-OM",
|
||||
"ar-PS",
|
||||
"ar-QA",
|
||||
"ar-SA",
|
||||
"ar-SY",
|
||||
"ar-TN",
|
||||
"ar-XA",
|
||||
"ar-YE",
|
||||
"as-IN",
|
||||
"ast-ES",
|
||||
"az-AZ",
|
||||
"be-BY",
|
||||
"bg-BG",
|
||||
"bn-BD",
|
||||
"bn-IN",
|
||||
"bs-BA",
|
||||
"ca-ES",
|
||||
"ceb-PH",
|
||||
"ckb-IQ",
|
||||
"cmn-Hans-CN",
|
||||
"cmn-Hant-TW",
|
||||
"cs-CZ",
|
||||
"cy-GB",
|
||||
"da-DK",
|
||||
"de-AT",
|
||||
"de-CH",
|
||||
"de-DE",
|
||||
"el-GR",
|
||||
"en-AU",
|
||||
"en-GB",
|
||||
"en-HK",
|
||||
"en-IE",
|
||||
"en-IN",
|
||||
"en-NZ",
|
||||
"en-PH",
|
||||
"en-PK",
|
||||
"en-SG",
|
||||
"en-US",
|
||||
"es-419",
|
||||
"es-AR",
|
||||
"es-BO",
|
||||
"es-CL",
|
||||
"es-CO",
|
||||
"es-CR",
|
||||
"es-DO",
|
||||
"es-EC",
|
||||
"es-ES",
|
||||
"es-GT",
|
||||
"es-HN",
|
||||
"es-MX",
|
||||
"es-NI",
|
||||
"es-PA",
|
||||
"es-PE",
|
||||
"es-PR",
|
||||
"es-SV",
|
||||
"es-US",
|
||||
"es-UY",
|
||||
"es-VE",
|
||||
"et-EE",
|
||||
"eu-ES",
|
||||
"fa-IR",
|
||||
"ff-SN",
|
||||
"fi-FI",
|
||||
"fil-PH",
|
||||
"fr-BE",
|
||||
"fr-CA",
|
||||
"fr-CH",
|
||||
"fr-FR",
|
||||
"ga-IE",
|
||||
"gl-ES",
|
||||
"gu-IN",
|
||||
"ha-NG",
|
||||
"hi-IN",
|
||||
"hr-HR",
|
||||
"hu-HU",
|
||||
"hy-AM",
|
||||
"id-ID",
|
||||
"ig-NG",
|
||||
"is-IS",
|
||||
"it-CH",
|
||||
"it-IT",
|
||||
"iw-IL",
|
||||
"ja-JP",
|
||||
"jv-ID",
|
||||
"ka-GE",
|
||||
"kam-KE",
|
||||
"kea-CV",
|
||||
"kk-KZ",
|
||||
"km-KH",
|
||||
"kn-IN",
|
||||
"ko-KR",
|
||||
"ky-KG",
|
||||
"lb-LU",
|
||||
"lg-UG",
|
||||
"ln-CD",
|
||||
"lo-LA",
|
||||
"lt-LT",
|
||||
"luo-KE",
|
||||
"lv-LV",
|
||||
"mi-NZ",
|
||||
"mk-MK",
|
||||
"ml-IN",
|
||||
"mn-MN",
|
||||
"mr-IN",
|
||||
"ms-MY",
|
||||
"mt-MT",
|
||||
"my-MM",
|
||||
"ne-NP",
|
||||
"nl-BE",
|
||||
"nl-NL",
|
||||
"no-NO",
|
||||
"nso-ZA",
|
||||
"ny-MW",
|
||||
"oc-FR",
|
||||
"om-ET",
|
||||
"or-IN",
|
||||
"pa-Guru-IN",
|
||||
"pl-PL",
|
||||
"ps-AF",
|
||||
"pt-BR",
|
||||
"pt-PT",
|
||||
"ro-RO",
|
||||
"ru-RU",
|
||||
"rup-BG",
|
||||
"rw-RW",
|
||||
"sd-IN",
|
||||
"si-LK",
|
||||
"sk-SK",
|
||||
"sl-SI",
|
||||
"sn-ZW",
|
||||
"so-SO",
|
||||
"sq-AL",
|
||||
"sr-RS",
|
||||
"ss-Latn-ZA",
|
||||
"st-ZA",
|
||||
"su-ID",
|
||||
"sv-SE",
|
||||
"sw",
|
||||
"sw-KE",
|
||||
"ta-IN",
|
||||
"te-IN",
|
||||
"tg-TJ",
|
||||
"th-TH",
|
||||
"tn-Latn-ZA",
|
||||
"tr-TR",
|
||||
"ts-ZA",
|
||||
"uk-UA",
|
||||
"umb-AO",
|
||||
"ur-PK",
|
||||
"uz-UZ",
|
||||
"ve-ZA",
|
||||
"vi-VN",
|
||||
"wo-SN",
|
||||
"xh-ZA",
|
||||
"yo-NG",
|
||||
"yue-Hant-HK",
|
||||
"zu-ZA",
|
||||
)
|
||||
|
||||
GOOGLE_TTS_MODELS = ("chirp_3_hd",)
|
||||
GOOGLE_TTS_VOICES = ("en-US-Chirp3-HD-Charon",)
|
||||
GOOGLE_TTS_LANGUAGES = (
|
||||
"ar-XA",
|
||||
"bn-IN",
|
||||
"bg-BG",
|
||||
"yue-HK",
|
||||
"hr-HR",
|
||||
"cs-CZ",
|
||||
"da-DK",
|
||||
"nl-BE",
|
||||
"nl-NL",
|
||||
"en-AU",
|
||||
"en-IN",
|
||||
"en-GB",
|
||||
"en-US",
|
||||
"et-EE",
|
||||
"fi-FI",
|
||||
"fr-CA",
|
||||
"fr-FR",
|
||||
"de-DE",
|
||||
"el-GR",
|
||||
"gu-IN",
|
||||
"he-IL",
|
||||
"hi-IN",
|
||||
"hu-HU",
|
||||
"id-ID",
|
||||
"it-IT",
|
||||
"ja-JP",
|
||||
"kn-IN",
|
||||
"ko-KR",
|
||||
"lv-LV",
|
||||
"lt-LT",
|
||||
"ml-IN",
|
||||
"cmn-CN",
|
||||
"mr-IN",
|
||||
"nb-NO",
|
||||
"pl-PL",
|
||||
"pt-BR",
|
||||
"pa-IN",
|
||||
"ro-RO",
|
||||
"ru-RU",
|
||||
"sr-RS",
|
||||
"sk-SK",
|
||||
"sl-SI",
|
||||
"es-ES",
|
||||
"es-US",
|
||||
"sw-KE",
|
||||
"sv-SE",
|
||||
"ta-IN",
|
||||
"te-IN",
|
||||
"th-TH",
|
||||
"tr-TR",
|
||||
"uk-UA",
|
||||
"ur-IN",
|
||||
"vi-VN",
|
||||
)
|
||||
66
api/services/configuration/options/sarvam.py
Normal file
66
api/services/configuration/options/sarvam.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
SARVAM_TTS_MODELS = ("bulbul:v2", "bulbul:v3")
|
||||
SARVAM_V2_VOICES = (
|
||||
"anushka",
|
||||
"manisha",
|
||||
"vidya",
|
||||
"arya",
|
||||
"abhilash",
|
||||
"karun",
|
||||
"hitesh",
|
||||
)
|
||||
SARVAM_V3_VOICES = (
|
||||
"shubh",
|
||||
"aditya",
|
||||
"ritu",
|
||||
"priya",
|
||||
"neha",
|
||||
"rahul",
|
||||
"pooja",
|
||||
"rohan",
|
||||
"simran",
|
||||
"kavya",
|
||||
"amit",
|
||||
"dev",
|
||||
"ishita",
|
||||
"shreya",
|
||||
"ratan",
|
||||
"varun",
|
||||
"manan",
|
||||
"sumit",
|
||||
"roopa",
|
||||
"kabir",
|
||||
"aayan",
|
||||
"ashutosh",
|
||||
"advait",
|
||||
"amelia",
|
||||
"sophia",
|
||||
"anand",
|
||||
"tanya",
|
||||
"tarun",
|
||||
"sunny",
|
||||
"mani",
|
||||
"gokul",
|
||||
"vijay",
|
||||
"shruti",
|
||||
"suhani",
|
||||
"mohit",
|
||||
"kavitha",
|
||||
"rehan",
|
||||
"soham",
|
||||
"rupali",
|
||||
)
|
||||
SARVAM_LANGUAGES = (
|
||||
"bn-IN",
|
||||
"en-IN",
|
||||
"gu-IN",
|
||||
"hi-IN",
|
||||
"kn-IN",
|
||||
"ml-IN",
|
||||
"mr-IN",
|
||||
"od-IN",
|
||||
"pa-IN",
|
||||
"ta-IN",
|
||||
"te-IN",
|
||||
"as-IN",
|
||||
)
|
||||
SARVAM_STT_MODELS = ("saarika:v2.5", "saaras:v2")
|
||||
63
api/services/configuration/options/speechmatics.py
Normal file
63
api/services/configuration/options/speechmatics.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
SPEECHMATICS_STT_LANGUAGES = (
|
||||
"ar",
|
||||
"ar_en",
|
||||
"ba",
|
||||
"eu",
|
||||
"be",
|
||||
"bn",
|
||||
"bg",
|
||||
"yue",
|
||||
"ca",
|
||||
"hr",
|
||||
"cs",
|
||||
"da",
|
||||
"nl",
|
||||
"en",
|
||||
"eo",
|
||||
"et",
|
||||
"fi",
|
||||
"fr",
|
||||
"gl",
|
||||
"de",
|
||||
"el",
|
||||
"he",
|
||||
"hi",
|
||||
"hu",
|
||||
"id",
|
||||
"ia",
|
||||
"ga",
|
||||
"it",
|
||||
"ja",
|
||||
"ko",
|
||||
"lv",
|
||||
"lt",
|
||||
"ms",
|
||||
"en_ms",
|
||||
"mt",
|
||||
"cmn",
|
||||
"cmn_en",
|
||||
"cmn_en_ms_ta",
|
||||
"mr",
|
||||
"mn",
|
||||
"no",
|
||||
"fa",
|
||||
"pl",
|
||||
"pt",
|
||||
"ro",
|
||||
"ru",
|
||||
"sk",
|
||||
"sl",
|
||||
"es",
|
||||
"sw",
|
||||
"sv",
|
||||
"tl",
|
||||
"ta",
|
||||
"en_ta",
|
||||
"th",
|
||||
"tr",
|
||||
"uk",
|
||||
"ur",
|
||||
"ug",
|
||||
"vi",
|
||||
"cy",
|
||||
)
|
||||
|
|
@ -2,7 +2,32 @@ import random
|
|||
from enum import Enum, auto
|
||||
from typing import Annotated, Dict, Literal, Type, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, computed_field, field_validator
|
||||
|
||||
from api.services.configuration.options import (
|
||||
DEEPGRAM_LANGUAGES,
|
||||
DEEPGRAM_STT_MODELS,
|
||||
GLADIA_STT_LANGUAGES,
|
||||
GLADIA_STT_MODELS,
|
||||
GOOGLE_MODELS,
|
||||
GOOGLE_REALTIME_LANGUAGES,
|
||||
GOOGLE_REALTIME_MODELS,
|
||||
GOOGLE_REALTIME_VOICES,
|
||||
GOOGLE_STT_LANGUAGES,
|
||||
GOOGLE_STT_MODELS,
|
||||
GOOGLE_TTS_LANGUAGES,
|
||||
GOOGLE_TTS_MODELS,
|
||||
GOOGLE_TTS_VOICES,
|
||||
GOOGLE_VERTEX_REALTIME_LANGUAGES,
|
||||
GOOGLE_VERTEX_REALTIME_MODELS,
|
||||
GOOGLE_VERTEX_REALTIME_VOICES,
|
||||
SARVAM_LANGUAGES,
|
||||
SARVAM_STT_MODELS,
|
||||
SARVAM_TTS_MODELS,
|
||||
SARVAM_V2_VOICES,
|
||||
SARVAM_V3_VOICES,
|
||||
SPEECHMATICS_STT_LANGUAGES,
|
||||
)
|
||||
|
||||
|
||||
class ServiceType(Enum):
|
||||
|
|
@ -153,9 +178,56 @@ def register_embeddings(cls: Type[BaseEmbeddingsConfiguration]):
|
|||
return register_service(ServiceType.EMBEDDINGS)(cls)
|
||||
|
||||
|
||||
def provider_model_config(
|
||||
title: str,
|
||||
*,
|
||||
description: str | None = None,
|
||||
provider_docs_url: str | None = None,
|
||||
) -> ConfigDict:
|
||||
json_schema_extra: dict[str, str] = {}
|
||||
if description is not None:
|
||||
json_schema_extra["description"] = description
|
||||
if provider_docs_url is not None:
|
||||
json_schema_extra["provider_docs_url"] = provider_docs_url
|
||||
if json_schema_extra:
|
||||
return ConfigDict(title=title, json_schema_extra=json_schema_extra)
|
||||
return ConfigDict(title=title)
|
||||
|
||||
|
||||
###################################################### LLM ########################################################################
|
||||
|
||||
# Suggested models for each provider (used for UI dropdown)
|
||||
OPENAI_PROVIDER_MODEL_CONFIG = provider_model_config("OpenAI")
|
||||
GOOGLE_PROVIDER_MODEL_CONFIG = provider_model_config("Google")
|
||||
GROQ_PROVIDER_MODEL_CONFIG = provider_model_config("Groq")
|
||||
OPENROUTER_PROVIDER_MODEL_CONFIG = provider_model_config("Open Router")
|
||||
AZURE_OPENAI_PROVIDER_MODEL_CONFIG = provider_model_config("Azure OpenAI")
|
||||
DOGRAH_PROVIDER_MODEL_CONFIG = provider_model_config("Dograh")
|
||||
AWS_BEDROCK_PROVIDER_MODEL_CONFIG = provider_model_config("AWS Bedrock")
|
||||
OPENAI_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("OpenAI Realtime")
|
||||
GOOGLE_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("Google Realtime")
|
||||
GOOGLE_VERTEX_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config(
|
||||
"Google Vertex Realtime"
|
||||
)
|
||||
DEEPGRAM_PROVIDER_MODEL_CONFIG = provider_model_config("Deepgram")
|
||||
ELEVENLABS_PROVIDER_MODEL_CONFIG = provider_model_config("ElevenLabs")
|
||||
CARTESIA_PROVIDER_MODEL_CONFIG = provider_model_config("Cartesia")
|
||||
SARVAM_PROVIDER_MODEL_CONFIG = provider_model_config("Sarvam")
|
||||
CAMB_PROVIDER_MODEL_CONFIG = provider_model_config("Camb.ai")
|
||||
RIME_PROVIDER_MODEL_CONFIG = provider_model_config("Rime")
|
||||
GOOGLE_CLOUD_PROVIDER_MODEL_CONFIG = provider_model_config("Google Cloud")
|
||||
SPEECHMATICS_PROVIDER_MODEL_CONFIG = provider_model_config("Speechmatics")
|
||||
ASSEMBLYAI_PROVIDER_MODEL_CONFIG = provider_model_config("AssemblyAI")
|
||||
GLADIA_PROVIDER_MODEL_CONFIG = provider_model_config("Gladia")
|
||||
SPEACHES_PROVIDER_MODEL_CONFIG = provider_model_config(
|
||||
"Local Models (Speaches)",
|
||||
description=(
|
||||
"Self-hosted OpenAI-compatible local models. See the Speaches project "
|
||||
"for setup and supported backends."
|
||||
),
|
||||
provider_docs_url="https://github.com/speaches-ai/speaches",
|
||||
)
|
||||
|
||||
OPENAI_MODELS = [
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
|
|
@ -165,12 +237,6 @@ OPENAI_MODELS = [
|
|||
"gpt-5-nano",
|
||||
"gpt-3.5-turbo",
|
||||
]
|
||||
GOOGLE_MODELS = [
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-lite",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
]
|
||||
GROQ_MODELS = [
|
||||
"llama-3.3-70b-versatile",
|
||||
"deepseek-r1-distill-llama-70b",
|
||||
|
|
@ -204,6 +270,7 @@ AWS_BEDROCK_MODELS = [
|
|||
|
||||
@register_llm
|
||||
class OpenAILLMService(BaseLLMConfiguration):
|
||||
model_config = OPENAI_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
|
||||
model: str = Field(
|
||||
default="gpt-4.1",
|
||||
|
|
@ -214,6 +281,7 @@ class OpenAILLMService(BaseLLMConfiguration):
|
|||
|
||||
@register_llm
|
||||
class GoogleLLMService(BaseLLMConfiguration):
|
||||
model_config = GOOGLE_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.GOOGLE] = ServiceProviders.GOOGLE
|
||||
model: str = Field(
|
||||
default="gemini-2.0-flash",
|
||||
|
|
@ -224,6 +292,7 @@ class GoogleLLMService(BaseLLMConfiguration):
|
|||
|
||||
@register_llm
|
||||
class GroqLLMService(BaseLLMConfiguration):
|
||||
model_config = GROQ_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.GROQ] = ServiceProviders.GROQ
|
||||
model: str = Field(
|
||||
default="llama-3.3-70b-versatile",
|
||||
|
|
@ -234,6 +303,7 @@ class GroqLLMService(BaseLLMConfiguration):
|
|||
|
||||
@register_llm
|
||||
class OpenRouterLLMConfiguration(BaseLLMConfiguration):
|
||||
model_config = OPENROUTER_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.OPENROUTER] = ServiceProviders.OPENROUTER
|
||||
model: str = Field(
|
||||
default="openai/gpt-4.1",
|
||||
|
|
@ -249,6 +319,7 @@ class OpenRouterLLMConfiguration(BaseLLMConfiguration):
|
|||
|
||||
@register_llm
|
||||
class AzureLLMService(BaseLLMConfiguration):
|
||||
model_config = AZURE_OPENAI_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.AZURE] = ServiceProviders.AZURE
|
||||
model: str = Field(
|
||||
default="gpt-4.1-mini",
|
||||
|
|
@ -263,6 +334,7 @@ class AzureLLMService(BaseLLMConfiguration):
|
|||
|
||||
@register_llm
|
||||
class DograhLLMService(BaseLLMConfiguration):
|
||||
model_config = DOGRAH_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
|
||||
model: str = Field(
|
||||
default="default",
|
||||
|
|
@ -273,6 +345,7 @@ class DograhLLMService(BaseLLMConfiguration):
|
|||
|
||||
@register_llm
|
||||
class AWSBedrockLLMConfiguration(BaseLLMConfiguration):
|
||||
model_config = AWS_BEDROCK_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.AWS_BEDROCK] = ServiceProviders.AWS_BEDROCK
|
||||
model: str = Field(
|
||||
default="us.amazon.nova-pro-v1:0",
|
||||
|
|
@ -302,6 +375,7 @@ SPEACHES_LLM_MODELS = ["llama3", "mistral", "phi3", "qwen2", "gemma2", "deepseek
|
|||
|
||||
@register_llm
|
||||
class SpeachesLLMConfiguration(BaseLLMConfiguration):
|
||||
model_config = SPEACHES_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.SPEACHES] = ServiceProviders.SPEACHES
|
||||
model: str = Field(
|
||||
default="llama3",
|
||||
|
|
@ -336,6 +410,7 @@ OPENAI_REALTIME_VOICES = [
|
|||
|
||||
@register_service(ServiceType.REALTIME)
|
||||
class OpenAIRealtimeLLMConfiguration(BaseLLMConfiguration):
|
||||
model_config = OPENAI_REALTIME_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.OPENAI_REALTIME] = (
|
||||
ServiceProviders.OPENAI_REALTIME
|
||||
)
|
||||
|
|
@ -357,39 +432,9 @@ class OpenAIRealtimeLLMConfiguration(BaseLLMConfiguration):
|
|||
)
|
||||
|
||||
|
||||
GOOGLE_REALTIME_MODELS = ["gemini-3.1-flash-live-preview"]
|
||||
GOOGLE_REALTIME_VOICES = ["Puck", "Charon", "Kore", "Fenrir", "Aoede"]
|
||||
GOOGLE_REALTIME_LANGUAGES = [
|
||||
"ar",
|
||||
"bn",
|
||||
"de",
|
||||
"en",
|
||||
"es",
|
||||
"fr",
|
||||
"gu",
|
||||
"hi",
|
||||
"id",
|
||||
"it",
|
||||
"ja",
|
||||
"kn",
|
||||
"ko",
|
||||
"ml",
|
||||
"mr",
|
||||
"nl",
|
||||
"pl",
|
||||
"pt",
|
||||
"ru",
|
||||
"ta",
|
||||
"te",
|
||||
"th",
|
||||
"tr",
|
||||
"vi",
|
||||
"zh",
|
||||
]
|
||||
|
||||
|
||||
@register_service(ServiceType.REALTIME)
|
||||
class GoogleRealtimeLLMConfiguration(BaseLLMConfiguration):
|
||||
model_config = GOOGLE_REALTIME_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.GOOGLE_REALTIME] = (
|
||||
ServiceProviders.GOOGLE_REALTIME
|
||||
)
|
||||
|
|
@ -419,15 +464,9 @@ class GoogleRealtimeLLMConfiguration(BaseLLMConfiguration):
|
|||
)
|
||||
|
||||
|
||||
GOOGLE_VERTEX_REALTIME_MODELS = [
|
||||
"google/gemini-live-2.5-flash-native-audio",
|
||||
]
|
||||
GOOGLE_VERTEX_REALTIME_VOICES = GOOGLE_REALTIME_VOICES
|
||||
GOOGLE_VERTEX_REALTIME_LANGUAGES = GOOGLE_REALTIME_LANGUAGES
|
||||
|
||||
|
||||
@register_service(ServiceType.REALTIME)
|
||||
class GoogleVertexRealtimeLLMConfiguration(BaseLLMConfiguration):
|
||||
model_config = GOOGLE_VERTEX_REALTIME_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.GOOGLE_VERTEX_REALTIME] = (
|
||||
ServiceProviders.GOOGLE_VERTEX_REALTIME
|
||||
)
|
||||
|
|
@ -512,6 +551,7 @@ RealtimeConfig = Annotated[
|
|||
|
||||
@register_tts
|
||||
class DeepgramTTSConfiguration(BaseServiceConfiguration):
|
||||
model_config = DEEPGRAM_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.DEEPGRAM] = ServiceProviders.DEEPGRAM
|
||||
voice: str = Field(
|
||||
default="aura-2-helena-en",
|
||||
|
|
@ -537,6 +577,7 @@ ELEVENLABS_TTS_MODELS = ["eleven_flash_v2_5"]
|
|||
|
||||
@register_tts
|
||||
class ElevenlabsTTSConfiguration(BaseServiceConfiguration):
|
||||
model_config = ELEVENLABS_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.ELEVENLABS] = ServiceProviders.ELEVENLABS
|
||||
voice: str = Field(
|
||||
default="21m00Tcm4TlvDq8ikWAM",
|
||||
|
|
@ -558,11 +599,70 @@ class ElevenlabsTTSConfiguration(BaseServiceConfiguration):
|
|||
)
|
||||
|
||||
|
||||
@register_tts
|
||||
class GoogleTTSConfiguration(BaseTTSConfiguration):
|
||||
model_config = GOOGLE_CLOUD_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.GOOGLE] = ServiceProviders.GOOGLE
|
||||
model: str = Field(
|
||||
default="chirp_3_hd",
|
||||
description=(
|
||||
"Google Cloud low-latency TTS engine. Dograh maps this to Pipecat's "
|
||||
"streaming Google TTS service for Chirp 3 HD and Journey voices."
|
||||
),
|
||||
json_schema_extra={
|
||||
"examples": GOOGLE_TTS_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
voice: str = Field(
|
||||
default="en-US-Chirp3-HD-Charon",
|
||||
description="Google Cloud voice name. Use a Chirp 3 HD or Journey voice for streaming TTS.",
|
||||
json_schema_extra={
|
||||
"examples": GOOGLE_TTS_VOICES,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
language: str = Field(
|
||||
default="en-US",
|
||||
description="BCP-47 language code for synthesis.",
|
||||
json_schema_extra={
|
||||
"examples": GOOGLE_TTS_LANGUAGES,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
speed: float = Field(
|
||||
default=1.0,
|
||||
ge=0.25,
|
||||
le=2.0,
|
||||
description="Speech speed multiplier for Google streaming TTS.",
|
||||
)
|
||||
location: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional Google Cloud regional Text-to-Speech endpoint (for example "
|
||||
"'us-central1'). Leave blank to use the default endpoint."
|
||||
),
|
||||
)
|
||||
credentials: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Paste the entire Google Cloud service-account JSON. If omitted, "
|
||||
"the server falls back to Application Default Credentials (ADC)."
|
||||
),
|
||||
json_schema_extra={"multiline": True},
|
||||
)
|
||||
api_key: str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="Not used for Google Cloud TTS. Leave blank.",
|
||||
)
|
||||
|
||||
|
||||
OPENAI_TTS_MODELS = ["gpt-4o-mini-tts"]
|
||||
|
||||
|
||||
@register_tts
|
||||
class OpenAITTSService(BaseTTSConfiguration):
|
||||
model_config = OPENAI_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
|
||||
model: str = Field(
|
||||
default="gpt-4o-mini-tts",
|
||||
|
|
@ -580,6 +680,7 @@ DOGRAH_TTS_MODELS = ["default"]
|
|||
|
||||
@register_tts
|
||||
class DograhTTSService(BaseTTSConfiguration):
|
||||
model_config = DOGRAH_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
|
||||
model: str = Field(
|
||||
default="default",
|
||||
|
|
@ -598,6 +699,7 @@ CARTESIA_TTS_MODELS = ["sonic-3"]
|
|||
|
||||
@register_tts
|
||||
class CartesiaTTSConfiguration(BaseTTSConfiguration):
|
||||
model_config = CARTESIA_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.CARTESIA] = ServiceProviders.CARTESIA
|
||||
model: str = Field(
|
||||
default="sonic-3",
|
||||
|
|
@ -617,75 +719,9 @@ class CartesiaTTSConfiguration(BaseTTSConfiguration):
|
|||
)
|
||||
|
||||
|
||||
SARVAM_TTS_MODELS = ["bulbul:v2", "bulbul:v3"]
|
||||
SARVAM_V2_VOICES = [
|
||||
"anushka",
|
||||
"manisha",
|
||||
"vidya",
|
||||
"arya",
|
||||
"abhilash",
|
||||
"karun",
|
||||
"hitesh",
|
||||
]
|
||||
SARVAM_V3_VOICES = [
|
||||
"shubh",
|
||||
"aditya",
|
||||
"ritu",
|
||||
"priya",
|
||||
"neha",
|
||||
"rahul",
|
||||
"pooja",
|
||||
"rohan",
|
||||
"simran",
|
||||
"kavya",
|
||||
"amit",
|
||||
"dev",
|
||||
"ishita",
|
||||
"shreya",
|
||||
"ratan",
|
||||
"varun",
|
||||
"manan",
|
||||
"sumit",
|
||||
"roopa",
|
||||
"kabir",
|
||||
"aayan",
|
||||
"ashutosh",
|
||||
"advait",
|
||||
"amelia",
|
||||
"sophia",
|
||||
"anand",
|
||||
"tanya",
|
||||
"tarun",
|
||||
"sunny",
|
||||
"mani",
|
||||
"gokul",
|
||||
"vijay",
|
||||
"shruti",
|
||||
"suhani",
|
||||
"mohit",
|
||||
"kavitha",
|
||||
"rehan",
|
||||
"soham",
|
||||
"rupali",
|
||||
]
|
||||
SARVAM_LANGUAGES = [
|
||||
"bn-IN",
|
||||
"en-IN",
|
||||
"gu-IN",
|
||||
"hi-IN",
|
||||
"kn-IN",
|
||||
"ml-IN",
|
||||
"mr-IN",
|
||||
"od-IN",
|
||||
"pa-IN",
|
||||
"ta-IN",
|
||||
"te-IN",
|
||||
"as-IN",
|
||||
]
|
||||
|
||||
|
||||
@register_tts
|
||||
class SarvamTTSConfiguration(BaseTTSConfiguration):
|
||||
model_config = SARVAM_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.SARVAM] = ServiceProviders.SARVAM
|
||||
model: str = Field(
|
||||
default="bulbul:v2",
|
||||
|
|
@ -715,6 +751,7 @@ CAMB_TTS_MODELS = ["mars-flash", "mars-pro", "mars-instruct"]
|
|||
|
||||
@register_tts
|
||||
class CambTTSConfiguration(BaseTTSConfiguration):
|
||||
model_config = CAMB_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.CAMB] = ServiceProviders.CAMB
|
||||
model: str = Field(
|
||||
default="mars-flash",
|
||||
|
|
@ -731,6 +768,7 @@ RIME_TTS_LANGUAGES = ["en", "de", "fr", "es", "hi"]
|
|||
|
||||
@register_tts
|
||||
class RimeTTSConfiguration(BaseTTSConfiguration):
|
||||
model_config = RIME_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.RIME] = ServiceProviders.RIME
|
||||
model: str = Field(
|
||||
default="arcana",
|
||||
|
|
@ -756,6 +794,7 @@ SPEACHES_TTS_MODELS = ["hexgrad/Kokoro-82M"]
|
|||
|
||||
@register_tts
|
||||
class SpeachesTTSConfiguration(BaseTTSConfiguration):
|
||||
model_config = SPEACHES_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.SPEACHES] = ServiceProviders.SPEACHES
|
||||
model: str = Field(
|
||||
default="kokoro",
|
||||
|
|
@ -786,6 +825,7 @@ class SpeachesTTSConfiguration(BaseTTSConfiguration):
|
|||
TTSConfig = Annotated[
|
||||
Union[
|
||||
DeepgramTTSConfiguration,
|
||||
GoogleTTSConfiguration,
|
||||
OpenAITTSService,
|
||||
ElevenlabsTTSConfiguration,
|
||||
CartesiaTTSConfiguration,
|
||||
|
|
@ -801,94 +841,9 @@ TTSConfig = Annotated[
|
|||
###################################################### STT ########################################################################
|
||||
|
||||
|
||||
DEEPGRAM_STT_MODELS = ["nova-3-general", "flux-general-en", "flux-general-multi"]
|
||||
DEEPGRAM_LANGUAGES = [
|
||||
"multi",
|
||||
"ar",
|
||||
"ar-AE",
|
||||
"ar-SA",
|
||||
"ar-QA",
|
||||
"ar-KW",
|
||||
"ar-SY",
|
||||
"ar-LB",
|
||||
"ar-PS",
|
||||
"ar-JO",
|
||||
"ar-EG",
|
||||
"ar-SD",
|
||||
"ar-TD",
|
||||
"ar-MA",
|
||||
"ar-DZ",
|
||||
"ar-TN",
|
||||
"ar-IQ",
|
||||
"ar-IR",
|
||||
"be",
|
||||
"bn",
|
||||
"bs",
|
||||
"bg",
|
||||
"ca",
|
||||
"cs",
|
||||
"da",
|
||||
"da-DK",
|
||||
"de",
|
||||
"de-CH",
|
||||
"el",
|
||||
"en",
|
||||
"en-US",
|
||||
"en-AU",
|
||||
"en-GB",
|
||||
"en-IN",
|
||||
"en-NZ",
|
||||
"es",
|
||||
"es-419",
|
||||
"et",
|
||||
"fa",
|
||||
"fi",
|
||||
"fr",
|
||||
"fr-CA",
|
||||
"he",
|
||||
"hi",
|
||||
"hr",
|
||||
"hu",
|
||||
"id",
|
||||
"it",
|
||||
"ja",
|
||||
"kn",
|
||||
"ko",
|
||||
"ko-KR",
|
||||
"lt",
|
||||
"lv",
|
||||
"mk",
|
||||
"mr",
|
||||
"ms",
|
||||
"nl",
|
||||
"nl-BE",
|
||||
"no",
|
||||
"pl",
|
||||
"pt",
|
||||
"pt-BR",
|
||||
"pt-PT",
|
||||
"ro",
|
||||
"ru",
|
||||
"sk",
|
||||
"sl",
|
||||
"sr",
|
||||
"sv",
|
||||
"sv-SE",
|
||||
"ta",
|
||||
"te",
|
||||
"th",
|
||||
"tl",
|
||||
"tr",
|
||||
"uk",
|
||||
"ur",
|
||||
"vi",
|
||||
"zh-CN",
|
||||
"zh-TW",
|
||||
]
|
||||
|
||||
|
||||
@register_stt
|
||||
class DeepgramSTTConfiguration(BaseSTTConfiguration):
|
||||
model_config = DEEPGRAM_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.DEEPGRAM] = ServiceProviders.DEEPGRAM
|
||||
model: str = Field(
|
||||
default="nova-3-general",
|
||||
|
|
@ -902,7 +857,7 @@ class DeepgramSTTConfiguration(BaseSTTConfiguration):
|
|||
"examples": DEEPGRAM_LANGUAGES,
|
||||
"model_options": {
|
||||
"nova-3-general": DEEPGRAM_LANGUAGES,
|
||||
"flux-general-en": ["en"],
|
||||
"flux-general-en": ("en",),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
|
@ -913,6 +868,7 @@ CARTESIA_STT_MODELS = ["ink-whisper"]
|
|||
|
||||
@register_stt
|
||||
class CartesiaSTTConfiguration(BaseSTTConfiguration):
|
||||
model_config = CARTESIA_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.CARTESIA] = ServiceProviders.CARTESIA
|
||||
model: str = Field(
|
||||
default="ink-whisper",
|
||||
|
|
@ -926,6 +882,7 @@ OPENAI_STT_MODELS = ["gpt-4o-transcribe"]
|
|||
|
||||
@register_stt
|
||||
class OpenAISTTConfiguration(BaseSTTConfiguration):
|
||||
model_config = OPENAI_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
|
||||
model: str = Field(
|
||||
default="gpt-4o-transcribe",
|
||||
|
|
@ -934,6 +891,45 @@ class OpenAISTTConfiguration(BaseSTTConfiguration):
|
|||
)
|
||||
|
||||
|
||||
@register_stt
|
||||
class GoogleSTTConfiguration(BaseSTTConfiguration):
|
||||
model_config = GOOGLE_CLOUD_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.GOOGLE] = ServiceProviders.GOOGLE
|
||||
model: str = Field(
|
||||
default="latest_long",
|
||||
description="Google Cloud Speech-to-Text V2 recognition model.",
|
||||
json_schema_extra={
|
||||
"examples": GOOGLE_STT_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
language: str = Field(
|
||||
default="en-US",
|
||||
description="Primary BCP-47 language code for recognition.",
|
||||
json_schema_extra={
|
||||
"examples": GOOGLE_STT_LANGUAGES,
|
||||
"allow_custom_input": True,
|
||||
"docs_url": "https://docs.cloud.google.com/speech-to-text/docs/speech-to-text-supported-languages",
|
||||
},
|
||||
)
|
||||
location: str = Field(
|
||||
default="global",
|
||||
description="Google Cloud Speech-to-Text region (for example 'global' or 'us-central1').",
|
||||
)
|
||||
credentials: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Paste the entire Google Cloud service-account JSON. If omitted, "
|
||||
"the server falls back to Application Default Credentials (ADC)."
|
||||
),
|
||||
json_schema_extra={"multiline": True},
|
||||
)
|
||||
api_key: str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="Not used for Google Cloud STT. Leave blank.",
|
||||
)
|
||||
|
||||
|
||||
# Dograh STT Service
|
||||
DOGRAH_STT_MODELS = ["default"]
|
||||
DOGRAH_STT_LANGUAGES = DEEPGRAM_LANGUAGES
|
||||
|
|
@ -941,6 +937,7 @@ DOGRAH_STT_LANGUAGES = DEEPGRAM_LANGUAGES
|
|||
|
||||
@register_stt
|
||||
class DograhSTTService(BaseSTTConfiguration):
|
||||
model_config = DOGRAH_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
|
||||
model: str = Field(
|
||||
default="default",
|
||||
|
|
@ -954,12 +951,9 @@ class DograhSTTService(BaseSTTConfiguration):
|
|||
)
|
||||
|
||||
|
||||
# Sarvam STT Service
|
||||
SARVAM_STT_MODELS = ["saarika:v2.5", "saaras:v2"]
|
||||
|
||||
|
||||
@register_stt
|
||||
class SarvamSTTConfiguration(BaseSTTConfiguration):
|
||||
model_config = SARVAM_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.SARVAM] = ServiceProviders.SARVAM
|
||||
model: str = Field(
|
||||
default="saarika:v2.5",
|
||||
|
|
@ -973,74 +967,9 @@ class SarvamSTTConfiguration(BaseSTTConfiguration):
|
|||
)
|
||||
|
||||
|
||||
# Speechmatics STT Service
|
||||
SPEECHMATICS_STT_LANGUAGES = [
|
||||
"ar",
|
||||
"ar_en",
|
||||
"ba",
|
||||
"eu",
|
||||
"be",
|
||||
"bn",
|
||||
"bg",
|
||||
"yue",
|
||||
"ca",
|
||||
"hr",
|
||||
"cs",
|
||||
"da",
|
||||
"nl",
|
||||
"en",
|
||||
"eo",
|
||||
"et",
|
||||
"fi",
|
||||
"fr",
|
||||
"gl",
|
||||
"de",
|
||||
"el",
|
||||
"he",
|
||||
"hi",
|
||||
"hu",
|
||||
"id",
|
||||
"ia",
|
||||
"ga",
|
||||
"it",
|
||||
"ja",
|
||||
"ko",
|
||||
"lv",
|
||||
"lt",
|
||||
"ms",
|
||||
"en_ms",
|
||||
"mt",
|
||||
"cmn",
|
||||
"cmn_en",
|
||||
"cmn_en_ms_ta",
|
||||
"mr",
|
||||
"mn",
|
||||
"no",
|
||||
"fa",
|
||||
"pl",
|
||||
"pt",
|
||||
"ro",
|
||||
"ru",
|
||||
"sk",
|
||||
"sl",
|
||||
"es",
|
||||
"sw",
|
||||
"sv",
|
||||
"tl",
|
||||
"ta",
|
||||
"en_ta",
|
||||
"th",
|
||||
"tr",
|
||||
"uk",
|
||||
"ur",
|
||||
"ug",
|
||||
"vi",
|
||||
"cy",
|
||||
]
|
||||
|
||||
|
||||
@register_stt
|
||||
class SpeechmaticsSTTConfiguration(BaseSTTConfiguration):
|
||||
model_config = SPEECHMATICS_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.SPEECHMATICS] = ServiceProviders.SPEECHMATICS
|
||||
model: str = Field(
|
||||
default="enhanced",
|
||||
|
|
@ -1062,6 +991,7 @@ SPEACHES_STT_LANGUAGES = ["en", "ar", "nl", "fr", "de", "hi", "it", "pt", "es"]
|
|||
|
||||
@register_stt
|
||||
class SpeachesSTTConfiguration(BaseSTTConfiguration):
|
||||
model_config = SPEACHES_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.SPEACHES] = ServiceProviders.SPEACHES
|
||||
model: str = Field(
|
||||
default="Systran/faster-distil-whisper-small.en",
|
||||
|
|
@ -1095,6 +1025,7 @@ ASSEMBLYAI_STT_LANGUAGES = ["en", "es", "de", "fr", "pt", "it"]
|
|||
|
||||
@register_stt
|
||||
class AssemblyAISTTConfiguration(BaseSTTConfiguration):
|
||||
model_config = ASSEMBLYAI_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.ASSEMBLYAI] = ServiceProviders.ASSEMBLYAI
|
||||
model: str = Field(
|
||||
default="u3-rt-pro",
|
||||
|
|
@ -1108,113 +1039,9 @@ class AssemblyAISTTConfiguration(BaseSTTConfiguration):
|
|||
)
|
||||
|
||||
|
||||
GLADIA_STT_MODELS = ["solaria-1"]
|
||||
GLADIA_STT_LANGUAGES = [
|
||||
"af",
|
||||
"am",
|
||||
"ar",
|
||||
"as",
|
||||
"az",
|
||||
"ba",
|
||||
"be",
|
||||
"bg",
|
||||
"bn",
|
||||
"bo",
|
||||
"br",
|
||||
"bs",
|
||||
"ca",
|
||||
"cs",
|
||||
"cy",
|
||||
"da",
|
||||
"de",
|
||||
"el",
|
||||
"en",
|
||||
"es",
|
||||
"et",
|
||||
"eu",
|
||||
"fa",
|
||||
"fi",
|
||||
"fo",
|
||||
"fr",
|
||||
"gl",
|
||||
"gu",
|
||||
"ha",
|
||||
"haw",
|
||||
"he",
|
||||
"hi",
|
||||
"hr",
|
||||
"ht",
|
||||
"hu",
|
||||
"hy",
|
||||
"id",
|
||||
"is",
|
||||
"it",
|
||||
"ja",
|
||||
"jw",
|
||||
"ka",
|
||||
"kk",
|
||||
"km",
|
||||
"kn",
|
||||
"ko",
|
||||
"la",
|
||||
"lb",
|
||||
"ln",
|
||||
"lo",
|
||||
"lt",
|
||||
"lv",
|
||||
"mg",
|
||||
"mi",
|
||||
"mk",
|
||||
"ml",
|
||||
"mn",
|
||||
"mr",
|
||||
"ms",
|
||||
"mt",
|
||||
"my",
|
||||
"ne",
|
||||
"nl",
|
||||
"nn",
|
||||
"no",
|
||||
"oc",
|
||||
"pa",
|
||||
"pl",
|
||||
"ps",
|
||||
"pt",
|
||||
"ro",
|
||||
"ru",
|
||||
"sa",
|
||||
"sd",
|
||||
"si",
|
||||
"sk",
|
||||
"sl",
|
||||
"sn",
|
||||
"so",
|
||||
"sq",
|
||||
"sr",
|
||||
"su",
|
||||
"sv",
|
||||
"sw",
|
||||
"ta",
|
||||
"te",
|
||||
"tg",
|
||||
"th",
|
||||
"tk",
|
||||
"tl",
|
||||
"tr",
|
||||
"tt",
|
||||
"uk",
|
||||
"ur",
|
||||
"uz",
|
||||
"vi",
|
||||
"wo",
|
||||
"yi",
|
||||
"yo",
|
||||
"zh",
|
||||
]
|
||||
|
||||
|
||||
@register_stt
|
||||
class GladiaSTTConfiguration(BaseSTTConfiguration):
|
||||
model_config = GLADIA_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.GLADIA] = ServiceProviders.GLADIA
|
||||
model: str = Field(
|
||||
default="solaria-1",
|
||||
|
|
@ -1233,6 +1060,7 @@ STTConfig = Annotated[
|
|||
DeepgramSTTConfiguration,
|
||||
CartesiaSTTConfiguration,
|
||||
OpenAISTTConfiguration,
|
||||
GoogleSTTConfiguration,
|
||||
DograhSTTService,
|
||||
SpeechmaticsSTTConfiguration,
|
||||
SarvamSTTConfiguration,
|
||||
|
|
@ -1250,6 +1078,7 @@ OPENAI_EMBEDDING_MODELS = ["text-embedding-3-small"]
|
|||
|
||||
@register_embeddings
|
||||
class OpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
||||
model_config = OPENAI_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
|
||||
model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
|
|
@ -1263,6 +1092,7 @@ OPENROUTER_EMBEDDING_MODELS = ["openai/text-embedding-3-small"]
|
|||
|
||||
@register_embeddings
|
||||
class OpenRouterEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
||||
model_config = OPENROUTER_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.OPENROUTER] = ServiceProviders.OPENROUTER
|
||||
model: str = Field(
|
||||
default="openai/text-embedding-3-small",
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ from pipecat.services.dograh.tts import DograhTTSService, DograhTTSSettings
|
|||
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService, ElevenLabsTTSSettings
|
||||
from pipecat.services.gladia.stt import GladiaSTTService, GladiaSTTSettings
|
||||
from pipecat.services.google.llm import GoogleLLMService, GoogleLLMSettings
|
||||
from pipecat.services.google.stt import GoogleSTTService, GoogleSTTSettings
|
||||
from pipecat.services.google.tts import GoogleTTSService, GoogleTTSSettings
|
||||
from pipecat.services.groq.llm import GroqLLMService, GroqLLMSettings
|
||||
from pipecat.services.openai.base_llm import OpenAILLMSettings
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
|
@ -101,6 +103,23 @@ def create_stt_service(
|
|||
api_key=user_config.stt.api_key,
|
||||
settings=OpenAISTTSettings(model=user_config.stt.model),
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.GOOGLE.value:
|
||||
language = getattr(user_config.stt, "language", None) or "en-US"
|
||||
location = getattr(user_config.stt, "location", None) or "global"
|
||||
credentials = getattr(user_config.stt, "credentials", None)
|
||||
|
||||
settings_kwargs = {"model": user_config.stt.model}
|
||||
try:
|
||||
settings_kwargs["languages"] = [Language(language)]
|
||||
except ValueError:
|
||||
settings_kwargs["language_codes"] = [language]
|
||||
|
||||
return GoogleSTTService(
|
||||
credentials=credentials,
|
||||
location=location,
|
||||
settings=GoogleSTTSettings(**settings_kwargs),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.CARTESIA.value:
|
||||
return CartesiaSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
|
|
@ -241,6 +260,30 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
skip_aggregator_types=["recording_router", "recording"],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.GOOGLE.value:
|
||||
model = getattr(user_config.tts, "model", None) or "chirp_3_hd"
|
||||
language = getattr(user_config.tts, "language", None) or "en-US"
|
||||
voice = getattr(user_config.tts, "voice", None) or "en-US-Chirp3-HD-Charon"
|
||||
speed = getattr(user_config.tts, "speed", None)
|
||||
location = getattr(user_config.tts, "location", None) or None
|
||||
credentials = getattr(user_config.tts, "credentials", None)
|
||||
|
||||
settings_kwargs = {
|
||||
"model": model,
|
||||
"voice": voice,
|
||||
"language": language,
|
||||
}
|
||||
if speed is not None and speed != 1.0:
|
||||
settings_kwargs["speaking_rate"] = speed
|
||||
|
||||
return GoogleTTSService(
|
||||
credentials=credentials,
|
||||
location=location,
|
||||
settings=GoogleTTSSettings(**settings_kwargs),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
skip_aggregator_types=["recording_router", "recording"],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
|
||||
# Backward compatible with older configuration "Name - voice_id"
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -610,13 +610,14 @@ class GlobalNodeData(BaseNodeData, _PromptedNodeDataMixin):
|
|||
"trigger_path": {
|
||||
"display_name": "Trigger Path",
|
||||
"description": (
|
||||
"Auto-generated UUID-style path segment that uniquely identifies "
|
||||
"Path segment that uniquely identifies "
|
||||
"this trigger. Used in both URLs:\n"
|
||||
" • Production: `/api/v1/public/agent/<trigger_path>` — executes "
|
||||
"the published agent.\n"
|
||||
" • Test: `/api/v1/public/agent/test/<trigger_path>` — executes "
|
||||
"the latest draft.\n"
|
||||
"Do not edit manually."
|
||||
"Can be customized to a descriptive value up to 36 characters "
|
||||
"using letters, numbers, hyphens, or underscores."
|
||||
),
|
||||
},
|
||||
},
|
||||
|
|
|
|||
142
api/services/workflow/trigger_paths.py
Normal file
142
api/services/workflow/trigger_paths.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
import copy
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
TRIGGER_PATH_MAX_LENGTH = 36
|
||||
TRIGGER_PATH_PATTERN = re.compile(r"^[A-Za-z0-9_-]+$")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TriggerPathIssue:
|
||||
node_id: str | None
|
||||
trigger_path: str
|
||||
message: str
|
||||
|
||||
|
||||
def extract_trigger_paths(workflow_definition: Optional[dict]) -> list[str]:
|
||||
"""Extract trigger paths from a workflow definition."""
|
||||
if not workflow_definition:
|
||||
return []
|
||||
|
||||
trigger_paths: list[str] = []
|
||||
for node in workflow_definition.get("nodes") or []:
|
||||
if node.get("type") != "trigger":
|
||||
continue
|
||||
trigger_path = (node.get("data") or {}).get("trigger_path")
|
||||
if isinstance(trigger_path, str) and trigger_path:
|
||||
trigger_paths.append(trigger_path)
|
||||
return trigger_paths
|
||||
|
||||
|
||||
def trigger_path_to_node_id(workflow_definition: Optional[dict]) -> dict[str, str]:
|
||||
"""Map each trigger node's trigger_path to its node id."""
|
||||
if not workflow_definition:
|
||||
return {}
|
||||
|
||||
out: dict[str, str] = {}
|
||||
for node in workflow_definition.get("nodes") or []:
|
||||
if node.get("type") != "trigger":
|
||||
continue
|
||||
trigger_path = (node.get("data") or {}).get("trigger_path")
|
||||
if isinstance(trigger_path, str) and trigger_path:
|
||||
out[trigger_path] = node.get("id")
|
||||
return out
|
||||
|
||||
|
||||
def regenerate_trigger_uuids(workflow_definition: Optional[dict]) -> Optional[dict]:
|
||||
"""Regenerate UUIDs for all trigger nodes in a workflow definition."""
|
||||
if not workflow_definition:
|
||||
return workflow_definition
|
||||
|
||||
updated_definition = copy.deepcopy(workflow_definition)
|
||||
for node in updated_definition.get("nodes") or []:
|
||||
if node.get("type") != "trigger":
|
||||
continue
|
||||
data = node.setdefault("data", {})
|
||||
data["trigger_path"] = str(uuid.uuid4())
|
||||
return updated_definition
|
||||
|
||||
|
||||
def ensure_trigger_paths(workflow_definition: Optional[dict]) -> Optional[dict]:
|
||||
"""Mint UUIDs for trigger nodes that do not already have a path."""
|
||||
if not workflow_definition:
|
||||
return workflow_definition
|
||||
|
||||
out = copy.deepcopy(workflow_definition)
|
||||
for node in out.get("nodes") or []:
|
||||
if node.get("type") != "trigger":
|
||||
continue
|
||||
data = node.setdefault("data", {})
|
||||
if not data.get("trigger_path"):
|
||||
data["trigger_path"] = str(uuid.uuid4())
|
||||
return out
|
||||
|
||||
|
||||
def validate_trigger_paths(
|
||||
workflow_definition: Optional[dict],
|
||||
) -> list[TriggerPathIssue]:
|
||||
"""Validate custom trigger paths before they reach persistence/runtime."""
|
||||
if not workflow_definition:
|
||||
return []
|
||||
|
||||
issues: list[TriggerPathIssue] = []
|
||||
seen_paths: dict[str, str | None] = {}
|
||||
|
||||
for node in workflow_definition.get("nodes") or []:
|
||||
if node.get("type") != "trigger":
|
||||
continue
|
||||
|
||||
node_id = node.get("id")
|
||||
trigger_path = (node.get("data") or {}).get("trigger_path")
|
||||
if not trigger_path:
|
||||
continue
|
||||
|
||||
if not isinstance(trigger_path, str):
|
||||
issues.append(
|
||||
TriggerPathIssue(
|
||||
node_id=node_id,
|
||||
trigger_path=repr(trigger_path),
|
||||
message="Trigger path must be a string.",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if len(trigger_path) > TRIGGER_PATH_MAX_LENGTH:
|
||||
issues.append(
|
||||
TriggerPathIssue(
|
||||
node_id=node_id,
|
||||
trigger_path=trigger_path,
|
||||
message=(
|
||||
f"Trigger path must be {TRIGGER_PATH_MAX_LENGTH} "
|
||||
"characters or fewer."
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not TRIGGER_PATH_PATTERN.fullmatch(trigger_path):
|
||||
issues.append(
|
||||
TriggerPathIssue(
|
||||
node_id=node_id,
|
||||
trigger_path=trigger_path,
|
||||
message=(
|
||||
"Trigger path must be a single URL path segment using "
|
||||
"only letters, numbers, hyphens, and underscores."
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
first_node_id = seen_paths.get(trigger_path)
|
||||
if first_node_id is None:
|
||||
seen_paths[trigger_path] = node_id
|
||||
else:
|
||||
issues.append(
|
||||
TriggerPathIssue(
|
||||
node_id=node_id,
|
||||
trigger_path=trigger_path,
|
||||
message="Trigger path is duplicated in this workflow.",
|
||||
)
|
||||
)
|
||||
|
||||
return issues
|
||||
55
api/tests/test_google_stt_service_factory.py
Normal file
55
api/tests/test_google_stt_service_factory.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.pipecat.service_factory import create_stt_service
|
||||
|
||||
|
||||
def test_create_google_stt_service_uses_credentials_location_and_language():
|
||||
user_config = SimpleNamespace(
|
||||
stt=SimpleNamespace(
|
||||
provider=ServiceProviders.GOOGLE.value,
|
||||
credentials='{"project_id":"demo-project"}',
|
||||
api_key=None,
|
||||
model="latest_long",
|
||||
language="en-US",
|
||||
location="us-central1",
|
||||
)
|
||||
)
|
||||
audio_config = SimpleNamespace(transport_in_sample_rate=16000)
|
||||
|
||||
with patch("api.services.pipecat.service_factory.GoogleSTTService") as mock_service:
|
||||
create_stt_service(user_config, audio_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["credentials"] == '{"project_id":"demo-project"}'
|
||||
assert kwargs["location"] == "us-central1"
|
||||
assert kwargs["sample_rate"] == 16000
|
||||
assert kwargs["settings"].model == "latest_long"
|
||||
assert kwargs["settings"].languages == [Language.EN_US]
|
||||
|
||||
|
||||
def test_create_google_stt_service_falls_back_to_raw_language_codes():
|
||||
user_config = SimpleNamespace(
|
||||
stt=SimpleNamespace(
|
||||
provider=ServiceProviders.GOOGLE.value,
|
||||
credentials=None,
|
||||
api_key=None,
|
||||
model="chirp_3",
|
||||
language="cmn-Hans-CN",
|
||||
location="global",
|
||||
)
|
||||
)
|
||||
audio_config = SimpleNamespace(transport_in_sample_rate=24000)
|
||||
|
||||
with patch("api.services.pipecat.service_factory.GoogleSTTService") as mock_service:
|
||||
create_stt_service(user_config, audio_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["sample_rate"] == 24000
|
||||
assert kwargs["settings"].model == "chirp_3"
|
||||
assert kwargs["settings"].language_codes == ["cmn-Hans-CN"]
|
||||
67
api/tests/test_google_tts_service_factory.py
Normal file
67
api/tests/test_google_tts_service_factory.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from pipecat.services.settings import NOT_GIVEN
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.pipecat.service_factory import create_tts_service
|
||||
|
||||
|
||||
def test_create_google_tts_service_uses_credentials_location_and_settings():
|
||||
user_config = SimpleNamespace(
|
||||
tts=SimpleNamespace(
|
||||
provider=ServiceProviders.GOOGLE.value,
|
||||
credentials='{"project_id":"demo-project"}',
|
||||
api_key=None,
|
||||
model="chirp_3_hd",
|
||||
voice="en-US-Chirp3-HD-Charon",
|
||||
language="en-US",
|
||||
speed=1.15,
|
||||
location="us-central1",
|
||||
)
|
||||
)
|
||||
audio_config = SimpleNamespace(
|
||||
transport_out_sample_rate=24000,
|
||||
transport_in_sample_rate=16000,
|
||||
)
|
||||
|
||||
with patch("api.services.pipecat.service_factory.GoogleTTSService") as mock_service:
|
||||
create_tts_service(user_config, audio_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["credentials"] == '{"project_id":"demo-project"}'
|
||||
assert kwargs["location"] == "us-central1"
|
||||
assert kwargs["settings"].model == "chirp_3_hd"
|
||||
assert kwargs["settings"].voice == "en-US-Chirp3-HD-Charon"
|
||||
assert kwargs["settings"].language == "en-US"
|
||||
assert kwargs["settings"].speaking_rate == 1.15
|
||||
|
||||
|
||||
def test_create_google_tts_service_omits_default_speed():
|
||||
user_config = SimpleNamespace(
|
||||
tts=SimpleNamespace(
|
||||
provider=ServiceProviders.GOOGLE.value,
|
||||
credentials=None,
|
||||
api_key=None,
|
||||
model="chirp_3_hd",
|
||||
voice="en-US-Chirp3-HD-Charon",
|
||||
language="sw-KE",
|
||||
speed=1.0,
|
||||
location=None,
|
||||
)
|
||||
)
|
||||
audio_config = SimpleNamespace(
|
||||
transport_out_sample_rate=24000,
|
||||
transport_in_sample_rate=16000,
|
||||
)
|
||||
|
||||
with patch("api.services.pipecat.service_factory.GoogleTTSService") as mock_service:
|
||||
create_tts_service(user_config, audio_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["location"] is None
|
||||
assert kwargs["settings"].model == "chirp_3_hd"
|
||||
assert kwargs["settings"].language == "sw-KE"
|
||||
assert kwargs["settings"].speaking_rate is NOT_GIVEN
|
||||
|
|
@ -186,6 +186,45 @@ const n = wf.addTyped(startCall({ name: "g", prompt: "hi", promt: "typo" }));
|
|||
update_mock.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_trigger_path_surfaces_validation_error(mock_backends):
|
||||
save_mock, update_mock = mock_backends
|
||||
payload = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "trigger-1",
|
||||
"type": "trigger",
|
||||
"data": {"trigger_path": "support/west"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.tools.save_workflow.parse_code",
|
||||
AsyncMock(
|
||||
return_value={
|
||||
"ok": True,
|
||||
"workflowName": _FakeWorkflowModel.name,
|
||||
"workflow": payload,
|
||||
}
|
||||
),
|
||||
),
|
||||
patch(
|
||||
"api.mcp_server.tools.save_workflow.reconcile_positions",
|
||||
return_value=payload,
|
||||
),
|
||||
):
|
||||
result = await save_workflow(workflow_id=1, code="ignored")
|
||||
|
||||
assert result["saved"] is False
|
||||
assert result["error_code"] == "validation_error"
|
||||
assert "single URL path segment" in result["error"]
|
||||
save_mock.assert_not_awaited()
|
||||
update_mock.assert_not_awaited()
|
||||
|
||||
|
||||
# ─── Graph-stage rejections ──────────────────────────────────────────────
|
||||
|
||||
|
||||
|
|
|
|||
191
api/tests/test_public_agent_routes.py
Normal file
191
api/tests/test_public_agent_routes.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.routes.public_agent import router
|
||||
|
||||
|
||||
def _make_test_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
return app
|
||||
|
||||
|
||||
def _active_workflow(*, trigger_path: str | None = None):
|
||||
nodes = []
|
||||
if trigger_path is not None:
|
||||
nodes.append(
|
||||
{
|
||||
"type": "trigger",
|
||||
"data": {"trigger_path": trigger_path},
|
||||
}
|
||||
)
|
||||
|
||||
return SimpleNamespace(
|
||||
id=33,
|
||||
user_id=99,
|
||||
organization_id=11,
|
||||
status="active",
|
||||
workflow_uuid="workflow-uuid-123",
|
||||
released_definition=SimpleNamespace(
|
||||
workflow_json={"nodes": nodes, "edges": []}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _provider():
|
||||
return SimpleNamespace(
|
||||
PROVIDER_NAME="twilio",
|
||||
WEBHOOK_ENDPOINT="outbound",
|
||||
validate_config=Mock(return_value=True),
|
||||
initiate_call=AsyncMock(),
|
||||
)
|
||||
|
||||
|
||||
def test_trigger_route_executes_as_workflow_owner():
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
workflow = _active_workflow(trigger_path="trigger-uuid-123")
|
||||
provider = _provider()
|
||||
quota_mock = AsyncMock(
|
||||
return_value=SimpleNamespace(has_quota=True, error_message="")
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.routes.public_agent.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.public_agent.check_dograh_quota_by_user_id",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
"api.routes.public_agent.get_default_telephony_provider",
|
||||
new=AsyncMock(return_value=provider),
|
||||
),
|
||||
patch(
|
||||
"api.routes.public_agent.get_backend_endpoints",
|
||||
new=AsyncMock(return_value=("https://api.example.com", "wss://ignored")),
|
||||
),
|
||||
):
|
||||
mock_db.validate_api_key = AsyncMock(
|
||||
return_value=SimpleNamespace(id=7, organization_id=11, created_by=22)
|
||||
)
|
||||
mock_db.get_agent_trigger_by_path = AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
workflow_id=workflow.id,
|
||||
organization_id=11,
|
||||
state="active",
|
||||
)
|
||||
)
|
||||
mock_db.get_workflow = AsyncMock(return_value=workflow)
|
||||
mock_db.get_default_telephony_configuration = AsyncMock(
|
||||
return_value=SimpleNamespace(id=55)
|
||||
)
|
||||
mock_db.create_workflow_run = AsyncMock(return_value=SimpleNamespace(id=501))
|
||||
|
||||
response = client.post(
|
||||
"/public/agent/trigger-uuid-123",
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
json={"phone_number": "+15551234567"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id)
|
||||
mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11)
|
||||
|
||||
create_kwargs = mock_db.create_workflow_run.await_args.kwargs
|
||||
assert create_kwargs["workflow_id"] == workflow.id
|
||||
assert create_kwargs["user_id"] == workflow.user_id
|
||||
assert create_kwargs["organization_id"] == workflow.organization_id
|
||||
assert create_kwargs["initial_context"]["agent_uuid"] == "trigger-uuid-123"
|
||||
assert create_kwargs["initial_context"]["agent_identifier"] == "trigger-uuid-123"
|
||||
assert create_kwargs["initial_context"]["agent_identifier_type"] == "trigger_path"
|
||||
assert create_kwargs["initial_context"]["workflow_uuid"] == workflow.workflow_uuid
|
||||
assert create_kwargs["initial_context"]["api_key_id"] == 7
|
||||
assert create_kwargs["initial_context"]["api_key_created_by"] == 22
|
||||
|
||||
initiate_kwargs = provider.initiate_call.await_args.kwargs
|
||||
assert initiate_kwargs["workflow_id"] == workflow.id
|
||||
assert initiate_kwargs["user_id"] == workflow.user_id
|
||||
|
||||
|
||||
def test_workflow_uuid_route_uses_scoped_lookup_and_shared_execution():
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
workflow = _active_workflow()
|
||||
provider = _provider()
|
||||
quota_mock = AsyncMock(
|
||||
return_value=SimpleNamespace(has_quota=True, error_message="")
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.routes.public_agent.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.public_agent.check_dograh_quota_by_user_id",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
"api.routes.public_agent.get_default_telephony_provider",
|
||||
new=AsyncMock(return_value=provider),
|
||||
),
|
||||
patch(
|
||||
"api.routes.public_agent.get_backend_endpoints",
|
||||
new=AsyncMock(return_value=("https://api.example.com", "wss://ignored")),
|
||||
),
|
||||
):
|
||||
mock_db.validate_api_key = AsyncMock(
|
||||
return_value=SimpleNamespace(id=8, organization_id=11, created_by=22)
|
||||
)
|
||||
mock_db.get_workflow_by_uuid = AsyncMock(return_value=workflow)
|
||||
mock_db.get_default_telephony_configuration = AsyncMock(
|
||||
return_value=SimpleNamespace(id=55)
|
||||
)
|
||||
mock_db.create_workflow_run = AsyncMock(return_value=SimpleNamespace(id=601))
|
||||
|
||||
response = client.post(
|
||||
f"/public/agent/workflow/{workflow.workflow_uuid}",
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
json={"phone_number": "+15551234567"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_db.get_workflow_by_uuid.assert_awaited_once_with(
|
||||
workflow.workflow_uuid,
|
||||
11,
|
||||
)
|
||||
assert not mock_db.get_agent_trigger_by_path.called
|
||||
|
||||
create_kwargs = mock_db.create_workflow_run.await_args.kwargs
|
||||
assert create_kwargs["user_id"] == workflow.user_id
|
||||
assert (
|
||||
create_kwargs["initial_context"]["agent_identifier"] == workflow.workflow_uuid
|
||||
)
|
||||
assert create_kwargs["initial_context"]["agent_identifier_type"] == "workflow_uuid"
|
||||
assert "agent_uuid" not in create_kwargs["initial_context"]
|
||||
|
||||
|
||||
def test_workflow_uuid_route_rejects_archived_workflows():
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
workflow = _active_workflow()
|
||||
workflow.status = "archived"
|
||||
|
||||
with patch("api.routes.public_agent.db_client") as mock_db:
|
||||
mock_db.validate_api_key = AsyncMock(
|
||||
return_value=SimpleNamespace(id=9, organization_id=11, created_by=22)
|
||||
)
|
||||
mock_db.get_workflow_by_uuid = AsyncMock(return_value=workflow)
|
||||
|
||||
response = client.post(
|
||||
f"/public/agent/workflow/{workflow.workflow_uuid}",
|
||||
headers={"X-API-Key": "test-api-key"},
|
||||
json={"phone_number": "+15551234567"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Workflow is not active"
|
||||
assert not mock_db.create_workflow_run.called
|
||||
56
api/tests/test_trigger_path_validation.py
Normal file
56
api/tests/test_trigger_path_validation.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
from api.services.workflow.trigger_paths import (
|
||||
TRIGGER_PATH_MAX_LENGTH,
|
||||
validate_trigger_paths,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_trigger_paths_rejects_invalid_path_segments():
|
||||
workflow_definition = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "trigger-1",
|
||||
"type": "trigger",
|
||||
"data": {"trigger_path": "support/west"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
issues = validate_trigger_paths(workflow_definition)
|
||||
|
||||
assert len(issues) == 1
|
||||
assert issues[0].node_id == "trigger-1"
|
||||
assert "single URL path segment" in issues[0].message
|
||||
|
||||
|
||||
def test_validate_trigger_paths_rejects_long_and_duplicate_paths():
|
||||
long_path = "a" * (TRIGGER_PATH_MAX_LENGTH + 1)
|
||||
workflow_definition = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "trigger-1",
|
||||
"type": "trigger",
|
||||
"data": {"trigger_path": long_path},
|
||||
},
|
||||
{
|
||||
"id": "trigger-2",
|
||||
"type": "trigger",
|
||||
"data": {"trigger_path": "sales_agent"},
|
||||
},
|
||||
{
|
||||
"id": "trigger-3",
|
||||
"type": "trigger",
|
||||
"data": {"trigger_path": "sales_agent"},
|
||||
},
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
issues = validate_trigger_paths(workflow_definition)
|
||||
messages = [issue.message for issue in issues]
|
||||
|
||||
assert (
|
||||
f"Trigger path must be {TRIGGER_PATH_MAX_LENGTH} characters or fewer."
|
||||
in messages
|
||||
)
|
||||
assert "Trigger path is duplicated in this workflow." in messages
|
||||
49
api/tests/test_workflow_create_route.py
Normal file
49
api/tests/test_workflow_create_route.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.routes.workflow import router
|
||||
from api.services.auth.depends import get_user
|
||||
|
||||
|
||||
def _make_test_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
app.dependency_overrides[get_user] = lambda: SimpleNamespace(
|
||||
id=1,
|
||||
provider_id="provider-1",
|
||||
selected_organization_id=11,
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
def test_create_workflow_rejects_invalid_trigger_path_before_db_write():
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
with patch("api.routes.workflow.db_client") as mock_db:
|
||||
response = client.post(
|
||||
"/workflow/create/definition",
|
||||
json={
|
||||
"name": "Support Agent",
|
||||
"workflow_definition": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "trigger-1",
|
||||
"type": "trigger",
|
||||
"data": {"trigger_path": "support/west"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
detail = response.json()["detail"]
|
||||
assert detail["is_valid"] is False
|
||||
assert detail["errors"][0]["field"] == "data.trigger_path"
|
||||
assert "single URL path segment" in detail["errors"][0]["message"]
|
||||
assert mock_db.mock_calls == []
|
||||
52
api/tests/test_workflow_list_route.py
Normal file
52
api/tests/test_workflow_list_route.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.routes.workflow import router
|
||||
from api.services.auth.depends import get_user
|
||||
|
||||
|
||||
def _make_test_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
app.dependency_overrides[get_user] = lambda: SimpleNamespace(
|
||||
id=1,
|
||||
selected_organization_id=11,
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
def test_workflow_fetch_list_includes_workflow_uuid():
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
workflow = SimpleNamespace(
|
||||
id=5,
|
||||
name="Sales Agent",
|
||||
status="active",
|
||||
created_at=datetime(2026, 5, 22, 10, 30, tzinfo=timezone.utc),
|
||||
folder_id=3,
|
||||
workflow_uuid="workflow-uuid-123",
|
||||
)
|
||||
|
||||
with patch("api.routes.workflow.db_client") as mock_db:
|
||||
mock_db.get_all_workflows_for_listing = AsyncMock(return_value=[workflow])
|
||||
mock_db.get_workflow_run_counts = AsyncMock(return_value={workflow.id: 9})
|
||||
|
||||
response = client.get("/workflow/fetch")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [
|
||||
{
|
||||
"id": workflow.id,
|
||||
"name": workflow.name,
|
||||
"status": workflow.status,
|
||||
"created_at": "2026-05-22T10:30:00Z",
|
||||
"total_runs": 9,
|
||||
"folder_id": workflow.folder_id,
|
||||
"workflow_uuid": workflow.workflow_uuid,
|
||||
}
|
||||
]
|
||||
|
|
@ -182,7 +182,7 @@ class TestSaveDraft:
|
|||
workflow_definition=GRAPH_V2,
|
||||
)
|
||||
|
||||
refreshed = await db_session.get_workflow(workflow.id)
|
||||
refreshed = await db_session.get_workflow_by_id(workflow.id)
|
||||
assert refreshed.released_definition_id == original_released_id
|
||||
|
||||
async def test_save_draft_twice_updates_in_place(
|
||||
|
|
@ -264,7 +264,7 @@ class TestPublishDraft:
|
|||
|
||||
await db_session.publish_workflow_draft(workflow.id)
|
||||
|
||||
refreshed = await db_session.get_workflow(workflow.id)
|
||||
refreshed = await db_session.get_workflow_by_id(workflow.id)
|
||||
assert refreshed.released_definition_id == draft.id
|
||||
|
||||
async def test_publish_sets_published_at(self, db_session, workflow_with_v1):
|
||||
|
|
@ -346,7 +346,7 @@ class TestDiscardDraft:
|
|||
)
|
||||
await db_session.discard_workflow_draft(workflow.id)
|
||||
|
||||
refreshed = await db_session.get_workflow(workflow.id)
|
||||
refreshed = await db_session.get_workflow_by_id(workflow.id)
|
||||
assert refreshed.released_definition_id == original_released_id
|
||||
|
||||
async def test_discard_when_no_draft_raises(self, db_session, workflow_with_v1):
|
||||
|
|
@ -464,7 +464,7 @@ class TestRevert:
|
|||
|
||||
await db_session.revert_to_version(workflow.id, v1_id)
|
||||
|
||||
refreshed = await db_session.get_workflow(workflow.id)
|
||||
refreshed = await db_session.get_workflow_by_id(workflow.id)
|
||||
assert refreshed.released_definition_id == v2.id # still V2
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue