mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Streaming rag responses (#568)
* Tech spec for streaming RAG * Support for streaming Graph/Doc RAG
This commit is contained in:
parent
b1cc724f7d
commit
1948edaa50
20 changed files with 3087 additions and 94 deletions
|
|
@ -202,12 +202,16 @@ class StreamingReActParser:
|
|||
# Find which comes first
|
||||
if args_idx >= 0 and (newline_idx < 0 or args_idx < newline_idx):
|
||||
# Args delimiter found first
|
||||
self.action_buffer = self.line_buffer[:args_idx].strip().strip('"')
|
||||
# Only set action_buffer if not already set (to avoid overwriting with empty string)
|
||||
if not self.action_buffer:
|
||||
self.action_buffer = self.line_buffer[:args_idx].strip().strip('"')
|
||||
self.line_buffer = self.line_buffer[args_idx + len(self.ARGS_DELIMITER):].lstrip()
|
||||
self.state = ParserState.ARGS
|
||||
elif newline_idx >= 0:
|
||||
# Newline found, action name complete
|
||||
self.action_buffer = self.line_buffer[:newline_idx].strip().strip('"')
|
||||
# Only set action_buffer if not already set
|
||||
if not self.action_buffer:
|
||||
self.action_buffer = self.line_buffer[:newline_idx].strip().strip('"')
|
||||
self.line_buffer = self.line_buffer[newline_idx + 1:]
|
||||
# Stay in ACTION state or move to ARGS if we find delimiter
|
||||
# Actually, check if next line has Args:
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ class DocumentRag:
|
|||
|
||||
async def query(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
doc_limit=20,
|
||||
doc_limit=20, streaming=False, chunk_callback=None,
|
||||
):
|
||||
|
||||
if self.verbose:
|
||||
|
|
@ -86,10 +86,18 @@ class DocumentRag:
|
|||
logger.debug(f"Documents: {docs}")
|
||||
logger.debug(f"Query: {query}")
|
||||
|
||||
resp = await self.prompt_client.document_prompt(
|
||||
query = query,
|
||||
documents = docs
|
||||
)
|
||||
if streaming and chunk_callback:
|
||||
resp = await self.prompt_client.document_prompt(
|
||||
query=query,
|
||||
documents=docs,
|
||||
streaming=True,
|
||||
chunk_callback=chunk_callback
|
||||
)
|
||||
else:
|
||||
resp = await self.prompt_client.document_prompt(
|
||||
query=query,
|
||||
documents=docs
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Query processing complete")
|
||||
|
|
|
|||
|
|
@ -92,20 +92,56 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
doc_limit = self.doc_limit
|
||||
|
||||
response = await self.rag.query(
|
||||
v.query,
|
||||
user=v.user,
|
||||
collection=v.collection,
|
||||
doc_limit=doc_limit
|
||||
)
|
||||
# Check if streaming is requested
|
||||
if v.streaming:
|
||||
# Define async callback for streaming chunks
|
||||
async def send_chunk(chunk):
|
||||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
chunk=chunk,
|
||||
end_of_stream=False,
|
||||
response=None,
|
||||
error=None
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
response = response,
|
||||
error = None
|
||||
),
|
||||
properties = {"id": id}
|
||||
)
|
||||
# Query with streaming enabled
|
||||
full_response = await self.rag.query(
|
||||
v.query,
|
||||
user=v.user,
|
||||
collection=v.collection,
|
||||
doc_limit=doc_limit,
|
||||
streaming=True,
|
||||
chunk_callback=send_chunk,
|
||||
)
|
||||
|
||||
# Send final message with complete response
|
||||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
chunk=None,
|
||||
end_of_stream=True,
|
||||
response=full_response,
|
||||
error=None
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
else:
|
||||
# Non-streaming path (existing behavior)
|
||||
response = await self.rag.query(
|
||||
v.query,
|
||||
user=v.user,
|
||||
collection=v.collection,
|
||||
doc_limit=doc_limit
|
||||
)
|
||||
|
||||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
response = response,
|
||||
error = None
|
||||
),
|
||||
properties = {"id": id}
|
||||
)
|
||||
|
||||
logger.info("Request processing complete")
|
||||
|
||||
|
|
@ -115,14 +151,21 @@ class Processor(FlowProcessor):
|
|||
|
||||
logger.debug("Sending error response...")
|
||||
|
||||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
response = None,
|
||||
error = Error(
|
||||
type = "document-rag-error",
|
||||
message = str(e),
|
||||
),
|
||||
# Send error response with end_of_stream flag if streaming was requested
|
||||
error_response = DocumentRagResponse(
|
||||
response = None,
|
||||
error = Error(
|
||||
type = "document-rag-error",
|
||||
message = str(e),
|
||||
),
|
||||
)
|
||||
|
||||
# If streaming was requested, indicate stream end
|
||||
if v.streaming:
|
||||
error_response.end_of_stream = True
|
||||
|
||||
await flow("response").send(
|
||||
error_response,
|
||||
properties = {"id": id}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -316,7 +316,7 @@ class GraphRag:
|
|||
async def query(
|
||||
self, query, user = "trustgraph", collection = "default",
|
||||
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
|
||||
max_path_length = 2,
|
||||
max_path_length = 2, streaming = False, chunk_callback = None,
|
||||
):
|
||||
|
||||
if self.verbose:
|
||||
|
|
@ -337,7 +337,14 @@ class GraphRag:
|
|||
logger.debug(f"Knowledge graph: {kg}")
|
||||
logger.debug(f"Query: {query}")
|
||||
|
||||
resp = await self.prompt_client.kg_prompt(query, kg)
|
||||
if streaming and chunk_callback:
|
||||
resp = await self.prompt_client.kg_prompt(
|
||||
query, kg,
|
||||
streaming=True,
|
||||
chunk_callback=chunk_callback
|
||||
)
|
||||
else:
|
||||
resp = await self.prompt_client.kg_prompt(query, kg)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Query processing complete")
|
||||
|
|
|
|||
|
|
@ -135,20 +135,56 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
response = await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
)
|
||||
# Check if streaming is requested
|
||||
if v.streaming:
|
||||
# Define async callback for streaming chunks
|
||||
async def send_chunk(chunk):
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
chunk=chunk,
|
||||
end_of_stream=False,
|
||||
response=None,
|
||||
error=None
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
response = response,
|
||||
error = None
|
||||
),
|
||||
properties = {"id": id}
|
||||
)
|
||||
# Query with streaming enabled
|
||||
full_response = await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
streaming = True,
|
||||
chunk_callback = send_chunk,
|
||||
)
|
||||
|
||||
# Send final message with complete response
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
chunk=None,
|
||||
end_of_stream=True,
|
||||
response=full_response,
|
||||
error=None
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
else:
|
||||
# Non-streaming path (existing behavior)
|
||||
response = await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
)
|
||||
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
response = response,
|
||||
error = None
|
||||
),
|
||||
properties = {"id": id}
|
||||
)
|
||||
|
||||
logger.info("Request processing complete")
|
||||
|
||||
|
|
@ -158,14 +194,21 @@ class Processor(FlowProcessor):
|
|||
|
||||
logger.debug("Sending error response...")
|
||||
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
response = None,
|
||||
error = Error(
|
||||
type = "graph-rag-error",
|
||||
message = str(e),
|
||||
),
|
||||
# Send error response with end_of_stream flag if streaming was requested
|
||||
error_response = GraphRagResponse(
|
||||
response = None,
|
||||
error = Error(
|
||||
type = "graph-rag-error",
|
||||
message = str(e),
|
||||
),
|
||||
)
|
||||
|
||||
# If streaming was requested, indicate stream end
|
||||
if v.streaming:
|
||||
error_response.end_of_stream = True
|
||||
|
||||
await flow("response").send(
|
||||
error_response,
|
||||
properties = {"id": id}
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue