diff --git a/surfsense_backend/app/agents/researcher/nodes.py b/surfsense_backend/app/agents/researcher/nodes.py index c53e3348f..4b2f4b0e6 100644 --- a/surfsense_backend/app/agents/researcher/nodes.py +++ b/surfsense_backend/app/agents/researcher/nodes.py @@ -1440,7 +1440,12 @@ async def handle_qna_workflow( } # Create the state for the QNA agent (it has a different state structure) - qna_state = {"db_session": state.db_session, "chat_history": state.chat_history} + # Pass streaming_service so the QNA agent can stream tokens directly + qna_state = { + "db_session": state.db_session, + "chat_history": state.chat_history, + "streaming_service": streaming_service, + } try: writer( @@ -1455,36 +1460,26 @@ async def handle_qna_workflow( complete_content = "" captured_reranked_documents = [] - # Call the QNA agent with streaming - async for _chunk_type, chunk in qna_agent_graph.astream( - qna_state, qna_config, stream_mode=["values"] + # Call the QNA agent with both custom and values streaming modes + # - "custom" captures token-by-token streams from answer_question via writer() + # - "values" captures state updates including final_answer and reranked_documents + async for stream_mode, chunk in qna_agent_graph.astream( + qna_state, qna_config, stream_mode=["custom", "values"] ): - if "final_answer" in chunk: - new_content = chunk["final_answer"] - if new_content and new_content != complete_content: - # Extract only the new content (delta) - delta = new_content[len(complete_content) :] - complete_content = new_content + if stream_mode == "custom": + # Handle custom stream events (token chunks from answer_question) + if isinstance(chunk, dict) and "yield_value" in chunk: + # Forward the streamed token to the parent writer + writer(chunk) + elif stream_mode == "values" and isinstance(chunk, dict): + # Handle state value updates + # Capture the final answer from state + if chunk.get("final_answer"): + complete_content = chunk["final_answer"] - # Stream the real-time answer if there's new content - if delta: - # Update terminal with progress - word_count = len(complete_content.split()) - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"✍️ Writing answer... ({word_count} words)" - ) - } - ) - - writer( - {"yield_value": streaming_service.format_text_chunk(delta)} - ) - - # Capture reranked documents from QNA agent for further question generation - if "reranked_documents" in chunk: - captured_reranked_documents = chunk["reranked_documents"] + # Capture reranked documents from QNA agent for further question generation + if chunk.get("reranked_documents"): + captured_reranked_documents = chunk["reranked_documents"] # Set default if no content was received if not complete_content: diff --git a/surfsense_backend/app/agents/researcher/qna_agent/nodes.py b/surfsense_backend/app/agents/researcher/qna_agent/nodes.py index 37bdbc362..35f01146b 100644 --- a/surfsense_backend/app/agents/researcher/qna_agent/nodes.py +++ b/surfsense_backend/app/agents/researcher/qna_agent/nodes.py @@ -3,6 +3,7 @@ from typing import Any from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.runnables import RunnableConfig +from langgraph.types import StreamWriter from sqlalchemy import select from app.db import SearchSpace @@ -129,9 +130,11 @@ async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, An return {"reranked_documents": documents} -async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any]: +async def answer_question( + state: State, config: RunnableConfig, writer: StreamWriter +) -> dict[str, Any]: """ - Answer the user's question using the provided documents. + Answer the user's question using the provided documents with real-time streaming. This node takes the relevant documents provided in the configuration and uses an LLM to generate a comprehensive answer to the user's question with @@ -139,6 +142,8 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any documents. If no documents are provided, it will use chat history to generate an answer. + The response is streamed token-by-token for real-time updates to the frontend. + Returns: Dict containing the final answer in the "final_answer" key. """ @@ -151,6 +156,9 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any search_space_id = configuration.search_space_id language = configuration.language + # Get streaming service from state + streaming_service = state.streaming_service + # Fetch search space to get QnA configuration result = await state.db_session.execute( select(SearchSpace).where(SearchSpace.id == search_space_id) @@ -279,8 +287,17 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any total_tokens = calculate_token_count(messages_with_chat_history, llm.model) print(f"Final token count: {total_tokens}") - # Call the LLM and get the response - response = await llm.ainvoke(messages_with_chat_history) - final_answer = response.content + # Stream the LLM response token by token + final_answer = "" + + async for chunk in llm.astream(messages_with_chat_history): + # Extract the content from the chunk + if hasattr(chunk, "content") and chunk.content: + token = chunk.content + final_answer += token + + # Stream the token to the frontend via custom stream + if streaming_service: + writer({"yield_value": streaming_service.format_text_chunk(token)}) return {"final_answer": final_answer} diff --git a/surfsense_backend/app/agents/researcher/qna_agent/state.py b/surfsense_backend/app/agents/researcher/qna_agent/state.py index f6cc7b1ba..4113b9286 100644 --- a/surfsense_backend/app/agents/researcher/qna_agent/state.py +++ b/surfsense_backend/app/agents/researcher/qna_agent/state.py @@ -7,6 +7,8 @@ from typing import Any from sqlalchemy.ext.asyncio import AsyncSession +from app.services.streaming_service import StreamingService + @dataclass class State: @@ -21,6 +23,9 @@ class State: # Runtime context db_session: AsyncSession + # Streaming service for real-time token streaming + streaming_service: StreamingService | None = None + chat_history: list[Any] | None = field(default_factory=list) # OUTPUT: Populated by agent nodes reranked_documents: list[Any] | None = None diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py index a7b750673..5b7f9ce13 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py @@ -62,7 +62,7 @@ async def _process_extension_document( individual_document_dict, search_space_id: int, user_id: str ): """Process extension document with new session.""" - from pydantic import BaseModel + from pydantic import BaseModel, ConfigDict, Field # Reconstruct the document object from dict # You'll need to define the proper model for this @@ -75,8 +75,9 @@ async def _process_extension_document( VisitedWebPageVisitDurationInMilliseconds: str class IndividualDocument(BaseModel): + model_config = ConfigDict(populate_by_name=True) metadata: DocumentMetadata - pageContent: str + page_content: str = Field(alias="pageContent") individual_document = IndividualDocument(**individual_document_dict) diff --git a/surfsense_web/app/dashboard/[search_space_id]/researcher/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/researcher/[[...chat_id]]/page.tsx index 7481ddaa2..1a9a607fb 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/researcher/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/researcher/[[...chat_id]]/page.tsx @@ -17,13 +17,14 @@ export default function ResearcherPage() { const { search_space_id } = useParams(); const router = useRouter(); const hasSetInitialConnectors = useRef(false); + const hasInitiatedResponse = useRef(null); const activeChatId = useAtomValue(activeChatIdAtom); const { data: activeChatState, isFetching: isChatLoading } = useAtomValue(activeChatAtom); const { mutateAsync: createChat } = useAtomValue(createChatMutationAtom); const { mutateAsync: updateChat } = useAtomValue(updateChatMutationAtom); const isNewChat = !activeChatId; - // Reset the flag when chat ID changes + // Reset the flag when chat ID changes (but not hasInitiatedResponse - we need to remember if we already initiated) useEffect(() => { hasSetInitialConnectors.current = false; }, [activeChatId]); @@ -167,10 +168,14 @@ export default function ResearcherPage() { if (chatData.messages && Array.isArray(chatData.messages)) { if (chatData.messages.length === 1 && chatData.messages[0].role === "user") { // Single user message - append to trigger LLM response - handler.append({ - role: "user", - content: chatData.messages[0].content, - }); + // Only if we haven't already initiated for this chat and handler doesn't have messages yet + if (hasInitiatedResponse.current !== activeChatId && handler.messages.length === 0) { + hasInitiatedResponse.current = activeChatId; + handler.append({ + role: "user", + content: chatData.messages[0].content, + }); + } } else if (chatData.messages.length > 1) { // Multiple messages - set them all handler.setMessages(chatData.messages);