Merge pull request #530 from MODSetter/dev

fix: implement real-time streaming for responses
This commit is contained in:
Rohan Verma 2025-12-05 00:20:32 -08:00 committed by GitHub
commit d29bd6c12f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 72 additions and 49 deletions

View file

@ -1440,7 +1440,12 @@ async def handle_qna_workflow(
} }
# Create the state for the QNA agent (it has a different state structure) # Create the state for the QNA agent (it has a different state structure)
qna_state = {"db_session": state.db_session, "chat_history": state.chat_history} # Pass streaming_service so the QNA agent can stream tokens directly
qna_state = {
"db_session": state.db_session,
"chat_history": state.chat_history,
"streaming_service": streaming_service,
}
try: try:
writer( writer(
@ -1455,36 +1460,26 @@ async def handle_qna_workflow(
complete_content = "" complete_content = ""
captured_reranked_documents = [] captured_reranked_documents = []
# Call the QNA agent with streaming # Call the QNA agent with both custom and values streaming modes
async for _chunk_type, chunk in qna_agent_graph.astream( # - "custom" captures token-by-token streams from answer_question via writer()
qna_state, qna_config, stream_mode=["values"] # - "values" captures state updates including final_answer and reranked_documents
async for stream_mode, chunk in qna_agent_graph.astream(
qna_state, qna_config, stream_mode=["custom", "values"]
): ):
if "final_answer" in chunk: if stream_mode == "custom":
new_content = chunk["final_answer"] # Handle custom stream events (token chunks from answer_question)
if new_content and new_content != complete_content: if isinstance(chunk, dict) and "yield_value" in chunk:
# Extract only the new content (delta) # Forward the streamed token to the parent writer
delta = new_content[len(complete_content) :] writer(chunk)
complete_content = new_content elif stream_mode == "values" and isinstance(chunk, dict):
# Handle state value updates
# Capture the final answer from state
if chunk.get("final_answer"):
complete_content = chunk["final_answer"]
# Stream the real-time answer if there's new content # Capture reranked documents from QNA agent for further question generation
if delta: if chunk.get("reranked_documents"):
# Update terminal with progress captured_reranked_documents = chunk["reranked_documents"]
word_count = len(complete_content.split())
writer(
{
"yield_value": streaming_service.format_terminal_info_delta(
f"✍️ Writing answer... ({word_count} words)"
)
}
)
writer(
{"yield_value": streaming_service.format_text_chunk(delta)}
)
# Capture reranked documents from QNA agent for further question generation
if "reranked_documents" in chunk:
captured_reranked_documents = chunk["reranked_documents"]
# Set default if no content was received # Set default if no content was received
if not complete_content: if not complete_content:

View file

@ -3,6 +3,7 @@ from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy import select from sqlalchemy import select
from app.db import SearchSpace from app.db import SearchSpace
@ -129,9 +130,11 @@ async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, An
return {"reranked_documents": documents} return {"reranked_documents": documents}
async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any]: async def answer_question(
state: State, config: RunnableConfig, writer: StreamWriter
) -> dict[str, Any]:
""" """
Answer the user's question using the provided documents. Answer the user's question using the provided documents with real-time streaming.
This node takes the relevant documents provided in the configuration and uses This node takes the relevant documents provided in the configuration and uses
an LLM to generate a comprehensive answer to the user's question with an LLM to generate a comprehensive answer to the user's question with
@ -139,6 +142,8 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
documents. If no documents are provided, it will use chat history to generate documents. If no documents are provided, it will use chat history to generate
an answer. an answer.
The response is streamed token-by-token for real-time updates to the frontend.
Returns: Returns:
Dict containing the final answer in the "final_answer" key. Dict containing the final answer in the "final_answer" key.
""" """
@ -151,6 +156,9 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
search_space_id = configuration.search_space_id search_space_id = configuration.search_space_id
language = configuration.language language = configuration.language
# Get streaming service from state
streaming_service = state.streaming_service
# Fetch search space to get QnA configuration # Fetch search space to get QnA configuration
result = await state.db_session.execute( result = await state.db_session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id) select(SearchSpace).where(SearchSpace.id == search_space_id)
@ -279,8 +287,17 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
total_tokens = calculate_token_count(messages_with_chat_history, llm.model) total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
print(f"Final token count: {total_tokens}") print(f"Final token count: {total_tokens}")
# Call the LLM and get the response # Stream the LLM response token by token
response = await llm.ainvoke(messages_with_chat_history) final_answer = ""
final_answer = response.content
async for chunk in llm.astream(messages_with_chat_history):
# Extract the content from the chunk
if hasattr(chunk, "content") and chunk.content:
token = chunk.content
final_answer += token
# Stream the token to the frontend via custom stream
if streaming_service:
writer({"yield_value": streaming_service.format_text_chunk(token)})
return {"final_answer": final_answer} return {"final_answer": final_answer}

View file

@ -7,6 +7,8 @@ from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.services.streaming_service import StreamingService
@dataclass @dataclass
class State: class State:
@ -21,6 +23,9 @@ class State:
# Runtime context # Runtime context
db_session: AsyncSession db_session: AsyncSession
# Streaming service for real-time token streaming
streaming_service: StreamingService | None = None
chat_history: list[Any] | None = field(default_factory=list) chat_history: list[Any] | None = field(default_factory=list)
# OUTPUT: Populated by agent nodes # OUTPUT: Populated by agent nodes
reranked_documents: list[Any] | None = None reranked_documents: list[Any] | None = None

View file

@ -62,7 +62,7 @@ async def _process_extension_document(
individual_document_dict, search_space_id: int, user_id: str individual_document_dict, search_space_id: int, user_id: str
): ):
"""Process extension document with new session.""" """Process extension document with new session."""
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict, Field
# Reconstruct the document object from dict # Reconstruct the document object from dict
# You'll need to define the proper model for this # You'll need to define the proper model for this
@ -75,8 +75,9 @@ async def _process_extension_document(
VisitedWebPageVisitDurationInMilliseconds: str VisitedWebPageVisitDurationInMilliseconds: str
class IndividualDocument(BaseModel): class IndividualDocument(BaseModel):
model_config = ConfigDict(populate_by_name=True)
metadata: DocumentMetadata metadata: DocumentMetadata
pageContent: str page_content: str = Field(alias="pageContent")
individual_document = IndividualDocument(**individual_document_dict) individual_document = IndividualDocument(**individual_document_dict)

View file

@ -17,13 +17,14 @@ export default function ResearcherPage() {
const { search_space_id } = useParams(); const { search_space_id } = useParams();
const router = useRouter(); const router = useRouter();
const hasSetInitialConnectors = useRef(false); const hasSetInitialConnectors = useRef(false);
const hasInitiatedResponse = useRef<string | null>(null);
const activeChatId = useAtomValue(activeChatIdAtom); const activeChatId = useAtomValue(activeChatIdAtom);
const { data: activeChatState, isFetching: isChatLoading } = useAtomValue(activeChatAtom); const { data: activeChatState, isFetching: isChatLoading } = useAtomValue(activeChatAtom);
const { mutateAsync: createChat } = useAtomValue(createChatMutationAtom); const { mutateAsync: createChat } = useAtomValue(createChatMutationAtom);
const { mutateAsync: updateChat } = useAtomValue(updateChatMutationAtom); const { mutateAsync: updateChat } = useAtomValue(updateChatMutationAtom);
const isNewChat = !activeChatId; const isNewChat = !activeChatId;
// Reset the flag when chat ID changes // Reset the flag when chat ID changes (but not hasInitiatedResponse - we need to remember if we already initiated)
useEffect(() => { useEffect(() => {
hasSetInitialConnectors.current = false; hasSetInitialConnectors.current = false;
}, [activeChatId]); }, [activeChatId]);
@ -167,10 +168,14 @@ export default function ResearcherPage() {
if (chatData.messages && Array.isArray(chatData.messages)) { if (chatData.messages && Array.isArray(chatData.messages)) {
if (chatData.messages.length === 1 && chatData.messages[0].role === "user") { if (chatData.messages.length === 1 && chatData.messages[0].role === "user") {
// Single user message - append to trigger LLM response // Single user message - append to trigger LLM response
handler.append({ // Only if we haven't already initiated for this chat and handler doesn't have messages yet
role: "user", if (hasInitiatedResponse.current !== activeChatId && handler.messages.length === 0) {
content: chatData.messages[0].content, hasInitiatedResponse.current = activeChatId;
}); handler.append({
role: "user",
content: chatData.messages[0].content,
});
}
} else if (chatData.messages.length > 1) { } else if (chatData.messages.length > 1) {
// Multiple messages - set them all // Multiple messages - set them all
handler.setMessages(chatData.messages); handler.setMessages(chatData.messages);

View file

@ -36,25 +36,25 @@ export function useApiKey(): UseApiKeyReturn {
const fallbackCopyTextToClipboard = (text: string) => { const fallbackCopyTextToClipboard = (text: string) => {
const textArea = document.createElement("textarea"); const textArea = document.createElement("textarea");
textArea.value = text; textArea.value = text;
// Avoid scrolling to bottom // Avoid scrolling to bottom
textArea.style.top = "0"; textArea.style.top = "0";
textArea.style.left = "0"; textArea.style.left = "0";
textArea.style.position = "fixed"; textArea.style.position = "fixed";
textArea.style.opacity = "0"; textArea.style.opacity = "0";
document.body.appendChild(textArea); document.body.appendChild(textArea);
textArea.focus(); textArea.focus();
textArea.select(); textArea.select();
try { try {
const successful = document.execCommand('copy'); const successful = document.execCommand("copy");
document.body.removeChild(textArea); document.body.removeChild(textArea);
if (successful) { if (successful) {
setCopied(true); setCopied(true);
toast.success("API key copied to clipboard"); toast.success("API key copied to clipboard");
setTimeout(() => { setTimeout(() => {
setCopied(false); setCopied(false);
}, 2000); }, 2000);
@ -77,7 +77,7 @@ export function useApiKey(): UseApiKeyReturn {
await navigator.clipboard.writeText(apiKey); await navigator.clipboard.writeText(apiKey);
setCopied(true); setCopied(true);
toast.success("API key copied to clipboard"); toast.success("API key copied to clipboard");
setTimeout(() => { setTimeout(() => {
setCopied(false); setCopied(false);
}, 2000); }, 2000);
@ -97,4 +97,4 @@ export function useApiKey(): UseApiKeyReturn {
copied, copied,
copyToClipboard, copyToClipboard,
}; };
} }