fix: chat serialization and deserialization in text runner

Related Issue: #455
This commit is contained in:
Abhishek Kumar 2026-06-24 17:53:54 +05:30
parent 1e5e556a4d
commit 811b9e9803
2 changed files with 110 additions and 25 deletions

View file

@ -6,7 +6,7 @@ from datetime import UTC, datetime
from typing import Any
from fastapi.encoders import jsonable_encoder
from loguru import logger
from pipecat.bus.serializers.json import JSONMessageSerializer
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
@ -56,6 +56,46 @@ TEXT_CHAT_IDLE_SETTLE_SECONDS = 0.2
TEXT_CHAT_INTERNAL_CANCEL_REASON = "text_chat_turn_complete"
def _pipecat_type_tag(type_: type) -> str:
return f"{type_.__module__}.{type_.__name__}"
def _pipecat_json_serializer() -> JSONMessageSerializer:
return JSONMessageSerializer()
def _serialize_text_chat_checkpoint_messages(messages: list[Any]) -> list[Any]:
"""Serialize Pipecat context messages for JSONB checkpoint storage."""
# Pipecat's bus JSON serializer already knows how to preserve LLMContext,
# LLMSpecificMessage, and binary provider fields such as Gemini signatures.
# Keep the serializer shape dependency contained to these checkpoint helpers.
encoded_context = _pipecat_json_serializer()._serialize_value(
LLMContext(messages=list(messages))
)
encoded_data = (
encoded_context.get("__data__") if isinstance(encoded_context, dict) else None
)
encoded_messages = (
encoded_data.get("messages") if isinstance(encoded_data, dict) else None
)
if not isinstance(encoded_messages, list):
raise TypeError("Pipecat LLMContext serializer returned an unexpected shape")
return encoded_messages
def _deserialize_text_chat_checkpoint_messages(messages: list[Any]) -> list[Any]:
"""Restore JSONB checkpoint messages to Pipecat context message objects."""
restored_context = _pipecat_json_serializer()._deserialize_value(
{
"__type__": _pipecat_type_tag(LLMContext),
"__data__": {"messages": list(messages)},
}
)
if not isinstance(restored_context, LLMContext):
raise TypeError("Pipecat LLMContext deserializer returned an unexpected type")
return restored_context.get_messages()
def text_chat_trace_id(workflow_run_id: int) -> str:
"""Deterministic Langfuse trace id for a text-chat session.
@ -391,7 +431,6 @@ async def execute_text_chat_pending_turn(
if pending_turn.get("user_message") is not None
else None
)
workflow_run, _ = await db_client.get_workflow_run_with_context(workflow_run_id)
if not workflow_run or workflow_run.workflow_id != workflow_id:
raise ValueError("Workflow run not found for text chat execution")
@ -405,7 +444,6 @@ async def execute_text_chat_pending_turn(
)
if workflow is None:
raise ValueError("Workflow not found for text chat execution")
# Stamp the async context so OTEL spans are tagged with this org and routed
# to its Langfuse project (the voice paths do this in run_pipeline /
# webrtc_signaling; the text path previously skipped it, so its spans never
@ -426,7 +464,6 @@ async def execute_text_chat_pending_turn(
)
if user_config.llm is None:
raise ValueError("Text chat requires an LLM configuration")
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
ensure_mps_correlation_id,
@ -462,9 +499,10 @@ async def execute_text_chat_pending_turn(
skip_instance_constraints_for={"trigger"},
)
base_checkpoint = _resolve_checkpoint_for_pending_turn(session_data, checkpoint)
context = LLMContext()
context.set_messages(base_checkpoint["messages"])
context.set_messages(
_deserialize_text_chat_checkpoint_messages(base_checkpoint["messages"])
)
response_window = _ResponseWindowState()
capture_processor = _TextChatCaptureProcessor(response_window, context)
@ -511,7 +549,6 @@ async def execute_text_chat_pending_turn(
context_compaction_enabled = (workflow.workflow_configurations or {}).get(
"context_compaction_enabled", False
)
engine = PipecatEngine(
llm=llm,
inference_llm=inference_llm,
@ -557,7 +594,6 @@ async def execute_text_chat_pending_turn(
trace_span_attributes = {
"langfuse.trace.name": workflow_run.name or f"text-chat-{workflow_run_id}"
}
pipeline = Pipeline(
[
llm,
@ -634,20 +670,13 @@ async def execute_text_chat_pending_turn(
activity_marker=generation_marker,
)
finally:
if not task.has_finished():
await task.cancel(reason=TEXT_CHAT_INTERNAL_CANCEL_REASON)
try:
if not task.has_finished():
await task.cancel(reason=TEXT_CHAT_INTERNAL_CANCEL_REASON)
await runner_task
except Exception:
logger.exception(
"Transportless text chat pipeline failed while closing run {}",
workflow_run_id,
)
finally:
await engine.close_mcp_sessions()
await engine.cleanup()
raise
await engine.close_mcp_sessions()
await engine.cleanup()
gathered_context = await engine.get_gathered_context()
assistant_text = (
@ -658,29 +687,36 @@ async def execute_text_chat_pending_turn(
assistant_created_at = datetime.now(UTC).isoformat()
usage = pipeline_metrics_aggregator.get_all_usage_metrics_serialized()
current_node = getattr(engine, "_current_node", None)
context_messages = context.get_messages()
encoded_messages = _serialize_text_chat_checkpoint_messages(context_messages)
encoded_gathered_context = jsonable_encoder(gathered_context)
encoded_tool_state = jsonable_encoder(base_checkpoint.get("tool_state") or {})
updated_checkpoint = {
"version": TEXT_CHAT_CHECKPOINT_VERSION,
"anchor_turn_id": pending_turn.get("id"),
"current_node_id": current_node.id if current_node else None,
"messages": jsonable_encoder(context.get_messages()),
"gathered_context": jsonable_encoder(gathered_context),
"tool_state": jsonable_encoder(base_checkpoint.get("tool_state") or {}),
"messages": encoded_messages,
"gathered_context": encoded_gathered_context,
"tool_state": encoded_tool_state,
}
encoded_gathered_context = jsonable_encoder(gathered_context)
trace_url = get_trace_url(trace_id, org_id=workflow.organization_id)
if trace_url:
encoded_gathered_context = {**encoded_gathered_context, "trace_url": trace_url}
encoded_events = jsonable_encoder(capture_processor.events)
encoded_usage = jsonable_encoder(usage)
encoded_initial_context = jsonable_encoder(initial_context)
return TextChatTurnExecutionResult(
assistant_text=assistant_text,
assistant_created_at=assistant_created_at,
events=jsonable_encoder(capture_processor.events),
usage=jsonable_encoder(usage),
events=encoded_events,
usage=encoded_usage,
checkpoint=updated_checkpoint,
gathered_context=encoded_gathered_context,
initial_context=jsonable_encoder(initial_context),
initial_context=encoded_initial_context,
state=(
WorkflowRunState.COMPLETED.value
if engine.is_call_disposed()

View file

@ -1,10 +1,16 @@
import json
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
from pipecat.processors.aggregators.llm_context import LLMSpecificMessage
from api.db.models import OrganizationModel, UserModel
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.workflow.text_chat_runner import (
_deserialize_text_chat_checkpoint_messages,
_serialize_text_chat_checkpoint_messages,
)
from api.tests.integrations._run_pipeline_helpers import USER_CONFIGURATION
from pipecat.tests import MockLLMService
@ -18,6 +24,49 @@ def _log_texts(logs: dict | None, event_type: str) -> list[str]:
]
def test_text_chat_checkpoint_messages_round_trip_google_thought_signature():
signature = bytes.fromhex("12340a32010c39d6c7f38fd8b8eb6ab0")
messages = [
{"role": "assistant", "content": "Hello."},
{
"role": "user",
"content": "Hi",
},
LLMSpecificMessage(
llm="google",
message={
"type": "thought_signature",
"signature": signature,
"bookmark": {"text": "Hello."},
},
),
]
encoded = _serialize_text_chat_checkpoint_messages(messages)
json.dumps(encoded)
assert encoded[-1] == {
"__specific__": True,
"llm": "google",
"message": {
"type": "thought_signature",
"signature": {
"__type__": "bytes",
"__data__": "EjQKMgEMOdbH84/YuOtqsA==",
},
"bookmark": {"text": "Hello."},
},
}
restored = _deserialize_text_chat_checkpoint_messages(encoded)
assert restored[:2] == messages[:2]
assert isinstance(restored[-1], LLMSpecificMessage)
assert restored[-1].llm == "google"
assert restored[-1].message["signature"] == signature
assert restored[-1].message["bookmark"] == {"text": "Hello."}
async def _create_user_and_workflow(
db_session,
async_session,