From aecf00f040dc2b685dc1f8b438f4706d883f1a2e Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 12 Mar 2026 17:59:02 +0000 Subject: [PATCH] Minor agent tweaks (#692) Update RAG and Agent clients for streaming message handling GraphRAG now sends multiple message types in a stream: - 'explain' messages with explain_id and explain_graph for provenance - 'chunk' messages with response text fragments - end_of_session marker for stream completion Updated all clients to handle this properly: CLI clients (trustgraph-base/trustgraph/clients/): - graph_rag_client.py: Added chunk_callback and explain_callback - document_rag_client.py: Added chunk_callback and explain_callback - agent_client.py: Added think, observe, answer_callback, error_callback Internal clients (trustgraph-base/trustgraph/base/): - graph_rag_client.py: Async callbacks for streaming - agent_client.py: Async callbacks for streaming All clients now: - Route messages by chunk_type/message_type - Stream via optional callbacks for incremental delivery - Wait for proper completion signals (end_of_dialog/end_of_session/end_of_stream) - Accumulate and return complete response for callers not using callbacks Updated callers: - extract/kg/agent/extract.py: Uses new invoke(question=...) API - tests/integration/test_agent_kg_extraction_integration.py: Updated mocks This fixes the agent infinite loop issue where knowledge_query was returning the first 'explain' message (empty response) instead of waiting for the actual answer chunks. Concurrency in triples query --- .../test_agent_kg_extraction_integration.py | 20 +++---- .../trustgraph/base/agent_client.py | 60 ++++++++++++++++--- .../trustgraph/base/graph_rag_client.py | 50 ++++++++++++++-- .../trustgraph/base/triples_query_service.py | 12 +++- .../trustgraph/clients/agent_client.py | 54 +++++++++++++---- .../trustgraph/clients/document_rag_client.py | 46 ++++++++++++-- .../trustgraph/clients/graph_rag_client.py | 46 +++++++++++++- .../trustgraph/extract/kg/agent/extract.py | 16 ----- 8 files changed, 246 insertions(+), 58 deletions(-) diff --git a/tests/integration/test_agent_kg_extraction_integration.py b/tests/integration/test_agent_kg_extraction_integration.py index 1d37960a..c97bd529 100644 --- a/tests/integration/test_agent_kg_extraction_integration.py +++ b/tests/integration/test_agent_kg_extraction_integration.py @@ -31,7 +31,7 @@ class TestAgentKgExtractionIntegration: agent_client = AsyncMock() # Mock successful agent response in JSONL format - def mock_agent_response(recipient, question): + def mock_agent_response(question): # Simulate agent processing and return structured JSONL response mock_response = MagicMock() mock_response.error = None @@ -124,7 +124,7 @@ class TestAgentKgExtractionIntegration: # Get agent response (the mock returns a string directly) agent_client = flow("agent-request") - agent_response = agent_client.invoke(recipient=lambda x: True, question=prompt) + agent_response = agent_client.invoke(question=prompt) # Parse and process extraction_data = extractor.parse_jsonl(agent_response) @@ -197,7 +197,7 @@ class TestAgentKgExtractionIntegration: # Arrange - mock agent error response agent_client = mock_flow_context("agent-request") - def mock_error_response(recipient, question): + def mock_error_response(question): # Simulate agent error by raising an exception raise RuntimeError("Agent processing failed") @@ -219,7 +219,7 @@ class TestAgentKgExtractionIntegration: # Arrange - mock invalid JSON response agent_client = mock_flow_context("agent-request") - def mock_invalid_json_response(recipient, question): + def mock_invalid_json_response(question): return "This is not valid JSON at all" agent_client.invoke = mock_invalid_json_response @@ -244,7 +244,7 @@ class TestAgentKgExtractionIntegration: # Arrange - mock empty extraction response agent_client = mock_flow_context("agent-request") - def mock_empty_response(recipient, question): + def mock_empty_response(question): # Return empty JSONL (just empty/whitespace) return '' @@ -271,7 +271,7 @@ class TestAgentKgExtractionIntegration: # Arrange - mock malformed extraction response agent_client = mock_flow_context("agent-request") - def mock_malformed_response(recipient, question): + def mock_malformed_response(question): # JSONL with definition missing required field return '{"type": "definition", "entity": "Missing Definition"}' @@ -297,7 +297,7 @@ class TestAgentKgExtractionIntegration: agent_client = mock_flow_context("agent-request") - def capture_prompt(recipient, question): + def capture_prompt(question): # Verify the prompt contains the test text assert test_text in question return '' # Empty JSONL response @@ -330,7 +330,7 @@ class TestAgentKgExtractionIntegration: agent_client = mock_flow_context("agent-request") responses = [] - def mock_response(recipient, question): + def mock_response(question): response = f'{{"type": "definition", "entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}' responses.append(response) return response @@ -364,7 +364,7 @@ class TestAgentKgExtractionIntegration: agent_client = mock_flow_context("agent-request") - def mock_unicode_response(recipient, question): + def mock_unicode_response(question): # Verify unicode text was properly decoded and included assert "学习机器" in question assert "人工知能" in question @@ -400,7 +400,7 @@ class TestAgentKgExtractionIntegration: agent_client = mock_flow_context("agent-request") - def mock_large_text_response(recipient, question): + def mock_large_text_response(question): # Verify large text was included assert len(question) > 10000 return '{"type": "definition", "entity": "Machine Learning", "definition": "Important AI technique"}' diff --git a/trustgraph-base/trustgraph/base/agent_client.py b/trustgraph-base/trustgraph/base/agent_client.py index 03939dc3..f48fd024 100644 --- a/trustgraph-base/trustgraph/base/agent_client.py +++ b/trustgraph-base/trustgraph/base/agent_client.py @@ -4,10 +4,57 @@ from .. schema import AgentRequest, AgentResponse from .. knowledge import Uri, Literal class AgentClient(RequestResponse): - async def invoke(self, recipient, question, plan=None, state=None, - history=[], timeout=300): - - resp = await self.request( + async def invoke(self, question, plan=None, state=None, + history=[], think=None, observe=None, answer_callback=None, + timeout=300): + """ + Invoke the agent with optional streaming callbacks. + + Args: + question: The question to ask + plan: Optional plan context + state: Optional state context + history: Conversation history + think: Optional async callback(content, end_of_message) for thought chunks + observe: Optional async callback(content, end_of_message) for observation chunks + answer_callback: Optional async callback(content, end_of_message) for answer chunks + timeout: Request timeout in seconds + + Returns: + Complete answer text (accumulated from all answer chunks) + """ + accumulated_answer = [] + + async def recipient(resp): + if resp.error: + raise RuntimeError(resp.error.message) + + # Handle thought chunks + if resp.chunk_type == 'thought': + if think: + await think(resp.content, resp.end_of_message) + return False # Continue receiving + + # Handle observation chunks + if resp.chunk_type == 'observation': + if observe: + await observe(resp.content, resp.end_of_message) + return False # Continue receiving + + # Handle answer chunks + if resp.chunk_type == 'answer': + if resp.content: + accumulated_answer.append(resp.content) + if answer_callback: + await answer_callback(resp.content, resp.end_of_message) + + # Complete when dialog ends + if resp.end_of_dialog: + return True + + return False # Continue receiving + + await self.request( AgentRequest( question = question, plan = plan, @@ -18,10 +65,7 @@ class AgentClient(RequestResponse): timeout=timeout, ) - if resp.error: - raise RuntimeError(resp.error.message) - - return resp.answer + return "".join(accumulated_answer) class AgentClientSpec(RequestResponseSpec): def __init__( diff --git a/trustgraph-base/trustgraph/base/graph_rag_client.py b/trustgraph-base/trustgraph/base/graph_rag_client.py index c4f3f7ab..66dbad1e 100644 --- a/trustgraph-base/trustgraph/base/graph_rag_client.py +++ b/trustgraph-base/trustgraph/base/graph_rag_client.py @@ -4,20 +4,58 @@ from .. schema import GraphRagQuery, GraphRagResponse class GraphRagClient(RequestResponse): async def rag(self, query, user="trustgraph", collection="default", + chunk_callback=None, explain_callback=None, timeout=600): - resp = await self.request( + """ + Execute a graph RAG query with optional streaming callbacks. + + Args: + query: The question to ask + user: User identifier + collection: Collection identifier + chunk_callback: Optional async callback(text, end_of_stream) for text chunks + explain_callback: Optional async callback(explain_id, explain_graph) for explain notifications + timeout: Request timeout in seconds + + Returns: + Complete response text (accumulated from all chunks) + """ + accumulated_response = [] + + async def recipient(resp): + if resp.error: + raise RuntimeError(resp.error.message) + + # Handle explain notifications + if resp.message_type == 'explain': + if explain_callback and resp.explain_id: + await explain_callback(resp.explain_id, resp.explain_graph) + return False # Continue receiving + + # Handle text chunks + if resp.message_type == 'chunk': + if resp.response: + accumulated_response.append(resp.response) + if chunk_callback: + await chunk_callback(resp.response, resp.end_of_stream) + + # Complete when session ends + if resp.end_of_session: + return True + + return False # Continue receiving + + await self.request( GraphRagQuery( query = query, user = user, collection = collection, ), - timeout=timeout + timeout=timeout, + recipient=recipient, ) - if resp.error: - raise RuntimeError(resp.error.message) - - return resp.response + return "".join(accumulated_response) class GraphRagClientSpec(RequestResponseSpec): def __init__( diff --git a/trustgraph-base/trustgraph/base/triples_query_service.py b/trustgraph-base/trustgraph/base/triples_query_service.py index b8053b01..09f36652 100644 --- a/trustgraph-base/trustgraph/base/triples_query_service.py +++ b/trustgraph-base/trustgraph/base/triples_query_service.py @@ -17,12 +17,14 @@ from . producer_spec import ProducerSpec logger = logging.getLogger(__name__) default_ident = "triples-query" +default_concurrency = 10 class TriplesQueryService(FlowProcessor): def __init__(self, **params): id = params.get("id") + concurrency = params.get("concurrency", default_concurrency) super(TriplesQueryService, self).__init__(**params | { "id": id }) @@ -30,7 +32,8 @@ class TriplesQueryService(FlowProcessor): ConsumerSpec( name = "request", schema = TriplesQueryRequest, - handler = self.on_message + handler = self.on_message, + concurrency = concurrency, ) ) @@ -109,6 +112,13 @@ class TriplesQueryService(FlowProcessor): FlowProcessor.add_args(parser) + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Number of concurrent requests (default: {default_concurrency})' + ) + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-base/trustgraph/clients/agent_client.py b/trustgraph-base/trustgraph/clients/agent_client.py index b31b4e36..17ff5a09 100644 --- a/trustgraph-base/trustgraph/clients/agent_client.py +++ b/trustgraph-base/trustgraph/clients/agent_client.py @@ -42,25 +42,59 @@ class AgentClient(BaseClient): question, think=None, observe=None, + answer_callback=None, + error_callback=None, timeout=300 ): + """ + Request an agent query with optional streaming callbacks. + + Args: + question: The question to ask + think: Optional callback(content, end_of_message) for thought chunks + observe: Optional callback(content, end_of_message) for observation chunks + answer_callback: Optional callback(content, end_of_message) for answer chunks + error_callback: Optional callback(content) for error messages + timeout: Request timeout in seconds + + Returns: + Complete answer text (accumulated from all answer chunks) + """ + accumulated_answer = [] def inspect(x): + # Handle errors + if x.chunk_type == 'error' or x.error: + if error_callback: + error_callback(x.content or (x.error.message if x.error else "")) + # Continue to check end_of_dialog - if x.thought and think: - think(x.thought) - return + # Handle thought chunks + elif x.chunk_type == 'thought': + if think: + think(x.content, x.end_of_message) - if x.observation and observe: - observe(x.observation) - return + # Handle observation chunks + elif x.chunk_type == 'observation': + if observe: + observe(x.content, x.end_of_message) - if x.answer: + # Handle answer chunks + elif x.chunk_type == 'answer': + if x.content: + accumulated_answer.append(x.content) + if answer_callback: + answer_callback(x.content, x.end_of_message) + + # Complete when dialog ends + if x.end_of_dialog: return True - return False + return False # Continue receiving - return self.call( + self.call( question=question, inspect=inspect, timeout=timeout - ).answer + ) + + return "".join(accumulated_answer) diff --git a/trustgraph-base/trustgraph/clients/document_rag_client.py b/trustgraph-base/trustgraph/clients/document_rag_client.py index 6cbafa9b..946b1a6c 100644 --- a/trustgraph-base/trustgraph/clients/document_rag_client.py +++ b/trustgraph-base/trustgraph/clients/document_rag_client.py @@ -40,9 +40,47 @@ class DocumentRagClient(BaseClient): output_schema=DocumentRagResponse, ) - def request(self, query, timeout=300): + def request(self, query, user="trustgraph", collection="default", + chunk_callback=None, explain_callback=None, timeout=300): + """ + Request a document RAG query with optional streaming callbacks. - return self.call( - query=query, timeout=timeout - ).response + Args: + query: The question to ask + user: User identifier + collection: Collection identifier + chunk_callback: Optional callback(text, end_of_stream) for text chunks + explain_callback: Optional callback(explain_id, explain_graph) for explain notifications + timeout: Request timeout in seconds + + Returns: + Complete response text (accumulated from all chunks) + """ + accumulated_response = [] + + def inspect(x): + # Handle explain notifications (response is None/empty, explain_id present) + if x.explain_id and not x.response: + if explain_callback: + explain_callback(x.explain_id, x.explain_graph) + return False # Continue receiving + + # Handle text chunks + if x.response: + accumulated_response.append(x.response) + if chunk_callback: + chunk_callback(x.response, x.end_of_stream) + + # Complete when stream ends + if x.end_of_stream: + return True + + return False # Continue receiving + + self.call( + query=query, user=user, collection=collection, + inspect=inspect, timeout=timeout + ) + + return "".join(accumulated_response) diff --git a/trustgraph-base/trustgraph/clients/graph_rag_client.py b/trustgraph-base/trustgraph/clients/graph_rag_client.py index 77102e36..42ffce0c 100644 --- a/trustgraph-base/trustgraph/clients/graph_rag_client.py +++ b/trustgraph-base/trustgraph/clients/graph_rag_client.py @@ -42,10 +42,50 @@ class GraphRagClient(BaseClient): def request( self, query, user="trustgraph", collection="default", + chunk_callback=None, + explain_callback=None, timeout=500 ): + """ + Request a graph RAG query with optional streaming callbacks. - return self.call( - user=user, collection=collection, query=query, timeout=timeout - ).response + Args: + query: The question to ask + user: User identifier + collection: Collection identifier + chunk_callback: Optional callback(text, end_of_stream) for text chunks + explain_callback: Optional callback(explain_id, explain_graph) for explain notifications + timeout: Request timeout in seconds + + Returns: + Complete response text (accumulated from all chunks) + """ + accumulated_response = [] + + def inspect(x): + # Handle explain notifications + if x.message_type == 'explain': + if explain_callback and x.explain_id: + explain_callback(x.explain_id, x.explain_graph) + return False # Continue receiving + + # Handle text chunks + if x.message_type == 'chunk': + if x.response: + accumulated_response.append(x.response) + if chunk_callback: + chunk_callback(x.response, x.end_of_stream) + + # Complete when session ends + if x.end_of_session: + return True + + return False # Continue receiving + + self.call( + user=user, collection=collection, query=query, + inspect=inspect, timeout=timeout + ) + + return "".join(accumulated_response) diff --git a/trustgraph-flow/trustgraph/extract/kg/agent/extract.py b/trustgraph-flow/trustgraph/extract/kg/agent/extract.py index 79d5123d..a00944d6 100644 --- a/trustgraph-flow/trustgraph/extract/kg/agent/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/agent/extract.py @@ -183,24 +183,8 @@ class Processor(FlowProcessor): logger.debug(f"Agent prompt: {prompt}") - async def handle(response): - - logger.debug(f"Agent response: {response}") - - if response.error is not None: - if response.error.message: - raise RuntimeError(str(response.error.message)) - else: - raise RuntimeError(str(response.error)) - - if response.answer is not None: - return True - else: - return False - # Send to agent API agent_response = await flow("agent-request").invoke( - recipient = handle, question = prompt )