mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Streaming rag responses (#568)
* Tech spec for streaming RAG * Support for streaming Graph/Doc RAG
This commit is contained in:
parent
b1cc724f7d
commit
1948edaa50
20 changed files with 3087 additions and 94 deletions
|
|
@ -112,7 +112,7 @@ class PromptClient(RequestResponse):
|
|||
timeout = timeout,
|
||||
)
|
||||
|
||||
async def kg_prompt(self, query, kg, timeout=600):
|
||||
async def kg_prompt(self, query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||
return await self.prompt(
|
||||
id = "kg-prompt",
|
||||
variables = {
|
||||
|
|
@ -123,9 +123,11 @@ class PromptClient(RequestResponse):
|
|||
]
|
||||
},
|
||||
timeout = timeout,
|
||||
streaming = streaming,
|
||||
chunk_callback = chunk_callback,
|
||||
)
|
||||
|
||||
async def document_prompt(self, query, documents, timeout=600):
|
||||
async def document_prompt(self, query, documents, timeout=600, streaming=False, chunk_callback=None):
|
||||
return await self.prompt(
|
||||
id = "document-prompt",
|
||||
variables = {
|
||||
|
|
@ -133,6 +135,8 @@ class PromptClient(RequestResponse):
|
|||
"documents": documents,
|
||||
},
|
||||
timeout = timeout,
|
||||
streaming = streaming,
|
||||
chunk_callback = chunk_callback,
|
||||
)
|
||||
|
||||
async def agent_react(self, variables, timeout=600, streaming=False, chunk_callback=None):
|
||||
|
|
|
|||
|
|
@ -5,43 +5,65 @@ from .base import MessageTranslator
|
|||
|
||||
class DocumentRagRequestTranslator(MessageTranslator):
|
||||
"""Translator for DocumentRagQuery schema objects"""
|
||||
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentRagQuery:
|
||||
return DocumentRagQuery(
|
||||
query=data["query"],
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
doc_limit=int(data.get("doc-limit", 20))
|
||||
doc_limit=int(data.get("doc-limit", 20)),
|
||||
streaming=data.get("streaming", False)
|
||||
)
|
||||
|
||||
|
||||
def from_pulsar(self, obj: DocumentRagQuery) -> Dict[str, Any]:
|
||||
return {
|
||||
"query": obj.query,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection,
|
||||
"doc-limit": obj.doc_limit
|
||||
"doc-limit": obj.doc_limit,
|
||||
"streaming": getattr(obj, "streaming", False)
|
||||
}
|
||||
|
||||
|
||||
class DocumentRagResponseTranslator(MessageTranslator):
|
||||
"""Translator for DocumentRagResponse schema objects"""
|
||||
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentRagResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
|
||||
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
|
||||
return {
|
||||
"response": obj.response
|
||||
}
|
||||
|
||||
result = {}
|
||||
|
||||
# Check if this is a streaming response (has chunk)
|
||||
if hasattr(obj, 'chunk') and obj.chunk:
|
||||
result["chunk"] = obj.chunk
|
||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||
else:
|
||||
# Non-streaming response
|
||||
if obj.response:
|
||||
result["response"] = obj.response
|
||||
|
||||
# Always include error if present
|
||||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
result["error"] = {"message": obj.error.message, "type": obj.error.type}
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
# For streaming responses, check end_of_stream
|
||||
if hasattr(obj, 'chunk') and obj.chunk:
|
||||
is_final = getattr(obj, 'end_of_stream', False)
|
||||
else:
|
||||
# For non-streaming responses, it's always final
|
||||
is_final = True
|
||||
|
||||
return self.from_pulsar(obj), is_final
|
||||
|
||||
|
||||
class GraphRagRequestTranslator(MessageTranslator):
|
||||
"""Translator for GraphRagQuery schema objects"""
|
||||
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> GraphRagQuery:
|
||||
return GraphRagQuery(
|
||||
query=data["query"],
|
||||
|
|
@ -50,9 +72,10 @@ class GraphRagRequestTranslator(MessageTranslator):
|
|||
entity_limit=int(data.get("entity-limit", 50)),
|
||||
triple_limit=int(data.get("triple-limit", 30)),
|
||||
max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
|
||||
max_path_length=int(data.get("max-path-length", 2))
|
||||
max_path_length=int(data.get("max-path-length", 2)),
|
||||
streaming=data.get("streaming", False)
|
||||
)
|
||||
|
||||
|
||||
def from_pulsar(self, obj: GraphRagQuery) -> Dict[str, Any]:
|
||||
return {
|
||||
"query": obj.query,
|
||||
|
|
@ -61,21 +84,42 @@ class GraphRagRequestTranslator(MessageTranslator):
|
|||
"entity-limit": obj.entity_limit,
|
||||
"triple-limit": obj.triple_limit,
|
||||
"max-subgraph-size": obj.max_subgraph_size,
|
||||
"max-path-length": obj.max_path_length
|
||||
"max-path-length": obj.max_path_length,
|
||||
"streaming": getattr(obj, "streaming", False)
|
||||
}
|
||||
|
||||
|
||||
class GraphRagResponseTranslator(MessageTranslator):
|
||||
"""Translator for GraphRagResponse schema objects"""
|
||||
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> GraphRagResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
|
||||
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
|
||||
return {
|
||||
"response": obj.response
|
||||
}
|
||||
|
||||
result = {}
|
||||
|
||||
# Check if this is a streaming response (has chunk)
|
||||
if hasattr(obj, 'chunk') and obj.chunk:
|
||||
result["chunk"] = obj.chunk
|
||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||
else:
|
||||
# Non-streaming response
|
||||
if obj.response:
|
||||
result["response"] = obj.response
|
||||
|
||||
# Always include error if present
|
||||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
result["error"] = {"message": obj.error.message, "type": obj.error.type}
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
# For streaming responses, check end_of_stream
|
||||
if hasattr(obj, 'chunk') and obj.chunk:
|
||||
is_final = getattr(obj, 'end_of_stream', False)
|
||||
else:
|
||||
# For non-streaming responses, it's always final
|
||||
is_final = True
|
||||
|
||||
return self.from_pulsar(obj), is_final
|
||||
|
|
@ -15,10 +15,13 @@ class GraphRagQuery(Record):
|
|||
triple_limit = Integer()
|
||||
max_subgraph_size = Integer()
|
||||
max_path_length = Integer()
|
||||
streaming = Boolean()
|
||||
|
||||
class GraphRagResponse(Record):
|
||||
error = Error()
|
||||
response = String()
|
||||
chunk = String()
|
||||
end_of_stream = Boolean()
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
@ -29,8 +32,11 @@ class DocumentRagQuery(Record):
|
|||
user = String()
|
||||
collection = String()
|
||||
doc_limit = Integer()
|
||||
streaming = Boolean()
|
||||
|
||||
class DocumentRagResponse(Record):
|
||||
error = Error()
|
||||
response = String()
|
||||
chunk = String()
|
||||
end_of_stream = Boolean()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue