diff --git a/api/services/workflow/text_chat_runner.py b/api/services/workflow/text_chat_runner.py index b9fa06ab..662943e6 100644 --- a/api/services/workflow/text_chat_runner.py +++ b/api/services/workflow/text_chat_runner.py @@ -6,7 +6,7 @@ from datetime import UTC, datetime from typing import Any from fastapi.encoders import jsonable_encoder -from loguru import logger +from pipecat.bus.serializers.json import JSONMessageSerializer from pipecat.frames.frames import ( BotStoppedSpeakingFrame, CancelFrame, @@ -56,6 +56,46 @@ TEXT_CHAT_IDLE_SETTLE_SECONDS = 0.2 TEXT_CHAT_INTERNAL_CANCEL_REASON = "text_chat_turn_complete" +def _pipecat_type_tag(type_: type) -> str: + return f"{type_.__module__}.{type_.__name__}" + + +def _pipecat_json_serializer() -> JSONMessageSerializer: + return JSONMessageSerializer() + + +def _serialize_text_chat_checkpoint_messages(messages: list[Any]) -> list[Any]: + """Serialize Pipecat context messages for JSONB checkpoint storage.""" + # Pipecat's bus JSON serializer already knows how to preserve LLMContext, + # LLMSpecificMessage, and binary provider fields such as Gemini signatures. + # Keep the serializer shape dependency contained to these checkpoint helpers. + encoded_context = _pipecat_json_serializer()._serialize_value( + LLMContext(messages=list(messages)) + ) + encoded_data = ( + encoded_context.get("__data__") if isinstance(encoded_context, dict) else None + ) + encoded_messages = ( + encoded_data.get("messages") if isinstance(encoded_data, dict) else None + ) + if not isinstance(encoded_messages, list): + raise TypeError("Pipecat LLMContext serializer returned an unexpected shape") + return encoded_messages + + +def _deserialize_text_chat_checkpoint_messages(messages: list[Any]) -> list[Any]: + """Restore JSONB checkpoint messages to Pipecat context message objects.""" + restored_context = _pipecat_json_serializer()._deserialize_value( + { + "__type__": _pipecat_type_tag(LLMContext), + "__data__": {"messages": list(messages)}, + } + ) + if not isinstance(restored_context, LLMContext): + raise TypeError("Pipecat LLMContext deserializer returned an unexpected type") + return restored_context.get_messages() + + def text_chat_trace_id(workflow_run_id: int) -> str: """Deterministic Langfuse trace id for a text-chat session. @@ -391,7 +431,6 @@ async def execute_text_chat_pending_turn( if pending_turn.get("user_message") is not None else None ) - workflow_run, _ = await db_client.get_workflow_run_with_context(workflow_run_id) if not workflow_run or workflow_run.workflow_id != workflow_id: raise ValueError("Workflow run not found for text chat execution") @@ -405,7 +444,6 @@ async def execute_text_chat_pending_turn( ) if workflow is None: raise ValueError("Workflow not found for text chat execution") - # Stamp the async context so OTEL spans are tagged with this org and routed # to its Langfuse project (the voice paths do this in run_pipeline / # webrtc_signaling; the text path previously skipped it, so its spans never @@ -426,7 +464,6 @@ async def execute_text_chat_pending_turn( ) if user_config.llm is None: raise ValueError("Text chat requires an LLM configuration") - from api.services.managed_model_services import ( MPS_CORRELATION_ID_CONTEXT_KEY, ensure_mps_correlation_id, @@ -462,9 +499,10 @@ async def execute_text_chat_pending_turn( skip_instance_constraints_for={"trigger"}, ) base_checkpoint = _resolve_checkpoint_for_pending_turn(session_data, checkpoint) - context = LLMContext() - context.set_messages(base_checkpoint["messages"]) + context.set_messages( + _deserialize_text_chat_checkpoint_messages(base_checkpoint["messages"]) + ) response_window = _ResponseWindowState() capture_processor = _TextChatCaptureProcessor(response_window, context) @@ -511,7 +549,6 @@ async def execute_text_chat_pending_turn( context_compaction_enabled = (workflow.workflow_configurations or {}).get( "context_compaction_enabled", False ) - engine = PipecatEngine( llm=llm, inference_llm=inference_llm, @@ -557,7 +594,6 @@ async def execute_text_chat_pending_turn( trace_span_attributes = { "langfuse.trace.name": workflow_run.name or f"text-chat-{workflow_run_id}" } - pipeline = Pipeline( [ llm, @@ -634,20 +670,13 @@ async def execute_text_chat_pending_turn( activity_marker=generation_marker, ) finally: - if not task.has_finished(): - await task.cancel(reason=TEXT_CHAT_INTERNAL_CANCEL_REASON) try: + if not task.has_finished(): + await task.cancel(reason=TEXT_CHAT_INTERNAL_CANCEL_REASON) await runner_task - except Exception: - logger.exception( - "Transportless text chat pipeline failed while closing run {}", - workflow_run_id, - ) + finally: await engine.close_mcp_sessions() await engine.cleanup() - raise - await engine.close_mcp_sessions() - await engine.cleanup() gathered_context = await engine.get_gathered_context() assistant_text = ( @@ -658,29 +687,36 @@ async def execute_text_chat_pending_turn( assistant_created_at = datetime.now(UTC).isoformat() usage = pipeline_metrics_aggregator.get_all_usage_metrics_serialized() current_node = getattr(engine, "_current_node", None) + context_messages = context.get_messages() + encoded_messages = _serialize_text_chat_checkpoint_messages(context_messages) + encoded_gathered_context = jsonable_encoder(gathered_context) + encoded_tool_state = jsonable_encoder(base_checkpoint.get("tool_state") or {}) updated_checkpoint = { "version": TEXT_CHAT_CHECKPOINT_VERSION, "anchor_turn_id": pending_turn.get("id"), "current_node_id": current_node.id if current_node else None, - "messages": jsonable_encoder(context.get_messages()), - "gathered_context": jsonable_encoder(gathered_context), - "tool_state": jsonable_encoder(base_checkpoint.get("tool_state") or {}), + "messages": encoded_messages, + "gathered_context": encoded_gathered_context, + "tool_state": encoded_tool_state, } - encoded_gathered_context = jsonable_encoder(gathered_context) trace_url = get_trace_url(trace_id, org_id=workflow.organization_id) if trace_url: encoded_gathered_context = {**encoded_gathered_context, "trace_url": trace_url} + encoded_events = jsonable_encoder(capture_processor.events) + encoded_usage = jsonable_encoder(usage) + encoded_initial_context = jsonable_encoder(initial_context) + return TextChatTurnExecutionResult( assistant_text=assistant_text, assistant_created_at=assistant_created_at, - events=jsonable_encoder(capture_processor.events), - usage=jsonable_encoder(usage), + events=encoded_events, + usage=encoded_usage, checkpoint=updated_checkpoint, gathered_context=encoded_gathered_context, - initial_context=jsonable_encoder(initial_context), + initial_context=encoded_initial_context, state=( WorkflowRunState.COMPLETED.value if engine.is_call_disposed() diff --git a/api/tests/test_workflow_text_chat.py b/api/tests/test_workflow_text_chat.py index 40afdcfb..972661bc 100644 --- a/api/tests/test_workflow_text_chat.py +++ b/api/tests/test_workflow_text_chat.py @@ -1,10 +1,16 @@ +import json from types import SimpleNamespace from unittest.mock import AsyncMock, patch import pytest +from pipecat.processors.aggregators.llm_context import LLMSpecificMessage from api.db.models import OrganizationModel, UserModel from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration +from api.services.workflow.text_chat_runner import ( + _deserialize_text_chat_checkpoint_messages, + _serialize_text_chat_checkpoint_messages, +) from api.tests.integrations._run_pipeline_helpers import USER_CONFIGURATION from pipecat.tests import MockLLMService @@ -18,6 +24,49 @@ def _log_texts(logs: dict | None, event_type: str) -> list[str]: ] +def test_text_chat_checkpoint_messages_round_trip_google_thought_signature(): + signature = bytes.fromhex("12340a32010c39d6c7f38fd8b8eb6ab0") + messages = [ + {"role": "assistant", "content": "Hello."}, + { + "role": "user", + "content": "Hi", + }, + LLMSpecificMessage( + llm="google", + message={ + "type": "thought_signature", + "signature": signature, + "bookmark": {"text": "Hello."}, + }, + ), + ] + + encoded = _serialize_text_chat_checkpoint_messages(messages) + + json.dumps(encoded) + assert encoded[-1] == { + "__specific__": True, + "llm": "google", + "message": { + "type": "thought_signature", + "signature": { + "__type__": "bytes", + "__data__": "EjQKMgEMOdbH84/YuOtqsA==", + }, + "bookmark": {"text": "Hello."}, + }, + } + + restored = _deserialize_text_chat_checkpoint_messages(encoded) + + assert restored[:2] == messages[:2] + assert isinstance(restored[-1], LLMSpecificMessage) + assert restored[-1].llm == "google" + assert restored[-1].message["signature"] == signature + assert restored[-1].message["bookmark"] == {"text": "Hello."} + + async def _create_user_and_workflow( db_session, async_session,