mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-18 20:05:13 +02:00
Fix ontology selector defaults, add bypass mode, enforce domain/range (#929)
- Align similarity_threshold default to 0.3 everywhere (class signature had stale 0.7). Fix matching contradiction in tech-spec. - Add bypass_selector_below parameter (default 5) to skip vector similarity selection when ontology element count is small enough. - Enforce domain/range constraints in TripleConverter for object properties and datatype properties, with subclass hierarchy support. Properties with no declared domain/range pass through unchanged. - Add unit tests for domain/range validation, subclass acceptance, polymorphic pass-through, and selector bypass. Fixes #908, #920
This commit is contained in:
parent
aea4c2df8e
commit
38d9c746a8
5 changed files with 501 additions and 13 deletions
|
|
@ -121,6 +121,7 @@ class Processor(FlowProcessor):
|
|||
# Configuration
|
||||
self.top_k = params.get("top_k", 10)
|
||||
self.similarity_threshold = params.get("similarity_threshold", 0.3)
|
||||
self.bypass_selector_below = params.get("bypass_selector_below", 5)
|
||||
|
||||
# Per-workspace ontology version tracking
|
||||
self.current_ontology_versions = {} # workspace -> version
|
||||
|
|
@ -187,7 +188,8 @@ class Processor(FlowProcessor):
|
|||
ontology_embedder=ontology_embedder,
|
||||
ontology_loader=loader,
|
||||
top_k=self.top_k,
|
||||
similarity_threshold=self.similarity_threshold
|
||||
similarity_threshold=self.similarity_threshold,
|
||||
bypass_selector_below=self.bypass_selector_below,
|
||||
)
|
||||
|
||||
# Store flow-specific components
|
||||
|
|
@ -981,6 +983,13 @@ class Processor(FlowProcessor):
|
|||
default=0.3,
|
||||
help='Similarity threshold for ontology matching (default: 0.3, range: 0.0-1.0)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--bypass-selector-below',
|
||||
type=int,
|
||||
default=5,
|
||||
help='Bypass ontology selector when total ontology elements '
|
||||
'(classes + properties) is below this value (default: 5)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--triples-batch-size',
|
||||
type=int,
|
||||
|
|
|
|||
|
|
@ -33,19 +33,44 @@ class OntologySelector:
|
|||
def __init__(self, ontology_embedder: OntologyEmbedder,
|
||||
ontology_loader: OntologyLoader,
|
||||
top_k: int = 10,
|
||||
similarity_threshold: float = 0.7):
|
||||
"""Initialize the ontology selector.
|
||||
|
||||
Args:
|
||||
ontology_embedder: Embedder with vector store
|
||||
ontology_loader: Loader with ontology definitions
|
||||
top_k: Number of top results to retrieve per segment
|
||||
similarity_threshold: Minimum similarity score
|
||||
"""
|
||||
similarity_threshold: float = 0.3,
|
||||
bypass_selector_below: int = 5):
|
||||
self.embedder = ontology_embedder
|
||||
self.loader = ontology_loader
|
||||
self.top_k = top_k
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.bypass_selector_below = bypass_selector_below
|
||||
|
||||
def _total_ontology_elements(self) -> int:
|
||||
total = 0
|
||||
for ontology in self.loader.get_all_ontologies().values():
|
||||
total += len(ontology.classes)
|
||||
total += len(ontology.object_properties)
|
||||
total += len(ontology.datatype_properties)
|
||||
return total
|
||||
|
||||
def _build_full_subsets(self) -> List[OntologySubset]:
|
||||
subsets = []
|
||||
for ont_id, ontology in self.loader.get_all_ontologies().items():
|
||||
subset = OntologySubset(
|
||||
ontology_id=ont_id,
|
||||
classes={
|
||||
cid: cls.__dict__
|
||||
for cid, cls in ontology.classes.items()
|
||||
},
|
||||
object_properties={
|
||||
pid: prop.__dict__
|
||||
for pid, prop in ontology.object_properties.items()
|
||||
},
|
||||
datatype_properties={
|
||||
pid: prop.__dict__
|
||||
for pid, prop in ontology.datatype_properties.items()
|
||||
},
|
||||
metadata=ontology.metadata,
|
||||
relevance_score=1.0,
|
||||
)
|
||||
subsets.append(subset)
|
||||
return subsets
|
||||
|
||||
async def select_ontology_subset(self, segments: List[TextSegment]) -> List[OntologySubset]:
|
||||
"""Select relevant ontology subsets for text segments.
|
||||
|
|
@ -56,6 +81,15 @@ class OntologySelector:
|
|||
Returns:
|
||||
List of ontology subsets with relevant elements
|
||||
"""
|
||||
total = self._total_ontology_elements()
|
||||
if total < self.bypass_selector_below:
|
||||
logger.info(
|
||||
f"Ontology has {total} elements (below "
|
||||
f"bypass_selector_below={self.bypass_selector_below}), "
|
||||
f"using full ontology"
|
||||
)
|
||||
return self._build_full_subsets()
|
||||
|
||||
# Collect all relevant elements
|
||||
relevant_elements = await self._find_relevant_elements(segments)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ with full URIs and correct is_uri flags.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from .... schema import Triple, Term, IRI, LITERAL
|
||||
from .... rdf import RDF_TYPE, RDF_LABEL
|
||||
|
|
@ -32,6 +32,25 @@ class TripleConverter:
|
|||
self.ontology_id = ontology_id
|
||||
self.entity_registry = EntityRegistry(ontology_id)
|
||||
|
||||
def _get_ancestor_classes(self, class_id: str) -> Set[str]:
|
||||
ancestors = set()
|
||||
current = class_id
|
||||
while current:
|
||||
cls_def = self.ontology_subset.classes.get(current)
|
||||
if not cls_def:
|
||||
break
|
||||
parent = cls_def.get("subclass_of") if isinstance(cls_def, dict) else getattr(cls_def, "subclass_of", None)
|
||||
if not parent or parent in ancestors:
|
||||
break
|
||||
ancestors.add(parent)
|
||||
current = parent
|
||||
return ancestors
|
||||
|
||||
def _matches_class_constraint(self, actual_type: str, expected_type: str) -> bool:
|
||||
if actual_type == expected_type:
|
||||
return True
|
||||
return expected_type in self._get_ancestor_classes(actual_type)
|
||||
|
||||
def convert_all(self, extraction: ExtractionResult) -> List[Triple]:
|
||||
"""Convert complete extraction result to RDF triples.
|
||||
|
||||
|
|
@ -129,6 +148,29 @@ class TripleConverter:
|
|||
logger.warning(f"Unknown relationship '{relationship.relation}', skipping")
|
||||
return None
|
||||
|
||||
# Enforce domain/range constraints when declared
|
||||
prop_def = self.ontology_subset.object_properties.get(
|
||||
relationship.relation, {}
|
||||
)
|
||||
domain = prop_def.get("domain") if isinstance(prop_def, dict) else getattr(prop_def, "domain", None)
|
||||
range_ = prop_def.get("range") if isinstance(prop_def, dict) else getattr(prop_def, "range", None)
|
||||
|
||||
if domain and not self._matches_class_constraint(relationship.subject_type, domain):
|
||||
logger.warning(
|
||||
f"Domain violation: '{relationship.relation}' expects "
|
||||
f"domain '{domain}', got subject type "
|
||||
f"'{relationship.subject_type}', skipping"
|
||||
)
|
||||
return None
|
||||
|
||||
if range_ and not self._matches_class_constraint(relationship.object_type, range_):
|
||||
logger.warning(
|
||||
f"Range violation: '{relationship.relation}' expects "
|
||||
f"range '{range_}', got object type "
|
||||
f"'{relationship.object_type}', skipping"
|
||||
)
|
||||
return None
|
||||
|
||||
# Generate triple: subject property object
|
||||
return Triple(
|
||||
s=Term(type=IRI, iri=subject_uri),
|
||||
|
|
@ -157,11 +199,25 @@ class TripleConverter:
|
|||
logger.warning(f"Unknown attribute '{attribute.attribute}', skipping")
|
||||
return None
|
||||
|
||||
# Enforce domain constraint when declared
|
||||
prop_def = self.ontology_subset.datatype_properties.get(
|
||||
attribute.attribute, {}
|
||||
)
|
||||
domain = prop_def.get("domain") if isinstance(prop_def, dict) else getattr(prop_def, "domain", None)
|
||||
|
||||
if domain and not self._matches_class_constraint(attribute.entity_type, domain):
|
||||
logger.warning(
|
||||
f"Domain violation: attribute '{attribute.attribute}' "
|
||||
f"expects domain '{domain}', got entity type "
|
||||
f"'{attribute.entity_type}', skipping"
|
||||
)
|
||||
return None
|
||||
|
||||
# Generate triple: entity property "literal value"
|
||||
return Triple(
|
||||
s=Term(type=IRI, iri=entity_uri),
|
||||
p=Term(type=IRI, iri=property_uri),
|
||||
o=Term(type=LITERAL, value=attribute.value) # Literal!
|
||||
o=Term(type=LITERAL, value=attribute.value)
|
||||
)
|
||||
|
||||
def _get_class_uri(self, class_id: str) -> Optional[str]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue