fix(chat): normalize provider-safe message history

This commit is contained in:
Anish Sarkar 2026-06-12 02:17:37 +05:30
parent 5d5d574550
commit 3dd54230e7
9 changed files with 382 additions and 6 deletions

View file

@ -0,0 +1,89 @@
"""Convert persisted chat content into provider-safe LangChain history.
Assistant UI parts are a UI/storage shape, not an LLM prompt shape. This module
extracts only model-safe content before prior turns are replayed to a provider.
"""
from __future__ import annotations
from typing import Any
_USER_CONTENT_TYPES = {"text", "image", "image_url"}
def _text_from_block(block: dict[str, Any]) -> str:
value = block.get("text") or block.get("content") or ""
return value if isinstance(value, str) else ""
def assistant_content_to_llm_text(content: Any) -> str:
"""Return visible assistant text, dropping reasoning/UI/provider blocks."""
if isinstance(content, str):
return content
if isinstance(content, dict):
return _text_from_block(content)
if not isinstance(content, list):
return ""
text_chunks: list[str] = []
for block in content:
if isinstance(block, str):
if block:
text_chunks.append(block)
continue
if not isinstance(block, dict):
continue
if block.get("type") == "text":
text = _text_from_block(block)
if text:
text_chunks.append(text)
return "\n".join(text_chunks)
def user_content_to_llm_content(
content: Any,
*,
allow_images: bool = True,
) -> str | list[dict[str, Any]]:
"""Return provider-safe user text/image content for LangChain."""
if isinstance(content, str):
return content
if isinstance(content, dict):
return _text_from_block(content)
if not isinstance(content, list):
return ""
parts: list[dict[str, Any]] = []
text_chunks: list[str] = []
for block in content:
if isinstance(block, str):
if block:
text_chunks.append(block)
continue
if not isinstance(block, dict):
continue
block_type = block.get("type")
if block_type not in _USER_CONTENT_TYPES:
continue
if block_type == "text":
text = _text_from_block(block)
if text:
parts.append({"type": "text", "text": text})
text_chunks.append(text)
elif allow_images and block_type == "image":
image = block.get("image")
if isinstance(image, str) and image.startswith("data:"):
parts.append({"type": "image_url", "image_url": {"url": image}})
elif allow_images and block_type == "image_url":
image_url = block.get("image_url")
if isinstance(image_url, dict):
url = image_url.get("url")
if isinstance(url, str) and url.startswith("data:"):
parts.append({"type": "image_url", "image_url": {"url": url}})
elif isinstance(image_url, str) and image_url.startswith("data:"):
parts.append({"type": "image_url", "image_url": {"url": image_url}})
if allow_images and any(part.get("type") == "image_url" for part in parts):
return parts
return "\n".join(text_chunks)

View file

@ -0,0 +1,86 @@
"""Normalize final LangChain assistant messages into assistant-ui parts.
Live streaming remains the primary source for rich, incremental UI state.
This module is only used after the graph has finished so refresh persistence
does not depend on provider-specific streaming chunk shapes.
"""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any
from langchain_core.messages import AIMessage
def _text_from_content(content: Any) -> str:
if isinstance(content, str):
return content
if not isinstance(content, list):
return ""
text_parts: list[str] = []
for block in content:
if not isinstance(block, dict):
continue
if block.get("type") != "text":
continue
value = block.get("text") or block.get("content") or ""
if isinstance(value, str) and value:
text_parts.append(value)
return "".join(text_parts)
def normalize_ai_message_to_parts(message: AIMessage | Any | None) -> list[dict[str, Any]]:
"""Return user-visible assistant-ui parts for a final AI message.
We intentionally do not backfill provider ``thinking`` /
``reasoning_content`` blocks here. If reasoning streamed live, the
``AssistantContentBuilder`` already captured it. If it only exists in the
final model payload, persisting it retroactively could expose content the
UI never showed during the turn.
"""
if message is None:
return []
text = _text_from_content(getattr(message, "content", None)).strip()
if not text:
return []
return [{"type": "text", "text": text}]
def last_ai_message(messages: Iterable[Any] | None) -> AIMessage | Any | None:
if messages is None:
return None
for message in reversed(list(messages)):
if isinstance(message, AIMessage):
return message
if getattr(message, "type", None) == "ai":
return message
return None
def final_assistant_parts_from_messages(messages: Iterable[Any] | None) -> list[dict[str, Any]]:
return normalize_ai_message_to_parts(last_ai_message(messages))
def has_non_empty_text_part(parts: Iterable[dict[str, Any]]) -> bool:
return any(
part.get("type") == "text"
and isinstance(part.get("text"), str)
and bool(part.get("text", "").strip())
for part in parts
)
def merge_streamed_and_final_parts(
streamed_parts: list[dict[str, Any]],
final_parts: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Use final-state text only when streaming captured no answer text."""
if has_non_empty_text_part(streamed_parts):
return streamed_parts
if not has_non_empty_text_part(final_parts):
return streamed_parts
return [*streamed_parts, *final_parts]

View file

@ -25,6 +25,9 @@ from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
all_interrupt_values,
)
from app.tasks.chat.message_parts_normalizer import (
final_assistant_parts_from_messages,
)
from app.tasks.chat.streaming.shared.stream_result import StreamResult
from app.tasks.chat.streaming.shared.utils import safe_float
from app.utils.perf import get_perf_logger
@ -75,6 +78,9 @@ async def stream_agent_events(
state = await agent.aget_state(config)
state_values = getattr(state, "values", {}) or {}
result.final_message_parts = final_assistant_parts_from_messages(
state_values.get("messages")
)
# Safety net: if astream_events was cancelled before
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work

View file

@ -53,6 +53,7 @@ async def finalize_assistant_message(
):
return
from app.tasks.chat.message_parts_normalizer import merge_streamed_and_final_parts
from app.tasks.chat.persistence import finalize_assistant_turn
builder_stats: dict[str, int] | None = None
@ -74,6 +75,10 @@ async def finalize_assistant_message(
"text": stream_result.accumulated_text or "",
}
]
content_payload = merge_streamed_and_final_parts(
content_payload,
stream_result.final_message_parts,
)
if builder_stats is not None:
_perf_log.info(

View file

@ -35,3 +35,7 @@ class StreamResult:
# (``StreamResult`` is logged in some error branches) from dumping a
# potentially-large parts list.
content_builder: Any | None = field(default=None, repr=False)
# User-visible assistant message parts derived from the final LangGraph
# state. Used after streaming completes as a provider-agnostic persistence
# backfill when no text chunks reached the live stream.
final_message_parts: list[dict[str, Any]] = field(default_factory=list)

View file

@ -18,6 +18,11 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.tasks.chat.llm_history_normalizer import (
assistant_content_to_llm_text,
user_content_to_llm_content,
)
if TYPE_CHECKING:
from app.db import ChatVisibility
@ -95,17 +100,28 @@ async def bootstrap_history_from_db(
langchain_messages: list[HumanMessage | AIMessage] = []
for msg in db_messages:
text_content = extract_text_content(msg.content)
if not text_content:
continue
if msg.role == "user":
user_content = user_content_to_llm_content(
msg.content,
allow_images=False,
)
if not user_content:
continue
if is_shared:
author_name = (
msg.author.display_name if msg.author else None
) or "A team member"
text_content = f"**[{author_name}]:** {text_content}"
langchain_messages.append(HumanMessage(content=text_content))
if isinstance(user_content, str):
user_content = f"**[{author_name}]:** {user_content}"
elif user_content and user_content[0].get("type") == "text":
user_content[0] = {
**user_content[0],
"text": f"**[{author_name}]:** {user_content[0].get('text', '')}",
}
langchain_messages.append(HumanMessage(content=user_content))
elif msg.role == "assistant":
langchain_messages.append(AIMessage(content=text_content))
assistant_text = assistant_content_to_llm_text(msg.content)
if assistant_text:
langchain_messages.append(AIMessage(content=assistant_text))
return langchain_messages

View file

@ -0,0 +1,40 @@
"""Regression tests for model-boundary message sanitization."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage
from app.agents.chat.runtime.llm_config import _sanitize_messages
pytestmark = pytest.mark.unit
def test_sanitize_messages_strips_provider_specific_thinking_blocks() -> None:
original = AIMessage(
content=[
{"type": "thinking", "thinking": "private reasoning"},
{"type": "text", "text": "visible answer"},
]
)
sanitized = _sanitize_messages([original])
assert sanitized[0].content == "visible answer"
assert original.content == [
{"type": "thinking", "thinking": "private reasoning"},
{"type": "text", "text": "visible answer"},
]
def test_sanitize_messages_sets_tool_only_ai_content_to_none() -> None:
message = AIMessage(
content="",
tool_calls=[{"name": "search", "args": {"q": "x"}, "id": "call_1"}],
)
sanitized = _sanitize_messages([message])
assert sanitized[0].content is None
assert message.content == ""

View file

@ -0,0 +1,62 @@
"""Unit tests for provider-safe LLM history normalization."""
from __future__ import annotations
import pytest
from app.tasks.chat.llm_history_normalizer import (
assistant_content_to_llm_text,
user_content_to_llm_content,
)
pytestmark = pytest.mark.unit
def test_assistant_ui_parts_drop_thinking_steps_for_llm_history() -> None:
content = [
{"type": "data-thinking-steps", "data": [{"id": "thinking-1"}]},
{"type": "text", "text": "visible answer"},
]
assert assistant_content_to_llm_text(content) == "visible answer"
def test_provider_thinking_blocks_are_not_replayed_to_llm() -> None:
content = [
{"type": "thinking", "thinking": "private reasoning"},
{"type": "text", "text": "final answer"},
]
assert assistant_content_to_llm_text(content) == "final answer"
def test_unknown_assistant_blocks_are_dropped() -> None:
content = [
{"type": "redacted_thinking", "data": "hidden"},
{"type": "tool_use", "name": "search"},
{"type": "text", "text": "kept"},
]
assert assistant_content_to_llm_text(content) == "kept"
def test_user_images_convert_to_openai_compatible_image_url_blocks() -> None:
content = [
{"type": "text", "text": "look"},
{"type": "image", "image": "data:image/png;base64,abc"},
]
assert user_content_to_llm_content(content, allow_images=True) == [
{"type": "text", "text": "look"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
]
def test_user_images_can_be_dropped_for_text_only_history() -> None:
content = [
{"type": "text", "text": "look"},
{"type": "image", "image": "data:image/png;base64,abc"},
]
assert user_content_to_llm_content(content, allow_images=False) == "look"

View file

@ -0,0 +1,68 @@
"""Unit tests for final assistant message part normalization."""
from __future__ import annotations
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from app.tasks.chat.message_parts_normalizer import (
final_assistant_parts_from_messages,
merge_streamed_and_final_parts,
normalize_ai_message_to_parts,
)
pytestmark = pytest.mark.unit
def test_string_ai_message_content_becomes_text_part() -> None:
assert normalize_ai_message_to_parts(AIMessage(content="hello")) == [
{"type": "text", "text": "hello"}
]
def test_deepseek_thinking_plus_text_blocks_backfill_only_text() -> None:
message = AIMessage(
content=[
{"type": "thinking", "thinking": "hidden reasoning"},
{"type": "text", "text": "Yo bro! What's up?"},
],
additional_kwargs={"reasoning_content": "hidden reasoning"},
)
assert normalize_ai_message_to_parts(message) == [
{"type": "text", "text": "Yo bro! What's up?"}
]
def test_final_parts_use_last_ai_message_and_skip_trailing_tool_messages() -> None:
messages = [
HumanMessage(content="ask"),
AIMessage(content="draft"),
ToolMessage(content="tool output", tool_call_id="tc-1"),
AIMessage(content=[{"type": "text", "text": "final answer"}]),
ToolMessage(content="trailing tool noise", tool_call_id="tc-2"),
]
assert final_assistant_parts_from_messages(messages) == [
{"type": "text", "text": "final answer"}
]
def test_merge_adds_final_text_when_stream_only_has_thinking_steps() -> None:
streamed = [
{
"type": "data-thinking-steps",
"data": [{"id": "thinking-1", "status": "completed"}],
}
]
final = [{"type": "text", "text": "visible answer"}]
assert merge_streamed_and_final_parts(streamed, final) == [*streamed, *final]
def test_merge_does_not_duplicate_when_stream_already_has_text() -> None:
streamed = [{"type": "text", "text": "streamed answer"}]
final = [{"type": "text", "text": "final answer"}]
assert merge_streamed_and_final_parts(streamed, final) == streamed