mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-21 05:15:12 +02:00
Minor agent tweaks (#692)
Update RAG and Agent clients for streaming message handling GraphRAG now sends multiple message types in a stream: - 'explain' messages with explain_id and explain_graph for provenance - 'chunk' messages with response text fragments - end_of_session marker for stream completion Updated all clients to handle this properly: CLI clients (trustgraph-base/trustgraph/clients/): - graph_rag_client.py: Added chunk_callback and explain_callback - document_rag_client.py: Added chunk_callback and explain_callback - agent_client.py: Added think, observe, answer_callback, error_callback Internal clients (trustgraph-base/trustgraph/base/): - graph_rag_client.py: Async callbacks for streaming - agent_client.py: Async callbacks for streaming All clients now: - Route messages by chunk_type/message_type - Stream via optional callbacks for incremental delivery - Wait for proper completion signals (end_of_dialog/end_of_session/end_of_stream) - Accumulate and return complete response for callers not using callbacks Updated callers: - extract/kg/agent/extract.py: Uses new invoke(question=...) API - tests/integration/test_agent_kg_extraction_integration.py: Updated mocks This fixes the agent infinite loop issue where knowledge_query was returning the first 'explain' message (empty response) instead of waiting for the actual answer chunks. Concurrency in triples query
This commit is contained in:
parent
45e6ad4abc
commit
aecf00f040
8 changed files with 246 additions and 58 deletions
|
|
@ -42,25 +42,59 @@ class AgentClient(BaseClient):
|
|||
question,
|
||||
think=None,
|
||||
observe=None,
|
||||
answer_callback=None,
|
||||
error_callback=None,
|
||||
timeout=300
|
||||
):
|
||||
"""
|
||||
Request an agent query with optional streaming callbacks.
|
||||
|
||||
Args:
|
||||
question: The question to ask
|
||||
think: Optional callback(content, end_of_message) for thought chunks
|
||||
observe: Optional callback(content, end_of_message) for observation chunks
|
||||
answer_callback: Optional callback(content, end_of_message) for answer chunks
|
||||
error_callback: Optional callback(content) for error messages
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Complete answer text (accumulated from all answer chunks)
|
||||
"""
|
||||
accumulated_answer = []
|
||||
|
||||
def inspect(x):
|
||||
# Handle errors
|
||||
if x.chunk_type == 'error' or x.error:
|
||||
if error_callback:
|
||||
error_callback(x.content or (x.error.message if x.error else ""))
|
||||
# Continue to check end_of_dialog
|
||||
|
||||
if x.thought and think:
|
||||
think(x.thought)
|
||||
return
|
||||
# Handle thought chunks
|
||||
elif x.chunk_type == 'thought':
|
||||
if think:
|
||||
think(x.content, x.end_of_message)
|
||||
|
||||
if x.observation and observe:
|
||||
observe(x.observation)
|
||||
return
|
||||
# Handle observation chunks
|
||||
elif x.chunk_type == 'observation':
|
||||
if observe:
|
||||
observe(x.content, x.end_of_message)
|
||||
|
||||
if x.answer:
|
||||
# Handle answer chunks
|
||||
elif x.chunk_type == 'answer':
|
||||
if x.content:
|
||||
accumulated_answer.append(x.content)
|
||||
if answer_callback:
|
||||
answer_callback(x.content, x.end_of_message)
|
||||
|
||||
# Complete when dialog ends
|
||||
if x.end_of_dialog:
|
||||
return True
|
||||
|
||||
return False
|
||||
return False # Continue receiving
|
||||
|
||||
return self.call(
|
||||
self.call(
|
||||
question=question, inspect=inspect, timeout=timeout
|
||||
).answer
|
||||
)
|
||||
|
||||
return "".join(accumulated_answer)
|
||||
|
||||
|
|
|
|||
|
|
@ -40,9 +40,47 @@ class DocumentRagClient(BaseClient):
|
|||
output_schema=DocumentRagResponse,
|
||||
)
|
||||
|
||||
def request(self, query, timeout=300):
|
||||
def request(self, query, user="trustgraph", collection="default",
|
||||
chunk_callback=None, explain_callback=None, timeout=300):
|
||||
"""
|
||||
Request a document RAG query with optional streaming callbacks.
|
||||
|
||||
return self.call(
|
||||
query=query, timeout=timeout
|
||||
).response
|
||||
Args:
|
||||
query: The question to ask
|
||||
user: User identifier
|
||||
collection: Collection identifier
|
||||
chunk_callback: Optional callback(text, end_of_stream) for text chunks
|
||||
explain_callback: Optional callback(explain_id, explain_graph) for explain notifications
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Complete response text (accumulated from all chunks)
|
||||
"""
|
||||
accumulated_response = []
|
||||
|
||||
def inspect(x):
|
||||
# Handle explain notifications (response is None/empty, explain_id present)
|
||||
if x.explain_id and not x.response:
|
||||
if explain_callback:
|
||||
explain_callback(x.explain_id, x.explain_graph)
|
||||
return False # Continue receiving
|
||||
|
||||
# Handle text chunks
|
||||
if x.response:
|
||||
accumulated_response.append(x.response)
|
||||
if chunk_callback:
|
||||
chunk_callback(x.response, x.end_of_stream)
|
||||
|
||||
# Complete when stream ends
|
||||
if x.end_of_stream:
|
||||
return True
|
||||
|
||||
return False # Continue receiving
|
||||
|
||||
self.call(
|
||||
query=query, user=user, collection=collection,
|
||||
inspect=inspect, timeout=timeout
|
||||
)
|
||||
|
||||
return "".join(accumulated_response)
|
||||
|
||||
|
|
|
|||
|
|
@ -42,10 +42,50 @@ class GraphRagClient(BaseClient):
|
|||
|
||||
def request(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
chunk_callback=None,
|
||||
explain_callback=None,
|
||||
timeout=500
|
||||
):
|
||||
"""
|
||||
Request a graph RAG query with optional streaming callbacks.
|
||||
|
||||
return self.call(
|
||||
user=user, collection=collection, query=query, timeout=timeout
|
||||
).response
|
||||
Args:
|
||||
query: The question to ask
|
||||
user: User identifier
|
||||
collection: Collection identifier
|
||||
chunk_callback: Optional callback(text, end_of_stream) for text chunks
|
||||
explain_callback: Optional callback(explain_id, explain_graph) for explain notifications
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Complete response text (accumulated from all chunks)
|
||||
"""
|
||||
accumulated_response = []
|
||||
|
||||
def inspect(x):
|
||||
# Handle explain notifications
|
||||
if x.message_type == 'explain':
|
||||
if explain_callback and x.explain_id:
|
||||
explain_callback(x.explain_id, x.explain_graph)
|
||||
return False # Continue receiving
|
||||
|
||||
# Handle text chunks
|
||||
if x.message_type == 'chunk':
|
||||
if x.response:
|
||||
accumulated_response.append(x.response)
|
||||
if chunk_callback:
|
||||
chunk_callback(x.response, x.end_of_stream)
|
||||
|
||||
# Complete when session ends
|
||||
if x.end_of_session:
|
||||
return True
|
||||
|
||||
return False # Continue receiving
|
||||
|
||||
self.call(
|
||||
user=user, collection=collection, query=query,
|
||||
inspect=inspect, timeout=timeout
|
||||
)
|
||||
|
||||
return "".join(accumulated_response)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue