mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 08:26:21 +02:00
421 lines
19 KiB
Python
421 lines
19 KiB
Python
|
|
"""
|
||
|
|
Unit tests for relationship extraction logic
|
||
|
|
|
||
|
|
Tests the core business logic for extracting relationships between entities,
|
||
|
|
including pattern matching, relationship classification, and validation.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import Mock
|
||
|
|
import re
|
||
|
|
|
||
|
|
|
||
|
|
class TestRelationshipExtractionLogic:
|
||
|
|
"""Test cases for relationship extraction business logic"""
|
||
|
|
|
||
|
|
def test_simple_relationship_patterns(self):
|
||
|
|
"""Test simple pattern-based relationship extraction"""
|
||
|
|
# Arrange
|
||
|
|
text = "John Smith works for OpenAI in San Francisco."
|
||
|
|
entities = [
|
||
|
|
{"text": "John Smith", "type": "PERSON", "start": 0, "end": 10},
|
||
|
|
{"text": "OpenAI", "type": "ORG", "start": 21, "end": 27},
|
||
|
|
{"text": "San Francisco", "type": "PLACE", "start": 31, "end": 44}
|
||
|
|
]
|
||
|
|
|
||
|
|
def extract_relationships_pattern_based(text, entities):
|
||
|
|
relationships = []
|
||
|
|
|
||
|
|
# Define relationship patterns
|
||
|
|
patterns = [
|
||
|
|
(r'(\w+(?:\s+\w+)*)\s+works\s+for\s+(\w+(?:\s+\w+)*)', "works_for"),
|
||
|
|
(r'(\w+(?:\s+\w+)*)\s+is\s+employed\s+by\s+(\w+(?:\s+\w+)*)', "employed_by"),
|
||
|
|
(r'(\w+(?:\s+\w+)*)\s+in\s+(\w+(?:\s+\w+)*)', "located_in"),
|
||
|
|
(r'(\w+(?:\s+\w+)*)\s+founded\s+(\w+(?:\s+\w+)*)', "founded"),
|
||
|
|
(r'(\w+(?:\s+\w+)*)\s+developed\s+(\w+(?:\s+\w+)*)', "developed")
|
||
|
|
]
|
||
|
|
|
||
|
|
for pattern, relation_type in patterns:
|
||
|
|
matches = re.finditer(pattern, text, re.IGNORECASE)
|
||
|
|
for match in matches:
|
||
|
|
subject = match.group(1).strip()
|
||
|
|
object_text = match.group(2).strip()
|
||
|
|
|
||
|
|
# Verify entities exist in our entity list
|
||
|
|
subject_entity = next((e for e in entities if e["text"] == subject), None)
|
||
|
|
object_entity = next((e for e in entities if e["text"] == object_text), None)
|
||
|
|
|
||
|
|
if subject_entity and object_entity:
|
||
|
|
relationships.append({
|
||
|
|
"subject": subject,
|
||
|
|
"predicate": relation_type,
|
||
|
|
"object": object_text,
|
||
|
|
"confidence": 0.8,
|
||
|
|
"subject_type": subject_entity["type"],
|
||
|
|
"object_type": object_entity["type"]
|
||
|
|
})
|
||
|
|
|
||
|
|
return relationships
|
||
|
|
|
||
|
|
# Act
|
||
|
|
relationships = extract_relationships_pattern_based(text, entities)
|
||
|
|
|
||
|
|
# Assert
|
||
|
|
assert len(relationships) >= 0 # May not find relationships due to entity matching
|
||
|
|
if relationships:
|
||
|
|
work_rel = next((r for r in relationships if r["predicate"] == "works_for"), None)
|
||
|
|
if work_rel:
|
||
|
|
assert work_rel["subject"] == "John Smith"
|
||
|
|
assert work_rel["object"] == "OpenAI"
|
||
|
|
|
||
|
|
def test_relationship_type_classification(self):
|
||
|
|
"""Test relationship type classification and normalization"""
|
||
|
|
# Arrange
|
||
|
|
raw_relationships = [
|
||
|
|
("John Smith", "works for", "OpenAI"),
|
||
|
|
("John Smith", "is employed by", "OpenAI"),
|
||
|
|
("John Smith", "job at", "OpenAI"),
|
||
|
|
("OpenAI", "located in", "San Francisco"),
|
||
|
|
("OpenAI", "based in", "San Francisco"),
|
||
|
|
("OpenAI", "headquarters in", "San Francisco"),
|
||
|
|
("John Smith", "developed", "ChatGPT"),
|
||
|
|
("John Smith", "created", "ChatGPT"),
|
||
|
|
("John Smith", "built", "ChatGPT")
|
||
|
|
]
|
||
|
|
|
||
|
|
def classify_relationship_type(predicate):
|
||
|
|
# Normalize and classify relationships
|
||
|
|
predicate_lower = predicate.lower().strip()
|
||
|
|
|
||
|
|
# Employment relationships
|
||
|
|
if any(phrase in predicate_lower for phrase in ["works for", "employed by", "job at", "position at"]):
|
||
|
|
return "employment"
|
||
|
|
|
||
|
|
# Location relationships
|
||
|
|
if any(phrase in predicate_lower for phrase in ["located in", "based in", "headquarters in", "situated in"]):
|
||
|
|
return "location"
|
||
|
|
|
||
|
|
# Creation relationships
|
||
|
|
if any(phrase in predicate_lower for phrase in ["developed", "created", "built", "designed", "invented"]):
|
||
|
|
return "creation"
|
||
|
|
|
||
|
|
# Ownership relationships
|
||
|
|
if any(phrase in predicate_lower for phrase in ["owns", "founded", "established", "started"]):
|
||
|
|
return "ownership"
|
||
|
|
|
||
|
|
return "generic"
|
||
|
|
|
||
|
|
# Act & Assert
|
||
|
|
expected_classifications = {
|
||
|
|
"works for": "employment",
|
||
|
|
"is employed by": "employment",
|
||
|
|
"job at": "employment",
|
||
|
|
"located in": "location",
|
||
|
|
"based in": "location",
|
||
|
|
"headquarters in": "location",
|
||
|
|
"developed": "creation",
|
||
|
|
"created": "creation",
|
||
|
|
"built": "creation"
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, predicate, _ in raw_relationships:
|
||
|
|
if predicate in expected_classifications:
|
||
|
|
classification = classify_relationship_type(predicate)
|
||
|
|
expected = expected_classifications[predicate]
|
||
|
|
assert classification == expected, f"'{predicate}' classified as {classification}, expected {expected}"
|
||
|
|
|
||
|
|
def test_relationship_validation(self):
|
||
|
|
"""Test relationship validation rules"""
|
||
|
|
# Arrange
|
||
|
|
relationships = [
|
||
|
|
{"subject": "John Smith", "predicate": "works_for", "object": "OpenAI", "subject_type": "PERSON", "object_type": "ORG"},
|
||
|
|
{"subject": "OpenAI", "predicate": "located_in", "object": "San Francisco", "subject_type": "ORG", "object_type": "PLACE"},
|
||
|
|
{"subject": "John Smith", "predicate": "located_in", "object": "John Smith", "subject_type": "PERSON", "object_type": "PERSON"}, # Self-reference
|
||
|
|
{"subject": "", "predicate": "works_for", "object": "OpenAI", "subject_type": "PERSON", "object_type": "ORG"}, # Empty subject
|
||
|
|
{"subject": "Chair", "predicate": "located_in", "object": "Room", "subject_type": "OBJECT", "object_type": "PLACE"} # Valid object relationship
|
||
|
|
]
|
||
|
|
|
||
|
|
def validate_relationship(relationship):
|
||
|
|
subject = relationship.get("subject", "")
|
||
|
|
predicate = relationship.get("predicate", "")
|
||
|
|
obj = relationship.get("object", "")
|
||
|
|
subject_type = relationship.get("subject_type", "")
|
||
|
|
object_type = relationship.get("object_type", "")
|
||
|
|
|
||
|
|
# Basic validation rules
|
||
|
|
if not subject or not predicate or not obj:
|
||
|
|
return False, "Missing required fields"
|
||
|
|
|
||
|
|
if subject == obj:
|
||
|
|
return False, "Self-referential relationship"
|
||
|
|
|
||
|
|
# Type compatibility rules
|
||
|
|
type_rules = {
|
||
|
|
"works_for": {"valid_subject": ["PERSON"], "valid_object": ["ORG", "COMPANY"]},
|
||
|
|
"located_in": {"valid_subject": ["PERSON", "ORG", "OBJECT"], "valid_object": ["PLACE", "LOCATION"]},
|
||
|
|
"developed": {"valid_subject": ["PERSON", "ORG"], "valid_object": ["PRODUCT", "SOFTWARE"]}
|
||
|
|
}
|
||
|
|
|
||
|
|
if predicate in type_rules:
|
||
|
|
rule = type_rules[predicate]
|
||
|
|
if subject_type not in rule["valid_subject"]:
|
||
|
|
return False, f"Invalid subject type {subject_type} for predicate {predicate}"
|
||
|
|
if object_type not in rule["valid_object"]:
|
||
|
|
return False, f"Invalid object type {object_type} for predicate {predicate}"
|
||
|
|
|
||
|
|
return True, "Valid"
|
||
|
|
|
||
|
|
# Act & Assert
|
||
|
|
expected_results = [True, True, False, False, True]
|
||
|
|
|
||
|
|
for i, relationship in enumerate(relationships):
|
||
|
|
is_valid, reason = validate_relationship(relationship)
|
||
|
|
assert is_valid == expected_results[i], f"Relationship {i} validation mismatch: {reason}"
|
||
|
|
|
||
|
|
def test_relationship_confidence_scoring(self):
|
||
|
|
"""Test relationship confidence scoring"""
|
||
|
|
# Arrange
|
||
|
|
def calculate_relationship_confidence(relationship, context):
|
||
|
|
base_confidence = 0.5
|
||
|
|
|
||
|
|
predicate = relationship["predicate"]
|
||
|
|
subject_type = relationship.get("subject_type", "")
|
||
|
|
object_type = relationship.get("object_type", "")
|
||
|
|
|
||
|
|
# Boost confidence for common, reliable patterns
|
||
|
|
reliable_patterns = {
|
||
|
|
"works_for": 0.3,
|
||
|
|
"employed_by": 0.3,
|
||
|
|
"located_in": 0.2,
|
||
|
|
"founded": 0.4
|
||
|
|
}
|
||
|
|
|
||
|
|
if predicate in reliable_patterns:
|
||
|
|
base_confidence += reliable_patterns[predicate]
|
||
|
|
|
||
|
|
# Boost for type compatibility
|
||
|
|
if predicate == "works_for" and subject_type == "PERSON" and object_type == "ORG":
|
||
|
|
base_confidence += 0.2
|
||
|
|
|
||
|
|
if predicate == "located_in" and object_type in ["PLACE", "LOCATION"]:
|
||
|
|
base_confidence += 0.1
|
||
|
|
|
||
|
|
# Context clues
|
||
|
|
context_lower = context.lower()
|
||
|
|
context_boost_words = {
|
||
|
|
"works_for": ["employee", "staff", "team member"],
|
||
|
|
"located_in": ["address", "office", "building"],
|
||
|
|
"developed": ["creator", "developer", "engineer"]
|
||
|
|
}
|
||
|
|
|
||
|
|
if predicate in context_boost_words:
|
||
|
|
for word in context_boost_words[predicate]:
|
||
|
|
if word in context_lower:
|
||
|
|
base_confidence += 0.05
|
||
|
|
|
||
|
|
return min(base_confidence, 1.0)
|
||
|
|
|
||
|
|
test_cases = [
|
||
|
|
({"predicate": "works_for", "subject_type": "PERSON", "object_type": "ORG"},
|
||
|
|
"John Smith is an employee at OpenAI", 0.9),
|
||
|
|
({"predicate": "located_in", "subject_type": "ORG", "object_type": "PLACE"},
|
||
|
|
"The office building is in downtown", 0.8),
|
||
|
|
({"predicate": "unknown", "subject_type": "UNKNOWN", "object_type": "UNKNOWN"},
|
||
|
|
"Some random text", 0.5) # Reduced expectation for unknown relationships
|
||
|
|
]
|
||
|
|
|
||
|
|
# Act & Assert
|
||
|
|
for relationship, context, expected_min in test_cases:
|
||
|
|
confidence = calculate_relationship_confidence(relationship, context)
|
||
|
|
assert confidence >= expected_min, f"Confidence {confidence} too low for {relationship['predicate']}"
|
||
|
|
assert confidence <= 1.0, f"Confidence {confidence} exceeds maximum"
|
||
|
|
|
||
|
|
def test_relationship_directionality(self):
|
||
|
|
"""Test relationship directionality and symmetry"""
|
||
|
|
# Arrange
|
||
|
|
def analyze_relationship_directionality(predicate):
|
||
|
|
# Define directional properties of relationships
|
||
|
|
directional_rules = {
|
||
|
|
"works_for": {"directed": True, "symmetric": False, "inverse": "employs"},
|
||
|
|
"located_in": {"directed": True, "symmetric": False, "inverse": "contains"},
|
||
|
|
"married_to": {"directed": False, "symmetric": True, "inverse": "married_to"},
|
||
|
|
"sibling_of": {"directed": False, "symmetric": True, "inverse": "sibling_of"},
|
||
|
|
"founded": {"directed": True, "symmetric": False, "inverse": "founded_by"},
|
||
|
|
"owns": {"directed": True, "symmetric": False, "inverse": "owned_by"}
|
||
|
|
}
|
||
|
|
|
||
|
|
return directional_rules.get(predicate, {"directed": True, "symmetric": False, "inverse": None})
|
||
|
|
|
||
|
|
# Act & Assert
|
||
|
|
test_cases = [
|
||
|
|
("works_for", True, False, "employs"),
|
||
|
|
("married_to", False, True, "married_to"),
|
||
|
|
("located_in", True, False, "contains"),
|
||
|
|
("sibling_of", False, True, "sibling_of")
|
||
|
|
]
|
||
|
|
|
||
|
|
for predicate, is_directed, is_symmetric, inverse in test_cases:
|
||
|
|
rules = analyze_relationship_directionality(predicate)
|
||
|
|
assert rules["directed"] == is_directed, f"{predicate} directionality mismatch"
|
||
|
|
assert rules["symmetric"] == is_symmetric, f"{predicate} symmetry mismatch"
|
||
|
|
assert rules["inverse"] == inverse, f"{predicate} inverse mismatch"
|
||
|
|
|
||
|
|
def test_temporal_relationship_extraction(self):
|
||
|
|
"""Test extraction of temporal aspects in relationships"""
|
||
|
|
# Arrange
|
||
|
|
texts_with_temporal = [
|
||
|
|
"John Smith worked for OpenAI from 2020 to 2023.",
|
||
|
|
"Mary Johnson currently works at Microsoft.",
|
||
|
|
"Bob will join Google next month.",
|
||
|
|
"Alice previously worked for Apple."
|
||
|
|
]
|
||
|
|
|
||
|
|
def extract_temporal_info(text, relationship):
|
||
|
|
temporal_patterns = [
|
||
|
|
(r'from\s+(\d{4})\s+to\s+(\d{4})', "duration"),
|
||
|
|
(r'currently\s+', "present"),
|
||
|
|
(r'will\s+', "future"),
|
||
|
|
(r'previously\s+', "past"),
|
||
|
|
(r'formerly\s+', "past"),
|
||
|
|
(r'since\s+(\d{4})', "ongoing"),
|
||
|
|
(r'until\s+(\d{4})', "ended")
|
||
|
|
]
|
||
|
|
|
||
|
|
temporal_info = {"type": "unknown", "details": {}}
|
||
|
|
|
||
|
|
for pattern, temp_type in temporal_patterns:
|
||
|
|
match = re.search(pattern, text, re.IGNORECASE)
|
||
|
|
if match:
|
||
|
|
temporal_info["type"] = temp_type
|
||
|
|
if temp_type == "duration" and len(match.groups()) >= 2:
|
||
|
|
temporal_info["details"] = {
|
||
|
|
"start_year": match.group(1),
|
||
|
|
"end_year": match.group(2)
|
||
|
|
}
|
||
|
|
elif temp_type == "ongoing" and len(match.groups()) >= 1:
|
||
|
|
temporal_info["details"] = {"start_year": match.group(1)}
|
||
|
|
break
|
||
|
|
|
||
|
|
return temporal_info
|
||
|
|
|
||
|
|
# Act & Assert
|
||
|
|
expected_temporal_types = ["duration", "present", "future", "past"]
|
||
|
|
|
||
|
|
for i, text in enumerate(texts_with_temporal):
|
||
|
|
# Mock relationship for testing
|
||
|
|
relationship = {"subject": "Test", "predicate": "works_for", "object": "Company"}
|
||
|
|
temporal = extract_temporal_info(text, relationship)
|
||
|
|
|
||
|
|
assert temporal["type"] == expected_temporal_types[i]
|
||
|
|
|
||
|
|
if temporal["type"] == "duration":
|
||
|
|
assert "start_year" in temporal["details"]
|
||
|
|
assert "end_year" in temporal["details"]
|
||
|
|
|
||
|
|
def test_relationship_clustering(self):
|
||
|
|
"""Test clustering similar relationships"""
|
||
|
|
# Arrange
|
||
|
|
relationships = [
|
||
|
|
{"subject": "John", "predicate": "works_for", "object": "OpenAI"},
|
||
|
|
{"subject": "John", "predicate": "employed_by", "object": "OpenAI"},
|
||
|
|
{"subject": "Mary", "predicate": "works_at", "object": "Microsoft"},
|
||
|
|
{"subject": "Bob", "predicate": "located_in", "object": "New York"},
|
||
|
|
{"subject": "OpenAI", "predicate": "based_in", "object": "San Francisco"}
|
||
|
|
]
|
||
|
|
|
||
|
|
def cluster_similar_relationships(relationships):
|
||
|
|
# Group relationships by semantic similarity
|
||
|
|
clusters = {}
|
||
|
|
|
||
|
|
# Define semantic equivalence groups
|
||
|
|
equivalence_groups = {
|
||
|
|
"employment": ["works_for", "employed_by", "works_at", "job_at"],
|
||
|
|
"location": ["located_in", "based_in", "situated_in", "in"]
|
||
|
|
}
|
||
|
|
|
||
|
|
for rel in relationships:
|
||
|
|
predicate = rel["predicate"]
|
||
|
|
|
||
|
|
# Find which semantic group this predicate belongs to
|
||
|
|
semantic_group = "other"
|
||
|
|
for group_name, predicates in equivalence_groups.items():
|
||
|
|
if predicate in predicates:
|
||
|
|
semantic_group = group_name
|
||
|
|
break
|
||
|
|
|
||
|
|
# Create cluster key
|
||
|
|
cluster_key = (rel["subject"], semantic_group, rel["object"])
|
||
|
|
|
||
|
|
if cluster_key not in clusters:
|
||
|
|
clusters[cluster_key] = []
|
||
|
|
clusters[cluster_key].append(rel)
|
||
|
|
|
||
|
|
return clusters
|
||
|
|
|
||
|
|
# Act
|
||
|
|
clusters = cluster_similar_relationships(relationships)
|
||
|
|
|
||
|
|
# Assert
|
||
|
|
# John's employment relationships should be clustered
|
||
|
|
john_employment_key = ("John", "employment", "OpenAI")
|
||
|
|
assert john_employment_key in clusters
|
||
|
|
assert len(clusters[john_employment_key]) == 2 # works_for and employed_by
|
||
|
|
|
||
|
|
# Check that we have separate clusters for different subjects/objects
|
||
|
|
cluster_count = len(clusters)
|
||
|
|
assert cluster_count >= 3 # At least John-OpenAI, Mary-Microsoft, Bob-location, OpenAI-location
|
||
|
|
|
||
|
|
def test_relationship_chain_analysis(self):
|
||
|
|
"""Test analysis of relationship chains and paths"""
|
||
|
|
# Arrange
|
||
|
|
relationships = [
|
||
|
|
{"subject": "John", "predicate": "works_for", "object": "OpenAI"},
|
||
|
|
{"subject": "OpenAI", "predicate": "located_in", "object": "San Francisco"},
|
||
|
|
{"subject": "San Francisco", "predicate": "located_in", "object": "California"},
|
||
|
|
{"subject": "Mary", "predicate": "works_for", "object": "OpenAI"}
|
||
|
|
]
|
||
|
|
|
||
|
|
def find_relationship_chains(relationships, start_entity, max_depth=3):
|
||
|
|
# Build adjacency list
|
||
|
|
graph = {}
|
||
|
|
for rel in relationships:
|
||
|
|
subject = rel["subject"]
|
||
|
|
if subject not in graph:
|
||
|
|
graph[subject] = []
|
||
|
|
graph[subject].append((rel["predicate"], rel["object"]))
|
||
|
|
|
||
|
|
# Find chains starting from start_entity
|
||
|
|
def dfs_chains(current, path, depth):
|
||
|
|
if depth >= max_depth:
|
||
|
|
return [path]
|
||
|
|
|
||
|
|
chains = [path] # Include current path
|
||
|
|
|
||
|
|
if current in graph:
|
||
|
|
for predicate, next_entity in graph[current]:
|
||
|
|
if next_entity not in [p[0] for p in path]: # Avoid cycles
|
||
|
|
new_path = path + [(next_entity, predicate)]
|
||
|
|
chains.extend(dfs_chains(next_entity, new_path, depth + 1))
|
||
|
|
|
||
|
|
return chains
|
||
|
|
|
||
|
|
return dfs_chains(start_entity, [(start_entity, "start")], 0)
|
||
|
|
|
||
|
|
# Act
|
||
|
|
john_chains = find_relationship_chains(relationships, "John")
|
||
|
|
|
||
|
|
# Assert
|
||
|
|
# Should find chains like: John -> OpenAI -> San Francisco -> California
|
||
|
|
chain_lengths = [len(chain) for chain in john_chains]
|
||
|
|
assert max(chain_lengths) >= 3 # At least a 3-entity chain
|
||
|
|
|
||
|
|
# Check for specific expected chain
|
||
|
|
long_chains = [chain for chain in john_chains if len(chain) >= 4]
|
||
|
|
assert len(long_chains) > 0
|
||
|
|
|
||
|
|
# Verify chain contains expected entities
|
||
|
|
longest_chain = max(john_chains, key=len)
|
||
|
|
chain_entities = [entity for entity, _ in longest_chain]
|
||
|
|
assert "John" in chain_entities
|
||
|
|
assert "OpenAI" in chain_entities
|
||
|
|
assert "San Francisco" in chain_entities
|