mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Fix non streaming RAG problems (#607)
* Fix non-streaming failure in RAG services * Fix non-streaming failure in API * Fix agent non-streaming messaging * Agent messaging unit & contract tests
This commit is contained in:
parent
30ca1d2e8b
commit
807f6cc4e2
10 changed files with 677 additions and 21 deletions
|
|
@ -275,13 +275,17 @@ class SocketFlowInstance:
|
|||
result = self.client._send_request_sync("text-completion", self.flow_id, request, streaming)
|
||||
|
||||
if streaming:
|
||||
# For text completion, yield just the content
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
# For text completion, return generator that yields content
|
||||
return self._text_completion_generator(result)
|
||||
else:
|
||||
return result.get("response", "")
|
||||
|
||||
def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
|
||||
"""Generator for text completion streaming"""
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
|
||||
def graph_rag(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -308,9 +312,7 @@ class SocketFlowInstance:
|
|||
result = self.client._send_request_sync("graph-rag", self.flow_id, request, streaming)
|
||||
|
||||
if streaming:
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
return self._rag_generator(result)
|
||||
else:
|
||||
return result.get("response", "")
|
||||
|
||||
|
|
@ -336,12 +338,16 @@ class SocketFlowInstance:
|
|||
result = self.client._send_request_sync("document-rag", self.flow_id, request, streaming)
|
||||
|
||||
if streaming:
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
return self._rag_generator(result)
|
||||
else:
|
||||
return result.get("response", "")
|
||||
|
||||
def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
|
||||
"""Generator for RAG streaming (graph-rag and document-rag)"""
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
|
||||
def prompt(
|
||||
self,
|
||||
id: str,
|
||||
|
|
@ -360,9 +366,7 @@ class SocketFlowInstance:
|
|||
result = self.client._send_request_sync("prompt", self.flow_id, request, streaming)
|
||||
|
||||
if streaming:
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
return self._rag_generator(result)
|
||||
else:
|
||||
return result.get("response", "")
|
||||
|
||||
|
|
|
|||
|
|
@ -48,13 +48,13 @@ class AgentService(FlowProcessor):
|
|||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
|
||||
# Get ID early so error handler can use it
|
||||
id = msg.properties().get("id", "unknown")
|
||||
|
||||
try:
|
||||
|
||||
request = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
async def respond(resp):
|
||||
|
||||
await flow("response").send(
|
||||
|
|
@ -93,6 +93,8 @@ class AgentService(FlowProcessor):
|
|||
thought = None,
|
||||
observation = None,
|
||||
answer = None,
|
||||
end_of_message = True,
|
||||
end_of_dialog = True,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -44,13 +44,16 @@ class AgentResponseTranslator(MessageTranslator):
|
|||
result["end_of_message"] = getattr(obj, "end_of_message", False)
|
||||
result["end_of_dialog"] = getattr(obj, "end_of_dialog", False)
|
||||
else:
|
||||
# Legacy format
|
||||
# Legacy format (non-streaming)
|
||||
if obj.answer:
|
||||
result["answer"] = obj.answer
|
||||
if obj.thought:
|
||||
result["thought"] = obj.thought
|
||||
if obj.observation:
|
||||
result["observation"] = obj.observation
|
||||
# Include completion flags for legacy format too
|
||||
result["end_of_message"] = getattr(obj, "end_of_message", False)
|
||||
result["end_of_dialog"] = getattr(obj, "end_of_dialog", False)
|
||||
|
||||
# Always include error if present
|
||||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue