mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-30 21:59:46 +02:00
fix(chat): normalize provider-safe message history
This commit is contained in:
parent
5d5d574550
commit
3dd54230e7
9 changed files with 382 additions and 6 deletions
89
surfsense_backend/app/tasks/chat/llm_history_normalizer.py
Normal file
89
surfsense_backend/app/tasks/chat/llm_history_normalizer.py
Normal file
|
|
@ -0,0 +1,89 @@
|
||||||
|
"""Convert persisted chat content into provider-safe LangChain history.
|
||||||
|
|
||||||
|
Assistant UI parts are a UI/storage shape, not an LLM prompt shape. This module
|
||||||
|
extracts only model-safe content before prior turns are replayed to a provider.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
_USER_CONTENT_TYPES = {"text", "image", "image_url"}
|
||||||
|
|
||||||
|
|
||||||
|
def _text_from_block(block: dict[str, Any]) -> str:
|
||||||
|
value = block.get("text") or block.get("content") or ""
|
||||||
|
return value if isinstance(value, str) else ""
|
||||||
|
|
||||||
|
|
||||||
|
def assistant_content_to_llm_text(content: Any) -> str:
|
||||||
|
"""Return visible assistant text, dropping reasoning/UI/provider blocks."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, dict):
|
||||||
|
return _text_from_block(content)
|
||||||
|
if not isinstance(content, list):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
text_chunks: list[str] = []
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, str):
|
||||||
|
if block:
|
||||||
|
text_chunks.append(block)
|
||||||
|
continue
|
||||||
|
if not isinstance(block, dict):
|
||||||
|
continue
|
||||||
|
if block.get("type") == "text":
|
||||||
|
text = _text_from_block(block)
|
||||||
|
if text:
|
||||||
|
text_chunks.append(text)
|
||||||
|
return "\n".join(text_chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def user_content_to_llm_content(
|
||||||
|
content: Any,
|
||||||
|
*,
|
||||||
|
allow_images: bool = True,
|
||||||
|
) -> str | list[dict[str, Any]]:
|
||||||
|
"""Return provider-safe user text/image content for LangChain."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, dict):
|
||||||
|
return _text_from_block(content)
|
||||||
|
if not isinstance(content, list):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
parts: list[dict[str, Any]] = []
|
||||||
|
text_chunks: list[str] = []
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, str):
|
||||||
|
if block:
|
||||||
|
text_chunks.append(block)
|
||||||
|
continue
|
||||||
|
if not isinstance(block, dict):
|
||||||
|
continue
|
||||||
|
block_type = block.get("type")
|
||||||
|
if block_type not in _USER_CONTENT_TYPES:
|
||||||
|
continue
|
||||||
|
if block_type == "text":
|
||||||
|
text = _text_from_block(block)
|
||||||
|
if text:
|
||||||
|
parts.append({"type": "text", "text": text})
|
||||||
|
text_chunks.append(text)
|
||||||
|
elif allow_images and block_type == "image":
|
||||||
|
image = block.get("image")
|
||||||
|
if isinstance(image, str) and image.startswith("data:"):
|
||||||
|
parts.append({"type": "image_url", "image_url": {"url": image}})
|
||||||
|
elif allow_images and block_type == "image_url":
|
||||||
|
image_url = block.get("image_url")
|
||||||
|
if isinstance(image_url, dict):
|
||||||
|
url = image_url.get("url")
|
||||||
|
if isinstance(url, str) and url.startswith("data:"):
|
||||||
|
parts.append({"type": "image_url", "image_url": {"url": url}})
|
||||||
|
elif isinstance(image_url, str) and image_url.startswith("data:"):
|
||||||
|
parts.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||||
|
|
||||||
|
if allow_images and any(part.get("type") == "image_url" for part in parts):
|
||||||
|
return parts
|
||||||
|
return "\n".join(text_chunks)
|
||||||
|
|
||||||
86
surfsense_backend/app/tasks/chat/message_parts_normalizer.py
Normal file
86
surfsense_backend/app/tasks/chat/message_parts_normalizer.py
Normal file
|
|
@ -0,0 +1,86 @@
|
||||||
|
"""Normalize final LangChain assistant messages into assistant-ui parts.
|
||||||
|
|
||||||
|
Live streaming remains the primary source for rich, incremental UI state.
|
||||||
|
This module is only used after the graph has finished so refresh persistence
|
||||||
|
does not depend on provider-specific streaming chunk shapes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
|
||||||
|
def _text_from_content(content: Any) -> str:
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if not isinstance(content, list):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
text_parts: list[str] = []
|
||||||
|
for block in content:
|
||||||
|
if not isinstance(block, dict):
|
||||||
|
continue
|
||||||
|
if block.get("type") != "text":
|
||||||
|
continue
|
||||||
|
value = block.get("text") or block.get("content") or ""
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
text_parts.append(value)
|
||||||
|
return "".join(text_parts)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_ai_message_to_parts(message: AIMessage | Any | None) -> list[dict[str, Any]]:
|
||||||
|
"""Return user-visible assistant-ui parts for a final AI message.
|
||||||
|
|
||||||
|
We intentionally do not backfill provider ``thinking`` /
|
||||||
|
``reasoning_content`` blocks here. If reasoning streamed live, the
|
||||||
|
``AssistantContentBuilder`` already captured it. If it only exists in the
|
||||||
|
final model payload, persisting it retroactively could expose content the
|
||||||
|
UI never showed during the turn.
|
||||||
|
"""
|
||||||
|
if message is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
text = _text_from_content(getattr(message, "content", None)).strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
return [{"type": "text", "text": text}]
|
||||||
|
|
||||||
|
|
||||||
|
def last_ai_message(messages: Iterable[Any] | None) -> AIMessage | Any | None:
|
||||||
|
if messages is None:
|
||||||
|
return None
|
||||||
|
for message in reversed(list(messages)):
|
||||||
|
if isinstance(message, AIMessage):
|
||||||
|
return message
|
||||||
|
if getattr(message, "type", None) == "ai":
|
||||||
|
return message
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def final_assistant_parts_from_messages(messages: Iterable[Any] | None) -> list[dict[str, Any]]:
|
||||||
|
return normalize_ai_message_to_parts(last_ai_message(messages))
|
||||||
|
|
||||||
|
|
||||||
|
def has_non_empty_text_part(parts: Iterable[dict[str, Any]]) -> bool:
|
||||||
|
return any(
|
||||||
|
part.get("type") == "text"
|
||||||
|
and isinstance(part.get("text"), str)
|
||||||
|
and bool(part.get("text", "").strip())
|
||||||
|
for part in parts
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_streamed_and_final_parts(
|
||||||
|
streamed_parts: list[dict[str, Any]],
|
||||||
|
final_parts: list[dict[str, Any]],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Use final-state text only when streaming captured no answer text."""
|
||||||
|
if has_non_empty_text_part(streamed_parts):
|
||||||
|
return streamed_parts
|
||||||
|
if not has_non_empty_text_part(final_parts):
|
||||||
|
return streamed_parts
|
||||||
|
return [*streamed_parts, *final_parts]
|
||||||
|
|
||||||
|
|
@ -25,6 +25,9 @@ from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
|
||||||
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
|
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
|
||||||
all_interrupt_values,
|
all_interrupt_values,
|
||||||
)
|
)
|
||||||
|
from app.tasks.chat.message_parts_normalizer import (
|
||||||
|
final_assistant_parts_from_messages,
|
||||||
|
)
|
||||||
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||||
from app.tasks.chat.streaming.shared.utils import safe_float
|
from app.tasks.chat.streaming.shared.utils import safe_float
|
||||||
from app.utils.perf import get_perf_logger
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
@ -75,6 +78,9 @@ async def stream_agent_events(
|
||||||
|
|
||||||
state = await agent.aget_state(config)
|
state = await agent.aget_state(config)
|
||||||
state_values = getattr(state, "values", {}) or {}
|
state_values = getattr(state, "values", {}) or {}
|
||||||
|
result.final_message_parts = final_assistant_parts_from_messages(
|
||||||
|
state_values.get("messages")
|
||||||
|
)
|
||||||
|
|
||||||
# Safety net: if astream_events was cancelled before
|
# Safety net: if astream_events was cancelled before
|
||||||
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
|
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,7 @@ async def finalize_assistant_message(
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
from app.tasks.chat.message_parts_normalizer import merge_streamed_and_final_parts
|
||||||
from app.tasks.chat.persistence import finalize_assistant_turn
|
from app.tasks.chat.persistence import finalize_assistant_turn
|
||||||
|
|
||||||
builder_stats: dict[str, int] | None = None
|
builder_stats: dict[str, int] | None = None
|
||||||
|
|
@ -74,6 +75,10 @@ async def finalize_assistant_message(
|
||||||
"text": stream_result.accumulated_text or "",
|
"text": stream_result.accumulated_text or "",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
content_payload = merge_streamed_and_final_parts(
|
||||||
|
content_payload,
|
||||||
|
stream_result.final_message_parts,
|
||||||
|
)
|
||||||
|
|
||||||
if builder_stats is not None:
|
if builder_stats is not None:
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
|
|
|
||||||
|
|
@ -35,3 +35,7 @@ class StreamResult:
|
||||||
# (``StreamResult`` is logged in some error branches) from dumping a
|
# (``StreamResult`` is logged in some error branches) from dumping a
|
||||||
# potentially-large parts list.
|
# potentially-large parts list.
|
||||||
content_builder: Any | None = field(default=None, repr=False)
|
content_builder: Any | None = field(default=None, repr=False)
|
||||||
|
# User-visible assistant message parts derived from the final LangGraph
|
||||||
|
# state. Used after streaming completes as a provider-agnostic persistence
|
||||||
|
# backfill when no text chunks reached the live stream.
|
||||||
|
final_message_parts: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,11 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from app.tasks.chat.llm_history_normalizer import (
|
||||||
|
assistant_content_to_llm_text,
|
||||||
|
user_content_to_llm_content,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.db import ChatVisibility
|
from app.db import ChatVisibility
|
||||||
|
|
||||||
|
|
@ -95,17 +100,28 @@ async def bootstrap_history_from_db(
|
||||||
langchain_messages: list[HumanMessage | AIMessage] = []
|
langchain_messages: list[HumanMessage | AIMessage] = []
|
||||||
|
|
||||||
for msg in db_messages:
|
for msg in db_messages:
|
||||||
text_content = extract_text_content(msg.content)
|
|
||||||
if not text_content:
|
|
||||||
continue
|
|
||||||
if msg.role == "user":
|
if msg.role == "user":
|
||||||
|
user_content = user_content_to_llm_content(
|
||||||
|
msg.content,
|
||||||
|
allow_images=False,
|
||||||
|
)
|
||||||
|
if not user_content:
|
||||||
|
continue
|
||||||
if is_shared:
|
if is_shared:
|
||||||
author_name = (
|
author_name = (
|
||||||
msg.author.display_name if msg.author else None
|
msg.author.display_name if msg.author else None
|
||||||
) or "A team member"
|
) or "A team member"
|
||||||
text_content = f"**[{author_name}]:** {text_content}"
|
if isinstance(user_content, str):
|
||||||
langchain_messages.append(HumanMessage(content=text_content))
|
user_content = f"**[{author_name}]:** {user_content}"
|
||||||
|
elif user_content and user_content[0].get("type") == "text":
|
||||||
|
user_content[0] = {
|
||||||
|
**user_content[0],
|
||||||
|
"text": f"**[{author_name}]:** {user_content[0].get('text', '')}",
|
||||||
|
}
|
||||||
|
langchain_messages.append(HumanMessage(content=user_content))
|
||||||
elif msg.role == "assistant":
|
elif msg.role == "assistant":
|
||||||
langchain_messages.append(AIMessage(content=text_content))
|
assistant_text = assistant_content_to_llm_text(msg.content)
|
||||||
|
if assistant_text:
|
||||||
|
langchain_messages.append(AIMessage(content=assistant_text))
|
||||||
|
|
||||||
return langchain_messages
|
return langchain_messages
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,40 @@
|
||||||
|
"""Regression tests for model-boundary message sanitization."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
from app.agents.chat.runtime.llm_config import _sanitize_messages
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_messages_strips_provider_specific_thinking_blocks() -> None:
|
||||||
|
original = AIMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "thinking", "thinking": "private reasoning"},
|
||||||
|
{"type": "text", "text": "visible answer"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
sanitized = _sanitize_messages([original])
|
||||||
|
|
||||||
|
assert sanitized[0].content == "visible answer"
|
||||||
|
assert original.content == [
|
||||||
|
{"type": "thinking", "thinking": "private reasoning"},
|
||||||
|
{"type": "text", "text": "visible answer"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_messages_sets_tool_only_ai_content_to_none() -> None:
|
||||||
|
message = AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[{"name": "search", "args": {"q": "x"}, "id": "call_1"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
sanitized = _sanitize_messages([message])
|
||||||
|
|
||||||
|
assert sanitized[0].content is None
|
||||||
|
assert message.content == ""
|
||||||
|
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
"""Unit tests for provider-safe LLM history normalization."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.tasks.chat.llm_history_normalizer import (
|
||||||
|
assistant_content_to_llm_text,
|
||||||
|
user_content_to_llm_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def test_assistant_ui_parts_drop_thinking_steps_for_llm_history() -> None:
|
||||||
|
content = [
|
||||||
|
{"type": "data-thinking-steps", "data": [{"id": "thinking-1"}]},
|
||||||
|
{"type": "text", "text": "visible answer"},
|
||||||
|
]
|
||||||
|
|
||||||
|
assert assistant_content_to_llm_text(content) == "visible answer"
|
||||||
|
|
||||||
|
|
||||||
|
def test_provider_thinking_blocks_are_not_replayed_to_llm() -> None:
|
||||||
|
content = [
|
||||||
|
{"type": "thinking", "thinking": "private reasoning"},
|
||||||
|
{"type": "text", "text": "final answer"},
|
||||||
|
]
|
||||||
|
|
||||||
|
assert assistant_content_to_llm_text(content) == "final answer"
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_assistant_blocks_are_dropped() -> None:
|
||||||
|
content = [
|
||||||
|
{"type": "redacted_thinking", "data": "hidden"},
|
||||||
|
{"type": "tool_use", "name": "search"},
|
||||||
|
{"type": "text", "text": "kept"},
|
||||||
|
]
|
||||||
|
|
||||||
|
assert assistant_content_to_llm_text(content) == "kept"
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_images_convert_to_openai_compatible_image_url_blocks() -> None:
|
||||||
|
content = [
|
||||||
|
{"type": "text", "text": "look"},
|
||||||
|
{"type": "image", "image": "data:image/png;base64,abc"},
|
||||||
|
]
|
||||||
|
|
||||||
|
assert user_content_to_llm_content(content, allow_images=True) == [
|
||||||
|
{"type": "text", "text": "look"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_images_can_be_dropped_for_text_only_history() -> None:
|
||||||
|
content = [
|
||||||
|
{"type": "text", "text": "look"},
|
||||||
|
{"type": "image", "image": "data:image/png;base64,abc"},
|
||||||
|
]
|
||||||
|
|
||||||
|
assert user_content_to_llm_content(content, allow_images=False) == "look"
|
||||||
|
|
||||||
|
|
@ -0,0 +1,68 @@
|
||||||
|
"""Unit tests for final assistant message part normalization."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.tasks.chat.message_parts_normalizer import (
|
||||||
|
final_assistant_parts_from_messages,
|
||||||
|
merge_streamed_and_final_parts,
|
||||||
|
normalize_ai_message_to_parts,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def test_string_ai_message_content_becomes_text_part() -> None:
|
||||||
|
assert normalize_ai_message_to_parts(AIMessage(content="hello")) == [
|
||||||
|
{"type": "text", "text": "hello"}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_deepseek_thinking_plus_text_blocks_backfill_only_text() -> None:
|
||||||
|
message = AIMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "thinking", "thinking": "hidden reasoning"},
|
||||||
|
{"type": "text", "text": "Yo bro! What's up?"},
|
||||||
|
],
|
||||||
|
additional_kwargs={"reasoning_content": "hidden reasoning"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert normalize_ai_message_to_parts(message) == [
|
||||||
|
{"type": "text", "text": "Yo bro! What's up?"}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_final_parts_use_last_ai_message_and_skip_trailing_tool_messages() -> None:
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="ask"),
|
||||||
|
AIMessage(content="draft"),
|
||||||
|
ToolMessage(content="tool output", tool_call_id="tc-1"),
|
||||||
|
AIMessage(content=[{"type": "text", "text": "final answer"}]),
|
||||||
|
ToolMessage(content="trailing tool noise", tool_call_id="tc-2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert final_assistant_parts_from_messages(messages) == [
|
||||||
|
{"type": "text", "text": "final answer"}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_adds_final_text_when_stream_only_has_thinking_steps() -> None:
|
||||||
|
streamed = [
|
||||||
|
{
|
||||||
|
"type": "data-thinking-steps",
|
||||||
|
"data": [{"id": "thinking-1", "status": "completed"}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
final = [{"type": "text", "text": "visible answer"}]
|
||||||
|
|
||||||
|
assert merge_streamed_and_final_parts(streamed, final) == [*streamed, *final]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_does_not_duplicate_when_stream_already_has_text() -> None:
|
||||||
|
streamed = [{"type": "text", "text": "streamed answer"}]
|
||||||
|
final = [{"type": "text", "text": "final answer"}]
|
||||||
|
|
||||||
|
assert merge_streamed_and_final_parts(streamed, final) == streamed
|
||||||
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue