fix(chat): normalize provider-safe message history

This commit is contained in:
Anish Sarkar 2026-06-12 02:17:37 +05:30
parent 5d5d574550
commit 3dd54230e7
9 changed files with 382 additions and 6 deletions

View file

@ -0,0 +1,89 @@
"""Convert persisted chat content into provider-safe LangChain history.
Assistant UI parts are a UI/storage shape, not an LLM prompt shape. This module
extracts only model-safe content before prior turns are replayed to a provider.
"""
from __future__ import annotations
from typing import Any
_USER_CONTENT_TYPES = {"text", "image", "image_url"}
def _text_from_block(block: dict[str, Any]) -> str:
value = block.get("text") or block.get("content") or ""
return value if isinstance(value, str) else ""
def assistant_content_to_llm_text(content: Any) -> str:
"""Return visible assistant text, dropping reasoning/UI/provider blocks."""
if isinstance(content, str):
return content
if isinstance(content, dict):
return _text_from_block(content)
if not isinstance(content, list):
return ""
text_chunks: list[str] = []
for block in content:
if isinstance(block, str):
if block:
text_chunks.append(block)
continue
if not isinstance(block, dict):
continue
if block.get("type") == "text":
text = _text_from_block(block)
if text:
text_chunks.append(text)
return "\n".join(text_chunks)
def user_content_to_llm_content(
content: Any,
*,
allow_images: bool = True,
) -> str | list[dict[str, Any]]:
"""Return provider-safe user text/image content for LangChain."""
if isinstance(content, str):
return content
if isinstance(content, dict):
return _text_from_block(content)
if not isinstance(content, list):
return ""
parts: list[dict[str, Any]] = []
text_chunks: list[str] = []
for block in content:
if isinstance(block, str):
if block:
text_chunks.append(block)
continue
if not isinstance(block, dict):
continue
block_type = block.get("type")
if block_type not in _USER_CONTENT_TYPES:
continue
if block_type == "text":
text = _text_from_block(block)
if text:
parts.append({"type": "text", "text": text})
text_chunks.append(text)
elif allow_images and block_type == "image":
image = block.get("image")
if isinstance(image, str) and image.startswith("data:"):
parts.append({"type": "image_url", "image_url": {"url": image}})
elif allow_images and block_type == "image_url":
image_url = block.get("image_url")
if isinstance(image_url, dict):
url = image_url.get("url")
if isinstance(url, str) and url.startswith("data:"):
parts.append({"type": "image_url", "image_url": {"url": url}})
elif isinstance(image_url, str) and image_url.startswith("data:"):
parts.append({"type": "image_url", "image_url": {"url": image_url}})
if allow_images and any(part.get("type") == "image_url" for part in parts):
return parts
return "\n".join(text_chunks)

View file

@ -0,0 +1,86 @@
"""Normalize final LangChain assistant messages into assistant-ui parts.
Live streaming remains the primary source for rich, incremental UI state.
This module is only used after the graph has finished so refresh persistence
does not depend on provider-specific streaming chunk shapes.
"""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any
from langchain_core.messages import AIMessage
def _text_from_content(content: Any) -> str:
if isinstance(content, str):
return content
if not isinstance(content, list):
return ""
text_parts: list[str] = []
for block in content:
if not isinstance(block, dict):
continue
if block.get("type") != "text":
continue
value = block.get("text") or block.get("content") or ""
if isinstance(value, str) and value:
text_parts.append(value)
return "".join(text_parts)
def normalize_ai_message_to_parts(message: AIMessage | Any | None) -> list[dict[str, Any]]:
"""Return user-visible assistant-ui parts for a final AI message.
We intentionally do not backfill provider ``thinking`` /
``reasoning_content`` blocks here. If reasoning streamed live, the
``AssistantContentBuilder`` already captured it. If it only exists in the
final model payload, persisting it retroactively could expose content the
UI never showed during the turn.
"""
if message is None:
return []
text = _text_from_content(getattr(message, "content", None)).strip()
if not text:
return []
return [{"type": "text", "text": text}]
def last_ai_message(messages: Iterable[Any] | None) -> AIMessage | Any | None:
if messages is None:
return None
for message in reversed(list(messages)):
if isinstance(message, AIMessage):
return message
if getattr(message, "type", None) == "ai":
return message
return None
def final_assistant_parts_from_messages(messages: Iterable[Any] | None) -> list[dict[str, Any]]:
return normalize_ai_message_to_parts(last_ai_message(messages))
def has_non_empty_text_part(parts: Iterable[dict[str, Any]]) -> bool:
return any(
part.get("type") == "text"
and isinstance(part.get("text"), str)
and bool(part.get("text", "").strip())
for part in parts
)
def merge_streamed_and_final_parts(
streamed_parts: list[dict[str, Any]],
final_parts: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Use final-state text only when streaming captured no answer text."""
if has_non_empty_text_part(streamed_parts):
return streamed_parts
if not has_non_empty_text_part(final_parts):
return streamed_parts
return [*streamed_parts, *final_parts]

View file

@ -25,6 +25,9 @@ from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
all_interrupt_values,
)
from app.tasks.chat.message_parts_normalizer import (
final_assistant_parts_from_messages,
)
from app.tasks.chat.streaming.shared.stream_result import StreamResult
from app.tasks.chat.streaming.shared.utils import safe_float
from app.utils.perf import get_perf_logger
@ -75,6 +78,9 @@ async def stream_agent_events(
state = await agent.aget_state(config)
state_values = getattr(state, "values", {}) or {}
result.final_message_parts = final_assistant_parts_from_messages(
state_values.get("messages")
)
# Safety net: if astream_events was cancelled before
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work

View file

@ -53,6 +53,7 @@ async def finalize_assistant_message(
):
return
from app.tasks.chat.message_parts_normalizer import merge_streamed_and_final_parts
from app.tasks.chat.persistence import finalize_assistant_turn
builder_stats: dict[str, int] | None = None
@ -74,6 +75,10 @@ async def finalize_assistant_message(
"text": stream_result.accumulated_text or "",
}
]
content_payload = merge_streamed_and_final_parts(
content_payload,
stream_result.final_message_parts,
)
if builder_stats is not None:
_perf_log.info(

View file

@ -35,3 +35,7 @@ class StreamResult:
# (``StreamResult`` is logged in some error branches) from dumping a
# potentially-large parts list.
content_builder: Any | None = field(default=None, repr=False)
# User-visible assistant message parts derived from the final LangGraph
# state. Used after streaming completes as a provider-agnostic persistence
# backfill when no text chunks reached the live stream.
final_message_parts: list[dict[str, Any]] = field(default_factory=list)

View file

@ -18,6 +18,11 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.tasks.chat.llm_history_normalizer import (
assistant_content_to_llm_text,
user_content_to_llm_content,
)
if TYPE_CHECKING:
from app.db import ChatVisibility
@ -95,17 +100,28 @@ async def bootstrap_history_from_db(
langchain_messages: list[HumanMessage | AIMessage] = []
for msg in db_messages:
text_content = extract_text_content(msg.content)
if not text_content:
continue
if msg.role == "user":
user_content = user_content_to_llm_content(
msg.content,
allow_images=False,
)
if not user_content:
continue
if is_shared:
author_name = (
msg.author.display_name if msg.author else None
) or "A team member"
text_content = f"**[{author_name}]:** {text_content}"
langchain_messages.append(HumanMessage(content=text_content))
if isinstance(user_content, str):
user_content = f"**[{author_name}]:** {user_content}"
elif user_content and user_content[0].get("type") == "text":
user_content[0] = {
**user_content[0],
"text": f"**[{author_name}]:** {user_content[0].get('text', '')}",
}
langchain_messages.append(HumanMessage(content=user_content))
elif msg.role == "assistant":
langchain_messages.append(AIMessage(content=text_content))
assistant_text = assistant_content_to_llm_text(msg.content)
if assistant_text:
langchain_messages.append(AIMessage(content=assistant_text))
return langchain_messages