Fix ontology selector defaults, add bypass mode, enforce domain/range (#908, #920)

- Align similarity_threshold default to 0.3 everywhere (class signature
  had stale 0.7). Fix matching contradiction in tech-spec.
- Add bypass_selector_below parameter (default 5) to skip vector
  similarity selection when ontology element count is small enough.
- Enforce domain/range constraints in TripleConverter for object
  properties and datatype properties, with subclass hierarchy support.
  Properties with no declared domain/range pass through unchanged.
- Add unit tests for domain/range validation, subclass acceptance,
  polymorphic pass-through, and selector bypass.
This commit is contained in:
Cyber MacGeddon 2026-05-16 13:27:08 +01:00
parent aea4c2df8e
commit 79cfbc6abd
5 changed files with 501 additions and 13 deletions

View file

@ -278,7 +278,7 @@ The system uses **FAISS (Facebook AI Similarity Search)** with IndexFlatIP for e
3. **Similarity Search**: 3. **Similarity Search**:
- For each text segment embedding, search the vector store - For each text segment embedding, search the vector store
- Retrieve top-k (e.g., 10) most similar ontology elements - Retrieve top-k (e.g., 10) most similar ontology elements
- Apply similarity threshold (e.g., 0.7) to filter weak matches - Apply similarity threshold (e.g., 0.3) to filter weak matches
- Aggregate results across all segments, tracking match frequencies - Aggregate results across all segments, tracking match frequencies
4. **Dependency Resolution**: 4. **Dependency Resolution**:

View file

@ -0,0 +1,389 @@
"""
Tests for TripleConverter domain/range enforcement and
OntologySelector bypass for small ontologies.
Covers fixes for #908 (bypass_selector_below) and #920 (domain/range validation).
"""
import pytest
from unittest.mock import Mock, AsyncMock
from trustgraph.extract.kg.ontology.triple_converter import TripleConverter
from trustgraph.extract.kg.ontology.ontology_selector import (
OntologySelector,
OntologySubset,
)
from trustgraph.extract.kg.ontology.ontology_loader import (
Ontology,
OntologyClass,
OntologyProperty,
)
from trustgraph.extract.kg.ontology.simplified_parser import (
Relationship,
Attribute,
)
from trustgraph.extract.kg.ontology.text_processor import TextSegment
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def ontology_subset():
"""Ontology subset with classes, hierarchy, and constrained properties."""
return OntologySubset(
ontology_id="test",
classes={
"Person": {
"uri": "http://example.org/Person",
"type": "owl:Class",
"labels": [{"value": "Person"}],
"subclass_of": None,
},
"Employee": {
"uri": "http://example.org/Employee",
"type": "owl:Class",
"labels": [{"value": "Employee"}],
"subclass_of": "Person",
},
"Manager": {
"uri": "http://example.org/Manager",
"type": "owl:Class",
"labels": [{"value": "Manager"}],
"subclass_of": "Employee",
},
"Company": {
"uri": "http://example.org/Company",
"type": "owl:Class",
"labels": [{"value": "Company"}],
"subclass_of": None,
},
"Product": {
"uri": "http://example.org/Product",
"type": "owl:Class",
"labels": [{"value": "Product"}],
"subclass_of": None,
},
},
object_properties={
"worksFor": {
"uri": "http://example.org/worksFor",
"type": "owl:ObjectProperty",
"labels": [{"value": "works for"}],
"domain": "Person",
"range": "Company",
},
"manages": {
"uri": "http://example.org/manages",
"type": "owl:ObjectProperty",
"labels": [{"value": "manages"}],
"domain": "Manager",
"range": "Employee",
},
"relatedTo": {
"uri": "http://example.org/relatedTo",
"type": "owl:ObjectProperty",
"labels": [{"value": "related to"}],
"domain": None,
"range": None,
},
},
datatype_properties={
"employeeId": {
"uri": "http://example.org/employeeId",
"type": "owl:DatatypeProperty",
"labels": [{"value": "employee ID"}],
"domain": "Employee",
},
"description": {
"uri": "http://example.org/description",
"type": "owl:DatatypeProperty",
"labels": [{"value": "description"}],
"domain": None,
},
},
metadata={"name": "Test Ontology"},
)
@pytest.fixture
def converter(ontology_subset):
return TripleConverter(ontology_subset=ontology_subset, ontology_id="test")
# ---------------------------------------------------------------------------
# Domain/range enforcement — relationships
# ---------------------------------------------------------------------------
class TestRelationshipDomainRange:
def test_valid_domain_and_range(self, converter):
rel = Relationship(
subject="Alice", subject_type="Person",
relation="worksFor",
object="Acme Corp", object_type="Company",
)
triple = converter.convert_relationship(rel)
assert triple is not None
def test_domain_violation_rejected(self, converter):
rel = Relationship(
subject="Widget", subject_type="Product",
relation="worksFor",
object="Acme Corp", object_type="Company",
)
assert converter.convert_relationship(rel) is None
def test_range_violation_rejected(self, converter):
rel = Relationship(
subject="Alice", subject_type="Person",
relation="worksFor",
object="Widget", object_type="Product",
)
assert converter.convert_relationship(rel) is None
def test_both_domain_and_range_violated(self, converter):
rel = Relationship(
subject="Widget", subject_type="Product",
relation="worksFor",
object="Gadget", object_type="Product",
)
assert converter.convert_relationship(rel) is None
# ---------------------------------------------------------------------------
# Subclass acceptance
# ---------------------------------------------------------------------------
class TestSubclassAcceptance:
def test_direct_subclass_matches_domain(self, converter):
"""Employee is subclass of Person; worksFor domain is Person."""
rel = Relationship(
subject="Bob", subject_type="Employee",
relation="worksFor",
object="Acme Corp", object_type="Company",
)
assert converter.convert_relationship(rel) is not None
def test_transitive_subclass_matches_domain(self, converter):
"""Manager → Employee → Person; worksFor domain is Person."""
rel = Relationship(
subject="Carol", subject_type="Manager",
relation="worksFor",
object="Acme Corp", object_type="Company",
)
assert converter.convert_relationship(rel) is not None
def test_subclass_matches_range(self, converter):
"""manages range is Employee; Manager is subclass of Employee."""
rel = Relationship(
subject="Carol", subject_type="Manager",
relation="manages",
object="Dave", object_type="Manager",
)
assert converter.convert_relationship(rel) is not None
def test_superclass_does_not_match_subclass_constraint(self, converter):
"""manages domain is Manager; Person is NOT a subclass of Manager."""
rel = Relationship(
subject="Alice", subject_type="Person",
relation="manages",
object="Bob", object_type="Employee",
)
assert converter.convert_relationship(rel) is None
# ---------------------------------------------------------------------------
# Polymorphic properties (no domain/range)
# ---------------------------------------------------------------------------
class TestPolymorphicProperties:
def test_no_domain_no_range_allows_anything(self, converter):
rel = Relationship(
subject="Alice", subject_type="Person",
relation="relatedTo",
object="Acme Corp", object_type="Company",
)
assert converter.convert_relationship(rel) is not None
def test_polymorphic_with_unrelated_types(self, converter):
rel = Relationship(
subject="Widget", subject_type="Product",
relation="relatedTo",
object="Bob", object_type="Employee",
)
assert converter.convert_relationship(rel) is not None
# ---------------------------------------------------------------------------
# Datatype property domain enforcement
# ---------------------------------------------------------------------------
class TestAttributeDomainValidation:
def test_valid_domain(self, converter):
attr = Attribute(
entity="Bob", entity_type="Employee",
attribute="employeeId", value="E-1234",
)
assert converter.convert_attribute(attr) is not None
def test_subclass_matches_domain(self, converter):
"""Manager is subclass of Employee; employeeId domain is Employee."""
attr = Attribute(
entity="Carol", entity_type="Manager",
attribute="employeeId", value="M-5678",
)
assert converter.convert_attribute(attr) is not None
def test_domain_violation_rejected(self, converter):
attr = Attribute(
entity="Acme Corp", entity_type="Company",
attribute="employeeId", value="E-0000",
)
assert converter.convert_attribute(attr) is None
def test_no_domain_allows_anything(self, converter):
attr = Attribute(
entity="Widget", entity_type="Product",
attribute="description", value="A useful widget",
)
assert converter.convert_attribute(attr) is not None
# ---------------------------------------------------------------------------
# OntologySelector bypass for small ontologies (#908)
# ---------------------------------------------------------------------------
def _make_ontology(n_classes, n_obj_props=0, n_dt_props=0):
classes = {
f"C{i}": OntologyClass(uri=f"http://example.org/C{i}")
for i in range(n_classes)
}
obj_props = {
f"op{i}": OntologyProperty(
uri=f"http://example.org/op{i}", type="owl:ObjectProperty"
)
for i in range(n_obj_props)
}
dt_props = {
f"dp{i}": OntologyProperty(
uri=f"http://example.org/dp{i}", type="owl:DatatypeProperty"
)
for i in range(n_dt_props)
}
return Ontology(
id="tiny",
metadata={"name": "Tiny"},
classes=classes,
object_properties=obj_props,
datatype_properties=dt_props,
)
def _make_loader(ontology):
loader = Mock()
loader.get_ontology.return_value = ontology
loader.get_all_ontologies.return_value = {"tiny": ontology}
return loader
class TestBypassSelectorBelow:
async def test_bypass_returns_full_ontology(self):
"""With 3 elements and bypass_selector_below=5, selector is bypassed."""
ont = _make_ontology(2, 1, 0)
loader = _make_loader(ont)
embedder = Mock()
selector = OntologySelector(
ontology_embedder=embedder,
ontology_loader=loader,
bypass_selector_below=5,
)
segments = [TextSegment(text="some text", type="sentence", position=0)]
subsets = await selector.select_ontology_subset(segments)
assert len(subsets) == 1
assert subsets[0].ontology_id == "tiny"
assert len(subsets[0].classes) == 2
assert len(subsets[0].object_properties) == 1
assert subsets[0].relevance_score == 1.0
# Embedder should never be called
embedder.embed_text.assert_not_called()
async def test_no_bypass_when_above_threshold(self):
"""With 10 elements and bypass_selector_below=5, selector runs normally."""
ont = _make_ontology(6, 3, 1)
loader = _make_loader(ont)
embedder = Mock()
embedder.embed_text = AsyncMock(return_value=[0.1, 0.2])
vector_store = Mock()
vector_store.size.return_value = 10
vector_store.search.return_value = []
embedder.get_vector_store.return_value = vector_store
selector = OntologySelector(
ontology_embedder=embedder,
ontology_loader=loader,
bypass_selector_below=5,
)
segments = [TextSegment(text="some text", type="sentence", position=0)]
subsets = await selector.select_ontology_subset(segments)
# Vector store was consulted (selector ran normally)
vector_store.size.assert_called_once()
async def test_bypass_at_exact_threshold_not_triggered(self):
"""With exactly 5 elements and bypass_selector_below=5, selector runs (< not <=)."""
ont = _make_ontology(3, 1, 1) # total = 5
loader = _make_loader(ont)
embedder = Mock()
embedder.embed_text = AsyncMock(return_value=[0.1, 0.2])
vector_store = Mock()
vector_store.size.return_value = 5
vector_store.search.return_value = []
embedder.get_vector_store.return_value = vector_store
selector = OntologySelector(
ontology_embedder=embedder,
ontology_loader=loader,
bypass_selector_below=5,
)
segments = [TextSegment(text="some text", type="sentence", position=0)]
subsets = await selector.select_ontology_subset(segments)
# Should NOT bypass — 5 is not < 5
vector_store.size.assert_called_once()
async def test_bypass_zero_disables(self):
"""bypass_selector_below=0 means bypass never triggers."""
ont = _make_ontology(0, 0, 0) # empty ontology
loader = _make_loader(ont)
embedder = Mock()
embedder.embed_text = AsyncMock(return_value=[0.1])
vector_store = Mock()
vector_store.size.return_value = 0
vector_store.search.return_value = []
embedder.get_vector_store.return_value = vector_store
selector = OntologySelector(
ontology_embedder=embedder,
ontology_loader=loader,
bypass_selector_below=0,
)
segments = [TextSegment(text="some text", type="sentence", position=0)]
subsets = await selector.select_ontology_subset(segments)
# 0 is not < 0, so bypass doesn't trigger
vector_store.size.assert_called_once()

View file

@ -121,6 +121,7 @@ class Processor(FlowProcessor):
# Configuration # Configuration
self.top_k = params.get("top_k", 10) self.top_k = params.get("top_k", 10)
self.similarity_threshold = params.get("similarity_threshold", 0.3) self.similarity_threshold = params.get("similarity_threshold", 0.3)
self.bypass_selector_below = params.get("bypass_selector_below", 5)
# Per-workspace ontology version tracking # Per-workspace ontology version tracking
self.current_ontology_versions = {} # workspace -> version self.current_ontology_versions = {} # workspace -> version
@ -187,7 +188,8 @@ class Processor(FlowProcessor):
ontology_embedder=ontology_embedder, ontology_embedder=ontology_embedder,
ontology_loader=loader, ontology_loader=loader,
top_k=self.top_k, top_k=self.top_k,
similarity_threshold=self.similarity_threshold similarity_threshold=self.similarity_threshold,
bypass_selector_below=self.bypass_selector_below,
) )
# Store flow-specific components # Store flow-specific components
@ -981,6 +983,13 @@ class Processor(FlowProcessor):
default=0.3, default=0.3,
help='Similarity threshold for ontology matching (default: 0.3, range: 0.0-1.0)' help='Similarity threshold for ontology matching (default: 0.3, range: 0.0-1.0)'
) )
parser.add_argument(
'--bypass-selector-below',
type=int,
default=5,
help='Bypass ontology selector when total ontology elements '
'(classes + properties) is below this value (default: 5)'
)
parser.add_argument( parser.add_argument(
'--triples-batch-size', '--triples-batch-size',
type=int, type=int,

View file

@ -33,19 +33,44 @@ class OntologySelector:
def __init__(self, ontology_embedder: OntologyEmbedder, def __init__(self, ontology_embedder: OntologyEmbedder,
ontology_loader: OntologyLoader, ontology_loader: OntologyLoader,
top_k: int = 10, top_k: int = 10,
similarity_threshold: float = 0.7): similarity_threshold: float = 0.3,
"""Initialize the ontology selector. bypass_selector_below: int = 5):
Args:
ontology_embedder: Embedder with vector store
ontology_loader: Loader with ontology definitions
top_k: Number of top results to retrieve per segment
similarity_threshold: Minimum similarity score
"""
self.embedder = ontology_embedder self.embedder = ontology_embedder
self.loader = ontology_loader self.loader = ontology_loader
self.top_k = top_k self.top_k = top_k
self.similarity_threshold = similarity_threshold self.similarity_threshold = similarity_threshold
self.bypass_selector_below = bypass_selector_below
def _total_ontology_elements(self) -> int:
total = 0
for ontology in self.loader.get_all_ontologies().values():
total += len(ontology.classes)
total += len(ontology.object_properties)
total += len(ontology.datatype_properties)
return total
def _build_full_subsets(self) -> List[OntologySubset]:
subsets = []
for ont_id, ontology in self.loader.get_all_ontologies().items():
subset = OntologySubset(
ontology_id=ont_id,
classes={
cid: cls.__dict__
for cid, cls in ontology.classes.items()
},
object_properties={
pid: prop.__dict__
for pid, prop in ontology.object_properties.items()
},
datatype_properties={
pid: prop.__dict__
for pid, prop in ontology.datatype_properties.items()
},
metadata=ontology.metadata,
relevance_score=1.0,
)
subsets.append(subset)
return subsets
async def select_ontology_subset(self, segments: List[TextSegment]) -> List[OntologySubset]: async def select_ontology_subset(self, segments: List[TextSegment]) -> List[OntologySubset]:
"""Select relevant ontology subsets for text segments. """Select relevant ontology subsets for text segments.
@ -56,6 +81,15 @@ class OntologySelector:
Returns: Returns:
List of ontology subsets with relevant elements List of ontology subsets with relevant elements
""" """
total = self._total_ontology_elements()
if total < self.bypass_selector_below:
logger.info(
f"Ontology has {total} elements (below "
f"bypass_selector_below={self.bypass_selector_below}), "
f"using full ontology"
)
return self._build_full_subsets()
# Collect all relevant elements # Collect all relevant elements
relevant_elements = await self._find_relevant_elements(segments) relevant_elements = await self._find_relevant_elements(segments)

View file

@ -6,7 +6,7 @@ with full URIs and correct is_uri flags.
""" """
import logging import logging
from typing import List, Optional from typing import List, Optional, Set
from .... schema import Triple, Term, IRI, LITERAL from .... schema import Triple, Term, IRI, LITERAL
from .... rdf import RDF_TYPE, RDF_LABEL from .... rdf import RDF_TYPE, RDF_LABEL
@ -32,6 +32,25 @@ class TripleConverter:
self.ontology_id = ontology_id self.ontology_id = ontology_id
self.entity_registry = EntityRegistry(ontology_id) self.entity_registry = EntityRegistry(ontology_id)
def _get_ancestor_classes(self, class_id: str) -> Set[str]:
ancestors = set()
current = class_id
while current:
cls_def = self.ontology_subset.classes.get(current)
if not cls_def:
break
parent = cls_def.get("subclass_of") if isinstance(cls_def, dict) else getattr(cls_def, "subclass_of", None)
if not parent or parent in ancestors:
break
ancestors.add(parent)
current = parent
return ancestors
def _matches_class_constraint(self, actual_type: str, expected_type: str) -> bool:
if actual_type == expected_type:
return True
return expected_type in self._get_ancestor_classes(actual_type)
def convert_all(self, extraction: ExtractionResult) -> List[Triple]: def convert_all(self, extraction: ExtractionResult) -> List[Triple]:
"""Convert complete extraction result to RDF triples. """Convert complete extraction result to RDF triples.
@ -129,6 +148,29 @@ class TripleConverter:
logger.warning(f"Unknown relationship '{relationship.relation}', skipping") logger.warning(f"Unknown relationship '{relationship.relation}', skipping")
return None return None
# Enforce domain/range constraints when declared
prop_def = self.ontology_subset.object_properties.get(
relationship.relation, {}
)
domain = prop_def.get("domain") if isinstance(prop_def, dict) else getattr(prop_def, "domain", None)
range_ = prop_def.get("range") if isinstance(prop_def, dict) else getattr(prop_def, "range", None)
if domain and not self._matches_class_constraint(relationship.subject_type, domain):
logger.warning(
f"Domain violation: '{relationship.relation}' expects "
f"domain '{domain}', got subject type "
f"'{relationship.subject_type}', skipping"
)
return None
if range_ and not self._matches_class_constraint(relationship.object_type, range_):
logger.warning(
f"Range violation: '{relationship.relation}' expects "
f"range '{range_}', got object type "
f"'{relationship.object_type}', skipping"
)
return None
# Generate triple: subject property object # Generate triple: subject property object
return Triple( return Triple(
s=Term(type=IRI, iri=subject_uri), s=Term(type=IRI, iri=subject_uri),
@ -157,11 +199,25 @@ class TripleConverter:
logger.warning(f"Unknown attribute '{attribute.attribute}', skipping") logger.warning(f"Unknown attribute '{attribute.attribute}', skipping")
return None return None
# Enforce domain constraint when declared
prop_def = self.ontology_subset.datatype_properties.get(
attribute.attribute, {}
)
domain = prop_def.get("domain") if isinstance(prop_def, dict) else getattr(prop_def, "domain", None)
if domain and not self._matches_class_constraint(attribute.entity_type, domain):
logger.warning(
f"Domain violation: attribute '{attribute.attribute}' "
f"expects domain '{domain}', got entity type "
f"'{attribute.entity_type}', skipping"
)
return None
# Generate triple: entity property "literal value" # Generate triple: entity property "literal value"
return Triple( return Triple(
s=Term(type=IRI, iri=entity_uri), s=Term(type=IRI, iri=entity_uri),
p=Term(type=IRI, iri=property_uri), p=Term(type=IRI, iri=property_uri),
o=Term(type=LITERAL, value=attribute.value) # Literal! o=Term(type=LITERAL, value=attribute.value)
) )
def _get_class_uri(self, class_id: str) -> Optional[str]: def _get_class_uri(self, class_id: str) -> Optional[str]: