mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Fix knowledge query ignoring the collection (#467)
* Fix knowledge query ignoring the collection * Updated the agent_manager.py to properly pass config parameters when instantiating tool implementations * Added tests for agent collection parameter
This commit is contained in:
parent
28190fea8a
commit
6e9e2a11b1
3 changed files with 135 additions and 5 deletions
|
|
@ -187,7 +187,7 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
|
|||
|
||||
# Verify tool was executed
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("What is machine learning?")
|
||||
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="default")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
|
||||
|
|
@ -272,7 +272,7 @@ Args: {{
|
|||
|
||||
# Verify correct service was called
|
||||
if tool_name == "knowledge_query":
|
||||
mock_flow_context("graph-rag-request").rag.assert_called()
|
||||
mock_flow_context("graph-rag-request").rag.assert_called_with("test question", collection="default")
|
||||
elif tool_name == "text_completion":
|
||||
mock_flow_context("prompt-request").question.assert_called()
|
||||
|
||||
|
|
@ -713,4 +713,127 @@ Final Answer: {
|
|||
|
||||
# Should not raise JSON serialization errors
|
||||
json_str = json.dumps(variables, indent=4)
|
||||
assert len(json_str) > 0
|
||||
assert len(json_str) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_query_with_default_collection(self, mock_flow_context):
|
||||
"""Test KnowledgeQueryImpl uses default collection when not specified"""
|
||||
# Arrange
|
||||
tool = KnowledgeQueryImpl(mock_flow_context)
|
||||
|
||||
# Act
|
||||
result = await tool.invoke(question="What is AI?")
|
||||
|
||||
# Assert
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("What is AI?", collection="default")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_query_with_custom_collection(self, mock_flow_context):
|
||||
"""Test KnowledgeQueryImpl uses custom collection when specified"""
|
||||
# Arrange
|
||||
tool = KnowledgeQueryImpl(mock_flow_context, collection="custom_collection")
|
||||
|
||||
# Act
|
||||
result = await tool.invoke(question="What is machine learning?")
|
||||
|
||||
# Assert
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="custom_collection")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_query_with_none_collection(self, mock_flow_context):
|
||||
"""Test KnowledgeQueryImpl handles None collection properly"""
|
||||
# Arrange
|
||||
tool = KnowledgeQueryImpl(mock_flow_context, collection=None)
|
||||
|
||||
# Act
|
||||
result = await tool.invoke(question="Explain neural networks")
|
||||
|
||||
# Assert
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("Explain neural networks", collection="default")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_knowledge_query_collection_integration(self, mock_flow_context):
|
||||
"""Test agent manager integration with KnowledgeQueryImpl collection parameter"""
|
||||
# Arrange
|
||||
custom_tools = {
|
||||
"knowledge_query_custom": Tool(
|
||||
name="knowledge_query_custom",
|
||||
description="Query custom knowledge collection",
|
||||
arguments=[
|
||||
Argument(
|
||||
name="question",
|
||||
type="string",
|
||||
description="The question to ask"
|
||||
)
|
||||
],
|
||||
implementation=KnowledgeQueryImpl,
|
||||
config={"collection": "research_papers"}
|
||||
),
|
||||
"knowledge_query_default": Tool(
|
||||
name="knowledge_query_default",
|
||||
description="Query default knowledge collection",
|
||||
arguments=[
|
||||
Argument(
|
||||
name="question",
|
||||
type="string",
|
||||
description="The question to ask"
|
||||
)
|
||||
],
|
||||
implementation=KnowledgeQueryImpl,
|
||||
config={}
|
||||
)
|
||||
}
|
||||
|
||||
agent = AgentManager(tools=custom_tools, additional_context="")
|
||||
|
||||
# Mock response for custom collection query
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search in the research papers
|
||||
Action: knowledge_query_custom
|
||||
Args: {
|
||||
"question": "Latest AI research?"
|
||||
}"""
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
action = await agent.react("Find latest research", [], think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
assert action.name == "knowledge_query_custom"
|
||||
|
||||
# Verify the custom collection was used
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("Latest AI research?", collection="research_papers")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_query_multiple_collections(self, mock_flow_context):
|
||||
"""Test multiple KnowledgeQueryImpl instances with different collections"""
|
||||
# Arrange
|
||||
tools = {
|
||||
"general_kb": KnowledgeQueryImpl(mock_flow_context, collection="general"),
|
||||
"technical_kb": KnowledgeQueryImpl(mock_flow_context, collection="technical"),
|
||||
"research_kb": KnowledgeQueryImpl(mock_flow_context, collection="research")
|
||||
}
|
||||
|
||||
# Act & Assert for each tool
|
||||
test_cases = [
|
||||
("general_kb", "What is Python?", "general"),
|
||||
("technical_kb", "Explain TCP/IP", "technical"),
|
||||
("research_kb", "Latest ML papers", "research")
|
||||
]
|
||||
|
||||
for tool_name, question, expected_collection in test_cases:
|
||||
# Reset mock
|
||||
mock_flow_context("graph-rag-request").reset_mock()
|
||||
|
||||
# Invoke tool
|
||||
await tools[tool_name].invoke(question=question)
|
||||
|
||||
# Verify correct collection was used
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with(question, collection=expected_collection)
|
||||
|
|
@ -269,7 +269,13 @@ class AgentManager:
|
|||
|
||||
logger.debug(f"TOOL>>> {act}")
|
||||
|
||||
resp = await action.implementation(context).invoke(
|
||||
# Instantiate the tool implementation with context and config
|
||||
if action.config:
|
||||
tool_instance = action.implementation(context, **action.config)
|
||||
else:
|
||||
tool_instance = action.implementation(context)
|
||||
|
||||
resp = await tool_instance.invoke(
|
||||
**act.arguments
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,8 @@ class KnowledgeQueryImpl:
|
|||
client = self.context("graph-rag-request")
|
||||
logger.debug("Graph RAG question...")
|
||||
return await client.rag(
|
||||
arguments.get("question")
|
||||
arguments.get("question"),
|
||||
collection=self.collection if self.collection else "default"
|
||||
)
|
||||
|
||||
# This tool implementation knows how to do text completion. This uses
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue