mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-15 18:25:18 +02:00
Merge pull request #530 from MODSetter/dev
fix: implement real-time streaming for responses
This commit is contained in:
commit
d29bd6c12f
6 changed files with 72 additions and 49 deletions
|
|
@ -1440,7 +1440,12 @@ async def handle_qna_workflow(
|
|||
}
|
||||
|
||||
# Create the state for the QNA agent (it has a different state structure)
|
||||
qna_state = {"db_session": state.db_session, "chat_history": state.chat_history}
|
||||
# Pass streaming_service so the QNA agent can stream tokens directly
|
||||
qna_state = {
|
||||
"db_session": state.db_session,
|
||||
"chat_history": state.chat_history,
|
||||
"streaming_service": streaming_service,
|
||||
}
|
||||
|
||||
try:
|
||||
writer(
|
||||
|
|
@ -1455,36 +1460,26 @@ async def handle_qna_workflow(
|
|||
complete_content = ""
|
||||
captured_reranked_documents = []
|
||||
|
||||
# Call the QNA agent with streaming
|
||||
async for _chunk_type, chunk in qna_agent_graph.astream(
|
||||
qna_state, qna_config, stream_mode=["values"]
|
||||
# Call the QNA agent with both custom and values streaming modes
|
||||
# - "custom" captures token-by-token streams from answer_question via writer()
|
||||
# - "values" captures state updates including final_answer and reranked_documents
|
||||
async for stream_mode, chunk in qna_agent_graph.astream(
|
||||
qna_state, qna_config, stream_mode=["custom", "values"]
|
||||
):
|
||||
if "final_answer" in chunk:
|
||||
new_content = chunk["final_answer"]
|
||||
if new_content and new_content != complete_content:
|
||||
# Extract only the new content (delta)
|
||||
delta = new_content[len(complete_content) :]
|
||||
complete_content = new_content
|
||||
if stream_mode == "custom":
|
||||
# Handle custom stream events (token chunks from answer_question)
|
||||
if isinstance(chunk, dict) and "yield_value" in chunk:
|
||||
# Forward the streamed token to the parent writer
|
||||
writer(chunk)
|
||||
elif stream_mode == "values" and isinstance(chunk, dict):
|
||||
# Handle state value updates
|
||||
# Capture the final answer from state
|
||||
if chunk.get("final_answer"):
|
||||
complete_content = chunk["final_answer"]
|
||||
|
||||
# Stream the real-time answer if there's new content
|
||||
if delta:
|
||||
# Update terminal with progress
|
||||
word_count = len(complete_content.split())
|
||||
writer(
|
||||
{
|
||||
"yield_value": streaming_service.format_terminal_info_delta(
|
||||
f"✍️ Writing answer... ({word_count} words)"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
writer(
|
||||
{"yield_value": streaming_service.format_text_chunk(delta)}
|
||||
)
|
||||
|
||||
# Capture reranked documents from QNA agent for further question generation
|
||||
if "reranked_documents" in chunk:
|
||||
captured_reranked_documents = chunk["reranked_documents"]
|
||||
# Capture reranked documents from QNA agent for further question generation
|
||||
if chunk.get("reranked_documents"):
|
||||
captured_reranked_documents = chunk["reranked_documents"]
|
||||
|
||||
# Set default if no content was received
|
||||
if not complete_content:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Any
|
|||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import SearchSpace
|
||||
|
|
@ -129,9 +130,11 @@ async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, An
|
|||
return {"reranked_documents": documents}
|
||||
|
||||
|
||||
async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||
async def answer_question(
|
||||
state: State, config: RunnableConfig, writer: StreamWriter
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Answer the user's question using the provided documents.
|
||||
Answer the user's question using the provided documents with real-time streaming.
|
||||
|
||||
This node takes the relevant documents provided in the configuration and uses
|
||||
an LLM to generate a comprehensive answer to the user's question with
|
||||
|
|
@ -139,6 +142,8 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
|
|||
documents. If no documents are provided, it will use chat history to generate
|
||||
an answer.
|
||||
|
||||
The response is streamed token-by-token for real-time updates to the frontend.
|
||||
|
||||
Returns:
|
||||
Dict containing the final answer in the "final_answer" key.
|
||||
"""
|
||||
|
|
@ -151,6 +156,9 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
|
|||
search_space_id = configuration.search_space_id
|
||||
language = configuration.language
|
||||
|
||||
# Get streaming service from state
|
||||
streaming_service = state.streaming_service
|
||||
|
||||
# Fetch search space to get QnA configuration
|
||||
result = await state.db_session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
|
|
@ -279,8 +287,17 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
|
|||
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
|
||||
print(f"Final token count: {total_tokens}")
|
||||
|
||||
# Call the LLM and get the response
|
||||
response = await llm.ainvoke(messages_with_chat_history)
|
||||
final_answer = response.content
|
||||
# Stream the LLM response token by token
|
||||
final_answer = ""
|
||||
|
||||
async for chunk in llm.astream(messages_with_chat_history):
|
||||
# Extract the content from the chunk
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
token = chunk.content
|
||||
final_answer += token
|
||||
|
||||
# Stream the token to the frontend via custom stream
|
||||
if streaming_service:
|
||||
writer({"yield_value": streaming_service.format_text_chunk(token)})
|
||||
|
||||
return {"final_answer": final_answer}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from typing import Any
|
|||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.streaming_service import StreamingService
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
|
|
@ -21,6 +23,9 @@ class State:
|
|||
# Runtime context
|
||||
db_session: AsyncSession
|
||||
|
||||
# Streaming service for real-time token streaming
|
||||
streaming_service: StreamingService | None = None
|
||||
|
||||
chat_history: list[Any] | None = field(default_factory=list)
|
||||
# OUTPUT: Populated by agent nodes
|
||||
reranked_documents: list[Any] | None = None
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ async def _process_extension_document(
|
|||
individual_document_dict, search_space_id: int, user_id: str
|
||||
):
|
||||
"""Process extension document with new session."""
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
# Reconstruct the document object from dict
|
||||
# You'll need to define the proper model for this
|
||||
|
|
@ -75,8 +75,9 @@ async def _process_extension_document(
|
|||
VisitedWebPageVisitDurationInMilliseconds: str
|
||||
|
||||
class IndividualDocument(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
metadata: DocumentMetadata
|
||||
pageContent: str
|
||||
page_content: str = Field(alias="pageContent")
|
||||
|
||||
individual_document = IndividualDocument(**individual_document_dict)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,13 +17,14 @@ export default function ResearcherPage() {
|
|||
const { search_space_id } = useParams();
|
||||
const router = useRouter();
|
||||
const hasSetInitialConnectors = useRef(false);
|
||||
const hasInitiatedResponse = useRef<string | null>(null);
|
||||
const activeChatId = useAtomValue(activeChatIdAtom);
|
||||
const { data: activeChatState, isFetching: isChatLoading } = useAtomValue(activeChatAtom);
|
||||
const { mutateAsync: createChat } = useAtomValue(createChatMutationAtom);
|
||||
const { mutateAsync: updateChat } = useAtomValue(updateChatMutationAtom);
|
||||
const isNewChat = !activeChatId;
|
||||
|
||||
// Reset the flag when chat ID changes
|
||||
// Reset the flag when chat ID changes (but not hasInitiatedResponse - we need to remember if we already initiated)
|
||||
useEffect(() => {
|
||||
hasSetInitialConnectors.current = false;
|
||||
}, [activeChatId]);
|
||||
|
|
@ -167,10 +168,14 @@ export default function ResearcherPage() {
|
|||
if (chatData.messages && Array.isArray(chatData.messages)) {
|
||||
if (chatData.messages.length === 1 && chatData.messages[0].role === "user") {
|
||||
// Single user message - append to trigger LLM response
|
||||
handler.append({
|
||||
role: "user",
|
||||
content: chatData.messages[0].content,
|
||||
});
|
||||
// Only if we haven't already initiated for this chat and handler doesn't have messages yet
|
||||
if (hasInitiatedResponse.current !== activeChatId && handler.messages.length === 0) {
|
||||
hasInitiatedResponse.current = activeChatId;
|
||||
handler.append({
|
||||
role: "user",
|
||||
content: chatData.messages[0].content,
|
||||
});
|
||||
}
|
||||
} else if (chatData.messages.length > 1) {
|
||||
// Multiple messages - set them all
|
||||
handler.setMessages(chatData.messages);
|
||||
|
|
|
|||
|
|
@ -36,25 +36,25 @@ export function useApiKey(): UseApiKeyReturn {
|
|||
const fallbackCopyTextToClipboard = (text: string) => {
|
||||
const textArea = document.createElement("textarea");
|
||||
textArea.value = text;
|
||||
|
||||
|
||||
// Avoid scrolling to bottom
|
||||
textArea.style.top = "0";
|
||||
textArea.style.left = "0";
|
||||
textArea.style.position = "fixed";
|
||||
textArea.style.opacity = "0";
|
||||
|
||||
|
||||
document.body.appendChild(textArea);
|
||||
textArea.focus();
|
||||
textArea.select();
|
||||
|
||||
|
||||
try {
|
||||
const successful = document.execCommand('copy');
|
||||
const successful = document.execCommand("copy");
|
||||
document.body.removeChild(textArea);
|
||||
|
||||
|
||||
if (successful) {
|
||||
setCopied(true);
|
||||
toast.success("API key copied to clipboard");
|
||||
|
||||
|
||||
setTimeout(() => {
|
||||
setCopied(false);
|
||||
}, 2000);
|
||||
|
|
@ -77,7 +77,7 @@ export function useApiKey(): UseApiKeyReturn {
|
|||
await navigator.clipboard.writeText(apiKey);
|
||||
setCopied(true);
|
||||
toast.success("API key copied to clipboard");
|
||||
|
||||
|
||||
setTimeout(() => {
|
||||
setCopied(false);
|
||||
}, 2000);
|
||||
|
|
@ -97,4 +97,4 @@ export function useApiKey(): UseApiKeyReturn {
|
|||
copied,
|
||||
copyToClipboard,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue