Merge pull request #530 from MODSetter/dev

fix: implement real-time streaming for responses
This commit is contained in:
Rohan Verma 2025-12-05 00:20:32 -08:00 committed by GitHub
commit d29bd6c12f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 72 additions and 49 deletions

View file

@ -1440,7 +1440,12 @@ async def handle_qna_workflow(
}
# Create the state for the QNA agent (it has a different state structure)
qna_state = {"db_session": state.db_session, "chat_history": state.chat_history}
# Pass streaming_service so the QNA agent can stream tokens directly
qna_state = {
"db_session": state.db_session,
"chat_history": state.chat_history,
"streaming_service": streaming_service,
}
try:
writer(
@ -1455,36 +1460,26 @@ async def handle_qna_workflow(
complete_content = ""
captured_reranked_documents = []
# Call the QNA agent with streaming
async for _chunk_type, chunk in qna_agent_graph.astream(
qna_state, qna_config, stream_mode=["values"]
# Call the QNA agent with both custom and values streaming modes
# - "custom" captures token-by-token streams from answer_question via writer()
# - "values" captures state updates including final_answer and reranked_documents
async for stream_mode, chunk in qna_agent_graph.astream(
qna_state, qna_config, stream_mode=["custom", "values"]
):
if "final_answer" in chunk:
new_content = chunk["final_answer"]
if new_content and new_content != complete_content:
# Extract only the new content (delta)
delta = new_content[len(complete_content) :]
complete_content = new_content
if stream_mode == "custom":
# Handle custom stream events (token chunks from answer_question)
if isinstance(chunk, dict) and "yield_value" in chunk:
# Forward the streamed token to the parent writer
writer(chunk)
elif stream_mode == "values" and isinstance(chunk, dict):
# Handle state value updates
# Capture the final answer from state
if chunk.get("final_answer"):
complete_content = chunk["final_answer"]
# Stream the real-time answer if there's new content
if delta:
# Update terminal with progress
word_count = len(complete_content.split())
writer(
{
"yield_value": streaming_service.format_terminal_info_delta(
f"✍️ Writing answer... ({word_count} words)"
)
}
)
writer(
{"yield_value": streaming_service.format_text_chunk(delta)}
)
# Capture reranked documents from QNA agent for further question generation
if "reranked_documents" in chunk:
captured_reranked_documents = chunk["reranked_documents"]
# Capture reranked documents from QNA agent for further question generation
if chunk.get("reranked_documents"):
captured_reranked_documents = chunk["reranked_documents"]
# Set default if no content was received
if not complete_content:

View file

@ -3,6 +3,7 @@ from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy import select
from app.db import SearchSpace
@ -129,9 +130,11 @@ async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, An
return {"reranked_documents": documents}
async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any]:
async def answer_question(
state: State, config: RunnableConfig, writer: StreamWriter
) -> dict[str, Any]:
"""
Answer the user's question using the provided documents.
Answer the user's question using the provided documents with real-time streaming.
This node takes the relevant documents provided in the configuration and uses
an LLM to generate a comprehensive answer to the user's question with
@ -139,6 +142,8 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
documents. If no documents are provided, it will use chat history to generate
an answer.
The response is streamed token-by-token for real-time updates to the frontend.
Returns:
Dict containing the final answer in the "final_answer" key.
"""
@ -151,6 +156,9 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
search_space_id = configuration.search_space_id
language = configuration.language
# Get streaming service from state
streaming_service = state.streaming_service
# Fetch search space to get QnA configuration
result = await state.db_session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
@ -279,8 +287,17 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
print(f"Final token count: {total_tokens}")
# Call the LLM and get the response
response = await llm.ainvoke(messages_with_chat_history)
final_answer = response.content
# Stream the LLM response token by token
final_answer = ""
async for chunk in llm.astream(messages_with_chat_history):
# Extract the content from the chunk
if hasattr(chunk, "content") and chunk.content:
token = chunk.content
final_answer += token
# Stream the token to the frontend via custom stream
if streaming_service:
writer({"yield_value": streaming_service.format_text_chunk(token)})
return {"final_answer": final_answer}

View file

@ -7,6 +7,8 @@ from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.streaming_service import StreamingService
@dataclass
class State:
@ -21,6 +23,9 @@ class State:
# Runtime context
db_session: AsyncSession
# Streaming service for real-time token streaming
streaming_service: StreamingService | None = None
chat_history: list[Any] | None = field(default_factory=list)
# OUTPUT: Populated by agent nodes
reranked_documents: list[Any] | None = None

View file

@ -62,7 +62,7 @@ async def _process_extension_document(
individual_document_dict, search_space_id: int, user_id: str
):
"""Process extension document with new session."""
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict, Field
# Reconstruct the document object from dict
# You'll need to define the proper model for this
@ -75,8 +75,9 @@ async def _process_extension_document(
VisitedWebPageVisitDurationInMilliseconds: str
class IndividualDocument(BaseModel):
model_config = ConfigDict(populate_by_name=True)
metadata: DocumentMetadata
pageContent: str
page_content: str = Field(alias="pageContent")
individual_document = IndividualDocument(**individual_document_dict)

View file

@ -17,13 +17,14 @@ export default function ResearcherPage() {
const { search_space_id } = useParams();
const router = useRouter();
const hasSetInitialConnectors = useRef(false);
const hasInitiatedResponse = useRef<string | null>(null);
const activeChatId = useAtomValue(activeChatIdAtom);
const { data: activeChatState, isFetching: isChatLoading } = useAtomValue(activeChatAtom);
const { mutateAsync: createChat } = useAtomValue(createChatMutationAtom);
const { mutateAsync: updateChat } = useAtomValue(updateChatMutationAtom);
const isNewChat = !activeChatId;
// Reset the flag when chat ID changes
// Reset the flag when chat ID changes (but not hasInitiatedResponse - we need to remember if we already initiated)
useEffect(() => {
hasSetInitialConnectors.current = false;
}, [activeChatId]);
@ -167,10 +168,14 @@ export default function ResearcherPage() {
if (chatData.messages && Array.isArray(chatData.messages)) {
if (chatData.messages.length === 1 && chatData.messages[0].role === "user") {
// Single user message - append to trigger LLM response
handler.append({
role: "user",
content: chatData.messages[0].content,
});
// Only if we haven't already initiated for this chat and handler doesn't have messages yet
if (hasInitiatedResponse.current !== activeChatId && handler.messages.length === 0) {
hasInitiatedResponse.current = activeChatId;
handler.append({
role: "user",
content: chatData.messages[0].content,
});
}
} else if (chatData.messages.length > 1) {
// Multiple messages - set them all
handler.setMessages(chatData.messages);

View file

@ -48,7 +48,7 @@ export function useApiKey(): UseApiKeyReturn {
textArea.select();
try {
const successful = document.execCommand('copy');
const successful = document.execCommand("copy");
document.body.removeChild(textArea);
if (successful) {