Minor agent tweaks (#692)

Update RAG and Agent clients for streaming message handling

GraphRAG now sends multiple message types in a stream:
- 'explain' messages with explain_id and explain_graph for
  provenance
- 'chunk' messages with response text fragments
- end_of_session marker for stream completion

Updated all clients to handle this properly:

CLI clients (trustgraph-base/trustgraph/clients/):
- graph_rag_client.py: Added chunk_callback and explain_callback
- document_rag_client.py: Added chunk_callback and explain_callback
- agent_client.py: Added think, observe, answer_callback,
  error_callback

Internal clients (trustgraph-base/trustgraph/base/):
- graph_rag_client.py: Async callbacks for streaming
- agent_client.py: Async callbacks for streaming

All clients now:
- Route messages by chunk_type/message_type
- Stream via optional callbacks for incremental delivery
- Wait for proper completion signals
(end_of_dialog/end_of_session/end_of_stream)
- Accumulate and return complete response for callers not using
  callbacks

Updated callers:
- extract/kg/agent/extract.py: Uses new invoke(question=...) API
- tests/integration/test_agent_kg_extraction_integration.py:
  Updated mocks

This fixes the agent infinite loop issue where knowledge_query was
returning the first 'explain' message (empty response) instead of
waiting for the actual answer chunks.

Concurrency in triples query
This commit is contained in:
cybermaggedon 2026-03-12 17:59:02 +00:00 committed by GitHub
parent 45e6ad4abc
commit aecf00f040
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 246 additions and 58 deletions

View file

@ -31,7 +31,7 @@ class TestAgentKgExtractionIntegration:
agent_client = AsyncMock() agent_client = AsyncMock()
# Mock successful agent response in JSONL format # Mock successful agent response in JSONL format
def mock_agent_response(recipient, question): def mock_agent_response(question):
# Simulate agent processing and return structured JSONL response # Simulate agent processing and return structured JSONL response
mock_response = MagicMock() mock_response = MagicMock()
mock_response.error = None mock_response.error = None
@ -124,7 +124,7 @@ class TestAgentKgExtractionIntegration:
# Get agent response (the mock returns a string directly) # Get agent response (the mock returns a string directly)
agent_client = flow("agent-request") agent_client = flow("agent-request")
agent_response = agent_client.invoke(recipient=lambda x: True, question=prompt) agent_response = agent_client.invoke(question=prompt)
# Parse and process # Parse and process
extraction_data = extractor.parse_jsonl(agent_response) extraction_data = extractor.parse_jsonl(agent_response)
@ -197,7 +197,7 @@ class TestAgentKgExtractionIntegration:
# Arrange - mock agent error response # Arrange - mock agent error response
agent_client = mock_flow_context("agent-request") agent_client = mock_flow_context("agent-request")
def mock_error_response(recipient, question): def mock_error_response(question):
# Simulate agent error by raising an exception # Simulate agent error by raising an exception
raise RuntimeError("Agent processing failed") raise RuntimeError("Agent processing failed")
@ -219,7 +219,7 @@ class TestAgentKgExtractionIntegration:
# Arrange - mock invalid JSON response # Arrange - mock invalid JSON response
agent_client = mock_flow_context("agent-request") agent_client = mock_flow_context("agent-request")
def mock_invalid_json_response(recipient, question): def mock_invalid_json_response(question):
return "This is not valid JSON at all" return "This is not valid JSON at all"
agent_client.invoke = mock_invalid_json_response agent_client.invoke = mock_invalid_json_response
@ -244,7 +244,7 @@ class TestAgentKgExtractionIntegration:
# Arrange - mock empty extraction response # Arrange - mock empty extraction response
agent_client = mock_flow_context("agent-request") agent_client = mock_flow_context("agent-request")
def mock_empty_response(recipient, question): def mock_empty_response(question):
# Return empty JSONL (just empty/whitespace) # Return empty JSONL (just empty/whitespace)
return '' return ''
@ -271,7 +271,7 @@ class TestAgentKgExtractionIntegration:
# Arrange - mock malformed extraction response # Arrange - mock malformed extraction response
agent_client = mock_flow_context("agent-request") agent_client = mock_flow_context("agent-request")
def mock_malformed_response(recipient, question): def mock_malformed_response(question):
# JSONL with definition missing required field # JSONL with definition missing required field
return '{"type": "definition", "entity": "Missing Definition"}' return '{"type": "definition", "entity": "Missing Definition"}'
@ -297,7 +297,7 @@ class TestAgentKgExtractionIntegration:
agent_client = mock_flow_context("agent-request") agent_client = mock_flow_context("agent-request")
def capture_prompt(recipient, question): def capture_prompt(question):
# Verify the prompt contains the test text # Verify the prompt contains the test text
assert test_text in question assert test_text in question
return '' # Empty JSONL response return '' # Empty JSONL response
@ -330,7 +330,7 @@ class TestAgentKgExtractionIntegration:
agent_client = mock_flow_context("agent-request") agent_client = mock_flow_context("agent-request")
responses = [] responses = []
def mock_response(recipient, question): def mock_response(question):
response = f'{{"type": "definition", "entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}' response = f'{{"type": "definition", "entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}'
responses.append(response) responses.append(response)
return response return response
@ -364,7 +364,7 @@ class TestAgentKgExtractionIntegration:
agent_client = mock_flow_context("agent-request") agent_client = mock_flow_context("agent-request")
def mock_unicode_response(recipient, question): def mock_unicode_response(question):
# Verify unicode text was properly decoded and included # Verify unicode text was properly decoded and included
assert "学习机器" in question assert "学习机器" in question
assert "人工知能" in question assert "人工知能" in question
@ -400,7 +400,7 @@ class TestAgentKgExtractionIntegration:
agent_client = mock_flow_context("agent-request") agent_client = mock_flow_context("agent-request")
def mock_large_text_response(recipient, question): def mock_large_text_response(question):
# Verify large text was included # Verify large text was included
assert len(question) > 10000 assert len(question) > 10000
return '{"type": "definition", "entity": "Machine Learning", "definition": "Important AI technique"}' return '{"type": "definition", "entity": "Machine Learning", "definition": "Important AI technique"}'

View file

@ -4,10 +4,57 @@ from .. schema import AgentRequest, AgentResponse
from .. knowledge import Uri, Literal from .. knowledge import Uri, Literal
class AgentClient(RequestResponse): class AgentClient(RequestResponse):
async def invoke(self, recipient, question, plan=None, state=None, async def invoke(self, question, plan=None, state=None,
history=[], timeout=300): history=[], think=None, observe=None, answer_callback=None,
timeout=300):
resp = await self.request( """
Invoke the agent with optional streaming callbacks.
Args:
question: The question to ask
plan: Optional plan context
state: Optional state context
history: Conversation history
think: Optional async callback(content, end_of_message) for thought chunks
observe: Optional async callback(content, end_of_message) for observation chunks
answer_callback: Optional async callback(content, end_of_message) for answer chunks
timeout: Request timeout in seconds
Returns:
Complete answer text (accumulated from all answer chunks)
"""
accumulated_answer = []
async def recipient(resp):
if resp.error:
raise RuntimeError(resp.error.message)
# Handle thought chunks
if resp.chunk_type == 'thought':
if think:
await think(resp.content, resp.end_of_message)
return False # Continue receiving
# Handle observation chunks
if resp.chunk_type == 'observation':
if observe:
await observe(resp.content, resp.end_of_message)
return False # Continue receiving
# Handle answer chunks
if resp.chunk_type == 'answer':
if resp.content:
accumulated_answer.append(resp.content)
if answer_callback:
await answer_callback(resp.content, resp.end_of_message)
# Complete when dialog ends
if resp.end_of_dialog:
return True
return False # Continue receiving
await self.request(
AgentRequest( AgentRequest(
question = question, question = question,
plan = plan, plan = plan,
@ -18,10 +65,7 @@ class AgentClient(RequestResponse):
timeout=timeout, timeout=timeout,
) )
if resp.error: return "".join(accumulated_answer)
raise RuntimeError(resp.error.message)
return resp.answer
class AgentClientSpec(RequestResponseSpec): class AgentClientSpec(RequestResponseSpec):
def __init__( def __init__(

View file

@ -4,20 +4,58 @@ from .. schema import GraphRagQuery, GraphRagResponse
class GraphRagClient(RequestResponse): class GraphRagClient(RequestResponse):
async def rag(self, query, user="trustgraph", collection="default", async def rag(self, query, user="trustgraph", collection="default",
chunk_callback=None, explain_callback=None,
timeout=600): timeout=600):
resp = await self.request( """
Execute a graph RAG query with optional streaming callbacks.
Args:
query: The question to ask
user: User identifier
collection: Collection identifier
chunk_callback: Optional async callback(text, end_of_stream) for text chunks
explain_callback: Optional async callback(explain_id, explain_graph) for explain notifications
timeout: Request timeout in seconds
Returns:
Complete response text (accumulated from all chunks)
"""
accumulated_response = []
async def recipient(resp):
if resp.error:
raise RuntimeError(resp.error.message)
# Handle explain notifications
if resp.message_type == 'explain':
if explain_callback and resp.explain_id:
await explain_callback(resp.explain_id, resp.explain_graph)
return False # Continue receiving
# Handle text chunks
if resp.message_type == 'chunk':
if resp.response:
accumulated_response.append(resp.response)
if chunk_callback:
await chunk_callback(resp.response, resp.end_of_stream)
# Complete when session ends
if resp.end_of_session:
return True
return False # Continue receiving
await self.request(
GraphRagQuery( GraphRagQuery(
query = query, query = query,
user = user, user = user,
collection = collection, collection = collection,
), ),
timeout=timeout timeout=timeout,
recipient=recipient,
) )
if resp.error: return "".join(accumulated_response)
raise RuntimeError(resp.error.message)
return resp.response
class GraphRagClientSpec(RequestResponseSpec): class GraphRagClientSpec(RequestResponseSpec):
def __init__( def __init__(

View file

@ -17,12 +17,14 @@ from . producer_spec import ProducerSpec
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
default_ident = "triples-query" default_ident = "triples-query"
default_concurrency = 10
class TriplesQueryService(FlowProcessor): class TriplesQueryService(FlowProcessor):
def __init__(self, **params): def __init__(self, **params):
id = params.get("id") id = params.get("id")
concurrency = params.get("concurrency", default_concurrency)
super(TriplesQueryService, self).__init__(**params | { "id": id }) super(TriplesQueryService, self).__init__(**params | { "id": id })
@ -30,7 +32,8 @@ class TriplesQueryService(FlowProcessor):
ConsumerSpec( ConsumerSpec(
name = "request", name = "request",
schema = TriplesQueryRequest, schema = TriplesQueryRequest,
handler = self.on_message handler = self.on_message,
concurrency = concurrency,
) )
) )
@ -109,6 +112,13 @@ class TriplesQueryService(FlowProcessor):
FlowProcessor.add_args(parser) FlowProcessor.add_args(parser)
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Number of concurrent requests (default: {default_concurrency})'
)
def run(): def run():
Processor.launch(default_ident, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -42,25 +42,59 @@ class AgentClient(BaseClient):
question, question,
think=None, think=None,
observe=None, observe=None,
answer_callback=None,
error_callback=None,
timeout=300 timeout=300
): ):
"""
Request an agent query with optional streaming callbacks.
Args:
question: The question to ask
think: Optional callback(content, end_of_message) for thought chunks
observe: Optional callback(content, end_of_message) for observation chunks
answer_callback: Optional callback(content, end_of_message) for answer chunks
error_callback: Optional callback(content) for error messages
timeout: Request timeout in seconds
Returns:
Complete answer text (accumulated from all answer chunks)
"""
accumulated_answer = []
def inspect(x): def inspect(x):
# Handle errors
if x.chunk_type == 'error' or x.error:
if error_callback:
error_callback(x.content or (x.error.message if x.error else ""))
# Continue to check end_of_dialog
if x.thought and think: # Handle thought chunks
think(x.thought) elif x.chunk_type == 'thought':
return if think:
think(x.content, x.end_of_message)
if x.observation and observe: # Handle observation chunks
observe(x.observation) elif x.chunk_type == 'observation':
return if observe:
observe(x.content, x.end_of_message)
if x.answer: # Handle answer chunks
elif x.chunk_type == 'answer':
if x.content:
accumulated_answer.append(x.content)
if answer_callback:
answer_callback(x.content, x.end_of_message)
# Complete when dialog ends
if x.end_of_dialog:
return True return True
return False return False # Continue receiving
return self.call( self.call(
question=question, inspect=inspect, timeout=timeout question=question, inspect=inspect, timeout=timeout
).answer )
return "".join(accumulated_answer)

View file

@ -40,9 +40,47 @@ class DocumentRagClient(BaseClient):
output_schema=DocumentRagResponse, output_schema=DocumentRagResponse,
) )
def request(self, query, timeout=300): def request(self, query, user="trustgraph", collection="default",
chunk_callback=None, explain_callback=None, timeout=300):
"""
Request a document RAG query with optional streaming callbacks.
return self.call( Args:
query=query, timeout=timeout query: The question to ask
).response user: User identifier
collection: Collection identifier
chunk_callback: Optional callback(text, end_of_stream) for text chunks
explain_callback: Optional callback(explain_id, explain_graph) for explain notifications
timeout: Request timeout in seconds
Returns:
Complete response text (accumulated from all chunks)
"""
accumulated_response = []
def inspect(x):
# Handle explain notifications (response is None/empty, explain_id present)
if x.explain_id and not x.response:
if explain_callback:
explain_callback(x.explain_id, x.explain_graph)
return False # Continue receiving
# Handle text chunks
if x.response:
accumulated_response.append(x.response)
if chunk_callback:
chunk_callback(x.response, x.end_of_stream)
# Complete when stream ends
if x.end_of_stream:
return True
return False # Continue receiving
self.call(
query=query, user=user, collection=collection,
inspect=inspect, timeout=timeout
)
return "".join(accumulated_response)

View file

@ -42,10 +42,50 @@ class GraphRagClient(BaseClient):
def request( def request(
self, query, user="trustgraph", collection="default", self, query, user="trustgraph", collection="default",
chunk_callback=None,
explain_callback=None,
timeout=500 timeout=500
): ):
"""
Request a graph RAG query with optional streaming callbacks.
return self.call( Args:
user=user, collection=collection, query=query, timeout=timeout query: The question to ask
).response user: User identifier
collection: Collection identifier
chunk_callback: Optional callback(text, end_of_stream) for text chunks
explain_callback: Optional callback(explain_id, explain_graph) for explain notifications
timeout: Request timeout in seconds
Returns:
Complete response text (accumulated from all chunks)
"""
accumulated_response = []
def inspect(x):
# Handle explain notifications
if x.message_type == 'explain':
if explain_callback and x.explain_id:
explain_callback(x.explain_id, x.explain_graph)
return False # Continue receiving
# Handle text chunks
if x.message_type == 'chunk':
if x.response:
accumulated_response.append(x.response)
if chunk_callback:
chunk_callback(x.response, x.end_of_stream)
# Complete when session ends
if x.end_of_session:
return True
return False # Continue receiving
self.call(
user=user, collection=collection, query=query,
inspect=inspect, timeout=timeout
)
return "".join(accumulated_response)

View file

@ -183,24 +183,8 @@ class Processor(FlowProcessor):
logger.debug(f"Agent prompt: {prompt}") logger.debug(f"Agent prompt: {prompt}")
async def handle(response):
logger.debug(f"Agent response: {response}")
if response.error is not None:
if response.error.message:
raise RuntimeError(str(response.error.message))
else:
raise RuntimeError(str(response.error))
if response.answer is not None:
return True
else:
return False
# Send to agent API # Send to agent API
agent_response = await flow("agent-request").invoke( agent_response = await flow("agent-request").invoke(
recipient = handle,
question = prompt question = prompt
) )