mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-25 08:48:13 +02:00
fix: chat serialization and deserialization in text runner
Related Issue: #455
This commit is contained in:
parent
1e5e556a4d
commit
811b9e9803
2 changed files with 110 additions and 25 deletions
|
|
@ -6,7 +6,7 @@ from datetime import UTC, datetime
|
|||
from typing import Any
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from loguru import logger
|
||||
from pipecat.bus.serializers.json import JSONMessageSerializer
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
|
|
@ -56,6 +56,46 @@ TEXT_CHAT_IDLE_SETTLE_SECONDS = 0.2
|
|||
TEXT_CHAT_INTERNAL_CANCEL_REASON = "text_chat_turn_complete"
|
||||
|
||||
|
||||
def _pipecat_type_tag(type_: type) -> str:
|
||||
return f"{type_.__module__}.{type_.__name__}"
|
||||
|
||||
|
||||
def _pipecat_json_serializer() -> JSONMessageSerializer:
|
||||
return JSONMessageSerializer()
|
||||
|
||||
|
||||
def _serialize_text_chat_checkpoint_messages(messages: list[Any]) -> list[Any]:
|
||||
"""Serialize Pipecat context messages for JSONB checkpoint storage."""
|
||||
# Pipecat's bus JSON serializer already knows how to preserve LLMContext,
|
||||
# LLMSpecificMessage, and binary provider fields such as Gemini signatures.
|
||||
# Keep the serializer shape dependency contained to these checkpoint helpers.
|
||||
encoded_context = _pipecat_json_serializer()._serialize_value(
|
||||
LLMContext(messages=list(messages))
|
||||
)
|
||||
encoded_data = (
|
||||
encoded_context.get("__data__") if isinstance(encoded_context, dict) else None
|
||||
)
|
||||
encoded_messages = (
|
||||
encoded_data.get("messages") if isinstance(encoded_data, dict) else None
|
||||
)
|
||||
if not isinstance(encoded_messages, list):
|
||||
raise TypeError("Pipecat LLMContext serializer returned an unexpected shape")
|
||||
return encoded_messages
|
||||
|
||||
|
||||
def _deserialize_text_chat_checkpoint_messages(messages: list[Any]) -> list[Any]:
|
||||
"""Restore JSONB checkpoint messages to Pipecat context message objects."""
|
||||
restored_context = _pipecat_json_serializer()._deserialize_value(
|
||||
{
|
||||
"__type__": _pipecat_type_tag(LLMContext),
|
||||
"__data__": {"messages": list(messages)},
|
||||
}
|
||||
)
|
||||
if not isinstance(restored_context, LLMContext):
|
||||
raise TypeError("Pipecat LLMContext deserializer returned an unexpected type")
|
||||
return restored_context.get_messages()
|
||||
|
||||
|
||||
def text_chat_trace_id(workflow_run_id: int) -> str:
|
||||
"""Deterministic Langfuse trace id for a text-chat session.
|
||||
|
||||
|
|
@ -391,7 +431,6 @@ async def execute_text_chat_pending_turn(
|
|||
if pending_turn.get("user_message") is not None
|
||||
else None
|
||||
)
|
||||
|
||||
workflow_run, _ = await db_client.get_workflow_run_with_context(workflow_run_id)
|
||||
if not workflow_run or workflow_run.workflow_id != workflow_id:
|
||||
raise ValueError("Workflow run not found for text chat execution")
|
||||
|
|
@ -405,7 +444,6 @@ async def execute_text_chat_pending_turn(
|
|||
)
|
||||
if workflow is None:
|
||||
raise ValueError("Workflow not found for text chat execution")
|
||||
|
||||
# Stamp the async context so OTEL spans are tagged with this org and routed
|
||||
# to its Langfuse project (the voice paths do this in run_pipeline /
|
||||
# webrtc_signaling; the text path previously skipped it, so its spans never
|
||||
|
|
@ -426,7 +464,6 @@ async def execute_text_chat_pending_turn(
|
|||
)
|
||||
if user_config.llm is None:
|
||||
raise ValueError("Text chat requires an LLM configuration")
|
||||
|
||||
from api.services.managed_model_services import (
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY,
|
||||
ensure_mps_correlation_id,
|
||||
|
|
@ -462,9 +499,10 @@ async def execute_text_chat_pending_turn(
|
|||
skip_instance_constraints_for={"trigger"},
|
||||
)
|
||||
base_checkpoint = _resolve_checkpoint_for_pending_turn(session_data, checkpoint)
|
||||
|
||||
context = LLMContext()
|
||||
context.set_messages(base_checkpoint["messages"])
|
||||
context.set_messages(
|
||||
_deserialize_text_chat_checkpoint_messages(base_checkpoint["messages"])
|
||||
)
|
||||
response_window = _ResponseWindowState()
|
||||
capture_processor = _TextChatCaptureProcessor(response_window, context)
|
||||
|
||||
|
|
@ -511,7 +549,6 @@ async def execute_text_chat_pending_turn(
|
|||
context_compaction_enabled = (workflow.workflow_configurations or {}).get(
|
||||
"context_compaction_enabled", False
|
||||
)
|
||||
|
||||
engine = PipecatEngine(
|
||||
llm=llm,
|
||||
inference_llm=inference_llm,
|
||||
|
|
@ -557,7 +594,6 @@ async def execute_text_chat_pending_turn(
|
|||
trace_span_attributes = {
|
||||
"langfuse.trace.name": workflow_run.name or f"text-chat-{workflow_run_id}"
|
||||
}
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
llm,
|
||||
|
|
@ -634,20 +670,13 @@ async def execute_text_chat_pending_turn(
|
|||
activity_marker=generation_marker,
|
||||
)
|
||||
finally:
|
||||
if not task.has_finished():
|
||||
await task.cancel(reason=TEXT_CHAT_INTERNAL_CANCEL_REASON)
|
||||
try:
|
||||
if not task.has_finished():
|
||||
await task.cancel(reason=TEXT_CHAT_INTERNAL_CANCEL_REASON)
|
||||
await runner_task
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Transportless text chat pipeline failed while closing run {}",
|
||||
workflow_run_id,
|
||||
)
|
||||
finally:
|
||||
await engine.close_mcp_sessions()
|
||||
await engine.cleanup()
|
||||
raise
|
||||
await engine.close_mcp_sessions()
|
||||
await engine.cleanup()
|
||||
|
||||
gathered_context = await engine.get_gathered_context()
|
||||
assistant_text = (
|
||||
|
|
@ -658,29 +687,36 @@ async def execute_text_chat_pending_turn(
|
|||
assistant_created_at = datetime.now(UTC).isoformat()
|
||||
usage = pipeline_metrics_aggregator.get_all_usage_metrics_serialized()
|
||||
current_node = getattr(engine, "_current_node", None)
|
||||
context_messages = context.get_messages()
|
||||
encoded_messages = _serialize_text_chat_checkpoint_messages(context_messages)
|
||||
encoded_gathered_context = jsonable_encoder(gathered_context)
|
||||
encoded_tool_state = jsonable_encoder(base_checkpoint.get("tool_state") or {})
|
||||
|
||||
updated_checkpoint = {
|
||||
"version": TEXT_CHAT_CHECKPOINT_VERSION,
|
||||
"anchor_turn_id": pending_turn.get("id"),
|
||||
"current_node_id": current_node.id if current_node else None,
|
||||
"messages": jsonable_encoder(context.get_messages()),
|
||||
"gathered_context": jsonable_encoder(gathered_context),
|
||||
"tool_state": jsonable_encoder(base_checkpoint.get("tool_state") or {}),
|
||||
"messages": encoded_messages,
|
||||
"gathered_context": encoded_gathered_context,
|
||||
"tool_state": encoded_tool_state,
|
||||
}
|
||||
|
||||
encoded_gathered_context = jsonable_encoder(gathered_context)
|
||||
trace_url = get_trace_url(trace_id, org_id=workflow.organization_id)
|
||||
if trace_url:
|
||||
encoded_gathered_context = {**encoded_gathered_context, "trace_url": trace_url}
|
||||
|
||||
encoded_events = jsonable_encoder(capture_processor.events)
|
||||
encoded_usage = jsonable_encoder(usage)
|
||||
encoded_initial_context = jsonable_encoder(initial_context)
|
||||
|
||||
return TextChatTurnExecutionResult(
|
||||
assistant_text=assistant_text,
|
||||
assistant_created_at=assistant_created_at,
|
||||
events=jsonable_encoder(capture_processor.events),
|
||||
usage=jsonable_encoder(usage),
|
||||
events=encoded_events,
|
||||
usage=encoded_usage,
|
||||
checkpoint=updated_checkpoint,
|
||||
gathered_context=encoded_gathered_context,
|
||||
initial_context=jsonable_encoder(initial_context),
|
||||
initial_context=encoded_initial_context,
|
||||
state=(
|
||||
WorkflowRunState.COMPLETED.value
|
||||
if engine.is_call_disposed()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,16 @@
|
|||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from pipecat.processors.aggregators.llm_context import LLMSpecificMessage
|
||||
|
||||
from api.db.models import OrganizationModel, UserModel
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.workflow.text_chat_runner import (
|
||||
_deserialize_text_chat_checkpoint_messages,
|
||||
_serialize_text_chat_checkpoint_messages,
|
||||
)
|
||||
from api.tests.integrations._run_pipeline_helpers import USER_CONFIGURATION
|
||||
from pipecat.tests import MockLLMService
|
||||
|
||||
|
|
@ -18,6 +24,49 @@ def _log_texts(logs: dict | None, event_type: str) -> list[str]:
|
|||
]
|
||||
|
||||
|
||||
def test_text_chat_checkpoint_messages_round_trip_google_thought_signature():
|
||||
signature = bytes.fromhex("12340a32010c39d6c7f38fd8b8eb6ab0")
|
||||
messages = [
|
||||
{"role": "assistant", "content": "Hello."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi",
|
||||
},
|
||||
LLMSpecificMessage(
|
||||
llm="google",
|
||||
message={
|
||||
"type": "thought_signature",
|
||||
"signature": signature,
|
||||
"bookmark": {"text": "Hello."},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
encoded = _serialize_text_chat_checkpoint_messages(messages)
|
||||
|
||||
json.dumps(encoded)
|
||||
assert encoded[-1] == {
|
||||
"__specific__": True,
|
||||
"llm": "google",
|
||||
"message": {
|
||||
"type": "thought_signature",
|
||||
"signature": {
|
||||
"__type__": "bytes",
|
||||
"__data__": "EjQKMgEMOdbH84/YuOtqsA==",
|
||||
},
|
||||
"bookmark": {"text": "Hello."},
|
||||
},
|
||||
}
|
||||
|
||||
restored = _deserialize_text_chat_checkpoint_messages(encoded)
|
||||
|
||||
assert restored[:2] == messages[:2]
|
||||
assert isinstance(restored[-1], LLMSpecificMessage)
|
||||
assert restored[-1].llm == "google"
|
||||
assert restored[-1].message["signature"] == signature
|
||||
assert restored[-1].message["bookmark"] == {"text": "Hello."}
|
||||
|
||||
|
||||
async def _create_user_and_workflow(
|
||||
db_session,
|
||||
async_session,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue