mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 08:56:21 +02:00
Update to enable knowledge extraction using the agent framework (#439)
* Implement KG extraction agent (kg-extract-agent) * Using ReAct framework (agent-manager-react) * ReAct manager had an issue when emitting JSON, which conflicts which ReAct manager's own JSON messages, so refactored ReAct manager to use traditional ReAct messages, non-JSON structure. * Minor refactor to take the prompt template client out of prompt-template so it can be more readily used by other modules. kg-extract-agent uses this framework.
This commit is contained in:
parent
1fe4ed5226
commit
d83e4e3d59
30 changed files with 3192 additions and 799 deletions
481
tests/integration/test_agent_kg_extraction_integration.py
Normal file
481
tests/integration/test_agent_kg_extraction_integration.py
Normal file
|
|
@ -0,0 +1,481 @@
|
|||
"""
|
||||
Integration tests for Agent-based Knowledge Graph Extraction
|
||||
|
||||
These tests verify the end-to-end functionality of the agent-driven knowledge graph
|
||||
extraction pipeline, testing the integration between agent communication, prompt
|
||||
rendering, JSON response processing, and knowledge graph generation.
|
||||
Following the TEST_STRATEGY.md approach for integration testing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error
|
||||
from trustgraph.schema import EntityContext, EntityContexts, AgentRequest, AgentResponse
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
from trustgraph.template.prompt_manager import PromptManager
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAgentKgExtractionIntegration:
|
||||
"""Integration tests for Agent-based Knowledge Graph Extraction"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flow_context(self):
|
||||
"""Mock flow context for agent communication and output publishing"""
|
||||
context = MagicMock()
|
||||
|
||||
# Mock agent client
|
||||
agent_client = AsyncMock()
|
||||
|
||||
# Mock successful agent response
|
||||
def mock_agent_response(recipient, question):
|
||||
# Simulate agent processing and return structured response
|
||||
mock_response = MagicMock()
|
||||
mock_response.error = None
|
||||
mock_response.answer = '''```json
|
||||
{
|
||||
"definitions": [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": true
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```'''
|
||||
return mock_response.answer
|
||||
|
||||
agent_client.invoke = mock_agent_response
|
||||
|
||||
# Mock output publishers
|
||||
triples_publisher = AsyncMock()
|
||||
entity_contexts_publisher = AsyncMock()
|
||||
|
||||
def context_router(service_name):
|
||||
if service_name == "agent-request":
|
||||
return agent_client
|
||||
elif service_name == "triples":
|
||||
return triples_publisher
|
||||
elif service_name == "entity-contexts":
|
||||
return entity_contexts_publisher
|
||||
else:
|
||||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
return context
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunk(self):
|
||||
"""Sample text chunk for knowledge extraction"""
|
||||
text = """
|
||||
Machine Learning is a subset of Artificial Intelligence that enables computers
|
||||
to learn from data without explicit programming. Neural Networks are computing
|
||||
systems inspired by biological neural networks that process information.
|
||||
Neural Networks are commonly used in Machine Learning applications.
|
||||
"""
|
||||
|
||||
return Chunk(
|
||||
chunk=text.encode('utf-8'),
|
||||
metadata=Metadata(
|
||||
id="doc123",
|
||||
metadata=[
|
||||
Triple(
|
||||
s=Value(value="doc123", is_uri=True),
|
||||
p=Value(value="http://example.org/type", is_uri=True),
|
||||
o=Value(value="document", is_uri=False)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def configured_agent_extractor(self):
|
||||
"""Mock agent extractor with loaded configuration for integration testing"""
|
||||
# Create a mock extractor that simulates the real behavior
|
||||
from trustgraph.extract.kg.agent.extract import Processor
|
||||
|
||||
# Create mock without calling __init__ to avoid FlowProcessor issues
|
||||
extractor = MagicMock()
|
||||
real_extractor = Processor.__new__(Processor)
|
||||
|
||||
# Copy the methods we want to test
|
||||
extractor.to_uri = real_extractor.to_uri
|
||||
extractor.parse_json = real_extractor.parse_json
|
||||
extractor.process_extraction_data = real_extractor.process_extraction_data
|
||||
extractor.emit_triples = real_extractor.emit_triples
|
||||
extractor.emit_entity_contexts = real_extractor.emit_entity_contexts
|
||||
|
||||
# Set up the configuration and manager
|
||||
extractor.manager = PromptManager()
|
||||
extractor.template_id = "agent-kg-extract"
|
||||
extractor.config_key = "prompt"
|
||||
|
||||
# Mock configuration
|
||||
config = {
|
||||
"system": json.dumps("You are a knowledge extraction agent."),
|
||||
"template-index": json.dumps(["agent-kg-extract"]),
|
||||
"template.agent-kg-extract": json.dumps({
|
||||
"prompt": "Extract entities and relationships from: {{ text }}",
|
||||
"response-type": "json"
|
||||
})
|
||||
}
|
||||
|
||||
# Load configuration
|
||||
extractor.manager.load_config(config)
|
||||
|
||||
# Mock the on_message method to simulate real behavior
|
||||
async def mock_on_message(msg, consumer, flow):
|
||||
v = msg.value()
|
||||
chunk_text = v.chunk.decode('utf-8')
|
||||
|
||||
# Render prompt
|
||||
prompt = extractor.manager.render(extractor.template_id, {"text": chunk_text})
|
||||
|
||||
# Get agent response (the mock returns a string directly)
|
||||
agent_client = flow("agent-request")
|
||||
agent_response = agent_client.invoke(recipient=lambda x: True, question=prompt)
|
||||
|
||||
# Parse and process
|
||||
extraction_data = extractor.parse_json(agent_response)
|
||||
triples, entity_contexts = extractor.process_extraction_data(extraction_data, v.metadata)
|
||||
|
||||
# Add metadata triples
|
||||
for t in v.metadata.metadata:
|
||||
triples.append(t)
|
||||
|
||||
# Emit outputs
|
||||
if triples:
|
||||
await extractor.emit_triples(flow("triples"), v.metadata, triples)
|
||||
if entity_contexts:
|
||||
await extractor.emit_entity_contexts(flow("entity-contexts"), v.metadata, entity_contexts)
|
||||
|
||||
extractor.on_message = mock_on_message
|
||||
|
||||
return extractor
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_knowledge_extraction(self, configured_agent_extractor, sample_chunk, mock_flow_context):
|
||||
"""Test complete end-to-end knowledge extraction workflow"""
|
||||
# Arrange
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Verify agent was called with rendered prompt
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
# Check that the mock function was replaced and called
|
||||
assert hasattr(agent_client, 'invoke')
|
||||
|
||||
# Verify triples were emitted
|
||||
triples_publisher = mock_flow_context("triples")
|
||||
triples_publisher.send.assert_called_once()
|
||||
|
||||
sent_triples = triples_publisher.send.call_args[0][0]
|
||||
assert isinstance(sent_triples, Triples)
|
||||
assert sent_triples.metadata.id == "doc123"
|
||||
assert len(sent_triples.triples) > 0
|
||||
|
||||
# Check that we have definition triples
|
||||
definition_triples = [t for t in sent_triples.triples if t.p.value == DEFINITION]
|
||||
assert len(definition_triples) >= 2 # Should have definitions for ML and Neural Networks
|
||||
|
||||
# Check that we have label triples
|
||||
label_triples = [t for t in sent_triples.triples if t.p.value == RDF_LABEL]
|
||||
assert len(label_triples) >= 2 # Should have labels for entities
|
||||
|
||||
# Check subject-of relationships
|
||||
subject_of_triples = [t for t in sent_triples.triples if t.p.value == SUBJECT_OF]
|
||||
assert len(subject_of_triples) >= 2 # Entities should be linked to document
|
||||
|
||||
# Verify entity contexts were emitted
|
||||
entity_contexts_publisher = mock_flow_context("entity-contexts")
|
||||
entity_contexts_publisher.send.assert_called_once()
|
||||
|
||||
sent_contexts = entity_contexts_publisher.send.call_args[0][0]
|
||||
assert isinstance(sent_contexts, EntityContexts)
|
||||
assert len(sent_contexts.entities) >= 2 # Should have contexts for both entities
|
||||
|
||||
# Verify entity URIs are properly formed
|
||||
entity_uris = [ec.entity.value for ec in sent_contexts.entities]
|
||||
assert f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" in entity_uris
|
||||
assert f"{TRUSTGRAPH_ENTITIES}Neural%20Networks" in entity_uris
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_error_handling(self, configured_agent_extractor, sample_chunk, mock_flow_context):
|
||||
"""Test handling of agent errors"""
|
||||
# Arrange - mock agent error response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_error_response(recipient, question):
|
||||
# Simulate agent error by raising an exception
|
||||
raise RuntimeError("Agent processing failed")
|
||||
|
||||
agent_client.invoke = mock_error_response
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
assert "Agent processing failed" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_response_handling(self, configured_agent_extractor, sample_chunk, mock_flow_context):
|
||||
"""Test handling of invalid JSON responses from agent"""
|
||||
# Arrange - mock invalid JSON response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_invalid_json_response(recipient, question):
|
||||
return "This is not valid JSON at all"
|
||||
|
||||
agent_client.invoke = mock_invalid_json_response
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises((ValueError, json.JSONDecodeError)):
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_extraction_results(self, configured_agent_extractor, sample_chunk, mock_flow_context):
|
||||
"""Test handling of empty extraction results"""
|
||||
# Arrange - mock empty extraction response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_empty_response(recipient, question):
|
||||
return '{"definitions": [], "relationships": []}'
|
||||
|
||||
agent_client.invoke = mock_empty_response
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Should still emit outputs (even if empty) to maintain flow consistency
|
||||
triples_publisher = mock_flow_context("triples")
|
||||
entity_contexts_publisher = mock_flow_context("entity-contexts")
|
||||
|
||||
# Triples should include metadata triples at minimum
|
||||
triples_publisher.send.assert_called_once()
|
||||
sent_triples = triples_publisher.send.call_args[0][0]
|
||||
assert isinstance(sent_triples, Triples)
|
||||
|
||||
# Entity contexts should not be sent if empty
|
||||
entity_contexts_publisher.send.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_extraction_data(self, configured_agent_extractor, sample_chunk, mock_flow_context):
|
||||
"""Test handling of malformed extraction data"""
|
||||
# Arrange - mock malformed extraction response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_malformed_response(recipient, question):
|
||||
return '''{"definitions": [{"entity": "Missing Definition"}], "relationships": [{"subject": "Missing Object"}]}'''
|
||||
|
||||
agent_client.invoke = mock_malformed_response
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(KeyError):
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_rendering_integration(self, configured_agent_extractor, mock_flow_context):
|
||||
"""Test integration with prompt template rendering"""
|
||||
# Create a chunk with specific text
|
||||
test_text = "Test text for prompt rendering"
|
||||
chunk = Chunk(
|
||||
chunk=test_text.encode('utf-8'),
|
||||
metadata=Metadata(id="test-doc", metadata=[])
|
||||
)
|
||||
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def capture_prompt(recipient, question):
|
||||
# Verify the prompt contains the test text
|
||||
assert test_text in question
|
||||
return '{"definitions": [], "relationships": []}'
|
||||
|
||||
agent_client.invoke = capture_prompt
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert - prompt should have been rendered with the text
|
||||
# The agent_client.invoke is a function, not a mock, so we verify it was called by checking the flow worked
|
||||
assert hasattr(agent_client, 'invoke')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_processing_simulation(self, configured_agent_extractor, mock_flow_context):
|
||||
"""Test simulation of concurrent chunk processing"""
|
||||
# Create multiple chunks
|
||||
chunks = []
|
||||
for i in range(3):
|
||||
text = f"Test document {i} content"
|
||||
chunks.append(Chunk(
|
||||
chunk=text.encode('utf-8'),
|
||||
metadata=Metadata(id=f"doc{i}", metadata=[])
|
||||
))
|
||||
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
responses = []
|
||||
|
||||
def mock_response(recipient, question):
|
||||
response = f'{{"definitions": [{{"entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}], "relationships": []}}'
|
||||
responses.append(response)
|
||||
return response
|
||||
|
||||
agent_client.invoke = mock_response
|
||||
|
||||
# Process chunks sequentially (simulating concurrent processing)
|
||||
for chunk in chunks:
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 3
|
||||
|
||||
# Verify all chunks were processed
|
||||
triples_publisher = mock_flow_context("triples")
|
||||
assert triples_publisher.send.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unicode_text_handling(self, configured_agent_extractor, mock_flow_context):
|
||||
"""Test handling of text with unicode characters"""
|
||||
# Create chunk with unicode text
|
||||
unicode_text = "Machine Learning (学习机器) は人工知能の一分野です。"
|
||||
chunk = Chunk(
|
||||
chunk=unicode_text.encode('utf-8'),
|
||||
metadata=Metadata(id="unicode-doc", metadata=[])
|
||||
)
|
||||
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_unicode_response(recipient, question):
|
||||
# Verify unicode text was properly decoded and included
|
||||
assert "学习机器" in question
|
||||
assert "人工知能" in question
|
||||
return '''{"definitions": [{"entity": "機械学習", "definition": "人工知能の一分野"}], "relationships": []}'''
|
||||
|
||||
agent_client.invoke = mock_unicode_response
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert - should handle unicode properly
|
||||
triples_publisher = mock_flow_context("triples")
|
||||
triples_publisher.send.assert_called_once()
|
||||
|
||||
sent_triples = triples_publisher.send.call_args[0][0]
|
||||
# Check that unicode entity was properly processed
|
||||
entity_labels = [t for t in sent_triples.triples if t.p.value == RDF_LABEL and t.o.value == "機械学習"]
|
||||
assert len(entity_labels) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_text_chunk_processing(self, configured_agent_extractor, mock_flow_context):
|
||||
"""Test processing of large text chunks"""
|
||||
# Create a large text chunk
|
||||
large_text = "Machine Learning is important. " * 1000 # Repeat to create large text
|
||||
chunk = Chunk(
|
||||
chunk=large_text.encode('utf-8'),
|
||||
metadata=Metadata(id="large-doc", metadata=[])
|
||||
)
|
||||
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_large_text_response(recipient, question):
|
||||
# Verify large text was included
|
||||
assert len(question) > 10000
|
||||
return '''{"definitions": [{"entity": "Machine Learning", "definition": "Important AI technique"}], "relationships": []}'''
|
||||
|
||||
agent_client.invoke = mock_large_text_response
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert - should handle large text without issues
|
||||
triples_publisher = mock_flow_context("triples")
|
||||
triples_publisher.send.assert_called_once()
|
||||
|
||||
def test_configuration_parameter_validation(self):
|
||||
"""Test parameter validation logic"""
|
||||
# Test that default parameter logic would work
|
||||
default_template_id = "agent-kg-extract"
|
||||
default_config_type = "prompt"
|
||||
default_concurrency = 1
|
||||
|
||||
# Simulate parameter handling
|
||||
params = {}
|
||||
template_id = params.get("template-id", default_template_id)
|
||||
config_key = params.get("config-type", default_config_type)
|
||||
concurrency = params.get("concurrency", default_concurrency)
|
||||
|
||||
assert template_id == "agent-kg-extract"
|
||||
assert config_key == "prompt"
|
||||
assert concurrency == 1
|
||||
|
||||
# Test with custom parameters
|
||||
custom_params = {
|
||||
"template-id": "custom-template",
|
||||
"config-type": "custom-config",
|
||||
"concurrency": 10
|
||||
}
|
||||
|
||||
template_id = custom_params.get("template-id", default_template_id)
|
||||
config_key = custom_params.get("config-type", default_config_type)
|
||||
concurrency = custom_params.get("concurrency", default_concurrency)
|
||||
|
||||
assert template_id == "custom-template"
|
||||
assert config_key == "custom-config"
|
||||
assert concurrency == 10
|
||||
|
|
@ -28,11 +28,11 @@ class TestAgentManagerIntegration:
|
|||
|
||||
# Mock prompt client
|
||||
prompt_client = AsyncMock()
|
||||
prompt_client.agent_react.return_value = {
|
||||
"thought": "I need to search for information about machine learning",
|
||||
"action": "knowledge_query",
|
||||
"arguments": {"question": "What is machine learning?"}
|
||||
}
|
||||
prompt_client.agent_react.return_value = """Thought: I need to search for information about machine learning
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is machine learning?"
|
||||
}"""
|
||||
|
||||
# Mock graph RAG client
|
||||
graph_rag_client = AsyncMock()
|
||||
|
|
@ -147,10 +147,8 @@ class TestAgentManagerIntegration:
|
|||
async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager returning final answer"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": "I have enough information to answer the question",
|
||||
"final-answer": "Machine learning is a field of AI that enables computers to learn from data."
|
||||
}
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have enough information to answer the question
|
||||
Final Answer: Machine learning is a field of AI that enables computers to learn from data."""
|
||||
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
|
@ -195,10 +193,8 @@ class TestAgentManagerIntegration:
|
|||
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
|
||||
"""Test ReAct cycle ending with final answer"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": "I can provide a direct answer",
|
||||
"final-answer": "Machine learning is a branch of artificial intelligence."
|
||||
}
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide a direct answer
|
||||
Final Answer: Machine learning is a branch of artificial intelligence."""
|
||||
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
|
@ -258,11 +254,11 @@ class TestAgentManagerIntegration:
|
|||
|
||||
for tool_name, expected_service in tool_scenarios:
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": f"I need to use {tool_name}",
|
||||
"action": tool_name,
|
||||
"arguments": {"question": "test question"}
|
||||
}
|
||||
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: I need to use {tool_name}
|
||||
Action: {tool_name}
|
||||
Args: {{
|
||||
"question": "test question"
|
||||
}}"""
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -288,11 +284,11 @@ class TestAgentManagerIntegration:
|
|||
async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager error handling for unknown tool"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": "I need to use an unknown tool",
|
||||
"action": "unknown_tool",
|
||||
"arguments": {"param": "value"}
|
||||
}
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to use an unknown tool
|
||||
Action: unknown_tool
|
||||
Args: {
|
||||
"param": "value"
|
||||
}"""
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -325,11 +321,11 @@ class TestAgentManagerIntegration:
|
|||
question = "Find information about AI and summarize it"
|
||||
|
||||
# Mock multi-step reasoning
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": "I need to search for AI information first",
|
||||
"action": "knowledge_query",
|
||||
"arguments": {"question": "What is artificial intelligence?"}
|
||||
}
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search for AI information first
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is artificial intelligence?"
|
||||
}"""
|
||||
|
||||
# Act
|
||||
action = await agent_manager.reason(question, [], mock_flow_context)
|
||||
|
|
@ -373,11 +369,12 @@ class TestAgentManagerIntegration:
|
|||
|
||||
for test_case in test_cases:
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": f"Using {test_case['action']}",
|
||||
"action": test_case['action'],
|
||||
"arguments": test_case['arguments']
|
||||
}
|
||||
# Format arguments as JSON
|
||||
import json
|
||||
args_json = json.dumps(test_case['arguments'], indent=4)
|
||||
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: Using {test_case['action']}
|
||||
Action: {test_case['action']}
|
||||
Args: {args_json}"""
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -465,6 +462,193 @@ class TestAgentManagerIntegration:
|
|||
# Reset mocks
|
||||
mock_flow_context("graph-rag-request").reset_mock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_malformed_response_handling(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager handling of malformed text responses"""
|
||||
# Test cases with expected error messages
|
||||
test_cases = [
|
||||
# Missing action/final answer
|
||||
{
|
||||
"response": "Thought: I need to do something",
|
||||
"error_contains": "Response has thought but no action or final answer"
|
||||
},
|
||||
# Invalid JSON in Args
|
||||
{
|
||||
"response": """Thought: I need to search
|
||||
Action: knowledge_query
|
||||
Args: {invalid json}""",
|
||||
"error_contains": "Invalid JSON in Args"
|
||||
},
|
||||
# Empty response
|
||||
{
|
||||
"response": "",
|
||||
"error_contains": "Could not parse response"
|
||||
},
|
||||
# Only whitespace
|
||||
{
|
||||
"response": " \n\t ",
|
||||
"error_contains": "Could not parse response"
|
||||
},
|
||||
# Missing Args for action (should create empty args dict)
|
||||
{
|
||||
"response": """Thought: I need to search
|
||||
Action: knowledge_query""",
|
||||
"error_contains": None # This should actually succeed with empty args
|
||||
},
|
||||
# Incomplete JSON
|
||||
{
|
||||
"response": """Thought: I need to search
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "test"
|
||||
""",
|
||||
"error_contains": "Invalid JSON in Args"
|
||||
},
|
||||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
mock_flow_context("prompt-request").agent_react.return_value = test_case["response"]
|
||||
|
||||
if test_case["error_contains"]:
|
||||
# Should raise an error
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await agent_manager.reason("test question", [], mock_flow_context)
|
||||
|
||||
assert "Failed to parse agent response" in str(exc_info.value)
|
||||
assert test_case["error_contains"] in str(exc_info.value)
|
||||
else:
|
||||
# Should succeed
|
||||
action = await agent_manager.reason("test question", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
assert action.name == "knowledge_query"
|
||||
assert action.arguments == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_text_parsing_edge_cases(self, agent_manager, mock_flow_context):
|
||||
"""Test edge cases in text parsing"""
|
||||
# Test response with markdown code blocks
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """```
|
||||
Thought: I need to search for information
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is AI?"
|
||||
}
|
||||
```"""
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
assert action.thought == "I need to search for information"
|
||||
assert action.name == "knowledge_query"
|
||||
|
||||
# Test response with extra whitespace
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """
|
||||
|
||||
Thought: I need to think about this
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "test"
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
assert action.thought == "I need to think about this"
|
||||
assert action.name == "knowledge_query"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_multiline_content(self, agent_manager, mock_flow_context):
|
||||
"""Test handling of multi-line thoughts and final answers"""
|
||||
# Multi-line thought
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to consider multiple factors:
|
||||
1. The user's question is complex
|
||||
2. I should search for comprehensive information
|
||||
3. This requires using the knowledge query tool
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "complex query"
|
||||
}"""
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
assert "multiple factors" in action.thought
|
||||
assert "knowledge query tool" in action.thought
|
||||
|
||||
# Multi-line final answer
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have gathered enough information
|
||||
Final Answer: Here is a comprehensive answer:
|
||||
1. First point about the topic
|
||||
2. Second point with details
|
||||
3. Final conclusion
|
||||
|
||||
This covers all aspects of the question."""
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Final)
|
||||
assert "First point" in action.final
|
||||
assert "Final conclusion" in action.final
|
||||
assert "all aspects" in action.final
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_json_args_special_characters(self, agent_manager, mock_flow_context):
|
||||
"""Test JSON arguments with special characters and edge cases"""
|
||||
# Test with special characters in JSON (properly escaped)
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Processing special characters
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What about \\"quotes\\" and 'apostrophes'?",
|
||||
"context": "Line 1\\nLine 2\\tTabbed",
|
||||
"special": "Symbols: @#$%^&*()_+-=[]{}|;':,.<>?"
|
||||
}"""
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
assert action.arguments["question"] == 'What about "quotes" and \'apostrophes\'?'
|
||||
assert action.arguments["context"] == "Line 1\nLine 2\tTabbed"
|
||||
assert "@#$%^&*" in action.arguments["special"]
|
||||
|
||||
# Test with nested JSON
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Complex arguments
|
||||
Action: web_search
|
||||
Args: {
|
||||
"query": "test",
|
||||
"options": {
|
||||
"limit": 10,
|
||||
"filters": ["recent", "relevant"],
|
||||
"metadata": {
|
||||
"source": "user",
|
||||
"timestamp": "2024-01-01"
|
||||
}
|
||||
}
|
||||
}"""
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
assert action.arguments["options"]["limit"] == 10
|
||||
assert "recent" in action.arguments["options"]["filters"]
|
||||
assert action.arguments["options"]["metadata"]["source"] == "user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_final_answer_json_format(self, agent_manager, mock_flow_context):
|
||||
"""Test final answers that contain JSON-like content"""
|
||||
# Final answer with JSON content
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide the data in JSON format
|
||||
Final Answer: {
|
||||
"result": "success",
|
||||
"data": {
|
||||
"name": "Machine Learning",
|
||||
"type": "AI Technology",
|
||||
"applications": ["NLP", "Computer Vision", "Robotics"]
|
||||
},
|
||||
"confidence": 0.95
|
||||
}"""
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Final)
|
||||
# The final answer should preserve the JSON structure as a string
|
||||
assert '"result": "success"' in action.final
|
||||
assert '"applications":' in action.final
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
async def test_agent_manager_performance_with_large_history(self, agent_manager, mock_flow_context):
|
||||
|
|
|
|||
205
tests/integration/test_template_service_integration.py
Normal file
205
tests/integration/test_template_service_integration.py
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
"""
|
||||
Simplified integration tests for Template Service
|
||||
|
||||
These tests verify the basic functionality of the template service
|
||||
without the full message queue infrastructure.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.schema import PromptRequest, PromptResponse
|
||||
from trustgraph.template.prompt_manager import PromptManager
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestTemplateServiceSimple:
|
||||
"""Simplified integration tests for Template Service components"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config(self):
|
||||
"""Sample configuration for testing"""
|
||||
return {
|
||||
"system": json.dumps("You are a helpful assistant."),
|
||||
"template-index": json.dumps(["greeting", "json_test"]),
|
||||
"template.greeting": json.dumps({
|
||||
"prompt": "Hello {{ name }}, welcome to {{ system_name }}!",
|
||||
"response-type": "text"
|
||||
}),
|
||||
"template.json_test": json.dumps({
|
||||
"prompt": "Generate profile for {{ username }}",
|
||||
"response-type": "json",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"role": {"type": "string"}
|
||||
},
|
||||
"required": ["name", "role"]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_manager(self, sample_config):
|
||||
"""Create a configured PromptManager"""
|
||||
pm = PromptManager()
|
||||
pm.load_config(sample_config)
|
||||
pm.terms["system_name"] = "TrustGraph"
|
||||
return pm
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_manager_text_invocation(self, prompt_manager):
|
||||
"""Test PromptManager text response invocation"""
|
||||
# Mock LLM function
|
||||
async def mock_llm(system, prompt):
|
||||
assert system == "You are a helpful assistant."
|
||||
assert "Hello Alice, welcome to TrustGraph!" in prompt
|
||||
return "Welcome message processed!"
|
||||
|
||||
result = await prompt_manager.invoke("greeting", {"name": "Alice"}, mock_llm)
|
||||
|
||||
assert result == "Welcome message processed!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_manager_json_invocation(self, prompt_manager):
|
||||
"""Test PromptManager JSON response invocation"""
|
||||
# Mock LLM function
|
||||
async def mock_llm(system, prompt):
|
||||
assert "Generate profile for johndoe" in prompt
|
||||
return '{"name": "John Doe", "role": "user"}'
|
||||
|
||||
result = await prompt_manager.invoke("json_test", {"username": "johndoe"}, mock_llm)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["name"] == "John Doe"
|
||||
assert result["role"] == "user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_manager_json_validation_error(self, prompt_manager):
|
||||
"""Test JSON schema validation failure"""
|
||||
# Mock LLM function that returns invalid JSON
|
||||
async def mock_llm(system, prompt):
|
||||
return '{"name": "John Doe"}' # Missing required "role"
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await prompt_manager.invoke("json_test", {"username": "johndoe"}, mock_llm)
|
||||
|
||||
assert "Schema validation fail" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_manager_json_parse_error(self, prompt_manager):
|
||||
"""Test JSON parsing failure"""
|
||||
# Mock LLM function that returns non-JSON
|
||||
async def mock_llm(system, prompt):
|
||||
return "This is not JSON at all"
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await prompt_manager.invoke("json_test", {"username": "johndoe"}, mock_llm)
|
||||
|
||||
assert "JSON parse fail" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_manager_unknown_prompt(self, prompt_manager):
|
||||
"""Test unknown prompt ID handling"""
|
||||
async def mock_llm(system, prompt):
|
||||
return "Response"
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
await prompt_manager.invoke("unknown_prompt", {}, mock_llm)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_manager_term_merging(self, prompt_manager):
|
||||
"""Test proper term merging (global + prompt + input)"""
|
||||
# Add prompt-specific terms
|
||||
prompt_manager.prompts["greeting"].terms = {"greeting_prefix": "Hi"}
|
||||
|
||||
async def mock_llm(system, prompt):
|
||||
# Should have global term (system_name), input term (name), and any prompt terms
|
||||
assert "TrustGraph" in prompt # Global term
|
||||
assert "Bob" in prompt # Input term
|
||||
return "Merged correctly"
|
||||
|
||||
result = await prompt_manager.invoke("greeting", {"name": "Bob"}, mock_llm)
|
||||
assert result == "Merged correctly"
|
||||
|
||||
def test_prompt_manager_template_rendering(self, prompt_manager):
|
||||
"""Test direct template rendering"""
|
||||
result = prompt_manager.render("greeting", {"name": "Charlie"})
|
||||
|
||||
assert "Hello Charlie, welcome to TrustGraph!" == result.strip()
|
||||
|
||||
def test_prompt_manager_configuration_loading(self):
|
||||
"""Test configuration loading with various formats"""
|
||||
pm = PromptManager()
|
||||
|
||||
# Test empty configuration
|
||||
pm.load_config({})
|
||||
assert pm.config.system_template == "Be helpful."
|
||||
assert len(pm.prompts) == 0
|
||||
|
||||
# Test configuration with single prompt
|
||||
config = {
|
||||
"system": json.dumps("Test system"),
|
||||
"template-index": json.dumps(["test"]),
|
||||
"template.test": json.dumps({
|
||||
"prompt": "Test {{ value }}",
|
||||
"response-type": "text"
|
||||
})
|
||||
}
|
||||
pm.load_config(config)
|
||||
|
||||
assert pm.config.system_template == "Test system"
|
||||
assert "test" in pm.prompts
|
||||
assert pm.prompts["test"].response_type == "text"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_manager_json_with_markdown(self, prompt_manager):
|
||||
"""Test JSON extraction from markdown code blocks"""
|
||||
async def mock_llm(system, prompt):
|
||||
return '''
|
||||
Here's the profile:
|
||||
```json
|
||||
{"name": "Jane Smith", "role": "admin"}
|
||||
```
|
||||
'''
|
||||
|
||||
result = await prompt_manager.invoke("json_test", {"username": "jane"}, mock_llm)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["name"] == "Jane Smith"
|
||||
assert result["role"] == "admin"
|
||||
|
||||
def test_prompt_manager_error_handling_in_templates(self, prompt_manager):
|
||||
"""Test error handling in template rendering"""
|
||||
# Test with missing variable - ibis might handle this differently than Jinja2
|
||||
try:
|
||||
result = prompt_manager.render("greeting", {}) # Missing 'name'
|
||||
# If no exception, check that result is still a string
|
||||
assert isinstance(result, str)
|
||||
except Exception as e:
|
||||
# If exception is raised, that's also acceptable
|
||||
assert "name" in str(e) or "undefined" in str(e).lower() or "variable" in str(e).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_prompt_invocations(self, prompt_manager):
|
||||
"""Test concurrent invocations"""
|
||||
async def mock_llm(system, prompt):
|
||||
# Extract name from prompt for response
|
||||
if "Alice" in prompt:
|
||||
return "Alice response"
|
||||
elif "Bob" in prompt:
|
||||
return "Bob response"
|
||||
else:
|
||||
return "Default response"
|
||||
|
||||
# Run concurrent invocations
|
||||
import asyncio
|
||||
results = await asyncio.gather(
|
||||
prompt_manager.invoke("greeting", {"name": "Alice"}, mock_llm),
|
||||
prompt_manager.invoke("greeting", {"name": "Bob"}, mock_llm),
|
||||
)
|
||||
|
||||
assert "Alice response" in results
|
||||
assert "Bob response" in results
|
||||
Loading…
Add table
Add a link
Reference in a new issue