diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 85a8658ec..854627d4b 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -24,9 +24,9 @@ from sqlalchemy.orm import selectinload from app.agents.new_chat.filesystem_selection import ( ClientPlatform, - LocalFilesystemMount, FilesystemMode, FilesystemSelection, + LocalFilesystemMount, ) from app.config import config from app.db import ( @@ -64,6 +64,10 @@ from app.services.token_tracking_service import record_token_usage from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat from app.users import current_active_user from app.utils.rbac import check_permission +from app.utils.user_message_multimodal import ( + split_langchain_human_content, + split_persisted_user_content_parts, +) _logger = logging.getLogger(__name__) _background_tasks: set[asyncio.Task] = set() @@ -1237,6 +1241,10 @@ async def handle_new_chat( # connection (the "Exception terminating connection" errors). await session.close() + image_urls = ( + [p.as_data_url() for p in request.user_images] if request.user_images else None + ) + return StreamingResponse( stream_new_chat( user_query=request.user_query, @@ -1252,6 +1260,7 @@ async def handle_new_chat( disabled_tools=request.disabled_tools, filesystem_selection=filesystem_selection, request_id=getattr(http_request.state, "request_id", "unknown"), + user_image_data_urls=image_urls, ), media_type="text/event-stream", headers={ @@ -1360,6 +1369,7 @@ async def regenerate_response( target_checkpoint_id = None user_query_to_use = request.user_query + regenerate_image_urls: list[str] = [] # Look through checkpoints to find the right one # We want to find the checkpoint just before the last HumanMessage @@ -1385,9 +1395,13 @@ async def regenerate_response( prev_messages = prev_channel_values.get("messages", []) for msg in reversed(prev_messages): if isinstance(msg, HumanMessage): - user_query_to_use = msg.content + q, imgs = split_langchain_human_content(msg.content) + user_query_to_use = q + regenerate_image_urls = imgs break - if user_query_to_use: + if user_query_to_use is not None and ( + str(user_query_to_use).strip() or regenerate_image_urls + ): break target_checkpoint_id = cp_tuple.config["configurable"][ @@ -1405,7 +1419,9 @@ async def regenerate_response( state_messages = channel_values.get("messages", []) for msg in state_messages: if isinstance(msg, HumanMessage): - user_query_to_use = msg.content + q, imgs = split_langchain_human_content(msg.content) + user_query_to_use = q + regenerate_image_urls = imgs break else: # Use the oldest checkpoint @@ -1431,20 +1447,25 @@ async def regenerate_response( if isinstance(content, str): user_query_to_use = content elif isinstance(content, list): - # Extract text from content parts - for part in content: - if isinstance(part, dict) and part.get("type") == "text": - user_query_to_use = part.get("text", "") - break - elif isinstance(part, str): - user_query_to_use = part - break + plain, imgs = split_persisted_user_content_parts(content) + user_query_to_use = plain + regenerate_image_urls = imgs + + if isinstance(user_query_to_use, list): + user_query_to_use, regenerate_image_urls = split_langchain_human_content( + user_query_to_use + ) if user_query_to_use is None: raise HTTPException( status_code=400, detail="Could not determine user query for regeneration. Please provide a user_query.", ) + if not str(user_query_to_use).strip() and not regenerate_image_urls: + raise HTTPException( + status_code=400, + detail="Could not determine user query for regeneration. Please provide a user_query.", + ) # Get the last two messages to delete AFTER streaming succeeds # This prevents data loss if streaming fails @@ -1483,7 +1504,7 @@ async def regenerate_response( streaming_completed = False try: async for chunk in stream_new_chat( - user_query=user_query_to_use, + user_query=str(user_query_to_use), search_space_id=request.search_space_id, chat_id=thread_id, user_id=str(user.id), @@ -1497,6 +1518,7 @@ async def regenerate_response( disabled_tools=request.disabled_tools, filesystem_selection=filesystem_selection, request_id=getattr(http_request.state, "request_id", "unknown"), + user_image_data_urls=regenerate_image_urls or None, ): yield chunk streaming_completed = True diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index 1222deab2..e757ce178 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -7,12 +7,13 @@ These schemas follow the assistant-ui ThreadHistoryAdapter pattern: """ from datetime import datetime -from typing import Any, Literal +from typing import Any, Literal, Self from uuid import UUID -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from app.db import ChatVisibility, NewChatMessageRole +from app.utils.user_message_multimodal import decode_base64_image, to_data_url from .base import IDModel, TimestampModel @@ -173,6 +174,26 @@ class LocalFilesystemMountPayload(BaseModel): root_path: str +MAX_NEW_CHAT_IMAGE_BYTES = 8 * 1024 * 1024 +MAX_NEW_CHAT_IMAGES = 4 + + +class NewChatUserImagePart(BaseModel): + """One inline image for a user turn (raw base64 body, no data: URL prefix).""" + + media_type: Literal["image/png", "image/jpeg", "image/webp"] + data: str = Field(..., min_length=1) + + @field_validator("data") + @classmethod + def _validate_payload(cls, v: str) -> str: + decode_base64_image(v, max_bytes=MAX_NEW_CHAT_IMAGE_BYTES) + return v + + def as_data_url(self) -> str: + return to_data_url(self.media_type, self.data) + + class NewChatRequest(BaseModel): """Request schema for the deep agent chat endpoint.""" @@ -192,6 +213,20 @@ class NewChatRequest(BaseModel): filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud" client_platform: Literal["web", "desktop"] = "web" local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None + user_images: list[NewChatUserImagePart] | None = Field( + default=None, + description="Optional images for this user turn", + ) + + @model_validator(mode="after") + def _require_text_or_images(self) -> Self: + has_text = bool(self.user_query.strip()) + has_images = bool(self.user_images) + if not has_text and not has_images: + raise ValueError("Provide non-empty user_query and/or user_images") + if self.user_images is not None and len(self.user_images) > MAX_NEW_CHAT_IMAGES: + raise ValueError(f"At most {MAX_NEW_CHAT_IMAGES} images allowed") + return self class RegenerateRequest(BaseModel): diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 5a6117808..396c7574e 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -31,7 +31,6 @@ from sqlalchemy.orm import selectinload from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer from app.agents.new_chat.filesystem_selection import FilesystemSelection -from app.config import config from app.agents.new_chat.llm_config import ( AgentConfig, create_chat_litellm_from_agent_config, @@ -62,6 +61,7 @@ from app.services.connector_service import ConnectorService from app.services.new_streaming_service import VercelStreamingService from app.utils.content_utils import bootstrap_history_from_db from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap +from app.utils.user_message_multimodal import build_human_message_content _background_tasks: set[asyncio.Task] = set() _perf_log = get_perf_logger() @@ -1350,6 +1350,7 @@ async def stream_new_chat( disabled_tools: list[str] | None = None, filesystem_selection: FilesystemSelection | None = None, request_id: str | None = None, + user_image_data_urls: list[str] | None = None, ) -> AsyncGenerator[str, None]: """ Stream chat responses from the new SurfSense deep agent. @@ -1625,8 +1626,10 @@ async def stream_new_chat( # elif msg.role == "assistant": # langchain_messages.append(AIMessage(content=msg.content)) # else: - # Fallback: just use the current user query with attachment context - langchain_messages.append(HumanMessage(content=final_query)) + human_content = build_human_message_content( + final_query, list(user_image_data_urls or ()) + ) + langchain_messages.append(HumanMessage(content=human_content)) input_state = { # Lets not pass this message atm because we are using the checkpointer to manage the conversation history @@ -1687,8 +1690,13 @@ async def stream_new_chat( action_verb = "Processing" processing_parts = [] - query_text = user_query[:80] + ("..." if len(user_query) > 80 else "") - processing_parts.append(query_text) + if user_query.strip(): + query_text = user_query[:80] + ("..." if len(user_query) > 80 else "") + processing_parts.append(query_text) + elif user_image_data_urls: + processing_parts.append(f"[{len(user_image_data_urls)} image(s)]") + else: + processing_parts.append("(message)") if mentioned_surfsense_docs: doc_names = [] @@ -1750,8 +1758,13 @@ async def stream_new_chat( _turn_accumulator.set(None) + title_seed = user_query.strip() or ( + f"[{len(user_image_data_urls or [])} image(s)]" + if user_image_data_urls + else "" + ) prompt = TITLE_GENERATION_PROMPT.replace( - "{user_query}", user_query[:500] + "{user_query}", title_seed[:500] or "(message)" ) messages = [{"role": "user", "content": prompt}] @@ -1947,10 +1960,15 @@ async def stream_new_chat( # Fire background memory extraction if the agent didn't handle it. # Shared threads write to team memory; private threads write to user memory. if not stream_result.agent_called_update_memory: + memory_seed = user_query.strip() or ( + f"[{len(user_image_data_urls or [])} image(s)]" + if user_image_data_urls + else "(message)" + ) if visibility == ChatVisibility.SEARCH_SPACE: task = asyncio.create_task( extract_and_save_team_memory( - user_message=user_query, + user_message=memory_seed, search_space_id=search_space_id, llm=llm, author_display_name=current_user_display_name, @@ -1961,7 +1979,7 @@ async def stream_new_chat( elif user_id: task = asyncio.create_task( extract_and_save_memory( - user_message=user_query, + user_message=memory_seed, user_id=user_id, llm=llm, ) diff --git a/surfsense_backend/app/utils/user_message_multimodal.py b/surfsense_backend/app/utils/user_message_multimodal.py new file mode 100644 index 000000000..1d0691697 --- /dev/null +++ b/surfsense_backend/app/utils/user_message_multimodal.py @@ -0,0 +1,80 @@ +"""Helpers for multimodal user turns (text + inline images) in LangChain messages.""" + +from __future__ import annotations + +import base64 +import binascii +from typing import Any + + +def build_human_message_content(final_query: str, image_data_urls: list[str]) -> str | list[dict[str, Any]]: + if not image_data_urls: + return final_query + parts: list[dict[str, Any]] = [{"type": "text", "text": final_query}] + for url in image_data_urls: + parts.append({"type": "image_url", "image_url": {"url": url}}) + return parts + + +def split_langchain_human_content(content: str | list[Any]) -> tuple[str, list[str]]: + """Return plain text and data URLs from a LangChain HumanMessage ``content`` value.""" + if isinstance(content, str): + return content, [] + if not isinstance(content, list): + return "", [] + + text_chunks: list[str] = [] + urls: list[str] = [] + for block in content: + if not isinstance(block, dict): + continue + btype = block.get("type") + if btype == "text": + t = block.get("text") + if isinstance(t, str) and t: + text_chunks.append(t) + elif btype == "image_url": + iu = block.get("image_url") + if isinstance(iu, dict): + u = iu.get("url") + if isinstance(u, str) and u.startswith("data:"): + urls.append(u) + elif isinstance(iu, str) and iu.startswith("data:"): + urls.append(iu) + return "\n".join(text_chunks), urls + + +def decode_base64_image(data: str, *, max_bytes: int) -> bytes: + raw = data.strip() + if not raw: + raise ValueError("empty image payload") + try: + decoded = base64.b64decode(raw, validate=True) + except binascii.Error as e: + raise ValueError("invalid base64 image data") from e + if len(decoded) > max_bytes: + raise ValueError("image exceeds maximum size") + return decoded + + +def to_data_url(media_type: str, raw_b64: str) -> str: + return f"data:{media_type};base64,{raw_b64.strip()}" + + +def split_persisted_user_content_parts(parts: list[Any]) -> tuple[str, list[str]]: + """Extract plain text and data URLs from persisted assistant-ui style user ``content``.""" + text_chunks: list[str] = [] + urls: list[str] = [] + for block in parts: + if not isinstance(block, dict): + continue + btype = block.get("type") + if btype == "text": + t = block.get("text") + if isinstance(t, str): + text_chunks.append(t) + elif btype == "image": + u = block.get("image") + if isinstance(u, str) and u.startswith("data:"): + urls.append(u) + return "".join(text_chunks), urls