trustgraph/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py
2025-11-12 17:55:10 +00:00

628 lines
No EOL
24 KiB
Python

"""
OntoRAG: Ontology-based knowledge extraction service.
Extracts ontology-conformant triples from text chunks.
"""
import json
import logging
import asyncio
from typing import List, Dict, Any, Optional
from .... schema import Chunk, Triple, Triples, Metadata, Value
from .... schema import EntityContext, EntityContexts
from .... schema import PromptRequest, PromptResponse
from .... rdf import TRUSTGRAPH_ENTITIES, RDF_TYPE, RDF_LABEL, DEFINITION
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec, EmbeddingsClientSpec
from .ontology_loader import OntologyLoader
from .ontology_embedder import OntologyEmbedder
from .vector_store import InMemoryVectorStore
from .text_processor import TextProcessor
from .ontology_selector import OntologySelector, OntologySubset
logger = logging.getLogger(__name__)
default_ident = "kg-extract-ontology"
default_concurrency = 1
# URI prefix mappings for common namespaces
URI_PREFIXES = {
"rdf:": "http://www.w3.org/1999/02/22-rdf-syntax-ns#",
"rdfs:": "http://www.w3.org/2000/01/rdf-schema#",
"owl:": "http://www.w3.org/2002/07/owl#",
"skos:": "http://www.w3.org/2004/02/skos/core#",
"schema:": "https://schema.org/",
"xsd:": "http://www.w3.org/2001/XMLSchema#",
}
class Processor(FlowProcessor):
"""Main OntoRAG extraction processor."""
def __init__(self, **params):
id = params.get("id", default_ident)
concurrency = params.get("concurrency", default_concurrency)
super(Processor, self).__init__(
**params | {
"id": id,
"concurrency": concurrency,
}
)
# Register specifications
self.register_specification(
ConsumerSpec(
name="input",
schema=Chunk,
handler=self.on_message,
concurrency=concurrency,
)
)
self.register_specification(
PromptClientSpec(
request_name="prompt-request",
response_name="prompt-response",
)
)
self.register_specification(
EmbeddingsClientSpec(
request_name="embeddings-request",
response_name="embeddings-response"
)
)
self.register_specification(
ProducerSpec(
name="triples",
schema=Triples
)
)
self.register_specification(
ProducerSpec(
name="entity-contexts",
schema=EntityContexts
)
)
# Register config handler for ontology updates
self.register_config_handler(self.on_ontology_config)
# Shared components (not flow-specific)
self.ontology_loader = OntologyLoader()
self.text_processor = TextProcessor()
# Per-flow components (each flow gets its own embedder/vector store/selector)
self.flow_components = {} # flow_id -> {embedder, vector_store, selector}
# Configuration
self.top_k = params.get("top_k", 10)
self.similarity_threshold = params.get("similarity_threshold", 0.3)
# Track loaded ontology version
self.current_ontology_version = None
self.loaded_ontology_ids = set()
async def initialize_flow_components(self, flow):
"""Initialize per-flow OntoRAG components.
Each flow gets its own vector store and embedder to support
different embedding models across flows. The vector store dimension
is auto-detected from the embeddings service.
Args:
flow: Flow object for this processing context
Returns:
flow_id: Identifier for this flow's components
"""
# Use flow object as identifier
flow_id = id(flow)
if flow_id in self.flow_components:
return flow_id # Already initialized for this flow
try:
logger.info(f"Initializing components for flow {flow_id}")
# Use embeddings client directly (no wrapper needed)
embeddings_client = flow("embeddings-request")
# Detect embedding dimension by embedding a test string
logger.info("Detecting embedding dimension from embeddings service...")
test_embedding_response = await embeddings_client.embed("test")
test_embedding = test_embedding_response[0] # Extract from [[vector]]
dimension = len(test_embedding)
logger.info(f"Detected embedding dimension: {dimension}")
# Initialize vector store with detected dimension
vector_store = InMemoryVectorStore(
dimension=dimension,
index_type='flat'
)
ontology_embedder = OntologyEmbedder(
embedding_service=embeddings_client,
vector_store=vector_store
)
# Embed all loaded ontologies for this flow
if self.ontology_loader.get_all_ontologies():
logger.info(f"Embedding ontologies for flow {flow_id}")
for ont_id, ontology in self.ontology_loader.get_all_ontologies().items():
await ontology_embedder.embed_ontology(ontology)
logger.info(f"Embedded {ontology_embedder.get_embedded_count()} ontology elements for flow {flow_id}")
# Initialize ontology selector
ontology_selector = OntologySelector(
ontology_embedder=ontology_embedder,
ontology_loader=self.ontology_loader,
top_k=self.top_k,
similarity_threshold=self.similarity_threshold
)
# Store flow-specific components
self.flow_components[flow_id] = {
'embedder': ontology_embedder,
'vector_store': vector_store,
'selector': ontology_selector,
'dimension': dimension
}
logger.info(f"Flow {flow_id} components initialized successfully (dimension={dimension})")
return flow_id
except Exception as e:
logger.error(f"Failed to initialize flow {flow_id} components: {e}", exc_info=True)
raise
async def on_ontology_config(self, config, version):
"""
Handle ontology configuration updates from ConfigPush queue.
Parses and stores ontologies. Embedding happens per-flow on first message.
Called automatically when:
- Processor starts (gets full config history via start_of_messages=True)
- Config service pushes updates (immediate event-driven notification)
Args:
config: Full configuration map - config[type][key] = value
version: Config version number (monotonically increasing)
"""
try:
logger.info(f"Received ontology config update, version={version}")
# Skip if we've already processed this version
if version == self.current_ontology_version:
logger.debug(f"Already at version {version}, skipping")
return
# Extract ontology configurations
if "ontology" not in config:
logger.warning("No 'ontology' section in config")
return
ontology_configs = config["ontology"]
# Parse ontology definitions
ontologies = {}
for ont_id, ont_json in ontology_configs.items():
try:
ontologies[ont_id] = json.loads(ont_json)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse ontology '{ont_id}': {e}")
continue
logger.info(f"Loaded {len(ontologies)} ontology definitions")
# Determine what changed (for incremental updates)
new_ids = set(ontologies.keys())
added_ids = new_ids - self.loaded_ontology_ids
removed_ids = self.loaded_ontology_ids - new_ids
updated_ids = new_ids & self.loaded_ontology_ids # May have changed content
if added_ids:
logger.info(f"New ontologies: {added_ids}")
if removed_ids:
logger.info(f"Removed ontologies: {removed_ids}")
if updated_ids:
logger.info(f"Updated ontologies: {updated_ids}")
# Update ontology loader's internal state
self.ontology_loader.update_ontologies(ontologies)
# Clear all flow components to force re-embedding with new ontologies
if added_ids or removed_ids or updated_ids:
logger.info("Clearing flow components to trigger re-embedding")
self.flow_components.clear()
# Update tracking
self.current_ontology_version = version
self.loaded_ontology_ids = new_ids
logger.info(f"Ontology config update complete, version={version}")
except Exception as e:
logger.error(f"Failed to process ontology config: {e}", exc_info=True)
async def on_message(self, msg, consumer, flow):
"""Process incoming chunk message."""
v = msg.value()
logger.info(f"Extracting ontology-based triples from {v.metadata.id}...")
# Initialize flow-specific components if needed
flow_id = await self.initialize_flow_components(flow)
components = self.flow_components[flow_id]
chunk = v.chunk.decode("utf-8")
logger.debug(f"Processing chunk: {chunk[:200]}...")
try:
# Process text into segments
segments = self.text_processor.process_chunk(chunk, extract_phrases=True)
logger.debug(f"Split chunk into {len(segments)} segments")
# Select relevant ontology subset (using flow-specific selector)
ontology_subsets = await components['selector'].select_ontology_subset(segments)
if not ontology_subsets:
logger.warning("No relevant ontology elements found for chunk")
# Emit empty outputs
await self.emit_triples(
flow("triples"),
v.metadata,
[]
)
await self.emit_entity_contexts(
flow("entity-contexts"),
v.metadata,
[]
)
return
# Merge subsets if multiple ontologies matched
if len(ontology_subsets) > 1:
ontology_subset = components['selector'].merge_subsets(ontology_subsets)
else:
ontology_subset = ontology_subsets[0]
logger.debug(f"Selected ontology subset with {len(ontology_subset.classes)} classes, "
f"{len(ontology_subset.object_properties)} object properties, "
f"{len(ontology_subset.datatype_properties)} datatype properties")
# Build extraction prompt variables
prompt_variables = self.build_extraction_variables(chunk, ontology_subset)
# Call prompt service for extraction
try:
# Use prompt() method with extract-with-ontologies prompt ID
triples_response = await flow("prompt-request").prompt(
id="extract-with-ontologies",
variables=prompt_variables
)
logger.debug(f"Extraction response: {triples_response}")
if not isinstance(triples_response, list):
logger.error("Expected list of triples from prompt service")
triples_response = []
except Exception as e:
logger.error(f"Prompt service error: {e}", exc_info=True)
triples_response = []
# Parse and validate triples
triples = self.parse_and_validate_triples(triples_response, ontology_subset)
# Add metadata triples
for t in v.metadata.metadata:
triples.append(t)
# Build entity contexts from triples
entity_contexts = self.build_entity_contexts(triples)
# Emit triples
await self.emit_triples(
flow("triples"),
v.metadata,
triples
)
# Emit entity contexts
await self.emit_entity_contexts(
flow("entity-contexts"),
v.metadata,
entity_contexts
)
logger.info(f"Extracted {len(triples)} ontology-conformant triples and {len(entity_contexts)} entity contexts")
except Exception as e:
logger.error(f"OntoRAG extraction exception: {e}", exc_info=True)
# Emit empty outputs on error
await self.emit_triples(
flow("triples"),
v.metadata,
[]
)
await self.emit_entity_contexts(
flow("entity-contexts"),
v.metadata,
[]
)
def build_extraction_variables(self, chunk: str, ontology_subset: OntologySubset) -> Dict[str, Any]:
"""Build variables for ontology-based extraction prompt template.
Args:
chunk: Text chunk to extract from
ontology_subset: Relevant ontology elements
Returns:
Dict with template variables: text, classes, object_properties, datatype_properties
"""
return {
"text": chunk,
"classes": ontology_subset.classes,
"object_properties": ontology_subset.object_properties,
"datatype_properties": ontology_subset.datatype_properties
}
def parse_and_validate_triples(self, triples_response: List[Any],
ontology_subset: OntologySubset) -> List[Triple]:
"""Parse and validate extracted triples against ontology."""
validated_triples = []
ontology_id = ontology_subset.ontology_id
for triple_data in triples_response:
try:
if isinstance(triple_data, dict):
subject = triple_data.get('subject', '')
predicate = triple_data.get('predicate', '')
object_val = triple_data.get('object', '')
if not subject or not predicate or not object_val:
continue
# Validate against ontology
if self.is_valid_triple(subject, predicate, object_val, ontology_subset):
# Expand URIs before creating Value objects
subject_uri = self.expand_uri(subject, ontology_subset, ontology_id)
predicate_uri = self.expand_uri(predicate, ontology_subset, ontology_id)
# Object might be URI or literal - check before expanding
if self.is_uri(object_val) or self.should_expand_as_uri(object_val, ontology_subset):
object_uri = self.expand_uri(object_val, ontology_subset, ontology_id)
is_object_uri = True
else:
object_uri = object_val
is_object_uri = False
# Create Triple object with expanded URIs
s_value = Value(value=subject_uri, is_uri=True)
p_value = Value(value=predicate_uri, is_uri=True)
o_value = Value(value=object_uri, is_uri=is_object_uri)
validated_triples.append(Triple(
s=s_value,
p=p_value,
o=o_value
))
else:
logger.debug(f"Invalid triple: ({subject}, {predicate}, {object_val})")
except Exception as e:
logger.error(f"Error parsing triple: {e}")
return validated_triples
def should_expand_as_uri(self, value: str, ontology_subset: OntologySubset) -> bool:
"""Check if a value should be treated as URI (not literal).
Returns True if value is a class name, property name, or entity reference.
"""
# Check if it's a class or property from ontology
if value in ontology_subset.classes:
return True
if value in ontology_subset.object_properties:
return True
if value in ontology_subset.datatype_properties:
return True
# Check if it starts with a known prefix
for prefix in URI_PREFIXES.keys():
if value.startswith(prefix):
return True
# Check if it looks like an entity reference (e.g., "recipe:cornish-pasty")
if ":" in value and not value.startswith("http"):
return True
return False
def is_valid_triple(self, subject: str, predicate: str, object_val: str,
ontology_subset: OntologySubset) -> bool:
"""Validate triple against ontology constraints."""
# Special case for rdf:type
if predicate == "rdf:type" or predicate == str(RDF_TYPE):
# Check if object is a valid class
return object_val in ontology_subset.classes
# Special case for rdfs:label
if predicate == "rdfs:label" or predicate == str(RDF_LABEL):
return True # Labels are always valid
# Check if predicate is a valid property
is_obj_prop = predicate in ontology_subset.object_properties
is_dt_prop = predicate in ontology_subset.datatype_properties
if not is_obj_prop and not is_dt_prop:
return False # Unknown property
# TODO: Add more sophisticated validation (domain/range checking)
return True
def expand_uri(self, value: str, ontology_subset: OntologySubset, ontology_id: str = "unknown") -> str:
"""Expand prefix notation or short names to full URIs.
Args:
value: Value to expand (e.g., "rdf:type", "Recipe", "has_ingredient")
ontology_subset: Ontology subset for class/property lookup
ontology_id: ID of the ontology for constructing instance URIs
Returns:
Full URI string
"""
# Already a full URI
if value.startswith("http://") or value.startswith("https://"):
return value
# Check standard prefixes (rdf:, rdfs:, etc.)
for prefix, namespace in URI_PREFIXES.items():
if value.startswith(prefix):
return namespace + value[len(prefix):]
# Check if it's an ontology class
if value in ontology_subset.classes:
class_def = ontology_subset.classes[value]
# class_def is a dict (from cls.__dict__ in ontology_selector)
if isinstance(class_def, dict) and 'uri' in class_def and class_def['uri']:
return class_def['uri']
# Fallback: construct URI
return f"https://trustgraph.ai/ontology/{ontology_id}#{value}"
# Check if it's an ontology property
if value in ontology_subset.object_properties:
prop_def = ontology_subset.object_properties[value]
# prop_def is a dict (from prop.__dict__ in ontology_selector)
if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']:
return prop_def['uri']
return f"https://trustgraph.ai/ontology/{ontology_id}#{value}"
if value in ontology_subset.datatype_properties:
prop_def = ontology_subset.datatype_properties[value]
# prop_def is a dict (from prop.__dict__ in ontology_selector)
if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']:
return prop_def['uri']
return f"https://trustgraph.ai/ontology/{ontology_id}#{value}"
# Otherwise, treat as entity instance - construct unique URI
# Normalize the value for URI (lowercase, replace spaces with hyphens)
normalized = value.replace(" ", "-").lower()
return f"https://trustgraph.ai/{ontology_id}/{normalized}"
def is_uri(self, value: str) -> bool:
"""Check if value is already a full URI."""
return value.startswith("http://") or value.startswith("https://")
async def emit_triples(self, pub, metadata: Metadata, triples: List[Triple]):
"""Emit triples to output."""
t = Triples(
metadata=Metadata(
id=metadata.id,
metadata=[],
user=metadata.user,
collection=metadata.collection,
),
triples=triples,
)
await pub.send(t)
async def emit_entity_contexts(self, pub, metadata: Metadata, entities: List[EntityContext]):
"""Emit entity contexts to output."""
ec = EntityContexts(
metadata=Metadata(
id=metadata.id,
metadata=[],
user=metadata.user,
collection=metadata.collection,
),
entities=entities,
)
await pub.send(ec)
def build_entity_contexts(self, triples: List[Triple]) -> List[EntityContext]:
"""Build entity contexts from extracted triples.
Collects rdfs:label and definition properties for each entity to create
contextual descriptions for embedding.
Args:
triples: List of extracted triples
Returns:
List of EntityContext objects
"""
# Group triples by subject to collect entity information
entity_data = {} # subject_uri -> {labels: [], definitions: []}
for triple in triples:
subject_uri = triple.s.value
predicate_uri = triple.p.value
object_val = triple.o.value
# Initialize entity data if not exists
if subject_uri not in entity_data:
entity_data[subject_uri] = {'labels': [], 'definitions': []}
# Collect labels (rdfs:label)
if predicate_uri == RDF_LABEL:
if not triple.o.is_uri: # Labels are literals
entity_data[subject_uri]['labels'].append(object_val)
# Collect definitions (skos:definition, schema:description)
elif predicate_uri == DEFINITION or predicate_uri == "https://schema.org/description":
if not triple.o.is_uri:
entity_data[subject_uri]['definitions'].append(object_val)
# Build EntityContext objects
entity_contexts = []
for subject_uri, data in entity_data.items():
# Build context text from labels and definitions
context_parts = []
if data['labels']:
context_parts.append(f"Label: {data['labels'][0]}")
if data['definitions']:
context_parts.extend(data['definitions'])
# Only create EntityContext if we have meaningful context
if context_parts:
context_text = ". ".join(context_parts)
entity_contexts.append(EntityContext(
entity=Value(value=subject_uri, is_uri=True),
context=context_text
))
logger.debug(f"Built {len(entity_contexts)} entity contexts from {len(triples)} triples")
return entity_contexts
@staticmethod
def add_args(parser):
"""Add command-line arguments."""
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Concurrent processing threads (default: {default_concurrency})'
)
parser.add_argument(
'--top-k',
type=int,
default=10,
help='Number of top ontology elements to retrieve (default: 10)'
)
parser.add_argument(
'--similarity-threshold',
type=float,
default=0.3,
help='Similarity threshold for ontology matching (default: 0.3, range: 0.0-1.0)'
)
FlowProcessor.add_args(parser)
def run():
"""Launch the OntoRAG extraction service."""
Processor.launch(default_ident, __doc__)