diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 25c53b69e..29c4a6857 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -45,6 +45,7 @@ from app.schemas.new_chat import ( NewChatThreadUpdate, NewChatThreadVisibilityUpdate, NewChatThreadWithMessages, + RegenerateRequest, ThreadHistoryLoadResponse, ThreadListItem, ThreadListResponse, @@ -1013,6 +1014,238 @@ async def handle_new_chat( ) from None +# ============================================================================= +# Chat Regeneration Endpoint (Edit/Reload) +# ============================================================================= + + +@router.post("/threads/{thread_id}/regenerate") +async def regenerate_response( + thread_id: int, + request: RegenerateRequest, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Regenerate the AI response for a chat thread. + + This endpoint supports two operations: + 1. **Edit**: Provide a new `user_query` to replace the last user message and regenerate + 2. **Reload**: Leave `user_query` empty (or None) to regenerate with the same query + + Both operations: + - Rewind the LangGraph checkpointer to the state before the last AI response + - Delete the last user message and AI response from the database + - Stream a new response from that checkpoint + + Access is granted if: + - User is the creator of the thread + - Thread visibility is SEARCH_SPACE + + Requires CHATS_UPDATE permission. + """ + from langchain_core.messages import HumanMessage + + from app.agents.new_chat.checkpointer import get_checkpointer + + try: + # Verify thread exists and user has permission + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_UPDATE.value, + "You don't have permission to update chats in this search space", + ) + + # Check thread-level access based on visibility + await check_thread_access(session, thread, user) + + # Get the checkpointer and state history + checkpointer = await get_checkpointer() + + config = {"configurable": {"thread_id": str(thread_id)}} + + # Collect checkpoint tuples from the async iterator + # CheckpointTuple has: config, checkpoint (dict with channel_values), metadata, parent_config + checkpoint_tuples = [] + async for cp_tuple in checkpointer.alist(config): + checkpoint_tuples.append(cp_tuple) + + if not checkpoint_tuples: + raise HTTPException( + status_code=400, detail="No conversation history found for this thread" + ) + + # Find the checkpoint to rewind to + # Checkpoints are in reverse chronological order (newest first) + # We need to find a checkpoint before the last user message was added + # + # The checkpointer stores states after each node execution. + # For a typical conversation flow: + # - User sends message -> state 1 (with HumanMessage) + # - Agent responds -> state 2 (with HumanMessage + AIMessage) + # + # To regenerate, we need the state BEFORE the last HumanMessage was processed + + target_checkpoint_id = None + user_query_to_use = request.user_query + + # Look through checkpoints to find the right one + # We want to find the checkpoint just before the last HumanMessage + for i, cp_tuple in enumerate(checkpoint_tuples): + # Access the checkpoint's channel_values which contains "messages" + checkpoint_data = cp_tuple.checkpoint + channel_values = checkpoint_data.get("channel_values", {}) + state_messages = channel_values.get("messages", []) + + if state_messages: + last_msg = state_messages[-1] + # Find a checkpoint where the last message is NOT a HumanMessage + # This means we're at a state before the user's last message + if not isinstance(last_msg, HumanMessage): + # If no new user_query provided (reload), extract from a later checkpoint + if user_query_to_use is None and i > 0: + # Get the user query from a more recent checkpoint + for prev_cp_tuple in checkpoint_tuples[:i]: + prev_checkpoint_data = prev_cp_tuple.checkpoint + prev_channel_values = prev_checkpoint_data.get( + "channel_values", {} + ) + prev_messages = prev_channel_values.get("messages", []) + for msg in reversed(prev_messages): + if isinstance(msg, HumanMessage): + user_query_to_use = msg.content + break + if user_query_to_use: + break + + target_checkpoint_id = cp_tuple.config["configurable"][ + "checkpoint_id" + ] + break + + # If we couldn't find a good checkpoint, try alternative approaches + if target_checkpoint_id is None and checkpoint_tuples: + if len(checkpoint_tuples) == 1: + # Only one checkpoint - get the user query from it if not provided + if user_query_to_use is None: + checkpoint_data = checkpoint_tuples[0].checkpoint + channel_values = checkpoint_data.get("channel_values", {}) + state_messages = channel_values.get("messages", []) + for msg in state_messages: + if isinstance(msg, HumanMessage): + user_query_to_use = msg.content + break + else: + # Use the oldest checkpoint + target_checkpoint_id = checkpoint_tuples[-1].config["configurable"][ + "checkpoint_id" + ] + + # If we still don't have a user query, get it from the database + if user_query_to_use is None: + # Get the last user message from the database + last_user_msg_result = await session.execute( + select(NewChatMessage) + .filter( + NewChatMessage.thread_id == thread_id, + NewChatMessage.role == NewChatMessageRole.USER, + ) + .order_by(NewChatMessage.created_at.desc()) + .limit(1) + ) + last_user_msg = last_user_msg_result.scalars().first() + if last_user_msg: + content = last_user_msg.content + if isinstance(content, str): + user_query_to_use = content + elif isinstance(content, list): + # Extract text from content parts + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + user_query_to_use = part.get("text", "") + break + elif isinstance(part, str): + user_query_to_use = part + break + + if user_query_to_use is None: + raise HTTPException( + status_code=400, + detail="Could not determine user query for regeneration. Please provide a user_query.", + ) + + # Delete the last user message and assistant response from the database + # Get the last two messages (should be user + assistant) + last_messages_result = await session.execute( + select(NewChatMessage) + .filter(NewChatMessage.thread_id == thread_id) + .order_by(NewChatMessage.created_at.desc()) + .limit(2) + ) + last_messages = last_messages_result.scalars().all() + + for msg in last_messages: + await session.delete(msg) + + await session.commit() + + # Get search space for LLM config + search_space_result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == request.search_space_id) + ) + search_space = search_space_result.scalars().first() + + if not search_space: + raise HTTPException(status_code=404, detail="Search space not found") + + llm_config_id = ( + search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 + ) + + # Return streaming response with checkpoint_id for rewinding + return StreamingResponse( + stream_new_chat( + user_query=user_query_to_use, + search_space_id=request.search_space_id, + chat_id=thread_id, + session=session, + user_id=str(user.id), + llm_config_id=llm_config_id, + attachments=request.attachments, + mentioned_document_ids=request.mentioned_document_ids, + mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids, + checkpoint_id=target_checkpoint_id, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + except HTTPException: + raise + except Exception as e: + import traceback + + traceback.print_exc() + raise HTTPException( + status_code=500, + detail=f"An unexpected error occurred during regeneration: {e!s}", + ) from None + + # ============================================================================= # Attachment Processing Endpoint # ============================================================================= diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index 24e779b50..0881f997f 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -184,3 +184,21 @@ class NewChatRequest(BaseModel): mentioned_surfsense_doc_ids: list[int] | None = ( None # Optional SurfSense documentation IDs mentioned with @ in the chat ) + + +class RegenerateRequest(BaseModel): + """ + Request schema for regenerating an AI response. + + This supports two operations: + 1. Edit: Provide a new user_query to replace the last user message and regenerate + 2. Reload: Leave user_query empty to regenerate the last AI response with the same query + + Both operations rewind the LangGraph checkpointer to the appropriate state. + """ + + search_space_id: int + user_query: str | None = None # New user query (for edit). None = reload with same query + attachments: list[ChatAttachment] | None = None + mentioned_document_ids: list[int] | None = None + mentioned_surfsense_doc_ids: list[int] | None = None diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 31229a59b..a49c244eb 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -159,6 +159,7 @@ async def stream_new_chat( attachments: list[ChatAttachment] | None = None, mentioned_document_ids: list[int] | None = None, mentioned_surfsense_doc_ids: list[int] | None = None, + checkpoint_id: str | None = None, ) -> AsyncGenerator[str, None]: """ Stream chat responses from the new SurfSense deep agent. @@ -177,6 +178,7 @@ async def stream_new_chat( attachments: Optional attachments with extracted content mentioned_document_ids: Optional list of document IDs mentioned with @ in the chat mentioned_surfsense_doc_ids: Optional list of SurfSense doc IDs mentioned with @ in the chat + checkpoint_id: Optional checkpoint ID to rewind/fork from (for edit/reload operations) Yields: str: SSE formatted response strings @@ -325,10 +327,13 @@ async def stream_new_chat( } # Configure LangGraph with thread_id for memory + # If checkpoint_id is provided, fork from that checkpoint (for edit/reload) + configurable = {"thread_id": str(chat_id)} + if checkpoint_id: + configurable["checkpoint_id"] = checkpoint_id + config = { - "configurable": { - "thread_id": str(chat_id), - }, + "configurable": configurable, "recursion_limit": 80, # Increase from default 25 to allow more tool iterations } diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 4d43b7f64..84b3f93ff 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -48,6 +48,7 @@ import { appendMessage, type ChatVisibility, createThread, + getRegenerateUrl, getThreadFull, getThreadMessages, type MessageRecord, @@ -1045,16 +1046,415 @@ export default function NewChatPage() { [] ); - // Handle editing a message - removes messages after the edited one and sends as new + /** + * Handle regeneration (edit or reload) by calling the regenerate endpoint + * and streaming the response. This rewinds the LangGraph checkpointer state. + * + * @param newUserQuery - The new user query (for edit). Pass null/undefined for reload. + */ + const handleRegenerate = useCallback( + async (newUserQuery?: string | null) => { + if (!threadId) { + toast.error("Cannot regenerate: no active chat thread"); + return; + } + + // Abort any previous streaming request + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + abortControllerRef.current = null; + } + + const token = getBearerToken(); + if (!token) { + toast.error("Not authenticated. Please log in again."); + return; + } + + // Extract the original user query BEFORE removing messages (for reload mode) + let userQueryToDisplay = newUserQuery; + let originalUserMessageContent: ThreadMessageLike["content"] | null = null; + let originalUserMessageAttachments: ThreadMessageLike["attachments"] | undefined; + let originalUserMessageMetadata: ThreadMessageLike["metadata"] | undefined; + + if (!newUserQuery) { + // Reload mode - find and preserve the last user message content + const lastUserMessage = [...messages].reverse().find((m) => m.role === "user"); + if (lastUserMessage) { + originalUserMessageContent = lastUserMessage.content; + originalUserMessageAttachments = lastUserMessage.attachments; + originalUserMessageMetadata = lastUserMessage.metadata; + // Extract text for the API request + for (const part of lastUserMessage.content) { + if (typeof part === "object" && part.type === "text" && "text" in part) { + userQueryToDisplay = part.text; + break; + } + } + } + } + + // Remove the last two messages (user + assistant) from the UI immediately + // The backend will also delete them from the database + setMessages((prev) => { + if (prev.length >= 2) { + return prev.slice(0, -2); + } + return prev; + }); + + // Clear thinking steps for the removed messages + setMessageThinkingSteps((prev) => { + const newMap = new Map(prev); + // Remove thinking steps for the last two messages + const lastTwoIds = messages.slice(-2).map((m) => m.id).filter((id): id is string => !!id); + for (const id of lastTwoIds) { + newMap.delete(id); + } + return newMap; + }); + + // Start streaming + setIsRunning(true); + const controller = new AbortController(); + abortControllerRef.current = controller; + + // Add placeholder user message if we have a new query (edit mode) + const userMsgId = `msg-user-${Date.now()}`; + const assistantMsgId = `msg-assistant-${Date.now()}`; + const currentThinkingSteps = new Map(); + + // Content parts tracking (same as onNew) + type ContentPart = + | { type: "text"; text: string } + | { + type: "tool-call"; + toolCallId: string; + toolName: string; + args: Record; + result?: unknown; + }; + const contentParts: ContentPart[] = []; + let currentTextPartIndex = -1; + const toolCallIndices = new Map(); + + const appendText = (delta: string) => { + if (currentTextPartIndex >= 0 && contentParts[currentTextPartIndex]?.type === "text") { + (contentParts[currentTextPartIndex] as { type: "text"; text: string }).text += delta; + } else { + contentParts.push({ type: "text", text: delta }); + currentTextPartIndex = contentParts.length - 1; + } + }; + + const addToolCall = (toolCallId: string, toolName: string, args: Record) => { + if (TOOLS_WITH_UI.has(toolName)) { + contentParts.push({ type: "tool-call", toolCallId, toolName, args }); + toolCallIndices.set(toolCallId, contentParts.length - 1); + currentTextPartIndex = -1; + } + }; + + const updateToolCall = ( + toolCallId: string, + update: { args?: Record; result?: unknown } + ) => { + const index = toolCallIndices.get(toolCallId); + if (index !== undefined && contentParts[index]?.type === "tool-call") { + const tc = contentParts[index] as ContentPart & { type: "tool-call" }; + if (update.args) tc.args = update.args; + if (update.result !== undefined) tc.result = update.result; + } + }; + + const buildContentForUI = (): ThreadMessageLike["content"] => { + const filtered = contentParts.filter((part) => { + if (part.type === "text") return part.text.length > 0; + if (part.type === "tool-call") return TOOLS_WITH_UI.has(part.toolName); + return false; + }); + return filtered.length > 0 + ? (filtered as ThreadMessageLike["content"]) + : [{ type: "text", text: "" }]; + }; + + const buildContentForPersistence = (): unknown[] => { + const parts: unknown[] = []; + if (currentThinkingSteps.size > 0) { + parts.push({ + type: "thinking-steps", + steps: Array.from(currentThinkingSteps.values()), + }); + } + for (const part of contentParts) { + if (part.type === "text" && part.text.length > 0) { + parts.push(part); + } else if (part.type === "tool-call" && TOOLS_WITH_UI.has(part.toolName)) { + parts.push(part); + } + } + return parts.length > 0 ? parts : [{ type: "text", text: "" }]; + }; + + // Add placeholder messages to UI + // Always add back the user message (with new query for edit, or original content for reload) + const userMessage: ThreadMessageLike = { + id: userMsgId, + role: "user", + content: newUserQuery + ? [{ type: "text", text: newUserQuery }] + : originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }], + createdAt: new Date(), + attachments: newUserQuery ? undefined : originalUserMessageAttachments, + metadata: newUserQuery ? undefined : originalUserMessageMetadata, + }; + setMessages((prev) => [...prev, userMessage]); + + // Add placeholder assistant message + setMessages((prev) => [ + ...prev, + { + id: assistantMsgId, + role: "assistant", + content: [{ type: "text", text: "" }], + createdAt: new Date(), + }, + ]); + + try { + const response = await fetch(getRegenerateUrl(threadId), { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + search_space_id: searchSpaceId, + user_query: newUserQuery || null, + }), + signal: controller.signal, + }); + + if (!response.ok) { + throw new Error(`Backend error: ${response.status}`); + } + + if (!response.body) { + throw new Error("No response body"); + } + + // Parse SSE stream (same logic as onNew) + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const events = buffer.split(/\r?\n\r?\n/); + buffer = events.pop() || ""; + + for (const event of events) { + const lines = event.split(/\r?\n/); + for (const line of lines) { + if (!line.startsWith("data: ")) continue; + const data = line.slice(6).trim(); + if (!data || data === "[DONE]") continue; + + try { + const parsed = JSON.parse(data); + + switch (parsed.type) { + case "text-delta": + appendText(parsed.delta); + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m + ) + ); + break; + + case "tool-input-start": + addToolCall(parsed.toolCallId, parsed.toolName, {}); + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m + ) + ); + break; + + case "tool-input-available": + if (toolCallIndices.has(parsed.toolCallId)) { + updateToolCall(parsed.toolCallId, { args: parsed.input || {} }); + } else { + addToolCall(parsed.toolCallId, parsed.toolName, parsed.input || {}); + } + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m + ) + ); + break; + + case "tool-output-available": + updateToolCall(parsed.toolCallId, { result: parsed.output }); + if (parsed.output?.status === "processing" && parsed.output?.task_id) { + const idx = toolCallIndices.get(parsed.toolCallId); + if (idx !== undefined) { + const part = contentParts[idx]; + if (part?.type === "tool-call" && part.toolName === "generate_podcast") { + setActivePodcastTaskId(parsed.output.task_id); + } + } + } + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m + ) + ); + break; + + case "data-thinking-step": { + const stepData = parsed.data as ThinkingStepData; + if (stepData?.id) { + currentThinkingSteps.set(stepData.id, stepData); + setMessageThinkingSteps((prev) => { + const newMap = new Map(prev); + newMap.set(assistantMsgId, Array.from(currentThinkingSteps.values())); + return newMap; + }); + } + break; + } + + case "error": + throw new Error(parsed.errorText || "Server error"); + } + } catch (e) { + if (e instanceof SyntaxError) continue; + throw e; + } + } + } + } + } finally { + reader.releaseLock(); + } + + // Persist messages after streaming completes + const finalContent = buildContentForPersistence(); + if (contentParts.length > 0) { + try { + // Persist user message (for both edit and reload modes, since backend deleted it) + const userContentToPersist = newUserQuery + ? [{ type: "text", text: newUserQuery }] + : originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }]; + + const savedUserMessage = await appendMessage(threadId, { + role: "user", + content: userContentToPersist, + }); + + // Update user message ID to database ID + const newUserMsgId = `msg-${savedUserMessage.id}`; + setMessages((prev) => + prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m)) + ); + + // Persist assistant message + const savedMessage = await appendMessage(threadId, { + role: "assistant", + content: finalContent, + }); + + // Update assistant message ID to database ID + const newMsgId = `msg-${savedMessage.id}`; + setMessages((prev) => + prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + ); + + setMessageThinkingSteps((prev) => { + const steps = prev.get(assistantMsgId); + if (steps) { + const newMap = new Map(prev); + newMap.delete(assistantMsgId); + newMap.set(newMsgId, steps); + return newMap; + } + return prev; + }); + + // Track successful response + trackChatResponseReceived(searchSpaceId, threadId); + } catch (err) { + console.error("Failed to persist regenerated message:", err); + } + } + } catch (error) { + if (error instanceof Error && error.name === "AbortError") { + return; + } + console.error("[NewChatPage] Regeneration error:", error); + trackChatError( + searchSpaceId, + threadId, + error instanceof Error ? error.message : "Unknown error" + ); + toast.error("Failed to regenerate response. Please try again."); + // Update assistant message with error + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId + ? { + ...m, + content: [ + { type: "text", text: "Sorry, there was an error. Please try again." }, + ], + } + : m + ) + ); + } finally { + setIsRunning(false); + abortControllerRef.current = null; + } + }, + [threadId, searchSpaceId, messages, setMessageThinkingSteps] + ); + + // Handle editing a message - truncates history and regenerates with new query const onEdit = useCallback( async (message: AppendMessage) => { - // Find the message being edited by looking at the parentId - // The parentId tells us which message's response we're editing - // For now, we'll just treat edits like new messages - // A more sophisticated implementation would truncate the history - await onNew(message); + // Extract the new user query from the message content + let newUserQuery = ""; + for (const part of message.content) { + if (part.type === "text") { + newUserQuery += part.text; + } + } + + if (!newUserQuery.trim()) { + toast.error("Cannot edit with empty message"); + return; + } + + // Call regenerate with the new query + await handleRegenerate(newUserQuery.trim()); }, - [onNew] + [handleRegenerate] + ); + + // Handle reloading/refreshing the last AI response + const onReload = useCallback( + async (parentId: string | null) => { + // parentId is the ID of the message to reload from (the user message) + // We call regenerate without a query to use the same query + await handleRegenerate(null); + }, + [handleRegenerate] ); // Create external store runtime with attachment support @@ -1063,6 +1463,7 @@ export default function NewChatPage() { isRunning, onNew, onEdit, + onReload, convertMessage, onCancel: cancelRun, adapters: { diff --git a/surfsense_web/lib/chat/thread-persistence.ts b/surfsense_web/lib/chat/thread-persistence.ts index 738d1062f..08c08ba78 100644 --- a/surfsense_web/lib/chat/thread-persistence.ts +++ b/surfsense_web/lib/chat/thread-persistence.ts @@ -160,6 +160,30 @@ export async function getThreadFull(threadId: number): Promise { return baseApiService.get(`/api/v1/threads/${threadId}/full`); } +/** + * Regeneration request parameters + */ +export interface RegenerateParams { + searchSpaceId: number; + userQuery?: string | null; // New user query (for edit). Null/undefined = reload with same query + attachments?: Array<{ + id: string; + name: string; + type: string; + content: string; + }>; + mentionedDocumentIds?: number[]; + mentionedSurfsenseDocIds?: number[]; +} + +/** + * Get the URL for the regenerate endpoint (for streaming fetch) + */ +export function getRegenerateUrl(threadId: number): string { + const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; + return `${backendUrl}/api/v1/threads/${threadId}/regenerate`; +} + // ============================================================================= // Thread List Manager (for thread list sidebar) // =============================================================================