mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-02 19:55:18 +02:00
feat(chat): add regenerate endpoint for chat threads to support editing and reloading responses
This commit is contained in:
parent
14b6001489
commit
ad475397c4
5 changed files with 691 additions and 10 deletions
|
|
@ -45,6 +45,7 @@ from app.schemas.new_chat import (
|
|||
NewChatThreadUpdate,
|
||||
NewChatThreadVisibilityUpdate,
|
||||
NewChatThreadWithMessages,
|
||||
RegenerateRequest,
|
||||
ThreadHistoryLoadResponse,
|
||||
ThreadListItem,
|
||||
ThreadListResponse,
|
||||
|
|
@ -1013,6 +1014,238 @@ async def handle_new_chat(
|
|||
) from None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Chat Regeneration Endpoint (Edit/Reload)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/regenerate")
|
||||
async def regenerate_response(
|
||||
thread_id: int,
|
||||
request: RegenerateRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Regenerate the AI response for a chat thread.
|
||||
|
||||
This endpoint supports two operations:
|
||||
1. **Edit**: Provide a new `user_query` to replace the last user message and regenerate
|
||||
2. **Reload**: Leave `user_query` empty (or None) to regenerate with the same query
|
||||
|
||||
Both operations:
|
||||
- Rewind the LangGraph checkpointer to the state before the last AI response
|
||||
- Delete the last user message and AI response from the database
|
||||
- Stream a new response from that checkpoint
|
||||
|
||||
Access is granted if:
|
||||
- User is the creator of the thread
|
||||
- Thread visibility is SEARCH_SPACE
|
||||
|
||||
Requires CHATS_UPDATE permission.
|
||||
"""
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||
|
||||
try:
|
||||
# Verify thread exists and user has permission
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
)
|
||||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
|
||||
# Get the checkpointer and state history
|
||||
checkpointer = await get_checkpointer()
|
||||
|
||||
config = {"configurable": {"thread_id": str(thread_id)}}
|
||||
|
||||
# Collect checkpoint tuples from the async iterator
|
||||
# CheckpointTuple has: config, checkpoint (dict with channel_values), metadata, parent_config
|
||||
checkpoint_tuples = []
|
||||
async for cp_tuple in checkpointer.alist(config):
|
||||
checkpoint_tuples.append(cp_tuple)
|
||||
|
||||
if not checkpoint_tuples:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No conversation history found for this thread"
|
||||
)
|
||||
|
||||
# Find the checkpoint to rewind to
|
||||
# Checkpoints are in reverse chronological order (newest first)
|
||||
# We need to find a checkpoint before the last user message was added
|
||||
#
|
||||
# The checkpointer stores states after each node execution.
|
||||
# For a typical conversation flow:
|
||||
# - User sends message -> state 1 (with HumanMessage)
|
||||
# - Agent responds -> state 2 (with HumanMessage + AIMessage)
|
||||
#
|
||||
# To regenerate, we need the state BEFORE the last HumanMessage was processed
|
||||
|
||||
target_checkpoint_id = None
|
||||
user_query_to_use = request.user_query
|
||||
|
||||
# Look through checkpoints to find the right one
|
||||
# We want to find the checkpoint just before the last HumanMessage
|
||||
for i, cp_tuple in enumerate(checkpoint_tuples):
|
||||
# Access the checkpoint's channel_values which contains "messages"
|
||||
checkpoint_data = cp_tuple.checkpoint
|
||||
channel_values = checkpoint_data.get("channel_values", {})
|
||||
state_messages = channel_values.get("messages", [])
|
||||
|
||||
if state_messages:
|
||||
last_msg = state_messages[-1]
|
||||
# Find a checkpoint where the last message is NOT a HumanMessage
|
||||
# This means we're at a state before the user's last message
|
||||
if not isinstance(last_msg, HumanMessage):
|
||||
# If no new user_query provided (reload), extract from a later checkpoint
|
||||
if user_query_to_use is None and i > 0:
|
||||
# Get the user query from a more recent checkpoint
|
||||
for prev_cp_tuple in checkpoint_tuples[:i]:
|
||||
prev_checkpoint_data = prev_cp_tuple.checkpoint
|
||||
prev_channel_values = prev_checkpoint_data.get(
|
||||
"channel_values", {}
|
||||
)
|
||||
prev_messages = prev_channel_values.get("messages", [])
|
||||
for msg in reversed(prev_messages):
|
||||
if isinstance(msg, HumanMessage):
|
||||
user_query_to_use = msg.content
|
||||
break
|
||||
if user_query_to_use:
|
||||
break
|
||||
|
||||
target_checkpoint_id = cp_tuple.config["configurable"][
|
||||
"checkpoint_id"
|
||||
]
|
||||
break
|
||||
|
||||
# If we couldn't find a good checkpoint, try alternative approaches
|
||||
if target_checkpoint_id is None and checkpoint_tuples:
|
||||
if len(checkpoint_tuples) == 1:
|
||||
# Only one checkpoint - get the user query from it if not provided
|
||||
if user_query_to_use is None:
|
||||
checkpoint_data = checkpoint_tuples[0].checkpoint
|
||||
channel_values = checkpoint_data.get("channel_values", {})
|
||||
state_messages = channel_values.get("messages", [])
|
||||
for msg in state_messages:
|
||||
if isinstance(msg, HumanMessage):
|
||||
user_query_to_use = msg.content
|
||||
break
|
||||
else:
|
||||
# Use the oldest checkpoint
|
||||
target_checkpoint_id = checkpoint_tuples[-1].config["configurable"][
|
||||
"checkpoint_id"
|
||||
]
|
||||
|
||||
# If we still don't have a user query, get it from the database
|
||||
if user_query_to_use is None:
|
||||
# Get the last user message from the database
|
||||
last_user_msg_result = await session.execute(
|
||||
select(NewChatMessage)
|
||||
.filter(
|
||||
NewChatMessage.thread_id == thread_id,
|
||||
NewChatMessage.role == NewChatMessageRole.USER,
|
||||
)
|
||||
.order_by(NewChatMessage.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
last_user_msg = last_user_msg_result.scalars().first()
|
||||
if last_user_msg:
|
||||
content = last_user_msg.content
|
||||
if isinstance(content, str):
|
||||
user_query_to_use = content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from content parts
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
user_query_to_use = part.get("text", "")
|
||||
break
|
||||
elif isinstance(part, str):
|
||||
user_query_to_use = part
|
||||
break
|
||||
|
||||
if user_query_to_use is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Could not determine user query for regeneration. Please provide a user_query.",
|
||||
)
|
||||
|
||||
# Delete the last user message and assistant response from the database
|
||||
# Get the last two messages (should be user + assistant)
|
||||
last_messages_result = await session.execute(
|
||||
select(NewChatMessage)
|
||||
.filter(NewChatMessage.thread_id == thread_id)
|
||||
.order_by(NewChatMessage.created_at.desc())
|
||||
.limit(2)
|
||||
)
|
||||
last_messages = last_messages_result.scalars().all()
|
||||
|
||||
for msg in last_messages:
|
||||
await session.delete(msg)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Get search space for LLM config
|
||||
search_space_result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
||||
)
|
||||
search_space = search_space_result.scalars().first()
|
||||
|
||||
if not search_space:
|
||||
raise HTTPException(status_code=404, detail="Search space not found")
|
||||
|
||||
llm_config_id = (
|
||||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
)
|
||||
|
||||
# Return streaming response with checkpoint_id for rewinding
|
||||
return StreamingResponse(
|
||||
stream_new_chat(
|
||||
user_query=user_query_to_use,
|
||||
search_space_id=request.search_space_id,
|
||||
chat_id=thread_id,
|
||||
session=session,
|
||||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
attachments=request.attachments,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||
checkpoint_id=target_checkpoint_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An unexpected error occurred during regeneration: {e!s}",
|
||||
) from None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Attachment Processing Endpoint
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -184,3 +184,21 @@ class NewChatRequest(BaseModel):
|
|||
mentioned_surfsense_doc_ids: list[int] | None = (
|
||||
None # Optional SurfSense documentation IDs mentioned with @ in the chat
|
||||
)
|
||||
|
||||
|
||||
class RegenerateRequest(BaseModel):
|
||||
"""
|
||||
Request schema for regenerating an AI response.
|
||||
|
||||
This supports two operations:
|
||||
1. Edit: Provide a new user_query to replace the last user message and regenerate
|
||||
2. Reload: Leave user_query empty to regenerate the last AI response with the same query
|
||||
|
||||
Both operations rewind the LangGraph checkpointer to the appropriate state.
|
||||
"""
|
||||
|
||||
search_space_id: int
|
||||
user_query: str | None = None # New user query (for edit). None = reload with same query
|
||||
attachments: list[ChatAttachment] | None = None
|
||||
mentioned_document_ids: list[int] | None = None
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None
|
||||
|
|
|
|||
|
|
@ -159,6 +159,7 @@ async def stream_new_chat(
|
|||
attachments: list[ChatAttachment] | None = None,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None,
|
||||
checkpoint_id: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream chat responses from the new SurfSense deep agent.
|
||||
|
|
@ -177,6 +178,7 @@ async def stream_new_chat(
|
|||
attachments: Optional attachments with extracted content
|
||||
mentioned_document_ids: Optional list of document IDs mentioned with @ in the chat
|
||||
mentioned_surfsense_doc_ids: Optional list of SurfSense doc IDs mentioned with @ in the chat
|
||||
checkpoint_id: Optional checkpoint ID to rewind/fork from (for edit/reload operations)
|
||||
|
||||
Yields:
|
||||
str: SSE formatted response strings
|
||||
|
|
@ -325,10 +327,13 @@ async def stream_new_chat(
|
|||
}
|
||||
|
||||
# Configure LangGraph with thread_id for memory
|
||||
# If checkpoint_id is provided, fork from that checkpoint (for edit/reload)
|
||||
configurable = {"thread_id": str(chat_id)}
|
||||
if checkpoint_id:
|
||||
configurable["checkpoint_id"] = checkpoint_id
|
||||
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": str(chat_id),
|
||||
},
|
||||
"configurable": configurable,
|
||||
"recursion_limit": 80, # Increase from default 25 to allow more tool iterations
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ import {
|
|||
appendMessage,
|
||||
type ChatVisibility,
|
||||
createThread,
|
||||
getRegenerateUrl,
|
||||
getThreadFull,
|
||||
getThreadMessages,
|
||||
type MessageRecord,
|
||||
|
|
@ -1045,16 +1046,415 @@ export default function NewChatPage() {
|
|||
[]
|
||||
);
|
||||
|
||||
// Handle editing a message - removes messages after the edited one and sends as new
|
||||
/**
|
||||
* Handle regeneration (edit or reload) by calling the regenerate endpoint
|
||||
* and streaming the response. This rewinds the LangGraph checkpointer state.
|
||||
*
|
||||
* @param newUserQuery - The new user query (for edit). Pass null/undefined for reload.
|
||||
*/
|
||||
const handleRegenerate = useCallback(
|
||||
async (newUserQuery?: string | null) => {
|
||||
if (!threadId) {
|
||||
toast.error("Cannot regenerate: no active chat thread");
|
||||
return;
|
||||
}
|
||||
|
||||
// Abort any previous streaming request
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
|
||||
const token = getBearerToken();
|
||||
if (!token) {
|
||||
toast.error("Not authenticated. Please log in again.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract the original user query BEFORE removing messages (for reload mode)
|
||||
let userQueryToDisplay = newUserQuery;
|
||||
let originalUserMessageContent: ThreadMessageLike["content"] | null = null;
|
||||
let originalUserMessageAttachments: ThreadMessageLike["attachments"] | undefined;
|
||||
let originalUserMessageMetadata: ThreadMessageLike["metadata"] | undefined;
|
||||
|
||||
if (!newUserQuery) {
|
||||
// Reload mode - find and preserve the last user message content
|
||||
const lastUserMessage = [...messages].reverse().find((m) => m.role === "user");
|
||||
if (lastUserMessage) {
|
||||
originalUserMessageContent = lastUserMessage.content;
|
||||
originalUserMessageAttachments = lastUserMessage.attachments;
|
||||
originalUserMessageMetadata = lastUserMessage.metadata;
|
||||
// Extract text for the API request
|
||||
for (const part of lastUserMessage.content) {
|
||||
if (typeof part === "object" && part.type === "text" && "text" in part) {
|
||||
userQueryToDisplay = part.text;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the last two messages (user + assistant) from the UI immediately
|
||||
// The backend will also delete them from the database
|
||||
setMessages((prev) => {
|
||||
if (prev.length >= 2) {
|
||||
return prev.slice(0, -2);
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
|
||||
// Clear thinking steps for the removed messages
|
||||
setMessageThinkingSteps((prev) => {
|
||||
const newMap = new Map(prev);
|
||||
// Remove thinking steps for the last two messages
|
||||
const lastTwoIds = messages.slice(-2).map((m) => m.id).filter((id): id is string => !!id);
|
||||
for (const id of lastTwoIds) {
|
||||
newMap.delete(id);
|
||||
}
|
||||
return newMap;
|
||||
});
|
||||
|
||||
// Start streaming
|
||||
setIsRunning(true);
|
||||
const controller = new AbortController();
|
||||
abortControllerRef.current = controller;
|
||||
|
||||
// Add placeholder user message if we have a new query (edit mode)
|
||||
const userMsgId = `msg-user-${Date.now()}`;
|
||||
const assistantMsgId = `msg-assistant-${Date.now()}`;
|
||||
const currentThinkingSteps = new Map<string, ThinkingStepData>();
|
||||
|
||||
// Content parts tracking (same as onNew)
|
||||
type ContentPart =
|
||||
| { type: "text"; text: string }
|
||||
| {
|
||||
type: "tool-call";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
args: Record<string, unknown>;
|
||||
result?: unknown;
|
||||
};
|
||||
const contentParts: ContentPart[] = [];
|
||||
let currentTextPartIndex = -1;
|
||||
const toolCallIndices = new Map<string, number>();
|
||||
|
||||
const appendText = (delta: string) => {
|
||||
if (currentTextPartIndex >= 0 && contentParts[currentTextPartIndex]?.type === "text") {
|
||||
(contentParts[currentTextPartIndex] as { type: "text"; text: string }).text += delta;
|
||||
} else {
|
||||
contentParts.push({ type: "text", text: delta });
|
||||
currentTextPartIndex = contentParts.length - 1;
|
||||
}
|
||||
};
|
||||
|
||||
const addToolCall = (toolCallId: string, toolName: string, args: Record<string, unknown>) => {
|
||||
if (TOOLS_WITH_UI.has(toolName)) {
|
||||
contentParts.push({ type: "tool-call", toolCallId, toolName, args });
|
||||
toolCallIndices.set(toolCallId, contentParts.length - 1);
|
||||
currentTextPartIndex = -1;
|
||||
}
|
||||
};
|
||||
|
||||
const updateToolCall = (
|
||||
toolCallId: string,
|
||||
update: { args?: Record<string, unknown>; result?: unknown }
|
||||
) => {
|
||||
const index = toolCallIndices.get(toolCallId);
|
||||
if (index !== undefined && contentParts[index]?.type === "tool-call") {
|
||||
const tc = contentParts[index] as ContentPart & { type: "tool-call" };
|
||||
if (update.args) tc.args = update.args;
|
||||
if (update.result !== undefined) tc.result = update.result;
|
||||
}
|
||||
};
|
||||
|
||||
const buildContentForUI = (): ThreadMessageLike["content"] => {
|
||||
const filtered = contentParts.filter((part) => {
|
||||
if (part.type === "text") return part.text.length > 0;
|
||||
if (part.type === "tool-call") return TOOLS_WITH_UI.has(part.toolName);
|
||||
return false;
|
||||
});
|
||||
return filtered.length > 0
|
||||
? (filtered as ThreadMessageLike["content"])
|
||||
: [{ type: "text", text: "" }];
|
||||
};
|
||||
|
||||
const buildContentForPersistence = (): unknown[] => {
|
||||
const parts: unknown[] = [];
|
||||
if (currentThinkingSteps.size > 0) {
|
||||
parts.push({
|
||||
type: "thinking-steps",
|
||||
steps: Array.from(currentThinkingSteps.values()),
|
||||
});
|
||||
}
|
||||
for (const part of contentParts) {
|
||||
if (part.type === "text" && part.text.length > 0) {
|
||||
parts.push(part);
|
||||
} else if (part.type === "tool-call" && TOOLS_WITH_UI.has(part.toolName)) {
|
||||
parts.push(part);
|
||||
}
|
||||
}
|
||||
return parts.length > 0 ? parts : [{ type: "text", text: "" }];
|
||||
};
|
||||
|
||||
// Add placeholder messages to UI
|
||||
// Always add back the user message (with new query for edit, or original content for reload)
|
||||
const userMessage: ThreadMessageLike = {
|
||||
id: userMsgId,
|
||||
role: "user",
|
||||
content: newUserQuery
|
||||
? [{ type: "text", text: newUserQuery }]
|
||||
: originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }],
|
||||
createdAt: new Date(),
|
||||
attachments: newUserQuery ? undefined : originalUserMessageAttachments,
|
||||
metadata: newUserQuery ? undefined : originalUserMessageMetadata,
|
||||
};
|
||||
setMessages((prev) => [...prev, userMessage]);
|
||||
|
||||
// Add placeholder assistant message
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
id: assistantMsgId,
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "" }],
|
||||
createdAt: new Date(),
|
||||
},
|
||||
]);
|
||||
|
||||
try {
|
||||
const response = await fetch(getRegenerateUrl(threadId), {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
search_space_id: searchSpaceId,
|
||||
user_query: newUserQuery || null,
|
||||
}),
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Backend error: ${response.status}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
// Parse SSE stream (same logic as onNew)
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const events = buffer.split(/\r?\n\r?\n/);
|
||||
buffer = events.pop() || "";
|
||||
|
||||
for (const event of events) {
|
||||
const lines = event.split(/\r?\n/);
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith("data: ")) continue;
|
||||
const data = line.slice(6).trim();
|
||||
if (!data || data === "[DONE]") continue;
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
|
||||
switch (parsed.type) {
|
||||
case "text-delta":
|
||||
appendText(parsed.delta);
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m
|
||||
)
|
||||
);
|
||||
break;
|
||||
|
||||
case "tool-input-start":
|
||||
addToolCall(parsed.toolCallId, parsed.toolName, {});
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m
|
||||
)
|
||||
);
|
||||
break;
|
||||
|
||||
case "tool-input-available":
|
||||
if (toolCallIndices.has(parsed.toolCallId)) {
|
||||
updateToolCall(parsed.toolCallId, { args: parsed.input || {} });
|
||||
} else {
|
||||
addToolCall(parsed.toolCallId, parsed.toolName, parsed.input || {});
|
||||
}
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m
|
||||
)
|
||||
);
|
||||
break;
|
||||
|
||||
case "tool-output-available":
|
||||
updateToolCall(parsed.toolCallId, { result: parsed.output });
|
||||
if (parsed.output?.status === "processing" && parsed.output?.task_id) {
|
||||
const idx = toolCallIndices.get(parsed.toolCallId);
|
||||
if (idx !== undefined) {
|
||||
const part = contentParts[idx];
|
||||
if (part?.type === "tool-call" && part.toolName === "generate_podcast") {
|
||||
setActivePodcastTaskId(parsed.output.task_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m
|
||||
)
|
||||
);
|
||||
break;
|
||||
|
||||
case "data-thinking-step": {
|
||||
const stepData = parsed.data as ThinkingStepData;
|
||||
if (stepData?.id) {
|
||||
currentThinkingSteps.set(stepData.id, stepData);
|
||||
setMessageThinkingSteps((prev) => {
|
||||
const newMap = new Map(prev);
|
||||
newMap.set(assistantMsgId, Array.from(currentThinkingSteps.values()));
|
||||
return newMap;
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "error":
|
||||
throw new Error(parsed.errorText || "Server error");
|
||||
}
|
||||
} catch (e) {
|
||||
if (e instanceof SyntaxError) continue;
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
|
||||
// Persist messages after streaming completes
|
||||
const finalContent = buildContentForPersistence();
|
||||
if (contentParts.length > 0) {
|
||||
try {
|
||||
// Persist user message (for both edit and reload modes, since backend deleted it)
|
||||
const userContentToPersist = newUserQuery
|
||||
? [{ type: "text", text: newUserQuery }]
|
||||
: originalUserMessageContent || [{ type: "text", text: userQueryToDisplay || "" }];
|
||||
|
||||
const savedUserMessage = await appendMessage(threadId, {
|
||||
role: "user",
|
||||
content: userContentToPersist,
|
||||
});
|
||||
|
||||
// Update user message ID to database ID
|
||||
const newUserMsgId = `msg-${savedUserMessage.id}`;
|
||||
setMessages((prev) =>
|
||||
prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m))
|
||||
);
|
||||
|
||||
// Persist assistant message
|
||||
const savedMessage = await appendMessage(threadId, {
|
||||
role: "assistant",
|
||||
content: finalContent,
|
||||
});
|
||||
|
||||
// Update assistant message ID to database ID
|
||||
const newMsgId = `msg-${savedMessage.id}`;
|
||||
setMessages((prev) =>
|
||||
prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m))
|
||||
);
|
||||
|
||||
setMessageThinkingSteps((prev) => {
|
||||
const steps = prev.get(assistantMsgId);
|
||||
if (steps) {
|
||||
const newMap = new Map(prev);
|
||||
newMap.delete(assistantMsgId);
|
||||
newMap.set(newMsgId, steps);
|
||||
return newMap;
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
|
||||
// Track successful response
|
||||
trackChatResponseReceived(searchSpaceId, threadId);
|
||||
} catch (err) {
|
||||
console.error("Failed to persist regenerated message:", err);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
console.error("[NewChatPage] Regeneration error:", error);
|
||||
trackChatError(
|
||||
searchSpaceId,
|
||||
threadId,
|
||||
error instanceof Error ? error.message : "Unknown error"
|
||||
);
|
||||
toast.error("Failed to regenerate response. Please try again.");
|
||||
// Update assistant message with error
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.id === assistantMsgId
|
||||
? {
|
||||
...m,
|
||||
content: [
|
||||
{ type: "text", text: "Sorry, there was an error. Please try again." },
|
||||
],
|
||||
}
|
||||
: m
|
||||
)
|
||||
);
|
||||
} finally {
|
||||
setIsRunning(false);
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
},
|
||||
[threadId, searchSpaceId, messages, setMessageThinkingSteps]
|
||||
);
|
||||
|
||||
// Handle editing a message - truncates history and regenerates with new query
|
||||
const onEdit = useCallback(
|
||||
async (message: AppendMessage) => {
|
||||
// Find the message being edited by looking at the parentId
|
||||
// The parentId tells us which message's response we're editing
|
||||
// For now, we'll just treat edits like new messages
|
||||
// A more sophisticated implementation would truncate the history
|
||||
await onNew(message);
|
||||
// Extract the new user query from the message content
|
||||
let newUserQuery = "";
|
||||
for (const part of message.content) {
|
||||
if (part.type === "text") {
|
||||
newUserQuery += part.text;
|
||||
}
|
||||
}
|
||||
|
||||
if (!newUserQuery.trim()) {
|
||||
toast.error("Cannot edit with empty message");
|
||||
return;
|
||||
}
|
||||
|
||||
// Call regenerate with the new query
|
||||
await handleRegenerate(newUserQuery.trim());
|
||||
},
|
||||
[onNew]
|
||||
[handleRegenerate]
|
||||
);
|
||||
|
||||
// Handle reloading/refreshing the last AI response
|
||||
const onReload = useCallback(
|
||||
async (parentId: string | null) => {
|
||||
// parentId is the ID of the message to reload from (the user message)
|
||||
// We call regenerate without a query to use the same query
|
||||
await handleRegenerate(null);
|
||||
},
|
||||
[handleRegenerate]
|
||||
);
|
||||
|
||||
// Create external store runtime with attachment support
|
||||
|
|
@ -1063,6 +1463,7 @@ export default function NewChatPage() {
|
|||
isRunning,
|
||||
onNew,
|
||||
onEdit,
|
||||
onReload,
|
||||
convertMessage,
|
||||
onCancel: cancelRun,
|
||||
adapters: {
|
||||
|
|
|
|||
|
|
@ -160,6 +160,30 @@ export async function getThreadFull(threadId: number): Promise<ThreadRecord> {
|
|||
return baseApiService.get<ThreadRecord>(`/api/v1/threads/${threadId}/full`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Regeneration request parameters
|
||||
*/
|
||||
export interface RegenerateParams {
|
||||
searchSpaceId: number;
|
||||
userQuery?: string | null; // New user query (for edit). Null/undefined = reload with same query
|
||||
attachments?: Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
type: string;
|
||||
content: string;
|
||||
}>;
|
||||
mentionedDocumentIds?: number[];
|
||||
mentionedSurfsenseDocIds?: number[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the URL for the regenerate endpoint (for streaming fetch)
|
||||
*/
|
||||
export function getRegenerateUrl(threadId: number): string {
|
||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||
return `${backendUrl}/api/v1/threads/${threadId}/regenerate`;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Thread List Manager (for thread list sidebar)
|
||||
// =============================================================================
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue