feat: simplify TesterPanel design

This commit is contained in:
Abhishek Kumar 2026-05-19 08:24:39 +05:30
parent f929a332bb
commit b243e97502
15 changed files with 2461 additions and 565 deletions

View file

@ -60,8 +60,8 @@ class WorkflowRunTextSessionClient(BaseDBClient):
async def get_workflow_run_text_session(
self,
workflow_run_id: int,
user_id: int | None = None,
organization_id: int | None = None,
*,
organization_id: int,
) -> WorkflowRunTextSessionModel | None:
async with self.async_session() as session:
query = (
@ -74,13 +74,9 @@ class WorkflowRunTextSessionClient(BaseDBClient):
.join(WorkflowRunTextSessionModel.workflow_run)
.join(WorkflowRunModel.workflow)
.where(WorkflowRunTextSessionModel.workflow_run_id == workflow_run_id)
.where(WorkflowModel.organization_id == organization_id)
)
if organization_id is not None:
query = query.where(WorkflowModel.organization_id == organization_id)
elif user_id is not None:
query = query.where(WorkflowModel.user_id == user_id)
result = await session.execute(query)
return result.scalars().first()

View file

@ -12,6 +12,10 @@ from api.db.workflow_run_text_session_client import (
)
from api.enums import WorkflowRunMode
from api.services.auth.depends import get_user
from api.services.workflow.text_chat_runner import (
execute_text_chat_pending_turn,
merge_text_chat_usage_info,
)
router = APIRouter(prefix="/workflow", tags=["workflow-text-chat"])
@ -129,25 +133,6 @@ def _build_response(
)
def _build_response_from_run_and_session(workflow_run, text_session):
return WorkflowRunTextSessionResponse(
workflow_run_id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
name=workflow_run.name,
mode=workflow_run.mode,
state=_get_state_value(workflow_run.state),
is_completed=workflow_run.is_completed,
revision=text_session.revision,
initial_context=workflow_run.initial_context,
gathered_context=workflow_run.gathered_context,
annotations=workflow_run.annotations,
session_data=_normalize_session_data(text_session.session_data),
checkpoint=_normalize_checkpoint(text_session.checkpoint),
created_at=text_session.created_at,
updated_at=text_session.updated_at,
)
def _validate_turn_cursor(
session_data: Dict[str, Any], cursor_turn_id: str | None
) -> None:
@ -188,16 +173,50 @@ def _truncate_future_turns(
def _latest_completed_turn_id(turns: list[Dict[str, Any]]) -> str | None:
for turn in reversed(turns):
if turn.get("status") == "completed" and turn.get("assistant_message"):
if turn.get("status") == "completed":
return turn.get("id")
return None
def _build_pending_turn(*, user_text: str | None) -> Dict[str, Any]:
now = datetime.now(UTC).isoformat()
return {
"id": f"turn_{uuid4().hex[:12]}",
"status": "pending",
"created_at": now,
"user_message": (
{
"text": user_text,
"created_at": now,
}
if user_text is not None
else None
),
"assistant_message": None,
"events": [],
"usage": {},
}
def _revision_conflict_detail(
e: WorkflowRunTextSessionRevisionConflictError,
) -> dict[str, Any]:
return {
"message": "Text chat session revision conflict",
"expected_revision": e.expected_revision,
"actual_revision": e.actual_revision,
}
async def _load_text_session_or_404(
workflow_id: int,
run_id: int,
user: UserModel,
) -> WorkflowRunTextSessionModel:
if user.selected_organization_id is None:
raise HTTPException(
status_code=403, detail="Organization context is required"
)
text_session = await db_client.get_workflow_run_text_session(
run_id, organization_id=user.selected_organization_id
)
@ -212,6 +231,98 @@ async def _load_text_session_or_404(
return text_session
async def _execute_pending_turn_and_build_response(
*,
workflow_id: int,
run_id: int,
text_session: WorkflowRunTextSessionModel,
user: UserModel,
) -> WorkflowRunTextSessionResponse:
try:
execution = await execute_text_chat_pending_turn(
workflow_run_id=run_id,
workflow_id=workflow_id,
session_data=_normalize_session_data(text_session.session_data),
checkpoint=_normalize_checkpoint(text_session.checkpoint),
)
except Exception as e:
failed_session_data = _normalize_session_data(text_session.session_data)
failed_turns = list(failed_session_data.get("turns") or [])
if failed_turns and failed_turns[-1].get("status") == "pending":
failed_turns[-1]["status"] = "failed"
failed_turns[-1]["events"] = [
*(failed_turns[-1].get("events") or []),
{
"type": "execution_error",
"created_at": datetime.now(UTC).isoformat(),
"payload": {"message": str(e)},
},
]
failed_session_data["turns"] = failed_turns
failed_session_data["status"] = "error"
try:
await db_client.update_workflow_run_text_session(
run_id,
session_data=failed_session_data,
expected_revision=text_session.revision,
)
except WorkflowRunTextSessionRevisionConflictError:
pass
raise HTTPException(
status_code=500, detail="Failed to execute text chat assistant turn"
)
completed_session_data = _normalize_session_data(text_session.session_data)
completed_turns = list(completed_session_data.get("turns") or [])
if not completed_turns or completed_turns[-1].get("status") != "pending":
raise HTTPException(
status_code=500,
detail="Text chat session lost its pending turn before completion",
)
completed_turns[-1]["status"] = "completed"
completed_turns[-1]["assistant_message"] = (
{
"text": execution.assistant_text,
"created_at": execution.assistant_created_at,
}
if execution.assistant_text
else None
)
completed_turns[-1]["events"] = execution.events
completed_turns[-1]["usage"] = execution.usage
completed_turns[-1]["checkpoint_after_turn"] = execution.checkpoint
completed_session_data["turns"] = completed_turns
completed_session_data["status"] = "idle"
try:
await db_client.update_workflow_run_text_session(
run_id,
session_data=completed_session_data,
checkpoint=execution.checkpoint,
expected_revision=text_session.revision,
)
except WorkflowRunTextSessionRevisionConflictError as e:
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
existing_usage_info = text_session.workflow_run.usage_info or {}
merged_usage_info = merge_text_chat_usage_info(
existing_usage_info,
execution.usage,
)
await db_client.update_workflow_run(
run_id,
initial_context=execution.initial_context,
usage_info=merged_usage_info,
gathered_context=execution.gathered_context,
state=execution.state,
is_completed=execution.is_completed,
)
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
return _build_response(text_session)
@router.post(
"/{workflow_id}/text-chat/sessions",
response_model=WorkflowRunTextSessionResponse,
@ -252,7 +363,30 @@ async def create_text_chat_session(
session_data=_default_session_data(),
checkpoint=_default_checkpoint(),
)
return _build_response_from_run_and_session(workflow_run, text_session)
session_data = _normalize_session_data(text_session.session_data)
checkpoint = _normalize_checkpoint(text_session.checkpoint)
session_data["turns"] = [_build_pending_turn(user_text=None)]
session_data["status"] = "pending_assistant_turn"
checkpoint["anchor_turn_id"] = _latest_completed_turn_id(session_data["turns"])
try:
await db_client.update_workflow_run_text_session(
workflow_run.id,
session_data=session_data,
checkpoint=checkpoint,
expected_revision=text_session.revision,
)
except WorkflowRunTextSessionRevisionConflictError as e:
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
text_session = await _load_text_session_or_404(workflow_id, workflow_run.id, user)
return await _execute_pending_turn_and_build_response(
workflow_id=workflow_id,
run_id=workflow_run.id,
text_session=text_session,
user=user,
)
@router.get(
@ -283,22 +417,7 @@ async def append_text_chat_message(
checkpoint = _normalize_checkpoint(text_session.checkpoint)
active_turns, discarded_future = _truncate_future_turns(session_data)
now = datetime.now(UTC).isoformat()
turn_id = f"turn_{uuid4().hex[:12]}"
active_turns.append(
{
"id": turn_id,
"status": "pending",
"created_at": now,
"user_message": {
"text": request.text,
"created_at": now,
},
"assistant_message": None,
"events": [],
"usage": {},
}
)
active_turns.append(_build_pending_turn(user_text=request.text))
session_data["turns"] = active_turns
session_data["discarded_future"] = discarded_future
@ -307,24 +426,22 @@ async def append_text_chat_message(
checkpoint["anchor_turn_id"] = _latest_completed_turn_id(active_turns)
try:
text_session = await db_client.update_workflow_run_text_session(
await db_client.update_workflow_run_text_session(
run_id,
session_data=session_data,
checkpoint=checkpoint,
expected_revision=request.expected_revision,
)
except WorkflowRunTextSessionRevisionConflictError as e:
raise HTTPException(
status_code=409,
detail={
"message": "Text chat session revision conflict",
"expected_revision": e.expected_revision,
"actual_revision": e.actual_revision,
},
)
raise HTTPException(status_code=409, detail=_revision_conflict_detail(e))
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
return _build_response(text_session)
return await _execute_pending_turn_and_build_response(
workflow_id=workflow_id,
run_id=run_id,
text_session=text_session,
user=user,
)
@router.post(

View file

@ -6,7 +6,7 @@ from api.db import db_client
from api.enums import PostHogEvent, WorkflowRunState
from api.services.campaign.circuit_breaker import circuit_breaker
from api.services.pipecat.audio_config import AudioConfig
from api.services.pipecat.audio_playback import play_audio, play_audio_loop
from api.services.pipecat.audio_playback import play_audio_loop
from api.services.pipecat.in_memory_buffers import (
InMemoryAudioBuffer,
InMemoryLogsBuffer,
@ -19,8 +19,6 @@ from api.tasks.arq import enqueue_job
from api.tasks.function_names import FunctionNames
from pipecat.frames.frames import (
Frame,
LLMContextFrame,
TTSSpeakFrame,
)
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
@ -68,7 +66,6 @@ def register_event_handlers(
pipeline_metrics_aggregator: PipelineMetricsAggregator,
audio_config=AudioConfig,
pre_call_fetch_task: asyncio.Task | None = None,
fetch_recording_audio=None,
user_provider_id: str | None = None,
):
"""Register all event handlers for transport and task events.
@ -97,20 +94,11 @@ def register_event_handlers(
"initial_response_triggered": False,
}
async def queue_initial_llm_context():
# Queue LLMContextFrame after the VoicemailDetector since the detector
# gates LLMContextFrames until voicemail detection completes. We also
# don't want to trigger the Voicemail LLM with this initial frame.
await engine.llm.queue_frame(LLMContextFrame(engine.context))
async def maybe_trigger_initial_response():
"""Start the conversation after both pipeline_started and client_connected events.
If a pre-call fetch is in progress, plays a ringer while waiting for the
response, then merges the result into the call context before proceeding.
If the start node has a greeting configured, play it directly via TTS.
Otherwise, trigger an LLM generation for the opening message.
"""
if (
ready_state["pipeline_started"]
@ -165,46 +153,11 @@ def register_event_handlers(
# Set the start node now (after pre-call fetch data is merged)
# so that render_template() has the complete _call_context_vars.
await engine.set_node(engine.workflow.start_node_id)
greeting_info = engine.get_start_greeting()
if greeting_info:
greeting_type, greeting_value = greeting_info
if (
greeting_type == "audio"
and greeting_value
and fetch_recording_audio
):
logger.debug(f"Playing audio greeting recording: {greeting_value}")
result = await fetch_recording_audio(
recording_pk=int(greeting_value)
)
if result:
await play_audio(
result.audio,
sample_rate=audio_config.pipeline_sample_rate or 16000,
queue_frame=transport.output().queue_frame,
transcript=result.transcript,
append_to_context=True,
)
else:
logger.warning(
f"Failed to fetch audio greeting {greeting_value}, "
"falling back to LLM generation"
)
await queue_initial_llm_context()
else:
logger.debug("Playing text greeting via TTS")
# append_to_context=True so the assistant aggregator commits
# the greeting to the LLM context once TTS finishes; without
# it the LLM would re-greet on its first generation.
await task.queue_frame(
TTSSpeakFrame(greeting_value, append_to_context=True)
)
else:
logger.debug(
"Both pipeline_started and client_connected received - triggering initial LLM generation"
)
await queue_initial_llm_context()
await engine.queue_node_opening(
node_id=engine.workflow.start_node_id,
previous_node_id=None,
generate_if_no_greeting=True,
)
@transport.event_handler("on_client_connected")
async def on_client_connected(_transport, _participant):

View file

@ -779,7 +779,6 @@ async def _run_pipeline(
pipeline_metrics_aggregator=pipeline_metrics_aggregator,
audio_config=audio_config,
pre_call_fetch_task=pre_call_fetch_task,
fetch_recording_audio=fetch_audio,
user_provider_id=user_provider_id,
)

View file

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Union
from typing import TYPE_CHECKING, Awaitable, Callable, Literal, Optional, Union
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.frames.frames import (
@ -7,6 +7,7 @@ from pipecat.frames.frames import (
CancelFrame,
EndFrame,
FunctionCallResultProperties,
LLMContextFrame,
TTSSpeakFrame,
)
from pipecat.pipeline.task import PipelineTask
@ -590,8 +591,8 @@ class PipecatEngine:
# Setup LLM Context with Prompts and Functions
await self._setup_llm_context(node)
def get_start_greeting(self) -> Optional[tuple[str, Optional[str]]]:
"""Return the greeting info for the start node, or None if not configured.
def get_node_greeting(self, node_id: str) -> Optional[tuple[str, Optional[str]]]:
"""Return the greeting info for a node, or None if not configured.
Returns:
A tuple of (greeting_type, value) where:
@ -599,20 +600,89 @@ class PipecatEngine:
- ("audio", recording_id) for pre-recorded audio greetings
Or None if no greeting is configured.
"""
start_node = self.workflow.nodes.get(self.workflow.start_node_id)
if not start_node:
node = self.workflow.nodes.get(node_id)
if not node:
return None
greeting_type = start_node.greeting_type or "text"
greeting_type = node.greeting_type or "text"
if greeting_type == "audio" and start_node.greeting_recording_id:
return ("audio", start_node.greeting_recording_id)
if greeting_type == "audio" and node.greeting_recording_id:
return ("audio", node.greeting_recording_id)
if start_node.greeting:
return ("text", self._format_prompt(start_node.greeting))
if node.greeting:
return ("text", self._format_prompt(node.greeting))
return None
def get_start_greeting(self) -> Optional[tuple[str, Optional[str]]]:
"""Return the greeting info for the start node, or None if not configured."""
return self.get_node_greeting(self.workflow.start_node_id)
async def queue_node_opening(
self,
*,
node_id: str,
previous_node_id: Optional[str] = None,
generate_if_no_greeting: bool = False,
) -> Literal["none", "greeting", "llm"]:
"""Queue the opening behavior for a node.
This is the shared source of truth for how a node begins once the
engine is ready and the node has already been set on the context.
Returns:
"greeting" when a text/audio greeting was queued,
"llm" when an initial LLM generation was queued,
"none" when nothing was queued.
"""
if previous_node_id != node_id:
greeting_info = self.get_node_greeting(node_id)
if greeting_info:
greeting_type, greeting_value = greeting_info
if (
greeting_type == "audio"
and greeting_value
and self._fetch_recording_audio
and self._transport_output is not None
):
logger.debug(f"Playing audio greeting recording: {greeting_value}")
result = await self._fetch_recording_audio(
recording_pk=int(greeting_value)
)
if result:
await play_audio(
result.audio,
sample_rate=self._audio_config.pipeline_sample_rate
if self._audio_config
else 16000,
queue_frame=self._transport_output.queue_frame,
transcript=result.transcript,
append_to_context=True,
)
return "greeting"
logger.warning(
f"Failed to fetch audio greeting {greeting_value}, "
"falling back to LLM generation"
)
elif greeting_value and self.task is not None:
logger.debug("Playing text greeting via TTS")
# append_to_context=True so the assistant aggregator commits
# the greeting to the LLM context once TTS finishes; without
# it the LLM would re-greet on its first generation.
await self.task.queue_frame(
TTSSpeakFrame(greeting_value, append_to_context=True)
)
return "greeting"
if generate_if_no_greeting and self.llm is not None and self.context is not None:
logger.debug("Queueing initial LLM generation for node opening")
# Queue after the voicemail detector in the live pipeline so the
# detector can gate initial generations when needed.
await self.llm.queue_frame(LLMContextFrame(self.context))
return "llm"
return "none"
async def _handle_end_node(self, node: Node) -> None:
"""Handle end node execution."""
if node.is_static:

View file

@ -431,6 +431,17 @@ class CustomToolManager:
workflow_run = await db_client.get_workflow_run_by_id(
self._engine._workflow_run_id
)
if workflow_run.mode == WorkflowRunMode.TEXTCHAT.value:
textchat_error_result = {
"status": "failed",
"message": "I'm sorry, but call transfers are not available in text chat tests.",
"action": "transfer_failed",
"reason": "textchat_not_supported",
}
await self._handle_transfer_result(
textchat_error_result, function_call_params, properties
)
return
if workflow_run.mode in [
WorkflowRunMode.WEBRTC.value,
WorkflowRunMode.SMALLWEBRTC.value,

View file

@ -0,0 +1,601 @@
import asyncio
import time
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any
from fastapi.encoders import jsonable_encoder
from loguru import logger
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
LLMContextFrame,
LLMAssistantPushAggregationFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
TTSSpeakFrame,
TTSTextFrame,
TTSStoppedFrame,
TextFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregatorParams,
LLMContextAggregatorPair,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from api.db import db_client
from api.enums import WorkflowRunMode, WorkflowRunState
from api.services.configuration.resolve import resolve_effective_config
from api.services.pipecat.audio_config import create_audio_config
from api.services.pipecat.pipeline_builder import create_pipeline_task
from api.services.pipecat.pipeline_metrics_aggregator import (
PipelineMetricsAggregator,
)
from api.services.pipecat.recording_audio_cache import create_recording_audio_fetcher
from api.services.pipecat.service_factory import create_llm_service
from api.services.workflow.dto import ReactFlowDTO
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow_graph import WorkflowGraph
TEXT_CHAT_CHECKPOINT_VERSION = 1
TEXT_CHAT_TURN_TIMEOUT_SECONDS = 60.0
TEXT_CHAT_IDLE_SETTLE_SECONDS = 0.2
TEXT_CHAT_INTERNAL_CANCEL_REASON = "text_chat_turn_complete"
def default_text_chat_checkpoint() -> dict[str, Any]:
return {
"version": TEXT_CHAT_CHECKPOINT_VERSION,
"anchor_turn_id": None,
"current_node_id": None,
"messages": [],
"gathered_context": {},
"tool_state": {},
}
def normalize_text_chat_checkpoint(
checkpoint: dict[str, Any] | None,
) -> dict[str, Any]:
normalized = {
**default_text_chat_checkpoint(),
**(checkpoint or {}),
}
normalized["messages"] = list(normalized.get("messages") or [])
normalized["gathered_context"] = dict(normalized.get("gathered_context") or {})
normalized["tool_state"] = dict(normalized.get("tool_state") or {})
return normalized
@dataclass
class TextChatTurnExecutionResult:
assistant_text: str | None
assistant_created_at: str
events: list[dict[str, Any]]
usage: dict[str, Any]
checkpoint: dict[str, Any]
gathered_context: dict[str, Any]
initial_context: dict[str, Any]
state: str
is_completed: bool
@dataclass
class _ResponseWindowState:
active_assistant_segments: int = 0
active_llm_completions: int = 0
pending_context_requests: int = 0
blocking_tool_call_ids: set[str] = field(default_factory=set)
outputs: list[str] = field(default_factory=list)
def note_direct_context_request(self) -> None:
self.pending_context_requests += 1
def note_upstream_context_request(self) -> None:
self.pending_context_requests += 1
def note_llm_start(self) -> None:
if self.pending_context_requests > 0:
self.pending_context_requests -= 1
self.active_llm_completions += 1
def note_llm_end(self) -> None:
if self.active_llm_completions > 0:
self.active_llm_completions -= 1
def note_assistant_turn_started(self) -> None:
self.active_assistant_segments += 1
def note_assistant_turn_stopped(self, content: str) -> None:
if self.active_assistant_segments > 0:
self.active_assistant_segments -= 1
normalized_content = content.strip()
if normalized_content:
self.outputs.append(normalized_content)
def note_function_call_in_progress(self, tool_call_id: str, blocking: bool) -> None:
if blocking:
self.blocking_tool_call_ids.add(tool_call_id)
def note_function_call_result(self, tool_call_id: str) -> None:
self.blocking_tool_call_ids.discard(tool_call_id)
@property
def has_blocking_tool_calls(self) -> bool:
return bool(self.blocking_tool_call_ids)
@property
def frontier_is_idle(self) -> bool:
return (
self.pending_context_requests == 0
and self.active_llm_completions == 0
and self.active_assistant_segments == 0
and not self.has_blocking_tool_calls
)
class _TaskQueueProxy:
def __init__(self, queue_frame):
self.queue_frame = queue_frame
class _TextChatCaptureProcessor(FrameProcessor):
def __init__(self, response_window: _ResponseWindowState) -> None:
super().__init__()
self.last_activity_at = time.monotonic()
self.activity_count = 0
self.events: list[dict[str, Any]] = []
self._response_window = response_window
def _touch(self) -> None:
self.last_activity_at = time.monotonic()
self.activity_count += 1
def _append_event(self, event_type: str, payload: dict[str, Any]) -> None:
self.events.append(
{
"type": event_type,
"created_at": datetime.now(UTC).isoformat(),
"payload": jsonable_encoder(payload),
}
)
async def process_frame(self, frame, direction: FrameDirection):
await super().process_frame(frame, direction)
self._touch()
if isinstance(frame, TTSSpeakFrame):
text_frame = TextFrame(frame.text)
text_frame.append_to_context = (
frame.append_to_context
if frame.append_to_context is not None
else True
)
await self.push_frame(text_frame, direction)
await self.push_frame(LLMAssistantPushAggregationFrame(), direction)
return
if isinstance(frame, LLMContextFrame) and direction == FrameDirection.UPSTREAM:
self._response_window.note_upstream_context_request()
if isinstance(frame, TTSStoppedFrame):
await self.push_frame(frame, direction)
await self.push_frame(LLMAssistantPushAggregationFrame(), direction)
return
if (
isinstance(frame, LLMFullResponseStartFrame)
and direction == FrameDirection.DOWNSTREAM
):
self._response_window.note_llm_start()
if isinstance(frame, LLMFullResponseEndFrame) and direction is FrameDirection.DOWNSTREAM:
self._response_window.note_llm_end()
await self.push_frame(frame, direction)
# Text chat has no TTS/output transport, so mixed text+tool responses
# would otherwise leave function calls waiting forever on a
# BotStoppedSpeakingFrame that never arrives.
await self.push_frame(BotStoppedSpeakingFrame(), FrameDirection.UPSTREAM)
return
if isinstance(frame, FunctionCallInProgressFrame):
self._response_window.note_function_call_in_progress(
tool_call_id=frame.tool_call_id,
blocking=frame.cancel_on_interruption,
)
self._append_event(
"tool_call_started",
{
"function_name": frame.function_name,
"tool_call_id": frame.tool_call_id,
"arguments": dict(frame.arguments or {}),
},
)
elif isinstance(frame, FunctionCallResultFrame):
self._response_window.note_function_call_result(frame.tool_call_id)
self._append_event(
"tool_call_result",
{
"function_name": frame.function_name,
"tool_call_id": frame.tool_call_id,
"result": frame.result,
},
)
elif isinstance(frame, EndFrame):
self._append_event("session_end", {"reason": frame.reason})
elif isinstance(frame, CancelFrame):
if frame.reason != TEXT_CHAT_INTERNAL_CANCEL_REASON:
self._append_event("session_cancelled", {"reason": frame.reason})
await self.push_frame(frame, direction)
def _merge_usage_info(
existing: dict[str, Any] | None,
delta: dict[str, Any] | None,
) -> dict[str, Any]:
merged = dict(existing or {})
delta = dict(delta or {})
merged_llm = dict(merged.get("llm") or {})
for key, value in (delta.get("llm") or {}).items():
current = dict(merged_llm.get(key) or {})
merged_llm[key] = {
"prompt_tokens": int(current.get("prompt_tokens") or 0)
+ int(value.get("prompt_tokens") or 0),
"completion_tokens": int(current.get("completion_tokens") or 0)
+ int(value.get("completion_tokens") or 0),
"total_tokens": int(current.get("total_tokens") or 0)
+ int(value.get("total_tokens") or 0),
"cache_read_input_tokens": int(
current.get("cache_read_input_tokens") or 0
)
+ int(value.get("cache_read_input_tokens") or 0),
"cache_creation_input_tokens": int(
current.get("cache_creation_input_tokens") or 0
)
+ int(value.get("cache_creation_input_tokens") or 0),
}
merged["llm"] = merged_llm
for section in ("tts", "stt"):
merged_section = dict(merged.get(section) or {})
for key, value in (delta.get(section) or {}).items():
merged_section[key] = float(merged_section.get(key) or 0) + float(value)
merged[section] = merged_section
merged["call_duration_seconds"] = int(merged.get("call_duration_seconds") or 0) + int(
delta.get("call_duration_seconds") or 0
)
return merged
def merge_text_chat_usage_info(
existing: dict[str, Any] | None,
delta: dict[str, Any] | None,
) -> dict[str, Any]:
return _merge_usage_info(existing, delta)
def _resolve_checkpoint_for_pending_turn(
session_data: dict[str, Any],
checkpoint: dict[str, Any] | None,
) -> dict[str, Any]:
turns = list(session_data.get("turns") or [])
if not turns:
return normalize_text_chat_checkpoint(checkpoint)
pending_turn = turns[-1]
if pending_turn.get("status") != "pending":
return normalize_text_chat_checkpoint(checkpoint)
for turn in reversed(turns[:-1]):
if turn.get("status") != "completed":
continue
stored_checkpoint = turn.get("checkpoint_after_turn")
if stored_checkpoint:
return normalize_text_chat_checkpoint(stored_checkpoint)
break
return normalize_text_chat_checkpoint(checkpoint)
async def _wait_for_quiescence(
*,
capture_processor: _TextChatCaptureProcessor,
response_window: _ResponseWindowState,
runner_task: asyncio.Task,
activity_marker: int,
timeout_seconds: float = TEXT_CHAT_TURN_TIMEOUT_SECONDS,
) -> None:
loop = asyncio.get_running_loop()
deadline = loop.time() + timeout_seconds
while loop.time() < deadline:
if runner_task.done():
await runner_task
return
if (
capture_processor.activity_count <= activity_marker
and response_window.frontier_is_idle
):
await asyncio.sleep(0.05)
continue
if response_window.frontier_is_idle and (
time.monotonic() - capture_processor.last_activity_at
) >= TEXT_CHAT_IDLE_SETTLE_SECONDS:
return
await asyncio.sleep(0.05)
raise TimeoutError(
"Timed out waiting for text chat response window to settle "
f"(pending_context_requests={response_window.pending_context_requests}, "
f"active_llm_completions={response_window.active_llm_completions}, "
f"active_assistant_segments={response_window.active_assistant_segments}, "
f"blocking_tool_calls={sorted(response_window.blocking_tool_call_ids)})"
)
async def execute_text_chat_pending_turn(
*,
workflow_run_id: int,
workflow_id: int,
session_data: dict[str, Any],
checkpoint: dict[str, Any] | None,
) -> TextChatTurnExecutionResult:
turns = list(session_data.get("turns") or [])
if not turns or turns[-1].get("status") != "pending":
raise ValueError("Text chat session has no pending turn to execute")
pending_turn = turns[-1]
pending_user_message = (
((pending_turn.get("user_message") or {}).get("text") or "").strip()
if pending_turn.get("user_message") is not None
else None
)
workflow_run, _ = await db_client.get_workflow_run_with_context(workflow_run_id)
if not workflow_run or workflow_run.workflow_id != workflow_id:
raise ValueError("Workflow run not found for text chat execution")
if workflow_run.definition is None:
raise ValueError("Workflow run is missing a pinned definition")
if workflow_run.workflow is None or workflow_run.workflow.user is None:
raise ValueError("Workflow run is missing workflow context")
workflow = await db_client.get_workflow(
workflow_id, organization_id=workflow_run.workflow.organization_id
)
if workflow is None:
raise ValueError("Workflow not found for text chat execution")
run_definition = workflow_run.definition
run_configs = run_definition.workflow_configurations or {}
user_config = await db_client.get_user_configurations(workflow_run.workflow.user.id)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
)
if user_config.llm is None:
raise ValueError("Text chat requires an LLM configuration")
llm = create_llm_service(user_config)
inference_llm = llm
runtime_configuration = {
"llm_provider": user_config.llm.provider,
"llm_model": user_config.llm.model,
}
initial_context = {
**(workflow_run.initial_context or {}),
"runtime_configuration": runtime_configuration,
}
workflow_graph = WorkflowGraph(
ReactFlowDTO.model_validate(run_definition.workflow_json)
)
base_checkpoint = _resolve_checkpoint_for_pending_turn(session_data, checkpoint)
response_window = _ResponseWindowState()
capture_processor = _TextChatCaptureProcessor(response_window)
context = LLMContext()
context.set_messages(base_checkpoint["messages"])
node_transition_events = capture_processor.events
async def send_node_transition(
node_id: str,
node_name: str,
previous_node_id: str | None,
previous_node_name: str | None,
allow_interrupt: bool = False,
) -> None:
node_transition_events.append(
{
"type": "node_transition",
"created_at": datetime.now(UTC).isoformat(),
"payload": {
"node_id": node_id,
"node_name": node_name,
"previous_node_id": previous_node_id,
"previous_node_name": previous_node_name,
"allow_interrupt": allow_interrupt,
},
}
)
embeddings_api_key = None
embeddings_model = None
embeddings_base_url = None
if user_config.embeddings:
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
has_recordings = await db_client.has_active_recordings(workflow.organization_id)
context_compaction_enabled = (workflow.workflow_configurations or {}).get(
"context_compaction_enabled", False
)
engine = PipecatEngine(
llm=llm,
inference_llm=inference_llm,
context=context,
workflow=workflow_graph,
call_context_vars=initial_context,
workflow_run_id=workflow_run_id,
node_transition_callback=send_node_transition,
embeddings_api_key=embeddings_api_key,
embeddings_model=embeddings_model,
embeddings_base_url=embeddings_base_url,
has_recordings=has_recordings,
context_compaction_enabled=context_compaction_enabled,
)
engine._gathered_context = dict(base_checkpoint["gathered_context"])
assistant_params = LLMAssistantAggregatorParams()
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params
)
assistant_context_aggregator = context_aggregator.assistant()
@assistant_context_aggregator.event_handler("on_assistant_turn_started")
async def on_assistant_turn_started(_aggregator):
response_window.note_assistant_turn_started()
@assistant_context_aggregator.event_handler("on_assistant_turn_stopped")
async def on_assistant_turn_stopped(_aggregator, message):
response_window.note_assistant_turn_stopped(message.content or "")
# Text chat has no wire transport; reuse the neutral 16 kHz config shape
# from the browser pipeline so TTS/recording helpers still have sane defaults.
audio_config = create_audio_config(WorkflowRunMode.SMALLWEBRTC.value)
pipeline_metrics_aggregator = PipelineMetricsAggregator()
pipeline = Pipeline(
[
llm,
capture_processor,
assistant_context_aggregator,
pipeline_metrics_aggregator,
]
)
task = create_pipeline_task(pipeline, workflow_run_id, audio_config)
runner = PipelineRunner(handle_sigint=False, handle_sigterm=False)
runner_task = asyncio.create_task(runner.run(task))
engine.set_task(task)
engine.set_audio_config(audio_config)
engine.set_transport_output(_TaskQueueProxy(task.queue_frame))
engine.set_fetch_recording_audio(
create_recording_audio_fetcher(
organization_id=workflow.organization_id,
pipeline_sample_rate=audio_config.pipeline_sample_rate,
)
)
try:
await asyncio.wait_for(task._pipeline_start_event.wait(), timeout=5.0)
await engine.initialize()
current_node_id = base_checkpoint.get("current_node_id")
target_node_id = current_node_id or workflow_graph.start_node_id
await engine.set_node(target_node_id)
opening_marker = capture_processor.activity_count
opening_expects_llm = (
pending_user_message is None
and (
current_node_id == target_node_id
or engine.get_node_greeting(target_node_id) is None
)
)
if opening_expects_llm:
response_window.note_direct_context_request()
opening_action = await engine.queue_node_opening(
node_id=target_node_id,
previous_node_id=current_node_id,
generate_if_no_greeting=pending_user_message is None,
)
if opening_action != "llm" and opening_expects_llm:
response_window.pending_context_requests = max(
0, response_window.pending_context_requests - 1
)
if opening_action != "none":
await _wait_for_quiescence(
capture_processor=capture_processor,
response_window=response_window,
runner_task=runner_task,
activity_marker=opening_marker,
)
if pending_user_message is not None:
context.add_message({"role": "user", "content": pending_user_message})
generation_marker = capture_processor.activity_count
response_window.note_direct_context_request()
await llm.queue_frame(LLMContextFrame(context))
await _wait_for_quiescence(
capture_processor=capture_processor,
response_window=response_window,
runner_task=runner_task,
activity_marker=generation_marker,
)
finally:
if not task.has_finished():
await task.cancel(reason=TEXT_CHAT_INTERNAL_CANCEL_REASON)
try:
await runner_task
except Exception:
logger.exception(
"Transportless text chat pipeline failed while closing run {}",
workflow_run_id,
)
await engine.cleanup()
raise
await engine.cleanup()
gathered_context = await engine.get_gathered_context()
assistant_text = (
"\n\n".join(part for part in response_window.outputs if part).strip()
if response_window.outputs
else None
)
assistant_created_at = datetime.now(UTC).isoformat()
usage = pipeline_metrics_aggregator.get_all_usage_metrics_serialized()
current_node = getattr(engine, "_current_node", None)
updated_checkpoint = {
"version": TEXT_CHAT_CHECKPOINT_VERSION,
"anchor_turn_id": pending_turn.get("id"),
"current_node_id": current_node.id if current_node else None,
"messages": jsonable_encoder(context.get_messages()),
"gathered_context": jsonable_encoder(gathered_context),
"tool_state": jsonable_encoder(base_checkpoint.get("tool_state") or {}),
}
return TextChatTurnExecutionResult(
assistant_text=assistant_text,
assistant_created_at=assistant_created_at,
events=jsonable_encoder(capture_processor.events),
usage=jsonable_encoder(usage),
checkpoint=updated_checkpoint,
gathered_context=jsonable_encoder(gathered_context),
initial_context=jsonable_encoder(initial_context),
state=(
WorkflowRunState.COMPLETED.value
if engine.is_call_disposed()
else WorkflowRunState.RUNNING.value
),
is_completed=engine.is_call_disposed(),
)

View file

@ -375,6 +375,103 @@ class TestStartGreeting:
result = engine.get_start_greeting()
assert result == ("text", "Hello Alice!")
@pytest.mark.asyncio
async def test_queue_node_opening_queues_text_greeting(
self, text_workflow: WorkflowGraph
):
"""Fresh node entry with a greeting should queue TTS and skip LLM bootstrap."""
llm = Mock()
llm.queue_frame = AsyncMock()
task = Mock()
task.queue_frame = AsyncMock()
engine = PipecatEngine(
llm=llm,
context=LLMContext(),
workflow=text_workflow,
call_context_vars={},
workflow_run_id=1,
)
engine.set_task(task)
result = await engine.queue_node_opening(
node_id=text_workflow.start_node_id,
previous_node_id=None,
generate_if_no_greeting=True,
)
assert result == "greeting"
llm.queue_frame.assert_not_awaited()
queued_frame = task.queue_frame.await_args.args[0]
assert isinstance(queued_frame, TTSSpeakFrame)
assert queued_frame.text == TEXT_GREETING
assert queued_frame.append_to_context is True
@pytest.mark.asyncio
async def test_queue_node_opening_falls_back_to_llm_without_greeting(self):
"""When a node has no greeting, the engine should queue initial LLM generation."""
dto = ReactFlowDTO(
nodes=[
StartCallRFNode(
id="start",
position=Position(x=0, y=0),
data=StartCallNodeData(
name="Start",
prompt="Prompt",
is_start=True,
add_global_prompt=False,
extraction_enabled=False,
),
),
EndCallRFNode(
id="end",
position=Position(x=0, y=200),
data=EndCallNodeData(
name="End",
prompt="End",
is_end=True,
add_global_prompt=False,
extraction_enabled=False,
),
),
],
edges=[
RFEdgeDTO(
id="e",
source="start",
target="end",
data=EdgeDataDTO(label="End", condition="End"),
),
],
)
workflow = WorkflowGraph(dto)
context = LLMContext()
llm = Mock()
llm.queue_frame = AsyncMock()
task = Mock()
task.queue_frame = AsyncMock()
engine = PipecatEngine(
llm=llm,
context=context,
workflow=workflow,
call_context_vars={},
workflow_run_id=1,
)
engine.set_task(task)
result = await engine.queue_node_opening(
node_id=workflow.start_node_id,
previous_node_id=None,
generate_if_no_greeting=True,
)
assert result == "llm"
task.queue_frame.assert_not_awaited()
queued_frame = llm.queue_frame.await_args.args[0]
assert isinstance(queued_frame, LLMContextFrame)
assert queued_frame.context is context
# ─── Tests: Transition Speech (Pipeline) ────────────────────────

View file

@ -0,0 +1,849 @@
from unittest.mock import AsyncMock, patch
import pytest
from pipecat.tests import MockLLMService
from api.db.models import OrganizationModel, UserModel
from api.schemas.user_configuration import UserConfiguration
from api.tests.integrations._run_pipeline_helpers import USER_CONFIGURATION
async def _create_user_and_workflow(
db_session,
async_session,
*,
workflow_definition: dict,
suffix: str,
):
org = OrganizationModel(provider_id=f"textchat-org-{suffix}")
async_session.add(org)
await async_session.flush()
user = UserModel(
provider_id=f"textchat-user-{suffix}",
selected_organization_id=org.id,
)
async_session.add(user)
await async_session.flush()
await db_session.update_user_configuration(
user_id=user.id,
configuration=UserConfiguration.model_validate(USER_CONFIGURATION),
)
workflow = await db_session.create_workflow(
name=f"Text Chat Workflow {suffix}",
workflow_definition=workflow_definition,
user_id=user.id,
organization_id=org.id,
)
return user, workflow
@pytest.mark.asyncio
async def test_text_chat_session_creation_executes_initial_assistant_turn(
db_session,
async_session,
test_client_factory,
):
workflow_definition = {
"nodes": [
{
"id": "start",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {
"name": "Start",
"prompt": "You are a helpful assistant.",
"is_start": True,
"allow_interrupt": False,
"add_global_prompt": False,
},
},
{
"id": "end",
"type": "endCall",
"position": {"x": 0, "y": 200},
"data": {
"name": "End",
"prompt": "Wrap up the conversation.",
"is_end": True,
"allow_interrupt": False,
"add_global_prompt": False,
},
},
],
"edges": [
{
"id": "start-end",
"source": "start",
"target": "end",
"data": {"label": "End Call", "condition": "When the task is done."},
}
],
}
user, workflow = await _create_user_and_workflow(
db_session,
async_session,
workflow_definition=workflow_definition,
suffix="bootstrap",
)
llm = MockLLMService(
mock_steps=[MockLLMService.create_text_chunks("Hello from the workflow tester.")],
chunk_delay=0.001,
)
async with test_client_factory(user) as client:
with patch(
"api.services.workflow.text_chat_runner.create_llm_service",
return_value=llm,
), patch(
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
new=AsyncMock(return_value=False),
):
create_response = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
json={},
)
assert create_response.status_code == 200
created = create_response.json()
turns = created["session_data"]["turns"]
assert created["revision"] == 2
assert created["session_data"]["status"] == "idle"
assert len(turns) == 1
assert turns[0]["status"] == "completed"
assert turns[0]["user_message"] is None
assert turns[0]["assistant_message"]["text"] == "Hello from the workflow tester."
assert turns[0]["checkpoint_after_turn"]["current_node_id"] == "start"
assert created["checkpoint"]["current_node_id"] == "start"
assert created["state"] == "running"
assert "Start" in (created["gathered_context"] or {}).get("nodes_visited", [])
@pytest.mark.asyncio
async def test_text_chat_message_executes_assistant_turn(
db_session,
async_session,
test_client_factory,
):
workflow_definition = {
"nodes": [
{
"id": "start",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {
"name": "Start",
"prompt": "You are a helpful assistant.",
"is_start": True,
"allow_interrupt": False,
"add_global_prompt": False,
"greeting_type": "text",
"greeting": "Welcome to the workflow tester.",
},
},
{
"id": "end",
"type": "endCall",
"position": {"x": 0, "y": 200},
"data": {
"name": "End",
"prompt": "Wrap up the conversation.",
"is_end": True,
"allow_interrupt": False,
"add_global_prompt": False,
},
},
],
"edges": [
{
"id": "start-end",
"source": "start",
"target": "end",
"data": {"label": "End Call", "condition": "When the task is done."},
}
],
}
user, workflow = await _create_user_and_workflow(
db_session,
async_session,
workflow_definition=workflow_definition,
suffix="basic",
)
llm_responses = [
MockLLMService(mock_steps=[], chunk_delay=0.001),
MockLLMService(
mock_steps=[MockLLMService.create_text_chunks("Hello from the workflow tester.")],
chunk_delay=0.001,
),
]
async with test_client_factory(user) as client:
with patch(
"api.services.workflow.text_chat_runner.create_llm_service",
side_effect=llm_responses,
), patch(
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
new=AsyncMock(return_value=False),
):
create_response = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
json={},
)
assert create_response.status_code == 200
created = create_response.json()
message_response = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{created['workflow_run_id']}/messages",
json={
"text": "Hi there",
"expected_revision": created["revision"],
},
)
assert message_response.status_code == 200
payload = message_response.json()
turns = payload["session_data"]["turns"]
assert payload["revision"] == 4
assert payload["session_data"]["status"] == "idle"
assert len(turns) == 2
assert turns[0]["user_message"] is None
assert turns[0]["assistant_message"]["text"] == "Welcome to the workflow tester."
assert turns[1]["status"] == "completed"
assert turns[1]["user_message"]["text"] == "Hi there"
assert turns[1]["assistant_message"]["text"] == "Hello from the workflow tester."
assert turns[1]["checkpoint_after_turn"]["current_node_id"] == "start"
assert payload["checkpoint"]["current_node_id"] == "start"
assert payload["state"] == "running"
assert "Start" in (payload["gathered_context"] or {}).get("nodes_visited", [])
@pytest.mark.asyncio
async def test_text_chat_executes_deferred_tool_calls_after_text_response(
db_session,
async_session,
test_client_factory,
):
workflow_definition = {
"nodes": [
{
"id": "start",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {
"name": "Start",
"prompt": "You are at the start node.",
"is_start": True,
"allow_interrupt": False,
"add_global_prompt": False,
"greeting_type": "text",
"greeting": "Welcome to the workflow tester.",
},
},
{
"id": "agent1",
"type": "agentNode",
"position": {"x": 0, "y": 200},
"data": {
"name": "Agent One",
"prompt": "You are in agent one.",
"allow_interrupt": False,
"add_global_prompt": False,
},
},
],
"edges": [
{
"id": "start-agent1",
"source": "start",
"target": "agent1",
"data": {
"label": "Go To Agent One",
"condition": "Move to agent one.",
},
}
],
}
user, workflow = await _create_user_and_workflow(
db_session,
async_session,
workflow_definition=workflow_definition,
suffix="mixed-tool-turn",
)
llm_responses = [
MockLLMService(mock_steps=[], chunk_delay=0.001),
MockLLMService(
mock_steps=[
MockLLMService.create_mixed_chunks(
"Let me transfer you.",
"go_to_agent_one",
{},
tool_call_id="call_agent_one",
),
MockLLMService.create_text_chunks("Agent one here."),
],
chunk_delay=0.001,
),
]
async with test_client_factory(user) as client:
with patch(
"api.services.workflow.text_chat_runner.create_llm_service",
side_effect=llm_responses,
), patch(
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
new=AsyncMock(return_value=False),
):
create_response = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
json={},
)
assert create_response.status_code == 200
session = create_response.json()
message_response = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{session['workflow_run_id']}/messages",
json={
"text": "Please transfer me",
"expected_revision": session["revision"],
},
)
assert message_response.status_code == 200
payload = message_response.json()
assistant_text = payload["session_data"]["turns"][1]["assistant_message"]["text"]
assert "Let me transfer you." in assistant_text
assert "Agent one here." in assistant_text
assert payload["checkpoint"]["current_node_id"] == "agent1"
assert any(
event["type"] == "tool_call_started"
and event["payload"]["function_name"] == "go_to_agent_one"
for event in payload["session_data"]["turns"][1]["events"]
)
@pytest.mark.asyncio
async def test_text_chat_chains_multiple_follow_up_completions_in_one_turn(
db_session,
async_session,
test_client_factory,
):
workflow_definition = {
"nodes": [
{
"id": "start",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {
"name": "Start",
"prompt": "You are at the start node.",
"is_start": True,
"allow_interrupt": False,
"add_global_prompt": False,
"greeting_type": "text",
"greeting": "Welcome to the workflow tester.",
},
},
{
"id": "agent1",
"type": "agentNode",
"position": {"x": 0, "y": 200},
"data": {
"name": "Agent One",
"prompt": "You are in agent one.",
"allow_interrupt": False,
"add_global_prompt": False,
},
},
{
"id": "agent2",
"type": "agentNode",
"position": {"x": 0, "y": 400},
"data": {
"name": "Agent Two",
"prompt": "You are in agent two.",
"allow_interrupt": False,
"add_global_prompt": False,
},
},
],
"edges": [
{
"id": "start-agent1",
"source": "start",
"target": "agent1",
"data": {
"label": "Go To Agent One",
"condition": "Move to agent one.",
},
},
{
"id": "agent1-agent2",
"source": "agent1",
"target": "agent2",
"data": {
"label": "Go To Agent Two",
"condition": "Move to agent two.",
},
},
],
}
user, workflow = await _create_user_and_workflow(
db_session,
async_session,
workflow_definition=workflow_definition,
suffix="multi-hop-turn",
)
llm_responses = [
MockLLMService(mock_steps=[], chunk_delay=0.001),
MockLLMService(
mock_steps=[
MockLLMService.create_mixed_chunks(
"Moving to agent one.",
"go_to_agent_one",
{},
tool_call_id="call_agent_one",
),
MockLLMService.create_mixed_chunks(
"Moving to agent two.",
"go_to_agent_two",
{},
tool_call_id="call_agent_two",
),
MockLLMService.create_text_chunks("Agent two here."),
],
chunk_delay=0.001,
),
]
async with test_client_factory(user) as client:
with patch(
"api.services.workflow.text_chat_runner.create_llm_service",
side_effect=llm_responses,
), patch(
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
new=AsyncMock(return_value=False),
):
create_response = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
json={},
)
assert create_response.status_code == 200
session = create_response.json()
message_response = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{session['workflow_run_id']}/messages",
json={
"text": "Please route me through the flow",
"expected_revision": session["revision"],
},
)
assert message_response.status_code == 200
payload = message_response.json()
assistant_text = payload["session_data"]["turns"][1]["assistant_message"]["text"]
assert "Moving to agent one." in assistant_text
assert "Moving to agent two." in assistant_text
assert "Agent two here." in assistant_text
assert payload["checkpoint"]["current_node_id"] == "agent2"
assert sum(
1
for event in payload["session_data"]["turns"][1]["events"]
if event["type"] == "tool_call_started"
) == 2
@pytest.mark.asyncio
async def test_text_chat_greeting_only_plays_on_fresh_node_entry(
db_session,
async_session,
test_client_factory,
):
workflow_definition = {
"nodes": [
{
"id": "start",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {
"name": "Start",
"prompt": "You are a helpful assistant.",
"is_start": True,
"allow_interrupt": False,
"add_global_prompt": False,
"greeting_type": "text",
"greeting": "Welcome to the workflow tester.",
},
},
{
"id": "end",
"type": "endCall",
"position": {"x": 0, "y": 200},
"data": {
"name": "End",
"prompt": "Wrap up the conversation.",
"is_end": True,
"allow_interrupt": False,
"add_global_prompt": False,
},
},
],
"edges": [
{
"id": "start-end",
"source": "start",
"target": "end",
"data": {"label": "End Call", "condition": "When the task is done."},
}
],
}
user, workflow = await _create_user_and_workflow(
db_session,
async_session,
workflow_definition=workflow_definition,
suffix="greeting-once",
)
llm_responses = [
MockLLMService(mock_steps=[], chunk_delay=0.001),
MockLLMService(
mock_steps=[MockLLMService.create_text_chunks("First answer.")],
chunk_delay=0.001,
),
MockLLMService(
mock_steps=[MockLLMService.create_text_chunks("Second answer.")],
chunk_delay=0.001,
),
]
async with test_client_factory(user) as client:
with patch(
"api.services.workflow.text_chat_runner.create_llm_service",
side_effect=llm_responses,
), patch(
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
new=AsyncMock(return_value=False),
):
create_response = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
json={},
)
assert create_response.status_code == 200
session = create_response.json()
opening_text = session["session_data"]["turns"][0]["assistant_message"]["text"]
first_message = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{session['workflow_run_id']}/messages",
json={
"text": "First turn",
"expected_revision": session["revision"],
},
)
assert first_message.status_code == 200
first_payload = first_message.json()
second_message = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{session['workflow_run_id']}/messages",
json={
"text": "Second turn",
"expected_revision": first_payload["revision"],
},
)
assert second_message.status_code == 200
first_text = first_payload["session_data"]["turns"][1]["assistant_message"]["text"]
second_text = second_message.json()["session_data"]["turns"][2]["assistant_message"][
"text"
]
assert opening_text == "Welcome to the workflow tester."
assert "Welcome to the workflow tester." not in first_text
assert "First answer." in first_text
assert "Welcome to the workflow tester." not in second_text
assert "Second answer." in second_text
@pytest.mark.asyncio
async def test_text_chat_rewind_reuses_checkpoint_snapshot(
db_session,
async_session,
test_client_factory,
):
workflow_definition = {
"nodes": [
{
"id": "start",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {
"name": "Start",
"prompt": "You are at the start node.",
"is_start": True,
"allow_interrupt": False,
"add_global_prompt": False,
"greeting_type": "text",
"greeting": "Welcome to the rewind test.",
},
},
{
"id": "agent1",
"type": "agentNode",
"position": {"x": 0, "y": 200},
"data": {
"name": "Agent One",
"prompt": "You are in agent one.",
"allow_interrupt": False,
"add_global_prompt": False,
},
},
{
"id": "agent2",
"type": "agentNode",
"position": {"x": 0, "y": 400},
"data": {
"name": "Agent Two",
"prompt": "You are in agent two.",
"allow_interrupt": False,
"add_global_prompt": False,
},
},
{
"id": "end",
"type": "endCall",
"position": {"x": 0, "y": 600},
"data": {
"name": "End",
"prompt": "You are at the end node.",
"is_end": True,
"allow_interrupt": False,
"add_global_prompt": False,
},
},
],
"edges": [
{
"id": "start-agent1",
"source": "start",
"target": "agent1",
"data": {
"label": "Go To Agent One",
"condition": "Move to agent one.",
},
},
{
"id": "agent1-agent2",
"source": "agent1",
"target": "agent2",
"data": {
"label": "Go To Agent Two",
"condition": "Move to agent two.",
},
},
{
"id": "agent2-end",
"source": "agent2",
"target": "end",
"data": {"label": "Finish", "condition": "End the flow."},
},
],
}
user, workflow = await _create_user_and_workflow(
db_session,
async_session,
workflow_definition=workflow_definition,
suffix="rewind",
)
llm_responses = [
MockLLMService(mock_steps=[], chunk_delay=0.001),
MockLLMService(
mock_steps=[
MockLLMService.create_function_call_chunks(
"go_to_agent_one",
{},
tool_call_id="call_agent_one",
),
MockLLMService.create_text_chunks("Agent one here."),
],
chunk_delay=0.001,
),
MockLLMService(
mock_steps=[
MockLLMService.create_function_call_chunks(
"go_to_agent_two",
{},
tool_call_id="call_agent_two",
),
MockLLMService.create_text_chunks("Agent two here."),
],
chunk_delay=0.001,
),
MockLLMService(
mock_steps=[MockLLMService.create_text_chunks("Back in agent one.")],
chunk_delay=0.001,
),
]
async with test_client_factory(user) as client:
with patch(
"api.services.workflow.text_chat_runner.create_llm_service",
side_effect=llm_responses,
), patch(
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
new=AsyncMock(return_value=False),
):
create_response = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
json={},
)
assert create_response.status_code == 200
session = create_response.json()
first_message = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{session['workflow_run_id']}/messages",
json={
"text": "First turn",
"expected_revision": session["revision"],
},
)
assert first_message.status_code == 200
first_payload = first_message.json()
first_turn_id = first_payload["session_data"]["turns"][1]["id"]
assert first_payload["checkpoint"]["current_node_id"] == "agent1"
second_message = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{session['workflow_run_id']}/messages",
json={
"text": "Second turn",
"expected_revision": first_payload["revision"],
},
)
assert second_message.status_code == 200
second_payload = second_message.json()
assert second_payload["checkpoint"]["current_node_id"] == "agent2"
rewind_response = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{session['workflow_run_id']}/rewind",
json={
"cursor_turn_id": first_turn_id,
"expected_revision": second_payload["revision"],
},
)
assert rewind_response.status_code == 200
rewound = rewind_response.json()
assert rewound["session_data"]["cursor_turn_id"] == first_turn_id
third_message = await client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{session['workflow_run_id']}/messages",
json={
"text": "Third turn after rewind",
"expected_revision": rewound["revision"],
},
)
assert third_message.status_code == 200
payload = third_message.json()
assert payload["checkpoint"]["current_node_id"] == "agent1"
assert payload["session_data"]["discarded_future"]
assert len(payload["session_data"]["turns"]) == 3
assert payload["session_data"]["turns"][1]["id"] == first_turn_id
assert (
payload["session_data"]["turns"][2]["assistant_message"]["text"]
== "Back in agent one."
)
@pytest.mark.asyncio
async def test_text_chat_session_is_not_accessible_from_another_org(
db_session,
async_session,
test_client_factory,
):
workflow_definition = {
"nodes": [
{
"id": "start",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {
"name": "Start",
"prompt": "You are a helpful assistant.",
"is_start": True,
"allow_interrupt": False,
"add_global_prompt": False,
},
},
{
"id": "end",
"type": "endCall",
"position": {"x": 0, "y": 200},
"data": {
"name": "End",
"prompt": "Wrap up the conversation.",
"is_end": True,
"allow_interrupt": False,
"add_global_prompt": False,
},
},
],
"edges": [
{
"id": "start-end",
"source": "start",
"target": "end",
"data": {"label": "End Call", "condition": "When the task is done."},
}
],
}
owner_user, workflow = await _create_user_and_workflow(
db_session,
async_session,
workflow_definition=workflow_definition,
suffix="owner",
)
other_user, _ = await _create_user_and_workflow(
db_session,
async_session,
workflow_definition=workflow_definition,
suffix="other",
)
async with test_client_factory(owner_user) as owner_client:
llm = MockLLMService(
mock_steps=[MockLLMService.create_text_chunks("Hello from the workflow tester.")],
chunk_delay=0.001,
)
with patch(
"api.services.workflow.text_chat_runner.create_llm_service",
return_value=llm,
), patch(
"api.services.workflow.text_chat_runner.db_client.has_active_recordings",
new=AsyncMock(return_value=False),
):
create_response = await owner_client.post(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions",
json={},
)
assert create_response.status_code == 200
created = create_response.json()
async with test_client_factory(other_user) as other_client:
get_response = await other_client.get(
f"/api/v1/workflow/{workflow.id}/text-chat/sessions/{created['workflow_run_id']}"
)
assert get_response.status_code == 404