Streaming rag responses (#568)

* Tech spec for streaming RAG

* Support for streaming Graph/Doc RAG
This commit is contained in:
cybermaggedon 2025-11-26 19:47:39 +00:00 committed by GitHub
parent b1cc724f7d
commit 1948edaa50
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 3087 additions and 94 deletions

View file

@ -112,7 +112,7 @@ class PromptClient(RequestResponse):
timeout = timeout,
)
async def kg_prompt(self, query, kg, timeout=600):
async def kg_prompt(self, query, kg, timeout=600, streaming=False, chunk_callback=None):
return await self.prompt(
id = "kg-prompt",
variables = {
@ -123,9 +123,11 @@ class PromptClient(RequestResponse):
]
},
timeout = timeout,
streaming = streaming,
chunk_callback = chunk_callback,
)
async def document_prompt(self, query, documents, timeout=600):
async def document_prompt(self, query, documents, timeout=600, streaming=False, chunk_callback=None):
return await self.prompt(
id = "document-prompt",
variables = {
@ -133,6 +135,8 @@ class PromptClient(RequestResponse):
"documents": documents,
},
timeout = timeout,
streaming = streaming,
chunk_callback = chunk_callback,
)
async def agent_react(self, variables, timeout=600, streaming=False, chunk_callback=None):

View file

@ -5,43 +5,65 @@ from .base import MessageTranslator
class DocumentRagRequestTranslator(MessageTranslator):
"""Translator for DocumentRagQuery schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentRagQuery:
return DocumentRagQuery(
query=data["query"],
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
doc_limit=int(data.get("doc-limit", 20))
doc_limit=int(data.get("doc-limit", 20)),
streaming=data.get("streaming", False)
)
def from_pulsar(self, obj: DocumentRagQuery) -> Dict[str, Any]:
return {
"query": obj.query,
"user": obj.user,
"collection": obj.collection,
"doc-limit": obj.doc_limit
"doc-limit": obj.doc_limit,
"streaming": getattr(obj, "streaming", False)
}
class DocumentRagResponseTranslator(MessageTranslator):
"""Translator for DocumentRagResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentRagResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
return {
"response": obj.response
}
result = {}
# Check if this is a streaming response (has chunk)
if hasattr(obj, 'chunk') and obj.chunk:
result["chunk"] = obj.chunk
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
else:
# Non-streaming response
if obj.response:
result["response"] = obj.response
# Always include error if present
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type}
return result
def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
# For streaming responses, check end_of_stream
if hasattr(obj, 'chunk') and obj.chunk:
is_final = getattr(obj, 'end_of_stream', False)
else:
# For non-streaming responses, it's always final
is_final = True
return self.from_pulsar(obj), is_final
class GraphRagRequestTranslator(MessageTranslator):
"""Translator for GraphRagQuery schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> GraphRagQuery:
return GraphRagQuery(
query=data["query"],
@ -50,9 +72,10 @@ class GraphRagRequestTranslator(MessageTranslator):
entity_limit=int(data.get("entity-limit", 50)),
triple_limit=int(data.get("triple-limit", 30)),
max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
max_path_length=int(data.get("max-path-length", 2))
max_path_length=int(data.get("max-path-length", 2)),
streaming=data.get("streaming", False)
)
def from_pulsar(self, obj: GraphRagQuery) -> Dict[str, Any]:
return {
"query": obj.query,
@ -61,21 +84,42 @@ class GraphRagRequestTranslator(MessageTranslator):
"entity-limit": obj.entity_limit,
"triple-limit": obj.triple_limit,
"max-subgraph-size": obj.max_subgraph_size,
"max-path-length": obj.max_path_length
"max-path-length": obj.max_path_length,
"streaming": getattr(obj, "streaming", False)
}
class GraphRagResponseTranslator(MessageTranslator):
"""Translator for GraphRagResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> GraphRagResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
return {
"response": obj.response
}
result = {}
# Check if this is a streaming response (has chunk)
if hasattr(obj, 'chunk') and obj.chunk:
result["chunk"] = obj.chunk
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
else:
# Non-streaming response
if obj.response:
result["response"] = obj.response
# Always include error if present
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type}
return result
def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
# For streaming responses, check end_of_stream
if hasattr(obj, 'chunk') and obj.chunk:
is_final = getattr(obj, 'end_of_stream', False)
else:
# For non-streaming responses, it's always final
is_final = True
return self.from_pulsar(obj), is_final

View file

@ -15,10 +15,13 @@ class GraphRagQuery(Record):
triple_limit = Integer()
max_subgraph_size = Integer()
max_path_length = Integer()
streaming = Boolean()
class GraphRagResponse(Record):
error = Error()
response = String()
chunk = String()
end_of_stream = Boolean()
############################################################################
@ -29,8 +32,11 @@ class DocumentRagQuery(Record):
user = String()
collection = String()
doc_limit = Integer()
streaming = Boolean()
class DocumentRagResponse(Record):
error = Error()
response = String()
chunk = String()
end_of_stream = Boolean()