mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-27 09:46:25 +02:00
263 lines
7.3 KiB
TypeScript
263 lines
7.3 KiB
TypeScript
/**
|
|
* Custom ChatModelAdapter for the new-chat feature using LocalRuntime.
|
|
* Connects directly to the FastAPI backend using the Vercel AI SDK Data Stream Protocol.
|
|
*/
|
|
|
|
import type { ChatModelAdapter, ChatModelRunOptions } from "@assistant-ui/react";
|
|
import { getBearerToken } from "@/lib/auth-utils";
|
|
|
|
interface NewChatAdapterConfig {
|
|
searchSpaceId: number;
|
|
chatId: number;
|
|
}
|
|
|
|
/**
|
|
* Represents an in-progress or completed tool call
|
|
*/
|
|
interface ToolCallState {
|
|
toolCallId: string;
|
|
toolName: string;
|
|
args: Record<string, unknown>;
|
|
result?: unknown;
|
|
}
|
|
|
|
/**
|
|
* Tools that should render custom UI in the chat.
|
|
* Other tools (like search_knowledge_base) will be hidden from the UI.
|
|
*/
|
|
const TOOLS_WITH_UI = new Set(["generate_podcast"]);
|
|
|
|
/**
|
|
* Creates a ChatModelAdapter that connects to the FastAPI new_chat endpoint.
|
|
*
|
|
* The backend expects:
|
|
* - POST /api/v1/new_chat
|
|
* - Body: { chat_id: number, user_query: string, search_space_id: number }
|
|
* - Returns: SSE stream with Vercel AI SDK Data Stream Protocol
|
|
*/
|
|
export function createNewChatAdapter(config: NewChatAdapterConfig): ChatModelAdapter {
|
|
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
|
|
|
return {
|
|
async *run({ messages, abortSignal }: ChatModelRunOptions) {
|
|
// Get the last user message
|
|
const lastUserMessage = messages.filter((m) => m.role === "user").pop();
|
|
|
|
if (!lastUserMessage) {
|
|
throw new Error("No user message found");
|
|
}
|
|
|
|
// Extract text content from the message
|
|
let userQuery = "";
|
|
for (const part of lastUserMessage.content) {
|
|
if (part.type === "text") {
|
|
userQuery += part.text;
|
|
}
|
|
}
|
|
|
|
if (!userQuery.trim()) {
|
|
throw new Error("User query cannot be empty");
|
|
}
|
|
|
|
const token = getBearerToken();
|
|
if (!token) {
|
|
throw new Error("Not authenticated. Please log in again.");
|
|
}
|
|
|
|
const response = await fetch(`${backendUrl}/api/v1/new_chat`, {
|
|
method: "POST",
|
|
headers: {
|
|
"Content-Type": "application/json",
|
|
Authorization: `Bearer ${token}`,
|
|
},
|
|
body: JSON.stringify({
|
|
chat_id: config.chatId,
|
|
user_query: userQuery.trim(),
|
|
search_space_id: config.searchSpaceId,
|
|
}),
|
|
signal: abortSignal,
|
|
});
|
|
|
|
if (!response.ok) {
|
|
const errorText = await response.text().catch(() => "Unknown error");
|
|
throw new Error(`Backend error (${response.status}): ${errorText}`);
|
|
}
|
|
|
|
if (!response.body) {
|
|
throw new Error("No response body");
|
|
}
|
|
|
|
// Parse the SSE stream (Vercel AI SDK Data Stream Protocol)
|
|
const reader = response.body.getReader();
|
|
const decoder = new TextDecoder();
|
|
let buffer = "";
|
|
let accumulatedText = "";
|
|
|
|
// Track tool calls by their ID
|
|
const toolCalls = new Map<string, ToolCallState>();
|
|
|
|
/**
|
|
* Build the content array with text and tool calls.
|
|
* Only includes tools that have custom UI (defined in TOOLS_WITH_UI).
|
|
*/
|
|
function buildContent() {
|
|
const content: Array<
|
|
| { type: "text"; text: string }
|
|
| { type: "tool-call"; toolCallId: string; toolName: string; args: Record<string, unknown>; result?: unknown }
|
|
> = [];
|
|
|
|
// Add text content if any
|
|
if (accumulatedText) {
|
|
content.push({ type: "text" as const, text: accumulatedText });
|
|
}
|
|
|
|
// Only add tool calls that have custom UI registered
|
|
// Other tools (like search_knowledge_base) are hidden from the UI
|
|
for (const toolCall of toolCalls.values()) {
|
|
if (TOOLS_WITH_UI.has(toolCall.toolName)) {
|
|
content.push({
|
|
type: "tool-call" as const,
|
|
toolCallId: toolCall.toolCallId,
|
|
toolName: toolCall.toolName,
|
|
args: toolCall.args,
|
|
result: toolCall.result,
|
|
});
|
|
}
|
|
}
|
|
|
|
return content;
|
|
}
|
|
|
|
try {
|
|
while (true) {
|
|
const { done, value } = await reader.read();
|
|
if (done) {
|
|
break;
|
|
}
|
|
|
|
const chunk = decoder.decode(value, { stream: true });
|
|
buffer += chunk;
|
|
|
|
// Split on double newlines (handle both \n\n and \r\n\r\n)
|
|
const events = buffer.split(/\r?\n\r?\n/);
|
|
buffer = events.pop() || "";
|
|
|
|
for (const event of events) {
|
|
// Each event can have multiple lines, find the data line
|
|
const lines = event.split(/\r?\n/);
|
|
for (const line of lines) {
|
|
if (!line.startsWith("data: ")) continue;
|
|
|
|
const data = line.slice(6).trim(); // Remove "data: " prefix
|
|
|
|
// Handle [DONE] marker
|
|
if (data === "[DONE]") {
|
|
continue;
|
|
}
|
|
|
|
if (!data) continue;
|
|
|
|
try {
|
|
const parsed = JSON.parse(data);
|
|
|
|
// Handle different message types from the Data Stream Protocol
|
|
switch (parsed.type) {
|
|
case "text-delta":
|
|
accumulatedText += parsed.delta;
|
|
yield { content: buildContent() };
|
|
break;
|
|
|
|
case "tool-input-start": {
|
|
// Tool call is starting - create a new tool call entry
|
|
const { toolCallId, toolName } = parsed;
|
|
toolCalls.set(toolCallId, {
|
|
toolCallId,
|
|
toolName,
|
|
args: {},
|
|
});
|
|
// Yield to show tool is starting (running state)
|
|
yield { content: buildContent() };
|
|
break;
|
|
}
|
|
|
|
case "tool-input-available": {
|
|
// Tool input is complete - update the args
|
|
const { toolCallId, toolName, input } = parsed;
|
|
const existing = toolCalls.get(toolCallId);
|
|
if (existing) {
|
|
existing.args = input || {};
|
|
} else {
|
|
// Create new entry if we missed tool-input-start
|
|
toolCalls.set(toolCallId, {
|
|
toolCallId,
|
|
toolName,
|
|
args: input || {},
|
|
});
|
|
}
|
|
yield { content: buildContent() };
|
|
break;
|
|
}
|
|
|
|
case "tool-output-available": {
|
|
// Tool execution is complete - add the result
|
|
const { toolCallId, output } = parsed;
|
|
const existing = toolCalls.get(toolCallId);
|
|
if (existing) {
|
|
existing.result = output;
|
|
}
|
|
yield { content: buildContent() };
|
|
break;
|
|
}
|
|
|
|
case "error":
|
|
throw new Error(parsed.errorText || "Unknown error from server");
|
|
|
|
// Other types like text-start, text-end, start-step, finish-step, etc.
|
|
// are handled implicitly
|
|
default:
|
|
break;
|
|
}
|
|
} catch (e) {
|
|
// Skip non-JSON lines
|
|
if (e instanceof SyntaxError) {
|
|
continue;
|
|
}
|
|
throw e;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handle any remaining buffer
|
|
if (buffer.trim()) {
|
|
const lines = buffer.split(/\r?\n/);
|
|
for (const line of lines) {
|
|
if (line.startsWith("data: ")) {
|
|
const data = line.slice(6).trim();
|
|
if (data && data !== "[DONE]") {
|
|
try {
|
|
const parsed = JSON.parse(data);
|
|
if (parsed.type === "text-delta") {
|
|
accumulatedText += parsed.delta;
|
|
yield { content: buildContent() };
|
|
} else if (parsed.type === "tool-output-available") {
|
|
const { toolCallId, output } = parsed;
|
|
const existing = toolCalls.get(toolCallId);
|
|
if (existing) {
|
|
existing.result = output;
|
|
}
|
|
yield { content: buildContent() };
|
|
}
|
|
} catch {
|
|
// Ignore parse errors
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} finally {
|
|
reader.releaseLock();
|
|
}
|
|
},
|
|
};
|
|
}
|