mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 08:26:21 +02:00
* Implement KG extraction agent (kg-extract-agent) * Using ReAct framework (agent-manager-react) * ReAct manager had an issue when emitting JSON, which conflicts which ReAct manager's own JSON messages, so refactored ReAct manager to use traditional ReAct messages, non-JSON structure. * Minor refactor to take the prompt template client out of prompt-template so it can be more readily used by other modules. kg-extract-agent uses this framework.
432 lines
No EOL
18 KiB
Python
432 lines
No EOL
18 KiB
Python
"""
|
|
Unit tests for Agent-based Knowledge Graph Extraction
|
|
|
|
These tests verify the core functionality of the agent-driven KG extractor,
|
|
including JSON response parsing, triple generation, entity context creation,
|
|
and RDF URI handling.
|
|
"""
|
|
|
|
import pytest
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
|
|
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error
|
|
from trustgraph.schema import EntityContext, EntityContexts
|
|
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
|
from trustgraph.template.prompt_manager import PromptManager
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestAgentKgExtractor:
|
|
"""Unit tests for Agent-based Knowledge Graph Extractor"""
|
|
|
|
@pytest.fixture
|
|
def agent_extractor(self):
|
|
"""Create a mock agent extractor for testing core functionality"""
|
|
# Create a mock that has the methods we want to test
|
|
extractor = MagicMock()
|
|
|
|
# Add real implementations of the methods we want to test
|
|
from trustgraph.extract.kg.agent.extract import Processor
|
|
real_extractor = Processor.__new__(Processor) # Create without calling __init__
|
|
|
|
# Set up the methods we want to test
|
|
extractor.to_uri = real_extractor.to_uri
|
|
extractor.parse_json = real_extractor.parse_json
|
|
extractor.process_extraction_data = real_extractor.process_extraction_data
|
|
extractor.emit_triples = real_extractor.emit_triples
|
|
extractor.emit_entity_contexts = real_extractor.emit_entity_contexts
|
|
|
|
# Mock the prompt manager
|
|
extractor.manager = PromptManager()
|
|
extractor.template_id = "agent-kg-extract"
|
|
extractor.config_key = "prompt"
|
|
extractor.concurrency = 1
|
|
|
|
return extractor
|
|
|
|
@pytest.fixture
|
|
def sample_metadata(self):
|
|
"""Sample metadata for testing"""
|
|
return Metadata(
|
|
id="doc123",
|
|
metadata=[
|
|
Triple(
|
|
s=Value(value="doc123", is_uri=True),
|
|
p=Value(value="http://example.org/type", is_uri=True),
|
|
o=Value(value="document", is_uri=False)
|
|
)
|
|
]
|
|
)
|
|
|
|
@pytest.fixture
|
|
def sample_extraction_data(self):
|
|
"""Sample extraction data in expected format"""
|
|
return {
|
|
"definitions": [
|
|
{
|
|
"entity": "Machine Learning",
|
|
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
|
},
|
|
{
|
|
"entity": "Neural Networks",
|
|
"definition": "Computing systems inspired by biological neural networks that process information."
|
|
}
|
|
],
|
|
"relationships": [
|
|
{
|
|
"subject": "Machine Learning",
|
|
"predicate": "is_subset_of",
|
|
"object": "Artificial Intelligence",
|
|
"object-entity": True
|
|
},
|
|
{
|
|
"subject": "Neural Networks",
|
|
"predicate": "used_in",
|
|
"object": "Machine Learning",
|
|
"object-entity": True
|
|
},
|
|
{
|
|
"subject": "Deep Learning",
|
|
"predicate": "accuracy",
|
|
"object": "95%",
|
|
"object-entity": False
|
|
}
|
|
]
|
|
}
|
|
|
|
def test_to_uri_conversion(self, agent_extractor):
|
|
"""Test URI conversion for entities"""
|
|
# Test simple entity name
|
|
uri = agent_extractor.to_uri("Machine Learning")
|
|
expected = f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
|
assert uri == expected
|
|
|
|
# Test entity with special characters
|
|
uri = agent_extractor.to_uri("Entity with & special chars!")
|
|
expected = f"{TRUSTGRAPH_ENTITIES}Entity%20with%20%26%20special%20chars%21"
|
|
assert uri == expected
|
|
|
|
# Test empty string
|
|
uri = agent_extractor.to_uri("")
|
|
expected = f"{TRUSTGRAPH_ENTITIES}"
|
|
assert uri == expected
|
|
|
|
def test_parse_json_with_code_blocks(self, agent_extractor):
|
|
"""Test JSON parsing from code blocks"""
|
|
# Test JSON in code blocks
|
|
response = '''```json
|
|
{
|
|
"definitions": [{"entity": "AI", "definition": "Artificial Intelligence"}],
|
|
"relationships": []
|
|
}
|
|
```'''
|
|
|
|
result = agent_extractor.parse_json(response)
|
|
|
|
assert result["definitions"][0]["entity"] == "AI"
|
|
assert result["definitions"][0]["definition"] == "Artificial Intelligence"
|
|
assert result["relationships"] == []
|
|
|
|
def test_parse_json_without_code_blocks(self, agent_extractor):
|
|
"""Test JSON parsing without code blocks"""
|
|
response = '''{"definitions": [{"entity": "ML", "definition": "Machine Learning"}], "relationships": []}'''
|
|
|
|
result = agent_extractor.parse_json(response)
|
|
|
|
assert result["definitions"][0]["entity"] == "ML"
|
|
assert result["definitions"][0]["definition"] == "Machine Learning"
|
|
|
|
def test_parse_json_invalid_format(self, agent_extractor):
|
|
"""Test JSON parsing with invalid format"""
|
|
invalid_response = "This is not JSON at all"
|
|
|
|
with pytest.raises(json.JSONDecodeError):
|
|
agent_extractor.parse_json(invalid_response)
|
|
|
|
def test_parse_json_malformed_code_blocks(self, agent_extractor):
|
|
"""Test JSON parsing with malformed code blocks"""
|
|
# Missing closing backticks
|
|
response = '''```json
|
|
{"definitions": [], "relationships": []}
|
|
'''
|
|
|
|
# Should still parse the JSON content
|
|
with pytest.raises(json.JSONDecodeError):
|
|
agent_extractor.parse_json(response)
|
|
|
|
def test_process_extraction_data_definitions(self, agent_extractor, sample_metadata):
|
|
"""Test processing of definition data"""
|
|
data = {
|
|
"definitions": [
|
|
{
|
|
"entity": "Machine Learning",
|
|
"definition": "A subset of AI that enables learning from data."
|
|
}
|
|
],
|
|
"relationships": []
|
|
}
|
|
|
|
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
|
|
|
# Check entity label triple
|
|
label_triple = next((t for t in triples if t.p.value == RDF_LABEL and t.o.value == "Machine Learning"), None)
|
|
assert label_triple is not None
|
|
assert label_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
|
assert label_triple.s.is_uri == True
|
|
assert label_triple.o.is_uri == False
|
|
|
|
# Check definition triple
|
|
def_triple = next((t for t in triples if t.p.value == DEFINITION), None)
|
|
assert def_triple is not None
|
|
assert def_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
|
assert def_triple.o.value == "A subset of AI that enables learning from data."
|
|
|
|
# Check subject-of triple
|
|
subject_of_triple = next((t for t in triples if t.p.value == SUBJECT_OF), None)
|
|
assert subject_of_triple is not None
|
|
assert subject_of_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
|
assert subject_of_triple.o.value == "doc123"
|
|
|
|
# Check entity context
|
|
assert len(entity_contexts) == 1
|
|
assert entity_contexts[0].entity.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
|
assert entity_contexts[0].context == "A subset of AI that enables learning from data."
|
|
|
|
def test_process_extraction_data_relationships(self, agent_extractor, sample_metadata):
|
|
"""Test processing of relationship data"""
|
|
data = {
|
|
"definitions": [],
|
|
"relationships": [
|
|
{
|
|
"subject": "Machine Learning",
|
|
"predicate": "is_subset_of",
|
|
"object": "Artificial Intelligence",
|
|
"object-entity": True
|
|
}
|
|
]
|
|
}
|
|
|
|
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
|
|
|
# Check that subject, predicate, and object labels are created
|
|
subject_uri = f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
|
predicate_uri = f"{TRUSTGRAPH_ENTITIES}is_subset_of"
|
|
|
|
# Find label triples
|
|
subject_label = next((t for t in triples if t.s.value == subject_uri and t.p.value == RDF_LABEL), None)
|
|
assert subject_label is not None
|
|
assert subject_label.o.value == "Machine Learning"
|
|
|
|
predicate_label = next((t for t in triples if t.s.value == predicate_uri and t.p.value == RDF_LABEL), None)
|
|
assert predicate_label is not None
|
|
assert predicate_label.o.value == "is_subset_of"
|
|
|
|
# Check main relationship triple
|
|
# NOTE: Current implementation has bugs:
|
|
# 1. Uses data.get("object-entity") instead of rel.get("object-entity")
|
|
# 2. Sets object_value to predicate_uri instead of actual object URI
|
|
# This test documents the current buggy behavior
|
|
rel_triple = next((t for t in triples if t.s.value == subject_uri and t.p.value == predicate_uri), None)
|
|
assert rel_triple is not None
|
|
# Due to bug, object value is set to predicate_uri
|
|
assert rel_triple.o.value == predicate_uri
|
|
|
|
# Check subject-of relationships
|
|
subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF and t.o.value == "doc123"]
|
|
assert len(subject_of_triples) >= 2 # At least subject and predicate should have subject-of relations
|
|
|
|
def test_process_extraction_data_literal_object(self, agent_extractor, sample_metadata):
|
|
"""Test processing of relationships with literal objects"""
|
|
data = {
|
|
"definitions": [],
|
|
"relationships": [
|
|
{
|
|
"subject": "Deep Learning",
|
|
"predicate": "accuracy",
|
|
"object": "95%",
|
|
"object-entity": False
|
|
}
|
|
]
|
|
}
|
|
|
|
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
|
|
|
# Check that object labels are not created for literal objects
|
|
object_labels = [t for t in triples if t.p.value == RDF_LABEL and t.o.value == "95%"]
|
|
# Based on the code logic, it should not create object labels for non-entity objects
|
|
# But there might be a bug in the original implementation
|
|
|
|
def test_process_extraction_data_combined(self, agent_extractor, sample_metadata, sample_extraction_data):
|
|
"""Test processing of combined definitions and relationships"""
|
|
triples, entity_contexts = agent_extractor.process_extraction_data(sample_extraction_data, sample_metadata)
|
|
|
|
# Check that we have both definition and relationship triples
|
|
definition_triples = [t for t in triples if t.p.value == DEFINITION]
|
|
assert len(definition_triples) == 2 # Two definitions
|
|
|
|
# Check entity contexts are created for definitions
|
|
assert len(entity_contexts) == 2
|
|
entity_uris = [ec.entity.value for ec in entity_contexts]
|
|
assert f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" in entity_uris
|
|
assert f"{TRUSTGRAPH_ENTITIES}Neural%20Networks" in entity_uris
|
|
|
|
def test_process_extraction_data_no_metadata_id(self, agent_extractor):
|
|
"""Test processing when metadata has no ID"""
|
|
metadata = Metadata(id=None, metadata=[])
|
|
data = {
|
|
"definitions": [
|
|
{"entity": "Test Entity", "definition": "Test definition"}
|
|
],
|
|
"relationships": []
|
|
}
|
|
|
|
triples, entity_contexts = agent_extractor.process_extraction_data(data, metadata)
|
|
|
|
# Should not create subject-of relationships when no metadata ID
|
|
subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF]
|
|
assert len(subject_of_triples) == 0
|
|
|
|
# Should still create entity contexts
|
|
assert len(entity_contexts) == 1
|
|
|
|
def test_process_extraction_data_empty_data(self, agent_extractor, sample_metadata):
|
|
"""Test processing of empty extraction data"""
|
|
data = {"definitions": [], "relationships": []}
|
|
|
|
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
|
|
|
# Should only have metadata triples
|
|
assert len(entity_contexts) == 0
|
|
# Triples should only contain metadata triples if any
|
|
|
|
def test_process_extraction_data_missing_keys(self, agent_extractor, sample_metadata):
|
|
"""Test processing data with missing keys"""
|
|
# Test missing definitions key
|
|
data = {"relationships": []}
|
|
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
|
assert len(entity_contexts) == 0
|
|
|
|
# Test missing relationships key
|
|
data = {"definitions": []}
|
|
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
|
assert len(entity_contexts) == 0
|
|
|
|
# Test completely missing keys
|
|
data = {}
|
|
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
|
assert len(entity_contexts) == 0
|
|
|
|
def test_process_extraction_data_malformed_entries(self, agent_extractor, sample_metadata):
|
|
"""Test processing data with malformed entries"""
|
|
# Test definition missing required fields
|
|
data = {
|
|
"definitions": [
|
|
{"entity": "Test"}, # Missing definition
|
|
{"definition": "Test def"} # Missing entity
|
|
],
|
|
"relationships": [
|
|
{"subject": "A", "predicate": "rel"}, # Missing object
|
|
{"subject": "B", "object": "C"} # Missing predicate
|
|
]
|
|
}
|
|
|
|
# Should handle gracefully or raise appropriate errors
|
|
with pytest.raises(KeyError):
|
|
agent_extractor.process_extraction_data(data, sample_metadata)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_emit_triples(self, agent_extractor, sample_metadata):
|
|
"""Test emitting triples to publisher"""
|
|
mock_publisher = AsyncMock()
|
|
|
|
test_triples = [
|
|
Triple(
|
|
s=Value(value="test:subject", is_uri=True),
|
|
p=Value(value="test:predicate", is_uri=True),
|
|
o=Value(value="test object", is_uri=False)
|
|
)
|
|
]
|
|
|
|
await agent_extractor.emit_triples(mock_publisher, sample_metadata, test_triples)
|
|
|
|
mock_publisher.send.assert_called_once()
|
|
sent_triples = mock_publisher.send.call_args[0][0]
|
|
assert isinstance(sent_triples, Triples)
|
|
# Check metadata fields individually since implementation creates new Metadata object
|
|
assert sent_triples.metadata.id == sample_metadata.id
|
|
assert sent_triples.metadata.user == sample_metadata.user
|
|
assert sent_triples.metadata.collection == sample_metadata.collection
|
|
# Note: metadata.metadata is now empty array in the new implementation
|
|
assert sent_triples.metadata.metadata == []
|
|
assert len(sent_triples.triples) == 1
|
|
assert sent_triples.triples[0].s.value == "test:subject"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_emit_entity_contexts(self, agent_extractor, sample_metadata):
|
|
"""Test emitting entity contexts to publisher"""
|
|
mock_publisher = AsyncMock()
|
|
|
|
test_contexts = [
|
|
EntityContext(
|
|
entity=Value(value="test:entity", is_uri=True),
|
|
context="Test context"
|
|
)
|
|
]
|
|
|
|
await agent_extractor.emit_entity_contexts(mock_publisher, sample_metadata, test_contexts)
|
|
|
|
mock_publisher.send.assert_called_once()
|
|
sent_contexts = mock_publisher.send.call_args[0][0]
|
|
assert isinstance(sent_contexts, EntityContexts)
|
|
# Check metadata fields individually since implementation creates new Metadata object
|
|
assert sent_contexts.metadata.id == sample_metadata.id
|
|
assert sent_contexts.metadata.user == sample_metadata.user
|
|
assert sent_contexts.metadata.collection == sample_metadata.collection
|
|
# Note: metadata.metadata is now empty array in the new implementation
|
|
assert sent_contexts.metadata.metadata == []
|
|
assert len(sent_contexts.entities) == 1
|
|
assert sent_contexts.entities[0].entity.value == "test:entity"
|
|
|
|
def test_agent_extractor_initialization_params(self):
|
|
"""Test agent extractor parameter validation"""
|
|
# Test default parameters (we'll mock the initialization)
|
|
def mock_init(self, **kwargs):
|
|
self.template_id = kwargs.get('template-id', 'agent-kg-extract')
|
|
self.config_key = kwargs.get('config-type', 'prompt')
|
|
self.concurrency = kwargs.get('concurrency', 1)
|
|
|
|
with patch.object(AgentKgExtractor, '__init__', mock_init):
|
|
extractor = AgentKgExtractor()
|
|
|
|
# This tests the default parameter logic
|
|
assert extractor.template_id == 'agent-kg-extract'
|
|
assert extractor.config_key == 'prompt'
|
|
assert extractor.concurrency == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_prompt_config_loading_logic(self, agent_extractor):
|
|
"""Test prompt configuration loading logic"""
|
|
# Test the core logic without requiring full FlowProcessor initialization
|
|
config = {
|
|
"prompt": {
|
|
"system": json.dumps("Test system"),
|
|
"template-index": json.dumps(["agent-kg-extract"]),
|
|
"template.agent-kg-extract": json.dumps({
|
|
"prompt": "Extract knowledge from: {{ text }}",
|
|
"response-type": "json"
|
|
})
|
|
}
|
|
}
|
|
|
|
# Test the manager loading directly
|
|
if "prompt" in config:
|
|
agent_extractor.manager.load_config(config["prompt"])
|
|
|
|
# Should not raise an exception
|
|
assert agent_extractor.manager is not None
|
|
|
|
# Test with empty config
|
|
empty_config = {}
|
|
# Should handle gracefully - no config to load |