mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Minor agent tweaks (#692)
Update RAG and Agent clients for streaming message handling GraphRAG now sends multiple message types in a stream: - 'explain' messages with explain_id and explain_graph for provenance - 'chunk' messages with response text fragments - end_of_session marker for stream completion Updated all clients to handle this properly: CLI clients (trustgraph-base/trustgraph/clients/): - graph_rag_client.py: Added chunk_callback and explain_callback - document_rag_client.py: Added chunk_callback and explain_callback - agent_client.py: Added think, observe, answer_callback, error_callback Internal clients (trustgraph-base/trustgraph/base/): - graph_rag_client.py: Async callbacks for streaming - agent_client.py: Async callbacks for streaming All clients now: - Route messages by chunk_type/message_type - Stream via optional callbacks for incremental delivery - Wait for proper completion signals (end_of_dialog/end_of_session/end_of_stream) - Accumulate and return complete response for callers not using callbacks Updated callers: - extract/kg/agent/extract.py: Uses new invoke(question=...) API - tests/integration/test_agent_kg_extraction_integration.py: Updated mocks This fixes the agent infinite loop issue where knowledge_query was returning the first 'explain' message (empty response) instead of waiting for the actual answer chunks. Concurrency in triples query
This commit is contained in:
parent
45e6ad4abc
commit
aecf00f040
8 changed files with 246 additions and 58 deletions
|
|
@ -31,7 +31,7 @@ class TestAgentKgExtractionIntegration:
|
|||
agent_client = AsyncMock()
|
||||
|
||||
# Mock successful agent response in JSONL format
|
||||
def mock_agent_response(recipient, question):
|
||||
def mock_agent_response(question):
|
||||
# Simulate agent processing and return structured JSONL response
|
||||
mock_response = MagicMock()
|
||||
mock_response.error = None
|
||||
|
|
@ -124,7 +124,7 @@ class TestAgentKgExtractionIntegration:
|
|||
|
||||
# Get agent response (the mock returns a string directly)
|
||||
agent_client = flow("agent-request")
|
||||
agent_response = agent_client.invoke(recipient=lambda x: True, question=prompt)
|
||||
agent_response = agent_client.invoke(question=prompt)
|
||||
|
||||
# Parse and process
|
||||
extraction_data = extractor.parse_jsonl(agent_response)
|
||||
|
|
@ -197,7 +197,7 @@ class TestAgentKgExtractionIntegration:
|
|||
# Arrange - mock agent error response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_error_response(recipient, question):
|
||||
def mock_error_response(question):
|
||||
# Simulate agent error by raising an exception
|
||||
raise RuntimeError("Agent processing failed")
|
||||
|
||||
|
|
@ -219,7 +219,7 @@ class TestAgentKgExtractionIntegration:
|
|||
# Arrange - mock invalid JSON response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_invalid_json_response(recipient, question):
|
||||
def mock_invalid_json_response(question):
|
||||
return "This is not valid JSON at all"
|
||||
|
||||
agent_client.invoke = mock_invalid_json_response
|
||||
|
|
@ -244,7 +244,7 @@ class TestAgentKgExtractionIntegration:
|
|||
# Arrange - mock empty extraction response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_empty_response(recipient, question):
|
||||
def mock_empty_response(question):
|
||||
# Return empty JSONL (just empty/whitespace)
|
||||
return ''
|
||||
|
||||
|
|
@ -271,7 +271,7 @@ class TestAgentKgExtractionIntegration:
|
|||
# Arrange - mock malformed extraction response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_malformed_response(recipient, question):
|
||||
def mock_malformed_response(question):
|
||||
# JSONL with definition missing required field
|
||||
return '{"type": "definition", "entity": "Missing Definition"}'
|
||||
|
||||
|
|
@ -297,7 +297,7 @@ class TestAgentKgExtractionIntegration:
|
|||
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def capture_prompt(recipient, question):
|
||||
def capture_prompt(question):
|
||||
# Verify the prompt contains the test text
|
||||
assert test_text in question
|
||||
return '' # Empty JSONL response
|
||||
|
|
@ -330,7 +330,7 @@ class TestAgentKgExtractionIntegration:
|
|||
agent_client = mock_flow_context("agent-request")
|
||||
responses = []
|
||||
|
||||
def mock_response(recipient, question):
|
||||
def mock_response(question):
|
||||
response = f'{{"type": "definition", "entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}'
|
||||
responses.append(response)
|
||||
return response
|
||||
|
|
@ -364,7 +364,7 @@ class TestAgentKgExtractionIntegration:
|
|||
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_unicode_response(recipient, question):
|
||||
def mock_unicode_response(question):
|
||||
# Verify unicode text was properly decoded and included
|
||||
assert "学习机器" in question
|
||||
assert "人工知能" in question
|
||||
|
|
@ -400,7 +400,7 @@ class TestAgentKgExtractionIntegration:
|
|||
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_large_text_response(recipient, question):
|
||||
def mock_large_text_response(question):
|
||||
# Verify large text was included
|
||||
assert len(question) > 10000
|
||||
return '{"type": "definition", "entity": "Machine Learning", "definition": "Important AI technique"}'
|
||||
|
|
|
|||
|
|
@ -4,10 +4,57 @@ from .. schema import AgentRequest, AgentResponse
|
|||
from .. knowledge import Uri, Literal
|
||||
|
||||
class AgentClient(RequestResponse):
|
||||
async def invoke(self, recipient, question, plan=None, state=None,
|
||||
history=[], timeout=300):
|
||||
async def invoke(self, question, plan=None, state=None,
|
||||
history=[], think=None, observe=None, answer_callback=None,
|
||||
timeout=300):
|
||||
"""
|
||||
Invoke the agent with optional streaming callbacks.
|
||||
|
||||
resp = await self.request(
|
||||
Args:
|
||||
question: The question to ask
|
||||
plan: Optional plan context
|
||||
state: Optional state context
|
||||
history: Conversation history
|
||||
think: Optional async callback(content, end_of_message) for thought chunks
|
||||
observe: Optional async callback(content, end_of_message) for observation chunks
|
||||
answer_callback: Optional async callback(content, end_of_message) for answer chunks
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Complete answer text (accumulated from all answer chunks)
|
||||
"""
|
||||
accumulated_answer = []
|
||||
|
||||
async def recipient(resp):
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
# Handle thought chunks
|
||||
if resp.chunk_type == 'thought':
|
||||
if think:
|
||||
await think(resp.content, resp.end_of_message)
|
||||
return False # Continue receiving
|
||||
|
||||
# Handle observation chunks
|
||||
if resp.chunk_type == 'observation':
|
||||
if observe:
|
||||
await observe(resp.content, resp.end_of_message)
|
||||
return False # Continue receiving
|
||||
|
||||
# Handle answer chunks
|
||||
if resp.chunk_type == 'answer':
|
||||
if resp.content:
|
||||
accumulated_answer.append(resp.content)
|
||||
if answer_callback:
|
||||
await answer_callback(resp.content, resp.end_of_message)
|
||||
|
||||
# Complete when dialog ends
|
||||
if resp.end_of_dialog:
|
||||
return True
|
||||
|
||||
return False # Continue receiving
|
||||
|
||||
await self.request(
|
||||
AgentRequest(
|
||||
question = question,
|
||||
plan = plan,
|
||||
|
|
@ -18,10 +65,7 @@ class AgentClient(RequestResponse):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return resp.answer
|
||||
return "".join(accumulated_answer)
|
||||
|
||||
class AgentClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -4,20 +4,58 @@ from .. schema import GraphRagQuery, GraphRagResponse
|
|||
|
||||
class GraphRagClient(RequestResponse):
|
||||
async def rag(self, query, user="trustgraph", collection="default",
|
||||
chunk_callback=None, explain_callback=None,
|
||||
timeout=600):
|
||||
resp = await self.request(
|
||||
"""
|
||||
Execute a graph RAG query with optional streaming callbacks.
|
||||
|
||||
Args:
|
||||
query: The question to ask
|
||||
user: User identifier
|
||||
collection: Collection identifier
|
||||
chunk_callback: Optional async callback(text, end_of_stream) for text chunks
|
||||
explain_callback: Optional async callback(explain_id, explain_graph) for explain notifications
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Complete response text (accumulated from all chunks)
|
||||
"""
|
||||
accumulated_response = []
|
||||
|
||||
async def recipient(resp):
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
# Handle explain notifications
|
||||
if resp.message_type == 'explain':
|
||||
if explain_callback and resp.explain_id:
|
||||
await explain_callback(resp.explain_id, resp.explain_graph)
|
||||
return False # Continue receiving
|
||||
|
||||
# Handle text chunks
|
||||
if resp.message_type == 'chunk':
|
||||
if resp.response:
|
||||
accumulated_response.append(resp.response)
|
||||
if chunk_callback:
|
||||
await chunk_callback(resp.response, resp.end_of_stream)
|
||||
|
||||
# Complete when session ends
|
||||
if resp.end_of_session:
|
||||
return True
|
||||
|
||||
return False # Continue receiving
|
||||
|
||||
await self.request(
|
||||
GraphRagQuery(
|
||||
query = query,
|
||||
user = user,
|
||||
collection = collection,
|
||||
),
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
recipient=recipient,
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return resp.response
|
||||
return "".join(accumulated_response)
|
||||
|
||||
class GraphRagClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -17,12 +17,14 @@ from . producer_spec import ProducerSpec
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "triples-query"
|
||||
default_concurrency = 10
|
||||
|
||||
class TriplesQueryService(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", default_concurrency)
|
||||
|
||||
super(TriplesQueryService, self).__init__(**params | { "id": id })
|
||||
|
||||
|
|
@ -30,7 +32,8 @@ class TriplesQueryService(FlowProcessor):
|
|||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = TriplesQueryRequest,
|
||||
handler = self.on_message
|
||||
handler = self.on_message,
|
||||
concurrency = concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -109,6 +112,13 @@ class TriplesQueryService(FlowProcessor):
|
|||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Number of concurrent requests (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -42,25 +42,59 @@ class AgentClient(BaseClient):
|
|||
question,
|
||||
think=None,
|
||||
observe=None,
|
||||
answer_callback=None,
|
||||
error_callback=None,
|
||||
timeout=300
|
||||
):
|
||||
"""
|
||||
Request an agent query with optional streaming callbacks.
|
||||
|
||||
Args:
|
||||
question: The question to ask
|
||||
think: Optional callback(content, end_of_message) for thought chunks
|
||||
observe: Optional callback(content, end_of_message) for observation chunks
|
||||
answer_callback: Optional callback(content, end_of_message) for answer chunks
|
||||
error_callback: Optional callback(content) for error messages
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Complete answer text (accumulated from all answer chunks)
|
||||
"""
|
||||
accumulated_answer = []
|
||||
|
||||
def inspect(x):
|
||||
# Handle errors
|
||||
if x.chunk_type == 'error' or x.error:
|
||||
if error_callback:
|
||||
error_callback(x.content or (x.error.message if x.error else ""))
|
||||
# Continue to check end_of_dialog
|
||||
|
||||
if x.thought and think:
|
||||
think(x.thought)
|
||||
return
|
||||
# Handle thought chunks
|
||||
elif x.chunk_type == 'thought':
|
||||
if think:
|
||||
think(x.content, x.end_of_message)
|
||||
|
||||
if x.observation and observe:
|
||||
observe(x.observation)
|
||||
return
|
||||
# Handle observation chunks
|
||||
elif x.chunk_type == 'observation':
|
||||
if observe:
|
||||
observe(x.content, x.end_of_message)
|
||||
|
||||
if x.answer:
|
||||
# Handle answer chunks
|
||||
elif x.chunk_type == 'answer':
|
||||
if x.content:
|
||||
accumulated_answer.append(x.content)
|
||||
if answer_callback:
|
||||
answer_callback(x.content, x.end_of_message)
|
||||
|
||||
# Complete when dialog ends
|
||||
if x.end_of_dialog:
|
||||
return True
|
||||
|
||||
return False
|
||||
return False # Continue receiving
|
||||
|
||||
return self.call(
|
||||
self.call(
|
||||
question=question, inspect=inspect, timeout=timeout
|
||||
).answer
|
||||
)
|
||||
|
||||
return "".join(accumulated_answer)
|
||||
|
||||
|
|
|
|||
|
|
@ -40,9 +40,47 @@ class DocumentRagClient(BaseClient):
|
|||
output_schema=DocumentRagResponse,
|
||||
)
|
||||
|
||||
def request(self, query, timeout=300):
|
||||
def request(self, query, user="trustgraph", collection="default",
|
||||
chunk_callback=None, explain_callback=None, timeout=300):
|
||||
"""
|
||||
Request a document RAG query with optional streaming callbacks.
|
||||
|
||||
return self.call(
|
||||
query=query, timeout=timeout
|
||||
).response
|
||||
Args:
|
||||
query: The question to ask
|
||||
user: User identifier
|
||||
collection: Collection identifier
|
||||
chunk_callback: Optional callback(text, end_of_stream) for text chunks
|
||||
explain_callback: Optional callback(explain_id, explain_graph) for explain notifications
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Complete response text (accumulated from all chunks)
|
||||
"""
|
||||
accumulated_response = []
|
||||
|
||||
def inspect(x):
|
||||
# Handle explain notifications (response is None/empty, explain_id present)
|
||||
if x.explain_id and not x.response:
|
||||
if explain_callback:
|
||||
explain_callback(x.explain_id, x.explain_graph)
|
||||
return False # Continue receiving
|
||||
|
||||
# Handle text chunks
|
||||
if x.response:
|
||||
accumulated_response.append(x.response)
|
||||
if chunk_callback:
|
||||
chunk_callback(x.response, x.end_of_stream)
|
||||
|
||||
# Complete when stream ends
|
||||
if x.end_of_stream:
|
||||
return True
|
||||
|
||||
return False # Continue receiving
|
||||
|
||||
self.call(
|
||||
query=query, user=user, collection=collection,
|
||||
inspect=inspect, timeout=timeout
|
||||
)
|
||||
|
||||
return "".join(accumulated_response)
|
||||
|
||||
|
|
|
|||
|
|
@ -42,10 +42,50 @@ class GraphRagClient(BaseClient):
|
|||
|
||||
def request(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
chunk_callback=None,
|
||||
explain_callback=None,
|
||||
timeout=500
|
||||
):
|
||||
"""
|
||||
Request a graph RAG query with optional streaming callbacks.
|
||||
|
||||
return self.call(
|
||||
user=user, collection=collection, query=query, timeout=timeout
|
||||
).response
|
||||
Args:
|
||||
query: The question to ask
|
||||
user: User identifier
|
||||
collection: Collection identifier
|
||||
chunk_callback: Optional callback(text, end_of_stream) for text chunks
|
||||
explain_callback: Optional callback(explain_id, explain_graph) for explain notifications
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Complete response text (accumulated from all chunks)
|
||||
"""
|
||||
accumulated_response = []
|
||||
|
||||
def inspect(x):
|
||||
# Handle explain notifications
|
||||
if x.message_type == 'explain':
|
||||
if explain_callback and x.explain_id:
|
||||
explain_callback(x.explain_id, x.explain_graph)
|
||||
return False # Continue receiving
|
||||
|
||||
# Handle text chunks
|
||||
if x.message_type == 'chunk':
|
||||
if x.response:
|
||||
accumulated_response.append(x.response)
|
||||
if chunk_callback:
|
||||
chunk_callback(x.response, x.end_of_stream)
|
||||
|
||||
# Complete when session ends
|
||||
if x.end_of_session:
|
||||
return True
|
||||
|
||||
return False # Continue receiving
|
||||
|
||||
self.call(
|
||||
user=user, collection=collection, query=query,
|
||||
inspect=inspect, timeout=timeout
|
||||
)
|
||||
|
||||
return "".join(accumulated_response)
|
||||
|
||||
|
|
|
|||
|
|
@ -183,24 +183,8 @@ class Processor(FlowProcessor):
|
|||
|
||||
logger.debug(f"Agent prompt: {prompt}")
|
||||
|
||||
async def handle(response):
|
||||
|
||||
logger.debug(f"Agent response: {response}")
|
||||
|
||||
if response.error is not None:
|
||||
if response.error.message:
|
||||
raise RuntimeError(str(response.error.message))
|
||||
else:
|
||||
raise RuntimeError(str(response.error))
|
||||
|
||||
if response.answer is not None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
# Send to agent API
|
||||
agent_response = await flow("agent-request").invoke(
|
||||
recipient = handle,
|
||||
question = prompt
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue