diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index bc5aca91e..2db00a03d 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -32,6 +32,7 @@ import { membersAtom } from "@/atoms/members/members-query.atoms"; import { currentUserAtom } from "@/atoms/user/user-query.atoms"; import { Thread } from "@/components/assistant-ui/thread"; import { ChatHeader } from "@/components/new-chat/chat-header"; +import { CreateNotionPageToolUI } from "@/components/tool-ui/create-notion-page"; import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking"; import { DisplayImageToolUI } from "@/components/tool-ui/display-image"; import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast"; @@ -120,6 +121,7 @@ const TOOLS_WITH_UI = new Set([ "link_preview", "display_image", "scrape_webpage", + "create_notion_page", // "write_todos", // Disabled for now ]); @@ -147,6 +149,11 @@ export default function NewChatPage() { new Map() ); const abortControllerRef = useRef(null); + const [pendingInterrupt, setPendingInterrupt] = useState<{ + threadId: number; + assistantMsgId: string; + interruptData: Record; + } | null>(null); // Get mentioned document IDs from the composer const mentionedDocumentIds = useAtomValue(mentionedDocumentIdsAtom); @@ -545,6 +552,7 @@ export default function NewChatPage() { result?: unknown; }; const contentParts: ContentPart[] = []; + let wasInterrupted = false; // Track the current text segment index (for appending text deltas) let currentTextPartIndex = -1; @@ -816,27 +824,69 @@ export default function NewChatPage() { String(titleData.threadId), ], }); + } + break; + } + + case "data-interrupt-request": { + wasInterrupted = true; + const interruptData = parsed.data as Record; + const actionRequests = (interruptData.action_requests ?? []) as Array<{ + name: string; + args: Record; + }>; + for (const action of actionRequests) { + const existingIdx = Array.from(toolCallIndices.entries()).find( + ([, idx]) => { + const part = contentParts[idx]; + return part?.type === "tool-call" && part.toolName === action.name; + } + ); + if (existingIdx) { + updateToolCall(existingIdx[0], { + result: { __interrupt__: true, ...interruptData }, + }); + } else { + const tcId = `interrupt-${action.name}`; + addToolCall(tcId, action.name, action.args); + updateToolCall(tcId, { + result: { __interrupt__: true, ...interruptData }, + }); + } + } + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m + ) + ); + if (currentThreadId) { + setPendingInterrupt({ + threadId: currentThreadId, + assistantMsgId, + interruptData, + }); } break; } case "error": throw new Error(parsed.errorText || "Server error"); - } - } catch (e) { - if (e instanceof SyntaxError) continue; - throw e; } + } catch (e) { + if (e instanceof SyntaxError) continue; + throw e; } } } - } finally { + } + } finally { reader.releaseLock(); } // Persist assistant message (with thinking steps for restoration on refresh) + // Skip persistence for interrupted messages -- handleResume will persist the final version const finalContent = buildContentForPersistence(); - if (contentParts.length > 0) { + if (contentParts.length > 0 && !wasInterrupted) { try { const savedMessage = await appendMessage(currentThreadId, { role: "assistant", @@ -849,6 +899,13 @@ export default function NewChatPage() { prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) ); + // Update pending interrupt with the new persisted message ID + setPendingInterrupt((prev) => + prev && prev.assistantMsgId === assistantMsgId + ? { ...prev, assistantMsgId: newMsgId } + : prev + ); + // Also update thinking steps map with new ID setMessageThinkingSteps((prev) => { const steps = prev.get(assistantMsgId); @@ -941,6 +998,379 @@ export default function NewChatPage() { ] ); + const handleResume = useCallback( + async (decisions: Array<{ type: string; message?: string }>) => { + if (!pendingInterrupt) return; + const { threadId: resumeThreadId, assistantMsgId } = pendingInterrupt; + setPendingInterrupt(null); + setIsRunning(true); + + const token = getBearerToken(); + if (!token) { + toast.error("Not authenticated. Please log in again."); + setIsRunning(false); + return; + } + + const controller = new AbortController(); + abortControllerRef.current = controller; + + const currentThinkingSteps = new Map( + (messageThinkingSteps.get(assistantMsgId) ?? []).map((s) => [s.id, s]) + ); + + type ContentPart = + | { type: "text"; text: string } + | { + type: "tool-call"; + toolCallId: string; + toolName: string; + args: Record; + result?: unknown; + }; + const contentParts: ContentPart[] = []; + let currentTextPartIndex = -1; + const toolCallIndices = new Map(); + + const existingMsg = messages.find((m) => m.id === assistantMsgId); + if (existingMsg && Array.isArray(existingMsg.content)) { + for (const part of existingMsg.content) { + if (typeof part === "object" && part !== null) { + const p = part as Record; + if (p.type === "text") { + contentParts.push({ type: "text", text: String(p.text ?? "") }); + currentTextPartIndex = contentParts.length - 1; + } else if (p.type === "tool-call") { + toolCallIndices.set(String(p.toolCallId), contentParts.length); + contentParts.push({ + type: "tool-call", + toolCallId: String(p.toolCallId), + toolName: String(p.toolName), + args: (p.args as Record) ?? {}, + result: p.result as unknown, + }); + currentTextPartIndex = -1; + } + } + } + } + + const decisionType = decisions[0]?.type as "approve" | "reject" | undefined; + if (decisionType) { + for (const part of contentParts) { + if ( + part.type === "tool-call" && + typeof part.result === "object" && + part.result !== null && + "__interrupt__" in (part.result as Record) + ) { + part.result = { + ...(part.result as Record), + __decided__: decisionType, + }; + } + } + } + + const appendText = (delta: string) => { + if (currentTextPartIndex >= 0 && contentParts[currentTextPartIndex]?.type === "text") { + (contentParts[currentTextPartIndex] as { type: "text"; text: string }).text += delta; + } else { + contentParts.push({ type: "text", text: delta }); + currentTextPartIndex = contentParts.length - 1; + } + }; + + const addToolCall = (toolCallId: string, toolName: string, args: Record) => { + if (TOOLS_WITH_UI.has(toolName)) { + contentParts.push({ + type: "tool-call", + toolCallId, + toolName, + args, + }); + toolCallIndices.set(toolCallId, contentParts.length - 1); + currentTextPartIndex = -1; + } + }; + + const updateToolCall = ( + toolCallId: string, + update: { args?: Record; result?: unknown } + ) => { + const index = toolCallIndices.get(toolCallId); + if (index !== undefined && contentParts[index]?.type === "tool-call") { + const tc = contentParts[index] as ContentPart & { type: "tool-call" }; + if (update.args) tc.args = update.args; + if (update.result !== undefined) tc.result = update.result; + } + }; + + const buildContentForUI = (): ThreadMessageLike["content"] => { + const filtered = contentParts.filter((part) => { + if (part.type === "text") return part.text.length > 0; + if (part.type === "tool-call") return TOOLS_WITH_UI.has(part.toolName); + return false; + }); + return filtered.length > 0 + ? (filtered as ThreadMessageLike["content"]) + : [{ type: "text", text: "" }]; + }; + + const buildContentForPersistence = (): unknown[] => { + const parts: unknown[] = []; + if (currentThinkingSteps.size > 0) { + parts.push({ + type: "thinking-steps", + steps: Array.from(currentThinkingSteps.values()), + }); + } + for (const part of contentParts) { + if (part.type === "text" && part.text.length > 0) { + parts.push(part); + } else if (part.type === "tool-call" && TOOLS_WITH_UI.has(part.toolName)) { + parts.push(part); + } + } + return parts.length > 0 ? parts : [{ type: "text", text: "" }]; + }; + + try { + const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; + const response = await fetch(`${backendUrl}/api/v1/threads/${resumeThreadId}/resume`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + search_space_id: searchSpaceId, + decisions, + }), + signal: controller.signal, + }); + + if (!response.ok) { + throw new Error(`Backend error: ${response.status}`); + } + + if (!response.body) { + throw new Error("No response body"); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const events = buffer.split(/\r?\n\r?\n/); + buffer = events.pop() || ""; + + for (const event of events) { + const lines = event.split(/\r?\n/); + for (const line of lines) { + if (!line.startsWith("data: ")) continue; + const data = line.slice(6).trim(); + if (!data || data === "[DONE]") continue; + + try { + const parsed = JSON.parse(data); + + switch (parsed.type) { + case "text-delta": + appendText(parsed.delta); + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m + ) + ); + break; + + case "tool-input-start": + addToolCall(parsed.toolCallId, parsed.toolName, {}); + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m + ) + ); + break; + + case "tool-input-available": + if (toolCallIndices.has(parsed.toolCallId)) { + updateToolCall(parsed.toolCallId, { + args: parsed.input || {}, + }); + } else { + addToolCall(parsed.toolCallId, parsed.toolName, parsed.input || {}); + } + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m + ) + ); + break; + + case "tool-output-available": + updateToolCall(parsed.toolCallId, { + result: parsed.output, + }); + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m + ) + ); + break; + + case "data-thinking-step": { + const stepData = parsed.data as ThinkingStepData; + if (stepData?.id) { + currentThinkingSteps.set(stepData.id, stepData); + setMessageThinkingSteps((prev) => { + const newMap = new Map(prev); + newMap.set(assistantMsgId, Array.from(currentThinkingSteps.values())); + return newMap; + }); + } + break; + } + + case "data-interrupt-request": { + const interruptData = parsed.data as Record; + const actionRequests = (interruptData.action_requests ?? []) as Array<{ + name: string; + args: Record; + }>; + for (const action of actionRequests) { + const existingIdx = Array.from(toolCallIndices.entries()).find( + ([, idx]) => { + const part = contentParts[idx]; + return part?.type === "tool-call" && part.toolName === action.name; + } + ); + if (existingIdx) { + updateToolCall(existingIdx[0], { + result: { + __interrupt__: true, + ...interruptData, + }, + }); + } else { + const tcId = `interrupt-${action.name}`; + addToolCall(tcId, action.name, action.args); + updateToolCall(tcId, { + result: { + __interrupt__: true, + ...interruptData, + }, + }); + } + } + setMessages((prev) => + prev.map((m) => + m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m + ) + ); + setPendingInterrupt({ + threadId: resumeThreadId, + assistantMsgId, + interruptData, + }); + break; + } + + case "error": + throw new Error(parsed.errorText || "Server error"); + } + } catch (e) { + if (e instanceof SyntaxError) continue; + throw e; + } + } + } + } + } finally { + reader.releaseLock(); + } + + const finalContent = buildContentForPersistence(); + if (contentParts.length > 0) { + try { + const savedMessage = await appendMessage(resumeThreadId, { + role: "assistant", + content: finalContent, + }); + const newMsgId = `msg-${savedMessage.id}`; + setMessages((prev) => + prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m)) + ); + setMessageThinkingSteps((prev) => { + const steps = prev.get(assistantMsgId); + if (steps) { + const newMap = new Map(prev); + newMap.delete(assistantMsgId); + newMap.set(newMsgId, steps); + return newMap; + } + return prev; + }); + } catch (err) { + console.error("Failed to persist resumed assistant message:", err); + } + } + } catch (error) { + if (error instanceof Error && error.name === "AbortError") { + return; + } + console.error("[NewChatPage] Resume error:", error); + toast.error("Failed to resume. Please try again."); + } finally { + setIsRunning(false); + abortControllerRef.current = null; + } + }, + [pendingInterrupt, messages, searchSpaceId, messageThinkingSteps] + ); + + useEffect(() => { + const handler = (e: Event) => { + const detail = (e as CustomEvent).detail as { + decisions: Array<{ type: string; message?: string }>; + }; + if (detail?.decisions && pendingInterrupt) { + const decisionType = detail.decisions[0]?.type as "approve" | "reject"; + setMessages((prev) => + prev.map((m) => { + if (m.id !== pendingInterrupt.assistantMsgId) return m; + const parts = m.content as unknown as Array>; + const newContent = parts.map((part) => { + if ( + part.type === "tool-call" && + typeof part.result === "object" && + part.result !== null && + "__interrupt__" in part.result + ) { + return { + ...part, + result: { ...(part.result as Record), __decided__: decisionType }, + }; + } + return part; + }); + return { ...m, content: newContent as unknown as ThreadMessageLike["content"] }; + }) + ); + handleResume(detail.decisions); + } + }; + window.addEventListener("hitl-decision", handler); + return () => window.removeEventListener("hitl-decision", handler); + }, [handleResume, pendingInterrupt]); + // Convert message (pass through since already in correct format) const convertMessage = useCallback( (message: ThreadMessageLike): ThreadMessageLike => message, @@ -1432,6 +1862,7 @@ export default function NewChatPage() { + {/* Disabled for now */}
; + description?: string; + }>; + review_configs: Array<{ + action_name: string; + allowed_decisions: Array<"approve" | "edit" | "reject">; + }>; +} + +interface SuccessResult { + status: string; + page_id: string; + title: string; + url: string; +} + +type CreateNotionPageResult = InterruptResult | SuccessResult; + +function isInterruptResult(result: unknown): result is InterruptResult { + return ( + typeof result === "object" && + result !== null && + "__interrupt__" in result && + (result as InterruptResult).__interrupt__ === true + ); +} + +function ApprovalCard({ + args, + interruptData, + onDecision, +}: { + args: Record; + interruptData: InterruptResult; + onDecision: (decision: { type: "approve" | "reject"; message?: string }) => void; +}) { + const [decided, setDecided] = useState<"approve" | "reject" | null>( + interruptData.__decided__ ?? null + ); + const reviewConfig = interruptData.review_configs[0]; + const allowedDecisions = reviewConfig?.allowed_decisions ?? ["approve", "reject"]; + + return ( +
+
+
+ +
+
+

Create Notion Page

+

+ Requires your approval to proceed +

+
+
+ +
+ {args.title != null && ( +
+

Title

+

{String(args.title)}

+
+ )} + {args.content != null && ( +
+

Content

+

{String(args.content)}

+
+ )} +
+ +
+ {decided ? ( +

+ {decided === "approve" ? ( + <> + + Approved + + ) : ( + <> + + Rejected + + )} +

+ ) : ( + <> + {allowedDecisions.includes("approve") && ( + + )} + {allowedDecisions.includes("reject") && ( + + )} + + )} +
+
+ ); +} + +function SuccessCard({ result }: { result: SuccessResult }) { + return ( +
+
+
+ +
+
+

{result.title}

+

Notion page created

+
+
+
+ ); +} + +export const CreateNotionPageToolUI = makeAssistantToolUI< + { title: string; content: string }, + CreateNotionPageResult +>({ + toolName: "create_notion_page", + render: function CreateNotionPageUI({ args, result, status }) { + if (status.type === "running") { + return ( +
+ +

Preparing Notion page...

+
+ ); + } + + if (!result) { + return null; + } + + if (isInterruptResult(result)) { + return ( + { + const event = new CustomEvent("hitl-decision", { + detail: { decisions: [decision] }, + }); + window.dispatchEvent(event); + }} + /> + ); + } + + return ; + }, +}); diff --git a/surfsense_web/components/tool-ui/index.ts b/surfsense_web/components/tool-ui/index.ts index 5b4ea0a34..56e5c975b 100644 --- a/surfsense_web/components/tool-ui/index.ts +++ b/surfsense_web/components/tool-ui/index.ts @@ -16,6 +16,7 @@ export { type SerializableArticle, } from "./article"; export { Audio } from "./audio"; +export { CreateNotionPageToolUI } from "./create-notion-page"; export { type DeepAgentThinkingArgs, type DeepAgentThinkingResult,