mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-30 21:59:46 +02:00
fix(chat): normalize provider-safe message history
This commit is contained in:
parent
5d5d574550
commit
3dd54230e7
9 changed files with 382 additions and 6 deletions
89
surfsense_backend/app/tasks/chat/llm_history_normalizer.py
Normal file
89
surfsense_backend/app/tasks/chat/llm_history_normalizer.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""Convert persisted chat content into provider-safe LangChain history.
|
||||
|
||||
Assistant UI parts are a UI/storage shape, not an LLM prompt shape. This module
|
||||
extracts only model-safe content before prior turns are replayed to a provider.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
_USER_CONTENT_TYPES = {"text", "image", "image_url"}
|
||||
|
||||
|
||||
def _text_from_block(block: dict[str, Any]) -> str:
|
||||
value = block.get("text") or block.get("content") or ""
|
||||
return value if isinstance(value, str) else ""
|
||||
|
||||
|
||||
def assistant_content_to_llm_text(content: Any) -> str:
|
||||
"""Return visible assistant text, dropping reasoning/UI/provider blocks."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, dict):
|
||||
return _text_from_block(content)
|
||||
if not isinstance(content, list):
|
||||
return ""
|
||||
|
||||
text_chunks: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
if block:
|
||||
text_chunks.append(block)
|
||||
continue
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
text = _text_from_block(block)
|
||||
if text:
|
||||
text_chunks.append(text)
|
||||
return "\n".join(text_chunks)
|
||||
|
||||
|
||||
def user_content_to_llm_content(
|
||||
content: Any,
|
||||
*,
|
||||
allow_images: bool = True,
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Return provider-safe user text/image content for LangChain."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, dict):
|
||||
return _text_from_block(content)
|
||||
if not isinstance(content, list):
|
||||
return ""
|
||||
|
||||
parts: list[dict[str, Any]] = []
|
||||
text_chunks: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
if block:
|
||||
text_chunks.append(block)
|
||||
continue
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
block_type = block.get("type")
|
||||
if block_type not in _USER_CONTENT_TYPES:
|
||||
continue
|
||||
if block_type == "text":
|
||||
text = _text_from_block(block)
|
||||
if text:
|
||||
parts.append({"type": "text", "text": text})
|
||||
text_chunks.append(text)
|
||||
elif allow_images and block_type == "image":
|
||||
image = block.get("image")
|
||||
if isinstance(image, str) and image.startswith("data:"):
|
||||
parts.append({"type": "image_url", "image_url": {"url": image}})
|
||||
elif allow_images and block_type == "image_url":
|
||||
image_url = block.get("image_url")
|
||||
if isinstance(image_url, dict):
|
||||
url = image_url.get("url")
|
||||
if isinstance(url, str) and url.startswith("data:"):
|
||||
parts.append({"type": "image_url", "image_url": {"url": url}})
|
||||
elif isinstance(image_url, str) and image_url.startswith("data:"):
|
||||
parts.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
|
||||
if allow_images and any(part.get("type") == "image_url" for part in parts):
|
||||
return parts
|
||||
return "\n".join(text_chunks)
|
||||
|
||||
86
surfsense_backend/app/tasks/chat/message_parts_normalizer.py
Normal file
86
surfsense_backend/app/tasks/chat/message_parts_normalizer.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""Normalize final LangChain assistant messages into assistant-ui parts.
|
||||
|
||||
Live streaming remains the primary source for rich, incremental UI state.
|
||||
This module is only used after the graph has finished so refresh persistence
|
||||
does not depend on provider-specific streaming chunk shapes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
def _text_from_content(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if not isinstance(content, list):
|
||||
return ""
|
||||
|
||||
text_parts: list[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") != "text":
|
||||
continue
|
||||
value = block.get("text") or block.get("content") or ""
|
||||
if isinstance(value, str) and value:
|
||||
text_parts.append(value)
|
||||
return "".join(text_parts)
|
||||
|
||||
|
||||
def normalize_ai_message_to_parts(message: AIMessage | Any | None) -> list[dict[str, Any]]:
|
||||
"""Return user-visible assistant-ui parts for a final AI message.
|
||||
|
||||
We intentionally do not backfill provider ``thinking`` /
|
||||
``reasoning_content`` blocks here. If reasoning streamed live, the
|
||||
``AssistantContentBuilder`` already captured it. If it only exists in the
|
||||
final model payload, persisting it retroactively could expose content the
|
||||
UI never showed during the turn.
|
||||
"""
|
||||
if message is None:
|
||||
return []
|
||||
|
||||
text = _text_from_content(getattr(message, "content", None)).strip()
|
||||
if not text:
|
||||
return []
|
||||
return [{"type": "text", "text": text}]
|
||||
|
||||
|
||||
def last_ai_message(messages: Iterable[Any] | None) -> AIMessage | Any | None:
|
||||
if messages is None:
|
||||
return None
|
||||
for message in reversed(list(messages)):
|
||||
if isinstance(message, AIMessage):
|
||||
return message
|
||||
if getattr(message, "type", None) == "ai":
|
||||
return message
|
||||
return None
|
||||
|
||||
|
||||
def final_assistant_parts_from_messages(messages: Iterable[Any] | None) -> list[dict[str, Any]]:
|
||||
return normalize_ai_message_to_parts(last_ai_message(messages))
|
||||
|
||||
|
||||
def has_non_empty_text_part(parts: Iterable[dict[str, Any]]) -> bool:
|
||||
return any(
|
||||
part.get("type") == "text"
|
||||
and isinstance(part.get("text"), str)
|
||||
and bool(part.get("text", "").strip())
|
||||
for part in parts
|
||||
)
|
||||
|
||||
|
||||
def merge_streamed_and_final_parts(
|
||||
streamed_parts: list[dict[str, Any]],
|
||||
final_parts: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Use final-state text only when streaming captured no answer text."""
|
||||
if has_non_empty_text_part(streamed_parts):
|
||||
return streamed_parts
|
||||
if not has_non_empty_text_part(final_parts):
|
||||
return streamed_parts
|
||||
return [*streamed_parts, *final_parts]
|
||||
|
||||
|
|
@ -25,6 +25,9 @@ from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
|
|||
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
|
||||
all_interrupt_values,
|
||||
)
|
||||
from app.tasks.chat.message_parts_normalizer import (
|
||||
final_assistant_parts_from_messages,
|
||||
)
|
||||
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||
from app.tasks.chat.streaming.shared.utils import safe_float
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
|
@ -75,6 +78,9 @@ async def stream_agent_events(
|
|||
|
||||
state = await agent.aget_state(config)
|
||||
state_values = getattr(state, "values", {}) or {}
|
||||
result.final_message_parts = final_assistant_parts_from_messages(
|
||||
state_values.get("messages")
|
||||
)
|
||||
|
||||
# Safety net: if astream_events was cancelled before
|
||||
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ async def finalize_assistant_message(
|
|||
):
|
||||
return
|
||||
|
||||
from app.tasks.chat.message_parts_normalizer import merge_streamed_and_final_parts
|
||||
from app.tasks.chat.persistence import finalize_assistant_turn
|
||||
|
||||
builder_stats: dict[str, int] | None = None
|
||||
|
|
@ -74,6 +75,10 @@ async def finalize_assistant_message(
|
|||
"text": stream_result.accumulated_text or "",
|
||||
}
|
||||
]
|
||||
content_payload = merge_streamed_and_final_parts(
|
||||
content_payload,
|
||||
stream_result.final_message_parts,
|
||||
)
|
||||
|
||||
if builder_stats is not None:
|
||||
_perf_log.info(
|
||||
|
|
|
|||
|
|
@ -35,3 +35,7 @@ class StreamResult:
|
|||
# (``StreamResult`` is logged in some error branches) from dumping a
|
||||
# potentially-large parts list.
|
||||
content_builder: Any | None = field(default=None, repr=False)
|
||||
# User-visible assistant message parts derived from the final LangGraph
|
||||
# state. Used after streaming completes as a provider-agnostic persistence
|
||||
# backfill when no text chunks reached the live stream.
|
||||
final_message_parts: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,11 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.tasks.chat.llm_history_normalizer import (
|
||||
assistant_content_to_llm_text,
|
||||
user_content_to_llm_content,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.db import ChatVisibility
|
||||
|
||||
|
|
@ -95,17 +100,28 @@ async def bootstrap_history_from_db(
|
|||
langchain_messages: list[HumanMessage | AIMessage] = []
|
||||
|
||||
for msg in db_messages:
|
||||
text_content = extract_text_content(msg.content)
|
||||
if not text_content:
|
||||
continue
|
||||
if msg.role == "user":
|
||||
user_content = user_content_to_llm_content(
|
||||
msg.content,
|
||||
allow_images=False,
|
||||
)
|
||||
if not user_content:
|
||||
continue
|
||||
if is_shared:
|
||||
author_name = (
|
||||
msg.author.display_name if msg.author else None
|
||||
) or "A team member"
|
||||
text_content = f"**[{author_name}]:** {text_content}"
|
||||
langchain_messages.append(HumanMessage(content=text_content))
|
||||
if isinstance(user_content, str):
|
||||
user_content = f"**[{author_name}]:** {user_content}"
|
||||
elif user_content and user_content[0].get("type") == "text":
|
||||
user_content[0] = {
|
||||
**user_content[0],
|
||||
"text": f"**[{author_name}]:** {user_content[0].get('text', '')}",
|
||||
}
|
||||
langchain_messages.append(HumanMessage(content=user_content))
|
||||
elif msg.role == "assistant":
|
||||
langchain_messages.append(AIMessage(content=text_content))
|
||||
assistant_text = assistant_content_to_llm_text(msg.content)
|
||||
if assistant_text:
|
||||
langchain_messages.append(AIMessage(content=assistant_text))
|
||||
|
||||
return langchain_messages
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue