Extend new chat streaming for multimodal user turns

This commit is contained in:
CREDO23 2026-04-24 18:48:02 +02:00
parent c9477d13fc
commit d1080b1298
4 changed files with 178 additions and 23 deletions

View file

@ -24,9 +24,9 @@ from sqlalchemy.orm import selectinload
from app.agents.new_chat.filesystem_selection import (
ClientPlatform,
LocalFilesystemMount,
FilesystemMode,
FilesystemSelection,
LocalFilesystemMount,
)
from app.config import config
from app.db import (
@ -64,6 +64,10 @@ from app.services.token_tracking_service import record_token_usage
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
from app.users import current_active_user
from app.utils.rbac import check_permission
from app.utils.user_message_multimodal import (
split_langchain_human_content,
split_persisted_user_content_parts,
)
_logger = logging.getLogger(__name__)
_background_tasks: set[asyncio.Task] = set()
@ -1237,6 +1241,10 @@ async def handle_new_chat(
# connection (the "Exception terminating connection" errors).
await session.close()
image_urls = (
[p.as_data_url() for p in request.user_images] if request.user_images else None
)
return StreamingResponse(
stream_new_chat(
user_query=request.user_query,
@ -1252,6 +1260,7 @@ async def handle_new_chat(
disabled_tools=request.disabled_tools,
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=image_urls,
),
media_type="text/event-stream",
headers={
@ -1360,6 +1369,7 @@ async def regenerate_response(
target_checkpoint_id = None
user_query_to_use = request.user_query
regenerate_image_urls: list[str] = []
# Look through checkpoints to find the right one
# We want to find the checkpoint just before the last HumanMessage
@ -1385,9 +1395,13 @@ async def regenerate_response(
prev_messages = prev_channel_values.get("messages", [])
for msg in reversed(prev_messages):
if isinstance(msg, HumanMessage):
user_query_to_use = msg.content
q, imgs = split_langchain_human_content(msg.content)
user_query_to_use = q
regenerate_image_urls = imgs
break
if user_query_to_use:
if user_query_to_use is not None and (
str(user_query_to_use).strip() or regenerate_image_urls
):
break
target_checkpoint_id = cp_tuple.config["configurable"][
@ -1405,7 +1419,9 @@ async def regenerate_response(
state_messages = channel_values.get("messages", [])
for msg in state_messages:
if isinstance(msg, HumanMessage):
user_query_to_use = msg.content
q, imgs = split_langchain_human_content(msg.content)
user_query_to_use = q
regenerate_image_urls = imgs
break
else:
# Use the oldest checkpoint
@ -1431,20 +1447,25 @@ async def regenerate_response(
if isinstance(content, str):
user_query_to_use = content
elif isinstance(content, list):
# Extract text from content parts
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
user_query_to_use = part.get("text", "")
break
elif isinstance(part, str):
user_query_to_use = part
break
plain, imgs = split_persisted_user_content_parts(content)
user_query_to_use = plain
regenerate_image_urls = imgs
if isinstance(user_query_to_use, list):
user_query_to_use, regenerate_image_urls = split_langchain_human_content(
user_query_to_use
)
if user_query_to_use is None:
raise HTTPException(
status_code=400,
detail="Could not determine user query for regeneration. Please provide a user_query.",
)
if not str(user_query_to_use).strip() and not regenerate_image_urls:
raise HTTPException(
status_code=400,
detail="Could not determine user query for regeneration. Please provide a user_query.",
)
# Get the last two messages to delete AFTER streaming succeeds
# This prevents data loss if streaming fails
@ -1483,7 +1504,7 @@ async def regenerate_response(
streaming_completed = False
try:
async for chunk in stream_new_chat(
user_query=user_query_to_use,
user_query=str(user_query_to_use),
search_space_id=request.search_space_id,
chat_id=thread_id,
user_id=str(user.id),
@ -1497,6 +1518,7 @@ async def regenerate_response(
disabled_tools=request.disabled_tools,
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=regenerate_image_urls or None,
):
yield chunk
streaming_completed = True

View file

@ -7,12 +7,13 @@ These schemas follow the assistant-ui ThreadHistoryAdapter pattern:
"""
from datetime import datetime
from typing import Any, Literal
from typing import Any, Literal, Self
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from app.db import ChatVisibility, NewChatMessageRole
from app.utils.user_message_multimodal import decode_base64_image, to_data_url
from .base import IDModel, TimestampModel
@ -173,6 +174,26 @@ class LocalFilesystemMountPayload(BaseModel):
root_path: str
MAX_NEW_CHAT_IMAGE_BYTES = 8 * 1024 * 1024
MAX_NEW_CHAT_IMAGES = 4
class NewChatUserImagePart(BaseModel):
"""One inline image for a user turn (raw base64 body, no data: URL prefix)."""
media_type: Literal["image/png", "image/jpeg", "image/webp"]
data: str = Field(..., min_length=1)
@field_validator("data")
@classmethod
def _validate_payload(cls, v: str) -> str:
decode_base64_image(v, max_bytes=MAX_NEW_CHAT_IMAGE_BYTES)
return v
def as_data_url(self) -> str:
return to_data_url(self.media_type, self.data)
class NewChatRequest(BaseModel):
"""Request schema for the deep agent chat endpoint."""
@ -192,6 +213,20 @@ class NewChatRequest(BaseModel):
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
client_platform: Literal["web", "desktop"] = "web"
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
user_images: list[NewChatUserImagePart] | None = Field(
default=None,
description="Optional images for this user turn",
)
@model_validator(mode="after")
def _require_text_or_images(self) -> Self:
has_text = bool(self.user_query.strip())
has_images = bool(self.user_images)
if not has_text and not has_images:
raise ValueError("Provide non-empty user_query and/or user_images")
if self.user_images is not None and len(self.user_images) > MAX_NEW_CHAT_IMAGES:
raise ValueError(f"At most {MAX_NEW_CHAT_IMAGES} images allowed")
return self
class RegenerateRequest(BaseModel):

View file

@ -31,7 +31,6 @@ from sqlalchemy.orm import selectinload
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer
from app.agents.new_chat.filesystem_selection import FilesystemSelection
from app.config import config
from app.agents.new_chat.llm_config import (
AgentConfig,
create_chat_litellm_from_agent_config,
@ -62,6 +61,7 @@ from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService
from app.utils.content_utils import bootstrap_history_from_db
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
from app.utils.user_message_multimodal import build_human_message_content
_background_tasks: set[asyncio.Task] = set()
_perf_log = get_perf_logger()
@ -1350,6 +1350,7 @@ async def stream_new_chat(
disabled_tools: list[str] | None = None,
filesystem_selection: FilesystemSelection | None = None,
request_id: str | None = None,
user_image_data_urls: list[str] | None = None,
) -> AsyncGenerator[str, None]:
"""
Stream chat responses from the new SurfSense deep agent.
@ -1625,8 +1626,10 @@ async def stream_new_chat(
# elif msg.role == "assistant":
# langchain_messages.append(AIMessage(content=msg.content))
# else:
# Fallback: just use the current user query with attachment context
langchain_messages.append(HumanMessage(content=final_query))
human_content = build_human_message_content(
final_query, list(user_image_data_urls or ())
)
langchain_messages.append(HumanMessage(content=human_content))
input_state = {
# Lets not pass this message atm because we are using the checkpointer to manage the conversation history
@ -1687,8 +1690,13 @@ async def stream_new_chat(
action_verb = "Processing"
processing_parts = []
query_text = user_query[:80] + ("..." if len(user_query) > 80 else "")
processing_parts.append(query_text)
if user_query.strip():
query_text = user_query[:80] + ("..." if len(user_query) > 80 else "")
processing_parts.append(query_text)
elif user_image_data_urls:
processing_parts.append(f"[{len(user_image_data_urls)} image(s)]")
else:
processing_parts.append("(message)")
if mentioned_surfsense_docs:
doc_names = []
@ -1750,8 +1758,13 @@ async def stream_new_chat(
_turn_accumulator.set(None)
title_seed = user_query.strip() or (
f"[{len(user_image_data_urls or [])} image(s)]"
if user_image_data_urls
else ""
)
prompt = TITLE_GENERATION_PROMPT.replace(
"{user_query}", user_query[:500]
"{user_query}", title_seed[:500] or "(message)"
)
messages = [{"role": "user", "content": prompt}]
@ -1947,10 +1960,15 @@ async def stream_new_chat(
# Fire background memory extraction if the agent didn't handle it.
# Shared threads write to team memory; private threads write to user memory.
if not stream_result.agent_called_update_memory:
memory_seed = user_query.strip() or (
f"[{len(user_image_data_urls or [])} image(s)]"
if user_image_data_urls
else "(message)"
)
if visibility == ChatVisibility.SEARCH_SPACE:
task = asyncio.create_task(
extract_and_save_team_memory(
user_message=user_query,
user_message=memory_seed,
search_space_id=search_space_id,
llm=llm,
author_display_name=current_user_display_name,
@ -1961,7 +1979,7 @@ async def stream_new_chat(
elif user_id:
task = asyncio.create_task(
extract_and_save_memory(
user_message=user_query,
user_message=memory_seed,
user_id=user_id,
llm=llm,
)

View file

@ -0,0 +1,80 @@
"""Helpers for multimodal user turns (text + inline images) in LangChain messages."""
from __future__ import annotations
import base64
import binascii
from typing import Any
def build_human_message_content(final_query: str, image_data_urls: list[str]) -> str | list[dict[str, Any]]:
if not image_data_urls:
return final_query
parts: list[dict[str, Any]] = [{"type": "text", "text": final_query}]
for url in image_data_urls:
parts.append({"type": "image_url", "image_url": {"url": url}})
return parts
def split_langchain_human_content(content: str | list[Any]) -> tuple[str, list[str]]:
"""Return plain text and data URLs from a LangChain HumanMessage ``content`` value."""
if isinstance(content, str):
return content, []
if not isinstance(content, list):
return "", []
text_chunks: list[str] = []
urls: list[str] = []
for block in content:
if not isinstance(block, dict):
continue
btype = block.get("type")
if btype == "text":
t = block.get("text")
if isinstance(t, str) and t:
text_chunks.append(t)
elif btype == "image_url":
iu = block.get("image_url")
if isinstance(iu, dict):
u = iu.get("url")
if isinstance(u, str) and u.startswith("data:"):
urls.append(u)
elif isinstance(iu, str) and iu.startswith("data:"):
urls.append(iu)
return "\n".join(text_chunks), urls
def decode_base64_image(data: str, *, max_bytes: int) -> bytes:
raw = data.strip()
if not raw:
raise ValueError("empty image payload")
try:
decoded = base64.b64decode(raw, validate=True)
except binascii.Error as e:
raise ValueError("invalid base64 image data") from e
if len(decoded) > max_bytes:
raise ValueError("image exceeds maximum size")
return decoded
def to_data_url(media_type: str, raw_b64: str) -> str:
return f"data:{media_type};base64,{raw_b64.strip()}"
def split_persisted_user_content_parts(parts: list[Any]) -> tuple[str, list[str]]:
"""Extract plain text and data URLs from persisted assistant-ui style user ``content``."""
text_chunks: list[str] = []
urls: list[str] = []
for block in parts:
if not isinstance(block, dict):
continue
btype = block.get("type")
if btype == "text":
t = block.get("text")
if isinstance(t, str):
text_chunks.append(t)
elif btype == "image":
u = block.get("image")
if isinstance(u, str) and u.startswith("data:"):
urls.append(u)
return "".join(text_chunks), urls