diff --git a/tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py b/tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py index 9f9c8551..bae6bdbd 100644 --- a/tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py +++ b/tests/unit/test_extract/test_ontology/test_prompt_and_extraction.py @@ -231,6 +231,52 @@ class TestTripleValidation: is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset) assert is_valid == expected, f"Validation of {predicate} should be {expected}" + def test_validates_domain_correctly_with_entity_types(self, extractor, sample_ontology_subset): + """Test domain validation correctly compares against extracted entity_types.""" + subject = "my-recipe" + predicate = "produces" + object_val = "my-food" + + # Proper domain for produces is Recipe + entity_types = { + "my-recipe": "Recipe", + "my-food": "Food" + } + + is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types) + assert is_valid, "Valid domain should be accepted" + + # Invalid domain + entity_types_invalid = { + "my-recipe": "Ingredient", + "my-food": "Food" + } + is_invalid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types_invalid) + assert not is_invalid, "Invalid domain should be rejected" + + def test_validates_range_correctly_with_entity_types(self, extractor, sample_ontology_subset): + """Test range validation correctly compares against extracted entity_types.""" + subject = "my-recipe" + predicate = "produces" + object_val = "my-food" + + # Proper range for produces is Food + entity_types = { + "my-recipe": "Recipe", + "my-food": "Food" + } + + is_valid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types) + assert is_valid, "Valid range should be accepted" + + # Invalid range + entity_types_invalid = { + "my-recipe": "Recipe", + "my-food": "Recipe" + } + is_invalid = extractor.is_valid_triple(subject, predicate, object_val, sample_ontology_subset, entity_types_invalid) + assert not is_invalid, "Invalid range should be rejected" + class TestTripleParsing: """Test suite for parsing triples from LLM responses.""" diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index bdb0e6e8..e024ad40 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -429,6 +429,16 @@ class Processor(FlowProcessor): validated_triples = [] ontology_id = ontology_subset.ontology_id + # Gather entity types for domain/range validation + entity_types = {} + for triple_data in triples_response: + if isinstance(triple_data, dict): + s = triple_data.get('subject', '') + p = triple_data.get('predicate', '') + o = triple_data.get('object', '') + if s and p and o and (p == "rdf:type" or p == str(RDF_TYPE)): + entity_types[s] = o + for triple_data in triples_response: try: if isinstance(triple_data, dict): @@ -440,7 +450,7 @@ class Processor(FlowProcessor): continue # Validate against ontology - if self.is_valid_triple(subject, predicate, object_val, ontology_subset): + if self.is_valid_triple(subject, predicate, object_val, ontology_subset, entity_types): # Expand URIs before creating Value objects subject_uri = self.expand_uri(subject, ontology_subset, ontology_id) predicate_uri = self.expand_uri(predicate, ontology_subset, ontology_id) @@ -493,8 +503,11 @@ class Processor(FlowProcessor): return False def is_valid_triple(self, subject: str, predicate: str, object_val: str, - ontology_subset: OntologySubset) -> bool: + ontology_subset: OntologySubset, entity_types: dict = None) -> bool: """Validate triple against ontology constraints.""" + if entity_types is None: + entity_types = {} + # Special case for rdf:type if predicate == "rdf:type" or predicate == str(RDF_TYPE): # Check if object is a valid class @@ -511,7 +524,45 @@ class Processor(FlowProcessor): if not is_obj_prop and not is_dt_prop: return False # Unknown property - # TODO: Add more sophisticated validation (domain/range checking) + prop_def = ontology_subset.object_properties[predicate] if is_obj_prop else ontology_subset.datatype_properties[predicate] + if not isinstance(prop_def, dict): + prop_def = prop_def.__dict__ if hasattr(prop_def, '__dict__') else {} + + # Domain validation + expected_domain = prop_def.get('domain') + if expected_domain and subject in entity_types: + actual_domain = entity_types[subject] + if actual_domain != expected_domain: + is_subclass = False + curr_class = actual_domain + while curr_class in ontology_subset.classes: + cls_def = ontology_subset.classes[curr_class] + parent = cls_def.get('subclass_of') if isinstance(cls_def, dict) else None + if parent == expected_domain: + is_subclass = True + break + curr_class = parent + if not is_subclass: + return False + + # Range validation + if is_obj_prop: + expected_range = prop_def.get('range') + if expected_range and object_val in entity_types: + actual_range = entity_types[object_val] + if actual_range != expected_range: + is_subclass = False + curr_class = actual_range + while curr_class in ontology_subset.classes: + cls_def = ontology_subset.classes[curr_class] + parent = cls_def.get('subclass_of') if isinstance(cls_def, dict) else None + if parent == expected_range: + is_subclass = True + break + curr_class = parent + if not is_subclass: + return False + return True def expand_uri(self, value: str, ontology_subset: OntologySubset, ontology_id: str = "unknown") -> str: