Streaming rag responses (#568)

* Tech spec for streaming RAG

* Support for streaming Graph/Doc RAG
This commit is contained in:
cybermaggedon 2025-11-26 19:47:39 +00:00 committed by GitHub
parent b1cc724f7d
commit 1948edaa50
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 3087 additions and 94 deletions

View file

@ -202,12 +202,16 @@ class StreamingReActParser:
# Find which comes first
if args_idx >= 0 and (newline_idx < 0 or args_idx < newline_idx):
# Args delimiter found first
self.action_buffer = self.line_buffer[:args_idx].strip().strip('"')
# Only set action_buffer if not already set (to avoid overwriting with empty string)
if not self.action_buffer:
self.action_buffer = self.line_buffer[:args_idx].strip().strip('"')
self.line_buffer = self.line_buffer[args_idx + len(self.ARGS_DELIMITER):].lstrip()
self.state = ParserState.ARGS
elif newline_idx >= 0:
# Newline found, action name complete
self.action_buffer = self.line_buffer[:newline_idx].strip().strip('"')
# Only set action_buffer if not already set
if not self.action_buffer:
self.action_buffer = self.line_buffer[:newline_idx].strip().strip('"')
self.line_buffer = self.line_buffer[newline_idx + 1:]
# Stay in ACTION state or move to ARGS if we find delimiter
# Actually, check if next line has Args:

View file

@ -68,7 +68,7 @@ class DocumentRag:
async def query(
self, query, user="trustgraph", collection="default",
doc_limit=20,
doc_limit=20, streaming=False, chunk_callback=None,
):
if self.verbose:
@ -86,10 +86,18 @@ class DocumentRag:
logger.debug(f"Documents: {docs}")
logger.debug(f"Query: {query}")
resp = await self.prompt_client.document_prompt(
query = query,
documents = docs
)
if streaming and chunk_callback:
resp = await self.prompt_client.document_prompt(
query=query,
documents=docs,
streaming=True,
chunk_callback=chunk_callback
)
else:
resp = await self.prompt_client.document_prompt(
query=query,
documents=docs
)
if self.verbose:
logger.debug("Query processing complete")

View file

@ -92,20 +92,56 @@ class Processor(FlowProcessor):
else:
doc_limit = self.doc_limit
response = await self.rag.query(
v.query,
user=v.user,
collection=v.collection,
doc_limit=doc_limit
)
# Check if streaming is requested
if v.streaming:
# Define async callback for streaming chunks
async def send_chunk(chunk):
await flow("response").send(
DocumentRagResponse(
chunk=chunk,
end_of_stream=False,
response=None,
error=None
),
properties={"id": id}
)
await flow("response").send(
DocumentRagResponse(
response = response,
error = None
),
properties = {"id": id}
)
# Query with streaming enabled
full_response = await self.rag.query(
v.query,
user=v.user,
collection=v.collection,
doc_limit=doc_limit,
streaming=True,
chunk_callback=send_chunk,
)
# Send final message with complete response
await flow("response").send(
DocumentRagResponse(
chunk=None,
end_of_stream=True,
response=full_response,
error=None
),
properties={"id": id}
)
else:
# Non-streaming path (existing behavior)
response = await self.rag.query(
v.query,
user=v.user,
collection=v.collection,
doc_limit=doc_limit
)
await flow("response").send(
DocumentRagResponse(
response = response,
error = None
),
properties = {"id": id}
)
logger.info("Request processing complete")
@ -115,14 +151,21 @@ class Processor(FlowProcessor):
logger.debug("Sending error response...")
await flow("response").send(
DocumentRagResponse(
response = None,
error = Error(
type = "document-rag-error",
message = str(e),
),
# Send error response with end_of_stream flag if streaming was requested
error_response = DocumentRagResponse(
response = None,
error = Error(
type = "document-rag-error",
message = str(e),
),
)
# If streaming was requested, indicate stream end
if v.streaming:
error_response.end_of_stream = True
await flow("response").send(
error_response,
properties = {"id": id}
)

View file

@ -316,7 +316,7 @@ class GraphRag:
async def query(
self, query, user = "trustgraph", collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2,
max_path_length = 2, streaming = False, chunk_callback = None,
):
if self.verbose:
@ -337,7 +337,14 @@ class GraphRag:
logger.debug(f"Knowledge graph: {kg}")
logger.debug(f"Query: {query}")
resp = await self.prompt_client.kg_prompt(query, kg)
if streaming and chunk_callback:
resp = await self.prompt_client.kg_prompt(
query, kg,
streaming=True,
chunk_callback=chunk_callback
)
else:
resp = await self.prompt_client.kg_prompt(query, kg)
if self.verbose:
logger.debug("Query processing complete")

View file

@ -135,20 +135,56 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
response = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
)
# Check if streaming is requested
if v.streaming:
# Define async callback for streaming chunks
async def send_chunk(chunk):
await flow("response").send(
GraphRagResponse(
chunk=chunk,
end_of_stream=False,
response=None,
error=None
),
properties={"id": id}
)
await flow("response").send(
GraphRagResponse(
response = response,
error = None
),
properties = {"id": id}
)
# Query with streaming enabled
full_response = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
streaming = True,
chunk_callback = send_chunk,
)
# Send final message with complete response
await flow("response").send(
GraphRagResponse(
chunk=None,
end_of_stream=True,
response=full_response,
error=None
),
properties={"id": id}
)
else:
# Non-streaming path (existing behavior)
response = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
)
await flow("response").send(
GraphRagResponse(
response = response,
error = None
),
properties = {"id": id}
)
logger.info("Request processing complete")
@ -158,14 +194,21 @@ class Processor(FlowProcessor):
logger.debug("Sending error response...")
await flow("response").send(
GraphRagResponse(
response = None,
error = Error(
type = "graph-rag-error",
message = str(e),
),
# Send error response with end_of_stream flag if streaming was requested
error_response = GraphRagResponse(
response = None,
error = Error(
type = "graph-rag-error",
message = str(e),
),
)
# If streaming was requested, indicate stream end
if v.streaming:
error_response.end_of_stream = True
await flow("response").send(
error_response,
properties = {"id": id}
)