mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: add chat based testing for voice agent (#308)
* feat: add backend foundations * feat: add text chat UI * chore: simplify the reload behaviour * fix: fix upgrade banner to be triggered after package upload * feat: simplify TesterPanel design * chore: fix formatting and generate client * chore: fix tracing for text chat mode * fix: fix revert and edit CTA * refactor: refactor TesterPanel into smaller components * feat: enable runtime transition of nodes * fix: fix review comments
This commit is contained in:
parent
67479e98fd
commit
d97d1d72cd
96 changed files with 7630 additions and 1684 deletions
|
|
@ -41,6 +41,10 @@ api/
|
|||
- Telephony is a full subsystem under `services/telephony/`, with provider-specific packages under `services/telephony/providers/`
|
||||
- Integrations extend through `services/integrations/`; package-specific rules should live in that subtree's own `AGENTS.md`
|
||||
|
||||
## Routes vs Service Layer
|
||||
|
||||
**Keep route handlers thin** — parse/validate the request, resolve auth and `organization_id`, delegate, shape the response. Domain logic (orchestration, business rules, external calls, computation) belongs in `services/`. Before adding logic to a handler, find its home: extend an existing `services/<domain>/` module that owns the concern (see *Where to Find Things*) before adding a focused new module; never a catch-all. Keep DB access in `db/` clients — routes call services, services call DB clients. Litmus test: if `tasks/`, `mcp_server/`, or another route could reuse it, it must live in `services/` to be importable.
|
||||
|
||||
## Database Migrations
|
||||
|
||||
```bash
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
Revision ID: 19d2a4b6c8ef
|
||||
Revises: 0a1b2c3d4e5f
|
||||
|
||||
Create Date: 2026-05-19 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,64 @@
|
|||
"""add workflow_run_text_sessions
|
||||
|
||||
Revision ID: 2f638891cbb6
|
||||
Revises: 19d2a4b6c8ef
|
||||
Create Date: 2026-05-18 12:58:58.573381
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "2f638891cbb6"
|
||||
down_revision: Union[str, None] = "19d2a4b6c8ef"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"workflow_run_text_sessions",
|
||||
sa.Column("workflow_run_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"revision", sa.Integer(), server_default=sa.text("0"), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"session_data",
|
||||
sa.JSON(),
|
||||
server_default=sa.text("'{}'::json"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"checkpoint",
|
||||
sa.JSON(),
|
||||
server_default=sa.text("'{}'::json"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["workflow_run_id"], ["workflow_runs.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("workflow_run_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_workflow_run_text_sessions_updated_at",
|
||||
"workflow_run_text_sessions",
|
||||
["updated_at"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(
|
||||
"ix_workflow_run_text_sessions_updated_at",
|
||||
table_name="workflow_run_text_sessions",
|
||||
)
|
||||
op.drop_table("workflow_run_text_sessions")
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -16,12 +16,14 @@ from api.db.webhook_credential_client import WebhookCredentialClient
|
|||
from api.db.workflow_client import WorkflowClient
|
||||
from api.db.workflow_recording_client import WorkflowRecordingClient
|
||||
from api.db.workflow_run_client import WorkflowRunClient
|
||||
from api.db.workflow_run_text_session_client import WorkflowRunTextSessionClient
|
||||
from api.db.workflow_template_client import WorkflowTemplateClient
|
||||
|
||||
|
||||
class DBClient(
|
||||
WorkflowClient,
|
||||
WorkflowRunClient,
|
||||
WorkflowRunTextSessionClient,
|
||||
UserClient,
|
||||
OrganizationClient,
|
||||
OrganizationConfigurationClient,
|
||||
|
|
|
|||
|
|
@ -484,6 +484,12 @@ class WorkflowRunModel(Base):
|
|||
queued_run_id = Column(Integer, ForeignKey("queued_runs.id"), nullable=True)
|
||||
queued_run = relationship("QueuedRunModel", foreign_keys=[queued_run_id])
|
||||
public_access_token = Column(String(36), nullable=True)
|
||||
text_session = relationship(
|
||||
"WorkflowRunTextSessionModel",
|
||||
back_populates="workflow_run",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
|
|
@ -503,6 +509,43 @@ class WorkflowRunModel(Base):
|
|||
)
|
||||
|
||||
|
||||
class WorkflowRunTextSessionModel(Base):
|
||||
__tablename__ = "workflow_run_text_sessions"
|
||||
|
||||
workflow_run_id = Column(
|
||||
Integer,
|
||||
ForeignKey("workflow_runs.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
workflow_run = relationship("WorkflowRunModel", back_populates="text_session")
|
||||
revision = Column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=0,
|
||||
server_default=text("0"),
|
||||
)
|
||||
session_data = Column(
|
||||
JSON,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default=text("'{}'::json"),
|
||||
)
|
||||
checkpoint = Column(
|
||||
JSON,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default=text("'{}'::json"),
|
||||
)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
__table_args__ = (Index("ix_workflow_run_text_sessions_updated_at", "updated_at"),)
|
||||
|
||||
|
||||
class OrganizationUsageCycleModel(Base):
|
||||
"""
|
||||
This model is used to track the usage of Dograh tokens for an organization for a given usage
|
||||
|
|
|
|||
|
|
@ -151,9 +151,9 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
async def update_usage_after_run(
|
||||
self,
|
||||
organization_id: int,
|
||||
actual_tokens: int,
|
||||
duration_seconds: int = 0,
|
||||
charge_usd: float = None,
|
||||
actual_tokens: float,
|
||||
duration_seconds: float = 0,
|
||||
charge_usd: float | None = None,
|
||||
) -> None:
|
||||
"""Update usage after a workflow run completes with actual token count and duration.
|
||||
|
||||
|
|
@ -354,6 +354,7 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
"caller_number": caller_number,
|
||||
"called_number": called_number,
|
||||
"call_type": run.call_type,
|
||||
"mode": run.mode,
|
||||
"disposition": disposition,
|
||||
"initial_context": run.initial_context,
|
||||
"gathered_context": run.gathered_context,
|
||||
|
|
|
|||
|
|
@ -32,16 +32,22 @@ class WorkflowRunClient(BaseDBClient):
|
|||
campaign_id: int = None,
|
||||
queued_run_id: int = None,
|
||||
use_draft: bool = False,
|
||||
organization_id: int | None = None,
|
||||
) -> WorkflowRunModel:
|
||||
async with self.async_session() as session:
|
||||
# Get workflow and user to check organization
|
||||
workflow = await session.execute(
|
||||
workflow_query = (
|
||||
select(WorkflowModel)
|
||||
.options(joinedload(WorkflowModel.user))
|
||||
.where(
|
||||
WorkflowModel.id == workflow_id, WorkflowModel.user_id == user_id
|
||||
)
|
||||
)
|
||||
if organization_id is not None:
|
||||
workflow_query = workflow_query.where(
|
||||
WorkflowModel.organization_id == organization_id
|
||||
)
|
||||
|
||||
workflow = await session.execute(workflow_query)
|
||||
workflow = workflow.scalars().first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
||||
|
|
|
|||
124
api/db/workflow_run_text_session_client.py
Normal file
124
api/db/workflow_run_text_session_client.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import (
|
||||
WorkflowModel,
|
||||
WorkflowRunModel,
|
||||
WorkflowRunTextSessionModel,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowRunTextSessionRevisionConflictError(Exception):
|
||||
def __init__(self, expected_revision: int, actual_revision: int):
|
||||
self.expected_revision = expected_revision
|
||||
self.actual_revision = actual_revision
|
||||
super().__init__(
|
||||
"Workflow run text session revision conflict: "
|
||||
f"expected {expected_revision}, found {actual_revision}"
|
||||
)
|
||||
|
||||
|
||||
class WorkflowRunTextSessionClient(BaseDBClient):
|
||||
async def ensure_workflow_run_text_session(
|
||||
self,
|
||||
workflow_run_id: int,
|
||||
session_data: dict | None = None,
|
||||
checkpoint: dict | None = None,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowRunTextSessionModel)
|
||||
.where(WorkflowRunTextSessionModel.workflow_run_id == workflow_run_id)
|
||||
.with_for_update()
|
||||
)
|
||||
text_session = result.scalars().first()
|
||||
if text_session:
|
||||
return text_session
|
||||
|
||||
run_result = await session.execute(
|
||||
select(WorkflowRunModel).where(WorkflowRunModel.id == workflow_run_id)
|
||||
)
|
||||
workflow_run = run_result.scalars().first()
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run with ID {workflow_run_id} not found")
|
||||
|
||||
text_session = WorkflowRunTextSessionModel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
session_data=session_data or {},
|
||||
checkpoint=checkpoint or {},
|
||||
)
|
||||
session.add(text_session)
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(text_session)
|
||||
return text_session
|
||||
|
||||
async def get_workflow_run_text_session(
|
||||
self,
|
||||
workflow_run_id: int,
|
||||
*,
|
||||
organization_id: int,
|
||||
) -> WorkflowRunTextSessionModel | None:
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(WorkflowRunTextSessionModel)
|
||||
.options(
|
||||
joinedload(WorkflowRunTextSessionModel.workflow_run).joinedload(
|
||||
WorkflowRunModel.workflow
|
||||
)
|
||||
)
|
||||
.join(WorkflowRunTextSessionModel.workflow_run)
|
||||
.join(WorkflowRunModel.workflow)
|
||||
.where(WorkflowRunTextSessionModel.workflow_run_id == workflow_run_id)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().first()
|
||||
|
||||
async def update_workflow_run_text_session(
|
||||
self,
|
||||
workflow_run_id: int,
|
||||
*,
|
||||
session_data: dict | None = None,
|
||||
checkpoint: dict | None = None,
|
||||
expected_revision: int | None = None,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowRunTextSessionModel)
|
||||
.where(WorkflowRunTextSessionModel.workflow_run_id == workflow_run_id)
|
||||
.with_for_update()
|
||||
)
|
||||
text_session = result.scalars().first()
|
||||
if not text_session:
|
||||
raise ValueError(
|
||||
f"Workflow run text session with run ID {workflow_run_id} not found"
|
||||
)
|
||||
|
||||
if (
|
||||
expected_revision is not None
|
||||
and text_session.revision != expected_revision
|
||||
):
|
||||
raise WorkflowRunTextSessionRevisionConflictError(
|
||||
expected_revision=expected_revision,
|
||||
actual_revision=text_session.revision,
|
||||
)
|
||||
|
||||
if session_data is not None:
|
||||
text_session.session_data = session_data
|
||||
if checkpoint is not None:
|
||||
text_session.checkpoint = checkpoint
|
||||
text_session.revision += 1
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(text_session)
|
||||
return text_session
|
||||
|
|
@ -27,6 +27,7 @@ class WorkflowRunMode(Enum):
|
|||
TELNYX = "telnyx"
|
||||
WEBRTC = "webrtc"
|
||||
SMALLWEBRTC = "smallwebrtc"
|
||||
TEXTCHAT = "textchat"
|
||||
|
||||
# Historical, not used anymore. Don't
|
||||
# use and don't remove
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from api.routes.webrtc_signaling import router as webrtc_signaling_router
|
|||
from api.routes.workflow import router as workflow_router
|
||||
from api.routes.workflow_embed import router as workflow_embed_router
|
||||
from api.routes.workflow_recording import router as workflow_recording_router
|
||||
from api.routes.workflow_text_chat import router as workflow_text_chat_router
|
||||
from api.services.integrations import all_routers
|
||||
|
||||
router = APIRouter(
|
||||
|
|
@ -35,6 +36,7 @@ router = APIRouter(
|
|||
router.include_router(telephony_router)
|
||||
router.include_router(superuser_router)
|
||||
router.include_router(workflow_router)
|
||||
router.include_router(workflow_text_chat_router)
|
||||
router.include_router(user_router)
|
||||
router.include_router(campaign_router)
|
||||
router.include_router(credentials_router)
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ class WorkflowRunUsageResponse(BaseModel):
|
|||
caller_number: Optional[str] = None
|
||||
called_number: Optional[str] = None
|
||||
call_type: Optional[str] = None
|
||||
mode: Optional[str] = None
|
||||
disposition: Optional[str] = None
|
||||
initial_context: Optional[Dict[str, Any]] = None
|
||||
gathered_context: Optional[Dict[str, Any]] = None
|
||||
|
|
|
|||
|
|
@ -153,6 +153,7 @@ async def initiate_call(
|
|||
"telephony_configuration_id": telephony_configuration_id,
|
||||
},
|
||||
use_draft=True,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
workflow_run_id = workflow_run.id
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -53,9 +53,9 @@ router = APIRouter(prefix="/ws")
|
|||
class NonRelayFilterPolicy(Enum):
|
||||
"""What to filter from non-relay ICE candidates. Relay candidates always pass."""
|
||||
|
||||
NONE = "none" # filter nothing — pass all candidates
|
||||
NONE = "none" # filter nothing — pass all candidates
|
||||
PRIVATE = "private" # filter non-relay candidates with private/CGNAT IPs
|
||||
ALL = "all" # filter all non-relay candidates (relay-only mode)
|
||||
ALL = "all" # filter all non-relay candidates (relay-only mode)
|
||||
|
||||
|
||||
def is_local_or_cgnat_ip(ip_str: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -1081,7 +1081,12 @@ async def create_workflow_run(
|
|||
user: The user to create the workflow run for
|
||||
"""
|
||||
run = await db_client.create_workflow_run(
|
||||
request.name, workflow_id, request.mode, user.id, use_draft=True
|
||||
request.name,
|
||||
workflow_id,
|
||||
request.mode,
|
||||
user.id,
|
||||
use_draft=True,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
return {
|
||||
"id": run.id,
|
||||
|
|
|
|||
282
api/routes/workflow_text_chat.py
Normal file
282
api/routes/workflow_text_chat.py
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel, WorkflowRunTextSessionModel
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.quota_service import check_dograh_quota
|
||||
from api.services.workflow.text_chat_session_service import (
|
||||
TextChatPendingTurnLostError,
|
||||
TextChatSessionExecutionError,
|
||||
TextChatSessionRevisionConflictError,
|
||||
TextChatTurnNotFoundError,
|
||||
append_text_chat_user_message,
|
||||
default_text_chat_checkpoint,
|
||||
default_text_chat_session_data,
|
||||
execute_pending_text_chat_turn,
|
||||
initialize_text_chat_session,
|
||||
normalize_text_chat_checkpoint,
|
||||
normalize_text_chat_session_data,
|
||||
rewind_text_chat_session_state,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/workflow", tags=["workflow-text-chat"])
|
||||
|
||||
|
||||
class CreateTextChatSessionRequest(BaseModel):
|
||||
name: str | None = None
|
||||
initial_context: Dict[str, Any] | None = None
|
||||
annotations: Dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AppendTextChatMessageRequest(BaseModel):
|
||||
text: str = Field(min_length=1)
|
||||
expected_revision: int | None = None
|
||||
|
||||
|
||||
class RewindTextChatSessionRequest(BaseModel):
|
||||
cursor_turn_id: str | None = None
|
||||
expected_revision: int | None = None
|
||||
|
||||
|
||||
class WorkflowRunTextSessionResponse(BaseModel):
|
||||
workflow_run_id: int
|
||||
workflow_id: int
|
||||
name: str
|
||||
mode: str
|
||||
state: str
|
||||
is_completed: bool
|
||||
revision: int
|
||||
initial_context: Dict[str, Any] | None = None
|
||||
gathered_context: Dict[str, Any] | None = None
|
||||
annotations: Dict[str, Any] | None = None
|
||||
session_data: Dict[str, Any]
|
||||
checkpoint: Dict[str, Any]
|
||||
created_at: datetime
|
||||
updated_at: datetime | None = None
|
||||
|
||||
|
||||
def _get_state_value(state: Any) -> str:
|
||||
return state.value if hasattr(state, "value") else str(state)
|
||||
|
||||
|
||||
def _build_response(
|
||||
text_session: WorkflowRunTextSessionModel,
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
workflow_run = text_session.workflow_run
|
||||
return WorkflowRunTextSessionResponse(
|
||||
workflow_run_id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
name=workflow_run.name,
|
||||
mode=workflow_run.mode,
|
||||
state=_get_state_value(workflow_run.state),
|
||||
is_completed=workflow_run.is_completed,
|
||||
revision=text_session.revision,
|
||||
initial_context=workflow_run.initial_context,
|
||||
gathered_context=workflow_run.gathered_context,
|
||||
annotations=workflow_run.annotations,
|
||||
session_data=normalize_text_chat_session_data(text_session.session_data),
|
||||
checkpoint=normalize_text_chat_checkpoint(text_session.checkpoint),
|
||||
created_at=text_session.created_at,
|
||||
updated_at=text_session.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _revision_conflict_detail(e: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"message": "Text chat session revision conflict",
|
||||
"expected_revision": e.expected_revision,
|
||||
"actual_revision": e.actual_revision,
|
||||
}
|
||||
|
||||
|
||||
def _require_selected_organization_id(user: UserModel) -> int:
|
||||
if user.selected_organization_id is None:
|
||||
raise HTTPException(status_code=403, detail="Organization context is required")
|
||||
return user.selected_organization_id
|
||||
|
||||
|
||||
async def _ensure_text_chat_quota(user: UserModel, workflow_id: int) -> None:
|
||||
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
|
||||
async def _load_text_session_or_404(
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
user: UserModel,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
set_current_run_id(run_id)
|
||||
organization_id = _require_selected_organization_id(user)
|
||||
text_session = await db_client.get_workflow_run_text_session(
|
||||
run_id, organization_id=organization_id
|
||||
)
|
||||
if not text_session or not text_session.workflow_run:
|
||||
raise HTTPException(status_code=404, detail="Text chat session not found")
|
||||
if text_session.workflow_run.workflow_id != workflow_id:
|
||||
raise HTTPException(status_code=404, detail="Text chat session not found")
|
||||
if text_session.workflow_run.mode != WorkflowRunMode.TEXTCHAT.value:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Workflow run is not a text chat session"
|
||||
)
|
||||
return text_session
|
||||
|
||||
|
||||
async def _execute_pending_turn_response(
|
||||
*,
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
text_session: WorkflowRunTextSessionModel,
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
try:
|
||||
updated_text_session = await execute_pending_text_chat_turn(
|
||||
workflow_id=workflow_id,
|
||||
run_id=run_id,
|
||||
text_session=text_session,
|
||||
)
|
||||
except TextChatSessionRevisionConflictError as e:
|
||||
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
|
||||
except TextChatPendingTurnLostError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
except TextChatSessionExecutionError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return _build_response(updated_text_session)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{workflow_id}/text-chat/sessions",
|
||||
response_model=WorkflowRunTextSessionResponse,
|
||||
)
|
||||
async def create_text_chat_session(
|
||||
workflow_id: int,
|
||||
request: CreateTextChatSessionRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
organization_id = _require_selected_organization_id(user)
|
||||
await _ensure_text_chat_quota(user, workflow_id)
|
||||
|
||||
session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}"
|
||||
try:
|
||||
workflow_run = await db_client.create_workflow_run(
|
||||
name=session_name,
|
||||
workflow_id=workflow_id,
|
||||
mode=WorkflowRunMode.TEXTCHAT.value,
|
||||
user_id=user.id,
|
||||
initial_context=request.initial_context,
|
||||
use_draft=True,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
set_current_run_id(workflow_run.id)
|
||||
|
||||
annotations = {
|
||||
"tester": {
|
||||
"source": "workflow_editor",
|
||||
"modality": "text",
|
||||
}
|
||||
}
|
||||
if request.annotations:
|
||||
annotations = {**annotations, **request.annotations}
|
||||
workflow_run = await db_client.update_workflow_run(
|
||||
workflow_run.id,
|
||||
annotations=annotations,
|
||||
)
|
||||
|
||||
text_session = await db_client.ensure_workflow_run_text_session(
|
||||
workflow_run.id,
|
||||
session_data=default_text_chat_session_data(),
|
||||
checkpoint=default_text_chat_checkpoint(),
|
||||
)
|
||||
|
||||
try:
|
||||
text_session = await initialize_text_chat_session(
|
||||
run_id=workflow_run.id,
|
||||
text_session=text_session,
|
||||
)
|
||||
except TextChatSessionRevisionConflictError as e:
|
||||
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
|
||||
|
||||
return await _execute_pending_turn_response(
|
||||
workflow_id=workflow_id,
|
||||
run_id=workflow_run.id,
|
||||
text_session=text_session,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{workflow_id}/text-chat/sessions/{run_id}",
|
||||
response_model=WorkflowRunTextSessionResponse,
|
||||
)
|
||||
async def get_text_chat_session(
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
return _build_response(text_session)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{workflow_id}/text-chat/sessions/{run_id}/messages",
|
||||
response_model=WorkflowRunTextSessionResponse,
|
||||
)
|
||||
async def append_text_chat_message(
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
request: AppendTextChatMessageRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
await _ensure_text_chat_quota(user, workflow_id)
|
||||
|
||||
try:
|
||||
text_session = await append_text_chat_user_message(
|
||||
run_id=run_id,
|
||||
text_session=text_session,
|
||||
user_text=request.text,
|
||||
expected_revision=request.expected_revision,
|
||||
)
|
||||
except TextChatSessionRevisionConflictError as e:
|
||||
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
|
||||
|
||||
return await _execute_pending_turn_response(
|
||||
workflow_id=workflow_id,
|
||||
run_id=run_id,
|
||||
text_session=text_session,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{workflow_id}/text-chat/sessions/{run_id}/rewind",
|
||||
response_model=WorkflowRunTextSessionResponse,
|
||||
)
|
||||
async def rewind_text_chat_session(
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
request: RewindTextChatSessionRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
try:
|
||||
text_session = await rewind_text_chat_session_state(
|
||||
run_id=run_id,
|
||||
text_session=text_session,
|
||||
cursor_turn_id=request.cursor_turn_id,
|
||||
expected_revision=request.expected_revision,
|
||||
)
|
||||
except TextChatTurnNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except TextChatSessionRevisionConflictError as e:
|
||||
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
|
||||
|
||||
return _build_response(text_session)
|
||||
|
|
@ -7,7 +7,7 @@ from api.enums import PostHogEvent, WorkflowRunState
|
|||
from api.services.campaign.circuit_breaker import circuit_breaker
|
||||
from api.services.integrations import IntegrationRuntimeSession
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.pipecat.audio_playback import play_audio, play_audio_loop
|
||||
from api.services.pipecat.audio_playback import play_audio_loop
|
||||
from api.services.pipecat.in_memory_buffers import (
|
||||
InMemoryAudioBuffer,
|
||||
InMemoryLogsBuffer,
|
||||
|
|
@ -20,8 +20,6 @@ from api.tasks.arq import enqueue_job
|
|||
from api.tasks.function_names import FunctionNames
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMContextFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
|
|
@ -69,7 +67,6 @@ def register_event_handlers(
|
|||
pipeline_metrics_aggregator: PipelineMetricsAggregator,
|
||||
audio_config=AudioConfig,
|
||||
pre_call_fetch_task: asyncio.Task | None = None,
|
||||
fetch_recording_audio=None,
|
||||
user_provider_id: str | None = None,
|
||||
integration_runtime_sessions: list[IntegrationRuntimeSession] | None = None,
|
||||
):
|
||||
|
|
@ -99,20 +96,11 @@ def register_event_handlers(
|
|||
"initial_response_triggered": False,
|
||||
}
|
||||
|
||||
async def queue_initial_llm_context():
|
||||
# Queue LLMContextFrame after the VoicemailDetector since the detector
|
||||
# gates LLMContextFrames until voicemail detection completes. We also
|
||||
# don't want to trigger the Voicemail LLM with this initial frame.
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
|
||||
async def maybe_trigger_initial_response():
|
||||
"""Start the conversation after both pipeline_started and client_connected events.
|
||||
|
||||
If a pre-call fetch is in progress, plays a ringer while waiting for the
|
||||
response, then merges the result into the call context before proceeding.
|
||||
|
||||
If the start node has a greeting configured, play it directly via TTS.
|
||||
Otherwise, trigger an LLM generation for the opening message.
|
||||
"""
|
||||
if (
|
||||
ready_state["pipeline_started"]
|
||||
|
|
@ -167,46 +155,11 @@ def register_event_handlers(
|
|||
# Set the start node now (after pre-call fetch data is merged)
|
||||
# so that render_template() has the complete _call_context_vars.
|
||||
await engine.set_node(engine.workflow.start_node_id)
|
||||
|
||||
greeting_info = engine.get_start_greeting()
|
||||
if greeting_info:
|
||||
greeting_type, greeting_value = greeting_info
|
||||
if (
|
||||
greeting_type == "audio"
|
||||
and greeting_value
|
||||
and fetch_recording_audio
|
||||
):
|
||||
logger.debug(f"Playing audio greeting recording: {greeting_value}")
|
||||
result = await fetch_recording_audio(
|
||||
recording_pk=int(greeting_value)
|
||||
)
|
||||
if result:
|
||||
await play_audio(
|
||||
result.audio,
|
||||
sample_rate=audio_config.pipeline_sample_rate or 16000,
|
||||
queue_frame=transport.output().queue_frame,
|
||||
transcript=result.transcript,
|
||||
append_to_context=True,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to fetch audio greeting {greeting_value}, "
|
||||
"falling back to LLM generation"
|
||||
)
|
||||
await queue_initial_llm_context()
|
||||
else:
|
||||
logger.debug("Playing text greeting via TTS")
|
||||
# append_to_context=True so the assistant aggregator commits
|
||||
# the greeting to the LLM context once TTS finishes; without
|
||||
# it the LLM would re-greet on its first generation.
|
||||
await task.queue_frame(
|
||||
TTSSpeakFrame(greeting_value, append_to_context=True)
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Both pipeline_started and client_connected received - triggering initial LLM generation"
|
||||
)
|
||||
await queue_initial_llm_context()
|
||||
await engine.queue_node_opening(
|
||||
node_id=engine.workflow.start_node_id,
|
||||
previous_node_id=None,
|
||||
generate_if_no_greeting=True,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(_transport, _participant):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,10 @@ from typing import List, Optional
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from api.services.pipecat.realtime_feedback_events import (
|
||||
realtime_feedback_event_sort_key,
|
||||
stamp_realtime_feedback_event,
|
||||
)
|
||||
from api.utils.transcript import generate_transcript_text as _generate_transcript_text
|
||||
from pipecat.utils.enums import RealtimeFeedbackType
|
||||
|
||||
|
|
@ -98,16 +102,13 @@ class InMemoryLogsBuffer:
|
|||
|
||||
async def append(self, event: dict):
|
||||
"""Append a feedback event to the buffer with timestamp and current node."""
|
||||
# Add timestamp, turn tracking, and current node
|
||||
timestamped_event = {
|
||||
**event,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"turn": self._turn_counter,
|
||||
}
|
||||
if self._current_node_id:
|
||||
timestamped_event["node_id"] = self._current_node_id
|
||||
if self._current_node_name:
|
||||
timestamped_event["node_name"] = self._current_node_name
|
||||
timestamped_event = stamp_realtime_feedback_event(
|
||||
event,
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
turn=self._turn_counter,
|
||||
node_id=self._current_node_id,
|
||||
node_name=self._current_node_name,
|
||||
)
|
||||
self._events.append(timestamped_event)
|
||||
logger.trace(
|
||||
f"Appended event {event.get('type')} to logs buffer for workflow {self._workflow_run_id}"
|
||||
|
|
@ -120,17 +121,12 @@ class InMemoryLogsBuffer:
|
|||
f"Incremented turn counter to {self._turn_counter} for workflow {self._workflow_run_id}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _event_sort_key(event: dict) -> str:
|
||||
payload_ts = event.get("payload", {}).get("timestamp")
|
||||
return payload_ts or event.get("timestamp", "")
|
||||
|
||||
def _sorted_events(self) -> List[dict]:
|
||||
# Stable sort by the realtime (payload) timestamp when available, falling
|
||||
# back to the buffer-append timestamp. Python's sort is stable, so events
|
||||
# sharing a key retain their original insertion order — this keeps
|
||||
# consecutive bot-text chunks of a single turn contiguous.
|
||||
return sorted(self._events, key=self._event_sort_key)
|
||||
return sorted(self._events, key=realtime_feedback_event_sort_key)
|
||||
|
||||
def get_events(self) -> List[dict]:
|
||||
"""Get all events for final storage, ordered by realtime timestamp."""
|
||||
|
|
|
|||
|
|
@ -152,8 +152,30 @@ def build_realtime_pipeline(
|
|||
return Pipeline(processors)
|
||||
|
||||
|
||||
def create_pipeline_task(pipeline, workflow_run_id, audio_config: AudioConfig = None):
|
||||
"""Create a pipeline task with appropriate parameters"""
|
||||
def create_pipeline_task(
|
||||
pipeline,
|
||||
workflow_run_id,
|
||||
audio_config: AudioConfig = None,
|
||||
*,
|
||||
conversation_parent_context=None,
|
||||
conversation_type: str = "voice",
|
||||
additional_span_attributes: dict | None = None,
|
||||
):
|
||||
"""Create a pipeline task with appropriate parameters.
|
||||
|
||||
Args:
|
||||
pipeline: The pipeline to run.
|
||||
workflow_run_id: Run id, used as the conversation id.
|
||||
audio_config: Optional audio configuration.
|
||||
conversation_parent_context: Optional OTEL context carrying a fixed
|
||||
trace id. When provided, the conversation span attaches to that
|
||||
trace instead of starting a new root trace (used by text chat to
|
||||
stitch every per-turn pipeline into one trace).
|
||||
conversation_type: ``conversation.type`` span attribute value.
|
||||
additional_span_attributes: Extra attributes set on the conversation
|
||||
span (e.g. ``langfuse.trace.name`` to name a stitched trace that
|
||||
has no real root span).
|
||||
"""
|
||||
# Set up pipeline params with audio configuration if provided
|
||||
pipeline_params = PipelineParams(
|
||||
enable_metrics=True,
|
||||
|
|
@ -178,6 +200,9 @@ def create_pipeline_task(pipeline, workflow_run_id, audio_config: AudioConfig =
|
|||
enable_tracing=True,
|
||||
enable_rtvi=False,
|
||||
conversation_id=f"{workflow_run_id}",
|
||||
conversation_parent_context=conversation_parent_context,
|
||||
conversation_type=conversation_type,
|
||||
additional_span_attributes=additional_span_attributes,
|
||||
)
|
||||
|
||||
# Check if turn logging is enabled
|
||||
|
|
|
|||
163
api/services/pipecat/realtime_feedback_events.py
Normal file
163
api/services/pipecat/realtime_feedback_events.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
"""Shared helpers for building and ordering realtime feedback events."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pipecat.utils.enums import RealtimeFeedbackType
|
||||
|
||||
|
||||
def build_node_transition_event(
|
||||
*,
|
||||
node_id: str | None,
|
||||
node_name: str | None,
|
||||
previous_node_id: str | None,
|
||||
previous_node_name: str | None,
|
||||
allow_interrupt: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"type": RealtimeFeedbackType.NODE_TRANSITION.value,
|
||||
"payload": {
|
||||
"node_id": node_id,
|
||||
"node_name": node_name,
|
||||
"previous_node_id": previous_node_id,
|
||||
"previous_node_name": previous_node_name,
|
||||
"allow_interrupt": allow_interrupt,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def build_user_transcription_event(
|
||||
*,
|
||||
text: str,
|
||||
final: bool,
|
||||
timestamp: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"text": text,
|
||||
"final": final,
|
||||
}
|
||||
if timestamp is not None:
|
||||
payload["timestamp"] = timestamp
|
||||
if user_id is not None:
|
||||
payload["user_id"] = user_id
|
||||
return {
|
||||
"type": RealtimeFeedbackType.USER_TRANSCRIPTION.value,
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
|
||||
def build_bot_text_event(
|
||||
*,
|
||||
text: str,
|
||||
timestamp: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {"text": text}
|
||||
if timestamp is not None:
|
||||
payload["timestamp"] = timestamp
|
||||
return {
|
||||
"type": RealtimeFeedbackType.BOT_TEXT.value,
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
|
||||
def build_function_call_start_event(
|
||||
*,
|
||||
function_name: str | None,
|
||||
tool_call_id: str | None,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"function_name": function_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
if arguments is not None:
|
||||
payload["arguments"] = arguments
|
||||
return {
|
||||
"type": RealtimeFeedbackType.FUNCTION_CALL_START.value,
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
|
||||
def serialize_realtime_feedback_tool_result(result: Any) -> str | None:
|
||||
"""Normalize function-call results to the string shape stored in logs."""
|
||||
if result is None:
|
||||
return None
|
||||
return str(result)
|
||||
|
||||
|
||||
def build_function_call_end_event(
|
||||
*,
|
||||
function_name: str | None,
|
||||
tool_call_id: str | None,
|
||||
result: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"type": RealtimeFeedbackType.FUNCTION_CALL_END.value,
|
||||
"payload": {
|
||||
"function_name": function_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
"result": serialize_realtime_feedback_tool_result(result),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def build_ttfb_metric_event(
|
||||
*,
|
||||
ttfb_seconds: float,
|
||||
processor: str | None,
|
||||
model: str | None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"type": RealtimeFeedbackType.TTFB_METRIC.value,
|
||||
"payload": {
|
||||
"ttfb_seconds": ttfb_seconds,
|
||||
"processor": processor,
|
||||
"model": model,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def build_pipeline_error_event(
|
||||
*,
|
||||
error: str,
|
||||
fatal: bool,
|
||||
processor: str | None = None,
|
||||
extra_payload: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"error": error,
|
||||
"fatal": fatal,
|
||||
}
|
||||
if processor is not None:
|
||||
payload["processor"] = processor
|
||||
if extra_payload:
|
||||
payload.update(extra_payload)
|
||||
return {
|
||||
"type": RealtimeFeedbackType.PIPELINE_ERROR.value,
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
|
||||
def stamp_realtime_feedback_event(
|
||||
event: dict[str, Any],
|
||||
*,
|
||||
timestamp: str | None = None,
|
||||
turn: int | None = None,
|
||||
node_id: str | None = None,
|
||||
node_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
stamped = dict(event)
|
||||
if timestamp is not None:
|
||||
stamped["timestamp"] = timestamp
|
||||
if turn is not None:
|
||||
stamped["turn"] = turn
|
||||
if node_id is not None:
|
||||
stamped["node_id"] = node_id
|
||||
if node_name is not None:
|
||||
stamped["node_name"] = node_name
|
||||
return stamped
|
||||
|
||||
|
||||
def realtime_feedback_event_sort_key(event: dict[str, Any]) -> str:
|
||||
payload_timestamp = (event.get("payload") or {}).get("timestamp")
|
||||
return payload_timestamp or event.get("timestamp") or ""
|
||||
|
|
@ -27,6 +27,15 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Set
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from api.services.pipecat.realtime_feedback_events import (
|
||||
build_bot_text_event,
|
||||
build_function_call_end_event,
|
||||
build_function_call_start_event,
|
||||
build_pipeline_error_event,
|
||||
build_ttfb_metric_event,
|
||||
build_user_transcription_event,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.pipecat.in_memory_buffers import InMemoryLogsBuffer
|
||||
|
||||
|
|
@ -211,29 +220,23 @@ class RealtimeFeedbackObserver(BaseObserver):
|
|||
# Handle user transcriptions (interim) - WebSocket only
|
||||
elif isinstance(frame, InterimTranscriptionFrame):
|
||||
await self._send_ws(
|
||||
{
|
||||
"type": RealtimeFeedbackType.USER_TRANSCRIPTION.value,
|
||||
"payload": {
|
||||
"text": frame.text,
|
||||
"final": False,
|
||||
"user_id": frame.user_id,
|
||||
"timestamp": frame.timestamp,
|
||||
},
|
||||
}
|
||||
build_user_transcription_event(
|
||||
text=frame.text,
|
||||
final=False,
|
||||
user_id=frame.user_id,
|
||||
timestamp=frame.timestamp,
|
||||
)
|
||||
)
|
||||
# Handle user transcriptions (final) - WebSocket only
|
||||
# Complete turn text is persisted via register_turn_handlers
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
await self._send_ws(
|
||||
{
|
||||
"type": RealtimeFeedbackType.USER_TRANSCRIPTION.value,
|
||||
"payload": {
|
||||
"text": frame.text,
|
||||
"final": True,
|
||||
"user_id": frame.user_id,
|
||||
"timestamp": frame.timestamp,
|
||||
},
|
||||
}
|
||||
build_user_transcription_event(
|
||||
text=frame.text,
|
||||
final=True,
|
||||
user_id=frame.user_id,
|
||||
timestamp=frame.timestamp,
|
||||
)
|
||||
)
|
||||
# Handle engine-queued speech (transition/tool messages) marked for
|
||||
# log persistence. The downstream TTSTextFrame(s) from the TTS service
|
||||
|
|
@ -241,23 +244,13 @@ class RealtimeFeedbackObserver(BaseObserver):
|
|||
# to avoid word-level log entries from word-timestamp providers.
|
||||
elif isinstance(frame, TTSSpeakFrame):
|
||||
if getattr(frame, "persist_to_logs", False):
|
||||
await self._append_to_buffer(
|
||||
{
|
||||
"type": RealtimeFeedbackType.BOT_TEXT.value,
|
||||
"payload": {"text": frame.text},
|
||||
}
|
||||
)
|
||||
await self._append_to_buffer(build_bot_text_event(text=frame.text))
|
||||
# Handle bot TTS text - respect pts timing, WebSocket only
|
||||
# Complete turn text is persisted via register_turn_handlers,
|
||||
# except for frames explicitly flagged persist_to_logs (e.g. recording
|
||||
# transcripts from play_audio) which bypass the aggregator path.
|
||||
elif isinstance(frame, TTSTextFrame):
|
||||
message = {
|
||||
"type": RealtimeFeedbackType.BOT_TEXT.value,
|
||||
"payload": {
|
||||
"text": frame.text,
|
||||
},
|
||||
}
|
||||
message = build_bot_text_event(text=frame.text)
|
||||
|
||||
# If frame has pts, queue it for timed delivery
|
||||
if frame.pts:
|
||||
|
|
@ -280,13 +273,11 @@ class RealtimeFeedbackObserver(BaseObserver):
|
|||
and frame_direction == FrameDirection.DOWNSTREAM
|
||||
):
|
||||
await self._send_message(
|
||||
{
|
||||
"type": RealtimeFeedbackType.FUNCTION_CALL_START.value,
|
||||
"payload": {
|
||||
"function_name": frame.function_name,
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
},
|
||||
}
|
||||
build_function_call_start_event(
|
||||
function_name=frame.function_name,
|
||||
tool_call_id=frame.tool_call_id,
|
||||
arguments=dict(frame.arguments or {}),
|
||||
)
|
||||
)
|
||||
# Handle function call result
|
||||
elif (
|
||||
|
|
@ -294,14 +285,11 @@ class RealtimeFeedbackObserver(BaseObserver):
|
|||
and frame_direction == FrameDirection.DOWNSTREAM
|
||||
):
|
||||
await self._send_message(
|
||||
{
|
||||
"type": RealtimeFeedbackType.FUNCTION_CALL_END.value,
|
||||
"payload": {
|
||||
"function_name": frame.function_name,
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
"result": str(frame.result) if frame.result else None,
|
||||
},
|
||||
}
|
||||
build_function_call_end_event(
|
||||
function_name=frame.function_name,
|
||||
tool_call_id=frame.tool_call_id,
|
||||
result=frame.result,
|
||||
)
|
||||
)
|
||||
# Handle TTFB metrics - capture LLM generation time only
|
||||
elif isinstance(frame, MetricsFrame):
|
||||
|
|
@ -311,47 +299,42 @@ class RealtimeFeedbackObserver(BaseObserver):
|
|||
# Only send TTFB if it's from an LLM processor
|
||||
if metric_data.processor and "LLM" in metric_data.processor:
|
||||
await self._send_message(
|
||||
{
|
||||
"type": RealtimeFeedbackType.TTFB_METRIC.value,
|
||||
"payload": {
|
||||
"ttfb_seconds": metric_data.value,
|
||||
"processor": metric_data.processor,
|
||||
"model": metric_data.model,
|
||||
},
|
||||
}
|
||||
build_ttfb_metric_event(
|
||||
ttfb_seconds=metric_data.value,
|
||||
processor=metric_data.processor,
|
||||
model=metric_data.model,
|
||||
)
|
||||
)
|
||||
# Handle pipeline errors
|
||||
elif isinstance(frame, ErrorFrame):
|
||||
processor_name = str(frame.processor) if frame.processor else None
|
||||
payload = {
|
||||
"error": frame.error,
|
||||
"fatal": frame.fatal,
|
||||
"processor": processor_name,
|
||||
}
|
||||
extra_payload: dict[str, object] = {}
|
||||
# Surface structured fields when the underlying exception carries
|
||||
# them (e.g. google.genai APIError: code=1008, status=None,
|
||||
# message="Your project has been denied access...").
|
||||
exc = frame.exception
|
||||
if exc is not None:
|
||||
exc_type = type(exc).__name__
|
||||
payload["exception_type"] = exc_type
|
||||
payload["exception_message"] = str(exc)
|
||||
extra_payload["exception_type"] = exc_type
|
||||
extra_payload["exception_message"] = str(exc)
|
||||
for attr in ("code", "status", "message", "details"):
|
||||
value = getattr(exc, attr, None)
|
||||
if value is None or attr in payload:
|
||||
if value is None or attr in extra_payload:
|
||||
continue
|
||||
try:
|
||||
# Ensure the value is JSON-serializable; fall back
|
||||
# to str() for opaque objects (e.g. raw response).
|
||||
json.dumps(value)
|
||||
payload[attr] = value
|
||||
extra_payload[attr] = value
|
||||
except (TypeError, ValueError):
|
||||
payload[attr] = str(value)
|
||||
extra_payload[attr] = str(value)
|
||||
await self._send_message(
|
||||
{
|
||||
"type": RealtimeFeedbackType.PIPELINE_ERROR.value,
|
||||
"payload": payload,
|
||||
}
|
||||
build_pipeline_error_event(
|
||||
error=frame.error,
|
||||
fatal=frame.fatal,
|
||||
processor=processor_name,
|
||||
extra_payload=extra_payload or None,
|
||||
)
|
||||
)
|
||||
|
||||
async def _send_ws(self, message: dict):
|
||||
|
|
@ -401,14 +384,11 @@ def register_turn_log_handlers(
|
|||
logs_buffer.increment_turn()
|
||||
try:
|
||||
await logs_buffer.append(
|
||||
{
|
||||
"type": RealtimeFeedbackType.USER_TRANSCRIPTION.value,
|
||||
"payload": {
|
||||
"text": message.content,
|
||||
"final": True,
|
||||
"timestamp": message.timestamp,
|
||||
},
|
||||
}
|
||||
build_user_transcription_event(
|
||||
text=message.content,
|
||||
final=True,
|
||||
timestamp=message.timestamp,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to append user turn to logs buffer: {e}")
|
||||
|
|
@ -418,13 +398,10 @@ def register_turn_log_handlers(
|
|||
if message.content:
|
||||
try:
|
||||
await logs_buffer.append(
|
||||
{
|
||||
"type": RealtimeFeedbackType.BOT_TEXT.value,
|
||||
"payload": {
|
||||
"text": message.content,
|
||||
"timestamp": message.timestamp,
|
||||
},
|
||||
}
|
||||
build_bot_text_event(
|
||||
text=message.content,
|
||||
timestamp=message.timestamp,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to append assistant turn to logs buffer: {e}")
|
||||
|
|
|
|||
|
|
@ -28,6 +28,9 @@ from api.services.pipecat.pipeline_engine_callbacks_processor import (
|
|||
)
|
||||
from api.services.pipecat.pipeline_metrics_aggregator import PipelineMetricsAggregator
|
||||
from api.services.pipecat.pre_call_fetch import execute_pre_call_fetch
|
||||
from api.services.pipecat.realtime_feedback_events import (
|
||||
build_node_transition_event,
|
||||
)
|
||||
from api.services.pipecat.realtime_feedback_observer import (
|
||||
RealtimeFeedbackObserver,
|
||||
register_turn_log_handlers,
|
||||
|
|
@ -465,16 +468,13 @@ async def _run_pipeline(
|
|||
# Update current node on the buffer so subsequent events are tagged
|
||||
in_memory_logs_buffer.set_current_node(node_id, node_name)
|
||||
|
||||
message = {
|
||||
"type": RealtimeFeedbackType.NODE_TRANSITION.value,
|
||||
"payload": {
|
||||
"node_id": node_id,
|
||||
"node_name": node_name,
|
||||
"previous_node_id": previous_node_id,
|
||||
"previous_node_name": previous_node_name,
|
||||
"allow_interrupt": allow_interrupt,
|
||||
},
|
||||
}
|
||||
message = build_node_transition_event(
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
previous_node_id=previous_node_id,
|
||||
previous_node_name=previous_node_name,
|
||||
allow_interrupt=allow_interrupt,
|
||||
)
|
||||
# Send via WebSocket if available
|
||||
if ws_sender:
|
||||
try:
|
||||
|
|
@ -803,7 +803,6 @@ async def _run_pipeline(
|
|||
pipeline_metrics_aggregator=pipeline_metrics_aggregator,
|
||||
audio_config=audio_config,
|
||||
pre_call_fetch_task=pre_call_fetch_task,
|
||||
fetch_recording_audio=fetch_audio,
|
||||
user_provider_id=user_provider_id,
|
||||
integration_runtime_sessions=integration_runtime_sessions,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -254,6 +254,44 @@ async def handle_langfuse_sync(event):
|
|||
unregister_org_langfuse_credentials(org_id)
|
||||
|
||||
|
||||
def build_remote_parent_context(trace_id: str | None):
|
||||
"""Build an OTEL context whose active span carries ``trace_id``.
|
||||
|
||||
Spans started under the returned context join the Langfuse trace identified
|
||||
by ``trace_id`` (Langfuse groups observations by trace id). The parent span
|
||||
id is a non-existent placeholder, so spans created under it attach at the
|
||||
trace root rather than nesting under a real parent span.
|
||||
|
||||
This is the shared primitive behind both post-call QA tracing and text-chat
|
||||
trace stitching. Returns the context, or ``None`` when tracing is
|
||||
unavailable or ``trace_id`` is missing/invalid.
|
||||
"""
|
||||
if not trace_id:
|
||||
return None
|
||||
if not ensure_tracing():
|
||||
return None
|
||||
try:
|
||||
from opentelemetry.trace import (
|
||||
NonRecordingSpan,
|
||||
SpanContext,
|
||||
TraceFlags,
|
||||
set_span_in_context,
|
||||
)
|
||||
|
||||
parent_span_context = SpanContext(
|
||||
trace_id=int(trace_id, 16),
|
||||
span_id=0x1,
|
||||
is_remote=True,
|
||||
trace_flags=TraceFlags(0x01),
|
||||
)
|
||||
return set_span_in_context(NonRecordingSpan(parent_span_context))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to build remote parent context for trace {trace_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_trace_url(trace_id: str, org_id=None) -> str | None:
|
||||
"""Build a Langfuse trace URL, using org-specific host when available."""
|
||||
if org_id is None:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from decimal import Decimal
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
|
|
@ -63,24 +65,31 @@ async def _update_organization_usage(
|
|||
)
|
||||
|
||||
|
||||
async def calculate_workflow_run_cost(workflow_run_id: int):
|
||||
logger.debug("Calculating cost for workflow run")
|
||||
async def _get_pricing_organization(workflow_run):
|
||||
workflow = getattr(workflow_run, "workflow", None)
|
||||
organization_id = getattr(workflow, "organization_id", None)
|
||||
if organization_id is None and workflow and workflow.user:
|
||||
organization_id = workflow.user.selected_organization_id
|
||||
if organization_id is None:
|
||||
return None
|
||||
return await db_client.get_organization_by_id(organization_id)
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning("Workflow run not found")
|
||||
return
|
||||
|
||||
workflow_usage_info = workflow_run.usage_info
|
||||
if not workflow_usage_info:
|
||||
async def _build_usage_cost_snapshot(
|
||||
usage_info: dict | None,
|
||||
*,
|
||||
workflow_run=None,
|
||||
include_telephony_cost: bool = False,
|
||||
organization=None,
|
||||
calculated_at: str | None = None,
|
||||
) -> dict | None:
|
||||
if not usage_info:
|
||||
logger.warning("No usage info available for workflow run")
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
# Calculate cost breakdown
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(workflow_usage_info)
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
|
||||
|
||||
# Fetch telephony call cost
|
||||
if include_telephony_cost and workflow_run is not None:
|
||||
try:
|
||||
telephony_cost = await _fetch_telephony_cost(workflow_run)
|
||||
if telephony_cost:
|
||||
|
|
@ -95,61 +104,127 @@ async def calculate_workflow_run_cost(workflow_run_id: int):
|
|||
logger.error(f"Failed to fetch telephony call cost: {e}")
|
||||
# Don't fail the whole cost calculation if telephony API fails
|
||||
|
||||
# Store cost information back to the workflow run
|
||||
# Convert USD to Dograh Tokens (1 cent = 1 token)
|
||||
dograh_tokens = round(float(cost_breakdown["total"]) * 100, 2)
|
||||
total_cost_usd = Decimal(str(cost_breakdown["total"]))
|
||||
dograh_tokens = float(total_cost_usd * Decimal("100"))
|
||||
|
||||
# Get organization to check if it has USD pricing
|
||||
org = None
|
||||
charge_usd = None
|
||||
if (
|
||||
workflow_run.workflow
|
||||
and workflow_run.workflow.user
|
||||
and workflow_run.workflow.user.selected_organization_id
|
||||
):
|
||||
org = await db_client.get_organization_by_id(
|
||||
workflow_run.workflow.user.selected_organization_id
|
||||
)
|
||||
if organization is None and workflow_run is not None:
|
||||
organization = await _get_pricing_organization(workflow_run)
|
||||
|
||||
# Calculate USD cost if organization has pricing configured
|
||||
if org and org.price_per_second_usd:
|
||||
duration_seconds = workflow_usage_info.get("call_duration_seconds", 0)
|
||||
charge_usd = duration_seconds * org.price_per_second_usd
|
||||
charge_usd = None
|
||||
if organization and organization.price_per_second_usd:
|
||||
duration_seconds = usage_info.get("call_duration_seconds", 0)
|
||||
charge_usd = float(
|
||||
Decimal(str(duration_seconds))
|
||||
* Decimal(str(organization.price_per_second_usd))
|
||||
)
|
||||
|
||||
cost_info = {
|
||||
**workflow_run.cost_info,
|
||||
"cost_breakdown": cost_breakdown,
|
||||
"total_cost_usd": float(cost_breakdown["total"]),
|
||||
"dograh_token_usage": dograh_tokens,
|
||||
"calculated_at": workflow_run.created_at.isoformat(),
|
||||
"call_duration_seconds": workflow_usage_info["call_duration_seconds"],
|
||||
}
|
||||
cost_info = {
|
||||
"cost_breakdown": cost_breakdown,
|
||||
"total_cost_usd": float(total_cost_usd),
|
||||
"dograh_token_usage": dograh_tokens,
|
||||
"calculated_at": calculated_at
|
||||
or (workflow_run.created_at.isoformat() if workflow_run is not None else None),
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds", 0),
|
||||
}
|
||||
|
||||
# Add USD cost if available
|
||||
if charge_usd is not None:
|
||||
cost_info["charge_usd"] = charge_usd
|
||||
cost_info["price_per_second_usd"] = org.price_per_second_usd
|
||||
if charge_usd is not None:
|
||||
cost_info["charge_usd"] = charge_usd
|
||||
cost_info["price_per_second_usd"] = organization.price_per_second_usd
|
||||
|
||||
# Update workflow run with cost information
|
||||
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
|
||||
return cost_info
|
||||
|
||||
# Update organization usage if applicable
|
||||
if org:
|
||||
try:
|
||||
duration_seconds = workflow_usage_info.get("call_duration_seconds", 0)
|
||||
await _update_organization_usage(
|
||||
org, dograh_tokens, duration_seconds, charge_usd
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
async def build_workflow_run_cost_info(workflow_run) -> dict | None:
|
||||
cost_info = await _build_usage_cost_snapshot(
|
||||
workflow_run.usage_info,
|
||||
workflow_run=workflow_run,
|
||||
include_telephony_cost=True,
|
||||
calculated_at=workflow_run.created_at.isoformat(),
|
||||
)
|
||||
if cost_info is None:
|
||||
return None
|
||||
return {
|
||||
**(workflow_run.cost_info or {}),
|
||||
**cost_info,
|
||||
}
|
||||
|
||||
|
||||
async def save_workflow_run_cost_info(
|
||||
workflow_run_id: int, cost_info: dict | None
|
||||
) -> None:
|
||||
if cost_info is None:
|
||||
return
|
||||
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
|
||||
|
||||
|
||||
async def apply_workflow_run_usage_to_organization(
|
||||
workflow_run, cost_info: dict | None
|
||||
) -> None:
|
||||
if cost_info is None:
|
||||
return
|
||||
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if not org:
|
||||
return
|
||||
|
||||
await _update_organization_usage(
|
||||
org,
|
||||
float(cost_info.get("dograh_token_usage") or 0),
|
||||
float(cost_info.get("call_duration_seconds") or 0),
|
||||
cost_info.get("charge_usd"),
|
||||
)
|
||||
|
||||
|
||||
async def apply_usage_delta_to_organization(
|
||||
workflow_run, usage_info: dict | None
|
||||
) -> dict | None:
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if not org:
|
||||
return None
|
||||
|
||||
cost_info = await _build_usage_cost_snapshot(usage_info, organization=org)
|
||||
if cost_info is None:
|
||||
return None
|
||||
|
||||
await _update_organization_usage(
|
||||
org,
|
||||
float(cost_info.get("dograh_token_usage") or 0),
|
||||
float(cost_info.get("call_duration_seconds") or 0),
|
||||
cost_info.get("charge_usd"),
|
||||
)
|
||||
return cost_info
|
||||
|
||||
|
||||
async def calculate_workflow_run_cost(workflow_run_id: int):
|
||||
logger.debug("Calculating cost for workflow run")
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning("Workflow run not found")
|
||||
return
|
||||
|
||||
try:
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
if cost_info is None:
|
||||
return
|
||||
|
||||
await save_workflow_run_cost_info(workflow_run_id, cost_info)
|
||||
|
||||
try:
|
||||
await apply_workflow_run_usage_to_organization(workflow_run, cost_info)
|
||||
except Exception as e:
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if org:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for org {org.id}: {e}"
|
||||
)
|
||||
# Don't fail the whole task if usage update fails
|
||||
else:
|
||||
logger.error(f"Failed to update organization usage: {e}")
|
||||
# Don't fail the whole cost calculation if usage update fails
|
||||
|
||||
logger.info(
|
||||
f"Calculated cost for workflow run: ${cost_breakdown['total']:.6f} USD ({dograh_tokens} Dograh Tokens)"
|
||||
f"Calculated cost for workflow run: ${cost_info['total_cost_usd']:.6f} USD ({cost_info['dograh_token_usage']} Dograh Tokens)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost for workflow run: {e}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional, Union
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Literal, Optional, Union
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.frames.frames import (
|
||||
|
|
@ -7,6 +7,7 @@ from pipecat.frames.frames import (
|
|||
CancelFrame,
|
||||
EndFrame,
|
||||
FunctionCallResultProperties,
|
||||
LLMContextFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
|
|
@ -533,7 +534,7 @@ class PipecatEngine:
|
|||
)
|
||||
await self._update_llm_context(system_prompt, functions)
|
||||
|
||||
async def set_node(self, node_id: str):
|
||||
async def set_node(self, node_id: str, emit_transition_event: bool = True):
|
||||
"""
|
||||
Simplified set_node implementation according to v2 PRD.
|
||||
"""
|
||||
|
|
@ -556,7 +557,7 @@ class PipecatEngine:
|
|||
nodes_visited.append(node.name)
|
||||
|
||||
# Send node transition event if callback is provided
|
||||
if self._node_transition_callback:
|
||||
if emit_transition_event and self._node_transition_callback:
|
||||
try:
|
||||
await self._node_transition_callback(
|
||||
node_id,
|
||||
|
|
@ -598,8 +599,8 @@ class PipecatEngine:
|
|||
# Setup LLM context with prompts and functions.
|
||||
await self._setup_llm_context(node)
|
||||
|
||||
def get_start_greeting(self) -> Optional[tuple[str, Optional[str]]]:
|
||||
"""Return the greeting info for the start node, or None if not configured.
|
||||
def get_node_greeting(self, node_id: str) -> Optional[tuple[str, Optional[str]]]:
|
||||
"""Return the greeting info for a node, or None if not configured.
|
||||
|
||||
Returns:
|
||||
A tuple of (greeting_type, value) where:
|
||||
|
|
@ -607,20 +608,93 @@ class PipecatEngine:
|
|||
- ("audio", recording_id) for pre-recorded audio greetings
|
||||
Or None if no greeting is configured.
|
||||
"""
|
||||
start_node = self.workflow.nodes.get(self.workflow.start_node_id)
|
||||
if not start_node:
|
||||
node = self.workflow.nodes.get(node_id)
|
||||
if not node:
|
||||
return None
|
||||
|
||||
greeting_type = start_node.greeting_type or "text"
|
||||
greeting_type = node.greeting_type or "text"
|
||||
|
||||
if greeting_type == "audio" and start_node.greeting_recording_id:
|
||||
return ("audio", start_node.greeting_recording_id)
|
||||
if greeting_type == "audio" and node.greeting_recording_id:
|
||||
return ("audio", node.greeting_recording_id)
|
||||
|
||||
if start_node.greeting:
|
||||
return ("text", self._format_prompt(start_node.greeting))
|
||||
if node.greeting:
|
||||
return ("text", self._format_prompt(node.greeting))
|
||||
|
||||
return None
|
||||
|
||||
def get_start_greeting(self) -> Optional[tuple[str, Optional[str]]]:
|
||||
"""Return the greeting info for the start node, or None if not configured."""
|
||||
return self.get_node_greeting(self.workflow.start_node_id)
|
||||
|
||||
async def queue_node_opening(
|
||||
self,
|
||||
*,
|
||||
node_id: str,
|
||||
previous_node_id: Optional[str] = None,
|
||||
generate_if_no_greeting: bool = False,
|
||||
) -> Literal["none", "greeting", "llm"]:
|
||||
"""Queue the opening behavior for a node.
|
||||
|
||||
This is the shared source of truth for how a node begins once the
|
||||
engine is ready and the node has already been set on the context.
|
||||
|
||||
Returns:
|
||||
"greeting" when a text/audio greeting was queued,
|
||||
"llm" when an initial LLM generation was queued,
|
||||
"none" when nothing was queued.
|
||||
"""
|
||||
if previous_node_id != node_id:
|
||||
greeting_info = self.get_node_greeting(node_id)
|
||||
if greeting_info:
|
||||
greeting_type, greeting_value = greeting_info
|
||||
if (
|
||||
greeting_type == "audio"
|
||||
and greeting_value
|
||||
and self._fetch_recording_audio
|
||||
and self._transport_output is not None
|
||||
):
|
||||
logger.debug(f"Playing audio greeting recording: {greeting_value}")
|
||||
result = await self._fetch_recording_audio(
|
||||
recording_pk=int(greeting_value)
|
||||
)
|
||||
if result:
|
||||
await play_audio(
|
||||
result.audio,
|
||||
sample_rate=self._audio_config.pipeline_sample_rate
|
||||
if self._audio_config
|
||||
else 16000,
|
||||
queue_frame=self._transport_output.queue_frame,
|
||||
transcript=result.transcript,
|
||||
append_to_context=True,
|
||||
)
|
||||
return "greeting"
|
||||
logger.warning(
|
||||
f"Failed to fetch audio greeting {greeting_value}, "
|
||||
"falling back to LLM generation"
|
||||
)
|
||||
elif greeting_value and self.task is not None:
|
||||
logger.debug("Playing text greeting via TTS")
|
||||
# append_to_context=True so the assistant aggregator commits
|
||||
# the greeting to the LLM context once TTS finishes; without
|
||||
# it the LLM would re-greet on its first generation.
|
||||
await self.task.queue_frame(
|
||||
TTSSpeakFrame(greeting_value, append_to_context=True)
|
||||
)
|
||||
return "greeting"
|
||||
|
||||
if (
|
||||
generate_if_no_greeting
|
||||
and self.llm is not None
|
||||
and self.context is not None
|
||||
):
|
||||
logger.debug("Queueing initial LLM generation for node opening")
|
||||
# Queue after the voicemail detector in the live pipeline so the
|
||||
# detector can gate initial generations when needed.
|
||||
await self.llm.queue_frame(LLMContextFrame(self.context))
|
||||
return "llm"
|
||||
|
||||
return "none"
|
||||
|
||||
async def _handle_end_node(self, node: Node) -> None:
|
||||
"""Handle end node execution."""
|
||||
# Setup LLM context with prompts and functions.
|
||||
|
|
|
|||
|
|
@ -511,6 +511,17 @@ class CustomToolManager:
|
|||
workflow_run = await db_client.get_workflow_run_by_id(
|
||||
self._engine._workflow_run_id
|
||||
)
|
||||
if workflow_run.mode == WorkflowRunMode.TEXTCHAT.value:
|
||||
textchat_error_result = {
|
||||
"status": "failed",
|
||||
"message": "I'm sorry, but call transfers are not available in text chat tests.",
|
||||
"action": "transfer_failed",
|
||||
"reason": "textchat_not_supported",
|
||||
}
|
||||
await self._handle_transfer_result(
|
||||
textchat_error_result, function_call_params, properties
|
||||
)
|
||||
return
|
||||
if workflow_run.mode in [
|
||||
WorkflowRunMode.WEBRTC.value,
|
||||
WorkflowRunMode.SMALLWEBRTC.value,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,10 @@ import re
|
|||
from loguru import logger
|
||||
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.services.pipecat.tracing_config import get_trace_url
|
||||
from api.services.pipecat.tracing_config import (
|
||||
build_remote_parent_context,
|
||||
get_trace_url,
|
||||
)
|
||||
|
||||
|
||||
def extract_trace_id(gathered_context: dict) -> str | None:
|
||||
|
|
@ -33,36 +36,12 @@ def setup_langfuse_parent_context(workflow_run: WorkflowRunModel):
|
|||
|
||||
Returns the parent context object, or None if tracing is unavailable.
|
||||
"""
|
||||
try:
|
||||
from opentelemetry.trace import (
|
||||
NonRecordingSpan,
|
||||
SpanContext,
|
||||
TraceFlags,
|
||||
set_span_in_context,
|
||||
)
|
||||
|
||||
from api.services.pipecat.tracing_config import ensure_tracing
|
||||
|
||||
if not ensure_tracing():
|
||||
return None
|
||||
|
||||
gathered_context = workflow_run.gathered_context or {}
|
||||
trace_id = extract_trace_id(gathered_context)
|
||||
if not trace_id:
|
||||
logger.debug("No trace_id found, skipping Langfuse tracing")
|
||||
return None
|
||||
|
||||
parent_span_ctx = SpanContext(
|
||||
trace_id=int(trace_id, 16),
|
||||
span_id=0x1,
|
||||
is_remote=True,
|
||||
trace_flags=TraceFlags(0x01),
|
||||
)
|
||||
return set_span_in_context(NonRecordingSpan(parent_span_ctx))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set up Langfuse parent context: {e}")
|
||||
gathered_context = workflow_run.gathered_context or {}
|
||||
trace_id = extract_trace_id(gathered_context)
|
||||
if not trace_id:
|
||||
logger.debug("No trace_id found, skipping Langfuse tracing")
|
||||
return None
|
||||
return build_remote_parent_context(trace_id)
|
||||
|
||||
|
||||
def add_qa_span_to_trace(
|
||||
|
|
|
|||
144
api/services/workflow/text_chat_logs.py
Normal file
144
api/services/workflow/text_chat_logs.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
"""Helpers for projecting text-chat session state into run-log snapshots."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from api.services.pipecat.realtime_feedback_events import (
|
||||
build_bot_text_event,
|
||||
build_function_call_end_event,
|
||||
build_function_call_start_event,
|
||||
build_node_transition_event,
|
||||
build_pipeline_error_event,
|
||||
build_user_transcription_event,
|
||||
realtime_feedback_event_sort_key,
|
||||
stamp_realtime_feedback_event,
|
||||
)
|
||||
|
||||
|
||||
def visible_text_chat_turns(session_data: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Return the active branch of turns for the current text-chat session.
|
||||
|
||||
After a rewind, `session_data["turns"]` may still contain future turns until
|
||||
the next message is sent. Those turns are no longer part of the visible
|
||||
branch, so callers that synthesize transcript/log views should trim at
|
||||
`cursor_turn_id`.
|
||||
"""
|
||||
turns = list(session_data.get("turns") or [])
|
||||
cursor_turn_id = session_data.get("cursor_turn_id")
|
||||
if cursor_turn_id is None:
|
||||
return turns
|
||||
|
||||
for index, turn in enumerate(turns):
|
||||
if turn.get("id") == cursor_turn_id:
|
||||
return turns[: index + 1]
|
||||
|
||||
return turns
|
||||
|
||||
|
||||
def build_text_chat_realtime_feedback_events(
|
||||
session_data: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Project text-chat session state into `workflow_runs.logs` event format.
|
||||
|
||||
`workflow_run_text_sessions` holds the authoritative rewindable conversation
|
||||
state. Historical run pages and QA helpers read the normalized
|
||||
`workflow_runs.logs.realtime_feedback_events` schema instead, so this helper
|
||||
rebuilds that snapshot from the currently visible branch.
|
||||
"""
|
||||
events: list[dict[str, Any]] = []
|
||||
last_emitted_node_id: str | None = None
|
||||
|
||||
for turn_index, turn in enumerate(visible_text_chat_turns(session_data)):
|
||||
turn_events = list(turn.get("events") or [])
|
||||
for event in turn_events:
|
||||
payload = dict(event.get("payload") or {})
|
||||
event_type = event.get("type")
|
||||
timestamp = event.get("created_at") or turn.get("created_at")
|
||||
|
||||
if event_type == "node_transition":
|
||||
node_id = payload.get("node_id")
|
||||
if node_id is not None and node_id == last_emitted_node_id:
|
||||
continue
|
||||
snapshot_event = stamp_realtime_feedback_event(
|
||||
build_node_transition_event(
|
||||
node_id=node_id,
|
||||
node_name=payload.get("node_name"),
|
||||
previous_node_id=payload.get("previous_node_id"),
|
||||
previous_node_name=payload.get("previous_node_name"),
|
||||
allow_interrupt=bool(payload.get("allow_interrupt", False)),
|
||||
),
|
||||
timestamp=timestamp,
|
||||
turn=turn_index,
|
||||
node_id=node_id,
|
||||
node_name=payload.get("node_name"),
|
||||
)
|
||||
if node_id is not None:
|
||||
last_emitted_node_id = node_id
|
||||
events.append(snapshot_event)
|
||||
elif event_type == "tool_call_started":
|
||||
events.append(
|
||||
stamp_realtime_feedback_event(
|
||||
build_function_call_start_event(
|
||||
function_name=payload.get("function_name"),
|
||||
tool_call_id=payload.get("tool_call_id"),
|
||||
arguments=payload.get("arguments"),
|
||||
),
|
||||
timestamp=timestamp,
|
||||
turn=turn_index,
|
||||
)
|
||||
)
|
||||
elif event_type == "tool_call_result":
|
||||
events.append(
|
||||
stamp_realtime_feedback_event(
|
||||
build_function_call_end_event(
|
||||
function_name=payload.get("function_name"),
|
||||
tool_call_id=payload.get("tool_call_id"),
|
||||
result=payload.get("result"),
|
||||
),
|
||||
timestamp=timestamp,
|
||||
turn=turn_index,
|
||||
)
|
||||
)
|
||||
elif event_type == "execution_error":
|
||||
events.append(
|
||||
stamp_realtime_feedback_event(
|
||||
build_pipeline_error_event(
|
||||
error=payload.get("message", "Execution error"),
|
||||
fatal=True,
|
||||
),
|
||||
timestamp=timestamp,
|
||||
turn=turn_index,
|
||||
)
|
||||
)
|
||||
|
||||
user_message = turn.get("user_message") or {}
|
||||
if user_message.get("text"):
|
||||
message_timestamp = user_message.get("created_at") or turn.get("created_at")
|
||||
events.append(
|
||||
stamp_realtime_feedback_event(
|
||||
build_user_transcription_event(
|
||||
text=user_message["text"],
|
||||
final=True,
|
||||
timestamp=message_timestamp,
|
||||
),
|
||||
timestamp=message_timestamp,
|
||||
turn=turn_index,
|
||||
)
|
||||
)
|
||||
|
||||
assistant_message = turn.get("assistant_message") or {}
|
||||
if assistant_message.get("text"):
|
||||
message_timestamp = assistant_message.get("created_at") or turn.get(
|
||||
"created_at"
|
||||
)
|
||||
events.append(
|
||||
stamp_realtime_feedback_event(
|
||||
build_bot_text_event(
|
||||
text=assistant_message["text"],
|
||||
timestamp=message_timestamp,
|
||||
),
|
||||
timestamp=message_timestamp,
|
||||
turn=turn_index,
|
||||
)
|
||||
)
|
||||
|
||||
return sorted(events, key=realtime_feedback_event_sort_key)
|
||||
649
api/services/workflow/text_chat_runner.py
Normal file
649
api/services/workflow/text_chat_runner.py
Normal file
|
|
@ -0,0 +1,649 @@
|
|||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
LLMAssistantPushAggregationFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TextFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.run_context import set_current_org_id
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode, WorkflowRunState
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.pipecat.audio_config import create_audio_config
|
||||
from api.services.pipecat.pipeline_builder import create_pipeline_task
|
||||
from api.services.pipecat.pipeline_metrics_aggregator import (
|
||||
PipelineMetricsAggregator,
|
||||
)
|
||||
from api.services.pipecat.recording_audio_cache import create_recording_audio_fetcher
|
||||
from api.services.pipecat.service_factory import create_llm_service
|
||||
from api.services.pipecat.tracing_config import (
|
||||
build_remote_parent_context,
|
||||
get_trace_url,
|
||||
)
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow_graph import WorkflowGraph
|
||||
|
||||
TEXT_CHAT_CHECKPOINT_VERSION = 1
|
||||
TEXT_CHAT_TURN_TIMEOUT_SECONDS = 60.0
|
||||
TEXT_CHAT_IDLE_SETTLE_SECONDS = 0.2
|
||||
TEXT_CHAT_INTERNAL_CANCEL_REASON = "text_chat_turn_complete"
|
||||
|
||||
|
||||
def text_chat_trace_id(workflow_run_id: int) -> str:
|
||||
"""Deterministic Langfuse trace id for a text-chat session.
|
||||
|
||||
Each turn runs in its own short-lived pipeline, so there is no single
|
||||
long-running task to own the trace the way a voice call does. Deriving the
|
||||
id from the run id means every turn re-creates the *same* trace id and all
|
||||
per-turn spans land in one shared trace — without persisting extra state
|
||||
across the otherwise stateless turn requests.
|
||||
"""
|
||||
digest = hashlib.sha256(f"dograh-text-chat:{workflow_run_id}".encode()).hexdigest()
|
||||
return digest[:32]
|
||||
|
||||
|
||||
def default_text_chat_checkpoint() -> dict[str, Any]:
|
||||
return {
|
||||
"version": TEXT_CHAT_CHECKPOINT_VERSION,
|
||||
"anchor_turn_id": None,
|
||||
"current_node_id": None,
|
||||
"messages": [],
|
||||
"gathered_context": {},
|
||||
"tool_state": {},
|
||||
}
|
||||
|
||||
|
||||
def normalize_text_chat_checkpoint(
|
||||
checkpoint: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
normalized = {
|
||||
**default_text_chat_checkpoint(),
|
||||
**(checkpoint or {}),
|
||||
}
|
||||
normalized["messages"] = list(normalized.get("messages") or [])
|
||||
normalized["gathered_context"] = dict(normalized.get("gathered_context") or {})
|
||||
normalized["tool_state"] = dict(normalized.get("tool_state") or {})
|
||||
return normalized
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextChatTurnExecutionResult:
|
||||
assistant_text: str | None
|
||||
assistant_created_at: str
|
||||
events: list[dict[str, Any]]
|
||||
usage: dict[str, Any]
|
||||
checkpoint: dict[str, Any]
|
||||
gathered_context: dict[str, Any]
|
||||
initial_context: dict[str, Any]
|
||||
state: str
|
||||
is_completed: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ResponseWindowState:
|
||||
active_assistant_segments: int = 0
|
||||
active_llm_completions: int = 0
|
||||
pending_context_requests: int = 0
|
||||
blocking_tool_call_ids: set[str] = field(default_factory=set)
|
||||
outputs: list[str] = field(default_factory=list)
|
||||
|
||||
def note_direct_context_request(self) -> None:
|
||||
self.pending_context_requests += 1
|
||||
|
||||
def note_upstream_context_request(self) -> None:
|
||||
self.pending_context_requests += 1
|
||||
|
||||
def note_llm_start(self) -> None:
|
||||
if self.pending_context_requests > 0:
|
||||
self.pending_context_requests -= 1
|
||||
self.active_llm_completions += 1
|
||||
|
||||
def note_llm_end(self) -> None:
|
||||
if self.active_llm_completions > 0:
|
||||
self.active_llm_completions -= 1
|
||||
|
||||
def note_assistant_turn_started(self) -> None:
|
||||
self.active_assistant_segments += 1
|
||||
|
||||
def note_assistant_turn_stopped(self, content: str) -> None:
|
||||
if self.active_assistant_segments > 0:
|
||||
self.active_assistant_segments -= 1
|
||||
normalized_content = content.strip()
|
||||
if normalized_content:
|
||||
self.outputs.append(normalized_content)
|
||||
|
||||
def note_function_call_in_progress(self, tool_call_id: str, blocking: bool) -> None:
|
||||
if blocking:
|
||||
self.blocking_tool_call_ids.add(tool_call_id)
|
||||
|
||||
def note_function_call_result(self, tool_call_id: str) -> None:
|
||||
self.blocking_tool_call_ids.discard(tool_call_id)
|
||||
|
||||
@property
|
||||
def has_blocking_tool_calls(self) -> bool:
|
||||
return bool(self.blocking_tool_call_ids)
|
||||
|
||||
@property
|
||||
def frontier_is_idle(self) -> bool:
|
||||
return (
|
||||
self.pending_context_requests == 0
|
||||
and self.active_llm_completions == 0
|
||||
and self.active_assistant_segments == 0
|
||||
and not self.has_blocking_tool_calls
|
||||
)
|
||||
|
||||
|
||||
class _TaskQueueProxy:
|
||||
def __init__(self, queue_frame):
|
||||
self.queue_frame = queue_frame
|
||||
|
||||
|
||||
class _TextChatCaptureProcessor(FrameProcessor):
|
||||
def __init__(self, response_window: _ResponseWindowState) -> None:
|
||||
super().__init__()
|
||||
self.last_activity_at = time.monotonic()
|
||||
self.activity_count = 0
|
||||
self.events: list[dict[str, Any]] = []
|
||||
self._response_window = response_window
|
||||
|
||||
def _touch(self) -> None:
|
||||
self.last_activity_at = time.monotonic()
|
||||
self.activity_count += 1
|
||||
|
||||
def _append_event(self, event_type: str, payload: dict[str, Any]) -> None:
|
||||
self.events.append(
|
||||
{
|
||||
"type": event_type,
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"payload": jsonable_encoder(payload),
|
||||
}
|
||||
)
|
||||
|
||||
async def process_frame(self, frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
self._touch()
|
||||
|
||||
if isinstance(frame, TTSSpeakFrame):
|
||||
text_frame = TextFrame(frame.text)
|
||||
text_frame.append_to_context = (
|
||||
frame.append_to_context if frame.append_to_context is not None else True
|
||||
)
|
||||
await self.push_frame(text_frame, direction)
|
||||
await self.push_frame(LLMAssistantPushAggregationFrame(), direction)
|
||||
return
|
||||
|
||||
if isinstance(frame, LLMContextFrame) and direction == FrameDirection.UPSTREAM:
|
||||
self._response_window.note_upstream_context_request()
|
||||
|
||||
if isinstance(frame, TTSStoppedFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
await self.push_frame(LLMAssistantPushAggregationFrame(), direction)
|
||||
return
|
||||
|
||||
if (
|
||||
isinstance(frame, LLMFullResponseStartFrame)
|
||||
and direction == FrameDirection.DOWNSTREAM
|
||||
):
|
||||
self._response_window.note_llm_start()
|
||||
|
||||
if (
|
||||
isinstance(frame, LLMFullResponseEndFrame)
|
||||
and direction is FrameDirection.DOWNSTREAM
|
||||
):
|
||||
self._response_window.note_llm_end()
|
||||
await self.push_frame(frame, direction)
|
||||
# Text chat has no TTS/output transport, so mixed text+tool responses
|
||||
# would otherwise leave function calls waiting forever on a
|
||||
# BotStoppedSpeakingFrame that never arrives.
|
||||
await self.push_frame(BotStoppedSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||
return
|
||||
|
||||
if isinstance(frame, FunctionCallInProgressFrame):
|
||||
self._response_window.note_function_call_in_progress(
|
||||
tool_call_id=frame.tool_call_id,
|
||||
blocking=frame.cancel_on_interruption,
|
||||
)
|
||||
self._append_event(
|
||||
"tool_call_started",
|
||||
{
|
||||
"function_name": frame.function_name,
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
"arguments": dict(frame.arguments or {}),
|
||||
},
|
||||
)
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
self._response_window.note_function_call_result(frame.tool_call_id)
|
||||
self._append_event(
|
||||
"tool_call_result",
|
||||
{
|
||||
"function_name": frame.function_name,
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
"result": frame.result,
|
||||
},
|
||||
)
|
||||
elif isinstance(frame, EndFrame):
|
||||
self._append_event("session_end", {"reason": frame.reason})
|
||||
elif isinstance(frame, CancelFrame):
|
||||
if frame.reason != TEXT_CHAT_INTERNAL_CANCEL_REASON:
|
||||
self._append_event("session_cancelled", {"reason": frame.reason})
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
def _merge_usage_info(
|
||||
existing: dict[str, Any] | None,
|
||||
delta: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
merged = dict(existing or {})
|
||||
delta = dict(delta or {})
|
||||
|
||||
merged_llm = dict(merged.get("llm") or {})
|
||||
for key, value in (delta.get("llm") or {}).items():
|
||||
current = dict(merged_llm.get(key) or {})
|
||||
merged_llm[key] = {
|
||||
"prompt_tokens": int(current.get("prompt_tokens") or 0)
|
||||
+ int(value.get("prompt_tokens") or 0),
|
||||
"completion_tokens": int(current.get("completion_tokens") or 0)
|
||||
+ int(value.get("completion_tokens") or 0),
|
||||
"total_tokens": int(current.get("total_tokens") or 0)
|
||||
+ int(value.get("total_tokens") or 0),
|
||||
"cache_read_input_tokens": int(current.get("cache_read_input_tokens") or 0)
|
||||
+ int(value.get("cache_read_input_tokens") or 0),
|
||||
"cache_creation_input_tokens": int(
|
||||
current.get("cache_creation_input_tokens") or 0
|
||||
)
|
||||
+ int(value.get("cache_creation_input_tokens") or 0),
|
||||
}
|
||||
merged["llm"] = merged_llm
|
||||
|
||||
for section in ("tts", "stt"):
|
||||
merged_section = dict(merged.get(section) or {})
|
||||
for key, value in (delta.get(section) or {}).items():
|
||||
merged_section[key] = float(merged_section.get(key) or 0) + float(value)
|
||||
merged[section] = merged_section
|
||||
|
||||
merged["call_duration_seconds"] = int(
|
||||
merged.get("call_duration_seconds") or 0
|
||||
) + int(delta.get("call_duration_seconds") or 0)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def merge_text_chat_usage_info(
|
||||
existing: dict[str, Any] | None,
|
||||
delta: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
return _merge_usage_info(existing, delta)
|
||||
|
||||
|
||||
def _resolve_checkpoint_for_pending_turn(
|
||||
session_data: dict[str, Any],
|
||||
checkpoint: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
turns = list(session_data.get("turns") or [])
|
||||
if not turns:
|
||||
return normalize_text_chat_checkpoint(checkpoint)
|
||||
|
||||
pending_turn = turns[-1]
|
||||
if pending_turn.get("status") != "pending":
|
||||
return normalize_text_chat_checkpoint(checkpoint)
|
||||
|
||||
for turn in reversed(turns[:-1]):
|
||||
if turn.get("status") != "completed":
|
||||
continue
|
||||
stored_checkpoint = turn.get("checkpoint_after_turn")
|
||||
if stored_checkpoint:
|
||||
return normalize_text_chat_checkpoint(stored_checkpoint)
|
||||
break
|
||||
|
||||
return normalize_text_chat_checkpoint(checkpoint)
|
||||
|
||||
|
||||
async def _wait_for_quiescence(
|
||||
*,
|
||||
capture_processor: _TextChatCaptureProcessor,
|
||||
response_window: _ResponseWindowState,
|
||||
runner_task: asyncio.Task,
|
||||
activity_marker: int,
|
||||
timeout_seconds: float = TEXT_CHAT_TURN_TIMEOUT_SECONDS,
|
||||
) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = loop.time() + timeout_seconds
|
||||
|
||||
while loop.time() < deadline:
|
||||
if runner_task.done():
|
||||
await runner_task
|
||||
return
|
||||
|
||||
if (
|
||||
capture_processor.activity_count <= activity_marker
|
||||
and response_window.frontier_is_idle
|
||||
):
|
||||
await asyncio.sleep(0.05)
|
||||
continue
|
||||
|
||||
if (
|
||||
response_window.frontier_is_idle
|
||||
and (time.monotonic() - capture_processor.last_activity_at)
|
||||
>= TEXT_CHAT_IDLE_SETTLE_SECONDS
|
||||
):
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
raise TimeoutError(
|
||||
"Timed out waiting for text chat response window to settle "
|
||||
f"(pending_context_requests={response_window.pending_context_requests}, "
|
||||
f"active_llm_completions={response_window.active_llm_completions}, "
|
||||
f"active_assistant_segments={response_window.active_assistant_segments}, "
|
||||
f"blocking_tool_calls={sorted(response_window.blocking_tool_call_ids)})"
|
||||
)
|
||||
|
||||
|
||||
async def execute_text_chat_pending_turn(
|
||||
*,
|
||||
workflow_run_id: int,
|
||||
workflow_id: int,
|
||||
session_data: dict[str, Any],
|
||||
checkpoint: dict[str, Any] | None,
|
||||
) -> TextChatTurnExecutionResult:
|
||||
turns = list(session_data.get("turns") or [])
|
||||
if not turns or turns[-1].get("status") != "pending":
|
||||
raise ValueError("Text chat session has no pending turn to execute")
|
||||
|
||||
pending_turn = turns[-1]
|
||||
pending_user_message = (
|
||||
((pending_turn.get("user_message") or {}).get("text") or "").strip()
|
||||
if pending_turn.get("user_message") is not None
|
||||
else None
|
||||
)
|
||||
|
||||
workflow_run, _ = await db_client.get_workflow_run_with_context(workflow_run_id)
|
||||
if not workflow_run or workflow_run.workflow_id != workflow_id:
|
||||
raise ValueError("Workflow run not found for text chat execution")
|
||||
if workflow_run.definition is None:
|
||||
raise ValueError("Workflow run is missing a pinned definition")
|
||||
if workflow_run.workflow is None or workflow_run.workflow.user is None:
|
||||
raise ValueError("Workflow run is missing workflow context")
|
||||
|
||||
workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=workflow_run.workflow.organization_id
|
||||
)
|
||||
if workflow is None:
|
||||
raise ValueError("Workflow not found for text chat execution")
|
||||
|
||||
# Stamp the async context so OTEL spans are tagged with this org and routed
|
||||
# to its Langfuse project (the voice paths do this in run_pipeline /
|
||||
# webrtc_signaling; the text path previously skipped it, so its spans never
|
||||
# reached org-specific exporters).
|
||||
set_current_org_id(workflow.organization_id)
|
||||
|
||||
run_definition = workflow_run.definition
|
||||
run_configs = run_definition.workflow_configurations or {}
|
||||
|
||||
user_config = await db_client.get_user_configurations(workflow_run.workflow.user.id)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
)
|
||||
if user_config.llm is None:
|
||||
raise ValueError("Text chat requires an LLM configuration")
|
||||
|
||||
llm = create_llm_service(user_config)
|
||||
inference_llm = llm
|
||||
|
||||
runtime_configuration = {
|
||||
"llm_provider": user_config.llm.provider,
|
||||
"llm_model": user_config.llm.model,
|
||||
}
|
||||
initial_context = {
|
||||
**(workflow_run.initial_context or {}),
|
||||
"runtime_configuration": runtime_configuration,
|
||||
}
|
||||
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(run_definition.workflow_json)
|
||||
)
|
||||
base_checkpoint = _resolve_checkpoint_for_pending_turn(session_data, checkpoint)
|
||||
|
||||
response_window = _ResponseWindowState()
|
||||
capture_processor = _TextChatCaptureProcessor(response_window)
|
||||
context = LLMContext()
|
||||
context.set_messages(base_checkpoint["messages"])
|
||||
|
||||
node_transition_events = capture_processor.events
|
||||
|
||||
async def send_node_transition(
|
||||
node_id: str,
|
||||
node_name: str,
|
||||
previous_node_id: str | None,
|
||||
previous_node_name: str | None,
|
||||
allow_interrupt: bool = False,
|
||||
) -> None:
|
||||
node_transition_events.append(
|
||||
{
|
||||
"type": "node_transition",
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"payload": {
|
||||
"node_id": node_id,
|
||||
"node_name": node_name,
|
||||
"previous_node_id": previous_node_id,
|
||||
"previous_node_name": previous_node_name,
|
||||
"allow_interrupt": allow_interrupt,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
embeddings_base_url = None
|
||||
if user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
|
||||
has_recordings = await db_client.has_active_recordings(workflow.organization_id)
|
||||
context_compaction_enabled = (workflow.workflow_configurations or {}).get(
|
||||
"context_compaction_enabled", False
|
||||
)
|
||||
|
||||
engine = PipecatEngine(
|
||||
llm=llm,
|
||||
inference_llm=inference_llm,
|
||||
context=context,
|
||||
workflow=workflow_graph,
|
||||
call_context_vars=initial_context,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_transition_callback=send_node_transition,
|
||||
embeddings_api_key=embeddings_api_key,
|
||||
embeddings_model=embeddings_model,
|
||||
embeddings_base_url=embeddings_base_url,
|
||||
has_recordings=has_recordings,
|
||||
context_compaction_enabled=context_compaction_enabled,
|
||||
)
|
||||
engine._gathered_context = dict(base_checkpoint["gathered_context"])
|
||||
|
||||
assistant_params = LLMAssistantAggregatorParams()
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context, assistant_params=assistant_params
|
||||
)
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
@assistant_context_aggregator.event_handler("on_assistant_turn_started")
|
||||
async def on_assistant_turn_started(_aggregator):
|
||||
response_window.note_assistant_turn_started()
|
||||
|
||||
@assistant_context_aggregator.event_handler("on_assistant_turn_stopped")
|
||||
async def on_assistant_turn_stopped(_aggregator, message):
|
||||
response_window.note_assistant_turn_stopped(message.content or "")
|
||||
|
||||
# Text chat has no wire transport; reuse the neutral 16 kHz config shape
|
||||
# from the browser pipeline so TTS/recording helpers still have sane defaults.
|
||||
audio_config = create_audio_config(WorkflowRunMode.SMALLWEBRTC.value)
|
||||
pipeline_metrics_aggregator = PipelineMetricsAggregator()
|
||||
|
||||
# Stitch every per-turn pipeline of this session into one Langfuse trace by
|
||||
# handing each task the same remote parent context (derived from the run id).
|
||||
trace_id = text_chat_trace_id(workflow_run_id)
|
||||
conversation_parent_context = build_remote_parent_context(trace_id)
|
||||
# The stitched trace has no real root span (each per-turn conversation span
|
||||
# hangs off a synthetic remote parent), so Langfuse can't infer a name and
|
||||
# shows "Unnamed trace". Name it explicitly via the conversation span.
|
||||
trace_span_attributes = {
|
||||
"langfuse.trace.name": workflow_run.name or f"text-chat-{workflow_run_id}"
|
||||
}
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
llm,
|
||||
capture_processor,
|
||||
assistant_context_aggregator,
|
||||
pipeline_metrics_aggregator,
|
||||
]
|
||||
)
|
||||
task = create_pipeline_task(
|
||||
pipeline,
|
||||
workflow_run_id,
|
||||
audio_config,
|
||||
conversation_parent_context=conversation_parent_context,
|
||||
conversation_type="text",
|
||||
additional_span_attributes=trace_span_attributes,
|
||||
)
|
||||
runner = PipelineRunner(handle_sigint=False, handle_sigterm=False)
|
||||
runner_task = asyncio.create_task(runner.run(task))
|
||||
|
||||
engine.set_task(task)
|
||||
engine.set_audio_config(audio_config)
|
||||
engine.set_transport_output(_TaskQueueProxy(task.queue_frame))
|
||||
engine.set_fetch_recording_audio(
|
||||
create_recording_audio_fetcher(
|
||||
organization_id=workflow.organization_id,
|
||||
pipeline_sample_rate=audio_config.pipeline_sample_rate,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(task._pipeline_start_event.wait(), timeout=5.0)
|
||||
|
||||
await engine.initialize()
|
||||
|
||||
current_node_id = base_checkpoint.get("current_node_id")
|
||||
target_node_id = current_node_id or workflow_graph.start_node_id
|
||||
await engine.set_node(
|
||||
target_node_id,
|
||||
emit_transition_event=current_node_id is None,
|
||||
)
|
||||
|
||||
opening_marker = capture_processor.activity_count
|
||||
opening_expects_llm = pending_user_message is None and (
|
||||
current_node_id == target_node_id
|
||||
or engine.get_node_greeting(target_node_id) is None
|
||||
)
|
||||
if opening_expects_llm:
|
||||
response_window.note_direct_context_request()
|
||||
opening_action = await engine.queue_node_opening(
|
||||
node_id=target_node_id,
|
||||
previous_node_id=current_node_id,
|
||||
generate_if_no_greeting=pending_user_message is None,
|
||||
)
|
||||
if opening_action != "llm" and opening_expects_llm:
|
||||
response_window.pending_context_requests = max(
|
||||
0, response_window.pending_context_requests - 1
|
||||
)
|
||||
if opening_action != "none":
|
||||
await _wait_for_quiescence(
|
||||
capture_processor=capture_processor,
|
||||
response_window=response_window,
|
||||
runner_task=runner_task,
|
||||
activity_marker=opening_marker,
|
||||
)
|
||||
|
||||
if pending_user_message is not None:
|
||||
context.add_message({"role": "user", "content": pending_user_message})
|
||||
generation_marker = capture_processor.activity_count
|
||||
response_window.note_direct_context_request()
|
||||
await llm.queue_frame(LLMContextFrame(context))
|
||||
await _wait_for_quiescence(
|
||||
capture_processor=capture_processor,
|
||||
response_window=response_window,
|
||||
runner_task=runner_task,
|
||||
activity_marker=generation_marker,
|
||||
)
|
||||
finally:
|
||||
if not task.has_finished():
|
||||
await task.cancel(reason=TEXT_CHAT_INTERNAL_CANCEL_REASON)
|
||||
try:
|
||||
await runner_task
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Transportless text chat pipeline failed while closing run {}",
|
||||
workflow_run_id,
|
||||
)
|
||||
await engine.cleanup()
|
||||
raise
|
||||
await engine.cleanup()
|
||||
|
||||
gathered_context = await engine.get_gathered_context()
|
||||
assistant_text = (
|
||||
"\n\n".join(part for part in response_window.outputs if part).strip()
|
||||
if response_window.outputs
|
||||
else None
|
||||
)
|
||||
assistant_created_at = datetime.now(UTC).isoformat()
|
||||
usage = pipeline_metrics_aggregator.get_all_usage_metrics_serialized()
|
||||
current_node = getattr(engine, "_current_node", None)
|
||||
|
||||
updated_checkpoint = {
|
||||
"version": TEXT_CHAT_CHECKPOINT_VERSION,
|
||||
"anchor_turn_id": pending_turn.get("id"),
|
||||
"current_node_id": current_node.id if current_node else None,
|
||||
"messages": jsonable_encoder(context.get_messages()),
|
||||
"gathered_context": jsonable_encoder(gathered_context),
|
||||
"tool_state": jsonable_encoder(base_checkpoint.get("tool_state") or {}),
|
||||
}
|
||||
|
||||
encoded_gathered_context = jsonable_encoder(gathered_context)
|
||||
trace_url = get_trace_url(trace_id, org_id=workflow.organization_id)
|
||||
if trace_url:
|
||||
encoded_gathered_context = {**encoded_gathered_context, "trace_url": trace_url}
|
||||
|
||||
return TextChatTurnExecutionResult(
|
||||
assistant_text=assistant_text,
|
||||
assistant_created_at=assistant_created_at,
|
||||
events=jsonable_encoder(capture_processor.events),
|
||||
usage=jsonable_encoder(usage),
|
||||
checkpoint=updated_checkpoint,
|
||||
gathered_context=encoded_gathered_context,
|
||||
initial_context=jsonable_encoder(initial_context),
|
||||
state=(
|
||||
WorkflowRunState.COMPLETED.value
|
||||
if engine.is_call_disposed()
|
||||
else WorkflowRunState.RUNNING.value
|
||||
),
|
||||
is_completed=engine.is_call_disposed(),
|
||||
)
|
||||
411
api/services/workflow/text_chat_session_service.py
Normal file
411
api/services/workflow/text_chat_session_service.py
Normal file
|
|
@ -0,0 +1,411 @@
|
|||
"""Service helpers for text-chat session lifecycle orchestration."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunTextSessionModel
|
||||
from api.db.workflow_run_text_session_client import (
|
||||
WorkflowRunTextSessionRevisionConflictError,
|
||||
)
|
||||
from api.services.pricing.workflow_run_cost import (
|
||||
apply_usage_delta_to_organization,
|
||||
build_workflow_run_cost_info,
|
||||
)
|
||||
from api.services.workflow.text_chat_logs import (
|
||||
build_text_chat_realtime_feedback_events,
|
||||
)
|
||||
from api.services.workflow.text_chat_runner import (
|
||||
default_text_chat_checkpoint,
|
||||
execute_text_chat_pending_turn,
|
||||
merge_text_chat_usage_info,
|
||||
normalize_text_chat_checkpoint,
|
||||
)
|
||||
|
||||
TEXT_CHAT_SESSION_VERSION = 1
|
||||
|
||||
|
||||
class TextChatSessionRevisionConflictError(Exception):
|
||||
def __init__(self, expected_revision: int, actual_revision: int):
|
||||
self.expected_revision = expected_revision
|
||||
self.actual_revision = actual_revision
|
||||
super().__init__(
|
||||
"Text chat session revision conflict: "
|
||||
f"expected {expected_revision}, found {actual_revision}"
|
||||
)
|
||||
|
||||
|
||||
class TextChatSessionExecutionError(Exception):
|
||||
"""Raised when the assistant turn fails to execute."""
|
||||
|
||||
|
||||
class TextChatPendingTurnLostError(Exception):
|
||||
"""Raised when the pending turn disappears before persistence completes."""
|
||||
|
||||
|
||||
class TextChatTurnNotFoundError(Exception):
|
||||
"""Raised when a requested rewind cursor does not exist in the session."""
|
||||
|
||||
|
||||
def default_text_chat_session_data() -> dict[str, Any]:
|
||||
return {
|
||||
"version": TEXT_CHAT_SESSION_VERSION,
|
||||
"status": "idle",
|
||||
"cursor_turn_id": None,
|
||||
"turns": [],
|
||||
"discarded_future": [],
|
||||
"simulator": {
|
||||
"enabled": False,
|
||||
"config": {},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def normalize_text_chat_session_data(
|
||||
session_data: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
normalized = {
|
||||
**default_text_chat_session_data(),
|
||||
**(session_data or {}),
|
||||
}
|
||||
normalized["turns"] = list(normalized.get("turns") or [])
|
||||
normalized["discarded_future"] = list(normalized.get("discarded_future") or [])
|
||||
simulator = normalized.get("simulator") or {}
|
||||
normalized["simulator"] = {
|
||||
"enabled": bool(simulator.get("enabled", False)),
|
||||
"config": dict(simulator.get("config") or {}),
|
||||
}
|
||||
return normalized
|
||||
|
||||
|
||||
async def initialize_text_chat_session(
|
||||
*,
|
||||
run_id: int,
|
||||
text_session: WorkflowRunTextSessionModel,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
session_data = normalize_text_chat_session_data(text_session.session_data)
|
||||
checkpoint = normalize_text_chat_checkpoint(text_session.checkpoint)
|
||||
|
||||
session_data["turns"] = [build_pending_text_chat_turn(user_text=None)]
|
||||
session_data["status"] = "pending_assistant_turn"
|
||||
checkpoint["anchor_turn_id"] = latest_completed_text_chat_turn_id(
|
||||
session_data["turns"]
|
||||
)
|
||||
|
||||
try:
|
||||
await db_client.update_workflow_run_text_session(
|
||||
run_id,
|
||||
session_data=session_data,
|
||||
checkpoint=checkpoint,
|
||||
expected_revision=text_session.revision,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError as e:
|
||||
raise TextChatSessionRevisionConflictError(
|
||||
expected_revision=e.expected_revision,
|
||||
actual_revision=e.actual_revision,
|
||||
) from e
|
||||
|
||||
return await _reload_text_chat_session(run_id)
|
||||
|
||||
|
||||
async def append_text_chat_user_message(
|
||||
*,
|
||||
run_id: int,
|
||||
text_session: WorkflowRunTextSessionModel,
|
||||
user_text: str,
|
||||
expected_revision: int | None,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
session_data = normalize_text_chat_session_data(text_session.session_data)
|
||||
checkpoint = normalize_text_chat_checkpoint(text_session.checkpoint)
|
||||
|
||||
active_turns, discarded_future = truncate_text_chat_future_turns(session_data)
|
||||
active_turns.append(build_pending_text_chat_turn(user_text=user_text))
|
||||
|
||||
session_data["turns"] = active_turns
|
||||
session_data["discarded_future"] = discarded_future
|
||||
session_data["cursor_turn_id"] = None
|
||||
session_data["status"] = "pending_assistant_turn"
|
||||
checkpoint["anchor_turn_id"] = latest_completed_text_chat_turn_id(active_turns)
|
||||
|
||||
try:
|
||||
await db_client.update_workflow_run_text_session(
|
||||
run_id,
|
||||
session_data=session_data,
|
||||
checkpoint=checkpoint,
|
||||
expected_revision=expected_revision,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError as e:
|
||||
raise TextChatSessionRevisionConflictError(
|
||||
expected_revision=e.expected_revision,
|
||||
actual_revision=e.actual_revision,
|
||||
) from e
|
||||
|
||||
return await _reload_text_chat_session(run_id)
|
||||
|
||||
|
||||
async def rewind_text_chat_session_state(
|
||||
*,
|
||||
run_id: int,
|
||||
text_session: WorkflowRunTextSessionModel,
|
||||
cursor_turn_id: str | None,
|
||||
expected_revision: int | None,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
session_data = normalize_text_chat_session_data(text_session.session_data)
|
||||
validate_text_chat_turn_cursor(session_data, cursor_turn_id)
|
||||
|
||||
session_data["cursor_turn_id"] = cursor_turn_id
|
||||
session_data["status"] = "rewound" if cursor_turn_id else "idle"
|
||||
|
||||
try:
|
||||
await db_client.update_workflow_run_text_session(
|
||||
run_id,
|
||||
session_data=session_data,
|
||||
expected_revision=expected_revision,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError as e:
|
||||
raise TextChatSessionRevisionConflictError(
|
||||
expected_revision=e.expected_revision,
|
||||
actual_revision=e.actual_revision,
|
||||
) from e
|
||||
|
||||
await db_client.update_workflow_run(
|
||||
run_id,
|
||||
logs={
|
||||
"realtime_feedback_events": build_text_chat_realtime_feedback_events(
|
||||
session_data
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
return await _reload_text_chat_session(run_id)
|
||||
|
||||
|
||||
async def execute_pending_text_chat_turn(
|
||||
*,
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
text_session: WorkflowRunTextSessionModel,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
"""Execute the current pending assistant turn and persist its side effects."""
|
||||
session_data = normalize_text_chat_session_data(text_session.session_data)
|
||||
checkpoint = normalize_text_chat_checkpoint(text_session.checkpoint)
|
||||
|
||||
try:
|
||||
execution = await execute_text_chat_pending_turn(
|
||||
workflow_run_id=run_id,
|
||||
workflow_id=workflow_id,
|
||||
session_data=session_data,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
except Exception as e:
|
||||
await _mark_pending_turn_failed(
|
||||
run_id=run_id,
|
||||
text_session=text_session,
|
||||
error_message=str(e),
|
||||
)
|
||||
raise TextChatSessionExecutionError(
|
||||
"Failed to execute text chat assistant turn"
|
||||
) from e
|
||||
|
||||
completed_session_data = normalize_text_chat_session_data(text_session.session_data)
|
||||
completed_turns = list(completed_session_data.get("turns") or [])
|
||||
if not completed_turns or completed_turns[-1].get("status") != "pending":
|
||||
raise TextChatPendingTurnLostError(
|
||||
"Text chat session lost its pending turn before completion"
|
||||
)
|
||||
|
||||
completed_turns[-1]["status"] = "completed"
|
||||
completed_turns[-1]["assistant_message"] = (
|
||||
{
|
||||
"text": execution.assistant_text,
|
||||
"created_at": execution.assistant_created_at,
|
||||
}
|
||||
if execution.assistant_text
|
||||
else None
|
||||
)
|
||||
completed_turns[-1]["events"] = execution.events
|
||||
completed_turns[-1]["usage"] = execution.usage
|
||||
completed_turns[-1]["checkpoint_after_turn"] = execution.checkpoint
|
||||
completed_session_data["turns"] = completed_turns
|
||||
completed_session_data["status"] = "idle"
|
||||
|
||||
try:
|
||||
await db_client.update_workflow_run_text_session(
|
||||
run_id,
|
||||
session_data=completed_session_data,
|
||||
checkpoint=execution.checkpoint,
|
||||
expected_revision=text_session.revision,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError as e:
|
||||
raise TextChatSessionRevisionConflictError(
|
||||
expected_revision=e.expected_revision,
|
||||
actual_revision=e.actual_revision,
|
||||
) from e
|
||||
|
||||
existing_usage_info = text_session.workflow_run.usage_info or {}
|
||||
merged_usage_info = merge_text_chat_usage_info(existing_usage_info, execution.usage)
|
||||
text_chat_logs = {
|
||||
"realtime_feedback_events": build_text_chat_realtime_feedback_events(
|
||||
completed_session_data
|
||||
)
|
||||
}
|
||||
await db_client.update_workflow_run(
|
||||
run_id,
|
||||
initial_context=execution.initial_context,
|
||||
usage_info=merged_usage_info,
|
||||
gathered_context=execution.gathered_context,
|
||||
logs=text_chat_logs,
|
||||
state=execution.state,
|
||||
is_completed=execution.is_completed,
|
||||
)
|
||||
workflow_run = await db_client.get_workflow_run_by_id(run_id)
|
||||
if workflow_run:
|
||||
try:
|
||||
# Apply the per-turn delta so org usage tracks cumulative run cost
|
||||
# without replaying the full session totals on every turn.
|
||||
await apply_usage_delta_to_organization(workflow_run, execution.usage)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for text chat run {run_id}: {e}"
|
||||
)
|
||||
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
if cost_info is not None:
|
||||
await db_client.update_workflow_run(run_id, cost_info=cost_info)
|
||||
|
||||
return await _reload_text_chat_session(run_id)
|
||||
|
||||
|
||||
def validate_text_chat_turn_cursor(
|
||||
session_data: dict[str, Any],
|
||||
cursor_turn_id: str | None,
|
||||
) -> None:
|
||||
if cursor_turn_id is None:
|
||||
return
|
||||
if not any(turn.get("id") == cursor_turn_id for turn in session_data["turns"]):
|
||||
raise TextChatTurnNotFoundError("Turn not found in text chat session")
|
||||
|
||||
|
||||
def truncate_text_chat_future_turns(
|
||||
session_data: dict[str, Any],
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
cursor_turn_id = session_data.get("cursor_turn_id")
|
||||
turns = list(session_data.get("turns") or [])
|
||||
discarded_future = list(session_data.get("discarded_future") or [])
|
||||
|
||||
if cursor_turn_id is None:
|
||||
return turns, discarded_future
|
||||
|
||||
for index, turn in enumerate(turns):
|
||||
if turn.get("id") == cursor_turn_id:
|
||||
active_turns = turns[: index + 1]
|
||||
future_turns = turns[index + 1 :]
|
||||
if future_turns:
|
||||
discarded_future.append(
|
||||
{
|
||||
"rewound_from_turn_id": cursor_turn_id,
|
||||
"discarded_at": datetime.now(UTC).isoformat(),
|
||||
"turns": future_turns,
|
||||
}
|
||||
)
|
||||
return active_turns, discarded_future
|
||||
|
||||
raise TextChatTurnNotFoundError("Turn not found in text chat session")
|
||||
|
||||
|
||||
def latest_completed_text_chat_turn_id(turns: list[dict[str, Any]]) -> str | None:
|
||||
for turn in reversed(turns):
|
||||
if turn.get("status") == "completed":
|
||||
return turn.get("id")
|
||||
return None
|
||||
|
||||
|
||||
def build_pending_text_chat_turn(*, user_text: str | None) -> dict[str, Any]:
|
||||
now = datetime.now(UTC).isoformat()
|
||||
return {
|
||||
"id": f"turn_{uuid4().hex[:12]}",
|
||||
"status": "pending",
|
||||
"created_at": now,
|
||||
"user_message": (
|
||||
{
|
||||
"text": user_text,
|
||||
"created_at": now,
|
||||
}
|
||||
if user_text is not None
|
||||
else None
|
||||
),
|
||||
"assistant_message": None,
|
||||
"events": [],
|
||||
"usage": {},
|
||||
}
|
||||
|
||||
|
||||
async def _mark_pending_turn_failed(
|
||||
*,
|
||||
run_id: int,
|
||||
text_session: WorkflowRunTextSessionModel,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
failed_session_data = normalize_text_chat_session_data(text_session.session_data)
|
||||
failed_turns = list(failed_session_data.get("turns") or [])
|
||||
if not failed_turns or failed_turns[-1].get("status") != "pending":
|
||||
return
|
||||
|
||||
failed_turns[-1]["status"] = "failed"
|
||||
failed_turns[-1]["events"] = [
|
||||
*(failed_turns[-1].get("events") or []),
|
||||
{
|
||||
"type": "execution_error",
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"payload": {"message": error_message},
|
||||
},
|
||||
]
|
||||
failed_session_data["turns"] = failed_turns
|
||||
failed_session_data["status"] = "error"
|
||||
try:
|
||||
await db_client.update_workflow_run_text_session(
|
||||
run_id,
|
||||
session_data=failed_session_data,
|
||||
expected_revision=text_session.revision,
|
||||
)
|
||||
except WorkflowRunTextSessionRevisionConflictError:
|
||||
return
|
||||
|
||||
|
||||
async def _reload_text_chat_session(run_id: int) -> WorkflowRunTextSessionModel:
|
||||
organization_id = await db_client.get_organization_id_by_workflow_run_id(run_id)
|
||||
if organization_id is None:
|
||||
raise TextChatSessionExecutionError(
|
||||
"Workflow run organization not found after update"
|
||||
)
|
||||
updated_text_session = await db_client.get_workflow_run_text_session(
|
||||
run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if updated_text_session is None:
|
||||
raise TextChatSessionExecutionError("Text chat session not found after update")
|
||||
return updated_text_session
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TEXT_CHAT_SESSION_VERSION",
|
||||
"TextChatTurnNotFoundError",
|
||||
"append_text_chat_user_message",
|
||||
"build_pending_text_chat_turn",
|
||||
"TextChatPendingTurnLostError",
|
||||
"TextChatSessionExecutionError",
|
||||
"TextChatSessionRevisionConflictError",
|
||||
"default_text_chat_checkpoint",
|
||||
"default_text_chat_session_data",
|
||||
"execute_pending_text_chat_turn",
|
||||
"initialize_text_chat_session",
|
||||
"latest_completed_text_chat_turn_id",
|
||||
"normalize_text_chat_checkpoint",
|
||||
"normalize_text_chat_session_data",
|
||||
"rewind_text_chat_session_state",
|
||||
"truncate_text_chat_future_turns",
|
||||
"validate_text_chat_turn_cursor",
|
||||
]
|
||||
|
|
@ -167,9 +167,7 @@ class TestIsLocalOrCgnatIp:
|
|||
|
||||
class TestKeepCandidate:
|
||||
def test_private_relay_candidate_survives_private_policy(self):
|
||||
candidate = (
|
||||
"candidate:111 1 udp 41885439 192.168.1.50 50000 typ relay raddr 0.0.0.0 rport 0"
|
||||
)
|
||||
candidate = "candidate:111 1 udp 41885439 192.168.1.50 50000 typ relay raddr 0.0.0.0 rport 0"
|
||||
assert _keep_candidate(candidate, NonRelayFilterPolicy.PRIVATE) is True
|
||||
|
||||
def test_private_host_candidate_drops_under_private_policy(self):
|
||||
|
|
|
|||
53
api/tests/test_realtime_feedback_events.py
Normal file
53
api/tests/test_realtime_feedback_events.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
from api.services.pipecat.realtime_feedback_events import (
|
||||
build_bot_text_event,
|
||||
build_function_call_end_event,
|
||||
build_node_transition_event,
|
||||
realtime_feedback_event_sort_key,
|
||||
stamp_realtime_feedback_event,
|
||||
)
|
||||
|
||||
|
||||
def test_build_function_call_end_event_serializes_results():
|
||||
event = build_function_call_end_event(
|
||||
function_name="lookup_contact",
|
||||
tool_call_id="tool-1",
|
||||
result={"contact_id": 42},
|
||||
)
|
||||
|
||||
assert event == {
|
||||
"type": "rtf-function-call-end",
|
||||
"payload": {
|
||||
"function_name": "lookup_contact",
|
||||
"tool_call_id": "tool-1",
|
||||
"result": "{'contact_id': 42}",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_stamp_and_sort_realtime_feedback_events():
|
||||
node_transition = stamp_realtime_feedback_event(
|
||||
build_node_transition_event(
|
||||
node_id="node-1",
|
||||
node_name="Greeting",
|
||||
previous_node_id=None,
|
||||
previous_node_name=None,
|
||||
),
|
||||
timestamp="2026-01-01T00:00:03+00:00",
|
||||
turn=0,
|
||||
node_id="node-1",
|
||||
node_name="Greeting",
|
||||
)
|
||||
bot_text = stamp_realtime_feedback_event(
|
||||
build_bot_text_event(
|
||||
text="Hello there",
|
||||
timestamp="2026-01-01T00:00:01+00:00",
|
||||
),
|
||||
timestamp="2026-01-01T00:00:02+00:00",
|
||||
turn=0,
|
||||
)
|
||||
|
||||
events = sorted([node_transition, bot_text], key=realtime_feedback_event_sort_key)
|
||||
|
||||
assert events == [bot_text, node_transition]
|
||||
assert node_transition["node_id"] == "node-1"
|
||||
assert node_transition["node_name"] == "Greeting"
|
||||
|
|
@ -382,6 +382,105 @@ class TestStartGreeting:
|
|||
result = engine.get_start_greeting()
|
||||
assert result == ("text", "Hello Alice!")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_node_opening_queues_text_greeting(
|
||||
self, text_workflow: WorkflowGraph
|
||||
):
|
||||
"""Fresh node entry with a greeting should queue TTS and skip LLM bootstrap."""
|
||||
llm = Mock()
|
||||
llm.queue_frame = AsyncMock()
|
||||
task = Mock()
|
||||
task.queue_frame = AsyncMock()
|
||||
|
||||
engine = PipecatEngine(
|
||||
llm=llm,
|
||||
context=LLMContext(),
|
||||
workflow=text_workflow,
|
||||
call_context_vars={},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
engine.set_task(task)
|
||||
|
||||
result = await engine.queue_node_opening(
|
||||
node_id=text_workflow.start_node_id,
|
||||
previous_node_id=None,
|
||||
generate_if_no_greeting=True,
|
||||
)
|
||||
|
||||
assert result == "greeting"
|
||||
llm.queue_frame.assert_not_awaited()
|
||||
queued_frame = task.queue_frame.await_args.args[0]
|
||||
assert isinstance(queued_frame, TTSSpeakFrame)
|
||||
assert queued_frame.text == TEXT_GREETING
|
||||
assert queued_frame.append_to_context is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_node_opening_falls_back_to_llm_without_greeting(self):
|
||||
"""When a node has no greeting, the engine should queue initial LLM generation."""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type="startCall",
|
||||
position=Position(x=0, y=0),
|
||||
data=StartCallNodeData(
|
||||
name="Start",
|
||||
prompt="Prompt",
|
||||
is_start=True,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type="endCall",
|
||||
position=Position(x=0, y=200),
|
||||
data=EndCallNodeData(
|
||||
name="End",
|
||||
prompt="End",
|
||||
is_end=True,
|
||||
add_global_prompt=False,
|
||||
extraction_enabled=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="e",
|
||||
source="start",
|
||||
target="end",
|
||||
data=EdgeDataDTO(label="End", condition="End"),
|
||||
),
|
||||
],
|
||||
)
|
||||
workflow = WorkflowGraph(dto)
|
||||
context = LLMContext()
|
||||
llm = Mock()
|
||||
llm.queue_frame = AsyncMock()
|
||||
task = Mock()
|
||||
task.queue_frame = AsyncMock()
|
||||
|
||||
engine = PipecatEngine(
|
||||
llm=llm,
|
||||
context=context,
|
||||
workflow=workflow,
|
||||
call_context_vars={},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
engine.set_task(task)
|
||||
|
||||
result = await engine.queue_node_opening(
|
||||
node_id=workflow.start_node_id,
|
||||
previous_node_id=None,
|
||||
generate_if_no_greeting=True,
|
||||
)
|
||||
|
||||
assert result == "llm"
|
||||
task.queue_frame.assert_not_awaited()
|
||||
queued_frame = llm.queue_frame.await_args.args[0]
|
||||
assert isinstance(queued_frame, LLMContextFrame)
|
||||
assert queued_frame.context is context
|
||||
|
||||
|
||||
# ─── Tests: Transition Speech (Pipeline) ────────────────────────
|
||||
|
||||
|
|
|
|||
126
api/tests/test_text_chat_logs.py
Normal file
126
api/tests/test_text_chat_logs.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
from api.services.workflow.text_chat_logs import (
|
||||
build_text_chat_realtime_feedback_events,
|
||||
visible_text_chat_turns,
|
||||
)
|
||||
|
||||
|
||||
def test_visible_text_chat_turns_trims_to_cursor_branch():
|
||||
session_data = {
|
||||
"cursor_turn_id": "turn-2",
|
||||
"turns": [
|
||||
{"id": "turn-1"},
|
||||
{"id": "turn-2"},
|
||||
{"id": "turn-3"},
|
||||
],
|
||||
}
|
||||
|
||||
assert visible_text_chat_turns(session_data) == [
|
||||
{"id": "turn-1"},
|
||||
{"id": "turn-2"},
|
||||
]
|
||||
|
||||
|
||||
def test_build_text_chat_realtime_feedback_events_uses_visible_branch_and_dedupes_node_transitions():
|
||||
session_data = {
|
||||
"cursor_turn_id": "turn-2",
|
||||
"turns": [
|
||||
{
|
||||
"id": "turn-1",
|
||||
"created_at": "2026-01-01T00:00:00+00:00",
|
||||
"events": [
|
||||
{
|
||||
"type": "node_transition",
|
||||
"created_at": "2026-01-01T00:00:00+00:00",
|
||||
"payload": {
|
||||
"node_id": "node-start",
|
||||
"node_name": "Start",
|
||||
"previous_node_id": None,
|
||||
"previous_node_name": None,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
"user_message": None,
|
||||
"assistant_message": {
|
||||
"text": "Hello",
|
||||
"created_at": "2026-01-01T00:00:01+00:00",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "turn-2",
|
||||
"created_at": "2026-01-01T00:00:02+00:00",
|
||||
"events": [
|
||||
{
|
||||
"type": "node_transition",
|
||||
"created_at": "2026-01-01T00:00:02+00:00",
|
||||
"payload": {
|
||||
"node_id": "node-start",
|
||||
"node_name": "Start",
|
||||
"previous_node_id": None,
|
||||
"previous_node_name": None,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "tool_call_started",
|
||||
"created_at": "2026-01-01T00:00:03+00:00",
|
||||
"payload": {
|
||||
"function_name": "lookup_contact",
|
||||
"tool_call_id": "tool-1",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "tool_call_result",
|
||||
"created_at": "2026-01-01T00:00:04+00:00",
|
||||
"payload": {
|
||||
"function_name": "lookup_contact",
|
||||
"tool_call_id": "tool-1",
|
||||
"result": {"contact_id": 42},
|
||||
},
|
||||
},
|
||||
],
|
||||
"user_message": {
|
||||
"text": "Find Abhishek",
|
||||
"created_at": "2026-01-01T00:00:02+00:00",
|
||||
},
|
||||
"assistant_message": {
|
||||
"text": "I found one match.",
|
||||
"created_at": "2026-01-01T00:00:05+00:00",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "turn-3",
|
||||
"created_at": "2026-01-01T00:00:06+00:00",
|
||||
"events": [
|
||||
{
|
||||
"type": "execution_error",
|
||||
"created_at": "2026-01-01T00:00:06+00:00",
|
||||
"payload": {"message": "Should be hidden after rewind"},
|
||||
}
|
||||
],
|
||||
"user_message": {
|
||||
"text": "This turn is rewound away",
|
||||
"created_at": "2026-01-01T00:00:06+00:00",
|
||||
},
|
||||
"assistant_message": None,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
events = build_text_chat_realtime_feedback_events(session_data)
|
||||
|
||||
assert [event["type"] for event in events] == [
|
||||
"rtf-node-transition",
|
||||
"rtf-bot-text",
|
||||
"rtf-user-transcription",
|
||||
"rtf-function-call-start",
|
||||
"rtf-function-call-end",
|
||||
"rtf-bot-text",
|
||||
]
|
||||
assert events[0]["payload"]["node_name"] == "Start"
|
||||
assert events[2]["payload"]["text"] == "Find Abhishek"
|
||||
assert events[4]["payload"]["result"] == "{'contact_id': 42}"
|
||||
assert all(
|
||||
event.get("payload", {}).get("error") != "Should be hidden after rewind"
|
||||
for event in events
|
||||
)
|
||||
91
api/tests/test_text_chat_session_service.py
Normal file
91
api/tests/test_text_chat_session_service.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
import api.services.workflow.text_chat_session_service as text_chat_session_service
|
||||
from api.db.models import WorkflowRunTextSessionModel
|
||||
from api.services.workflow.text_chat_session_service import (
|
||||
TextChatSessionExecutionError,
|
||||
TextChatTurnNotFoundError,
|
||||
_reload_text_chat_session,
|
||||
build_pending_text_chat_turn,
|
||||
truncate_text_chat_future_turns,
|
||||
validate_text_chat_turn_cursor,
|
||||
)
|
||||
|
||||
|
||||
def test_build_pending_text_chat_turn_sets_pending_shape():
|
||||
turn = build_pending_text_chat_turn(user_text="Hello")
|
||||
|
||||
assert turn["id"].startswith("turn_")
|
||||
assert turn["status"] == "pending"
|
||||
assert turn["user_message"]["text"] == "Hello"
|
||||
assert turn["assistant_message"] is None
|
||||
assert turn["events"] == []
|
||||
assert turn["usage"] == {}
|
||||
|
||||
|
||||
def test_truncate_text_chat_future_turns_moves_rewound_branch_to_discarded_future():
|
||||
session_data = {
|
||||
"cursor_turn_id": "turn-2",
|
||||
"turns": [
|
||||
{"id": "turn-1"},
|
||||
{"id": "turn-2"},
|
||||
{"id": "turn-3"},
|
||||
],
|
||||
"discarded_future": [],
|
||||
}
|
||||
|
||||
active_turns, discarded_future = truncate_text_chat_future_turns(session_data)
|
||||
|
||||
assert active_turns == [{"id": "turn-1"}, {"id": "turn-2"}]
|
||||
assert discarded_future[0]["rewound_from_turn_id"] == "turn-2"
|
||||
assert discarded_future[0]["turns"] == [{"id": "turn-3"}]
|
||||
|
||||
|
||||
def test_validate_text_chat_turn_cursor_raises_for_missing_turn():
|
||||
with pytest.raises(TextChatTurnNotFoundError):
|
||||
validate_text_chat_turn_cursor(
|
||||
{"turns": [{"id": "turn-1"}]},
|
||||
"turn-404",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_text_chat_session_uses_run_id_to_resolve_organization(
|
||||
monkeypatch,
|
||||
):
|
||||
reloaded_session = WorkflowRunTextSessionModel(workflow_run_id=123)
|
||||
get_org_id = AsyncMock(return_value=77)
|
||||
get_text_session = AsyncMock(return_value=reloaded_session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
text_chat_session_service.db_client,
|
||||
"get_organization_id_by_workflow_run_id",
|
||||
get_org_id,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
text_chat_session_service.db_client,
|
||||
"get_workflow_run_text_session",
|
||||
get_text_session,
|
||||
)
|
||||
|
||||
result = await _reload_text_chat_session(123)
|
||||
|
||||
assert result is reloaded_session
|
||||
get_org_id.assert_awaited_once_with(123)
|
||||
get_text_session.assert_awaited_once_with(123, organization_id=77)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_text_chat_session_raises_when_run_organization_is_missing(
|
||||
monkeypatch,
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
text_chat_session_service.db_client,
|
||||
"get_organization_id_by_workflow_run_id",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
|
||||
with pytest.raises(TextChatSessionExecutionError, match="organization not found"):
|
||||
await _reload_text_chat_session(123)
|
||||
181
api/tests/test_workflow_run_cost.py
Normal file
181
api/tests/test_workflow_run_cost.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.pricing import workflow_run_cost as workflow_run_cost_mod
|
||||
from api.services.pricing.workflow_run_cost import (
|
||||
apply_usage_delta_to_organization,
|
||||
build_workflow_run_cost_info,
|
||||
calculate_workflow_run_cost,
|
||||
)
|
||||
|
||||
|
||||
def _make_workflow_run():
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
workflow_id=456,
|
||||
mode="textchat",
|
||||
created_at=datetime.now(UTC),
|
||||
usage_info={
|
||||
"llm": {},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 7,
|
||||
},
|
||||
cost_info={},
|
||||
workflow=SimpleNamespace(
|
||||
organization_id=42,
|
||||
user=SimpleNamespace(selected_organization_id=42),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_workflow_run_cost_info_does_not_update_org_usage(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
|
||||
update_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
|
||||
)
|
||||
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
|
||||
assert cost_info is not None
|
||||
assert cost_info["call_duration_seconds"] == 7
|
||||
assert "cost_breakdown" in cost_info
|
||||
assert "dograh_token_usage" in cost_info
|
||||
assert cost_info["charge_usd"] == 10.5
|
||||
update_usage.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_workflow_run_cost_keeps_org_usage_side_effect_in_wrapper(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=None))
|
||||
update_run = AsyncMock()
|
||||
update_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client,
|
||||
"get_workflow_run_by_id",
|
||||
AsyncMock(return_value=workflow_run),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_workflow_run", update_run
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
|
||||
)
|
||||
|
||||
await calculate_workflow_run_cost(workflow_run.id)
|
||||
|
||||
update_run.assert_awaited_once()
|
||||
saved_kwargs = update_run.await_args.kwargs
|
||||
assert saved_kwargs["run_id"] == workflow_run.id
|
||||
assert "cost_breakdown" in saved_kwargs["cost_info"]
|
||||
update_usage.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_usage_delta_to_organization_uses_incremental_costs(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.cost_info = {"call_id": "preserve-me"}
|
||||
|
||||
usage_delta_one = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 1_000,
|
||||
"completion_tokens": 100,
|
||||
"total_tokens": 1_100,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 3,
|
||||
}
|
||||
usage_delta_two = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 2_000,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 2_050,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 4,
|
||||
}
|
||||
merged_usage = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 3_000,
|
||||
"completion_tokens": 150,
|
||||
"total_tokens": 3_150,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 7,
|
||||
}
|
||||
|
||||
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
|
||||
update_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
|
||||
)
|
||||
|
||||
first_delta = await apply_usage_delta_to_organization(workflow_run, usage_delta_one)
|
||||
second_delta = await apply_usage_delta_to_organization(
|
||||
workflow_run, usage_delta_two
|
||||
)
|
||||
total_workflow_run = SimpleNamespace(**workflow_run.__dict__)
|
||||
total_workflow_run.usage_info = merged_usage
|
||||
total_cost = await build_workflow_run_cost_info(total_workflow_run)
|
||||
|
||||
assert first_delta is not None
|
||||
assert second_delta is not None
|
||||
assert total_cost is not None
|
||||
assert update_usage.await_count == 2
|
||||
assert update_usage.await_args_list[0].args == (
|
||||
42,
|
||||
first_delta["dograh_token_usage"],
|
||||
3.0,
|
||||
first_delta["charge_usd"],
|
||||
)
|
||||
assert update_usage.await_args_list[1].args == (
|
||||
42,
|
||||
second_delta["dograh_token_usage"],
|
||||
4.0,
|
||||
second_delta["charge_usd"],
|
||||
)
|
||||
assert (
|
||||
first_delta["dograh_token_usage"] + second_delta["dograh_token_usage"]
|
||||
) == pytest.approx(total_cost["dograh_token_usage"])
|
||||
assert (
|
||||
first_delta["charge_usd"] + second_delta["charge_usd"]
|
||||
== total_cost["charge_usd"]
|
||||
)
|
||||
1194
api/tests/test_workflow_text_chat.py
Normal file
1194
api/tests/test_workflow_text_chat.py
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue