mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 00:46:22 +02:00
Update to enable knowledge extraction using the agent framework (#439)
* Implement KG extraction agent (kg-extract-agent) * Using ReAct framework (agent-manager-react) * ReAct manager had an issue when emitting JSON, which conflicts which ReAct manager's own JSON messages, so refactored ReAct manager to use traditional ReAct messages, non-JSON structure. * Minor refactor to take the prompt template client out of prompt-template so it can be more readily used by other modules. kg-extract-agent uses this framework.
This commit is contained in:
parent
1fe4ed5226
commit
d83e4e3d59
30 changed files with 3192 additions and 799 deletions
481
tests/integration/test_agent_kg_extraction_integration.py
Normal file
481
tests/integration/test_agent_kg_extraction_integration.py
Normal file
|
|
@ -0,0 +1,481 @@
|
|||
"""
|
||||
Integration tests for Agent-based Knowledge Graph Extraction
|
||||
|
||||
These tests verify the end-to-end functionality of the agent-driven knowledge graph
|
||||
extraction pipeline, testing the integration between agent communication, prompt
|
||||
rendering, JSON response processing, and knowledge graph generation.
|
||||
Following the TEST_STRATEGY.md approach for integration testing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error
|
||||
from trustgraph.schema import EntityContext, EntityContexts, AgentRequest, AgentResponse
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
from trustgraph.template.prompt_manager import PromptManager
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAgentKgExtractionIntegration:
|
||||
"""Integration tests for Agent-based Knowledge Graph Extraction"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flow_context(self):
|
||||
"""Mock flow context for agent communication and output publishing"""
|
||||
context = MagicMock()
|
||||
|
||||
# Mock agent client
|
||||
agent_client = AsyncMock()
|
||||
|
||||
# Mock successful agent response
|
||||
def mock_agent_response(recipient, question):
|
||||
# Simulate agent processing and return structured response
|
||||
mock_response = MagicMock()
|
||||
mock_response.error = None
|
||||
mock_response.answer = '''```json
|
||||
{
|
||||
"definitions": [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": true
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```'''
|
||||
return mock_response.answer
|
||||
|
||||
agent_client.invoke = mock_agent_response
|
||||
|
||||
# Mock output publishers
|
||||
triples_publisher = AsyncMock()
|
||||
entity_contexts_publisher = AsyncMock()
|
||||
|
||||
def context_router(service_name):
|
||||
if service_name == "agent-request":
|
||||
return agent_client
|
||||
elif service_name == "triples":
|
||||
return triples_publisher
|
||||
elif service_name == "entity-contexts":
|
||||
return entity_contexts_publisher
|
||||
else:
|
||||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
return context
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunk(self):
|
||||
"""Sample text chunk for knowledge extraction"""
|
||||
text = """
|
||||
Machine Learning is a subset of Artificial Intelligence that enables computers
|
||||
to learn from data without explicit programming. Neural Networks are computing
|
||||
systems inspired by biological neural networks that process information.
|
||||
Neural Networks are commonly used in Machine Learning applications.
|
||||
"""
|
||||
|
||||
return Chunk(
|
||||
chunk=text.encode('utf-8'),
|
||||
metadata=Metadata(
|
||||
id="doc123",
|
||||
metadata=[
|
||||
Triple(
|
||||
s=Value(value="doc123", is_uri=True),
|
||||
p=Value(value="http://example.org/type", is_uri=True),
|
||||
o=Value(value="document", is_uri=False)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def configured_agent_extractor(self):
|
||||
"""Mock agent extractor with loaded configuration for integration testing"""
|
||||
# Create a mock extractor that simulates the real behavior
|
||||
from trustgraph.extract.kg.agent.extract import Processor
|
||||
|
||||
# Create mock without calling __init__ to avoid FlowProcessor issues
|
||||
extractor = MagicMock()
|
||||
real_extractor = Processor.__new__(Processor)
|
||||
|
||||
# Copy the methods we want to test
|
||||
extractor.to_uri = real_extractor.to_uri
|
||||
extractor.parse_json = real_extractor.parse_json
|
||||
extractor.process_extraction_data = real_extractor.process_extraction_data
|
||||
extractor.emit_triples = real_extractor.emit_triples
|
||||
extractor.emit_entity_contexts = real_extractor.emit_entity_contexts
|
||||
|
||||
# Set up the configuration and manager
|
||||
extractor.manager = PromptManager()
|
||||
extractor.template_id = "agent-kg-extract"
|
||||
extractor.config_key = "prompt"
|
||||
|
||||
# Mock configuration
|
||||
config = {
|
||||
"system": json.dumps("You are a knowledge extraction agent."),
|
||||
"template-index": json.dumps(["agent-kg-extract"]),
|
||||
"template.agent-kg-extract": json.dumps({
|
||||
"prompt": "Extract entities and relationships from: {{ text }}",
|
||||
"response-type": "json"
|
||||
})
|
||||
}
|
||||
|
||||
# Load configuration
|
||||
extractor.manager.load_config(config)
|
||||
|
||||
# Mock the on_message method to simulate real behavior
|
||||
async def mock_on_message(msg, consumer, flow):
|
||||
v = msg.value()
|
||||
chunk_text = v.chunk.decode('utf-8')
|
||||
|
||||
# Render prompt
|
||||
prompt = extractor.manager.render(extractor.template_id, {"text": chunk_text})
|
||||
|
||||
# Get agent response (the mock returns a string directly)
|
||||
agent_client = flow("agent-request")
|
||||
agent_response = agent_client.invoke(recipient=lambda x: True, question=prompt)
|
||||
|
||||
# Parse and process
|
||||
extraction_data = extractor.parse_json(agent_response)
|
||||
triples, entity_contexts = extractor.process_extraction_data(extraction_data, v.metadata)
|
||||
|
||||
# Add metadata triples
|
||||
for t in v.metadata.metadata:
|
||||
triples.append(t)
|
||||
|
||||
# Emit outputs
|
||||
if triples:
|
||||
await extractor.emit_triples(flow("triples"), v.metadata, triples)
|
||||
if entity_contexts:
|
||||
await extractor.emit_entity_contexts(flow("entity-contexts"), v.metadata, entity_contexts)
|
||||
|
||||
extractor.on_message = mock_on_message
|
||||
|
||||
return extractor
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_knowledge_extraction(self, configured_agent_extractor, sample_chunk, mock_flow_context):
|
||||
"""Test complete end-to-end knowledge extraction workflow"""
|
||||
# Arrange
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Verify agent was called with rendered prompt
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
# Check that the mock function was replaced and called
|
||||
assert hasattr(agent_client, 'invoke')
|
||||
|
||||
# Verify triples were emitted
|
||||
triples_publisher = mock_flow_context("triples")
|
||||
triples_publisher.send.assert_called_once()
|
||||
|
||||
sent_triples = triples_publisher.send.call_args[0][0]
|
||||
assert isinstance(sent_triples, Triples)
|
||||
assert sent_triples.metadata.id == "doc123"
|
||||
assert len(sent_triples.triples) > 0
|
||||
|
||||
# Check that we have definition triples
|
||||
definition_triples = [t for t in sent_triples.triples if t.p.value == DEFINITION]
|
||||
assert len(definition_triples) >= 2 # Should have definitions for ML and Neural Networks
|
||||
|
||||
# Check that we have label triples
|
||||
label_triples = [t for t in sent_triples.triples if t.p.value == RDF_LABEL]
|
||||
assert len(label_triples) >= 2 # Should have labels for entities
|
||||
|
||||
# Check subject-of relationships
|
||||
subject_of_triples = [t for t in sent_triples.triples if t.p.value == SUBJECT_OF]
|
||||
assert len(subject_of_triples) >= 2 # Entities should be linked to document
|
||||
|
||||
# Verify entity contexts were emitted
|
||||
entity_contexts_publisher = mock_flow_context("entity-contexts")
|
||||
entity_contexts_publisher.send.assert_called_once()
|
||||
|
||||
sent_contexts = entity_contexts_publisher.send.call_args[0][0]
|
||||
assert isinstance(sent_contexts, EntityContexts)
|
||||
assert len(sent_contexts.entities) >= 2 # Should have contexts for both entities
|
||||
|
||||
# Verify entity URIs are properly formed
|
||||
entity_uris = [ec.entity.value for ec in sent_contexts.entities]
|
||||
assert f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" in entity_uris
|
||||
assert f"{TRUSTGRAPH_ENTITIES}Neural%20Networks" in entity_uris
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_error_handling(self, configured_agent_extractor, sample_chunk, mock_flow_context):
|
||||
"""Test handling of agent errors"""
|
||||
# Arrange - mock agent error response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_error_response(recipient, question):
|
||||
# Simulate agent error by raising an exception
|
||||
raise RuntimeError("Agent processing failed")
|
||||
|
||||
agent_client.invoke = mock_error_response
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
assert "Agent processing failed" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_response_handling(self, configured_agent_extractor, sample_chunk, mock_flow_context):
|
||||
"""Test handling of invalid JSON responses from agent"""
|
||||
# Arrange - mock invalid JSON response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_invalid_json_response(recipient, question):
|
||||
return "This is not valid JSON at all"
|
||||
|
||||
agent_client.invoke = mock_invalid_json_response
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises((ValueError, json.JSONDecodeError)):
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_extraction_results(self, configured_agent_extractor, sample_chunk, mock_flow_context):
|
||||
"""Test handling of empty extraction results"""
|
||||
# Arrange - mock empty extraction response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_empty_response(recipient, question):
|
||||
return '{"definitions": [], "relationships": []}'
|
||||
|
||||
agent_client.invoke = mock_empty_response
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Should still emit outputs (even if empty) to maintain flow consistency
|
||||
triples_publisher = mock_flow_context("triples")
|
||||
entity_contexts_publisher = mock_flow_context("entity-contexts")
|
||||
|
||||
# Triples should include metadata triples at minimum
|
||||
triples_publisher.send.assert_called_once()
|
||||
sent_triples = triples_publisher.send.call_args[0][0]
|
||||
assert isinstance(sent_triples, Triples)
|
||||
|
||||
# Entity contexts should not be sent if empty
|
||||
entity_contexts_publisher.send.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_extraction_data(self, configured_agent_extractor, sample_chunk, mock_flow_context):
|
||||
"""Test handling of malformed extraction data"""
|
||||
# Arrange - mock malformed extraction response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_malformed_response(recipient, question):
|
||||
return '''{"definitions": [{"entity": "Missing Definition"}], "relationships": [{"subject": "Missing Object"}]}'''
|
||||
|
||||
agent_client.invoke = mock_malformed_response
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(KeyError):
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_rendering_integration(self, configured_agent_extractor, mock_flow_context):
|
||||
"""Test integration with prompt template rendering"""
|
||||
# Create a chunk with specific text
|
||||
test_text = "Test text for prompt rendering"
|
||||
chunk = Chunk(
|
||||
chunk=test_text.encode('utf-8'),
|
||||
metadata=Metadata(id="test-doc", metadata=[])
|
||||
)
|
||||
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def capture_prompt(recipient, question):
|
||||
# Verify the prompt contains the test text
|
||||
assert test_text in question
|
||||
return '{"definitions": [], "relationships": []}'
|
||||
|
||||
agent_client.invoke = capture_prompt
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert - prompt should have been rendered with the text
|
||||
# The agent_client.invoke is a function, not a mock, so we verify it was called by checking the flow worked
|
||||
assert hasattr(agent_client, 'invoke')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_processing_simulation(self, configured_agent_extractor, mock_flow_context):
|
||||
"""Test simulation of concurrent chunk processing"""
|
||||
# Create multiple chunks
|
||||
chunks = []
|
||||
for i in range(3):
|
||||
text = f"Test document {i} content"
|
||||
chunks.append(Chunk(
|
||||
chunk=text.encode('utf-8'),
|
||||
metadata=Metadata(id=f"doc{i}", metadata=[])
|
||||
))
|
||||
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
responses = []
|
||||
|
||||
def mock_response(recipient, question):
|
||||
response = f'{{"definitions": [{{"entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}], "relationships": []}}'
|
||||
responses.append(response)
|
||||
return response
|
||||
|
||||
agent_client.invoke = mock_response
|
||||
|
||||
# Process chunks sequentially (simulating concurrent processing)
|
||||
for chunk in chunks:
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 3
|
||||
|
||||
# Verify all chunks were processed
|
||||
triples_publisher = mock_flow_context("triples")
|
||||
assert triples_publisher.send.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unicode_text_handling(self, configured_agent_extractor, mock_flow_context):
|
||||
"""Test handling of text with unicode characters"""
|
||||
# Create chunk with unicode text
|
||||
unicode_text = "Machine Learning (学习机器) は人工知能の一分野です。"
|
||||
chunk = Chunk(
|
||||
chunk=unicode_text.encode('utf-8'),
|
||||
metadata=Metadata(id="unicode-doc", metadata=[])
|
||||
)
|
||||
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_unicode_response(recipient, question):
|
||||
# Verify unicode text was properly decoded and included
|
||||
assert "学习机器" in question
|
||||
assert "人工知能" in question
|
||||
return '''{"definitions": [{"entity": "機械学習", "definition": "人工知能の一分野"}], "relationships": []}'''
|
||||
|
||||
agent_client.invoke = mock_unicode_response
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert - should handle unicode properly
|
||||
triples_publisher = mock_flow_context("triples")
|
||||
triples_publisher.send.assert_called_once()
|
||||
|
||||
sent_triples = triples_publisher.send.call_args[0][0]
|
||||
# Check that unicode entity was properly processed
|
||||
entity_labels = [t for t in sent_triples.triples if t.p.value == RDF_LABEL and t.o.value == "機械学習"]
|
||||
assert len(entity_labels) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_text_chunk_processing(self, configured_agent_extractor, mock_flow_context):
|
||||
"""Test processing of large text chunks"""
|
||||
# Create a large text chunk
|
||||
large_text = "Machine Learning is important. " * 1000 # Repeat to create large text
|
||||
chunk = Chunk(
|
||||
chunk=large_text.encode('utf-8'),
|
||||
metadata=Metadata(id="large-doc", metadata=[])
|
||||
)
|
||||
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_large_text_response(recipient, question):
|
||||
# Verify large text was included
|
||||
assert len(question) > 10000
|
||||
return '''{"definitions": [{"entity": "Machine Learning", "definition": "Important AI technique"}], "relationships": []}'''
|
||||
|
||||
agent_client.invoke = mock_large_text_response
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert - should handle large text without issues
|
||||
triples_publisher = mock_flow_context("triples")
|
||||
triples_publisher.send.assert_called_once()
|
||||
|
||||
def test_configuration_parameter_validation(self):
|
||||
"""Test parameter validation logic"""
|
||||
# Test that default parameter logic would work
|
||||
default_template_id = "agent-kg-extract"
|
||||
default_config_type = "prompt"
|
||||
default_concurrency = 1
|
||||
|
||||
# Simulate parameter handling
|
||||
params = {}
|
||||
template_id = params.get("template-id", default_template_id)
|
||||
config_key = params.get("config-type", default_config_type)
|
||||
concurrency = params.get("concurrency", default_concurrency)
|
||||
|
||||
assert template_id == "agent-kg-extract"
|
||||
assert config_key == "prompt"
|
||||
assert concurrency == 1
|
||||
|
||||
# Test with custom parameters
|
||||
custom_params = {
|
||||
"template-id": "custom-template",
|
||||
"config-type": "custom-config",
|
||||
"concurrency": 10
|
||||
}
|
||||
|
||||
template_id = custom_params.get("template-id", default_template_id)
|
||||
config_key = custom_params.get("config-type", default_config_type)
|
||||
concurrency = custom_params.get("concurrency", default_concurrency)
|
||||
|
||||
assert template_id == "custom-template"
|
||||
assert config_key == "custom-config"
|
||||
assert concurrency == 10
|
||||
Loading…
Add table
Add a link
Reference in a new issue