Fix knowledge query ignoring the collection (#467)

* Fix knowledge query ignoring the collection

* Updated the agent_manager.py to properly pass config parameters when instantiating tool implementations

* Added tests for agent collection parameter
This commit is contained in:
cybermaggedon 2025-08-26 19:05:48 +01:00 committed by GitHub
parent 28190fea8a
commit 6e9e2a11b1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 135 additions and 5 deletions

View file

@ -187,7 +187,7 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
# Verify tool was executed
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with("What is machine learning?")
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="default")
@pytest.mark.asyncio
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
@ -272,7 +272,7 @@ Args: {{
# Verify correct service was called
if tool_name == "knowledge_query":
mock_flow_context("graph-rag-request").rag.assert_called()
mock_flow_context("graph-rag-request").rag.assert_called_with("test question", collection="default")
elif tool_name == "text_completion":
mock_flow_context("prompt-request").question.assert_called()
@ -713,4 +713,127 @@ Final Answer: {
# Should not raise JSON serialization errors
json_str = json.dumps(variables, indent=4)
assert len(json_str) > 0
assert len(json_str) > 0
@pytest.mark.asyncio
async def test_knowledge_query_with_default_collection(self, mock_flow_context):
"""Test KnowledgeQueryImpl uses default collection when not specified"""
# Arrange
tool = KnowledgeQueryImpl(mock_flow_context)
# Act
result = await tool.invoke(question="What is AI?")
# Assert
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with("What is AI?", collection="default")
@pytest.mark.asyncio
async def test_knowledge_query_with_custom_collection(self, mock_flow_context):
"""Test KnowledgeQueryImpl uses custom collection when specified"""
# Arrange
tool = KnowledgeQueryImpl(mock_flow_context, collection="custom_collection")
# Act
result = await tool.invoke(question="What is machine learning?")
# Assert
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="custom_collection")
@pytest.mark.asyncio
async def test_knowledge_query_with_none_collection(self, mock_flow_context):
"""Test KnowledgeQueryImpl handles None collection properly"""
# Arrange
tool = KnowledgeQueryImpl(mock_flow_context, collection=None)
# Act
result = await tool.invoke(question="Explain neural networks")
# Assert
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with("Explain neural networks", collection="default")
@pytest.mark.asyncio
async def test_agent_manager_knowledge_query_collection_integration(self, mock_flow_context):
"""Test agent manager integration with KnowledgeQueryImpl collection parameter"""
# Arrange
custom_tools = {
"knowledge_query_custom": Tool(
name="knowledge_query_custom",
description="Query custom knowledge collection",
arguments=[
Argument(
name="question",
type="string",
description="The question to ask"
)
],
implementation=KnowledgeQueryImpl,
config={"collection": "research_papers"}
),
"knowledge_query_default": Tool(
name="knowledge_query_default",
description="Query default knowledge collection",
arguments=[
Argument(
name="question",
type="string",
description="The question to ask"
)
],
implementation=KnowledgeQueryImpl,
config={}
)
}
agent = AgentManager(tools=custom_tools, additional_context="")
# Mock response for custom collection query
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search in the research papers
Action: knowledge_query_custom
Args: {
"question": "Latest AI research?"
}"""
think_callback = AsyncMock()
observe_callback = AsyncMock()
# Act
action = await agent.react("Find latest research", [], think_callback, observe_callback, mock_flow_context)
# Assert
assert isinstance(action, Action)
assert action.name == "knowledge_query_custom"
# Verify the custom collection was used
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with("Latest AI research?", collection="research_papers")
@pytest.mark.asyncio
async def test_knowledge_query_multiple_collections(self, mock_flow_context):
"""Test multiple KnowledgeQueryImpl instances with different collections"""
# Arrange
tools = {
"general_kb": KnowledgeQueryImpl(mock_flow_context, collection="general"),
"technical_kb": KnowledgeQueryImpl(mock_flow_context, collection="technical"),
"research_kb": KnowledgeQueryImpl(mock_flow_context, collection="research")
}
# Act & Assert for each tool
test_cases = [
("general_kb", "What is Python?", "general"),
("technical_kb", "Explain TCP/IP", "technical"),
("research_kb", "Latest ML papers", "research")
]
for tool_name, question, expected_collection in test_cases:
# Reset mock
mock_flow_context("graph-rag-request").reset_mock()
# Invoke tool
await tools[tool_name].invoke(question=question)
# Verify correct collection was used
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with(question, collection=expected_collection)