Fix ontology selector defaults, add bypass mode, enforce domain/range (#929)

- Align similarity_threshold default to 0.3 everywhere (class signature
  had stale 0.7). Fix matching contradiction in tech-spec.
- Add bypass_selector_below parameter (default 5) to skip vector
  similarity selection when ontology element count is small enough.
- Enforce domain/range constraints in TripleConverter for object
  properties and datatype properties, with subclass hierarchy support.
  Properties with no declared domain/range pass through unchanged.
- Add unit tests for domain/range validation, subclass acceptance,
  polymorphic pass-through, and selector bypass.

Fixes #908, #920
This commit is contained in:
cybermaggedon 2026-05-16 15:13:38 +01:00 committed by GitHub
parent aea4c2df8e
commit 38d9c746a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 501 additions and 13 deletions

View file

@ -278,7 +278,7 @@ The system uses **FAISS (Facebook AI Similarity Search)** with IndexFlatIP for e
3. **Similarity Search**:
- For each text segment embedding, search the vector store
- Retrieve top-k (e.g., 10) most similar ontology elements
- Apply similarity threshold (e.g., 0.7) to filter weak matches
- Apply similarity threshold (e.g., 0.3) to filter weak matches
- Aggregate results across all segments, tracking match frequencies
4. **Dependency Resolution**:

View file

@ -0,0 +1,389 @@
"""
Tests for TripleConverter domain/range enforcement and
OntologySelector bypass for small ontologies.
Covers fixes for #908 (bypass_selector_below) and #920 (domain/range validation).
"""
import pytest
from unittest.mock import Mock, AsyncMock
from trustgraph.extract.kg.ontology.triple_converter import TripleConverter
from trustgraph.extract.kg.ontology.ontology_selector import (
OntologySelector,
OntologySubset,
)
from trustgraph.extract.kg.ontology.ontology_loader import (
Ontology,
OntologyClass,
OntologyProperty,
)
from trustgraph.extract.kg.ontology.simplified_parser import (
Relationship,
Attribute,
)
from trustgraph.extract.kg.ontology.text_processor import TextSegment
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def ontology_subset():
"""Ontology subset with classes, hierarchy, and constrained properties."""
return OntologySubset(
ontology_id="test",
classes={
"Person": {
"uri": "http://example.org/Person",
"type": "owl:Class",
"labels": [{"value": "Person"}],
"subclass_of": None,
},
"Employee": {
"uri": "http://example.org/Employee",
"type": "owl:Class",
"labels": [{"value": "Employee"}],
"subclass_of": "Person",
},
"Manager": {
"uri": "http://example.org/Manager",
"type": "owl:Class",
"labels": [{"value": "Manager"}],
"subclass_of": "Employee",
},
"Company": {
"uri": "http://example.org/Company",
"type": "owl:Class",
"labels": [{"value": "Company"}],
"subclass_of": None,
},
"Product": {
"uri": "http://example.org/Product",
"type": "owl:Class",
"labels": [{"value": "Product"}],
"subclass_of": None,
},
},
object_properties={
"worksFor": {
"uri": "http://example.org/worksFor",
"type": "owl:ObjectProperty",
"labels": [{"value": "works for"}],
"domain": "Person",
"range": "Company",
},
"manages": {
"uri": "http://example.org/manages",
"type": "owl:ObjectProperty",
"labels": [{"value": "manages"}],
"domain": "Manager",
"range": "Employee",
},
"relatedTo": {
"uri": "http://example.org/relatedTo",
"type": "owl:ObjectProperty",
"labels": [{"value": "related to"}],
"domain": None,
"range": None,
},
},
datatype_properties={
"employeeId": {
"uri": "http://example.org/employeeId",
"type": "owl:DatatypeProperty",
"labels": [{"value": "employee ID"}],
"domain": "Employee",
},
"description": {
"uri": "http://example.org/description",
"type": "owl:DatatypeProperty",
"labels": [{"value": "description"}],
"domain": None,
},
},
metadata={"name": "Test Ontology"},
)
@pytest.fixture
def converter(ontology_subset):
return TripleConverter(ontology_subset=ontology_subset, ontology_id="test")
# ---------------------------------------------------------------------------
# Domain/range enforcement — relationships
# ---------------------------------------------------------------------------
class TestRelationshipDomainRange:
def test_valid_domain_and_range(self, converter):
rel = Relationship(
subject="Alice", subject_type="Person",
relation="worksFor",
object="Acme Corp", object_type="Company",
)
triple = converter.convert_relationship(rel)
assert triple is not None
def test_domain_violation_rejected(self, converter):
rel = Relationship(
subject="Widget", subject_type="Product",
relation="worksFor",
object="Acme Corp", object_type="Company",
)
assert converter.convert_relationship(rel) is None
def test_range_violation_rejected(self, converter):
rel = Relationship(
subject="Alice", subject_type="Person",
relation="worksFor",
object="Widget", object_type="Product",
)
assert converter.convert_relationship(rel) is None
def test_both_domain_and_range_violated(self, converter):
rel = Relationship(
subject="Widget", subject_type="Product",
relation="worksFor",
object="Gadget", object_type="Product",
)
assert converter.convert_relationship(rel) is None
# ---------------------------------------------------------------------------
# Subclass acceptance
# ---------------------------------------------------------------------------
class TestSubclassAcceptance:
def test_direct_subclass_matches_domain(self, converter):
"""Employee is subclass of Person; worksFor domain is Person."""
rel = Relationship(
subject="Bob", subject_type="Employee",
relation="worksFor",
object="Acme Corp", object_type="Company",
)
assert converter.convert_relationship(rel) is not None
def test_transitive_subclass_matches_domain(self, converter):
"""Manager → Employee → Person; worksFor domain is Person."""
rel = Relationship(
subject="Carol", subject_type="Manager",
relation="worksFor",
object="Acme Corp", object_type="Company",
)
assert converter.convert_relationship(rel) is not None
def test_subclass_matches_range(self, converter):
"""manages range is Employee; Manager is subclass of Employee."""
rel = Relationship(
subject="Carol", subject_type="Manager",
relation="manages",
object="Dave", object_type="Manager",
)
assert converter.convert_relationship(rel) is not None
def test_superclass_does_not_match_subclass_constraint(self, converter):
"""manages domain is Manager; Person is NOT a subclass of Manager."""
rel = Relationship(
subject="Alice", subject_type="Person",
relation="manages",
object="Bob", object_type="Employee",
)
assert converter.convert_relationship(rel) is None
# ---------------------------------------------------------------------------
# Polymorphic properties (no domain/range)
# ---------------------------------------------------------------------------
class TestPolymorphicProperties:
def test_no_domain_no_range_allows_anything(self, converter):
rel = Relationship(
subject="Alice", subject_type="Person",
relation="relatedTo",
object="Acme Corp", object_type="Company",
)
assert converter.convert_relationship(rel) is not None
def test_polymorphic_with_unrelated_types(self, converter):
rel = Relationship(
subject="Widget", subject_type="Product",
relation="relatedTo",
object="Bob", object_type="Employee",
)
assert converter.convert_relationship(rel) is not None
# ---------------------------------------------------------------------------
# Datatype property domain enforcement
# ---------------------------------------------------------------------------
class TestAttributeDomainValidation:
def test_valid_domain(self, converter):
attr = Attribute(
entity="Bob", entity_type="Employee",
attribute="employeeId", value="E-1234",
)
assert converter.convert_attribute(attr) is not None
def test_subclass_matches_domain(self, converter):
"""Manager is subclass of Employee; employeeId domain is Employee."""
attr = Attribute(
entity="Carol", entity_type="Manager",
attribute="employeeId", value="M-5678",
)
assert converter.convert_attribute(attr) is not None
def test_domain_violation_rejected(self, converter):
attr = Attribute(
entity="Acme Corp", entity_type="Company",
attribute="employeeId", value="E-0000",
)
assert converter.convert_attribute(attr) is None
def test_no_domain_allows_anything(self, converter):
attr = Attribute(
entity="Widget", entity_type="Product",
attribute="description", value="A useful widget",
)
assert converter.convert_attribute(attr) is not None
# ---------------------------------------------------------------------------
# OntologySelector bypass for small ontologies (#908)
# ---------------------------------------------------------------------------
def _make_ontology(n_classes, n_obj_props=0, n_dt_props=0):
classes = {
f"C{i}": OntologyClass(uri=f"http://example.org/C{i}")
for i in range(n_classes)
}
obj_props = {
f"op{i}": OntologyProperty(
uri=f"http://example.org/op{i}", type="owl:ObjectProperty"
)
for i in range(n_obj_props)
}
dt_props = {
f"dp{i}": OntologyProperty(
uri=f"http://example.org/dp{i}", type="owl:DatatypeProperty"
)
for i in range(n_dt_props)
}
return Ontology(
id="tiny",
metadata={"name": "Tiny"},
classes=classes,
object_properties=obj_props,
datatype_properties=dt_props,
)
def _make_loader(ontology):
loader = Mock()
loader.get_ontology.return_value = ontology
loader.get_all_ontologies.return_value = {"tiny": ontology}
return loader
class TestBypassSelectorBelow:
async def test_bypass_returns_full_ontology(self):
"""With 3 elements and bypass_selector_below=5, selector is bypassed."""
ont = _make_ontology(2, 1, 0)
loader = _make_loader(ont)
embedder = Mock()
selector = OntologySelector(
ontology_embedder=embedder,
ontology_loader=loader,
bypass_selector_below=5,
)
segments = [TextSegment(text="some text", type="sentence", position=0)]
subsets = await selector.select_ontology_subset(segments)
assert len(subsets) == 1
assert subsets[0].ontology_id == "tiny"
assert len(subsets[0].classes) == 2
assert len(subsets[0].object_properties) == 1
assert subsets[0].relevance_score == 1.0
# Embedder should never be called
embedder.embed_text.assert_not_called()
async def test_no_bypass_when_above_threshold(self):
"""With 10 elements and bypass_selector_below=5, selector runs normally."""
ont = _make_ontology(6, 3, 1)
loader = _make_loader(ont)
embedder = Mock()
embedder.embed_text = AsyncMock(return_value=[0.1, 0.2])
vector_store = Mock()
vector_store.size.return_value = 10
vector_store.search.return_value = []
embedder.get_vector_store.return_value = vector_store
selector = OntologySelector(
ontology_embedder=embedder,
ontology_loader=loader,
bypass_selector_below=5,
)
segments = [TextSegment(text="some text", type="sentence", position=0)]
subsets = await selector.select_ontology_subset(segments)
# Vector store was consulted (selector ran normally)
vector_store.size.assert_called_once()
async def test_bypass_at_exact_threshold_not_triggered(self):
"""With exactly 5 elements and bypass_selector_below=5, selector runs (< not <=)."""
ont = _make_ontology(3, 1, 1) # total = 5
loader = _make_loader(ont)
embedder = Mock()
embedder.embed_text = AsyncMock(return_value=[0.1, 0.2])
vector_store = Mock()
vector_store.size.return_value = 5
vector_store.search.return_value = []
embedder.get_vector_store.return_value = vector_store
selector = OntologySelector(
ontology_embedder=embedder,
ontology_loader=loader,
bypass_selector_below=5,
)
segments = [TextSegment(text="some text", type="sentence", position=0)]
subsets = await selector.select_ontology_subset(segments)
# Should NOT bypass — 5 is not < 5
vector_store.size.assert_called_once()
async def test_bypass_zero_disables(self):
"""bypass_selector_below=0 means bypass never triggers."""
ont = _make_ontology(0, 0, 0) # empty ontology
loader = _make_loader(ont)
embedder = Mock()
embedder.embed_text = AsyncMock(return_value=[0.1])
vector_store = Mock()
vector_store.size.return_value = 0
vector_store.search.return_value = []
embedder.get_vector_store.return_value = vector_store
selector = OntologySelector(
ontology_embedder=embedder,
ontology_loader=loader,
bypass_selector_below=0,
)
segments = [TextSegment(text="some text", type="sentence", position=0)]
subsets = await selector.select_ontology_subset(segments)
# 0 is not < 0, so bypass doesn't trigger
vector_store.size.assert_called_once()

View file

@ -121,6 +121,7 @@ class Processor(FlowProcessor):
# Configuration
self.top_k = params.get("top_k", 10)
self.similarity_threshold = params.get("similarity_threshold", 0.3)
self.bypass_selector_below = params.get("bypass_selector_below", 5)
# Per-workspace ontology version tracking
self.current_ontology_versions = {} # workspace -> version
@ -187,7 +188,8 @@ class Processor(FlowProcessor):
ontology_embedder=ontology_embedder,
ontology_loader=loader,
top_k=self.top_k,
similarity_threshold=self.similarity_threshold
similarity_threshold=self.similarity_threshold,
bypass_selector_below=self.bypass_selector_below,
)
# Store flow-specific components
@ -981,6 +983,13 @@ class Processor(FlowProcessor):
default=0.3,
help='Similarity threshold for ontology matching (default: 0.3, range: 0.0-1.0)'
)
parser.add_argument(
'--bypass-selector-below',
type=int,
default=5,
help='Bypass ontology selector when total ontology elements '
'(classes + properties) is below this value (default: 5)'
)
parser.add_argument(
'--triples-batch-size',
type=int,

View file

@ -33,19 +33,44 @@ class OntologySelector:
def __init__(self, ontology_embedder: OntologyEmbedder,
ontology_loader: OntologyLoader,
top_k: int = 10,
similarity_threshold: float = 0.7):
"""Initialize the ontology selector.
Args:
ontology_embedder: Embedder with vector store
ontology_loader: Loader with ontology definitions
top_k: Number of top results to retrieve per segment
similarity_threshold: Minimum similarity score
"""
similarity_threshold: float = 0.3,
bypass_selector_below: int = 5):
self.embedder = ontology_embedder
self.loader = ontology_loader
self.top_k = top_k
self.similarity_threshold = similarity_threshold
self.bypass_selector_below = bypass_selector_below
def _total_ontology_elements(self) -> int:
total = 0
for ontology in self.loader.get_all_ontologies().values():
total += len(ontology.classes)
total += len(ontology.object_properties)
total += len(ontology.datatype_properties)
return total
def _build_full_subsets(self) -> List[OntologySubset]:
subsets = []
for ont_id, ontology in self.loader.get_all_ontologies().items():
subset = OntologySubset(
ontology_id=ont_id,
classes={
cid: cls.__dict__
for cid, cls in ontology.classes.items()
},
object_properties={
pid: prop.__dict__
for pid, prop in ontology.object_properties.items()
},
datatype_properties={
pid: prop.__dict__
for pid, prop in ontology.datatype_properties.items()
},
metadata=ontology.metadata,
relevance_score=1.0,
)
subsets.append(subset)
return subsets
async def select_ontology_subset(self, segments: List[TextSegment]) -> List[OntologySubset]:
"""Select relevant ontology subsets for text segments.
@ -56,6 +81,15 @@ class OntologySelector:
Returns:
List of ontology subsets with relevant elements
"""
total = self._total_ontology_elements()
if total < self.bypass_selector_below:
logger.info(
f"Ontology has {total} elements (below "
f"bypass_selector_below={self.bypass_selector_below}), "
f"using full ontology"
)
return self._build_full_subsets()
# Collect all relevant elements
relevant_elements = await self._find_relevant_elements(segments)

View file

@ -6,7 +6,7 @@ with full URIs and correct is_uri flags.
"""
import logging
from typing import List, Optional
from typing import List, Optional, Set
from .... schema import Triple, Term, IRI, LITERAL
from .... rdf import RDF_TYPE, RDF_LABEL
@ -32,6 +32,25 @@ class TripleConverter:
self.ontology_id = ontology_id
self.entity_registry = EntityRegistry(ontology_id)
def _get_ancestor_classes(self, class_id: str) -> Set[str]:
ancestors = set()
current = class_id
while current:
cls_def = self.ontology_subset.classes.get(current)
if not cls_def:
break
parent = cls_def.get("subclass_of") if isinstance(cls_def, dict) else getattr(cls_def, "subclass_of", None)
if not parent or parent in ancestors:
break
ancestors.add(parent)
current = parent
return ancestors
def _matches_class_constraint(self, actual_type: str, expected_type: str) -> bool:
if actual_type == expected_type:
return True
return expected_type in self._get_ancestor_classes(actual_type)
def convert_all(self, extraction: ExtractionResult) -> List[Triple]:
"""Convert complete extraction result to RDF triples.
@ -129,6 +148,29 @@ class TripleConverter:
logger.warning(f"Unknown relationship '{relationship.relation}', skipping")
return None
# Enforce domain/range constraints when declared
prop_def = self.ontology_subset.object_properties.get(
relationship.relation, {}
)
domain = prop_def.get("domain") if isinstance(prop_def, dict) else getattr(prop_def, "domain", None)
range_ = prop_def.get("range") if isinstance(prop_def, dict) else getattr(prop_def, "range", None)
if domain and not self._matches_class_constraint(relationship.subject_type, domain):
logger.warning(
f"Domain violation: '{relationship.relation}' expects "
f"domain '{domain}', got subject type "
f"'{relationship.subject_type}', skipping"
)
return None
if range_ and not self._matches_class_constraint(relationship.object_type, range_):
logger.warning(
f"Range violation: '{relationship.relation}' expects "
f"range '{range_}', got object type "
f"'{relationship.object_type}', skipping"
)
return None
# Generate triple: subject property object
return Triple(
s=Term(type=IRI, iri=subject_uri),
@ -157,11 +199,25 @@ class TripleConverter:
logger.warning(f"Unknown attribute '{attribute.attribute}', skipping")
return None
# Enforce domain constraint when declared
prop_def = self.ontology_subset.datatype_properties.get(
attribute.attribute, {}
)
domain = prop_def.get("domain") if isinstance(prop_def, dict) else getattr(prop_def, "domain", None)
if domain and not self._matches_class_constraint(attribute.entity_type, domain):
logger.warning(
f"Domain violation: attribute '{attribute.attribute}' "
f"expects domain '{domain}', got entity type "
f"'{attribute.entity_type}', skipping"
)
return None
# Generate triple: entity property "literal value"
return Triple(
s=Term(type=IRI, iri=entity_uri),
p=Term(type=IRI, iri=property_uri),
o=Term(type=LITERAL, value=attribute.value) # Literal!
o=Term(type=LITERAL, value=attribute.value)
)
def _get_class_uri(self, class_id: str) -> Optional[str]: