mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-25 19:15:18 +02:00
Extend new chat streaming for multimodal user turns
This commit is contained in:
parent
c9477d13fc
commit
d1080b1298
4 changed files with 178 additions and 23 deletions
|
|
@ -24,9 +24,9 @@ from sqlalchemy.orm import selectinload
|
|||
|
||||
from app.agents.new_chat.filesystem_selection import (
|
||||
ClientPlatform,
|
||||
LocalFilesystemMount,
|
||||
FilesystemMode,
|
||||
FilesystemSelection,
|
||||
LocalFilesystemMount,
|
||||
)
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
|
|
@ -64,6 +64,10 @@ from app.services.token_tracking_service import record_token_usage
|
|||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.user_message_multimodal import (
|
||||
split_langchain_human_content,
|
||||
split_persisted_user_content_parts,
|
||||
)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
|
@ -1237,6 +1241,10 @@ async def handle_new_chat(
|
|||
# connection (the "Exception terminating connection" errors).
|
||||
await session.close()
|
||||
|
||||
image_urls = (
|
||||
[p.as_data_url() for p in request.user_images] if request.user_images else None
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
stream_new_chat(
|
||||
user_query=request.user_query,
|
||||
|
|
@ -1252,6 +1260,7 @@ async def handle_new_chat(
|
|||
disabled_tools=request.disabled_tools,
|
||||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
user_image_data_urls=image_urls,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
@ -1360,6 +1369,7 @@ async def regenerate_response(
|
|||
|
||||
target_checkpoint_id = None
|
||||
user_query_to_use = request.user_query
|
||||
regenerate_image_urls: list[str] = []
|
||||
|
||||
# Look through checkpoints to find the right one
|
||||
# We want to find the checkpoint just before the last HumanMessage
|
||||
|
|
@ -1385,9 +1395,13 @@ async def regenerate_response(
|
|||
prev_messages = prev_channel_values.get("messages", [])
|
||||
for msg in reversed(prev_messages):
|
||||
if isinstance(msg, HumanMessage):
|
||||
user_query_to_use = msg.content
|
||||
q, imgs = split_langchain_human_content(msg.content)
|
||||
user_query_to_use = q
|
||||
regenerate_image_urls = imgs
|
||||
break
|
||||
if user_query_to_use:
|
||||
if user_query_to_use is not None and (
|
||||
str(user_query_to_use).strip() or regenerate_image_urls
|
||||
):
|
||||
break
|
||||
|
||||
target_checkpoint_id = cp_tuple.config["configurable"][
|
||||
|
|
@ -1405,7 +1419,9 @@ async def regenerate_response(
|
|||
state_messages = channel_values.get("messages", [])
|
||||
for msg in state_messages:
|
||||
if isinstance(msg, HumanMessage):
|
||||
user_query_to_use = msg.content
|
||||
q, imgs = split_langchain_human_content(msg.content)
|
||||
user_query_to_use = q
|
||||
regenerate_image_urls = imgs
|
||||
break
|
||||
else:
|
||||
# Use the oldest checkpoint
|
||||
|
|
@ -1431,20 +1447,25 @@ async def regenerate_response(
|
|||
if isinstance(content, str):
|
||||
user_query_to_use = content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from content parts
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
user_query_to_use = part.get("text", "")
|
||||
break
|
||||
elif isinstance(part, str):
|
||||
user_query_to_use = part
|
||||
break
|
||||
plain, imgs = split_persisted_user_content_parts(content)
|
||||
user_query_to_use = plain
|
||||
regenerate_image_urls = imgs
|
||||
|
||||
if isinstance(user_query_to_use, list):
|
||||
user_query_to_use, regenerate_image_urls = split_langchain_human_content(
|
||||
user_query_to_use
|
||||
)
|
||||
|
||||
if user_query_to_use is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Could not determine user query for regeneration. Please provide a user_query.",
|
||||
)
|
||||
if not str(user_query_to_use).strip() and not regenerate_image_urls:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Could not determine user query for regeneration. Please provide a user_query.",
|
||||
)
|
||||
|
||||
# Get the last two messages to delete AFTER streaming succeeds
|
||||
# This prevents data loss if streaming fails
|
||||
|
|
@ -1483,7 +1504,7 @@ async def regenerate_response(
|
|||
streaming_completed = False
|
||||
try:
|
||||
async for chunk in stream_new_chat(
|
||||
user_query=user_query_to_use,
|
||||
user_query=str(user_query_to_use),
|
||||
search_space_id=request.search_space_id,
|
||||
chat_id=thread_id,
|
||||
user_id=str(user.id),
|
||||
|
|
@ -1497,6 +1518,7 @@ async def regenerate_response(
|
|||
disabled_tools=request.disabled_tools,
|
||||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
user_image_data_urls=regenerate_image_urls or None,
|
||||
):
|
||||
yield chunk
|
||||
streaming_completed = True
|
||||
|
|
|
|||
|
|
@ -7,12 +7,13 @@ These schemas follow the assistant-ui ThreadHistoryAdapter pattern:
|
|||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Self
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from app.db import ChatVisibility, NewChatMessageRole
|
||||
from app.utils.user_message_multimodal import decode_base64_image, to_data_url
|
||||
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
|
|
@ -173,6 +174,26 @@ class LocalFilesystemMountPayload(BaseModel):
|
|||
root_path: str
|
||||
|
||||
|
||||
MAX_NEW_CHAT_IMAGE_BYTES = 8 * 1024 * 1024
|
||||
MAX_NEW_CHAT_IMAGES = 4
|
||||
|
||||
|
||||
class NewChatUserImagePart(BaseModel):
|
||||
"""One inline image for a user turn (raw base64 body, no data: URL prefix)."""
|
||||
|
||||
media_type: Literal["image/png", "image/jpeg", "image/webp"]
|
||||
data: str = Field(..., min_length=1)
|
||||
|
||||
@field_validator("data")
|
||||
@classmethod
|
||||
def _validate_payload(cls, v: str) -> str:
|
||||
decode_base64_image(v, max_bytes=MAX_NEW_CHAT_IMAGE_BYTES)
|
||||
return v
|
||||
|
||||
def as_data_url(self) -> str:
|
||||
return to_data_url(self.media_type, self.data)
|
||||
|
||||
|
||||
class NewChatRequest(BaseModel):
|
||||
"""Request schema for the deep agent chat endpoint."""
|
||||
|
||||
|
|
@ -192,6 +213,20 @@ class NewChatRequest(BaseModel):
|
|||
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
||||
client_platform: Literal["web", "desktop"] = "web"
|
||||
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
||||
user_images: list[NewChatUserImagePart] | None = Field(
|
||||
default=None,
|
||||
description="Optional images for this user turn",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _require_text_or_images(self) -> Self:
|
||||
has_text = bool(self.user_query.strip())
|
||||
has_images = bool(self.user_images)
|
||||
if not has_text and not has_images:
|
||||
raise ValueError("Provide non-empty user_query and/or user_images")
|
||||
if self.user_images is not None and len(self.user_images) > MAX_NEW_CHAT_IMAGES:
|
||||
raise ValueError(f"At most {MAX_NEW_CHAT_IMAGES} images allowed")
|
||||
return self
|
||||
|
||||
|
||||
class RegenerateRequest(BaseModel):
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ from sqlalchemy.orm import selectinload
|
|||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemSelection
|
||||
from app.config import config
|
||||
from app.agents.new_chat.llm_config import (
|
||||
AgentConfig,
|
||||
create_chat_litellm_from_agent_config,
|
||||
|
|
@ -62,6 +61,7 @@ from app.services.connector_service import ConnectorService
|
|||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
|
||||
from app.utils.user_message_multimodal import build_human_message_content
|
||||
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
_perf_log = get_perf_logger()
|
||||
|
|
@ -1350,6 +1350,7 @@ async def stream_new_chat(
|
|||
disabled_tools: list[str] | None = None,
|
||||
filesystem_selection: FilesystemSelection | None = None,
|
||||
request_id: str | None = None,
|
||||
user_image_data_urls: list[str] | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream chat responses from the new SurfSense deep agent.
|
||||
|
|
@ -1625,8 +1626,10 @@ async def stream_new_chat(
|
|||
# elif msg.role == "assistant":
|
||||
# langchain_messages.append(AIMessage(content=msg.content))
|
||||
# else:
|
||||
# Fallback: just use the current user query with attachment context
|
||||
langchain_messages.append(HumanMessage(content=final_query))
|
||||
human_content = build_human_message_content(
|
||||
final_query, list(user_image_data_urls or ())
|
||||
)
|
||||
langchain_messages.append(HumanMessage(content=human_content))
|
||||
|
||||
input_state = {
|
||||
# Lets not pass this message atm because we are using the checkpointer to manage the conversation history
|
||||
|
|
@ -1687,8 +1690,13 @@ async def stream_new_chat(
|
|||
action_verb = "Processing"
|
||||
|
||||
processing_parts = []
|
||||
query_text = user_query[:80] + ("..." if len(user_query) > 80 else "")
|
||||
processing_parts.append(query_text)
|
||||
if user_query.strip():
|
||||
query_text = user_query[:80] + ("..." if len(user_query) > 80 else "")
|
||||
processing_parts.append(query_text)
|
||||
elif user_image_data_urls:
|
||||
processing_parts.append(f"[{len(user_image_data_urls)} image(s)]")
|
||||
else:
|
||||
processing_parts.append("(message)")
|
||||
|
||||
if mentioned_surfsense_docs:
|
||||
doc_names = []
|
||||
|
|
@ -1750,8 +1758,13 @@ async def stream_new_chat(
|
|||
|
||||
_turn_accumulator.set(None)
|
||||
|
||||
title_seed = user_query.strip() or (
|
||||
f"[{len(user_image_data_urls or [])} image(s)]"
|
||||
if user_image_data_urls
|
||||
else ""
|
||||
)
|
||||
prompt = TITLE_GENERATION_PROMPT.replace(
|
||||
"{user_query}", user_query[:500]
|
||||
"{user_query}", title_seed[:500] or "(message)"
|
||||
)
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
|
|
@ -1947,10 +1960,15 @@ async def stream_new_chat(
|
|||
# Fire background memory extraction if the agent didn't handle it.
|
||||
# Shared threads write to team memory; private threads write to user memory.
|
||||
if not stream_result.agent_called_update_memory:
|
||||
memory_seed = user_query.strip() or (
|
||||
f"[{len(user_image_data_urls or [])} image(s)]"
|
||||
if user_image_data_urls
|
||||
else "(message)"
|
||||
)
|
||||
if visibility == ChatVisibility.SEARCH_SPACE:
|
||||
task = asyncio.create_task(
|
||||
extract_and_save_team_memory(
|
||||
user_message=user_query,
|
||||
user_message=memory_seed,
|
||||
search_space_id=search_space_id,
|
||||
llm=llm,
|
||||
author_display_name=current_user_display_name,
|
||||
|
|
@ -1961,7 +1979,7 @@ async def stream_new_chat(
|
|||
elif user_id:
|
||||
task = asyncio.create_task(
|
||||
extract_and_save_memory(
|
||||
user_message=user_query,
|
||||
user_message=memory_seed,
|
||||
user_id=user_id,
|
||||
llm=llm,
|
||||
)
|
||||
|
|
|
|||
80
surfsense_backend/app/utils/user_message_multimodal.py
Normal file
80
surfsense_backend/app/utils/user_message_multimodal.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
"""Helpers for multimodal user turns (text + inline images) in LangChain messages."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
from typing import Any
|
||||
|
||||
|
||||
def build_human_message_content(final_query: str, image_data_urls: list[str]) -> str | list[dict[str, Any]]:
|
||||
if not image_data_urls:
|
||||
return final_query
|
||||
parts: list[dict[str, Any]] = [{"type": "text", "text": final_query}]
|
||||
for url in image_data_urls:
|
||||
parts.append({"type": "image_url", "image_url": {"url": url}})
|
||||
return parts
|
||||
|
||||
|
||||
def split_langchain_human_content(content: str | list[Any]) -> tuple[str, list[str]]:
|
||||
"""Return plain text and data URLs from a LangChain HumanMessage ``content`` value."""
|
||||
if isinstance(content, str):
|
||||
return content, []
|
||||
if not isinstance(content, list):
|
||||
return "", []
|
||||
|
||||
text_chunks: list[str] = []
|
||||
urls: list[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
btype = block.get("type")
|
||||
if btype == "text":
|
||||
t = block.get("text")
|
||||
if isinstance(t, str) and t:
|
||||
text_chunks.append(t)
|
||||
elif btype == "image_url":
|
||||
iu = block.get("image_url")
|
||||
if isinstance(iu, dict):
|
||||
u = iu.get("url")
|
||||
if isinstance(u, str) and u.startswith("data:"):
|
||||
urls.append(u)
|
||||
elif isinstance(iu, str) and iu.startswith("data:"):
|
||||
urls.append(iu)
|
||||
return "\n".join(text_chunks), urls
|
||||
|
||||
|
||||
def decode_base64_image(data: str, *, max_bytes: int) -> bytes:
|
||||
raw = data.strip()
|
||||
if not raw:
|
||||
raise ValueError("empty image payload")
|
||||
try:
|
||||
decoded = base64.b64decode(raw, validate=True)
|
||||
except binascii.Error as e:
|
||||
raise ValueError("invalid base64 image data") from e
|
||||
if len(decoded) > max_bytes:
|
||||
raise ValueError("image exceeds maximum size")
|
||||
return decoded
|
||||
|
||||
|
||||
def to_data_url(media_type: str, raw_b64: str) -> str:
|
||||
return f"data:{media_type};base64,{raw_b64.strip()}"
|
||||
|
||||
|
||||
def split_persisted_user_content_parts(parts: list[Any]) -> tuple[str, list[str]]:
|
||||
"""Extract plain text and data URLs from persisted assistant-ui style user ``content``."""
|
||||
text_chunks: list[str] = []
|
||||
urls: list[str] = []
|
||||
for block in parts:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
btype = block.get("type")
|
||||
if btype == "text":
|
||||
t = block.get("text")
|
||||
if isinstance(t, str):
|
||||
text_chunks.append(t)
|
||||
elif btype == "image":
|
||||
u = block.get("image")
|
||||
if isinstance(u, str) and u.startswith("data:"):
|
||||
urls.append(u)
|
||||
return "".join(text_chunks), urls
|
||||
Loading…
Add table
Add a link
Reference in a new issue