feat: improved agent streaming

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-29 07:20:31 -07:00
parent afb4b09cde
commit c110f5b955
60 changed files with 8068 additions and 303 deletions

View file

@ -14,6 +14,13 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner";
import { z } from "zod";
import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms";
import {
agentActionsByChatTurnIdAtom,
markAgentActionRevertedAtom,
resetAgentActionMapAtom,
updateAgentActionReversibleAtom,
upsertAgentActionAtom,
} from "@/atoms/chat/agent-actions.atom";
import {
clearTargetCommentIdAtom,
currentThreadAtom,
@ -36,6 +43,11 @@ import { closeEditorPanelAtom } from "@/atoms/editor/editor-panel.atom";
import { membersAtom } from "@/atoms/members/members-query.atoms";
import { removeChatTabAtom, updateChatTabTitleAtom } from "@/atoms/tabs/tabs.atom";
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
import {
EditMessageDialog,
type EditMessageDialogChoice,
} from "@/components/assistant-ui/edit-message-dialog";
import { StepSeparatorDataUI } from "@/components/assistant-ui/step-separator";
import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps";
import { Thread } from "@/components/assistant-ui/thread";
import {
@ -55,14 +67,19 @@ import {
setActivePodcastTaskId,
} from "@/lib/chat/podcast-state";
import {
addStepSeparator,
addToolCall,
appendReasoning,
appendText,
buildContentForPersistence,
buildContentForUI,
type ContentPartsState,
endReasoning,
FrameBatchedUpdater,
findToolCallIdByLcId,
readSSEStream,
type ThinkingStepData,
type ToolUIGate,
updateThinkingSteps,
updateToolCall,
} from "@/lib/chat/streaming-state";
@ -161,44 +178,38 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] {
}
/**
* Tools that should render custom UI in the chat.
* Every tool call renders a card. The legacy
* ``BASE_TOOLS_WITH_UI`` allowlist used to drop unknown tool calls on the
* floor; we now route everything through ``ToolFallback``. Persisted
* payload size stays bounded because the backend's
* ``format_thinking_step`` summarisation and the
* ``result_length``-only default for unknown tools (see
* ``stream_new_chat.py``) keep the JSON from ballooning.
*/
const BASE_TOOLS_WITH_UI = new Set([
"web_search",
"generate_podcast",
"generate_report",
"generate_resume",
"generate_video_presentation",
"display_image",
"generate_image",
"delete_notion_page",
"create_notion_page",
"update_notion_page",
"create_linear_issue",
"update_linear_issue",
"delete_linear_issue",
"create_google_drive_file",
"delete_google_drive_file",
"create_onedrive_file",
"delete_onedrive_file",
"create_dropbox_file",
"delete_dropbox_file",
"create_calendar_event",
"update_calendar_event",
"delete_calendar_event",
"create_gmail_draft",
"update_gmail_draft",
"send_gmail_email",
"trash_gmail_email",
"create_jira_issue",
"update_jira_issue",
"delete_jira_issue",
"create_confluence_page",
"update_confluence_page",
"delete_confluence_page",
"execute",
// "write_todos", // Disabled for now
]);
const TOOLS_WITH_UI_ALL: ToolUIGate = "all";
/**
* When a streamed message is persisted, the backend returns the durable
* ``turn_id`` (``configurable.turn_id`` from the agent run). Merge it
* into the assistant-ui message metadata so the per-turn "Revert turn"
* button can scope to this turn's actions even after a full chat reload.
*/
function mergeChatTurnIdIntoMessage(
msg: ThreadMessageLike,
turnId: string | null | undefined
): ThreadMessageLike {
if (!turnId) return msg;
const existingMeta = (msg.metadata ?? {}) as { custom?: Record<string, unknown> };
const existingCustom = existingMeta.custom ?? {};
if ((existingCustom as { chatTurnId?: string }).chatTurnId === turnId) return msg;
return {
...msg,
metadata: {
...existingMeta,
custom: { ...existingCustom, chatTurnId: turnId },
},
};
}
export default function NewChatPage() {
const params = useParams();
@ -215,7 +226,7 @@ export default function NewChatPage() {
assistantMsgId: string;
interruptData: Record<string, unknown>;
} | null>(null);
const toolsWithUI = useMemo(() => new Set([...BASE_TOOLS_WITH_UI]), []);
const toolsWithUI = TOOLS_WITH_UI_ALL;
// Get disabled tools from the tool toggle UI
const disabledTools = useAtomValue(disabledToolsAtom);
@ -235,6 +246,25 @@ export default function NewChatPage() {
const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom);
const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom);
const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom);
// Agent action log SSE side-channel.
const upsertAgentAction = useSetAtom(upsertAgentActionAtom);
const updateAgentActionReversible = useSetAtom(updateAgentActionReversibleAtom);
const markAgentActionReverted = useSetAtom(markAgentActionRevertedAtom);
const resetAgentActionMap = useSetAtom(resetAgentActionMapAtom);
// Chat-turn-keyed action map for the edit-from-position pre-flight
// that decides whether to show the confirmation dialog.
const agentActionsByChatTurnId = useAtomValue(agentActionsByChatTurnIdAtom);
// Edit dialog state. Holds the message id being edited and
// the (already extracted) regenerate args so we can resume the edit
// after the user picks "revert all" / "continue" / "cancel".
const [editDialogState, setEditDialogState] = useState<{
fromMessageId: number;
userQuery: string | null;
userMessageContent: ThreadMessageLike["content"];
userImages: NewChatUserImagePayload[];
downstreamReversibleCount: number;
downstreamTotalCount: number;
} | null>(null);
// Get current user for author info in shared chats
const { data: currentUser } = useAtomValue(currentUserAtom);
@ -327,6 +357,7 @@ export default function NewChatPage() {
clearPlanOwnerRegistry();
closeReportPanel();
closeEditorPanel();
resetAgentActionMap();
try {
if (urlChatId > 0) {
@ -395,6 +426,7 @@ export default function NewChatPage() {
removeChatTab,
searchSpaceId,
tokenUsageStore,
resetAgentActionMap,
]);
// Initialize on mount, and re-init when switching search spaces (even if urlChatId is the same)
@ -655,11 +687,14 @@ export default function NewChatPage() {
const contentPartsState: ContentPartsState = {
contentParts: [],
currentTextPartIndex: -1,
currentReasoningPartIndex: -1,
toolCallIndices: new Map(),
};
const { contentParts, toolCallIndices } = contentPartsState;
let wasInterrupted = false;
let tokenUsageData: Record<string, unknown> | null = null;
// Captured from ``data-turn-info`` at stream start.
let streamedChatTurnId: string | null = null;
// Add placeholder assistant message
setMessages((prev) => [
@ -752,21 +787,52 @@ export default function NewChatPage() {
scheduleFlush();
break;
case "reasoning-delta":
appendReasoning(contentPartsState, parsed.delta);
scheduleFlush();
break;
case "reasoning-end":
endReasoning(contentPartsState);
scheduleFlush();
break;
case "start-step":
addStepSeparator(contentPartsState);
scheduleFlush();
break;
case "finish-step":
break;
case "tool-input-start":
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
addToolCall(
contentPartsState,
toolsWithUI,
parsed.toolCallId,
parsed.toolName,
{},
false,
parsed.langchainToolCallId
);
batcher.flush();
break;
case "tool-input-available": {
if (toolCallIndices.has(parsed.toolCallId)) {
updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} });
updateToolCall(contentPartsState, parsed.toolCallId, {
args: parsed.input || {},
langchainToolCallId: parsed.langchainToolCallId,
});
} else {
addToolCall(
contentPartsState,
toolsWithUI,
parsed.toolCallId,
parsed.toolName,
parsed.input || {}
parsed.input || {},
false,
parsed.langchainToolCallId
);
}
batcher.flush();
@ -774,7 +840,10 @@ export default function NewChatPage() {
}
case "tool-output-available": {
updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output });
updateToolCall(contentPartsState, parsed.toolCallId, {
result: parsed.output,
langchainToolCallId: parsed.langchainToolCallId,
});
markInterruptsCompleted(contentParts);
if (parsed.output?.status === "pending" && parsed.output?.podcast_id) {
const idx = toolCallIndices.get(parsed.toolCallId);
@ -880,6 +949,50 @@ export default function NewChatPage() {
break;
}
case "data-action-log": {
const al = parsed.data;
const matchedToolCallId = al.lc_tool_call_id
? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id)
: null;
upsertAgentAction({
action: {
id: al.id,
threadId: currentThreadId,
lcToolCallId: al.lc_tool_call_id,
chatTurnId: al.chat_turn_id,
toolName: al.tool_name,
reversible: al.reversible,
reverseDescriptorPresent: al.reverse_descriptor_present,
error: al.error,
revertedByActionId: null,
isRevertAction: false,
createdAt: al.created_at,
},
toolCallId: matchedToolCallId,
});
break;
}
case "data-action-log-updated": {
updateAgentActionReversible({
id: parsed.data.id,
reversible: parsed.data.reversible,
});
break;
}
case "data-turn-info": {
streamedChatTurnId = parsed.data.chat_turn_id || null;
if (streamedChatTurnId) {
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m
)
);
}
break;
}
case "data-token-usage":
tokenUsageData = parsed.data;
tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData);
@ -900,13 +1013,18 @@ export default function NewChatPage() {
role: "assistant",
content: finalContent,
token_usage: tokenUsageData ?? undefined,
turn_id: streamedChatTurnId,
});
// Update message ID from temporary to database ID so comments work immediately
const newMsgId = `msg-${savedMessage.id}`;
tokenUsageStore.rename(assistantMsgId, newMsgId);
setMessages((prev) =>
prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m))
prev.map((m) =>
m.id === assistantMsgId
? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id)
: m
)
);
// Update pending interrupt with the new persisted message ID
@ -929,7 +1047,9 @@ export default function NewChatPage() {
const hasContent = contentParts.some(
(part) =>
(part.type === "text" && part.text.length > 0) ||
(part.type === "tool-call" && toolsWithUI.has(part.toolName))
(part.type === "reasoning" && part.text.length > 0) ||
(part.type === "tool-call" &&
(toolsWithUI === "all" || toolsWithUI.has(part.toolName)))
);
if (hasContent && currentThreadId) {
const partialContent = buildContentForPersistence(contentPartsState, toolsWithUI);
@ -937,12 +1057,17 @@ export default function NewChatPage() {
const savedMessage = await appendMessage(currentThreadId, {
role: "assistant",
content: partialContent,
turn_id: streamedChatTurnId,
});
// Update message ID from temporary to database ID
const newMsgId = `msg-${savedMessage.id}`;
setMessages((prev) =>
prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m))
prev.map((m) =>
m.id === assistantMsgId
? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id)
: m
)
);
} catch (err) {
console.error("Failed to persist partial assistant message:", err);
@ -1030,10 +1155,13 @@ export default function NewChatPage() {
const contentPartsState: ContentPartsState = {
contentParts: [],
currentTextPartIndex: -1,
currentReasoningPartIndex: -1,
toolCallIndices: new Map(),
};
const { contentParts, toolCallIndices } = contentPartsState;
let tokenUsageData: Record<string, unknown> | null = null;
// Captured from ``data-turn-info`` at stream start.
let streamedChatTurnId: string | null = null;
const existingMsg = messages.find((m) => m.id === assistantMsgId);
if (existingMsg && Array.isArray(existingMsg.content)) {
@ -1136,8 +1264,34 @@ export default function NewChatPage() {
scheduleFlush();
break;
case "reasoning-delta":
appendReasoning(contentPartsState, parsed.delta);
scheduleFlush();
break;
case "reasoning-end":
endReasoning(contentPartsState);
scheduleFlush();
break;
case "start-step":
addStepSeparator(contentPartsState);
scheduleFlush();
break;
case "finish-step":
break;
case "tool-input-start":
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
addToolCall(
contentPartsState,
toolsWithUI,
parsed.toolCallId,
parsed.toolName,
{},
false,
parsed.langchainToolCallId
);
batcher.flush();
break;
@ -1145,6 +1299,7 @@ export default function NewChatPage() {
if (toolCallIndices.has(parsed.toolCallId)) {
updateToolCall(contentPartsState, parsed.toolCallId, {
args: parsed.input || {},
langchainToolCallId: parsed.langchainToolCallId,
});
} else {
addToolCall(
@ -1152,7 +1307,9 @@ export default function NewChatPage() {
toolsWithUI,
parsed.toolCallId,
parsed.toolName,
parsed.input || {}
parsed.input || {},
false,
parsed.langchainToolCallId
);
}
batcher.flush();
@ -1161,6 +1318,7 @@ export default function NewChatPage() {
case "tool-output-available":
updateToolCall(contentPartsState, parsed.toolCallId, {
result: parsed.output,
langchainToolCallId: parsed.langchainToolCallId,
});
markInterruptsCompleted(contentParts);
batcher.flush();
@ -1222,6 +1380,50 @@ export default function NewChatPage() {
break;
}
case "data-action-log": {
const al = parsed.data;
const matchedToolCallId = al.lc_tool_call_id
? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id)
: null;
upsertAgentAction({
action: {
id: al.id,
threadId: resumeThreadId,
lcToolCallId: al.lc_tool_call_id,
chatTurnId: al.chat_turn_id,
toolName: al.tool_name,
reversible: al.reversible,
reverseDescriptorPresent: al.reverse_descriptor_present,
error: al.error,
revertedByActionId: null,
isRevertAction: false,
createdAt: al.created_at,
},
toolCallId: matchedToolCallId,
});
break;
}
case "data-action-log-updated": {
updateAgentActionReversible({
id: parsed.data.id,
reversible: parsed.data.reversible,
});
break;
}
case "data-turn-info": {
streamedChatTurnId = parsed.data.chat_turn_id || null;
if (streamedChatTurnId) {
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m
)
);
}
break;
}
case "data-token-usage":
tokenUsageData = parsed.data;
tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData);
@ -1241,11 +1443,16 @@ export default function NewChatPage() {
role: "assistant",
content: finalContent,
token_usage: tokenUsageData ?? undefined,
turn_id: streamedChatTurnId,
});
const newMsgId = `msg-${savedMessage.id}`;
tokenUsageStore.rename(assistantMsgId, newMsgId);
setMessages((prev) =>
prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m))
prev.map((m) =>
m.id === assistantMsgId
? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id)
: m
)
);
} catch (err) {
console.error("Failed to persist resumed assistant message:", err);
@ -1340,6 +1547,12 @@ export default function NewChatPage() {
editExtras?: {
userMessageContent: ThreadMessageLike["content"];
userImages: NewChatUserImagePayload[];
},
editFromPosition?: {
/** Message id (numeric, parsed from ``msg-<n>``) to rewind to. */
fromMessageId?: number | null;
/** When true, revert reversible downstream actions before stream. */
revertActions?: boolean;
}
) => {
if (!threadId) {
@ -1384,9 +1597,20 @@ export default function NewChatPage() {
userQueryToDisplay = newUserQuery;
}
// Remove the last two messages (user + assistant) from the UI immediately
// The backend will also delete them from the database
// Remove downstream messages from the UI immediately. The
// backend will also delete them from the database.
//
// When an explicit ``fromMessageId`` is passed, slice from
// that message forward; otherwise fall back to the legacy
// "drop the last 2" behaviour.
setMessages((prev) => {
if (editFromPosition?.fromMessageId != null) {
const targetId = `msg-${editFromPosition.fromMessageId}`;
const sliceIndex = prev.findIndex((m) => m.id === targetId);
if (sliceIndex >= 0) {
return prev.slice(0, sliceIndex);
}
}
if (prev.length >= 2) {
return prev.slice(0, -2);
}
@ -1406,11 +1630,16 @@ export default function NewChatPage() {
const contentPartsState: ContentPartsState = {
contentParts: [],
currentTextPartIndex: -1,
currentReasoningPartIndex: -1,
toolCallIndices: new Map(),
};
const { contentParts, toolCallIndices } = contentPartsState;
const batcher = new FrameBatchedUpdater();
let tokenUsageData: Record<string, unknown> | null = null;
// Captured from ``data-turn-info`` at stream start; stamped
// onto persisted messages so future edits can locate the
// right LangGraph checkpoint.
let streamedChatTurnId: string | null = null;
// Add placeholder messages to UI
// Always add back the user message (with new query for edit, or original content for reload)
@ -1449,6 +1678,16 @@ export default function NewChatPage() {
if (isEdit) {
requestBody.user_images = editExtras?.userImages ?? [];
}
// Explicit edit-from-arbitrary-position. Only send
// ``from_message_id`` / ``revert_actions`` when the
// caller asked for them; otherwise the backend keeps the
// legacy "last 2 messages" behaviour for back-compat.
if (editFromPosition?.fromMessageId != null) {
requestBody.from_message_id = editFromPosition.fromMessageId;
if (editFromPosition.revertActions) {
requestBody.revert_actions = true;
}
}
const response = await fetch(getRegenerateUrl(threadId), {
method: "POST",
headers: {
@ -1481,28 +1720,62 @@ export default function NewChatPage() {
scheduleFlush();
break;
case "reasoning-delta":
appendReasoning(contentPartsState, parsed.delta);
scheduleFlush();
break;
case "reasoning-end":
endReasoning(contentPartsState);
scheduleFlush();
break;
case "start-step":
addStepSeparator(contentPartsState);
scheduleFlush();
break;
case "finish-step":
break;
case "tool-input-start":
addToolCall(contentPartsState, toolsWithUI, parsed.toolCallId, parsed.toolName, {});
addToolCall(
contentPartsState,
toolsWithUI,
parsed.toolCallId,
parsed.toolName,
{},
false,
parsed.langchainToolCallId
);
batcher.flush();
break;
case "tool-input-available":
if (toolCallIndices.has(parsed.toolCallId)) {
updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} });
updateToolCall(contentPartsState, parsed.toolCallId, {
args: parsed.input || {},
langchainToolCallId: parsed.langchainToolCallId,
});
} else {
addToolCall(
contentPartsState,
toolsWithUI,
parsed.toolCallId,
parsed.toolName,
parsed.input || {}
parsed.input || {},
false,
parsed.langchainToolCallId
);
}
batcher.flush();
break;
case "tool-output-available":
updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output });
updateToolCall(contentPartsState, parsed.toolCallId, {
result: parsed.output,
langchainToolCallId: parsed.langchainToolCallId,
});
markInterruptsCompleted(contentParts);
if (parsed.output?.status === "pending" && parsed.output?.podcast_id) {
const idx = toolCallIndices.get(parsed.toolCallId);
@ -1528,6 +1801,82 @@ export default function NewChatPage() {
break;
}
case "data-action-log": {
const al = parsed.data;
const matchedToolCallId = al.lc_tool_call_id
? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id)
: null;
upsertAgentAction({
action: {
id: al.id,
threadId,
lcToolCallId: al.lc_tool_call_id,
chatTurnId: al.chat_turn_id,
toolName: al.tool_name,
reversible: al.reversible,
reverseDescriptorPresent: al.reverse_descriptor_present,
error: al.error,
revertedByActionId: null,
isRevertAction: false,
createdAt: al.created_at,
},
toolCallId: matchedToolCallId,
});
break;
}
case "data-action-log-updated": {
updateAgentActionReversible({
id: parsed.data.id,
reversible: parsed.data.reversible,
});
break;
}
case "data-turn-info": {
streamedChatTurnId = parsed.data.chat_turn_id || null;
if (streamedChatTurnId) {
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId ? mergeChatTurnIdIntoMessage(m, streamedChatTurnId) : m
)
);
}
break;
}
case "data-revert-results": {
const summary = parsed.data;
// failureCount must include every "not undone" bucket
// (not_reversible, permission_denied, failed) so the
// toast's "X could not be rolled back" math matches
// the response invariant ``total === sum(counters)``.
// ``skipped`` rows are batch revert artefacts (revert
// rows themselves) and are not user-facing failures.
const failureCount =
summary.failed + summary.not_reversible + (summary.permission_denied ?? 0);
if (failureCount > 0) {
toast.warning(
`Pre-revert: ${summary.reverted}/${summary.total} undone, ${failureCount} could not be rolled back.`
);
} else if (summary.reverted > 0) {
toast.success(
summary.reverted === 1
? "Reverted 1 downstream action before regenerating."
: `Reverted ${summary.reverted} downstream actions before regenerating.`
);
}
for (const r of summary.results) {
if (r.status === "reverted" || r.status === "already_reverted") {
markAgentActionReverted({
id: r.action_id,
newActionId: r.new_action_id ?? null,
});
}
}
break;
}
case "data-token-usage":
tokenUsageData = parsed.data;
tokenUsageStore.set(assistantMsgId, parsed.data as TokenUsageData);
@ -1552,12 +1901,17 @@ export default function NewChatPage() {
const savedUserMessage = await appendMessage(threadId, {
role: "user",
content: userContentToPersist,
turn_id: streamedChatTurnId,
});
// Update user message ID to database ID
const newUserMsgId = `msg-${savedUserMessage.id}`;
setMessages((prev) =>
prev.map((m) => (m.id === userMsgId ? { ...m, id: newUserMsgId } : m))
prev.map((m) =>
m.id === userMsgId
? mergeChatTurnIdIntoMessage({ ...m, id: newUserMsgId }, savedUserMessage.turn_id)
: m
)
);
// Persist assistant message
@ -1565,12 +1919,17 @@ export default function NewChatPage() {
role: "assistant",
content: finalContent,
token_usage: tokenUsageData ?? undefined,
turn_id: streamedChatTurnId,
});
const newMsgId = `msg-${savedMessage.id}`;
tokenUsageStore.rename(assistantMsgId, newMsgId);
setMessages((prev) =>
prev.map((m) => (m.id === assistantMsgId ? { ...m, id: newMsgId } : m))
prev.map((m) =>
m.id === assistantMsgId
? mergeChatTurnIdIntoMessage({ ...m, id: newMsgId }, savedMessage.turn_id)
: m
)
);
trackChatResponseReceived(searchSpaceId, threadId);
@ -1608,7 +1967,14 @@ export default function NewChatPage() {
[threadId, searchSpaceId, messages, disabledTools, tokenUsageStore, toolsWithUI]
);
// Handle editing a message - truncates history and regenerates with new query
// Handle editing a message - truncates history and regenerates with new query.
//
// When ``message.sourceId`` is set (the assistant-ui way to say
// "this edit replaces an older message"), we pin
// ``from_message_id`` so the backend rewinds to the right LangGraph
// checkpoint instead of relying on the legacy "last 2 messages"
// rewind. We also count downstream reversible actions and prompt the
// user to revert / continue / cancel before regenerating.
const onEdit = useCallback(
async (message: AppendMessage) => {
const { userQuery, userImages } = extractUserTurnForNewChatApi(message, []);
@ -1619,9 +1985,95 @@ export default function NewChatPage() {
}
const userMessageContent = message.content as unknown as ThreadMessageLike["content"];
await handleRegenerate(queryForApi, { userMessageContent, userImages });
// ``sourceId`` per @assistant-ui/core's ``AppendMessage`` is
// "the ID of the message that was edited". Parse the numeric
// suffix so we can map it back to a DB row.
const sourceId = (message as { sourceId?: string }).sourceId;
const fromMessageId =
sourceId && /^msg-\d+$/.test(sourceId)
? Number.parseInt(sourceId.replace(/^msg-/, ""), 10)
: null;
if (fromMessageId == null) {
// No source id (or non-DB id) — fall back to today's
// last-2 behaviour. The user gets the legacy edit flow.
await handleRegenerate(queryForApi, { userMessageContent, userImages });
return;
}
// Pre-flight: count reversible downstream actions so we can
// auto-skip the dialog for harmless edits.
//
// "Downstream" means messages AFTER the edited one. The
// previous slice ``messages.slice(editedIndex)`` included
// the edited message itself in both the total
// count and the reversibility scan (any actions on the
// edited turn would be double-counted). Slice from
// ``editedIndex + 1`` so the dialog text matches reality:
// "N downstream messages will be dropped".
const editedIndex = messages.findIndex((m) => m.id === `msg-${fromMessageId}`);
let downstreamReversibleCount = 0;
let downstreamTotalCount = 0;
if (editedIndex >= 0) {
const downstream = messages.slice(editedIndex + 1);
downstreamTotalCount = downstream.length;
const seenTurns = new Set<string>();
for (const m of downstream) {
const meta = (m.metadata ?? {}) as { custom?: { chatTurnId?: string } };
const tid = meta.custom?.chatTurnId;
if (!tid || seenTurns.has(tid)) continue;
seenTurns.add(tid);
const turnActions = agentActionsByChatTurnId.get(tid) ?? [];
for (const a of turnActions) {
if (a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error) {
downstreamReversibleCount += 1;
}
}
}
}
if (downstreamReversibleCount === 0) {
// Nothing to revert — submit silently.
await handleRegenerate(
queryForApi,
{ userMessageContent, userImages },
{ fromMessageId, revertActions: false }
);
return;
}
setEditDialogState({
fromMessageId,
userQuery: queryForApi,
userMessageContent,
userImages,
downstreamReversibleCount,
downstreamTotalCount,
});
},
[handleRegenerate]
[handleRegenerate, messages, agentActionsByChatTurnId]
);
const handleEditDialogChoice = useCallback(
async (choice: EditMessageDialogChoice) => {
const pending = editDialogState;
if (!pending) return;
setEditDialogState(null);
if (choice === "cancel") return;
await handleRegenerate(
pending.userQuery,
{
userMessageContent: pending.userMessageContent,
userImages: pending.userImages,
},
{
fromMessageId: pending.fromMessageId,
revertActions: choice === "revert",
}
);
},
[editDialogState, handleRegenerate]
);
// Handle reloading/refreshing the last AI response
@ -1671,6 +2123,7 @@ export default function NewChatPage() {
<TokenUsageProvider store={tokenUsageStore}>
<AssistantRuntimeProvider runtime={runtime}>
<ThinkingStepsDataUI />
<StepSeparatorDataUI />
<div key={searchSpaceId} className="flex h-full overflow-hidden">
<div className="flex-1 flex flex-col min-w-0 overflow-hidden">
<Thread />
@ -1679,6 +2132,15 @@ export default function NewChatPage() {
<MobileEditorPanel />
<MobileHitlEditPanel />
</div>
<EditMessageDialog
open={editDialogState !== null}
onOpenChange={(open) => {
if (!open) setEditDialogState(null);
}}
downstreamReversibleCount={editDialogState?.downstreamReversibleCount ?? 0}
downstreamTotalCount={editDialogState?.downstreamTotalCount ?? 0}
onChoose={handleEditDialogChoice}
/>
</AssistantRuntimeProvider>
</TokenUsageProvider>
);