mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
Merge pull request #530 from MODSetter/dev
fix: implement real-time streaming for responses
This commit is contained in:
commit
d29bd6c12f
6 changed files with 72 additions and 49 deletions
|
|
@ -1440,7 +1440,12 @@ async def handle_qna_workflow(
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create the state for the QNA agent (it has a different state structure)
|
# Create the state for the QNA agent (it has a different state structure)
|
||||||
qna_state = {"db_session": state.db_session, "chat_history": state.chat_history}
|
# Pass streaming_service so the QNA agent can stream tokens directly
|
||||||
|
qna_state = {
|
||||||
|
"db_session": state.db_session,
|
||||||
|
"chat_history": state.chat_history,
|
||||||
|
"streaming_service": streaming_service,
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
writer(
|
writer(
|
||||||
|
|
@ -1455,36 +1460,26 @@ async def handle_qna_workflow(
|
||||||
complete_content = ""
|
complete_content = ""
|
||||||
captured_reranked_documents = []
|
captured_reranked_documents = []
|
||||||
|
|
||||||
# Call the QNA agent with streaming
|
# Call the QNA agent with both custom and values streaming modes
|
||||||
async for _chunk_type, chunk in qna_agent_graph.astream(
|
# - "custom" captures token-by-token streams from answer_question via writer()
|
||||||
qna_state, qna_config, stream_mode=["values"]
|
# - "values" captures state updates including final_answer and reranked_documents
|
||||||
|
async for stream_mode, chunk in qna_agent_graph.astream(
|
||||||
|
qna_state, qna_config, stream_mode=["custom", "values"]
|
||||||
):
|
):
|
||||||
if "final_answer" in chunk:
|
if stream_mode == "custom":
|
||||||
new_content = chunk["final_answer"]
|
# Handle custom stream events (token chunks from answer_question)
|
||||||
if new_content and new_content != complete_content:
|
if isinstance(chunk, dict) and "yield_value" in chunk:
|
||||||
# Extract only the new content (delta)
|
# Forward the streamed token to the parent writer
|
||||||
delta = new_content[len(complete_content) :]
|
writer(chunk)
|
||||||
complete_content = new_content
|
elif stream_mode == "values" and isinstance(chunk, dict):
|
||||||
|
# Handle state value updates
|
||||||
|
# Capture the final answer from state
|
||||||
|
if chunk.get("final_answer"):
|
||||||
|
complete_content = chunk["final_answer"]
|
||||||
|
|
||||||
# Stream the real-time answer if there's new content
|
# Capture reranked documents from QNA agent for further question generation
|
||||||
if delta:
|
if chunk.get("reranked_documents"):
|
||||||
# Update terminal with progress
|
captured_reranked_documents = chunk["reranked_documents"]
|
||||||
word_count = len(complete_content.split())
|
|
||||||
writer(
|
|
||||||
{
|
|
||||||
"yield_value": streaming_service.format_terminal_info_delta(
|
|
||||||
f"✍️ Writing answer... ({word_count} words)"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
writer(
|
|
||||||
{"yield_value": streaming_service.format_text_chunk(delta)}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Capture reranked documents from QNA agent for further question generation
|
|
||||||
if "reranked_documents" in chunk:
|
|
||||||
captured_reranked_documents = chunk["reranked_documents"]
|
|
||||||
|
|
||||||
# Set default if no content was received
|
# Set default if no content was received
|
||||||
if not complete_content:
|
if not complete_content:
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langgraph.types import StreamWriter
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.db import SearchSpace
|
from app.db import SearchSpace
|
||||||
|
|
@ -129,9 +130,11 @@ async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, An
|
||||||
return {"reranked_documents": documents}
|
return {"reranked_documents": documents}
|
||||||
|
|
||||||
|
|
||||||
async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any]:
|
async def answer_question(
|
||||||
|
state: State, config: RunnableConfig, writer: StreamWriter
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Answer the user's question using the provided documents.
|
Answer the user's question using the provided documents with real-time streaming.
|
||||||
|
|
||||||
This node takes the relevant documents provided in the configuration and uses
|
This node takes the relevant documents provided in the configuration and uses
|
||||||
an LLM to generate a comprehensive answer to the user's question with
|
an LLM to generate a comprehensive answer to the user's question with
|
||||||
|
|
@ -139,6 +142,8 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
|
||||||
documents. If no documents are provided, it will use chat history to generate
|
documents. If no documents are provided, it will use chat history to generate
|
||||||
an answer.
|
an answer.
|
||||||
|
|
||||||
|
The response is streamed token-by-token for real-time updates to the frontend.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict containing the final answer in the "final_answer" key.
|
Dict containing the final answer in the "final_answer" key.
|
||||||
"""
|
"""
|
||||||
|
|
@ -151,6 +156,9 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
|
||||||
search_space_id = configuration.search_space_id
|
search_space_id = configuration.search_space_id
|
||||||
language = configuration.language
|
language = configuration.language
|
||||||
|
|
||||||
|
# Get streaming service from state
|
||||||
|
streaming_service = state.streaming_service
|
||||||
|
|
||||||
# Fetch search space to get QnA configuration
|
# Fetch search space to get QnA configuration
|
||||||
result = await state.db_session.execute(
|
result = await state.db_session.execute(
|
||||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||||
|
|
@ -279,8 +287,17 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
|
||||||
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
|
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
|
||||||
print(f"Final token count: {total_tokens}")
|
print(f"Final token count: {total_tokens}")
|
||||||
|
|
||||||
# Call the LLM and get the response
|
# Stream the LLM response token by token
|
||||||
response = await llm.ainvoke(messages_with_chat_history)
|
final_answer = ""
|
||||||
final_answer = response.content
|
|
||||||
|
async for chunk in llm.astream(messages_with_chat_history):
|
||||||
|
# Extract the content from the chunk
|
||||||
|
if hasattr(chunk, "content") and chunk.content:
|
||||||
|
token = chunk.content
|
||||||
|
final_answer += token
|
||||||
|
|
||||||
|
# Stream the token to the frontend via custom stream
|
||||||
|
if streaming_service:
|
||||||
|
writer({"yield_value": streaming_service.format_text_chunk(token)})
|
||||||
|
|
||||||
return {"final_answer": final_answer}
|
return {"final_answer": final_answer}
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ from typing import Any
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.services.streaming_service import StreamingService
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class State:
|
class State:
|
||||||
|
|
@ -21,6 +23,9 @@ class State:
|
||||||
# Runtime context
|
# Runtime context
|
||||||
db_session: AsyncSession
|
db_session: AsyncSession
|
||||||
|
|
||||||
|
# Streaming service for real-time token streaming
|
||||||
|
streaming_service: StreamingService | None = None
|
||||||
|
|
||||||
chat_history: list[Any] | None = field(default_factory=list)
|
chat_history: list[Any] | None = field(default_factory=list)
|
||||||
# OUTPUT: Populated by agent nodes
|
# OUTPUT: Populated by agent nodes
|
||||||
reranked_documents: list[Any] | None = None
|
reranked_documents: list[Any] | None = None
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ async def _process_extension_document(
|
||||||
individual_document_dict, search_space_id: int, user_id: str
|
individual_document_dict, search_space_id: int, user_id: str
|
||||||
):
|
):
|
||||||
"""Process extension document with new session."""
|
"""Process extension document with new session."""
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
# Reconstruct the document object from dict
|
# Reconstruct the document object from dict
|
||||||
# You'll need to define the proper model for this
|
# You'll need to define the proper model for this
|
||||||
|
|
@ -75,8 +75,9 @@ async def _process_extension_document(
|
||||||
VisitedWebPageVisitDurationInMilliseconds: str
|
VisitedWebPageVisitDurationInMilliseconds: str
|
||||||
|
|
||||||
class IndividualDocument(BaseModel):
|
class IndividualDocument(BaseModel):
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
metadata: DocumentMetadata
|
metadata: DocumentMetadata
|
||||||
pageContent: str
|
page_content: str = Field(alias="pageContent")
|
||||||
|
|
||||||
individual_document = IndividualDocument(**individual_document_dict)
|
individual_document = IndividualDocument(**individual_document_dict)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,13 +17,14 @@ export default function ResearcherPage() {
|
||||||
const { search_space_id } = useParams();
|
const { search_space_id } = useParams();
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const hasSetInitialConnectors = useRef(false);
|
const hasSetInitialConnectors = useRef(false);
|
||||||
|
const hasInitiatedResponse = useRef<string | null>(null);
|
||||||
const activeChatId = useAtomValue(activeChatIdAtom);
|
const activeChatId = useAtomValue(activeChatIdAtom);
|
||||||
const { data: activeChatState, isFetching: isChatLoading } = useAtomValue(activeChatAtom);
|
const { data: activeChatState, isFetching: isChatLoading } = useAtomValue(activeChatAtom);
|
||||||
const { mutateAsync: createChat } = useAtomValue(createChatMutationAtom);
|
const { mutateAsync: createChat } = useAtomValue(createChatMutationAtom);
|
||||||
const { mutateAsync: updateChat } = useAtomValue(updateChatMutationAtom);
|
const { mutateAsync: updateChat } = useAtomValue(updateChatMutationAtom);
|
||||||
const isNewChat = !activeChatId;
|
const isNewChat = !activeChatId;
|
||||||
|
|
||||||
// Reset the flag when chat ID changes
|
// Reset the flag when chat ID changes (but not hasInitiatedResponse - we need to remember if we already initiated)
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
hasSetInitialConnectors.current = false;
|
hasSetInitialConnectors.current = false;
|
||||||
}, [activeChatId]);
|
}, [activeChatId]);
|
||||||
|
|
@ -167,10 +168,14 @@ export default function ResearcherPage() {
|
||||||
if (chatData.messages && Array.isArray(chatData.messages)) {
|
if (chatData.messages && Array.isArray(chatData.messages)) {
|
||||||
if (chatData.messages.length === 1 && chatData.messages[0].role === "user") {
|
if (chatData.messages.length === 1 && chatData.messages[0].role === "user") {
|
||||||
// Single user message - append to trigger LLM response
|
// Single user message - append to trigger LLM response
|
||||||
handler.append({
|
// Only if we haven't already initiated for this chat and handler doesn't have messages yet
|
||||||
role: "user",
|
if (hasInitiatedResponse.current !== activeChatId && handler.messages.length === 0) {
|
||||||
content: chatData.messages[0].content,
|
hasInitiatedResponse.current = activeChatId;
|
||||||
});
|
handler.append({
|
||||||
|
role: "user",
|
||||||
|
content: chatData.messages[0].content,
|
||||||
|
});
|
||||||
|
}
|
||||||
} else if (chatData.messages.length > 1) {
|
} else if (chatData.messages.length > 1) {
|
||||||
// Multiple messages - set them all
|
// Multiple messages - set them all
|
||||||
handler.setMessages(chatData.messages);
|
handler.setMessages(chatData.messages);
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ export function useApiKey(): UseApiKeyReturn {
|
||||||
textArea.select();
|
textArea.select();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const successful = document.execCommand('copy');
|
const successful = document.execCommand("copy");
|
||||||
document.body.removeChild(textArea);
|
document.body.removeChild(textArea);
|
||||||
|
|
||||||
if (successful) {
|
if (successful) {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue