mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 00:46:22 +02:00
Minor agent tweaks (#692)
Update RAG and Agent clients for streaming message handling GraphRAG now sends multiple message types in a stream: - 'explain' messages with explain_id and explain_graph for provenance - 'chunk' messages with response text fragments - end_of_session marker for stream completion Updated all clients to handle this properly: CLI clients (trustgraph-base/trustgraph/clients/): - graph_rag_client.py: Added chunk_callback and explain_callback - document_rag_client.py: Added chunk_callback and explain_callback - agent_client.py: Added think, observe, answer_callback, error_callback Internal clients (trustgraph-base/trustgraph/base/): - graph_rag_client.py: Async callbacks for streaming - agent_client.py: Async callbacks for streaming All clients now: - Route messages by chunk_type/message_type - Stream via optional callbacks for incremental delivery - Wait for proper completion signals (end_of_dialog/end_of_session/end_of_stream) - Accumulate and return complete response for callers not using callbacks Updated callers: - extract/kg/agent/extract.py: Uses new invoke(question=...) API - tests/integration/test_agent_kg_extraction_integration.py: Updated mocks This fixes the agent infinite loop issue where knowledge_query was returning the first 'explain' message (empty response) instead of waiting for the actual answer chunks. Concurrency in triples query
This commit is contained in:
parent
45e6ad4abc
commit
aecf00f040
8 changed files with 246 additions and 58 deletions
|
|
@ -4,10 +4,57 @@ from .. schema import AgentRequest, AgentResponse
|
|||
from .. knowledge import Uri, Literal
|
||||
|
||||
class AgentClient(RequestResponse):
|
||||
async def invoke(self, recipient, question, plan=None, state=None,
|
||||
history=[], timeout=300):
|
||||
|
||||
resp = await self.request(
|
||||
async def invoke(self, question, plan=None, state=None,
|
||||
history=[], think=None, observe=None, answer_callback=None,
|
||||
timeout=300):
|
||||
"""
|
||||
Invoke the agent with optional streaming callbacks.
|
||||
|
||||
Args:
|
||||
question: The question to ask
|
||||
plan: Optional plan context
|
||||
state: Optional state context
|
||||
history: Conversation history
|
||||
think: Optional async callback(content, end_of_message) for thought chunks
|
||||
observe: Optional async callback(content, end_of_message) for observation chunks
|
||||
answer_callback: Optional async callback(content, end_of_message) for answer chunks
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Complete answer text (accumulated from all answer chunks)
|
||||
"""
|
||||
accumulated_answer = []
|
||||
|
||||
async def recipient(resp):
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
# Handle thought chunks
|
||||
if resp.chunk_type == 'thought':
|
||||
if think:
|
||||
await think(resp.content, resp.end_of_message)
|
||||
return False # Continue receiving
|
||||
|
||||
# Handle observation chunks
|
||||
if resp.chunk_type == 'observation':
|
||||
if observe:
|
||||
await observe(resp.content, resp.end_of_message)
|
||||
return False # Continue receiving
|
||||
|
||||
# Handle answer chunks
|
||||
if resp.chunk_type == 'answer':
|
||||
if resp.content:
|
||||
accumulated_answer.append(resp.content)
|
||||
if answer_callback:
|
||||
await answer_callback(resp.content, resp.end_of_message)
|
||||
|
||||
# Complete when dialog ends
|
||||
if resp.end_of_dialog:
|
||||
return True
|
||||
|
||||
return False # Continue receiving
|
||||
|
||||
await self.request(
|
||||
AgentRequest(
|
||||
question = question,
|
||||
plan = plan,
|
||||
|
|
@ -18,10 +65,7 @@ class AgentClient(RequestResponse):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return resp.answer
|
||||
return "".join(accumulated_answer)
|
||||
|
||||
class AgentClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -4,20 +4,58 @@ from .. schema import GraphRagQuery, GraphRagResponse
|
|||
|
||||
class GraphRagClient(RequestResponse):
|
||||
async def rag(self, query, user="trustgraph", collection="default",
|
||||
chunk_callback=None, explain_callback=None,
|
||||
timeout=600):
|
||||
resp = await self.request(
|
||||
"""
|
||||
Execute a graph RAG query with optional streaming callbacks.
|
||||
|
||||
Args:
|
||||
query: The question to ask
|
||||
user: User identifier
|
||||
collection: Collection identifier
|
||||
chunk_callback: Optional async callback(text, end_of_stream) for text chunks
|
||||
explain_callback: Optional async callback(explain_id, explain_graph) for explain notifications
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Complete response text (accumulated from all chunks)
|
||||
"""
|
||||
accumulated_response = []
|
||||
|
||||
async def recipient(resp):
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
# Handle explain notifications
|
||||
if resp.message_type == 'explain':
|
||||
if explain_callback and resp.explain_id:
|
||||
await explain_callback(resp.explain_id, resp.explain_graph)
|
||||
return False # Continue receiving
|
||||
|
||||
# Handle text chunks
|
||||
if resp.message_type == 'chunk':
|
||||
if resp.response:
|
||||
accumulated_response.append(resp.response)
|
||||
if chunk_callback:
|
||||
await chunk_callback(resp.response, resp.end_of_stream)
|
||||
|
||||
# Complete when session ends
|
||||
if resp.end_of_session:
|
||||
return True
|
||||
|
||||
return False # Continue receiving
|
||||
|
||||
await self.request(
|
||||
GraphRagQuery(
|
||||
query = query,
|
||||
user = user,
|
||||
collection = collection,
|
||||
),
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
recipient=recipient,
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return resp.response
|
||||
return "".join(accumulated_response)
|
||||
|
||||
class GraphRagClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -17,12 +17,14 @@ from . producer_spec import ProducerSpec
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "triples-query"
|
||||
default_concurrency = 10
|
||||
|
||||
class TriplesQueryService(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", default_concurrency)
|
||||
|
||||
super(TriplesQueryService, self).__init__(**params | { "id": id })
|
||||
|
||||
|
|
@ -30,7 +32,8 @@ class TriplesQueryService(FlowProcessor):
|
|||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = TriplesQueryRequest,
|
||||
handler = self.on_message
|
||||
handler = self.on_message,
|
||||
concurrency = concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -109,6 +112,13 @@ class TriplesQueryService(FlowProcessor):
|
|||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Number of concurrent requests (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue