diff --git a/docs/tech-specs/ontorag.md b/docs/tech-specs/ontorag.md index 86a3cd19..460e72ba 100644 --- a/docs/tech-specs/ontorag.md +++ b/docs/tech-specs/ontorag.md @@ -278,7 +278,7 @@ The system uses **FAISS (Facebook AI Similarity Search)** with IndexFlatIP for e 3. **Similarity Search**: - For each text segment embedding, search the vector store - Retrieve top-k (e.g., 10) most similar ontology elements - - Apply similarity threshold (e.g., 0.7) to filter weak matches + - Apply similarity threshold (e.g., 0.3) to filter weak matches - Aggregate results across all segments, tracking match frequencies 4. **Dependency Resolution**: diff --git a/tests/unit/test_extract/test_ontology/test_triple_converter_validation.py b/tests/unit/test_extract/test_ontology/test_triple_converter_validation.py new file mode 100644 index 00000000..195e8adf --- /dev/null +++ b/tests/unit/test_extract/test_ontology/test_triple_converter_validation.py @@ -0,0 +1,389 @@ +""" +Tests for TripleConverter domain/range enforcement and +OntologySelector bypass for small ontologies. + +Covers fixes for #908 (bypass_selector_below) and #920 (domain/range validation). +""" + +import pytest +from unittest.mock import Mock, AsyncMock + +from trustgraph.extract.kg.ontology.triple_converter import TripleConverter +from trustgraph.extract.kg.ontology.ontology_selector import ( + OntologySelector, + OntologySubset, +) +from trustgraph.extract.kg.ontology.ontology_loader import ( + Ontology, + OntologyClass, + OntologyProperty, +) +from trustgraph.extract.kg.ontology.simplified_parser import ( + Relationship, + Attribute, +) +from trustgraph.extract.kg.ontology.text_processor import TextSegment + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def ontology_subset(): + """Ontology subset with classes, hierarchy, and constrained properties.""" + return OntologySubset( + ontology_id="test", + classes={ + "Person": { + "uri": "http://example.org/Person", + "type": "owl:Class", + "labels": [{"value": "Person"}], + "subclass_of": None, + }, + "Employee": { + "uri": "http://example.org/Employee", + "type": "owl:Class", + "labels": [{"value": "Employee"}], + "subclass_of": "Person", + }, + "Manager": { + "uri": "http://example.org/Manager", + "type": "owl:Class", + "labels": [{"value": "Manager"}], + "subclass_of": "Employee", + }, + "Company": { + "uri": "http://example.org/Company", + "type": "owl:Class", + "labels": [{"value": "Company"}], + "subclass_of": None, + }, + "Product": { + "uri": "http://example.org/Product", + "type": "owl:Class", + "labels": [{"value": "Product"}], + "subclass_of": None, + }, + }, + object_properties={ + "worksFor": { + "uri": "http://example.org/worksFor", + "type": "owl:ObjectProperty", + "labels": [{"value": "works for"}], + "domain": "Person", + "range": "Company", + }, + "manages": { + "uri": "http://example.org/manages", + "type": "owl:ObjectProperty", + "labels": [{"value": "manages"}], + "domain": "Manager", + "range": "Employee", + }, + "relatedTo": { + "uri": "http://example.org/relatedTo", + "type": "owl:ObjectProperty", + "labels": [{"value": "related to"}], + "domain": None, + "range": None, + }, + }, + datatype_properties={ + "employeeId": { + "uri": "http://example.org/employeeId", + "type": "owl:DatatypeProperty", + "labels": [{"value": "employee ID"}], + "domain": "Employee", + }, + "description": { + "uri": "http://example.org/description", + "type": "owl:DatatypeProperty", + "labels": [{"value": "description"}], + "domain": None, + }, + }, + metadata={"name": "Test Ontology"}, + ) + + +@pytest.fixture +def converter(ontology_subset): + return TripleConverter(ontology_subset=ontology_subset, ontology_id="test") + + +# --------------------------------------------------------------------------- +# Domain/range enforcement — relationships +# --------------------------------------------------------------------------- + +class TestRelationshipDomainRange: + + def test_valid_domain_and_range(self, converter): + rel = Relationship( + subject="Alice", subject_type="Person", + relation="worksFor", + object="Acme Corp", object_type="Company", + ) + triple = converter.convert_relationship(rel) + assert triple is not None + + def test_domain_violation_rejected(self, converter): + rel = Relationship( + subject="Widget", subject_type="Product", + relation="worksFor", + object="Acme Corp", object_type="Company", + ) + assert converter.convert_relationship(rel) is None + + def test_range_violation_rejected(self, converter): + rel = Relationship( + subject="Alice", subject_type="Person", + relation="worksFor", + object="Widget", object_type="Product", + ) + assert converter.convert_relationship(rel) is None + + def test_both_domain_and_range_violated(self, converter): + rel = Relationship( + subject="Widget", subject_type="Product", + relation="worksFor", + object="Gadget", object_type="Product", + ) + assert converter.convert_relationship(rel) is None + + +# --------------------------------------------------------------------------- +# Subclass acceptance +# --------------------------------------------------------------------------- + +class TestSubclassAcceptance: + + def test_direct_subclass_matches_domain(self, converter): + """Employee is subclass of Person; worksFor domain is Person.""" + rel = Relationship( + subject="Bob", subject_type="Employee", + relation="worksFor", + object="Acme Corp", object_type="Company", + ) + assert converter.convert_relationship(rel) is not None + + def test_transitive_subclass_matches_domain(self, converter): + """Manager → Employee → Person; worksFor domain is Person.""" + rel = Relationship( + subject="Carol", subject_type="Manager", + relation="worksFor", + object="Acme Corp", object_type="Company", + ) + assert converter.convert_relationship(rel) is not None + + def test_subclass_matches_range(self, converter): + """manages range is Employee; Manager is subclass of Employee.""" + rel = Relationship( + subject="Carol", subject_type="Manager", + relation="manages", + object="Dave", object_type="Manager", + ) + assert converter.convert_relationship(rel) is not None + + def test_superclass_does_not_match_subclass_constraint(self, converter): + """manages domain is Manager; Person is NOT a subclass of Manager.""" + rel = Relationship( + subject="Alice", subject_type="Person", + relation="manages", + object="Bob", object_type="Employee", + ) + assert converter.convert_relationship(rel) is None + + +# --------------------------------------------------------------------------- +# Polymorphic properties (no domain/range) +# --------------------------------------------------------------------------- + +class TestPolymorphicProperties: + + def test_no_domain_no_range_allows_anything(self, converter): + rel = Relationship( + subject="Alice", subject_type="Person", + relation="relatedTo", + object="Acme Corp", object_type="Company", + ) + assert converter.convert_relationship(rel) is not None + + def test_polymorphic_with_unrelated_types(self, converter): + rel = Relationship( + subject="Widget", subject_type="Product", + relation="relatedTo", + object="Bob", object_type="Employee", + ) + assert converter.convert_relationship(rel) is not None + + +# --------------------------------------------------------------------------- +# Datatype property domain enforcement +# --------------------------------------------------------------------------- + +class TestAttributeDomainValidation: + + def test_valid_domain(self, converter): + attr = Attribute( + entity="Bob", entity_type="Employee", + attribute="employeeId", value="E-1234", + ) + assert converter.convert_attribute(attr) is not None + + def test_subclass_matches_domain(self, converter): + """Manager is subclass of Employee; employeeId domain is Employee.""" + attr = Attribute( + entity="Carol", entity_type="Manager", + attribute="employeeId", value="M-5678", + ) + assert converter.convert_attribute(attr) is not None + + def test_domain_violation_rejected(self, converter): + attr = Attribute( + entity="Acme Corp", entity_type="Company", + attribute="employeeId", value="E-0000", + ) + assert converter.convert_attribute(attr) is None + + def test_no_domain_allows_anything(self, converter): + attr = Attribute( + entity="Widget", entity_type="Product", + attribute="description", value="A useful widget", + ) + assert converter.convert_attribute(attr) is not None + + +# --------------------------------------------------------------------------- +# OntologySelector bypass for small ontologies (#908) +# --------------------------------------------------------------------------- + +def _make_ontology(n_classes, n_obj_props=0, n_dt_props=0): + classes = { + f"C{i}": OntologyClass(uri=f"http://example.org/C{i}") + for i in range(n_classes) + } + obj_props = { + f"op{i}": OntologyProperty( + uri=f"http://example.org/op{i}", type="owl:ObjectProperty" + ) + for i in range(n_obj_props) + } + dt_props = { + f"dp{i}": OntologyProperty( + uri=f"http://example.org/dp{i}", type="owl:DatatypeProperty" + ) + for i in range(n_dt_props) + } + return Ontology( + id="tiny", + metadata={"name": "Tiny"}, + classes=classes, + object_properties=obj_props, + datatype_properties=dt_props, + ) + + +def _make_loader(ontology): + loader = Mock() + loader.get_ontology.return_value = ontology + loader.get_all_ontologies.return_value = {"tiny": ontology} + return loader + + +class TestBypassSelectorBelow: + + async def test_bypass_returns_full_ontology(self): + """With 3 elements and bypass_selector_below=5, selector is bypassed.""" + ont = _make_ontology(2, 1, 0) + loader = _make_loader(ont) + embedder = Mock() + + selector = OntologySelector( + ontology_embedder=embedder, + ontology_loader=loader, + bypass_selector_below=5, + ) + + segments = [TextSegment(text="some text", type="sentence", position=0)] + subsets = await selector.select_ontology_subset(segments) + + assert len(subsets) == 1 + assert subsets[0].ontology_id == "tiny" + assert len(subsets[0].classes) == 2 + assert len(subsets[0].object_properties) == 1 + assert subsets[0].relevance_score == 1.0 + # Embedder should never be called + embedder.embed_text.assert_not_called() + + async def test_no_bypass_when_above_threshold(self): + """With 10 elements and bypass_selector_below=5, selector runs normally.""" + ont = _make_ontology(6, 3, 1) + loader = _make_loader(ont) + + embedder = Mock() + embedder.embed_text = AsyncMock(return_value=[0.1, 0.2]) + vector_store = Mock() + vector_store.size.return_value = 10 + vector_store.search.return_value = [] + embedder.get_vector_store.return_value = vector_store + + selector = OntologySelector( + ontology_embedder=embedder, + ontology_loader=loader, + bypass_selector_below=5, + ) + + segments = [TextSegment(text="some text", type="sentence", position=0)] + subsets = await selector.select_ontology_subset(segments) + + # Vector store was consulted (selector ran normally) + vector_store.size.assert_called_once() + + async def test_bypass_at_exact_threshold_not_triggered(self): + """With exactly 5 elements and bypass_selector_below=5, selector runs (< not <=).""" + ont = _make_ontology(3, 1, 1) # total = 5 + loader = _make_loader(ont) + + embedder = Mock() + embedder.embed_text = AsyncMock(return_value=[0.1, 0.2]) + vector_store = Mock() + vector_store.size.return_value = 5 + vector_store.search.return_value = [] + embedder.get_vector_store.return_value = vector_store + + selector = OntologySelector( + ontology_embedder=embedder, + ontology_loader=loader, + bypass_selector_below=5, + ) + + segments = [TextSegment(text="some text", type="sentence", position=0)] + subsets = await selector.select_ontology_subset(segments) + + # Should NOT bypass — 5 is not < 5 + vector_store.size.assert_called_once() + + async def test_bypass_zero_disables(self): + """bypass_selector_below=0 means bypass never triggers.""" + ont = _make_ontology(0, 0, 0) # empty ontology + loader = _make_loader(ont) + + embedder = Mock() + embedder.embed_text = AsyncMock(return_value=[0.1]) + vector_store = Mock() + vector_store.size.return_value = 0 + vector_store.search.return_value = [] + embedder.get_vector_store.return_value = vector_store + + selector = OntologySelector( + ontology_embedder=embedder, + ontology_loader=loader, + bypass_selector_below=0, + ) + + segments = [TextSegment(text="some text", type="sentence", position=0)] + subsets = await selector.select_ontology_subset(segments) + + # 0 is not < 0, so bypass doesn't trigger + vector_store.size.assert_called_once() diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index 1d45d3f9..6a43e547 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -121,6 +121,7 @@ class Processor(FlowProcessor): # Configuration self.top_k = params.get("top_k", 10) self.similarity_threshold = params.get("similarity_threshold", 0.3) + self.bypass_selector_below = params.get("bypass_selector_below", 5) # Per-workspace ontology version tracking self.current_ontology_versions = {} # workspace -> version @@ -187,7 +188,8 @@ class Processor(FlowProcessor): ontology_embedder=ontology_embedder, ontology_loader=loader, top_k=self.top_k, - similarity_threshold=self.similarity_threshold + similarity_threshold=self.similarity_threshold, + bypass_selector_below=self.bypass_selector_below, ) # Store flow-specific components @@ -981,6 +983,13 @@ class Processor(FlowProcessor): default=0.3, help='Similarity threshold for ontology matching (default: 0.3, range: 0.0-1.0)' ) + parser.add_argument( + '--bypass-selector-below', + type=int, + default=5, + help='Bypass ontology selector when total ontology elements ' + '(classes + properties) is below this value (default: 5)' + ) parser.add_argument( '--triples-batch-size', type=int, diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py index 5111529a..5fd60a0f 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py @@ -33,19 +33,44 @@ class OntologySelector: def __init__(self, ontology_embedder: OntologyEmbedder, ontology_loader: OntologyLoader, top_k: int = 10, - similarity_threshold: float = 0.7): - """Initialize the ontology selector. - - Args: - ontology_embedder: Embedder with vector store - ontology_loader: Loader with ontology definitions - top_k: Number of top results to retrieve per segment - similarity_threshold: Minimum similarity score - """ + similarity_threshold: float = 0.3, + bypass_selector_below: int = 5): self.embedder = ontology_embedder self.loader = ontology_loader self.top_k = top_k self.similarity_threshold = similarity_threshold + self.bypass_selector_below = bypass_selector_below + + def _total_ontology_elements(self) -> int: + total = 0 + for ontology in self.loader.get_all_ontologies().values(): + total += len(ontology.classes) + total += len(ontology.object_properties) + total += len(ontology.datatype_properties) + return total + + def _build_full_subsets(self) -> List[OntologySubset]: + subsets = [] + for ont_id, ontology in self.loader.get_all_ontologies().items(): + subset = OntologySubset( + ontology_id=ont_id, + classes={ + cid: cls.__dict__ + for cid, cls in ontology.classes.items() + }, + object_properties={ + pid: prop.__dict__ + for pid, prop in ontology.object_properties.items() + }, + datatype_properties={ + pid: prop.__dict__ + for pid, prop in ontology.datatype_properties.items() + }, + metadata=ontology.metadata, + relevance_score=1.0, + ) + subsets.append(subset) + return subsets async def select_ontology_subset(self, segments: List[TextSegment]) -> List[OntologySubset]: """Select relevant ontology subsets for text segments. @@ -56,6 +81,15 @@ class OntologySelector: Returns: List of ontology subsets with relevant elements """ + total = self._total_ontology_elements() + if total < self.bypass_selector_below: + logger.info( + f"Ontology has {total} elements (below " + f"bypass_selector_below={self.bypass_selector_below}), " + f"using full ontology" + ) + return self._build_full_subsets() + # Collect all relevant elements relevant_elements = await self._find_relevant_elements(segments) diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py b/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py index 06fff4f4..d9e6c837 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py @@ -6,7 +6,7 @@ with full URIs and correct is_uri flags. """ import logging -from typing import List, Optional +from typing import List, Optional, Set from .... schema import Triple, Term, IRI, LITERAL from .... rdf import RDF_TYPE, RDF_LABEL @@ -32,6 +32,25 @@ class TripleConverter: self.ontology_id = ontology_id self.entity_registry = EntityRegistry(ontology_id) + def _get_ancestor_classes(self, class_id: str) -> Set[str]: + ancestors = set() + current = class_id + while current: + cls_def = self.ontology_subset.classes.get(current) + if not cls_def: + break + parent = cls_def.get("subclass_of") if isinstance(cls_def, dict) else getattr(cls_def, "subclass_of", None) + if not parent or parent in ancestors: + break + ancestors.add(parent) + current = parent + return ancestors + + def _matches_class_constraint(self, actual_type: str, expected_type: str) -> bool: + if actual_type == expected_type: + return True + return expected_type in self._get_ancestor_classes(actual_type) + def convert_all(self, extraction: ExtractionResult) -> List[Triple]: """Convert complete extraction result to RDF triples. @@ -129,6 +148,29 @@ class TripleConverter: logger.warning(f"Unknown relationship '{relationship.relation}', skipping") return None + # Enforce domain/range constraints when declared + prop_def = self.ontology_subset.object_properties.get( + relationship.relation, {} + ) + domain = prop_def.get("domain") if isinstance(prop_def, dict) else getattr(prop_def, "domain", None) + range_ = prop_def.get("range") if isinstance(prop_def, dict) else getattr(prop_def, "range", None) + + if domain and not self._matches_class_constraint(relationship.subject_type, domain): + logger.warning( + f"Domain violation: '{relationship.relation}' expects " + f"domain '{domain}', got subject type " + f"'{relationship.subject_type}', skipping" + ) + return None + + if range_ and not self._matches_class_constraint(relationship.object_type, range_): + logger.warning( + f"Range violation: '{relationship.relation}' expects " + f"range '{range_}', got object type " + f"'{relationship.object_type}', skipping" + ) + return None + # Generate triple: subject property object return Triple( s=Term(type=IRI, iri=subject_uri), @@ -157,11 +199,25 @@ class TripleConverter: logger.warning(f"Unknown attribute '{attribute.attribute}', skipping") return None + # Enforce domain constraint when declared + prop_def = self.ontology_subset.datatype_properties.get( + attribute.attribute, {} + ) + domain = prop_def.get("domain") if isinstance(prop_def, dict) else getattr(prop_def, "domain", None) + + if domain and not self._matches_class_constraint(attribute.entity_type, domain): + logger.warning( + f"Domain violation: attribute '{attribute.attribute}' " + f"expects domain '{domain}', got entity type " + f"'{attribute.entity_type}', skipping" + ) + return None + # Generate triple: entity property "literal value" return Triple( s=Term(type=IRI, iri=entity_uri), p=Term(type=IRI, iri=property_uri), - o=Term(type=LITERAL, value=attribute.value) # Literal! + o=Term(type=LITERAL, value=attribute.value) ) def _get_class_uri(self, class_id: str) -> Optional[str]: