diff --git a/tests/integration/test_agent_manager_integration.py b/tests/integration/test_agent_manager_integration.py index ae852714..29a301ae 100644 --- a/tests/integration/test_agent_manager_integration.py +++ b/tests/integration/test_agent_manager_integration.py @@ -187,7 +187,7 @@ Final Answer: Machine learning is a field of AI that enables computers to learn # Verify tool was executed graph_rag_client = mock_flow_context("graph-rag-request") - graph_rag_client.rag.assert_called_once_with("What is machine learning?") + graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="default") @pytest.mark.asyncio async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context): @@ -272,7 +272,7 @@ Args: {{ # Verify correct service was called if tool_name == "knowledge_query": - mock_flow_context("graph-rag-request").rag.assert_called() + mock_flow_context("graph-rag-request").rag.assert_called_with("test question", collection="default") elif tool_name == "text_completion": mock_flow_context("prompt-request").question.assert_called() @@ -713,4 +713,127 @@ Final Answer: { # Should not raise JSON serialization errors json_str = json.dumps(variables, indent=4) - assert len(json_str) > 0 \ No newline at end of file + assert len(json_str) > 0 + + @pytest.mark.asyncio + async def test_knowledge_query_with_default_collection(self, mock_flow_context): + """Test KnowledgeQueryImpl uses default collection when not specified""" + # Arrange + tool = KnowledgeQueryImpl(mock_flow_context) + + # Act + result = await tool.invoke(question="What is AI?") + + # Assert + graph_rag_client = mock_flow_context("graph-rag-request") + graph_rag_client.rag.assert_called_once_with("What is AI?", collection="default") + + @pytest.mark.asyncio + async def test_knowledge_query_with_custom_collection(self, mock_flow_context): + """Test KnowledgeQueryImpl uses custom collection when specified""" + # Arrange + tool = KnowledgeQueryImpl(mock_flow_context, collection="custom_collection") + + # Act + result = await tool.invoke(question="What is machine learning?") + + # Assert + graph_rag_client = mock_flow_context("graph-rag-request") + graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="custom_collection") + + @pytest.mark.asyncio + async def test_knowledge_query_with_none_collection(self, mock_flow_context): + """Test KnowledgeQueryImpl handles None collection properly""" + # Arrange + tool = KnowledgeQueryImpl(mock_flow_context, collection=None) + + # Act + result = await tool.invoke(question="Explain neural networks") + + # Assert + graph_rag_client = mock_flow_context("graph-rag-request") + graph_rag_client.rag.assert_called_once_with("Explain neural networks", collection="default") + + @pytest.mark.asyncio + async def test_agent_manager_knowledge_query_collection_integration(self, mock_flow_context): + """Test agent manager integration with KnowledgeQueryImpl collection parameter""" + # Arrange + custom_tools = { + "knowledge_query_custom": Tool( + name="knowledge_query_custom", + description="Query custom knowledge collection", + arguments=[ + Argument( + name="question", + type="string", + description="The question to ask" + ) + ], + implementation=KnowledgeQueryImpl, + config={"collection": "research_papers"} + ), + "knowledge_query_default": Tool( + name="knowledge_query_default", + description="Query default knowledge collection", + arguments=[ + Argument( + name="question", + type="string", + description="The question to ask" + ) + ], + implementation=KnowledgeQueryImpl, + config={} + ) + } + + agent = AgentManager(tools=custom_tools, additional_context="") + + # Mock response for custom collection query + mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search in the research papers +Action: knowledge_query_custom +Args: { + "question": "Latest AI research?" +}""" + + think_callback = AsyncMock() + observe_callback = AsyncMock() + + # Act + action = await agent.react("Find latest research", [], think_callback, observe_callback, mock_flow_context) + + # Assert + assert isinstance(action, Action) + assert action.name == "knowledge_query_custom" + + # Verify the custom collection was used + graph_rag_client = mock_flow_context("graph-rag-request") + graph_rag_client.rag.assert_called_once_with("Latest AI research?", collection="research_papers") + + @pytest.mark.asyncio + async def test_knowledge_query_multiple_collections(self, mock_flow_context): + """Test multiple KnowledgeQueryImpl instances with different collections""" + # Arrange + tools = { + "general_kb": KnowledgeQueryImpl(mock_flow_context, collection="general"), + "technical_kb": KnowledgeQueryImpl(mock_flow_context, collection="technical"), + "research_kb": KnowledgeQueryImpl(mock_flow_context, collection="research") + } + + # Act & Assert for each tool + test_cases = [ + ("general_kb", "What is Python?", "general"), + ("technical_kb", "Explain TCP/IP", "technical"), + ("research_kb", "Latest ML papers", "research") + ] + + for tool_name, question, expected_collection in test_cases: + # Reset mock + mock_flow_context("graph-rag-request").reset_mock() + + # Invoke tool + await tools[tool_name].invoke(question=question) + + # Verify correct collection was used + graph_rag_client = mock_flow_context("graph-rag-request") + graph_rag_client.rag.assert_called_once_with(question, collection=expected_collection) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/agent/react/agent_manager.py b/trustgraph-flow/trustgraph/agent/react/agent_manager.py index 9b46bd34..ed22ea78 100644 --- a/trustgraph-flow/trustgraph/agent/react/agent_manager.py +++ b/trustgraph-flow/trustgraph/agent/react/agent_manager.py @@ -269,7 +269,13 @@ class AgentManager: logger.debug(f"TOOL>>> {act}") - resp = await action.implementation(context).invoke( + # Instantiate the tool implementation with context and config + if action.config: + tool_instance = action.implementation(context, **action.config) + else: + tool_instance = action.implementation(context) + + resp = await tool_instance.invoke( **act.arguments ) diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index d2a15bba..948424ec 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -27,7 +27,8 @@ class KnowledgeQueryImpl: client = self.context("graph-rag-request") logger.debug("Graph RAG question...") return await client.rag( - arguments.get("question") + arguments.get("question"), + collection=self.collection if self.collection else "default" ) # This tool implementation knows how to do text completion. This uses