OntoRAG: Ontology-Based Knowledge Extraction and Query Technical Specification (#523)

* Onto-rag tech spec

* New processor kg-extract-ontology, use 'ontology' objects from config to guide triple extraction

* Also entity contexts

* Integrate with ontology extractor from workbench

This is first phase, the extraction is tested and working, also GraphRAG with the extracted knowledge works
This commit is contained in:
cybermaggedon 2025-11-12 20:38:08 +00:00 committed by GitHub
parent 4c3db4dbbe
commit c69f5207a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 11824 additions and 0 deletions

View file

@ -16,6 +16,7 @@ dependencies = [
"scylla-driver",
"cohere",
"cryptography",
"faiss-cpu",
"falkordb",
"fastembed",
"google-genai",
@ -29,6 +30,7 @@ dependencies = [
"minio",
"mistralai",
"neo4j",
"nltk",
"ollama",
"openai",
"pinecone[grpc]",
@ -82,6 +84,7 @@ kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
kg-extract-objects = "trustgraph.extract.kg.objects:run"
kg-extract-relationships = "trustgraph.extract.kg.relationships:run"
kg-extract-topics = "trustgraph.extract.kg.topics:run"
kg-extract-ontology = "trustgraph.extract.kg.ontology:run"
kg-manager = "trustgraph.cores:run"
kg-store = "trustgraph.storage.knowledge:run"
librarian = "trustgraph.librarian:run"

View file

@ -0,0 +1 @@
from . extract import *

View file

@ -0,0 +1,848 @@
"""
OntoRAG: Ontology-based knowledge extraction service.
Extracts ontology-conformant triples from text chunks.
"""
import json
import logging
import asyncio
from typing import List, Dict, Any, Optional
from .... schema import Chunk, Triple, Triples, Metadata, Value
from .... schema import EntityContext, EntityContexts
from .... schema import PromptRequest, PromptResponse
from .... rdf import TRUSTGRAPH_ENTITIES, RDF_TYPE, RDF_LABEL, DEFINITION
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec, EmbeddingsClientSpec
from .ontology_loader import OntologyLoader
from .ontology_embedder import OntologyEmbedder
from .vector_store import InMemoryVectorStore
from .text_processor import TextProcessor
from .ontology_selector import OntologySelector, OntologySubset
logger = logging.getLogger(__name__)
default_ident = "kg-extract-ontology"
default_concurrency = 1
# URI prefix mappings for common namespaces
URI_PREFIXES = {
"rdf:": "http://www.w3.org/1999/02/22-rdf-syntax-ns#",
"rdfs:": "http://www.w3.org/2000/01/rdf-schema#",
"owl:": "http://www.w3.org/2002/07/owl#",
"skos:": "http://www.w3.org/2004/02/skos/core#",
"schema:": "https://schema.org/",
"xsd:": "http://www.w3.org/2001/XMLSchema#",
}
class Processor(FlowProcessor):
"""Main OntoRAG extraction processor."""
def __init__(self, **params):
id = params.get("id", default_ident)
concurrency = params.get("concurrency", default_concurrency)
super(Processor, self).__init__(
**params | {
"id": id,
"concurrency": concurrency,
}
)
# Register specifications
self.register_specification(
ConsumerSpec(
name="input",
schema=Chunk,
handler=self.on_message,
concurrency=concurrency,
)
)
self.register_specification(
PromptClientSpec(
request_name="prompt-request",
response_name="prompt-response",
)
)
self.register_specification(
EmbeddingsClientSpec(
request_name="embeddings-request",
response_name="embeddings-response"
)
)
self.register_specification(
ProducerSpec(
name="triples",
schema=Triples
)
)
self.register_specification(
ProducerSpec(
name="entity-contexts",
schema=EntityContexts
)
)
# Register config handler for ontology updates
self.register_config_handler(self.on_ontology_config)
# Shared components (not flow-specific)
self.ontology_loader = OntologyLoader()
self.text_processor = TextProcessor()
# Per-flow components (each flow gets its own embedder/vector store/selector)
self.flow_components = {} # flow_id -> {embedder, vector_store, selector}
# Configuration
self.top_k = params.get("top_k", 10)
self.similarity_threshold = params.get("similarity_threshold", 0.3)
# Track loaded ontology version
self.current_ontology_version = None
self.loaded_ontology_ids = set()
async def initialize_flow_components(self, flow):
"""Initialize per-flow OntoRAG components.
Each flow gets its own vector store and embedder to support
different embedding models across flows. The vector store dimension
is auto-detected from the embeddings service.
Args:
flow: Flow object for this processing context
Returns:
flow_id: Identifier for this flow's components
"""
# Use flow object as identifier
flow_id = id(flow)
if flow_id in self.flow_components:
return flow_id # Already initialized for this flow
try:
logger.info(f"Initializing components for flow {flow_id}")
# Use embeddings client directly (no wrapper needed)
embeddings_client = flow("embeddings-request")
# Detect embedding dimension by embedding a test string
logger.info("Detecting embedding dimension from embeddings service...")
test_embedding_response = await embeddings_client.embed("test")
test_embedding = test_embedding_response[0] # Extract from [[vector]]
dimension = len(test_embedding)
logger.info(f"Detected embedding dimension: {dimension}")
# Initialize vector store with detected dimension
vector_store = InMemoryVectorStore(
dimension=dimension,
index_type='flat'
)
ontology_embedder = OntologyEmbedder(
embedding_service=embeddings_client,
vector_store=vector_store
)
# Embed all loaded ontologies for this flow
if self.ontology_loader.get_all_ontologies():
logger.info(f"Embedding ontologies for flow {flow_id}")
for ont_id, ontology in self.ontology_loader.get_all_ontologies().items():
await ontology_embedder.embed_ontology(ontology)
logger.info(f"Embedded {ontology_embedder.get_embedded_count()} ontology elements for flow {flow_id}")
# Initialize ontology selector
ontology_selector = OntologySelector(
ontology_embedder=ontology_embedder,
ontology_loader=self.ontology_loader,
top_k=self.top_k,
similarity_threshold=self.similarity_threshold
)
# Store flow-specific components
self.flow_components[flow_id] = {
'embedder': ontology_embedder,
'vector_store': vector_store,
'selector': ontology_selector,
'dimension': dimension
}
logger.info(f"Flow {flow_id} components initialized successfully (dimension={dimension})")
return flow_id
except Exception as e:
logger.error(f"Failed to initialize flow {flow_id} components: {e}", exc_info=True)
raise
async def on_ontology_config(self, config, version):
"""
Handle ontology configuration updates from ConfigPush queue.
Parses and stores ontologies. Embedding happens per-flow on first message.
Called automatically when:
- Processor starts (gets full config history via start_of_messages=True)
- Config service pushes updates (immediate event-driven notification)
Args:
config: Full configuration map - config[type][key] = value
version: Config version number (monotonically increasing)
"""
try:
logger.info(f"Received ontology config update, version={version}")
# Skip if we've already processed this version
if version == self.current_ontology_version:
logger.debug(f"Already at version {version}, skipping")
return
# Extract ontology configurations
if "ontology" not in config:
logger.warning("No 'ontology' section in config")
return
ontology_configs = config["ontology"]
# Parse ontology definitions
ontologies = {}
for ont_id, ont_json in ontology_configs.items():
try:
ontologies[ont_id] = json.loads(ont_json)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse ontology '{ont_id}': {e}")
continue
logger.info(f"Loaded {len(ontologies)} ontology definitions")
# Determine what changed (for incremental updates)
new_ids = set(ontologies.keys())
added_ids = new_ids - self.loaded_ontology_ids
removed_ids = self.loaded_ontology_ids - new_ids
updated_ids = new_ids & self.loaded_ontology_ids # May have changed content
if added_ids:
logger.info(f"New ontologies: {added_ids}")
if removed_ids:
logger.info(f"Removed ontologies: {removed_ids}")
if updated_ids:
logger.info(f"Updated ontologies: {updated_ids}")
# Update ontology loader's internal state
self.ontology_loader.update_ontologies(ontologies)
# Clear all flow components to force re-embedding with new ontologies
if added_ids or removed_ids or updated_ids:
logger.info("Clearing flow components to trigger re-embedding")
self.flow_components.clear()
# Update tracking
self.current_ontology_version = version
self.loaded_ontology_ids = new_ids
logger.info(f"Ontology config update complete, version={version}")
except Exception as e:
logger.error(f"Failed to process ontology config: {e}", exc_info=True)
async def on_message(self, msg, consumer, flow):
"""Process incoming chunk message."""
v = msg.value()
logger.info(f"Extracting ontology-based triples from {v.metadata.id}...")
# Initialize flow-specific components if needed
flow_id = await self.initialize_flow_components(flow)
components = self.flow_components[flow_id]
chunk = v.chunk.decode("utf-8")
logger.debug(f"Processing chunk: {chunk[:200]}...")
try:
# Process text into segments
segments = self.text_processor.process_chunk(chunk, extract_phrases=True)
logger.debug(f"Split chunk into {len(segments)} segments")
# Select relevant ontology subset (using flow-specific selector)
ontology_subsets = await components['selector'].select_ontology_subset(segments)
if not ontology_subsets:
logger.warning("No relevant ontology elements found for chunk")
# Emit empty outputs
await self.emit_triples(
flow("triples"),
v.metadata,
[]
)
await self.emit_entity_contexts(
flow("entity-contexts"),
v.metadata,
[]
)
return
# Merge subsets if multiple ontologies matched
if len(ontology_subsets) > 1:
ontology_subset = components['selector'].merge_subsets(ontology_subsets)
else:
ontology_subset = ontology_subsets[0]
logger.debug(f"Selected ontology subset with {len(ontology_subset.classes)} classes, "
f"{len(ontology_subset.object_properties)} object properties, "
f"{len(ontology_subset.datatype_properties)} datatype properties")
# Build extraction prompt variables
prompt_variables = self.build_extraction_variables(chunk, ontology_subset)
# Call prompt service for extraction
try:
# Use prompt() method with extract-with-ontologies prompt ID
triples_response = await flow("prompt-request").prompt(
id="extract-with-ontologies",
variables=prompt_variables
)
logger.debug(f"Extraction response: {triples_response}")
if not isinstance(triples_response, list):
logger.error("Expected list of triples from prompt service")
triples_response = []
except Exception as e:
logger.error(f"Prompt service error: {e}", exc_info=True)
triples_response = []
# Parse and validate triples
triples = self.parse_and_validate_triples(triples_response, ontology_subset)
# Add metadata triples
for t in v.metadata.metadata:
triples.append(t)
# Generate ontology definition triples
ontology_triples = self.build_ontology_triples(ontology_subset)
# Combine extracted triples with ontology triples
all_triples = triples + ontology_triples
# Build entity contexts from all triples (including ontology elements)
entity_contexts = self.build_entity_contexts(all_triples)
# Emit all triples (extracted + ontology definitions)
await self.emit_triples(
flow("triples"),
v.metadata,
all_triples
)
# Emit entity contexts
await self.emit_entity_contexts(
flow("entity-contexts"),
v.metadata,
entity_contexts
)
logger.info(f"Extracted {len(triples)} content triples + {len(ontology_triples)} ontology triples "
f"= {len(all_triples)} total triples and {len(entity_contexts)} entity contexts")
except Exception as e:
logger.error(f"OntoRAG extraction exception: {e}", exc_info=True)
# Emit empty outputs on error
await self.emit_triples(
flow("triples"),
v.metadata,
[]
)
await self.emit_entity_contexts(
flow("entity-contexts"),
v.metadata,
[]
)
def build_extraction_variables(self, chunk: str, ontology_subset: OntologySubset) -> Dict[str, Any]:
"""Build variables for ontology-based extraction prompt template.
Args:
chunk: Text chunk to extract from
ontology_subset: Relevant ontology elements
Returns:
Dict with template variables: text, classes, object_properties, datatype_properties
"""
return {
"text": chunk,
"classes": ontology_subset.classes,
"object_properties": ontology_subset.object_properties,
"datatype_properties": ontology_subset.datatype_properties
}
def parse_and_validate_triples(self, triples_response: List[Any],
ontology_subset: OntologySubset) -> List[Triple]:
"""Parse and validate extracted triples against ontology."""
validated_triples = []
ontology_id = ontology_subset.ontology_id
for triple_data in triples_response:
try:
if isinstance(triple_data, dict):
subject = triple_data.get('subject', '')
predicate = triple_data.get('predicate', '')
object_val = triple_data.get('object', '')
if not subject or not predicate or not object_val:
continue
# Validate against ontology
if self.is_valid_triple(subject, predicate, object_val, ontology_subset):
# Expand URIs before creating Value objects
subject_uri = self.expand_uri(subject, ontology_subset, ontology_id)
predicate_uri = self.expand_uri(predicate, ontology_subset, ontology_id)
# Object might be URI or literal - check before expanding
if self.is_uri(object_val) or self.should_expand_as_uri(object_val, ontology_subset):
object_uri = self.expand_uri(object_val, ontology_subset, ontology_id)
is_object_uri = True
else:
object_uri = object_val
is_object_uri = False
# Create Triple object with expanded URIs
s_value = Value(value=subject_uri, is_uri=True)
p_value = Value(value=predicate_uri, is_uri=True)
o_value = Value(value=object_uri, is_uri=is_object_uri)
validated_triples.append(Triple(
s=s_value,
p=p_value,
o=o_value
))
else:
logger.debug(f"Invalid triple: ({subject}, {predicate}, {object_val})")
except Exception as e:
logger.error(f"Error parsing triple: {e}")
return validated_triples
def should_expand_as_uri(self, value: str, ontology_subset: OntologySubset) -> bool:
"""Check if a value should be treated as URI (not literal).
Returns True if value is a class name, property name, or entity reference.
"""
# Check if it's a class or property from ontology
if value in ontology_subset.classes:
return True
if value in ontology_subset.object_properties:
return True
if value in ontology_subset.datatype_properties:
return True
# Check if it starts with a known prefix
for prefix in URI_PREFIXES.keys():
if value.startswith(prefix):
return True
# Check if it looks like an entity reference (e.g., "recipe:cornish-pasty")
if ":" in value and not value.startswith("http"):
return True
return False
def is_valid_triple(self, subject: str, predicate: str, object_val: str,
ontology_subset: OntologySubset) -> bool:
"""Validate triple against ontology constraints."""
# Special case for rdf:type
if predicate == "rdf:type" or predicate == str(RDF_TYPE):
# Check if object is a valid class
return object_val in ontology_subset.classes
# Special case for rdfs:label
if predicate == "rdfs:label" or predicate == str(RDF_LABEL):
return True # Labels are always valid
# Check if predicate is a valid property
is_obj_prop = predicate in ontology_subset.object_properties
is_dt_prop = predicate in ontology_subset.datatype_properties
if not is_obj_prop and not is_dt_prop:
return False # Unknown property
# TODO: Add more sophisticated validation (domain/range checking)
return True
def expand_uri(self, value: str, ontology_subset: OntologySubset, ontology_id: str = "unknown") -> str:
"""Expand prefix notation or short names to full URIs.
Args:
value: Value to expand (e.g., "rdf:type", "Recipe", "has_ingredient")
ontology_subset: Ontology subset for class/property lookup
ontology_id: ID of the ontology for constructing instance URIs
Returns:
Full URI string
"""
# Already a full URI
if value.startswith("http://") or value.startswith("https://"):
return value
# Check standard prefixes (rdf:, rdfs:, etc.)
for prefix, namespace in URI_PREFIXES.items():
if value.startswith(prefix):
return namespace + value[len(prefix):]
# Check if it's an ontology class
if value in ontology_subset.classes:
class_def = ontology_subset.classes[value]
# class_def is a dict (from cls.__dict__ in ontology_selector)
if isinstance(class_def, dict) and 'uri' in class_def and class_def['uri']:
return class_def['uri']
# Fallback: construct URI
return f"https://trustgraph.ai/ontology/{ontology_id}#{value}"
# Check if it's an ontology property
if value in ontology_subset.object_properties:
prop_def = ontology_subset.object_properties[value]
# prop_def is a dict (from prop.__dict__ in ontology_selector)
if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']:
return prop_def['uri']
return f"https://trustgraph.ai/ontology/{ontology_id}#{value}"
if value in ontology_subset.datatype_properties:
prop_def = ontology_subset.datatype_properties[value]
# prop_def is a dict (from prop.__dict__ in ontology_selector)
if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']:
return prop_def['uri']
return f"https://trustgraph.ai/ontology/{ontology_id}#{value}"
# Otherwise, treat as entity instance - construct unique URI
# Normalize the value for URI (lowercase, replace spaces with hyphens)
normalized = value.replace(" ", "-").lower()
return f"https://trustgraph.ai/{ontology_id}/{normalized}"
def is_uri(self, value: str) -> bool:
"""Check if value is already a full URI."""
return value.startswith("http://") or value.startswith("https://")
async def emit_triples(self, pub, metadata: Metadata, triples: List[Triple]):
"""Emit triples to output."""
t = Triples(
metadata=Metadata(
id=metadata.id,
metadata=[],
user=metadata.user,
collection=metadata.collection,
),
triples=triples,
)
await pub.send(t)
async def emit_entity_contexts(self, pub, metadata: Metadata, entities: List[EntityContext]):
"""Emit entity contexts to output."""
ec = EntityContexts(
metadata=Metadata(
id=metadata.id,
metadata=[],
user=metadata.user,
collection=metadata.collection,
),
entities=entities,
)
await pub.send(ec)
def build_ontology_triples(self, ontology_subset: OntologySubset) -> List[Triple]:
"""Build triples describing the ontology elements themselves.
Generates triples for classes and properties so they exist in the knowledge graph.
Args:
ontology_subset: The ontology subset used for extraction
Returns:
List of Triple objects describing ontology elements
"""
ontology_triples = []
# Generate triples for classes
for class_id, class_def in ontology_subset.classes.items():
# Get URI for class
if isinstance(class_def, dict) and 'uri' in class_def and class_def['uri']:
class_uri = class_def['uri']
else:
# Fallback to constructed URI
class_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{class_id}"
# rdf:type owl:Class
ontology_triples.append(Triple(
s=Value(value=class_uri, is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://www.w3.org/2002/07/owl#Class", is_uri=True)
))
# rdfs:label (stored as 'labels' in OntologyClass.__dict__)
if isinstance(class_def, dict) and 'labels' in class_def:
labels = class_def['labels']
if isinstance(labels, list) and labels:
label_val = labels[0].get('value', class_id) if isinstance(labels[0], dict) else str(labels[0])
ontology_triples.append(Triple(
s=Value(value=class_uri, is_uri=True),
p=Value(value=RDF_LABEL, is_uri=True),
o=Value(value=label_val, is_uri=False)
))
# rdfs:comment (stored as 'comment' in OntologyClass.__dict__)
if isinstance(class_def, dict) and 'comment' in class_def and class_def['comment']:
comment = class_def['comment']
ontology_triples.append(Triple(
s=Value(value=class_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
o=Value(value=comment, is_uri=False)
))
# rdfs:subClassOf (stored as 'subclass_of' in OntologyClass.__dict__)
if isinstance(class_def, dict) and 'subclass_of' in class_def and class_def['subclass_of']:
parent = class_def['subclass_of']
# Get parent URI
if parent in ontology_subset.classes:
parent_class_def = ontology_subset.classes[parent]
if isinstance(parent_class_def, dict) and 'uri' in parent_class_def and parent_class_def['uri']:
parent_uri = parent_class_def['uri']
else:
parent_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{parent}"
else:
parent_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{parent}"
ontology_triples.append(Triple(
s=Value(value=class_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#subClassOf", is_uri=True),
o=Value(value=parent_uri, is_uri=True)
))
# Generate triples for object properties
for prop_id, prop_def in ontology_subset.object_properties.items():
# Get URI for property
if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']:
prop_uri = prop_def['uri']
else:
prop_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{prop_id}"
# rdf:type owl:ObjectProperty
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://www.w3.org/2002/07/owl#ObjectProperty", is_uri=True)
))
# rdfs:label (stored as 'labels' in OntologyProperty.__dict__)
if isinstance(prop_def, dict) and 'labels' in prop_def:
labels = prop_def['labels']
if isinstance(labels, list) and labels:
label_val = labels[0].get('value', prop_id) if isinstance(labels[0], dict) else str(labels[0])
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value=RDF_LABEL, is_uri=True),
o=Value(value=label_val, is_uri=False)
))
# rdfs:comment (stored as 'comment' in OntologyProperty.__dict__)
if isinstance(prop_def, dict) and 'comment' in prop_def and prop_def['comment']:
comment = prop_def['comment']
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
o=Value(value=comment, is_uri=False)
))
# rdfs:domain (stored as 'domain' in OntologyProperty.__dict__)
if isinstance(prop_def, dict) and 'domain' in prop_def and prop_def['domain']:
domain = prop_def['domain']
# Get domain class URI
if domain in ontology_subset.classes:
domain_class_def = ontology_subset.classes[domain]
if isinstance(domain_class_def, dict) and 'uri' in domain_class_def and domain_class_def['uri']:
domain_uri = domain_class_def['uri']
else:
domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}"
else:
domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}"
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True),
o=Value(value=domain_uri, is_uri=True)
))
# rdfs:range (stored as 'range' in OntologyProperty.__dict__)
if isinstance(prop_def, dict) and 'range' in prop_def and prop_def['range']:
range_val = prop_def['range']
# Get range class URI
if range_val in ontology_subset.classes:
range_class_def = ontology_subset.classes[range_val]
if isinstance(range_class_def, dict) and 'uri' in range_class_def and range_class_def['uri']:
range_uri = range_class_def['uri']
else:
range_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{range_val}"
else:
range_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{range_val}"
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#range", is_uri=True),
o=Value(value=range_uri, is_uri=True)
))
# Generate triples for datatype properties
for prop_id, prop_def in ontology_subset.datatype_properties.items():
# Get URI for property
if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']:
prop_uri = prop_def['uri']
else:
prop_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{prop_id}"
# rdf:type owl:DatatypeProperty
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://www.w3.org/2002/07/owl#DatatypeProperty", is_uri=True)
))
# rdfs:label (stored as 'labels' in OntologyProperty.__dict__)
if isinstance(prop_def, dict) and 'labels' in prop_def:
labels = prop_def['labels']
if isinstance(labels, list) and labels:
label_val = labels[0].get('value', prop_id) if isinstance(labels[0], dict) else str(labels[0])
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value=RDF_LABEL, is_uri=True),
o=Value(value=label_val, is_uri=False)
))
# rdfs:comment (stored as 'comment' in OntologyProperty.__dict__)
if isinstance(prop_def, dict) and 'comment' in prop_def and prop_def['comment']:
comment = prop_def['comment']
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
o=Value(value=comment, is_uri=False)
))
# rdfs:domain (stored as 'domain' in OntologyProperty.__dict__)
if isinstance(prop_def, dict) and 'domain' in prop_def and prop_def['domain']:
domain = prop_def['domain']
# Get domain class URI
if domain in ontology_subset.classes:
domain_class_def = ontology_subset.classes[domain]
if isinstance(domain_class_def, dict) and 'uri' in domain_class_def and domain_class_def['uri']:
domain_uri = domain_class_def['uri']
else:
domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}"
else:
domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}"
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True),
o=Value(value=domain_uri, is_uri=True)
))
# rdfs:range (datatype)
if isinstance(prop_def, dict) and 'rdfs:range' in prop_def and prop_def['rdfs:range']:
range_val = prop_def['rdfs:range']
# Range for datatype properties is usually xsd:string, xsd:int, etc.
if range_val.startswith('xsd:'):
range_uri = f"http://www.w3.org/2001/XMLSchema#{range_val[4:]}"
else:
range_uri = range_val
ontology_triples.append(Triple(
s=Value(value=prop_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#range", is_uri=True),
o=Value(value=range_uri, is_uri=True)
))
logger.info(f"Generated {len(ontology_triples)} triples describing ontology elements")
return ontology_triples
def build_entity_contexts(self, triples: List[Triple]) -> List[EntityContext]:
"""Build entity contexts from extracted triples.
Collects rdfs:label and definition properties for each entity to create
contextual descriptions for embedding.
Args:
triples: List of extracted triples
Returns:
List of EntityContext objects
"""
# Group triples by subject to collect entity information
entity_data = {} # subject_uri -> {labels: [], definitions: []}
for triple in triples:
subject_uri = triple.s.value
predicate_uri = triple.p.value
object_val = triple.o.value
# Initialize entity data if not exists
if subject_uri not in entity_data:
entity_data[subject_uri] = {'labels': [], 'definitions': []}
# Collect labels (rdfs:label)
if predicate_uri == RDF_LABEL:
if not triple.o.is_uri: # Labels are literals
entity_data[subject_uri]['labels'].append(object_val)
# Collect definitions (skos:definition, schema:description)
elif predicate_uri == DEFINITION or predicate_uri == "https://schema.org/description":
if not triple.o.is_uri:
entity_data[subject_uri]['definitions'].append(object_val)
# Build EntityContext objects
entity_contexts = []
for subject_uri, data in entity_data.items():
# Build context text from labels and definitions
context_parts = []
if data['labels']:
context_parts.append(f"Label: {data['labels'][0]}")
if data['definitions']:
context_parts.extend(data['definitions'])
# Only create EntityContext if we have meaningful context
if context_parts:
context_text = ". ".join(context_parts)
entity_contexts.append(EntityContext(
entity=Value(value=subject_uri, is_uri=True),
context=context_text
))
logger.debug(f"Built {len(entity_contexts)} entity contexts from {len(triples)} triples")
return entity_contexts
@staticmethod
def add_args(parser):
"""Add command-line arguments."""
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Concurrent processing threads (default: {default_concurrency})'
)
parser.add_argument(
'--top-k',
type=int,
default=10,
help='Number of top ontology elements to retrieve (default: 10)'
)
parser.add_argument(
'--similarity-threshold',
type=float,
default=0.3,
help='Similarity threshold for ontology matching (default: 0.3, range: 0.0-1.0)'
)
FlowProcessor.add_args(parser)
def run():
"""Launch the OntoRAG extraction service."""
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,310 @@
"""
Ontology embedder component for OntoRAG system.
Generates and stores embeddings for ontology elements.
"""
import asyncio
import logging
import numpy as np
from typing import Dict, List, Any, Optional
from dataclasses import dataclass
from .ontology_loader import Ontology, OntologyClass, OntologyProperty
from .vector_store import InMemoryVectorStore
logger = logging.getLogger(__name__)
@dataclass
class OntologyElementMetadata:
"""Metadata for an embedded ontology element."""
type: str # 'class', 'objectProperty', 'datatypeProperty'
ontology: str # Ontology ID
element: str # Element ID
definition: Dict[str, Any] # Full element definition
text: str # Text used for embedding
class OntologyEmbedder:
"""Generates embeddings for ontology elements and stores them in vector store."""
def __init__(self, embedding_service=None, vector_store: Optional[InMemoryVectorStore] = None):
"""Initialize the ontology embedder.
Args:
embedding_service: Service for generating embeddings
vector_store: Vector store instance (InMemoryVectorStore)
"""
self.embedding_service = embedding_service
self.vector_store = vector_store or InMemoryVectorStore()
self.embedded_ontologies = set()
def _create_text_representation(self, element_id: str, element: Any,
element_type: str) -> str:
"""Create text representation of an ontology element for embedding.
Args:
element_id: ID of the element
element: The element object (OntologyClass or OntologyProperty)
element_type: Type of element
Returns:
Text representation for embedding
"""
parts = []
# Add the element ID (often meaningful)
parts.append(element_id.replace('-', ' ').replace('_', ' '))
# Add labels
if hasattr(element, 'labels') and element.labels:
for label in element.labels:
if isinstance(label, dict):
parts.append(label.get('value', ''))
else:
parts.append(str(label))
# Add comment/description
if hasattr(element, 'comment') and element.comment:
parts.append(element.comment)
# Add type-specific information
if element_type == 'class':
if hasattr(element, 'subclass_of') and element.subclass_of:
parts.append(f"subclass of {element.subclass_of}")
elif element_type in ['objectProperty', 'datatypeProperty']:
if hasattr(element, 'domain') and element.domain:
parts.append(f"domain: {element.domain}")
if hasattr(element, 'range') and element.range:
parts.append(f"range: {element.range}")
# Join all parts with spaces
text = ' '.join(filter(None, parts))
return text
async def embed_ontology(self, ontology: Ontology) -> int:
"""Generate and store embeddings for all elements in an ontology.
Args:
ontology: The ontology to embed
Returns:
Number of elements embedded
"""
if not self.embedding_service:
logger.warning("No embedding service available, skipping embedding")
return 0
embedded_count = 0
batch_size = 50 # Process embeddings in batches
# Collect all elements to embed
elements_to_embed = []
# Process classes
for class_id, class_def in ontology.classes.items():
text = self._create_text_representation(class_id, class_def, 'class')
elements_to_embed.append({
'id': f"{ontology.id}:class:{class_id}",
'text': text,
'metadata': OntologyElementMetadata(
type='class',
ontology=ontology.id,
element=class_id,
definition=class_def.__dict__,
text=text
).__dict__
})
# Process object properties
for prop_id, prop_def in ontology.object_properties.items():
text = self._create_text_representation(prop_id, prop_def, 'objectProperty')
elements_to_embed.append({
'id': f"{ontology.id}:objectProperty:{prop_id}",
'text': text,
'metadata': OntologyElementMetadata(
type='objectProperty',
ontology=ontology.id,
element=prop_id,
definition=prop_def.__dict__,
text=text
).__dict__
})
# Process datatype properties
for prop_id, prop_def in ontology.datatype_properties.items():
text = self._create_text_representation(prop_id, prop_def, 'datatypeProperty')
elements_to_embed.append({
'id': f"{ontology.id}:datatypeProperty:{prop_id}",
'text': text,
'metadata': OntologyElementMetadata(
type='datatypeProperty',
ontology=ontology.id,
element=prop_id,
definition=prop_def.__dict__,
text=text
).__dict__
})
# Process in batches
for i in range(0, len(elements_to_embed), batch_size):
batch = elements_to_embed[i:i + batch_size]
# Get embeddings for batch
texts = [elem['text'] for elem in batch]
try:
# Call embedding service for each text
# Note: embed() returns 2D array [[vector]], so extract first element
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
embeddings_responses = await asyncio.gather(*embedding_tasks)
# Extract vectors from responses (each is [[vector]])
embeddings_list = [resp[0] for resp in embeddings_responses]
# Convert to numpy array
embeddings = np.array(embeddings_list)
# Log embedding shape for debugging
logger.debug(f"Embeddings shape: {embeddings.shape}, expected: ({len(batch)}, {self.vector_store.dimension})")
# Store in vector store
ids = [elem['id'] for elem in batch]
metadata_list = [elem['metadata'] for elem in batch]
self.vector_store.add_batch(ids, embeddings, metadata_list)
embedded_count += len(batch)
logger.debug(f"Embedded batch of {len(batch)} elements from ontology {ontology.id}")
except Exception as e:
logger.error(f"Failed to embed batch for ontology {ontology.id}: {e}", exc_info=True)
self.embedded_ontologies.add(ontology.id)
logger.info(f"Embedded {embedded_count} elements from ontology {ontology.id}")
return embedded_count
async def embed_ontologies(self, ontologies: Dict[str, Ontology]) -> int:
"""Generate and store embeddings for multiple ontologies.
Args:
ontologies: Dictionary of ontology ID to Ontology objects
Returns:
Total number of elements embedded
"""
total_embedded = 0
for ont_id, ontology in ontologies.items():
if ont_id not in self.embedded_ontologies:
count = await self.embed_ontology(ontology)
total_embedded += count
else:
logger.debug(f"Ontology {ont_id} already embedded, skipping")
logger.info(f"Total embedded elements: {total_embedded} from {len(ontologies)} ontologies")
return total_embedded
async def embed_text(self, text: str) -> Optional[np.ndarray]:
"""Generate embedding for a single text.
Args:
text: Text to embed
Returns:
Embedding vector or None if failed
"""
if not self.embedding_service:
logger.warning("No embedding service available")
return None
try:
# embed() returns 2D array [[vector]], extract first element
embedding_response = await self.embedding_service.embed(text)
return np.array(embedding_response[0])
except Exception as e:
logger.error(f"Failed to embed text: {e}")
return None
async def embed_texts(self, texts: List[str]) -> Optional[np.ndarray]:
"""Generate embeddings for multiple texts.
Args:
texts: List of texts to embed
Returns:
Array of embeddings or None if failed
"""
if not self.embedding_service:
logger.warning("No embedding service available")
return None
try:
# Call embed() for each text (returns [[vector]] per call)
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
embeddings_responses = await asyncio.gather(*embedding_tasks)
# Extract first vector from each response
embeddings_list = [resp[0] for resp in embeddings_responses]
return np.array(embeddings_list)
except Exception as e:
logger.error(f"Failed to embed texts: {e}")
return None
def remove_ontology(self, ontology_id: str):
"""Remove all embeddings for a specific ontology.
Note: FAISS doesn't support efficient deletion, so this currently
requires rebuilding the entire index without the removed ontology.
Args:
ontology_id: ID of ontology to remove
"""
if ontology_id not in self.embedded_ontologies:
logger.debug(f"Ontology '{ontology_id}' not embedded, nothing to remove")
return
# FAISS doesn't support selective deletion, so we'd need to rebuild the index
# For now, just remove from tracking set
# TODO: Implement index rebuilding if selective removal is needed
self.embedded_ontologies.discard(ontology_id)
logger.info(f"Removed ontology '{ontology_id}' from embedded set (note: vectors still in store)")
def clear_embeddings(self, ontology_id: Optional[str] = None):
"""Clear embeddings from vector store.
Args:
ontology_id: If provided, only clear embeddings for this ontology
Otherwise, clear all embeddings
"""
if ontology_id:
self.remove_ontology(ontology_id)
else:
self.vector_store.clear()
self.embedded_ontologies.clear()
logger.info("Cleared all embeddings from vector store")
def get_vector_store(self) -> InMemoryVectorStore:
"""Get the vector store instance.
Returns:
The vector store being used
"""
return self.vector_store
def get_embedded_count(self) -> int:
"""Get the number of embedded elements.
Returns:
Number of elements in the vector store
"""
return self.vector_store.size()
def is_ontology_embedded(self, ontology_id: str) -> bool:
"""Check if an ontology has been embedded.
Args:
ontology_id: ID of the ontology
Returns:
True if the ontology has been embedded
"""
return ontology_id in self.embedded_ontologies

View file

@ -0,0 +1,247 @@
"""
Ontology loader component for OntoRAG system.
Loads and manages ontologies from configuration service.
"""
import json
import logging
from typing import Dict, Any, Optional, List
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
@dataclass
class OntologyClass:
"""Represents an OWL-like class in the ontology."""
uri: str
type: str = "owl:Class"
labels: List[Dict[str, str]] = field(default_factory=list)
comment: Optional[str] = None
subclass_of: Optional[str] = None
equivalent_classes: List[str] = field(default_factory=list)
disjoint_with: List[str] = field(default_factory=list)
identifier: Optional[str] = None
@staticmethod
def from_dict(class_id: str, data: Dict[str, Any]) -> 'OntologyClass':
"""Create OntologyClass from dictionary representation."""
labels = data.get('rdfs:label', [])
if isinstance(labels, list):
labels = labels
else:
labels = [labels] if labels else []
return OntologyClass(
uri=data.get('uri', ''),
type=data.get('type', 'owl:Class'),
labels=labels,
comment=data.get('rdfs:comment'),
subclass_of=data.get('rdfs:subClassOf'),
equivalent_classes=data.get('owl:equivalentClass', []),
disjoint_with=data.get('owl:disjointWith', []),
identifier=data.get('dcterms:identifier')
)
@dataclass
class OntologyProperty:
"""Represents a property (object or datatype) in the ontology."""
uri: str
type: str
labels: List[Dict[str, str]] = field(default_factory=list)
comment: Optional[str] = None
domain: Optional[str] = None
range: Optional[str] = None
inverse_of: Optional[str] = None
functional: bool = False
inverse_functional: bool = False
min_cardinality: Optional[int] = None
max_cardinality: Optional[int] = None
cardinality: Optional[int] = None
@staticmethod
def from_dict(prop_id: str, data: Dict[str, Any]) -> 'OntologyProperty':
"""Create OntologyProperty from dictionary representation."""
labels = data.get('rdfs:label', [])
if isinstance(labels, list):
labels = labels
else:
labels = [labels] if labels else []
return OntologyProperty(
uri=data.get('uri', ''),
type=data.get('type', ''),
labels=labels,
comment=data.get('rdfs:comment'),
domain=data.get('rdfs:domain'),
range=data.get('rdfs:range'),
inverse_of=data.get('owl:inverseOf'),
functional=data.get('owl:functionalProperty', False),
inverse_functional=data.get('owl:inverseFunctionalProperty', False),
min_cardinality=data.get('owl:minCardinality'),
max_cardinality=data.get('owl:maxCardinality'),
cardinality=data.get('owl:cardinality')
)
@dataclass
class Ontology:
"""Represents a complete ontology with metadata, classes, and properties."""
id: str
metadata: Dict[str, Any]
classes: Dict[str, OntologyClass]
object_properties: Dict[str, OntologyProperty]
datatype_properties: Dict[str, OntologyProperty]
def get_class(self, class_id: str) -> Optional[OntologyClass]:
"""Get a class by ID."""
return self.classes.get(class_id)
def get_property(self, prop_id: str) -> Optional[OntologyProperty]:
"""Get a property (object or datatype) by ID."""
prop = self.object_properties.get(prop_id)
if prop is None:
prop = self.datatype_properties.get(prop_id)
return prop
def get_parent_classes(self, class_id: str) -> List[str]:
"""Get all parent classes (following subClassOf hierarchy)."""
parents = []
current = class_id
visited = set()
while current and current not in visited:
visited.add(current)
cls = self.get_class(current)
if cls and cls.subclass_of:
parents.append(cls.subclass_of)
current = cls.subclass_of
else:
break
return parents
def validate_structure(self) -> List[str]:
"""Validate ontology structure and return list of issues."""
issues = []
# Check for circular inheritance
for class_id in self.classes:
visited = set()
current = class_id
while current:
if current in visited:
issues.append(f"Circular inheritance detected for class {class_id}")
break
visited.add(current)
cls = self.get_class(current)
if cls:
current = cls.subclass_of
else:
break
# Check property domains and ranges exist
for prop_id, prop in {**self.object_properties, **self.datatype_properties}.items():
if prop.domain and prop.domain not in self.classes:
issues.append(f"Property {prop_id} has unknown domain {prop.domain}")
if prop.type == "owl:ObjectProperty" and prop.range and prop.range not in self.classes:
issues.append(f"Object property {prop_id} has unknown range class {prop.range}")
# Check disjoint classes
for class_id, cls in self.classes.items():
for disjoint_id in cls.disjoint_with:
if disjoint_id not in self.classes:
issues.append(f"Class {class_id} disjoint with unknown class {disjoint_id}")
return issues
class OntologyLoader:
"""Manages ontologies received via event-driven config updates.
No direct database access - receives ontologies via config handler.
"""
def __init__(self):
"""Initialize empty ontology store."""
self.ontologies: Dict[str, Ontology] = {}
def update_ontologies(self, ontology_configs: Dict[str, Any]):
"""Update ontology definitions from config.
Args:
ontology_configs: Dict mapping ontology_id -> ontology_definition (parsed dicts)
"""
self.ontologies.clear()
for ont_id, ont_data in ontology_configs.items():
try:
# Parse classes
classes = {}
for class_id, class_data in ont_data.get('classes', {}).items():
classes[class_id] = OntologyClass.from_dict(class_id, class_data)
# Parse object properties
object_props = {}
for prop_id, prop_data in ont_data.get('objectProperties', {}).items():
object_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data)
# Parse datatype properties
datatype_props = {}
for prop_id, prop_data in ont_data.get('datatypeProperties', {}).items():
datatype_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data)
# Create ontology
ontology = Ontology(
id=ont_id,
metadata=ont_data.get('metadata', {}),
classes=classes,
object_properties=object_props,
datatype_properties=datatype_props
)
# Validate structure
issues = ontology.validate_structure()
if issues:
logger.warning(f"Ontology {ont_id} has validation issues: {issues}")
self.ontologies[ont_id] = ontology
logger.info(f"Loaded ontology {ont_id} with {len(classes)} classes, "
f"{len(object_props)} object properties, "
f"{len(datatype_props)} datatype properties")
except Exception as e:
logger.error(f"Failed to load ontology {ont_id}: {e}", exc_info=True)
def get_ontology(self, ont_id: str) -> Optional[Ontology]:
"""Get a specific ontology by ID.
Args:
ont_id: Ontology identifier
Returns:
Ontology object or None if not found
"""
return self.ontologies.get(ont_id)
def get_all_ontologies(self) -> Dict[str, Ontology]:
"""Get all loaded ontologies.
Returns:
Dictionary of ontology ID to Ontology objects
"""
return self.ontologies
def list_ontology_ids(self) -> List[str]:
"""Get list of loaded ontology IDs.
Returns:
List of ontology IDs
"""
return list(self.ontologies.keys())
def clear(self):
"""Clear all loaded ontologies."""
self.ontologies.clear()
logger.info("Cleared all loaded ontologies")

View file

@ -0,0 +1,356 @@
"""
Ontology selection algorithm for OntoRAG system.
Selects relevant ontology subsets based on text similarity.
"""
import logging
from typing import List, Dict, Any, Set, Optional, Tuple
from dataclasses import dataclass
from collections import defaultdict
from .ontology_loader import Ontology, OntologyLoader
from .ontology_embedder import OntologyEmbedder
from .text_processor import TextSegment
from .vector_store import SearchResult
logger = logging.getLogger(__name__)
@dataclass
class OntologySubset:
"""Represents a subset of an ontology relevant to a text chunk."""
ontology_id: str
classes: Dict[str, Any]
object_properties: Dict[str, Any]
datatype_properties: Dict[str, Any]
metadata: Dict[str, Any]
relevance_score: float = 0.0
class OntologySelector:
"""Selects relevant ontology elements for text segments using vector similarity."""
def __init__(self, ontology_embedder: OntologyEmbedder,
ontology_loader: OntologyLoader,
top_k: int = 10,
similarity_threshold: float = 0.7):
"""Initialize the ontology selector.
Args:
ontology_embedder: Embedder with vector store
ontology_loader: Loader with ontology definitions
top_k: Number of top results to retrieve per segment
similarity_threshold: Minimum similarity score
"""
self.embedder = ontology_embedder
self.loader = ontology_loader
self.top_k = top_k
self.similarity_threshold = similarity_threshold
async def select_ontology_subset(self, segments: List[TextSegment]) -> List[OntologySubset]:
"""Select relevant ontology subsets for text segments.
Args:
segments: List of text segments to match
Returns:
List of ontology subsets with relevant elements
"""
# Collect all relevant elements
relevant_elements = await self._find_relevant_elements(segments)
# Group by ontology and build subsets
ontology_subsets = self._build_ontology_subsets(relevant_elements)
# Resolve dependencies
for subset in ontology_subsets:
self._resolve_dependencies(subset)
logger.info(f"Selected {len(ontology_subsets)} ontology subsets")
return ontology_subsets
async def _find_relevant_elements(self, segments: List[TextSegment]) -> Set[Tuple[str, str, str, Dict]]:
"""Find relevant ontology elements for text segments.
Args:
segments: Text segments to match
Returns:
Set of (ontology_id, element_type, element_id, definition) tuples
"""
relevant_elements = set()
element_scores = defaultdict(float)
# Check if vector store has any elements
vector_store = self.embedder.get_vector_store()
store_size = vector_store.size()
logger.info(f"Vector store size: {store_size} elements")
if store_size == 0:
logger.warning("Vector store is empty - no ontology elements embedded")
return relevant_elements
# Process each segment (log first few for debugging)
for i, segment in enumerate(segments):
# Get embedding for segment
embedding = await self.embedder.embed_text(segment.text)
if embedding is None:
logger.warning(f"Failed to embed segment: {segment.text[:50]}...")
continue
# Search vector store with no threshold to see all scores
all_results = vector_store.search(
embedding=embedding,
top_k=self.top_k,
threshold=0.0 # Get all results to see scores
)
# Log top scores for first 3 segments to debug
if i < 3 and all_results:
top_scores = [r.score for r in all_results[:3]]
top_elements = [r.metadata['element'] for r in all_results[:3]]
logger.info(f"Segment {i}: '{segment.text[:60]}...'")
logger.info(f" Top 3 scores: {top_scores} (threshold={self.similarity_threshold})")
logger.info(f" Top 3 elements: {top_elements}")
# Filter by threshold
results = [r for r in all_results if r.score >= self.similarity_threshold]
# Process results
for result in results:
metadata = result.metadata
element_key = (
metadata['ontology'],
metadata['type'],
metadata['element'],
str(metadata['definition']) # Convert dict to string for hashability
)
relevant_elements.add(element_key)
# Track scores for ranking
element_scores[element_key] = max(element_scores[element_key], result.score)
logger.info(f"Found {len(relevant_elements)} relevant elements from {len(segments)} segments")
return relevant_elements
def _build_ontology_subsets(self, relevant_elements: Set[Tuple[str, str, str, Dict]]) -> List[OntologySubset]:
"""Build ontology subsets from relevant elements.
Args:
relevant_elements: Set of relevant element tuples
Returns:
List of ontology subsets
"""
# Group elements by ontology
ontology_groups = defaultdict(lambda: {
'classes': {},
'object_properties': {},
'datatype_properties': {},
'scores': []
})
for ont_id, elem_type, elem_id, definition in relevant_elements:
# Parse definition back from string if needed
if isinstance(definition, str):
import json
try:
definition = json.loads(definition.replace("'", '"'))
except:
definition = eval(definition) # Fallback for dict-like strings
# Get the actual ontology and element
ontology = self.loader.get_ontology(ont_id)
if not ontology:
logger.warning(f"Ontology {ont_id} not found in loader")
continue
# Add element to appropriate category
if elem_type == 'class':
cls = ontology.get_class(elem_id)
if cls:
ontology_groups[ont_id]['classes'][elem_id] = cls.__dict__
elif elem_type == 'objectProperty':
prop = ontology.object_properties.get(elem_id)
if prop:
ontology_groups[ont_id]['object_properties'][elem_id] = prop.__dict__
elif elem_type == 'datatypeProperty':
prop = ontology.datatype_properties.get(elem_id)
if prop:
ontology_groups[ont_id]['datatype_properties'][elem_id] = prop.__dict__
# Create OntologySubset objects
subsets = []
for ont_id, elements in ontology_groups.items():
ontology = self.loader.get_ontology(ont_id)
if ontology:
subset = OntologySubset(
ontology_id=ont_id,
classes=elements['classes'],
object_properties=elements['object_properties'],
datatype_properties=elements['datatype_properties'],
metadata=ontology.metadata,
relevance_score=sum(elements['scores']) / len(elements['scores']) if elements['scores'] else 0.0
)
subsets.append(subset)
return subsets
def _resolve_dependencies(self, subset: OntologySubset):
"""Resolve dependencies for ontology subset elements.
Args:
subset: Ontology subset to resolve dependencies for
"""
ontology = self.loader.get_ontology(subset.ontology_id)
if not ontology:
return
# Track classes to add
classes_to_add = set()
# Resolve class hierarchies
for class_id in list(subset.classes.keys()):
# Add parent classes
parents = ontology.get_parent_classes(class_id)
for parent_id in parents:
parent_class = ontology.get_class(parent_id)
if parent_class and parent_id not in subset.classes:
classes_to_add.add(parent_id)
# Resolve property domains and ranges
for prop_id, prop_def in subset.object_properties.items():
# Add domain class
if 'domain' in prop_def and prop_def['domain']:
domain_id = prop_def['domain']
if domain_id not in subset.classes:
domain_class = ontology.get_class(domain_id)
if domain_class:
classes_to_add.add(domain_id)
# Add range class
if 'range' in prop_def and prop_def['range']:
range_id = prop_def['range']
if range_id not in subset.classes:
range_class = ontology.get_class(range_id)
if range_class:
classes_to_add.add(range_id)
# Resolve datatype property domains
for prop_id, prop_def in subset.datatype_properties.items():
if 'domain' in prop_def and prop_def['domain']:
domain_id = prop_def['domain']
if domain_id not in subset.classes:
domain_class = ontology.get_class(domain_id)
if domain_class:
classes_to_add.add(domain_id)
# Add inverse properties
for prop_id, prop_def in list(subset.object_properties.items()):
if 'inverse_of' in prop_def and prop_def['inverse_of']:
inverse_id = prop_def['inverse_of']
if inverse_id not in subset.object_properties:
inverse_prop = ontology.object_properties.get(inverse_id)
if inverse_prop:
subset.object_properties[inverse_id] = inverse_prop.__dict__
# NEW: Auto-include properties related to selected classes
# For each selected class, find all properties that reference it in domain or range
properties_added = 0
datatype_properties_added = 0
for class_id in list(subset.classes.keys()):
# Check all object properties in the ontology
for prop_id, prop_def in ontology.object_properties.items():
if prop_id not in subset.object_properties:
# Check if this class is in the property's domain or range
prop_domain = getattr(prop_def, 'domain', None)
prop_range = getattr(prop_def, 'range', None)
if prop_domain == class_id or prop_range == class_id:
subset.object_properties[prop_id] = prop_def.__dict__
properties_added += 1
# Also add the other class (domain or range) if not already present
if prop_domain and prop_domain != class_id and prop_domain not in subset.classes:
other_class = ontology.get_class(prop_domain)
if other_class:
classes_to_add.add(prop_domain)
if prop_range and prop_range != class_id and prop_range not in subset.classes:
other_class = ontology.get_class(prop_range)
if other_class:
classes_to_add.add(prop_range)
# Check all datatype properties in the ontology
for prop_id, prop_def in ontology.datatype_properties.items():
if prop_id not in subset.datatype_properties:
# Check if this class is in the property's domain
prop_domain = getattr(prop_def, 'domain', None)
if prop_domain == class_id:
subset.datatype_properties[prop_id] = prop_def.__dict__
datatype_properties_added += 1
# Add collected classes
for class_id in classes_to_add:
cls = ontology.get_class(class_id)
if cls:
subset.classes[class_id] = cls.__dict__
logger.debug(f"Resolved dependencies for subset {subset.ontology_id}: "
f"added {len(classes_to_add)} classes, "
f"{properties_added} object properties, "
f"{datatype_properties_added} datatype properties")
def merge_subsets(self, subsets: List[OntologySubset]) -> OntologySubset:
"""Merge multiple ontology subsets into one.
Args:
subsets: List of subsets to merge
Returns:
Merged ontology subset
"""
if not subsets:
return None
if len(subsets) == 1:
return subsets[0]
# Use first subset as base
merged = OntologySubset(
ontology_id="merged",
classes={},
object_properties={},
datatype_properties={},
metadata={},
relevance_score=0.0
)
# Merge all subsets
total_score = 0.0
for subset in subsets:
# Merge classes
for class_id, class_def in subset.classes.items():
key = f"{subset.ontology_id}:{class_id}"
merged.classes[key] = class_def
# Merge object properties
for prop_id, prop_def in subset.object_properties.items():
key = f"{subset.ontology_id}:{prop_id}"
merged.object_properties[key] = prop_def
# Merge datatype properties
for prop_id, prop_def in subset.datatype_properties.items():
key = f"{subset.ontology_id}:{prop_id}"
merged.datatype_properties[key] = prop_def
total_score += subset.relevance_score
# Average relevance score
merged.relevance_score = total_score / len(subsets)
logger.info(f"Merged {len(subsets)} subsets into one with "
f"{len(merged.classes)} classes, "
f"{len(merged.object_properties)} object properties, "
f"{len(merged.datatype_properties)} datatype properties")
return merged

View file

@ -0,0 +1,10 @@
#!/usr/bin/env python3
"""
OntoRAG extraction service launcher.
"""
from . extract import run
if __name__ == "__main__":
run()

View file

@ -0,0 +1,240 @@
"""
Text processing components for OntoRAG system.
Splits text into sentences and extracts phrases for granular matching.
"""
import logging
import re
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import nltk
from nltk.corpus import stopwords
logger = logging.getLogger(__name__)
# Ensure required NLTK data is downloaded
try:
nltk.data.find('tokenizers/punkt_tab')
except LookupError:
try:
nltk.download('punkt_tab', quiet=True)
except:
# Fallback to older punkt if punkt_tab not available
try:
nltk.download('punkt', quiet=True)
except:
pass
try:
nltk.data.find('taggers/averaged_perceptron_tagger_eng')
except LookupError:
try:
nltk.download('averaged_perceptron_tagger_eng', quiet=True)
except:
# Fallback to older name
try:
nltk.download('averaged_perceptron_tagger', quiet=True)
except:
pass
try:
nltk.data.find('corpora/stopwords')
except LookupError:
nltk.download('stopwords', quiet=True)
@dataclass
class TextSegment:
"""Represents a segment of text (sentence or phrase)."""
text: str
type: str # 'sentence', 'phrase', 'noun_phrase', 'verb_phrase'
position: int
parent_sentence: Optional[str] = None
metadata: Dict[str, Any] = None
class SentenceSplitter:
"""Splits text into sentences using NLTK."""
def __init__(self):
"""Initialize sentence splitter."""
try:
# Try newer punkt_tab first
self.sent_detector = nltk.data.load('tokenizers/punkt_tab/english/')
logger.info("Using NLTK sentence tokenizer (punkt_tab)")
except:
# Fallback to older punkt
self.sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
logger.info("Using NLTK sentence tokenizer (punkt)")
def split(self, text: str) -> List[str]:
"""Split text into sentences.
Args:
text: Text to split
Returns:
List of sentences
"""
sentences = self.sent_detector.tokenize(text)
return sentences
class PhraseExtractor:
"""Extracts meaningful phrases from sentences using NLTK."""
def __init__(self):
"""Initialize phrase extractor."""
logger.info("Using NLTK phrase extraction")
def extract(self, sentence: str) -> List[Dict[str, str]]:
"""Extract phrases from a sentence.
Args:
sentence: Sentence to extract phrases from
Returns:
List of phrases with their types
"""
phrases = []
# Tokenize and POS tag
tokens = nltk.word_tokenize(sentence)
pos_tags = nltk.pos_tag(tokens)
# Extract noun phrases (simple pattern)
noun_phrase = []
for word, pos in pos_tags:
if pos.startswith('NN') or pos.startswith('JJ'):
noun_phrase.append(word)
elif noun_phrase:
if len(noun_phrase) > 1:
phrases.append({
'text': ' '.join(noun_phrase),
'type': 'noun_phrase'
})
noun_phrase = []
# Add last noun phrase if exists
if noun_phrase and len(noun_phrase) > 1:
phrases.append({
'text': ' '.join(noun_phrase),
'type': 'noun_phrase'
})
# Extract verb phrases (simple pattern)
verb_phrase = []
for word, pos in pos_tags:
if pos.startswith('VB') or pos.startswith('RB'):
verb_phrase.append(word)
elif verb_phrase:
if len(verb_phrase) > 1:
phrases.append({
'text': ' '.join(verb_phrase),
'type': 'verb_phrase'
})
verb_phrase = []
# Add last verb phrase if exists
if verb_phrase and len(verb_phrase) > 1:
phrases.append({
'text': ' '.join(verb_phrase),
'type': 'verb_phrase'
})
return phrases
class TextProcessor:
"""Main text processing class that coordinates sentence splitting and phrase extraction."""
def __init__(self):
"""Initialize text processor."""
self.sentence_splitter = SentenceSplitter()
self.phrase_extractor = PhraseExtractor()
def process_chunk(self, chunk_text: str, extract_phrases: bool = True) -> List[TextSegment]:
"""Process a text chunk into segments.
Args:
chunk_text: Text chunk to process
extract_phrases: Whether to extract phrases from sentences
Returns:
List of TextSegment objects
"""
segments = []
position = 0
# Split into sentences
sentences = self.sentence_splitter.split(chunk_text)
for sentence in sentences:
# Add sentence segment
segments.append(TextSegment(
text=sentence,
type='sentence',
position=position
))
position += 1
# Extract phrases if requested
if extract_phrases:
phrases = self.phrase_extractor.extract(sentence)
for phrase_data in phrases:
segments.append(TextSegment(
text=phrase_data['text'],
type=phrase_data['type'],
position=position,
parent_sentence=sentence
))
position += 1
logger.debug(f"Processed chunk into {len(segments)} segments")
return segments
def extract_key_terms(self, text: str) -> List[str]:
"""Extract key terms from text for matching.
Args:
text: Text to extract terms from
Returns:
List of key terms
"""
terms = []
# Split on word boundaries
words = re.findall(r'\b\w+\b', text.lower())
# Use NLTK stopwords
stop_words = set(stopwords.words('english'))
# Filter stopwords and short words
terms = [w for w in words if w not in stop_words and len(w) > 2]
# Also extract multi-word terms (bigrams)
for i in range(len(words) - 1):
if words[i] not in stop_words and words[i+1] not in stop_words:
bigram = f"{words[i]} {words[i+1]}"
terms.append(bigram)
return terms
def normalize_text(self, text: str) -> str:
"""Normalize text for consistent processing.
Args:
text: Text to normalize
Returns:
Normalized text
"""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text)
# Remove leading/trailing whitespace
text = text.strip()
# Normalize quotes
text = text.replace('"', '"').replace('"', '"')
text = text.replace(''', "'").replace(''', "'")
return text

View file

@ -0,0 +1,130 @@
"""
Vector store implementation for OntoRAG system.
Provides FAISS-based vector storage for ontology embeddings.
"""
import logging
import numpy as np
from typing import List, Dict, Any
from dataclasses import dataclass
import faiss
logger = logging.getLogger(__name__)
@dataclass
class SearchResult:
"""Result from vector similarity search."""
id: str
score: float
metadata: Dict[str, Any]
class InMemoryVectorStore:
"""FAISS-based vector store implementation for ontology embeddings."""
def __init__(self, dimension: int = 1536, index_type: str = 'flat'):
"""Initialize FAISS vector store.
Args:
dimension: Embedding dimension (1536 for text-embedding-3-small)
index_type: 'flat' for exact search, 'ivf' for larger datasets
"""
self.dimension = dimension
self.metadata = []
self.ids = []
if index_type == 'flat':
# Exact search - best for ontologies with <10k elements
self.index = faiss.IndexFlatIP(dimension)
logger.info(f"Created FAISS flat index with dimension {dimension}")
else:
# Approximate search - for larger ontologies
quantizer = faiss.IndexFlatIP(dimension)
self.index = faiss.IndexIVFFlat(quantizer, dimension, 100)
# Train with random vectors for initialization
training_data = np.random.randn(1000, dimension).astype('float32')
training_data = training_data / np.linalg.norm(
training_data, axis=1, keepdims=True
)
self.index.train(training_data)
logger.info(f"Created FAISS IVF index with dimension {dimension}")
def add(self, id: str, embedding: np.ndarray, metadata: Dict[str, Any]):
"""Add single embedding with metadata."""
# Normalize for cosine similarity
embedding = embedding / np.linalg.norm(embedding)
self.index.add(np.array([embedding], dtype=np.float32))
self.metadata.append(metadata)
self.ids.append(id)
def add_batch(self, ids: List[str], embeddings: np.ndarray,
metadata_list: List[Dict[str, Any]]):
"""Batch add for initial ontology loading."""
# Normalize all embeddings
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
normalized = embeddings / norms
self.index.add(normalized.astype(np.float32))
self.metadata.extend(metadata_list)
self.ids.extend(ids)
logger.debug(f"Added batch of {len(ids)} embeddings to FAISS index")
def search(self, embedding: np.ndarray, top_k: int = 10,
threshold: float = 0.0) -> List[SearchResult]:
"""Search for similar vectors."""
# Normalize query
embedding = embedding / np.linalg.norm(embedding)
# Search
scores, indices = self.index.search(
np.array([embedding], dtype=np.float32),
min(top_k, self.index.ntotal)
)
# Filter by threshold and format results
results = []
for score, idx in zip(scores[0], indices[0]):
if idx >= 0 and score >= threshold: # FAISS returns -1 for empty slots
results.append(SearchResult(
id=self.ids[idx],
score=float(score),
metadata=self.metadata[idx]
))
return results
def clear(self):
"""Reset the store."""
self.index.reset()
self.metadata = []
self.ids = []
logger.info("Cleared FAISS vector store")
def size(self) -> int:
"""Return number of stored vectors."""
return self.index.ntotal
# Utility functions for vector operations
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Compute cosine similarity between two vectors."""
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def batch_cosine_similarity(queries: np.ndarray, targets: np.ndarray) -> np.ndarray:
"""Compute cosine similarity between query vectors and target vectors.
Args:
queries: Array of shape (n_queries, dimension)
targets: Array of shape (n_targets, dimension)
Returns:
Array of shape (n_queries, n_targets) with similarity scores
"""
# Normalize queries and targets
queries_norm = queries / np.linalg.norm(queries, axis=1, keepdims=True)
targets_norm = targets / np.linalg.norm(targets, axis=1, keepdims=True)
# Compute dot product
similarities = np.dot(queries_norm, targets_norm.T)
return similarities

View file

@ -0,0 +1,54 @@
"""
OntoRAG Query System.
Ontology-driven natural language query processing with multi-backend support.
Provides semantic query understanding, ontology matching, and answer generation.
"""
from .query_service import OntoRAGQueryService, QueryRequest, QueryResponse
from .question_analyzer import QuestionAnalyzer, QuestionComponents, QuestionType
from .ontology_matcher import OntologyMatcher, QueryOntologySubset
from .backend_router import BackendRouter, BackendType, QueryRoute
from .sparql_generator import SPARQLGenerator, SPARQLQuery
from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult
from .cypher_generator import CypherGenerator, CypherQuery
from .cypher_executor import CypherExecutor, CypherResult
from .answer_generator import AnswerGenerator, GeneratedAnswer, AnswerMetadata
__all__ = [
# Main service
'OntoRAGQueryService',
'QueryRequest',
'QueryResponse',
# Question analysis
'QuestionAnalyzer',
'QuestionComponents',
'QuestionType',
# Ontology matching
'OntologyMatcher',
'QueryOntologySubset',
# Backend routing
'BackendRouter',
'BackendType',
'QueryRoute',
# SPARQL components
'SPARQLGenerator',
'SPARQLQuery',
'SPARQLCassandraEngine',
'SPARQLResult',
# Cypher components
'CypherGenerator',
'CypherQuery',
'CypherExecutor',
'CypherResult',
# Answer generation
'AnswerGenerator',
'GeneratedAnswer',
'AnswerMetadata',
]

View file

@ -0,0 +1,521 @@
"""
Answer generator for natural language responses.
Converts query results into natural language answers using LLM assistance.
"""
import logging
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass
from datetime import datetime
from .question_analyzer import QuestionComponents, QuestionType
from .ontology_matcher import QueryOntologySubset
from .sparql_cassandra import SPARQLResult
from .cypher_executor import CypherResult
logger = logging.getLogger(__name__)
@dataclass
class AnswerMetadata:
"""Metadata about answer generation."""
query_type: str
backend_used: str
execution_time: float
result_count: int
confidence: float
explanation: str
sources: List[str]
@dataclass
class GeneratedAnswer:
"""Generated natural language answer."""
answer: str
metadata: AnswerMetadata
supporting_facts: List[str]
raw_results: Union[SPARQLResult, CypherResult]
generation_time: float
class AnswerGenerator:
"""Generates natural language answers from query results."""
def __init__(self, prompt_service=None):
"""Initialize answer generator.
Args:
prompt_service: Service for LLM-based answer generation
"""
self.prompt_service = prompt_service
# Answer templates for different question types
self.templates = {
'count': "There are {count} {entity_type}.",
'boolean_true': "Yes, {statement} is true.",
'boolean_false': "No, {statement} is not true.",
'list': "The {entity_type} are: {items}.",
'single': "The {property} of {entity} is {value}.",
'none': "No results were found for your query.",
'error': "I encountered an error processing your query: {error}"
}
async def generate_answer(self,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
ontology_subset: QueryOntologySubset,
backend_used: str) -> GeneratedAnswer:
"""Generate natural language answer from query results.
Args:
question_components: Original question analysis
query_results: Results from query execution
ontology_subset: Ontology subset used
backend_used: Backend that executed the query
Returns:
Generated answer with metadata
"""
start_time = datetime.now()
try:
# Try LLM-based generation first
if self.prompt_service:
llm_answer = await self._generate_with_llm(
question_components, query_results, ontology_subset
)
if llm_answer:
execution_time = (datetime.now() - start_time).total_seconds()
return self._build_answer_response(
llm_answer, question_components, query_results,
backend_used, execution_time
)
# Fall back to template-based generation
template_answer = self._generate_with_template(
question_components, query_results, ontology_subset
)
execution_time = (datetime.now() - start_time).total_seconds()
return self._build_answer_response(
template_answer, question_components, query_results,
backend_used, execution_time
)
except Exception as e:
logger.error(f"Answer generation failed: {e}")
execution_time = (datetime.now() - start_time).total_seconds()
error_answer = self.templates['error'].format(error=str(e))
return self._build_answer_response(
error_answer, question_components, query_results,
backend_used, execution_time, confidence=0.0
)
async def _generate_with_llm(self,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
ontology_subset: QueryOntologySubset) -> Optional[str]:
"""Generate answer using LLM.
Args:
question_components: Question analysis
query_results: Query results
ontology_subset: Ontology subset
Returns:
Generated answer or None if failed
"""
try:
prompt = self._build_answer_prompt(
question_components, query_results, ontology_subset
)
response = await self.prompt_service.generate_answer(prompt=prompt)
if response and isinstance(response, dict):
return response.get('answer', '').strip()
elif isinstance(response, str):
return response.strip()
except Exception as e:
logger.error(f"LLM answer generation failed: {e}")
return None
def _generate_with_template(self,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
ontology_subset: QueryOntologySubset) -> str:
"""Generate answer using templates.
Args:
question_components: Question analysis
query_results: Query results
ontology_subset: Ontology subset
Returns:
Template-based answer
"""
# Handle empty results
if not self._has_results(query_results):
return self.templates['none']
# Handle boolean queries
if question_components.question_type == QuestionType.BOOLEAN:
if hasattr(query_results, 'ask_result'):
# SPARQL ASK result
statement = self._extract_boolean_statement(question_components)
if query_results.ask_result:
return self.templates['boolean_true'].format(statement=statement)
else:
return self.templates['boolean_false'].format(statement=statement)
else:
# Cypher boolean (check if any results)
has_results = len(query_results.records) > 0
statement = self._extract_boolean_statement(question_components)
if has_results:
return self.templates['boolean_true'].format(statement=statement)
else:
return self.templates['boolean_false'].format(statement=statement)
# Handle count queries
if question_components.question_type == QuestionType.AGGREGATION:
count = self._extract_count(query_results)
entity_type = self._infer_entity_type(question_components, ontology_subset)
return self.templates['count'].format(count=count, entity_type=entity_type)
# Handle retrieval queries
if question_components.question_type == QuestionType.RETRIEVAL:
items = self._extract_items(query_results)
if len(items) == 1:
# Single result
entity = question_components.entities[0] if question_components.entities else "entity"
property_name = "value"
return self.templates['single'].format(
property=property_name, entity=entity, value=items[0]
)
else:
# Multiple results
entity_type = self._infer_entity_type(question_components, ontology_subset)
items_str = ", ".join(items)
return self.templates['list'].format(entity_type=entity_type, items=items_str)
# Handle factual queries
if question_components.question_type == QuestionType.FACTUAL:
facts = self._extract_facts(query_results)
return ". ".join(facts) if facts else self.templates['none']
# Default fallback
items = self._extract_items(query_results)
if items:
return f"Found: {', '.join(items[:5])}" + ("..." if len(items) > 5 else "")
else:
return self.templates['none']
def _build_answer_prompt(self,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
ontology_subset: QueryOntologySubset) -> str:
"""Build prompt for LLM answer generation.
Args:
question_components: Question analysis
query_results: Query results
ontology_subset: Ontology subset
Returns:
Formatted prompt string
"""
# Format results for prompt
results_str = self._format_results_for_prompt(query_results)
# Extract ontology context
context_classes = list(ontology_subset.classes.keys())[:5]
context_properties = list(ontology_subset.object_properties.keys())[:5]
prompt = f"""Generate a natural language answer for the following question based on the query results.
ORIGINAL QUESTION: {question_components.original_question}
QUESTION TYPE: {question_components.question_type.value}
EXPECTED ANSWER: {question_components.expected_answer_type}
ONTOLOGY CONTEXT:
- Classes: {', '.join(context_classes)}
- Properties: {', '.join(context_properties)}
QUERY RESULTS:
{results_str}
INSTRUCTIONS:
- Provide a clear, concise answer in natural language
- Use the original question's tone and style
- Include specific facts from the results
- If no results, explain that no information was found
- Be accurate and don't make assumptions beyond the data
- Limit response to 2-3 sentences unless the question requires more detail
ANSWER:"""
return prompt
def _format_results_for_prompt(self, query_results: Union[SPARQLResult, CypherResult]) -> str:
"""Format query results for prompt inclusion.
Args:
query_results: Query results to format
Returns:
Formatted results string
"""
if isinstance(query_results, SPARQLResult):
if hasattr(query_results, 'ask_result') and query_results.ask_result is not None:
return f"Boolean result: {query_results.ask_result}"
if not query_results.bindings:
return "No results found"
# Format SPARQL bindings
lines = []
for binding in query_results.bindings[:10]: # Limit to first 10
formatted = []
for var, value in binding.items():
if isinstance(value, dict):
formatted.append(f"{var}: {value.get('value', value)}")
else:
formatted.append(f"{var}: {value}")
lines.append("- " + ", ".join(formatted))
if len(query_results.bindings) > 10:
lines.append(f"... and {len(query_results.bindings) - 10} more results")
return "\n".join(lines)
else: # CypherResult
if not query_results.records:
return "No results found"
# Format Cypher records
lines = []
for record in query_results.records[:10]: # Limit to first 10
if isinstance(record, dict):
formatted = [f"{k}: {v}" for k, v in record.items()]
lines.append("- " + ", ".join(formatted))
else:
lines.append(f"- {record}")
if len(query_results.records) > 10:
lines.append(f"... and {len(query_results.records) - 10} more results")
return "\n".join(lines)
def _has_results(self, query_results: Union[SPARQLResult, CypherResult]) -> bool:
"""Check if query results contain data.
Args:
query_results: Query results to check
Returns:
True if results contain data
"""
if isinstance(query_results, SPARQLResult):
return bool(query_results.bindings) or query_results.ask_result is not None
else: # CypherResult
return bool(query_results.records)
def _extract_count(self, query_results: Union[SPARQLResult, CypherResult]) -> int:
"""Extract count from aggregation query results.
Args:
query_results: Query results
Returns:
Count value
"""
if isinstance(query_results, SPARQLResult):
if query_results.bindings:
binding = query_results.bindings[0]
# Look for count variable
for var, value in binding.items():
if 'count' in var.lower():
if isinstance(value, dict):
return int(value.get('value', 0))
return int(value)
return len(query_results.bindings)
else: # CypherResult
if query_results.records:
record = query_results.records[0]
if isinstance(record, dict):
# Look for count key
for key, value in record.items():
if 'count' in key.lower():
return int(value)
elif isinstance(record, (int, float)):
return int(record)
return len(query_results.records)
def _extract_items(self, query_results: Union[SPARQLResult, CypherResult]) -> List[str]:
"""Extract items from query results.
Args:
query_results: Query results
Returns:
List of extracted items
"""
items = []
if isinstance(query_results, SPARQLResult):
for binding in query_results.bindings:
for var, value in binding.items():
if isinstance(value, dict):
item_value = value.get('value', str(value))
else:
item_value = str(value)
# Clean up URIs
if item_value.startswith('http'):
item_value = item_value.split('/')[-1].split('#')[-1]
items.append(item_value)
break # Take first value per binding
else: # CypherResult
for record in query_results.records:
if isinstance(record, dict):
# Take first value from record
for key, value in record.items():
items.append(str(value))
break
else:
items.append(str(record))
return items
def _extract_facts(self, query_results: Union[SPARQLResult, CypherResult]) -> List[str]:
"""Extract facts from query results.
Args:
query_results: Query results
Returns:
List of facts
"""
facts = []
if isinstance(query_results, SPARQLResult):
for binding in query_results.bindings:
fact_parts = []
for var, value in binding.items():
if isinstance(value, dict):
val_str = value.get('value', str(value))
else:
val_str = str(value)
# Clean up URIs
if val_str.startswith('http'):
val_str = val_str.split('/')[-1].split('#')[-1]
fact_parts.append(f"{var}: {val_str}")
facts.append(", ".join(fact_parts))
else: # CypherResult
for record in query_results.records:
if isinstance(record, dict):
fact_parts = [f"{k}: {v}" for k, v in record.items()]
facts.append(", ".join(fact_parts))
else:
facts.append(str(record))
return facts
def _extract_boolean_statement(self, question_components: QuestionComponents) -> str:
"""Extract statement for boolean answer.
Args:
question_components: Question analysis
Returns:
Statement string
"""
# Extract the key assertion from the question
question = question_components.original_question.lower()
# Remove question words
statement = question.replace('is ', '').replace('are ', '').replace('does ', '')
statement = statement.replace('?', '').strip()
return statement
def _infer_entity_type(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> str:
"""Infer entity type from question and ontology.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Entity type string
"""
# Try to match entities to ontology classes
for entity in question_components.entities:
entity_lower = entity.lower()
for class_id in ontology_subset.classes:
if class_id.lower() == entity_lower or entity_lower in class_id.lower():
return class_id
# Fallback to first entity or generic term
if question_components.entities:
return question_components.entities[0]
else:
return "entities"
def _build_answer_response(self,
answer: str,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
backend_used: str,
execution_time: float,
confidence: float = 0.8) -> GeneratedAnswer:
"""Build final answer response.
Args:
answer: Generated answer text
question_components: Question analysis
query_results: Query results
backend_used: Backend used for query
execution_time: Answer generation time
confidence: Confidence score
Returns:
Complete answer response
"""
# Extract supporting facts
supporting_facts = self._extract_facts(query_results)
# Build metadata
result_count = 0
if isinstance(query_results, SPARQLResult):
result_count = len(query_results.bindings)
else: # CypherResult
result_count = len(query_results.records)
metadata = AnswerMetadata(
query_type=question_components.question_type.value,
backend_used=backend_used,
execution_time=execution_time,
result_count=result_count,
confidence=confidence,
explanation=f"Generated answer using {backend_used} backend",
sources=[] # Could be populated with data source information
)
return GeneratedAnswer(
answer=answer,
metadata=metadata,
supporting_facts=supporting_facts[:5], # Limit to top 5
raw_results=query_results,
generation_time=execution_time
)

View file

@ -0,0 +1,350 @@
"""
Backend router for ontology query system.
Routes queries to appropriate backend based on configuration.
"""
import logging
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from enum import Enum
from .question_analyzer import QuestionComponents
from .ontology_matcher import QueryOntologySubset
logger = logging.getLogger(__name__)
class BackendType(Enum):
"""Supported backend types."""
CASSANDRA = "cassandra"
NEO4J = "neo4j"
MEMGRAPH = "memgraph"
FALKORDB = "falkordb"
@dataclass
class BackendConfig:
"""Configuration for a backend."""
type: BackendType
priority: int = 0
enabled: bool = True
config: Dict[str, Any] = None
@dataclass
class QueryRoute:
"""Routing decision for a query."""
backend_type: BackendType
query_language: str # 'sparql' or 'cypher'
confidence: float
reasoning: str
class BackendRouter:
"""Routes queries to appropriate backends based on configuration and heuristics."""
def __init__(self, config: Dict[str, Any]):
"""Initialize backend router.
Args:
config: Router configuration
"""
self.config = config
self.backends = self._parse_backend_config(config)
self.routing_strategy = config.get('routing_strategy', 'priority')
self.enable_fallback = config.get('enable_fallback', True)
def _parse_backend_config(self, config: Dict[str, Any]) -> Dict[BackendType, BackendConfig]:
"""Parse backend configuration.
Args:
config: Configuration dictionary
Returns:
Dictionary of backend type to configuration
"""
backends = {}
# Parse primary backend
primary = config.get('primary', 'cassandra')
if primary:
try:
backend_type = BackendType(primary)
backends[backend_type] = BackendConfig(
type=backend_type,
priority=100,
enabled=True,
config=config.get(primary, {})
)
except ValueError:
logger.warning(f"Unknown primary backend type: {primary}")
# Parse fallback backends
fallbacks = config.get('fallback', [])
for i, fallback in enumerate(fallbacks):
try:
backend_type = BackendType(fallback)
backends[backend_type] = BackendConfig(
type=backend_type,
priority=50 - i * 10, # Decreasing priority
enabled=True,
config=config.get(fallback, {})
)
except ValueError:
logger.warning(f"Unknown fallback backend type: {fallback}")
return backends
def route_query(self,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset]) -> QueryRoute:
"""Route a query to the best backend.
Args:
question_components: Analyzed question
ontology_subsets: Relevant ontology subsets
Returns:
QueryRoute with routing decision
"""
if self.routing_strategy == 'priority':
return self._route_by_priority()
elif self.routing_strategy == 'adaptive':
return self._route_adaptive(question_components, ontology_subsets)
elif self.routing_strategy == 'round_robin':
return self._route_round_robin()
else:
return self._route_by_priority()
def _route_by_priority(self) -> QueryRoute:
"""Route based on backend priority.
Returns:
QueryRoute to highest priority backend
"""
# Find highest priority enabled backend
best_backend = None
best_priority = -1
for backend_type, backend_config in self.backends.items():
if backend_config.enabled and backend_config.priority > best_priority:
best_backend = backend_type
best_priority = backend_config.priority
if best_backend is None:
raise RuntimeError("No enabled backends available")
# Determine query language
query_language = 'sparql' if best_backend == BackendType.CASSANDRA else 'cypher'
return QueryRoute(
backend_type=best_backend,
query_language=query_language,
confidence=1.0,
reasoning=f"Priority routing to {best_backend.value}"
)
def _route_adaptive(self,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset]) -> QueryRoute:
"""Route based on question characteristics and ontology complexity.
Args:
question_components: Analyzed question
ontology_subsets: Relevant ontology subsets
Returns:
QueryRoute with adaptive decision
"""
scores = {}
for backend_type, backend_config in self.backends.items():
if not backend_config.enabled:
continue
score = self._calculate_backend_score(
backend_type, question_components, ontology_subsets
)
scores[backend_type] = score
if not scores:
raise RuntimeError("No enabled backends available")
# Select backend with highest score
best_backend = max(scores.keys(), key=lambda k: scores[k])
best_score = scores[best_backend]
# Determine query language
query_language = 'sparql' if best_backend == BackendType.CASSANDRA else 'cypher'
return QueryRoute(
backend_type=best_backend,
query_language=query_language,
confidence=best_score,
reasoning=f"Adaptive routing: {best_backend.value} scored {best_score:.2f}"
)
def _calculate_backend_score(self,
backend_type: BackendType,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset]) -> float:
"""Calculate score for a backend based on query characteristics.
Args:
backend_type: Backend to score
question_components: Question analysis
ontology_subsets: Ontology subsets
Returns:
Score (0.0 to 1.0)
"""
score = 0.0
# Base priority score
backend_config = self.backends[backend_type]
score += backend_config.priority / 100.0
# Question type preferences
if backend_type == BackendType.CASSANDRA:
# SPARQL is good for hierarchical and complex reasoning
if question_components.question_type.value in ['factual', 'aggregation']:
score += 0.3
# Good for ontology-heavy queries
if len(ontology_subsets) > 1:
score += 0.2
else:
# Cypher is good for graph traversal and relationships
if question_components.question_type.value in ['relationship', 'retrieval']:
score += 0.3
# Good for simple graph patterns
if len(question_components.relationships) > 0:
score += 0.2
# Complexity considerations
total_elements = sum(
len(subset.classes) + len(subset.object_properties) + len(subset.datatype_properties)
for subset in ontology_subsets
)
if backend_type == BackendType.CASSANDRA:
# SPARQL handles complex ontologies well
if total_elements > 20:
score += 0.2
else:
# Cypher is efficient for simpler queries
if total_elements <= 10:
score += 0.2
# Aggregation considerations
if question_components.aggregations:
if backend_type == BackendType.CASSANDRA:
score += 0.1 # SPARQL has built-in aggregation
else:
score += 0.2 # Cypher has excellent aggregation
return min(score, 1.0)
def _route_round_robin(self) -> QueryRoute:
"""Route using round-robin strategy.
Returns:
QueryRoute using round-robin selection
"""
# Simple round-robin implementation
enabled_backends = [
bt for bt, bc in self.backends.items() if bc.enabled
]
if not enabled_backends:
raise RuntimeError("No enabled backends available")
# For simplicity, just return the first enabled backend
# In a real implementation, you'd track state
backend_type = enabled_backends[0]
query_language = 'sparql' if backend_type == BackendType.CASSANDRA else 'cypher'
return QueryRoute(
backend_type=backend_type,
query_language=query_language,
confidence=0.8,
reasoning=f"Round-robin routing to {backend_type.value}"
)
def get_fallback_route(self, failed_backend: BackendType) -> Optional[QueryRoute]:
"""Get fallback route when a backend fails.
Args:
failed_backend: Backend that failed
Returns:
Fallback route or None if no fallback available
"""
if not self.enable_fallback:
return None
# Find next best backend
fallback_backends = [
(bt, bc) for bt, bc in self.backends.items()
if bc.enabled and bt != failed_backend
]
if not fallback_backends:
return None
# Sort by priority
fallback_backends.sort(key=lambda x: x[1].priority, reverse=True)
fallback_type = fallback_backends[0][0]
query_language = 'sparql' if fallback_type == BackendType.CASSANDRA else 'cypher'
return QueryRoute(
backend_type=fallback_type,
query_language=query_language,
confidence=0.7,
reasoning=f"Fallback from {failed_backend.value} to {fallback_type.value}"
)
def get_available_backends(self) -> List[BackendType]:
"""Get list of available backends.
Returns:
List of enabled backend types
"""
return [bt for bt, bc in self.backends.items() if bc.enabled]
def is_backend_enabled(self, backend_type: BackendType) -> bool:
"""Check if a backend is enabled.
Args:
backend_type: Backend to check
Returns:
True if backend is enabled
"""
backend_config = self.backends.get(backend_type)
return backend_config is not None and backend_config.enabled
def update_backend_status(self, backend_type: BackendType, enabled: bool):
"""Update backend enabled status.
Args:
backend_type: Backend to update
enabled: New enabled status
"""
if backend_type in self.backends:
self.backends[backend_type].enabled = enabled
logger.info(f"Backend {backend_type.value} {'enabled' if enabled else 'disabled'}")
else:
logger.warning(f"Unknown backend type: {backend_type}")
def get_backend_config(self, backend_type: BackendType) -> Optional[Dict[str, Any]]:
"""Get configuration for a backend.
Args:
backend_type: Backend type
Returns:
Configuration dictionary or None
"""
backend_config = self.backends.get(backend_type)
return backend_config.config if backend_config else None

View file

@ -0,0 +1,651 @@
"""
Caching system for OntoRAG query results and computations.
Provides multiple cache backends and intelligent cache management.
"""
import logging
import time
import json
import pickle
from typing import Dict, Any, List, Optional, Union, Tuple
from dataclasses import dataclass, asdict
from datetime import datetime, timedelta
from abc import ABC, abstractmethod
from pathlib import Path
import threading
logger = logging.getLogger(__name__)
@dataclass
class CacheEntry:
"""Cache entry with metadata."""
key: str
value: Any
created_at: datetime
accessed_at: datetime
access_count: int
ttl_seconds: Optional[int] = None
tags: List[str] = None
size_bytes: int = 0
def is_expired(self) -> bool:
"""Check if cache entry is expired."""
if self.ttl_seconds is None:
return False
return (datetime.now() - self.created_at).total_seconds() > self.ttl_seconds
def touch(self):
"""Update access time and count."""
self.accessed_at = datetime.now()
self.access_count += 1
@dataclass
class CacheStats:
"""Cache performance statistics."""
hits: int = 0
misses: int = 0
evictions: int = 0
total_entries: int = 0
total_size_bytes: int = 0
hit_rate: float = 0.0
def update_hit_rate(self):
"""Update hit rate calculation."""
total_requests = self.hits + self.misses
self.hit_rate = self.hits / total_requests if total_requests > 0 else 0.0
class CacheBackend(ABC):
"""Abstract base class for cache backends."""
@abstractmethod
def get(self, key: str) -> Optional[CacheEntry]:
"""Get cache entry by key."""
pass
@abstractmethod
def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None):
"""Set cache entry."""
pass
@abstractmethod
def delete(self, key: str) -> bool:
"""Delete cache entry."""
pass
@abstractmethod
def clear(self, tags: Optional[List[str]] = None):
"""Clear cache entries."""
pass
@abstractmethod
def get_stats(self) -> CacheStats:
"""Get cache statistics."""
pass
@abstractmethod
def cleanup_expired(self):
"""Clean up expired entries."""
pass
class InMemoryCache(CacheBackend):
"""In-memory cache backend."""
def __init__(self, max_size: int = 1000, max_size_bytes: int = 100 * 1024 * 1024):
"""Initialize in-memory cache.
Args:
max_size: Maximum number of entries
max_size_bytes: Maximum total size in bytes
"""
self.max_size = max_size
self.max_size_bytes = max_size_bytes
self.entries: Dict[str, CacheEntry] = {}
self.stats = CacheStats()
self._lock = threading.RLock()
def get(self, key: str) -> Optional[CacheEntry]:
"""Get cache entry by key."""
with self._lock:
entry = self.entries.get(key)
if entry is None:
self.stats.misses += 1
self.stats.update_hit_rate()
return None
if entry.is_expired():
del self.entries[key]
self.stats.misses += 1
self.stats.evictions += 1
self.stats.total_entries -= 1
self.stats.total_size_bytes -= entry.size_bytes
self.stats.update_hit_rate()
return None
entry.touch()
self.stats.hits += 1
self.stats.update_hit_rate()
return entry
def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None):
"""Set cache entry."""
with self._lock:
# Calculate size
try:
size_bytes = len(pickle.dumps(value))
except Exception:
size_bytes = len(str(value).encode('utf-8'))
# Create entry
now = datetime.now()
entry = CacheEntry(
key=key,
value=value,
created_at=now,
accessed_at=now,
access_count=1,
ttl_seconds=ttl_seconds,
tags=tags or [],
size_bytes=size_bytes
)
# Check if we need to evict
self._ensure_capacity(size_bytes)
# Store entry
old_entry = self.entries.get(key)
if old_entry:
self.stats.total_size_bytes -= old_entry.size_bytes
else:
self.stats.total_entries += 1
self.entries[key] = entry
self.stats.total_size_bytes += size_bytes
def delete(self, key: str) -> bool:
"""Delete cache entry."""
with self._lock:
entry = self.entries.pop(key, None)
if entry:
self.stats.total_entries -= 1
self.stats.total_size_bytes -= entry.size_bytes
self.stats.evictions += 1
return True
return False
def clear(self, tags: Optional[List[str]] = None):
"""Clear cache entries."""
with self._lock:
if tags is None:
# Clear all
self.stats.evictions += len(self.entries)
self.entries.clear()
self.stats.total_entries = 0
self.stats.total_size_bytes = 0
else:
# Clear by tags
to_delete = []
for key, entry in self.entries.items():
if any(tag in entry.tags for tag in tags):
to_delete.append(key)
for key in to_delete:
self.delete(key)
def get_stats(self) -> CacheStats:
"""Get cache statistics."""
with self._lock:
return CacheStats(
hits=self.stats.hits,
misses=self.stats.misses,
evictions=self.stats.evictions,
total_entries=self.stats.total_entries,
total_size_bytes=self.stats.total_size_bytes,
hit_rate=self.stats.hit_rate
)
def cleanup_expired(self):
"""Clean up expired entries."""
with self._lock:
to_delete = []
for key, entry in self.entries.items():
if entry.is_expired():
to_delete.append(key)
for key in to_delete:
self.delete(key)
def _ensure_capacity(self, new_size_bytes: int):
"""Ensure cache has capacity for new entry."""
# Check size limit
if self.stats.total_size_bytes + new_size_bytes > self.max_size_bytes:
self._evict_by_size(new_size_bytes)
# Check count limit
if len(self.entries) >= self.max_size:
self._evict_by_count()
def _evict_by_size(self, needed_bytes: int):
"""Evict entries to free up space."""
# Sort by access time (LRU)
sorted_entries = sorted(
self.entries.items(),
key=lambda x: (x[1].accessed_at, x[1].access_count)
)
freed_bytes = 0
for key, entry in sorted_entries:
if freed_bytes >= needed_bytes:
break
freed_bytes += entry.size_bytes
del self.entries[key]
self.stats.total_entries -= 1
self.stats.total_size_bytes -= entry.size_bytes
self.stats.evictions += 1
def _evict_by_count(self):
"""Evict least recently used entry."""
if not self.entries:
return
# Find LRU entry
lru_key = min(
self.entries.keys(),
key=lambda k: (self.entries[k].accessed_at, self.entries[k].access_count)
)
self.delete(lru_key)
class FileCache(CacheBackend):
"""File-based cache backend."""
def __init__(self, cache_dir: str, max_files: int = 10000):
"""Initialize file cache.
Args:
cache_dir: Directory to store cache files
max_files: Maximum number of cache files
"""
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.max_files = max_files
self.stats = CacheStats()
self._lock = threading.RLock()
# Load existing stats
self._load_stats()
def get(self, key: str) -> Optional[CacheEntry]:
"""Get cache entry by key."""
with self._lock:
cache_file = self.cache_dir / f"{self._safe_key(key)}.cache"
if not cache_file.exists():
self.stats.misses += 1
self.stats.update_hit_rate()
return None
try:
with open(cache_file, 'rb') as f:
entry = pickle.load(f)
if entry.is_expired():
cache_file.unlink()
self.stats.misses += 1
self.stats.evictions += 1
self.stats.total_entries -= 1
self.stats.update_hit_rate()
return None
entry.touch()
# Update file modification time
cache_file.touch()
self.stats.hits += 1
self.stats.update_hit_rate()
return entry
except Exception as e:
logger.error(f"Error reading cache file {cache_file}: {e}")
cache_file.unlink(missing_ok=True)
self.stats.misses += 1
self.stats.update_hit_rate()
return None
def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None):
"""Set cache entry."""
with self._lock:
cache_file = self.cache_dir / f"{self._safe_key(key)}.cache"
# Create entry
now = datetime.now()
entry = CacheEntry(
key=key,
value=value,
created_at=now,
accessed_at=now,
access_count=1,
ttl_seconds=ttl_seconds,
tags=tags or []
)
try:
# Ensure capacity
self._ensure_capacity()
# Write to file
with open(cache_file, 'wb') as f:
pickle.dump(entry, f)
entry.size_bytes = cache_file.stat().st_size
if not cache_file.exists():
self.stats.total_entries += 1
self.stats.total_size_bytes += entry.size_bytes
self._save_stats()
except Exception as e:
logger.error(f"Error writing cache file {cache_file}: {e}")
def delete(self, key: str) -> bool:
"""Delete cache entry."""
with self._lock:
cache_file = self.cache_dir / f"{self._safe_key(key)}.cache"
if cache_file.exists():
size = cache_file.stat().st_size
cache_file.unlink()
self.stats.total_entries -= 1
self.stats.total_size_bytes -= size
self.stats.evictions += 1
self._save_stats()
return True
return False
def clear(self, tags: Optional[List[str]] = None):
"""Clear cache entries."""
with self._lock:
if tags is None:
# Clear all
for cache_file in self.cache_dir.glob("*.cache"):
cache_file.unlink()
self.stats.evictions += self.stats.total_entries
self.stats.total_entries = 0
self.stats.total_size_bytes = 0
else:
# Clear by tags
for cache_file in self.cache_dir.glob("*.cache"):
try:
with open(cache_file, 'rb') as f:
entry = pickle.load(f)
if any(tag in entry.tags for tag in tags):
size = cache_file.stat().st_size
cache_file.unlink()
self.stats.total_entries -= 1
self.stats.total_size_bytes -= size
self.stats.evictions += 1
except Exception:
continue
self._save_stats()
def get_stats(self) -> CacheStats:
"""Get cache statistics."""
with self._lock:
return CacheStats(
hits=self.stats.hits,
misses=self.stats.misses,
evictions=self.stats.evictions,
total_entries=self.stats.total_entries,
total_size_bytes=self.stats.total_size_bytes,
hit_rate=self.stats.hit_rate
)
def cleanup_expired(self):
"""Clean up expired entries."""
with self._lock:
for cache_file in self.cache_dir.glob("*.cache"):
try:
with open(cache_file, 'rb') as f:
entry = pickle.load(f)
if entry.is_expired():
size = cache_file.stat().st_size
cache_file.unlink()
self.stats.total_entries -= 1
self.stats.total_size_bytes -= size
self.stats.evictions += 1
except Exception:
# Remove corrupted files
cache_file.unlink()
self._save_stats()
def _safe_key(self, key: str) -> str:
"""Convert key to safe filename."""
import hashlib
return hashlib.md5(key.encode()).hexdigest()
def _ensure_capacity(self):
"""Ensure cache has capacity for new entry."""
cache_files = list(self.cache_dir.glob("*.cache"))
if len(cache_files) >= self.max_files:
# Remove oldest file
oldest_file = min(cache_files, key=lambda f: f.stat().st_mtime)
size = oldest_file.stat().st_size
oldest_file.unlink()
self.stats.total_entries -= 1
self.stats.total_size_bytes -= size
self.stats.evictions += 1
def _load_stats(self):
"""Load statistics from file."""
stats_file = self.cache_dir / "stats.json"
if stats_file.exists():
try:
with open(stats_file, 'r') as f:
data = json.load(f)
self.stats = CacheStats(**data)
except Exception:
pass
def _save_stats(self):
"""Save statistics to file."""
stats_file = self.cache_dir / "stats.json"
try:
with open(stats_file, 'w') as f:
json.dump(asdict(self.stats), f, default=str)
except Exception:
pass
class CacheManager:
"""Cache manager with multiple backends and intelligent caching strategies."""
def __init__(self, config: Dict[str, Any]):
"""Initialize cache manager.
Args:
config: Cache configuration
"""
self.config = config
self.backends: Dict[str, CacheBackend] = {}
self.default_backend = config.get('default_backend', 'memory')
self.default_ttl = config.get('default_ttl_seconds', 3600) # 1 hour
# Initialize backends
self._init_backends()
# Start cleanup task
self.cleanup_interval = config.get('cleanup_interval_seconds', 300) # 5 minutes
self._start_cleanup_task()
def _init_backends(self):
"""Initialize cache backends."""
backends_config = self.config.get('backends', {})
# Memory backend
if 'memory' in backends_config or self.default_backend == 'memory':
memory_config = backends_config.get('memory', {})
self.backends['memory'] = InMemoryCache(
max_size=memory_config.get('max_size', 1000),
max_size_bytes=memory_config.get('max_size_bytes', 100 * 1024 * 1024)
)
# File backend
if 'file' in backends_config or self.default_backend == 'file':
file_config = backends_config.get('file', {})
self.backends['file'] = FileCache(
cache_dir=file_config.get('cache_dir', './cache'),
max_files=file_config.get('max_files', 10000)
)
def get(self, key: str, backend: Optional[str] = None) -> Optional[Any]:
"""Get value from cache.
Args:
key: Cache key
backend: Backend name (optional)
Returns:
Cached value or None
"""
backend_name = backend or self.default_backend
cache_backend = self.backends.get(backend_name)
if cache_backend is None:
logger.warning(f"Cache backend '{backend_name}' not found")
return None
entry = cache_backend.get(key)
return entry.value if entry else None
def set(self,
key: str,
value: Any,
ttl_seconds: Optional[int] = None,
tags: Optional[List[str]] = None,
backend: Optional[str] = None):
"""Set value in cache.
Args:
key: Cache key
value: Value to cache
ttl_seconds: Time to live in seconds
tags: Cache tags
backend: Backend name (optional)
"""
backend_name = backend or self.default_backend
cache_backend = self.backends.get(backend_name)
if cache_backend is None:
logger.warning(f"Cache backend '{backend_name}' not found")
return
ttl = ttl_seconds if ttl_seconds is not None else self.default_ttl
cache_backend.set(key, value, ttl, tags)
def delete(self, key: str, backend: Optional[str] = None) -> bool:
"""Delete value from cache.
Args:
key: Cache key
backend: Backend name (optional)
Returns:
True if deleted
"""
backend_name = backend or self.default_backend
cache_backend = self.backends.get(backend_name)
if cache_backend is None:
return False
return cache_backend.delete(key)
def clear(self, tags: Optional[List[str]] = None, backend: Optional[str] = None):
"""Clear cache entries.
Args:
tags: Tags to clear (optional)
backend: Backend name (optional)
"""
if backend:
cache_backend = self.backends.get(backend)
if cache_backend:
cache_backend.clear(tags)
else:
# Clear all backends
for cache_backend in self.backends.values():
cache_backend.clear(tags)
def get_stats(self) -> Dict[str, CacheStats]:
"""Get statistics for all backends.
Returns:
Dictionary of backend name to statistics
"""
return {name: backend.get_stats() for name, backend in self.backends.items()}
def cleanup_expired(self):
"""Clean up expired entries in all backends."""
for backend in self.backends.values():
try:
backend.cleanup_expired()
except Exception as e:
logger.error(f"Error cleaning up cache backend: {e}")
def _start_cleanup_task(self):
"""Start periodic cleanup task."""
def cleanup_worker():
while True:
try:
time.sleep(self.cleanup_interval)
self.cleanup_expired()
except Exception as e:
logger.error(f"Cache cleanup error: {e}")
cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
cleanup_thread.start()
# Cache decorators and utilities
def cache_result(cache_manager: CacheManager,
key_func: Optional[callable] = None,
ttl_seconds: Optional[int] = None,
tags: Optional[List[str]] = None,
backend: Optional[str] = None):
"""Decorator to cache function results.
Args:
cache_manager: Cache manager instance
key_func: Function to generate cache key
ttl_seconds: Time to live
tags: Cache tags
backend: Backend name
"""
def decorator(func):
def wrapper(*args, **kwargs):
# Generate cache key
if key_func:
cache_key = key_func(*args, **kwargs)
else:
cache_key = f"{func.__name__}:{hash((args, tuple(sorted(kwargs.items()))))}"
# Try to get from cache
cached_result = cache_manager.get(cache_key, backend)
if cached_result is not None:
return cached_result
# Execute function
result = func(*args, **kwargs)
# Cache result
cache_manager.set(cache_key, result, ttl_seconds, tags, backend)
return result
return wrapper
return decorator

View file

@ -0,0 +1,610 @@
"""
Cypher executor for multiple graph databases.
Executes Cypher queries against Neo4j, Memgraph, and FalkorDB.
"""
import logging
import asyncio
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass
from abc import ABC, abstractmethod
from .cypher_generator import CypherQuery
logger = logging.getLogger(__name__)
# Try to import various database drivers
try:
from neo4j import GraphDatabase, Driver as Neo4jDriver
NEO4J_AVAILABLE = True
except ImportError:
NEO4J_AVAILABLE = False
Neo4jDriver = None
try:
import redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
@dataclass
class CypherResult:
"""Result from Cypher query execution."""
records: List[Dict[str, Any]]
summary: Dict[str, Any]
execution_time: float
database_type: str
query_plan: Optional[Dict[str, Any]] = None
class CypherExecutorBase(ABC):
"""Abstract base class for Cypher executors."""
@abstractmethod
async def execute(self, cypher_query: CypherQuery) -> CypherResult:
"""Execute Cypher query."""
pass
@abstractmethod
async def close(self):
"""Close database connection."""
pass
@abstractmethod
def is_connected(self) -> bool:
"""Check if connected to database."""
pass
class Neo4jExecutor(CypherExecutorBase):
"""Cypher executor for Neo4j database."""
def __init__(self, config: Dict[str, Any]):
"""Initialize Neo4j executor.
Args:
config: Neo4j configuration
"""
if not NEO4J_AVAILABLE:
raise RuntimeError("Neo4j driver not available")
self.config = config
self.driver: Optional[Neo4jDriver] = None
self._connection_pool_size = config.get('connection_pool_size', 10)
async def connect(self):
"""Connect to Neo4j database."""
try:
uri = self.config.get('uri', 'bolt://localhost:7687')
username = self.config.get('username')
password = self.config.get('password')
auth = (username, password) if username and password else None
# Create driver with connection pool
self.driver = GraphDatabase.driver(
uri,
auth=auth,
max_connection_pool_size=self._connection_pool_size,
connection_timeout=self.config.get('connection_timeout', 30),
max_retry_time=self.config.get('max_retry_time', 15)
)
# Verify connectivity
await asyncio.get_event_loop().run_in_executor(
None, self.driver.verify_connectivity
)
logger.info(f"Connected to Neo4j at {uri}")
except Exception as e:
logger.error(f"Failed to connect to Neo4j: {e}")
raise
async def execute(self, cypher_query: CypherQuery) -> CypherResult:
"""Execute Cypher query against Neo4j.
Args:
cypher_query: Cypher query to execute
Returns:
Query results
"""
if not self.driver:
await self.connect()
import time
start_time = time.time()
try:
# Execute query in a session
records = await asyncio.get_event_loop().run_in_executor(
None, self._execute_sync, cypher_query
)
execution_time = time.time() - start_time
return CypherResult(
records=records,
summary={'record_count': len(records)},
execution_time=execution_time,
database_type='neo4j'
)
except Exception as e:
logger.error(f"Neo4j query execution error: {e}")
execution_time = time.time() - start_time
return CypherResult(
records=[],
summary={'error': str(e)},
execution_time=execution_time,
database_type='neo4j'
)
def _execute_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]:
"""Execute query synchronously in thread executor.
Args:
cypher_query: Cypher query to execute
Returns:
List of record dictionaries
"""
with self.driver.session() as session:
result = session.run(cypher_query.query, cypher_query.parameters)
records = []
for record in result:
record_dict = {}
for key in record.keys():
value = record[key]
record_dict[key] = self._format_neo4j_value(value)
records.append(record_dict)
return records
def _format_neo4j_value(self, value):
"""Format Neo4j value for JSON serialization.
Args:
value: Neo4j value
Returns:
JSON-serializable value
"""
# Handle Neo4j node objects
if hasattr(value, 'labels') and hasattr(value, 'items'):
return {
'labels': list(value.labels),
'properties': dict(value.items())
}
# Handle Neo4j relationship objects
elif hasattr(value, 'type') and hasattr(value, 'items'):
return {
'type': value.type,
'properties': dict(value.items())
}
# Handle Neo4j path objects
elif hasattr(value, 'nodes') and hasattr(value, 'relationships'):
return {
'nodes': [self._format_neo4j_value(n) for n in value.nodes],
'relationships': [self._format_neo4j_value(r) for r in value.relationships]
}
else:
return value
async def close(self):
"""Close Neo4j connection."""
if self.driver:
await asyncio.get_event_loop().run_in_executor(
None, self.driver.close
)
self.driver = None
logger.info("Neo4j connection closed")
def is_connected(self) -> bool:
"""Check if connected to Neo4j."""
return self.driver is not None
class MemgraphExecutor(CypherExecutorBase):
"""Cypher executor for Memgraph database."""
def __init__(self, config: Dict[str, Any]):
"""Initialize Memgraph executor.
Args:
config: Memgraph configuration
"""
if not NEO4J_AVAILABLE: # Memgraph uses Neo4j driver
raise RuntimeError("Neo4j driver required for Memgraph")
self.config = config
self.driver: Optional[Neo4jDriver] = None
async def connect(self):
"""Connect to Memgraph database."""
try:
uri = self.config.get('uri', 'bolt://localhost:7688')
username = self.config.get('username')
password = self.config.get('password')
auth = (username, password) if username and password else None
# Memgraph uses Neo4j driver but with different defaults
self.driver = GraphDatabase.driver(
uri,
auth=auth,
max_connection_pool_size=self.config.get('connection_pool_size', 5),
connection_timeout=self.config.get('connection_timeout', 10)
)
# Verify connectivity
await asyncio.get_event_loop().run_in_executor(
None, self.driver.verify_connectivity
)
logger.info(f"Connected to Memgraph at {uri}")
except Exception as e:
logger.error(f"Failed to connect to Memgraph: {e}")
raise
async def execute(self, cypher_query: CypherQuery) -> CypherResult:
"""Execute Cypher query against Memgraph.
Args:
cypher_query: Cypher query to execute
Returns:
Query results
"""
if not self.driver:
await self.connect()
import time
start_time = time.time()
try:
# Execute query with Memgraph-specific optimizations
records = await asyncio.get_event_loop().run_in_executor(
None, self._execute_memgraph_sync, cypher_query
)
execution_time = time.time() - start_time
return CypherResult(
records=records,
summary={
'record_count': len(records),
'engine': 'memgraph'
},
execution_time=execution_time,
database_type='memgraph'
)
except Exception as e:
logger.error(f"Memgraph query execution error: {e}")
execution_time = time.time() - start_time
return CypherResult(
records=[],
summary={'error': str(e)},
execution_time=execution_time,
database_type='memgraph'
)
def _execute_memgraph_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]:
"""Execute query synchronously for Memgraph.
Args:
cypher_query: Cypher query to execute
Returns:
List of record dictionaries
"""
with self.driver.session() as session:
# Add Memgraph-specific query hints if available
query = cypher_query.query
if cypher_query.database_hints and cypher_query.database_hints.get('memory_limit'):
# Memgraph supports memory limits
query = f"// Memory limit: {cypher_query.database_hints['memory_limit']}\n{query}"
result = session.run(query, cypher_query.parameters)
records = []
for record in result:
record_dict = {}
for key in record.keys():
record_dict[key] = record[key]
records.append(record_dict)
return records
async def close(self):
"""Close Memgraph connection."""
if self.driver:
await asyncio.get_event_loop().run_in_executor(
None, self.driver.close
)
self.driver = None
logger.info("Memgraph connection closed")
def is_connected(self) -> bool:
"""Check if connected to Memgraph."""
return self.driver is not None
class FalkorDBExecutor(CypherExecutorBase):
"""Cypher executor for FalkorDB (Redis-based graph database)."""
def __init__(self, config: Dict[str, Any]):
"""Initialize FalkorDB executor.
Args:
config: FalkorDB configuration
"""
if not REDIS_AVAILABLE:
raise RuntimeError("Redis driver required for FalkorDB")
self.config = config
self.redis_client: Optional[redis.Redis] = None
self.graph_name = config.get('graph_name', 'knowledge_graph')
async def connect(self):
"""Connect to FalkorDB (Redis)."""
try:
self.redis_client = redis.Redis(
host=self.config.get('host', 'localhost'),
port=self.config.get('port', 6379),
password=self.config.get('password'),
db=self.config.get('db', 0),
decode_responses=True,
socket_connect_timeout=self.config.get('connection_timeout', 10),
socket_timeout=self.config.get('socket_timeout', 10)
)
# Test connection
await asyncio.get_event_loop().run_in_executor(
None, self.redis_client.ping
)
logger.info(f"Connected to FalkorDB at {self.config.get('host', 'localhost')}")
except Exception as e:
logger.error(f"Failed to connect to FalkorDB: {e}")
raise
async def execute(self, cypher_query: CypherQuery) -> CypherResult:
"""Execute Cypher query against FalkorDB.
Args:
cypher_query: Cypher query to execute
Returns:
Query results
"""
if not self.redis_client:
await self.connect()
import time
start_time = time.time()
try:
# Execute query using FalkorDB's GRAPH.QUERY command
records = await asyncio.get_event_loop().run_in_executor(
None, self._execute_falkordb_sync, cypher_query
)
execution_time = time.time() - start_time
return CypherResult(
records=records,
summary={
'record_count': len(records),
'engine': 'falkordb'
},
execution_time=execution_time,
database_type='falkordb'
)
except Exception as e:
logger.error(f"FalkorDB query execution error: {e}")
execution_time = time.time() - start_time
return CypherResult(
records=[],
summary={'error': str(e)},
execution_time=execution_time,
database_type='falkordb'
)
def _execute_falkordb_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]:
"""Execute query synchronously for FalkorDB.
Args:
cypher_query: Cypher query to execute
Returns:
List of record dictionaries
"""
# Substitute parameters in query (FalkorDB parameter handling)
query = cypher_query.query
for param, value in cypher_query.parameters.items():
if isinstance(value, str):
query = query.replace(f'${param}', f'"{value}"')
else:
query = query.replace(f'${param}', str(value))
# Execute using FalkorDB GRAPH.QUERY command
result = self.redis_client.execute_command(
'GRAPH.QUERY', self.graph_name, query
)
# Parse FalkorDB result format
records = []
if result and len(result) > 1:
# FalkorDB returns [header, data rows, statistics]
headers = result[0] if result[0] else []
data_rows = result[1] if len(result) > 1 else []
for row in data_rows:
record = {}
for i, header in enumerate(headers):
if i < len(row):
record[header] = self._format_falkordb_value(row[i])
records.append(record)
return records
def _format_falkordb_value(self, value):
"""Format FalkorDB value for JSON serialization.
Args:
value: FalkorDB value
Returns:
JSON-serializable value
"""
# FalkorDB returns values in specific formats
if isinstance(value, list) and len(value) == 3:
# Check if it's a node/relationship representation
if value[0] == 1: # Node
return {
'type': 'node',
'labels': value[1],
'properties': value[2]
}
elif value[0] == 2: # Relationship
return {
'type': 'relationship',
'rel_type': value[1],
'properties': value[2]
}
return value
async def close(self):
"""Close FalkorDB connection."""
if self.redis_client:
await asyncio.get_event_loop().run_in_executor(
None, self.redis_client.close
)
self.redis_client = None
logger.info("FalkorDB connection closed")
def is_connected(self) -> bool:
"""Check if connected to FalkorDB."""
return self.redis_client is not None
class CypherExecutor:
"""Multi-database Cypher executor with automatic routing."""
def __init__(self, config: Dict[str, Any]):
"""Initialize multi-database executor.
Args:
config: Configuration for all database types
"""
self.config = config
self.executors: Dict[str, CypherExecutorBase] = {}
# Initialize available executors
self._initialize_executors()
def _initialize_executors(self):
"""Initialize database executors based on configuration."""
# Neo4j executor
if 'neo4j' in self.config and NEO4J_AVAILABLE:
try:
self.executors['neo4j'] = Neo4jExecutor(self.config['neo4j'])
logger.info("Neo4j executor initialized")
except Exception as e:
logger.error(f"Failed to initialize Neo4j executor: {e}")
# Memgraph executor
if 'memgraph' in self.config and NEO4J_AVAILABLE:
try:
self.executors['memgraph'] = MemgraphExecutor(self.config['memgraph'])
logger.info("Memgraph executor initialized")
except Exception as e:
logger.error(f"Failed to initialize Memgraph executor: {e}")
# FalkorDB executor
if 'falkordb' in self.config and REDIS_AVAILABLE:
try:
self.executors['falkordb'] = FalkorDBExecutor(self.config['falkordb'])
logger.info("FalkorDB executor initialized")
except Exception as e:
logger.error(f"Failed to initialize FalkorDB executor: {e}")
if not self.executors:
raise RuntimeError("No database executors could be initialized")
async def execute_cypher(self, cypher_query: CypherQuery,
database_type: str) -> CypherResult:
"""Execute Cypher query on specified database.
Args:
cypher_query: Cypher query to execute
database_type: Target database type
Returns:
Query results
"""
if database_type not in self.executors:
raise ValueError(f"Database type {database_type} not available. "
f"Available: {list(self.executors.keys())}")
executor = self.executors[database_type]
# Ensure connection
if not executor.is_connected():
await executor.connect()
# Execute query
return await executor.execute(cypher_query)
async def execute_on_all(self, cypher_query: CypherQuery) -> Dict[str, CypherResult]:
"""Execute query on all available databases.
Args:
cypher_query: Cypher query to execute
Returns:
Results from all databases
"""
results = {}
tasks = []
for db_type, executor in self.executors.items():
task = asyncio.create_task(
self.execute_cypher(cypher_query, db_type),
name=f"cypher_query_{db_type}"
)
tasks.append((db_type, task))
# Wait for all tasks to complete
for db_type, task in tasks:
try:
results[db_type] = await task
except Exception as e:
logger.error(f"Query failed on {db_type}: {e}")
results[db_type] = CypherResult(
records=[],
summary={'error': str(e)},
execution_time=0.0,
database_type=db_type
)
return results
def get_available_databases(self) -> List[str]:
"""Get list of available database types.
Returns:
List of available database type names
"""
return list(self.executors.keys())
async def close_all(self):
"""Close all database connections."""
for executor in self.executors.values():
await executor.close()
logger.info("All Cypher executor connections closed")

View file

@ -0,0 +1,628 @@
"""
Cypher query generator for ontology-sensitive queries.
Converts natural language questions to Cypher queries for graph databases.
"""
import logging
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from .question_analyzer import QuestionComponents, QuestionType
from .ontology_matcher import QueryOntologySubset
logger = logging.getLogger(__name__)
@dataclass
class CypherQuery:
"""Generated Cypher query with metadata."""
query: str
parameters: Dict[str, Any]
variables: List[str]
explanation: str
complexity_score: float
database_hints: Dict[str, Any] = None # Database-specific optimization hints
class CypherGenerator:
"""Generates Cypher queries from natural language questions using LLM assistance."""
def __init__(self, prompt_service=None):
"""Initialize Cypher generator.
Args:
prompt_service: Service for LLM-based query generation
"""
self.prompt_service = prompt_service
# Cypher query templates for common patterns
self.templates = {
'simple_node_query': """
MATCH (n:{node_label})
RETURN n.name AS name, n.{property} AS {property}
LIMIT {limit}""",
'relationship_query': """
MATCH (a:{source_label})-[r:{relationship}]->(b:{target_label})
WHERE a.name = $source_name
RETURN b.name AS name, r.{rel_property} AS property""",
'path_query': """
MATCH path = (start:{start_label})-[*1..{max_depth}]->(end:{end_label})
WHERE start.name = $start_name
RETURN path, length(path) AS path_length
ORDER BY path_length""",
'count_query': """
MATCH (n:{node_label})
{where_clause}
RETURN count(n) AS count""",
'aggregation_query': """
MATCH (n:{node_label})
{where_clause}
RETURN
count(n) AS count,
avg(n.{numeric_property}) AS average,
sum(n.{numeric_property}) AS total""",
'boolean_query': """
MATCH (a:{source_label})-[:{relationship}]->(b:{target_label})
WHERE a.name = $source_name AND b.name = $target_name
RETURN count(*) > 0 AS exists""",
'hierarchy_query': """
MATCH (child:{child_label})-[:SUBCLASS_OF*]->(parent:{parent_label})
WHERE parent.name = $parent_name
RETURN child.name AS child_name, parent.name AS parent_name""",
'property_filter_query': """
MATCH (n:{node_label})
WHERE n.{property} {operator} ${property}_value
RETURN n.name AS name, n.{property} AS {property}
ORDER BY n.{property}"""
}
async def generate_cypher(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
database_type: str = "neo4j") -> CypherQuery:
"""Generate Cypher query for a question.
Args:
question_components: Analyzed question components
ontology_subset: Relevant ontology subset
database_type: Target database (neo4j, memgraph, falkordb)
Returns:
Generated Cypher query
"""
# Try template-based generation first
template_query = self._try_template_generation(
question_components, ontology_subset, database_type
)
if template_query:
logger.debug("Generated Cypher using template")
return template_query
# Fall back to LLM-based generation
if self.prompt_service:
llm_query = await self._generate_with_llm(
question_components, ontology_subset, database_type
)
if llm_query:
logger.debug("Generated Cypher using LLM")
return llm_query
# Final fallback to simple pattern
logger.warning("Falling back to simple Cypher pattern")
return self._generate_fallback_query(question_components, ontology_subset)
def _try_template_generation(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
database_type: str) -> Optional[CypherQuery]:
"""Try to generate query using templates.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
database_type: Target database type
Returns:
Generated query or None if no template matches
"""
# Simple node query (What are the animals?)
if (question_components.question_type == QuestionType.RETRIEVAL and
len(question_components.entities) == 1):
node_label = self._find_matching_node_label(
question_components.entities[0], ontology_subset
)
if node_label:
query = self.templates['simple_node_query'].format(
node_label=node_label,
property='name',
limit=100
)
return CypherQuery(
query=query,
parameters={},
variables=['name'],
explanation=f"Retrieve all nodes of type {node_label}",
complexity_score=0.2,
database_hints=self._get_database_hints(database_type, 'simple')
)
# Count query (How many animals are there?)
if (question_components.question_type == QuestionType.AGGREGATION and
'count' in question_components.aggregations):
node_label = self._find_matching_node_label(
question_components.entities[0] if question_components.entities else 'Entity',
ontology_subset
)
if node_label:
where_clause = self._build_where_clause(question_components)
query = self.templates['count_query'].format(
node_label=node_label,
where_clause=where_clause
)
return CypherQuery(
query=query,
parameters=self._extract_parameters(question_components),
variables=['count'],
explanation=f"Count nodes of type {node_label}",
complexity_score=0.3,
database_hints=self._get_database_hints(database_type, 'aggregation')
)
# Relationship query (Which documents were authored by John Smith?)
if (question_components.question_type == QuestionType.RETRIEVAL and
len(question_components.entities) >= 2):
source_label = self._find_matching_node_label(
question_components.entities[1], ontology_subset
)
target_label = self._find_matching_node_label(
question_components.entities[0], ontology_subset
)
relationship = self._find_matching_relationship(
question_components, ontology_subset
)
if source_label and target_label and relationship:
query = self.templates['relationship_query'].format(
source_label=source_label,
target_label=target_label,
relationship=relationship,
rel_property='name'
)
return CypherQuery(
query=query,
parameters={'source_name': question_components.entities[1]},
variables=['name'],
explanation=f"Find {target_label} related to {source_label} via {relationship}",
complexity_score=0.4,
database_hints=self._get_database_hints(database_type, 'relationship')
)
# Boolean query (Is X related to Y?)
if question_components.question_type == QuestionType.BOOLEAN:
if len(question_components.entities) >= 2:
source_label = self._find_matching_node_label(
question_components.entities[0], ontology_subset
)
target_label = self._find_matching_node_label(
question_components.entities[1], ontology_subset
)
relationship = self._find_matching_relationship(
question_components, ontology_subset
)
if source_label and target_label and relationship:
query = self.templates['boolean_query'].format(
source_label=source_label,
target_label=target_label,
relationship=relationship
)
return CypherQuery(
query=query,
parameters={
'source_name': question_components.entities[0],
'target_name': question_components.entities[1]
},
variables=['exists'],
explanation="Boolean check for relationship existence",
complexity_score=0.3,
database_hints=self._get_database_hints(database_type, 'boolean')
)
return None
async def _generate_with_llm(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
database_type: str) -> Optional[CypherQuery]:
"""Generate Cypher using LLM.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
database_type: Target database type
Returns:
Generated query or None if failed
"""
try:
prompt = self._build_cypher_prompt(
question_components, ontology_subset, database_type
)
response = await self.prompt_service.generate_cypher(prompt=prompt)
if response and isinstance(response, dict):
query = response.get('query', '').strip()
if query.upper().startswith(('MATCH', 'CREATE', 'MERGE', 'DELETE', 'RETURN')):
return CypherQuery(
query=query,
parameters=response.get('parameters', {}),
variables=self._extract_variables(query),
explanation=response.get('explanation', 'Generated by LLM'),
complexity_score=self._calculate_complexity(query),
database_hints=self._get_database_hints(database_type, 'complex')
)
except Exception as e:
logger.error(f"LLM Cypher generation failed: {e}")
return None
def _build_cypher_prompt(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
database_type: str) -> str:
"""Build prompt for LLM Cypher generation.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
database_type: Target database type
Returns:
Formatted prompt string
"""
# Format ontology elements as node labels and relationships
node_labels = self._format_node_labels(ontology_subset.classes)
relationships = self._format_relationships(
ontology_subset.object_properties,
ontology_subset.datatype_properties
)
prompt = f"""Generate a Cypher query for the following question using the provided ontology.
QUESTION: {question_components.original_question}
TARGET DATABASE: {database_type}
AVAILABLE NODE LABELS (from classes):
{node_labels}
AVAILABLE RELATIONSHIP TYPES (from properties):
{relationships}
RULES:
- Use MATCH patterns for graph traversal
- Include WHERE clauses for filters
- Use aggregation functions when needed (COUNT, SUM, AVG)
- Optimize for {database_type} performance
- Consider index hints for large datasets
- Use parameters for values (e.g., $name)
QUERY TYPE HINTS:
- Question type: {question_components.question_type.value}
- Expected answer: {question_components.expected_answer_type}
- Entities mentioned: {', '.join(question_components.entities)}
- Aggregations: {', '.join(question_components.aggregations)}
DATABASE-SPECIFIC OPTIMIZATIONS:
{self._get_database_specific_hints(database_type)}
Generate a complete Cypher query with parameters:"""
return prompt
def _generate_fallback_query(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> CypherQuery:
"""Generate simple fallback query.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Basic Cypher query
"""
# Very basic MATCH query
first_class = list(ontology_subset.classes.keys())[0] if ontology_subset.classes else 'Entity'
query = f"""MATCH (n:{first_class})
WHERE n.name CONTAINS $keyword
RETURN n.name AS name, labels(n) AS types
LIMIT 10"""
return CypherQuery(
query=query,
parameters={'keyword': question_components.keywords[0] if question_components.keywords else 'entity'},
variables=['name', 'types'],
explanation="Fallback query for basic pattern matching",
complexity_score=0.1
)
def _find_matching_node_label(self, entity: str, ontology_subset: QueryOntologySubset) -> Optional[str]:
"""Find matching node label in ontology subset.
Args:
entity: Entity string to match
ontology_subset: Ontology subset
Returns:
Matching node label or None
"""
entity_lower = entity.lower()
# Direct match
for class_id in ontology_subset.classes:
if class_id.lower() == entity_lower:
return class_id
# Label match
for class_id, class_def in ontology_subset.classes.items():
labels = class_def.get('labels', [])
for label in labels:
if isinstance(label, dict):
label_value = label.get('value', '').lower()
if label_value == entity_lower:
return class_id
# Partial match
for class_id in ontology_subset.classes:
if entity_lower in class_id.lower() or class_id.lower() in entity_lower:
return class_id
return None
def _find_matching_relationship(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Optional[str]:
"""Find matching relationship type.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Matching relationship type or None
"""
# Look for relationship keywords
for keyword in question_components.keywords:
keyword_lower = keyword.lower()
# Check object properties
for prop_id in ontology_subset.object_properties:
if keyword_lower in prop_id.lower() or prop_id.lower() in keyword_lower:
return prop_id.upper().replace('-', '_')
# Common relationship mappings
relationship_mappings = {
'author': 'AUTHORED_BY',
'created': 'CREATED_BY',
'owns': 'OWNS',
'has': 'HAS',
'contains': 'CONTAINS',
'parent': 'PARENT_OF',
'child': 'CHILD_OF',
'related': 'RELATED_TO'
}
for keyword in question_components.keywords:
if keyword.lower() in relationship_mappings:
return relationship_mappings[keyword.lower()]
# Default relationship
return 'RELATED_TO'
def _build_where_clause(self, question_components: QuestionComponents) -> str:
"""Build WHERE clause for Cypher query.
Args:
question_components: Question analysis
Returns:
WHERE clause string
"""
conditions = []
for constraint in question_components.constraints:
if 'greater than' in constraint.lower():
import re
numbers = re.findall(r'\d+', constraint)
if numbers:
conditions.append(f"n.value > {numbers[0]}")
elif 'less than' in constraint.lower():
numbers = re.findall(r'\d+', constraint)
if numbers:
conditions.append(f"n.value < {numbers[0]}")
if conditions:
return f"WHERE {' AND '.join(conditions)}"
return ""
def _extract_parameters(self, question_components: QuestionComponents) -> Dict[str, Any]:
"""Extract parameters from question components.
Args:
question_components: Question analysis
Returns:
Parameters dictionary
"""
parameters = {}
# Extract numeric values
import re
for constraint in question_components.constraints:
numbers = re.findall(r'\d+', constraint)
for i, number in enumerate(numbers):
parameters[f'value_{i}'] = int(number)
return parameters
def _format_node_labels(self, classes: Dict[str, Any]) -> str:
"""Format classes as node labels for prompt.
Args:
classes: Classes dictionary
Returns:
Formatted node labels string
"""
if not classes:
return "None"
lines = []
for class_id, definition in classes.items():
comment = definition.get('comment', '')
lines.append(f"- :{class_id} - {comment}")
return '\n'.join(lines)
def _format_relationships(self,
object_props: Dict[str, Any],
datatype_props: Dict[str, Any]) -> str:
"""Format properties as relationships for prompt.
Args:
object_props: Object properties
datatype_props: Datatype properties
Returns:
Formatted relationships string
"""
lines = []
for prop_id, definition in object_props.items():
domain = definition.get('domain', 'Any')
range_val = definition.get('range', 'Any')
comment = definition.get('comment', '')
rel_type = prop_id.upper().replace('-', '_')
lines.append(f"- :{rel_type} ({domain} -> {range_val}) - {comment}")
return '\n'.join(lines) if lines else "None"
def _extract_variables(self, query: str) -> List[str]:
"""Extract variables from Cypher query.
Args:
query: Cypher query string
Returns:
List of variable names
"""
import re
# Extract RETURN clause variables
return_match = re.search(r'RETURN\s+(.+?)(?:ORDER|LIMIT|$)', query, re.IGNORECASE | re.DOTALL)
if return_match:
return_clause = return_match.group(1)
variables = re.findall(r'(\w+)(?:\s+AS\s+(\w+))?', return_clause)
return [var[1] if var[1] else var[0] for var in variables]
return []
def _calculate_complexity(self, query: str) -> float:
"""Calculate complexity score for Cypher query.
Args:
query: Cypher query string
Returns:
Complexity score (0.0 to 1.0)
"""
complexity = 0.0
query_upper = query.upper()
# Count different Cypher features
if 'JOIN' in query_upper or 'UNION' in query_upper:
complexity += 0.3
if 'WHERE' in query_upper:
complexity += 0.2
if 'OPTIONAL' in query_upper:
complexity += 0.1
if 'ORDER BY' in query_upper:
complexity += 0.1
if '*' in query: # Variable length paths
complexity += 0.2
if any(agg in query_upper for agg in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']):
complexity += 0.2
# Count path length
path_matches = re.findall(r'\[.*?\*(\d+)\.\.(\d+).*?\]', query)
for start, end in path_matches:
complexity += (int(end) - int(start)) * 0.05
return min(complexity, 1.0)
def _get_database_hints(self, database_type: str, query_category: str) -> Dict[str, Any]:
"""Get database-specific optimization hints.
Args:
database_type: Target database
query_category: Category of query
Returns:
Optimization hints
"""
hints = {}
if database_type == "neo4j":
hints.update({
'use_index': True,
'explain_plan': 'EXPLAIN',
'profile_query': 'PROFILE'
})
elif database_type == "memgraph":
hints.update({
'use_index': True,
'explain_plan': 'EXPLAIN',
'memory_limit': '1GB'
})
elif database_type == "falkordb":
hints.update({
'use_index': False, # Redis-based, different indexing
'cache_result': True
})
return hints
def _get_database_specific_hints(self, database_type: str) -> str:
"""Get database-specific optimization hints as text.
Args:
database_type: Target database
Returns:
Hints as formatted string
"""
if database_type == "neo4j":
return """- Use USING INDEX hints for large datasets
- Consider PROFILE for query optimization
- Prefer MERGE over CREATE when appropriate"""
elif database_type == "memgraph":
return """- Leverage in-memory processing advantages
- Use streaming for large result sets
- Consider query parallelization"""
elif database_type == "falkordb":
return """- Optimize for Redis memory constraints
- Use simple patterns for best performance
- Leverage Redis data structures when possible"""
else:
return "- Use standard Cypher optimization patterns"

View file

@ -0,0 +1,557 @@
"""
Error handling and recovery mechanisms for OntoRAG.
Provides comprehensive error handling, retry logic, and graceful degradation.
"""
import logging
import time
import asyncio
from typing import Dict, Any, List, Optional, Callable, Union, Type
from dataclasses import dataclass
from enum import Enum
from functools import wraps
import traceback
logger = logging.getLogger(__name__)
class ErrorSeverity(Enum):
"""Error severity levels."""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class ErrorCategory(Enum):
"""Error categories for better handling."""
ONTOLOGY_LOADING = "ontology_loading"
QUESTION_ANALYSIS = "question_analysis"
QUERY_GENERATION = "query_generation"
QUERY_EXECUTION = "query_execution"
ANSWER_GENERATION = "answer_generation"
BACKEND_CONNECTION = "backend_connection"
CACHE_ERROR = "cache_error"
VALIDATION_ERROR = "validation_error"
TIMEOUT_ERROR = "timeout_error"
AUTHENTICATION_ERROR = "authentication_error"
@dataclass
class ErrorContext:
"""Context information for an error."""
category: ErrorCategory
severity: ErrorSeverity
component: str
operation: str
user_message: Optional[str] = None
technical_details: Optional[str] = None
suggestion: Optional[str] = None
retry_count: int = 0
max_retries: int = 3
metadata: Dict[str, Any] = None
class OntoRAGError(Exception):
"""Base exception for OntoRAG system."""
def __init__(self,
message: str,
context: Optional[ErrorContext] = None,
cause: Optional[Exception] = None):
"""Initialize OntoRAG error.
Args:
message: Error message
context: Error context
cause: Original exception that caused this error
"""
super().__init__(message)
self.message = message
self.context = context or ErrorContext(
category=ErrorCategory.VALIDATION_ERROR,
severity=ErrorSeverity.MEDIUM,
component="unknown",
operation="unknown"
)
self.cause = cause
self.timestamp = time.time()
class OntologyLoadingError(OntoRAGError):
"""Error loading ontology."""
pass
class QuestionAnalysisError(OntoRAGError):
"""Error analyzing question."""
pass
class QueryGenerationError(OntoRAGError):
"""Error generating query."""
pass
class QueryExecutionError(OntoRAGError):
"""Error executing query."""
pass
class AnswerGenerationError(OntoRAGError):
"""Error generating answer."""
pass
class BackendConnectionError(OntoRAGError):
"""Error connecting to backend."""
pass
class TimeoutError(OntoRAGError):
"""Operation timeout error."""
pass
@dataclass
class RetryConfig:
"""Configuration for retry logic."""
max_retries: int = 3
base_delay: float = 1.0
max_delay: float = 60.0
exponential_backoff: bool = True
jitter: bool = True
retry_on_exceptions: List[Type[Exception]] = None
class ErrorRecoveryStrategy:
"""Strategy for handling and recovering from errors."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize error recovery strategy.
Args:
config: Recovery configuration
"""
self.config = config or {}
self.retry_configs = self._build_retry_configs()
self.fallback_strategies = self._build_fallback_strategies()
self.error_counters: Dict[str, int] = {}
self.circuit_breakers: Dict[str, Dict[str, Any]] = {}
def _build_retry_configs(self) -> Dict[ErrorCategory, RetryConfig]:
"""Build retry configurations for different error categories."""
return {
ErrorCategory.BACKEND_CONNECTION: RetryConfig(
max_retries=5,
base_delay=2.0,
retry_on_exceptions=[BackendConnectionError, ConnectionError, TimeoutError]
),
ErrorCategory.QUERY_EXECUTION: RetryConfig(
max_retries=3,
base_delay=1.0,
retry_on_exceptions=[QueryExecutionError, TimeoutError]
),
ErrorCategory.ONTOLOGY_LOADING: RetryConfig(
max_retries=2,
base_delay=0.5,
retry_on_exceptions=[OntologyLoadingError, IOError]
),
ErrorCategory.QUESTION_ANALYSIS: RetryConfig(
max_retries=2,
base_delay=1.0,
retry_on_exceptions=[QuestionAnalysisError, TimeoutError]
),
ErrorCategory.ANSWER_GENERATION: RetryConfig(
max_retries=2,
base_delay=1.0,
retry_on_exceptions=[AnswerGenerationError, TimeoutError]
)
}
def _build_fallback_strategies(self) -> Dict[ErrorCategory, Callable]:
"""Build fallback strategies for different error categories."""
return {
ErrorCategory.QUESTION_ANALYSIS: self._fallback_question_analysis,
ErrorCategory.QUERY_GENERATION: self._fallback_query_generation,
ErrorCategory.QUERY_EXECUTION: self._fallback_query_execution,
ErrorCategory.ANSWER_GENERATION: self._fallback_answer_generation,
ErrorCategory.BACKEND_CONNECTION: self._fallback_backend_connection
}
async def handle_error(self,
error: Exception,
context: ErrorContext,
operation: Callable,
*args,
**kwargs) -> Any:
"""Handle error with recovery strategies.
Args:
error: The exception that occurred
context: Error context
operation: Function to retry
*args: Operation arguments
**kwargs: Operation keyword arguments
Returns:
Result of successful operation or fallback
"""
logger.error(f"Handling error in {context.component}.{context.operation}: {error}")
# Update error counters
error_key = f"{context.category.value}:{context.component}"
self.error_counters[error_key] = self.error_counters.get(error_key, 0) + 1
# Check circuit breaker
if self._is_circuit_open(error_key):
return await self._execute_fallback(context, *args, **kwargs)
# Try retry if configured
retry_config = self.retry_configs.get(context.category)
if retry_config and context.retry_count < retry_config.max_retries:
if any(isinstance(error, exc_type) for exc_type in retry_config.retry_on_exceptions or []):
return await self._retry_operation(
operation, context, retry_config, *args, **kwargs
)
# Execute fallback strategy
return await self._execute_fallback(context, *args, **kwargs)
async def _retry_operation(self,
operation: Callable,
context: ErrorContext,
retry_config: RetryConfig,
*args,
**kwargs) -> Any:
"""Retry operation with backoff."""
context.retry_count += 1
# Calculate delay
delay = retry_config.base_delay
if retry_config.exponential_backoff:
delay *= (2 ** (context.retry_count - 1))
delay = min(delay, retry_config.max_delay)
# Add jitter
if retry_config.jitter:
import random
delay *= (0.5 + random.random())
logger.info(f"Retrying {context.component}.{context.operation} "
f"(attempt {context.retry_count}) after {delay:.2f}s")
await asyncio.sleep(delay)
try:
if asyncio.iscoroutinefunction(operation):
return await operation(*args, **kwargs)
else:
return operation(*args, **kwargs)
except Exception as e:
return await self.handle_error(e, context, operation, *args, **kwargs)
async def _execute_fallback(self,
context: ErrorContext,
*args,
**kwargs) -> Any:
"""Execute fallback strategy."""
fallback_func = self.fallback_strategies.get(context.category)
if fallback_func:
logger.info(f"Executing fallback for {context.category.value}")
try:
if asyncio.iscoroutinefunction(fallback_func):
return await fallback_func(context, *args, **kwargs)
else:
return fallback_func(context, *args, **kwargs)
except Exception as e:
logger.error(f"Fallback strategy failed: {e}")
# Default fallback
return self._default_fallback(context)
def _is_circuit_open(self, error_key: str) -> bool:
"""Check if circuit breaker is open."""
circuit = self.circuit_breakers.get(error_key, {})
error_count = self.error_counters.get(error_key, 0)
error_threshold = self.config.get('circuit_breaker_threshold', 10)
window_seconds = self.config.get('circuit_breaker_window', 300) # 5 minutes
current_time = time.time()
window_start = circuit.get('window_start', current_time)
# Reset window if expired
if current_time - window_start > window_seconds:
self.circuit_breakers[error_key] = {'window_start': current_time}
self.error_counters[error_key] = 0
return False
return error_count >= error_threshold
def _default_fallback(self, context: ErrorContext) -> Any:
"""Default fallback response."""
if context.category == ErrorCategory.ANSWER_GENERATION:
return "I'm sorry, I encountered an error while processing your question. Please try again."
elif context.category == ErrorCategory.QUERY_EXECUTION:
return {"error": "Query execution failed", "results": []}
else:
return None
# Fallback strategy implementations
async def _fallback_question_analysis(self, context: ErrorContext, question: str, **kwargs):
"""Fallback for question analysis."""
from .question_analyzer import QuestionComponents, QuestionType
# Simple keyword-based analysis
question_lower = question.lower()
# Determine question type
if any(word in question_lower for word in ['how many', 'count', 'number']):
question_type = QuestionType.AGGREGATION
elif question_lower.startswith(('is', 'are', 'does', 'can')):
question_type = QuestionType.BOOLEAN
elif any(word in question_lower for word in ['what', 'which', 'who', 'where']):
question_type = QuestionType.RETRIEVAL
else:
question_type = QuestionType.FACTUAL
# Extract simple entities (nouns)
import re
words = re.findall(r'\b[a-zA-Z]+\b', question)
entities = [word for word in words if len(word) > 3 and word.lower() not in
{'what', 'which', 'where', 'when', 'who', 'how', 'does', 'are', 'the'}]
return QuestionComponents(
original_question=question,
normalized_question=question.lower(),
question_type=question_type,
entities=entities[:3], # Limit to 3 entities
keywords=words[:5], # Limit to 5 keywords
relationships=[],
constraints=[],
aggregations=['count'] if question_type == QuestionType.AGGREGATION else [],
expected_answer_type='text'
)
async def _fallback_query_generation(self, context: ErrorContext, **kwargs):
"""Fallback for query generation."""
# Generate simple query based on available information
if 'sparql' in context.metadata.get('query_language', '').lower():
query = """
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?subject ?predicate ?object WHERE {
?subject ?predicate ?object .
}
LIMIT 10
"""
from .sparql_generator import SPARQLQuery
return SPARQLQuery(
query=query,
variables=['subject', 'predicate', 'object'],
query_type='SELECT',
explanation='Fallback SPARQL query',
complexity_score=0.1
)
else:
query = "MATCH (n) RETURN n LIMIT 10"
from .cypher_generator import CypherQuery
return CypherQuery(
query=query,
variables=['n'],
query_type='MATCH',
explanation='Fallback Cypher query',
complexity_score=0.1
)
async def _fallback_query_execution(self, context: ErrorContext, **kwargs):
"""Fallback for query execution."""
# Return empty results
if 'sparql' in context.metadata.get('query_language', '').lower():
from .sparql_cassandra import SPARQLResult
return SPARQLResult(
bindings=[],
variables=[],
execution_time=0.0
)
else:
from .cypher_executor import CypherResult
return CypherResult(
records=[],
summary={'type': 'fallback'},
metadata={'query': 'fallback'},
execution_time=0.0
)
async def _fallback_answer_generation(self, context: ErrorContext, question: str = None, **kwargs):
"""Fallback for answer generation."""
fallback_messages = [
"I'm experiencing some technical difficulties. Please try rephrasing your question.",
"I couldn't process your question at the moment. Could you try asking it differently?",
"There seems to be an issue with my analysis. Please try again in a moment.",
"I'm having trouble understanding your question right now. Please try again."
]
import random
return random.choice(fallback_messages)
async def _fallback_backend_connection(self, context: ErrorContext, **kwargs):
"""Fallback for backend connection."""
logger.warning(f"Backend connection failed for {context.component}")
# Could switch to alternative backend here
return None
def with_error_handling(category: ErrorCategory,
component: str,
operation: str,
severity: ErrorSeverity = ErrorSeverity.MEDIUM):
"""Decorator for automatic error handling.
Args:
category: Error category
component: Component name
operation: Operation name
severity: Error severity
"""
def decorator(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
try:
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return func(*args, **kwargs)
except Exception as e:
context = ErrorContext(
category=category,
severity=severity,
component=component,
operation=operation,
technical_details=str(e),
metadata={'args': str(args), 'kwargs': str(kwargs)}
)
# Get error recovery strategy from first argument if it's available
error_strategy = None
if args and hasattr(args[0], '_error_strategy'):
error_strategy = args[0]._error_strategy
if error_strategy:
return await error_strategy.handle_error(e, context, func, *args, **kwargs)
else:
# Re-raise as OntoRAG error
raise OntoRAGError(
f"Error in {component}.{operation}: {str(e)}",
context=context,
cause=e
)
@wraps(func)
def sync_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
context = ErrorContext(
category=category,
severity=severity,
component=component,
operation=operation,
technical_details=str(e),
metadata={'args': str(args), 'kwargs': str(kwargs)}
)
raise OntoRAGError(
f"Error in {component}.{operation}: {str(e)}",
context=context,
cause=e
)
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
class ErrorReporter:
"""Reports and tracks errors for monitoring and debugging."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize error reporter.
Args:
config: Reporter configuration
"""
self.config = config or {}
self.error_log: List[Dict[str, Any]] = []
self.max_log_size = self.config.get('max_log_size', 1000)
def report_error(self, error: OntoRAGError):
"""Report an error for tracking.
Args:
error: The error to report
"""
error_entry = {
'timestamp': error.timestamp,
'message': error.message,
'category': error.context.category.value,
'severity': error.context.severity.value,
'component': error.context.component,
'operation': error.context.operation,
'retry_count': error.context.retry_count,
'technical_details': error.context.technical_details,
'stack_trace': traceback.format_exc() if error.cause else None
}
self.error_log.append(error_entry)
# Trim log if too large
if len(self.error_log) > self.max_log_size:
self.error_log = self.error_log[-self.max_log_size:]
# Log based on severity
if error.context.severity == ErrorSeverity.CRITICAL:
logger.critical(f"CRITICAL ERROR: {error.message}")
elif error.context.severity == ErrorSeverity.HIGH:
logger.error(f"HIGH SEVERITY: {error.message}")
elif error.context.severity == ErrorSeverity.MEDIUM:
logger.warning(f"MEDIUM SEVERITY: {error.message}")
else:
logger.info(f"LOW SEVERITY: {error.message}")
def get_error_summary(self) -> Dict[str, Any]:
"""Get summary of recent errors.
Returns:
Error summary statistics
"""
if not self.error_log:
return {'total_errors': 0}
recent_errors = [
e for e in self.error_log
if time.time() - e['timestamp'] < 3600 # Last hour
]
category_counts = {}
severity_counts = {}
component_counts = {}
for error in recent_errors:
category_counts[error['category']] = category_counts.get(error['category'], 0) + 1
severity_counts[error['severity']] = severity_counts.get(error['severity'], 0) + 1
component_counts[error['component']] = component_counts.get(error['component'], 0) + 1
return {
'total_errors': len(self.error_log),
'recent_errors': len(recent_errors),
'category_breakdown': category_counts,
'severity_breakdown': severity_counts,
'component_breakdown': component_counts,
'most_recent_error': self.error_log[-1] if self.error_log else None
}

View file

@ -0,0 +1,737 @@
"""
Performance monitoring and metrics collection for OntoRAG.
Provides comprehensive monitoring of system performance, query patterns, and resource usage.
"""
import logging
import time
import asyncio
import threading
from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from collections import defaultdict, deque
import statistics
from enum import Enum
logger = logging.getLogger(__name__)
class MetricType(Enum):
"""Types of metrics to collect."""
COUNTER = "counter"
GAUGE = "gauge"
HISTOGRAM = "histogram"
TIMER = "timer"
@dataclass
class Metric:
"""Individual metric data point."""
name: str
value: float
timestamp: datetime
labels: Dict[str, str] = field(default_factory=dict)
metric_type: MetricType = MetricType.GAUGE
@dataclass
class TimerMetric:
"""Timer metric for measuring duration."""
name: str
start_time: float
labels: Dict[str, str] = field(default_factory=dict)
def stop(self) -> float:
"""Stop timer and return duration."""
return time.time() - self.start_time
@dataclass
class PerformanceStats:
"""Performance statistics for a component."""
total_requests: int = 0
successful_requests: int = 0
failed_requests: int = 0
avg_response_time: float = 0.0
min_response_time: float = float('inf')
max_response_time: float = 0.0
p95_response_time: float = 0.0
p99_response_time: float = 0.0
throughput_per_second: float = 0.0
error_rate: float = 0.0
@dataclass
class SystemHealth:
"""Overall system health metrics."""
status: str = "healthy" # healthy, degraded, unhealthy
uptime_seconds: float = 0.0
cpu_usage_percent: float = 0.0
memory_usage_percent: float = 0.0
active_connections: int = 0
queue_size: int = 0
cache_hit_rate: float = 0.0
error_rate: float = 0.0
class MetricsCollector:
"""Collects and stores metrics data."""
def __init__(self, max_metrics: int = 10000, retention_hours: int = 24):
"""Initialize metrics collector.
Args:
max_metrics: Maximum number of metrics to retain
retention_hours: Hours to retain metrics
"""
self.max_metrics = max_metrics
self.retention_hours = retention_hours
self.metrics: Dict[str, deque] = defaultdict(lambda: deque(maxlen=max_metrics))
self.counters: Dict[str, float] = defaultdict(float)
self.gauges: Dict[str, float] = defaultdict(float)
self.timers: Dict[str, List[float]] = defaultdict(list)
self._lock = threading.RLock()
def increment(self, name: str, value: float = 1.0, labels: Dict[str, str] = None):
"""Increment a counter metric.
Args:
name: Metric name
value: Value to increment by
labels: Metric labels
"""
with self._lock:
metric_key = self._build_key(name, labels)
self.counters[metric_key] += value
self._add_metric(name, value, MetricType.COUNTER, labels)
def set_gauge(self, name: str, value: float, labels: Dict[str, str] = None):
"""Set a gauge metric value.
Args:
name: Metric name
value: Gauge value
labels: Metric labels
"""
with self._lock:
metric_key = self._build_key(name, labels)
self.gauges[metric_key] = value
self._add_metric(name, value, MetricType.GAUGE, labels)
def record_timer(self, name: str, duration: float, labels: Dict[str, str] = None):
"""Record a timer measurement.
Args:
name: Metric name
duration: Duration in seconds
labels: Metric labels
"""
with self._lock:
metric_key = self._build_key(name, labels)
self.timers[metric_key].append(duration)
# Keep only recent measurements
max_timer_values = 1000
if len(self.timers[metric_key]) > max_timer_values:
self.timers[metric_key] = self.timers[metric_key][-max_timer_values:]
self._add_metric(name, duration, MetricType.TIMER, labels)
def start_timer(self, name: str, labels: Dict[str, str] = None) -> TimerMetric:
"""Start a timer.
Args:
name: Metric name
labels: Metric labels
Returns:
Timer metric object
"""
return TimerMetric(name=name, start_time=time.time(), labels=labels or {})
def stop_timer(self, timer: TimerMetric):
"""Stop a timer and record the measurement.
Args:
timer: Timer metric to stop
"""
duration = timer.stop()
self.record_timer(timer.name, duration, timer.labels)
return duration
def get_counter(self, name: str, labels: Dict[str, str] = None) -> float:
"""Get counter value.
Args:
name: Metric name
labels: Metric labels
Returns:
Counter value
"""
metric_key = self._build_key(name, labels)
return self.counters.get(metric_key, 0.0)
def get_gauge(self, name: str, labels: Dict[str, str] = None) -> float:
"""Get gauge value.
Args:
name: Metric name
labels: Metric labels
Returns:
Gauge value
"""
metric_key = self._build_key(name, labels)
return self.gauges.get(metric_key, 0.0)
def get_timer_stats(self, name: str, labels: Dict[str, str] = None) -> Dict[str, float]:
"""Get timer statistics.
Args:
name: Metric name
labels: Metric labels
Returns:
Timer statistics
"""
metric_key = self._build_key(name, labels)
values = self.timers.get(metric_key, [])
if not values:
return {}
sorted_values = sorted(values)
return {
'count': len(values),
'sum': sum(values),
'avg': statistics.mean(values),
'min': min(values),
'max': max(values),
'p50': sorted_values[int(len(sorted_values) * 0.5)],
'p95': sorted_values[int(len(sorted_values) * 0.95)],
'p99': sorted_values[int(len(sorted_values) * 0.99)]
}
def get_metrics(self,
name_pattern: Optional[str] = None,
since: Optional[datetime] = None) -> List[Metric]:
"""Get metrics matching pattern and time range.
Args:
name_pattern: Pattern to match metric names
since: Only return metrics since this time
Returns:
List of matching metrics
"""
with self._lock:
results = []
cutoff_time = since or datetime.now() - timedelta(hours=self.retention_hours)
for metric_name, metric_queue in self.metrics.items():
if name_pattern and name_pattern not in metric_name:
continue
for metric in metric_queue:
if metric.timestamp >= cutoff_time:
results.append(metric)
return sorted(results, key=lambda m: m.timestamp)
def cleanup_old_metrics(self):
"""Remove old metrics beyond retention period."""
with self._lock:
cutoff_time = datetime.now() - timedelta(hours=self.retention_hours)
for metric_name in list(self.metrics.keys()):
metric_queue = self.metrics[metric_name]
# Remove old metrics
while metric_queue and metric_queue[0].timestamp < cutoff_time:
metric_queue.popleft()
# Remove empty queues
if not metric_queue:
del self.metrics[metric_name]
def _add_metric(self, name: str, value: float, metric_type: MetricType, labels: Dict[str, str]):
"""Add metric to storage."""
metric = Metric(
name=name,
value=value,
timestamp=datetime.now(),
labels=labels or {},
metric_type=metric_type
)
self.metrics[name].append(metric)
def _build_key(self, name: str, labels: Dict[str, str]) -> str:
"""Build metric key from name and labels."""
if not labels:
return name
label_str = ','.join(f"{k}={v}" for k, v in sorted(labels.items()))
return f"{name}{{{label_str}}}"
class PerformanceMonitor:
"""Monitors system performance and component health."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize performance monitor.
Args:
config: Monitor configuration
"""
self.config = config or {}
self.metrics_collector = MetricsCollector(
max_metrics=self.config.get('max_metrics', 10000),
retention_hours=self.config.get('retention_hours', 24)
)
self.component_stats: Dict[str, PerformanceStats] = {}
self.start_time = time.time()
self.monitoring_enabled = self.config.get('enabled', True)
# Start background monitoring tasks
if self.monitoring_enabled:
self._start_background_tasks()
def record_request(self,
component: str,
operation: str,
duration: float,
success: bool = True,
labels: Dict[str, str] = None):
"""Record a request completion.
Args:
component: Component name
operation: Operation name
duration: Request duration in seconds
success: Whether request was successful
labels: Additional labels
"""
if not self.monitoring_enabled:
return
base_labels = {'component': component, 'operation': operation}
if labels:
base_labels.update(labels)
# Record metrics
self.metrics_collector.increment('requests_total', labels=base_labels)
self.metrics_collector.record_timer('request_duration', duration, base_labels)
if success:
self.metrics_collector.increment('requests_successful', labels=base_labels)
else:
self.metrics_collector.increment('requests_failed', labels=base_labels)
# Update component stats
self._update_component_stats(component, duration, success)
def record_query_complexity(self,
complexity_score: float,
query_type: str,
backend: str):
"""Record query complexity metrics.
Args:
complexity_score: Query complexity score (0.0 to 1.0)
query_type: Type of query (SPARQL, Cypher)
backend: Backend used
"""
if not self.monitoring_enabled:
return
labels = {'query_type': query_type, 'backend': backend}
self.metrics_collector.set_gauge('query_complexity', complexity_score, labels)
def record_cache_access(self, hit: bool, cache_type: str = 'default'):
"""Record cache access.
Args:
hit: Whether it was a cache hit
cache_type: Type of cache
"""
if not self.monitoring_enabled:
return
labels = {'cache_type': cache_type}
self.metrics_collector.increment('cache_requests_total', labels=labels)
if hit:
self.metrics_collector.increment('cache_hits_total', labels=labels)
else:
self.metrics_collector.increment('cache_misses_total', labels=labels)
def record_ontology_selection(self,
selected_elements: int,
total_elements: int,
ontology_id: str):
"""Record ontology selection metrics.
Args:
selected_elements: Number of selected ontology elements
total_elements: Total ontology elements
ontology_id: Ontology identifier
"""
if not self.monitoring_enabled:
return
labels = {'ontology_id': ontology_id}
self.metrics_collector.set_gauge('ontology_elements_selected', selected_elements, labels)
self.metrics_collector.set_gauge('ontology_elements_total', total_elements, labels)
selection_ratio = selected_elements / total_elements if total_elements > 0 else 0
self.metrics_collector.set_gauge('ontology_selection_ratio', selection_ratio, labels)
def get_component_stats(self, component: str) -> Optional[PerformanceStats]:
"""Get performance statistics for a component.
Args:
component: Component name
Returns:
Performance statistics or None
"""
return self.component_stats.get(component)
def get_system_health(self) -> SystemHealth:
"""Get overall system health status.
Returns:
System health metrics
"""
# Calculate uptime
uptime = time.time() - self.start_time
# Get error rate
total_requests = self.metrics_collector.get_counter('requests_total')
failed_requests = self.metrics_collector.get_counter('requests_failed')
error_rate = failed_requests / total_requests if total_requests > 0 else 0.0
# Get cache hit rate
cache_hits = self.metrics_collector.get_counter('cache_hits_total')
cache_requests = self.metrics_collector.get_counter('cache_requests_total')
cache_hit_rate = cache_hits / cache_requests if cache_requests > 0 else 0.0
# Determine status
status = "healthy"
if error_rate > 0.1: # More than 10% error rate
status = "degraded"
if error_rate > 0.3: # More than 30% error rate
status = "unhealthy"
return SystemHealth(
status=status,
uptime_seconds=uptime,
error_rate=error_rate,
cache_hit_rate=cache_hit_rate
)
def get_performance_report(self) -> Dict[str, Any]:
"""Get comprehensive performance report.
Returns:
Performance report
"""
report = {
'system_health': self.get_system_health(),
'component_stats': {},
'top_slow_operations': [],
'error_patterns': {},
'cache_performance': {},
'ontology_usage': {}
}
# Component statistics
for component, stats in self.component_stats.items():
report['component_stats'][component] = stats
# Top slow operations
timer_stats = {}
for metric_name in self.metrics_collector.timers.keys():
if 'request_duration' in metric_name:
stats = self.metrics_collector.get_timer_stats(metric_name)
if stats:
timer_stats[metric_name] = stats
# Sort by p95 latency
slow_ops = sorted(
timer_stats.items(),
key=lambda x: x[1].get('p95', 0),
reverse=True
)[:10]
report['top_slow_operations'] = [
{'operation': op, 'stats': stats} for op, stats in slow_ops
]
# Cache performance
cache_types = set()
for metric_name in self.metrics_collector.counters.keys():
if 'cache_type=' in metric_name:
cache_type = metric_name.split('cache_type=')[1].split(',')[0].split('}')[0]
cache_types.add(cache_type)
for cache_type in cache_types:
labels = {'cache_type': cache_type}
hits = self.metrics_collector.get_counter('cache_hits_total', labels)
requests = self.metrics_collector.get_counter('cache_requests_total', labels)
hit_rate = hits / requests if requests > 0 else 0.0
report['cache_performance'][cache_type] = {
'hit_rate': hit_rate,
'total_requests': requests,
'total_hits': hits
}
return report
def _update_component_stats(self, component: str, duration: float, success: bool):
"""Update component performance statistics."""
if component not in self.component_stats:
self.component_stats[component] = PerformanceStats()
stats = self.component_stats[component]
stats.total_requests += 1
if success:
stats.successful_requests += 1
else:
stats.failed_requests += 1
# Update response time stats
stats.min_response_time = min(stats.min_response_time, duration)
stats.max_response_time = max(stats.max_response_time, duration)
# Get timer stats for percentiles
timer_stats = self.metrics_collector.get_timer_stats(
'request_duration', {'component': component}
)
if timer_stats:
stats.avg_response_time = timer_stats.get('avg', 0.0)
stats.p95_response_time = timer_stats.get('p95', 0.0)
stats.p99_response_time = timer_stats.get('p99', 0.0)
# Calculate rates
stats.error_rate = stats.failed_requests / stats.total_requests
# Calculate throughput (requests per second over last minute)
recent_requests = len([
m for m in self.metrics_collector.get_metrics('requests_total')
if m.labels.get('component') == component and
m.timestamp > datetime.now() - timedelta(minutes=1)
])
stats.throughput_per_second = recent_requests / 60.0
def _start_background_tasks(self):
"""Start background monitoring tasks."""
def cleanup_worker():
"""Worker to clean up old metrics."""
while True:
try:
time.sleep(300) # 5 minutes
self.metrics_collector.cleanup_old_metrics()
except Exception as e:
logger.error(f"Metrics cleanup error: {e}")
cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
cleanup_thread.start()
# Monitoring decorators
def monitor_performance(component: str,
operation: str,
monitor: Optional[PerformanceMonitor] = None):
"""Decorator to monitor function performance.
Args:
component: Component name
operation: Operation name
monitor: Performance monitor instance
"""
def decorator(func):
def wrapper(*args, **kwargs):
if not monitor or not monitor.monitoring_enabled:
return func(*args, **kwargs)
timer = monitor.metrics_collector.start_timer(
'request_duration',
{'component': component, 'operation': operation}
)
success = True
try:
result = func(*args, **kwargs)
return result
except Exception as e:
success = False
raise
finally:
duration = monitor.metrics_collector.stop_timer(timer)
monitor.record_request(component, operation, duration, success)
async def async_wrapper(*args, **kwargs):
if not monitor or not monitor.monitoring_enabled:
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return func(*args, **kwargs)
timer = monitor.metrics_collector.start_timer(
'request_duration',
{'component': component, 'operation': operation}
)
success = True
try:
if asyncio.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
return result
except Exception as e:
success = False
raise
finally:
duration = monitor.metrics_collector.stop_timer(timer)
monitor.record_request(component, operation, duration, success)
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return wrapper
return decorator
class QueryPatternAnalyzer:
"""Analyzes query patterns for optimization insights."""
def __init__(self, monitor: PerformanceMonitor):
"""Initialize query pattern analyzer.
Args:
monitor: Performance monitor instance
"""
self.monitor = monitor
self.query_patterns: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
def record_query_pattern(self,
question_type: str,
entities: List[str],
complexity: float,
backend: str,
duration: float,
success: bool):
"""Record a query pattern for analysis.
Args:
question_type: Type of question
entities: Entities in question
complexity: Query complexity score
backend: Backend used
duration: Query duration
success: Whether query succeeded
"""
pattern = {
'timestamp': datetime.now(),
'question_type': question_type,
'entity_count': len(entities),
'entities': entities,
'complexity': complexity,
'backend': backend,
'duration': duration,
'success': success
}
pattern_key = f"{question_type}:{len(entities)}"
self.query_patterns[pattern_key].append(pattern)
# Keep only recent patterns
cutoff_time = datetime.now() - timedelta(hours=24)
self.query_patterns[pattern_key] = [
p for p in self.query_patterns[pattern_key]
if p['timestamp'] > cutoff_time
]
def get_optimization_insights(self) -> Dict[str, Any]:
"""Get insights for query optimization.
Returns:
Optimization insights and recommendations
"""
insights = {
'slow_patterns': [],
'common_failures': [],
'backend_performance': {},
'complexity_analysis': {},
'recommendations': []
}
# Analyze slow patterns
for pattern_key, patterns in self.query_patterns.items():
if not patterns:
continue
avg_duration = statistics.mean([p['duration'] for p in patterns])
success_rate = sum(1 for p in patterns if p['success']) / len(patterns)
if avg_duration > 5.0: # Slow queries > 5 seconds
insights['slow_patterns'].append({
'pattern': pattern_key,
'avg_duration': avg_duration,
'count': len(patterns),
'success_rate': success_rate
})
if success_rate < 0.8: # Low success rate
insights['common_failures'].append({
'pattern': pattern_key,
'success_rate': success_rate,
'count': len(patterns)
})
# Analyze backend performance
backend_stats = defaultdict(list)
for patterns in self.query_patterns.values():
for pattern in patterns:
backend_stats[pattern['backend']].append(pattern['duration'])
for backend, durations in backend_stats.items():
insights['backend_performance'][backend] = {
'avg_duration': statistics.mean(durations),
'p95_duration': sorted(durations)[int(len(durations) * 0.95)],
'query_count': len(durations)
}
# Generate recommendations
recommendations = []
# Slow pattern recommendations
for slow_pattern in insights['slow_patterns']:
recommendations.append(
f"Consider optimizing {slow_pattern['pattern']} queries - "
f"average duration {slow_pattern['avg_duration']:.2f}s"
)
# Backend recommendations
if len(insights['backend_performance']) > 1:
fastest_backend = min(
insights['backend_performance'].items(),
key=lambda x: x[1]['avg_duration']
)[0]
recommendations.append(
f"Consider routing more queries to {fastest_backend} "
f"for better performance"
)
insights['recommendations'] = recommendations
return insights

View file

@ -0,0 +1,656 @@
"""
Multi-language support for OntoRAG.
Provides language detection, translation, and multilingual query processing.
"""
import logging
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
logger = logging.getLogger(__name__)
class Language(Enum):
"""Supported languages."""
ENGLISH = "en"
SPANISH = "es"
FRENCH = "fr"
GERMAN = "de"
ITALIAN = "it"
PORTUGUESE = "pt"
CHINESE = "zh"
JAPANESE = "ja"
KOREAN = "ko"
ARABIC = "ar"
RUSSIAN = "ru"
DUTCH = "nl"
@dataclass
class LanguageDetectionResult:
"""Language detection result."""
language: Language
confidence: float
detected_text: str
alternative_languages: List[Tuple[Language, float]] = None
@dataclass
class TranslationResult:
"""Translation result."""
original_text: str
translated_text: str
source_language: Language
target_language: Language
confidence: float
class LanguageDetector:
"""Detects language of input text."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize language detector.
Args:
config: Detector configuration
"""
self.config = config or {}
self.default_language = Language(self.config.get('default_language', 'en'))
self.confidence_threshold = self.config.get('confidence_threshold', 0.7)
# Try to import language detection libraries
self.detector = None
self._init_detector()
def _init_detector(self):
"""Initialize language detection backend."""
try:
# Try langdetect first
import langdetect
self.detector = 'langdetect'
logger.info("Using langdetect for language detection")
except ImportError:
try:
# Try textblob as fallback
from textblob import TextBlob
self.detector = 'textblob'
logger.info("Using TextBlob for language detection")
except ImportError:
logger.warning("No language detection library available, using rule-based detection")
self.detector = 'rule_based'
def detect_language(self, text: str) -> LanguageDetectionResult:
"""Detect language of input text.
Args:
text: Text to analyze
Returns:
Language detection result
"""
if not text or not text.strip():
return LanguageDetectionResult(
language=self.default_language,
confidence=0.0,
detected_text=text
)
try:
if self.detector == 'langdetect':
return self._detect_with_langdetect(text)
elif self.detector == 'textblob':
return self._detect_with_textblob(text)
else:
return self._detect_with_rules(text)
except Exception as e:
logger.error(f"Language detection failed: {e}")
return LanguageDetectionResult(
language=self.default_language,
confidence=0.0,
detected_text=text
)
def _detect_with_langdetect(self, text: str) -> LanguageDetectionResult:
"""Detect language using langdetect library."""
import langdetect
from langdetect.lang_detect_exception import LangDetectException
try:
# Get detailed detection results
probabilities = langdetect.detect_langs(text)
if not probabilities:
return LanguageDetectionResult(
language=self.default_language,
confidence=0.0,
detected_text=text
)
best_match = probabilities[0]
detected_lang_code = best_match.lang
confidence = best_match.prob
# Map to our Language enum
try:
detected_language = Language(detected_lang_code)
except ValueError:
# Map common variations
lang_mapping = {
'ca': Language.SPANISH, # Catalan -> Spanish
'eu': Language.SPANISH, # Basque -> Spanish
'gl': Language.SPANISH, # Galician -> Spanish
'zh-cn': Language.CHINESE,
'zh-tw': Language.CHINESE,
}
detected_language = lang_mapping.get(detected_lang_code, self.default_language)
# Get alternatives
alternatives = []
for lang_prob in probabilities[1:3]: # Top 3 alternatives
try:
alt_lang = Language(lang_prob.lang)
alternatives.append((alt_lang, lang_prob.prob))
except ValueError:
continue
return LanguageDetectionResult(
language=detected_language,
confidence=confidence,
detected_text=text,
alternative_languages=alternatives
)
except LangDetectException:
return LanguageDetectionResult(
language=self.default_language,
confidence=0.0,
detected_text=text
)
def _detect_with_textblob(self, text: str) -> LanguageDetectionResult:
"""Detect language using TextBlob."""
from textblob import TextBlob
try:
blob = TextBlob(text)
detected_lang_code = blob.detect_language()
try:
detected_language = Language(detected_lang_code)
except ValueError:
detected_language = self.default_language
# TextBlob doesn't provide confidence, so estimate based on text length
confidence = min(0.8, len(text) / 100.0) if len(text) > 10 else 0.5
return LanguageDetectionResult(
language=detected_language,
confidence=confidence,
detected_text=text
)
except Exception:
return LanguageDetectionResult(
language=self.default_language,
confidence=0.0,
detected_text=text
)
def _detect_with_rules(self, text: str) -> LanguageDetectionResult:
"""Rule-based language detection fallback."""
text_lower = text.lower()
# Simple keyword-based detection
language_keywords = {
Language.SPANISH: ['qué', 'cuál', 'cuándo', 'dónde', 'cómo', 'por qué', 'cuántos'],
Language.FRENCH: ['que', 'quel', 'quand', '', 'comment', 'pourquoi', 'combien'],
Language.GERMAN: ['was', 'welche', 'wann', 'wo', 'wie', 'warum', 'wieviele'],
Language.ITALIAN: ['che', 'quale', 'quando', 'dove', 'come', 'perché', 'quanti'],
Language.PORTUGUESE: ['que', 'qual', 'quando', 'onde', 'como', 'por que', 'quantos'],
Language.DUTCH: ['wat', 'welke', 'wanneer', 'waar', 'hoe', 'waarom', 'hoeveel']
}
best_match = self.default_language
best_score = 0
for language, keywords in language_keywords.items():
score = sum(1 for keyword in keywords if keyword in text_lower)
if score > best_score:
best_score = score
best_match = language
confidence = min(0.8, best_score / 3.0) if best_score > 0 else 0.1
return LanguageDetectionResult(
language=best_match,
confidence=confidence,
detected_text=text
)
class TextTranslator:
"""Translates text between languages."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize text translator.
Args:
config: Translator configuration
"""
self.config = config or {}
self.translator = None
self._init_translator()
def _init_translator(self):
"""Initialize translation backend."""
try:
# Try Google Translate first
from googletrans import Translator
self.translator = Translator()
self.backend = 'googletrans'
logger.info("Using Google Translate for translation")
except ImportError:
try:
# Try TextBlob as fallback
from textblob import TextBlob
self.backend = 'textblob'
logger.info("Using TextBlob for translation")
except ImportError:
logger.warning("No translation library available")
self.backend = None
def translate(self,
text: str,
target_language: Language,
source_language: Optional[Language] = None) -> TranslationResult:
"""Translate text to target language.
Args:
text: Text to translate
target_language: Target language
source_language: Source language (auto-detect if None)
Returns:
Translation result
"""
if not text or not text.strip():
return TranslationResult(
original_text=text,
translated_text=text,
source_language=source_language or Language.ENGLISH,
target_language=target_language,
confidence=0.0
)
try:
if self.backend == 'googletrans':
return self._translate_with_googletrans(text, target_language, source_language)
elif self.backend == 'textblob':
return self._translate_with_textblob(text, target_language, source_language)
else:
# No translation available
return TranslationResult(
original_text=text,
translated_text=text,
source_language=source_language or Language.ENGLISH,
target_language=target_language,
confidence=0.0
)
except Exception as e:
logger.error(f"Translation failed: {e}")
return TranslationResult(
original_text=text,
translated_text=text,
source_language=source_language or Language.ENGLISH,
target_language=target_language,
confidence=0.0
)
def _translate_with_googletrans(self,
text: str,
target_language: Language,
source_language: Optional[Language]) -> TranslationResult:
"""Translate using Google Translate."""
try:
src_code = source_language.value if source_language else 'auto'
dest_code = target_language.value
result = self.translator.translate(text, src=src_code, dest=dest_code)
detected_source = Language(result.src) if result.src != 'auto' else Language.ENGLISH
confidence = 0.9 # Google Translate is generally reliable
return TranslationResult(
original_text=text,
translated_text=result.text,
source_language=detected_source,
target_language=target_language,
confidence=confidence
)
except Exception as e:
logger.error(f"Google Translate error: {e}")
raise
def _translate_with_textblob(self,
text: str,
target_language: Language,
source_language: Optional[Language]) -> TranslationResult:
"""Translate using TextBlob."""
from textblob import TextBlob
try:
blob = TextBlob(text)
if not source_language:
# Auto-detect source language
detected_lang = blob.detect_language()
try:
source_language = Language(detected_lang)
except ValueError:
source_language = Language.ENGLISH
translated_blob = blob.translate(to=target_language.value)
translated_text = str(translated_blob)
# TextBlob confidence estimation
confidence = 0.7 if len(text) > 10 else 0.5
return TranslationResult(
original_text=text,
translated_text=translated_text,
source_language=source_language,
target_language=target_language,
confidence=confidence
)
except Exception as e:
logger.error(f"TextBlob translation error: {e}")
raise
class MultiLanguageQueryProcessor:
"""Processes queries in multiple languages."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize multi-language query processor.
Args:
config: Processor configuration
"""
self.config = config or {}
self.language_detector = LanguageDetector(config.get('language_detection', {}))
self.translator = TextTranslator(config.get('translation', {}))
self.supported_languages = [Language(lang) for lang in config.get('supported_languages', ['en'])]
self.primary_language = Language(config.get('primary_language', 'en'))
async def process_multilingual_query(self, question: str) -> Dict[str, Any]:
"""Process a query in any supported language.
Args:
question: Question in any language
Returns:
Processing result with language information
"""
# Step 1: Detect language
detection_result = self.language_detector.detect_language(question)
detected_language = detection_result.language
logger.info(f"Detected language: {detected_language.value} "
f"(confidence: {detection_result.confidence:.2f})")
# Step 2: Translate to primary language if needed
translated_question = question
translation_result = None
if detected_language != self.primary_language:
if detection_result.confidence >= self.language_detector.confidence_threshold:
translation_result = self.translator.translate(
question, self.primary_language, detected_language
)
translated_question = translation_result.translated_text
logger.info(f"Translated question: {translated_question}")
else:
logger.warning(f"Low confidence language detection, processing in {self.primary_language.value}")
# Step 3: Return processing information
return {
'original_question': question,
'translated_question': translated_question,
'detected_language': detected_language,
'detection_confidence': detection_result.confidence,
'translation_result': translation_result,
'processing_language': self.primary_language,
'alternative_languages': detection_result.alternative_languages
}
async def translate_answer(self,
answer: str,
target_language: Language) -> TranslationResult:
"""Translate answer back to target language.
Args:
answer: Answer in primary language
target_language: Target language for answer
Returns:
Translation result
"""
if target_language == self.primary_language:
# No translation needed
return TranslationResult(
original_text=answer,
translated_text=answer,
source_language=self.primary_language,
target_language=target_language,
confidence=1.0
)
return self.translator.translate(answer, target_language, self.primary_language)
def get_language_specific_ontology_terms(self,
ontology_subset: Dict[str, Any],
language: Language) -> Dict[str, Any]:
"""Get language-specific terms from ontology.
Args:
ontology_subset: Ontology subset
language: Target language
Returns:
Language-specific ontology terms
"""
# Extract language-specific labels and descriptions
lang_code = language.value
result = {}
# Process classes
if 'classes' in ontology_subset:
result['classes'] = {}
for class_id, class_def in ontology_subset['classes'].items():
lang_labels = []
if 'labels' in class_def:
for label in class_def['labels']:
if isinstance(label, dict) and label.get('language') == lang_code:
lang_labels.append(label['value'])
elif isinstance(label, str):
lang_labels.append(label)
result['classes'][class_id] = {
**class_def,
'language_labels': lang_labels
}
# Process properties
for prop_type in ['object_properties', 'datatype_properties']:
if prop_type in ontology_subset:
result[prop_type] = {}
for prop_id, prop_def in ontology_subset[prop_type].items():
lang_labels = []
if 'labels' in prop_def:
for label in prop_def['labels']:
if isinstance(label, dict) and label.get('language') == lang_code:
lang_labels.append(label['value'])
elif isinstance(label, str):
lang_labels.append(label)
result[prop_type][prop_id] = {
**prop_def,
'language_labels': lang_labels
}
return result
def is_language_supported(self, language: Language) -> bool:
"""Check if language is supported.
Args:
language: Language to check
Returns:
True if language is supported
"""
return language in self.supported_languages
def get_supported_languages(self) -> List[Language]:
"""Get list of supported languages.
Returns:
List of supported languages
"""
return self.supported_languages.copy()
def add_language_support(self, language: Language):
"""Add support for a new language.
Args:
language: Language to add support for
"""
if language not in self.supported_languages:
self.supported_languages.append(language)
logger.info(f"Added support for language: {language.value}")
def remove_language_support(self, language: Language):
"""Remove support for a language.
Args:
language: Language to remove support for
"""
if language in self.supported_languages and language != self.primary_language:
self.supported_languages.remove(language)
logger.info(f"Removed support for language: {language.value}")
else:
logger.warning(f"Cannot remove primary language or unsupported language: {language.value}")
class LanguageSpecificTemplates:
"""Manages language-specific query and answer templates."""
def __init__(self):
"""Initialize language-specific templates."""
self.question_templates = {
Language.ENGLISH: {
'count': ['how many', 'count of', 'number of'],
'boolean': ['is', 'are', 'does', 'can', 'will'],
'retrieval': ['what', 'which', 'who', 'where'],
'factual': ['tell me about', 'describe', 'explain']
},
Language.SPANISH: {
'count': ['cuántos', 'cuántas', 'número de', 'cantidad de'],
'boolean': ['es', 'son', 'está', 'están', 'puede', 'pueden'],
'retrieval': ['qué', 'cuál', 'cuáles', 'quién', 'dónde'],
'factual': ['dime sobre', 'describe', 'explica']
},
Language.FRENCH: {
'count': ['combien', 'nombre de', 'quantité de'],
'boolean': ['est', 'sont', 'peut', 'peuvent'],
'retrieval': ['que', 'quel', 'quelle', 'qui', ''],
'factual': ['dis-moi sur', 'décris', 'explique']
},
Language.GERMAN: {
'count': ['wie viele', 'anzahl der', 'zahl der'],
'boolean': ['ist', 'sind', 'kann', 'können'],
'retrieval': ['was', 'welche', 'wer', 'wo'],
'factual': ['erzähl mir über', 'beschreibe', 'erkläre']
}
}
self.answer_templates = {
Language.ENGLISH: {
'count': 'There are {count} {entity}.',
'boolean_true': 'Yes, {statement}.',
'boolean_false': 'No, {statement}.',
'not_found': 'No information found.',
'error': 'Sorry, I encountered an error.'
},
Language.SPANISH: {
'count': 'Hay {count} {entity}.',
'boolean_true': 'Sí, {statement}.',
'boolean_false': 'No, {statement}.',
'not_found': 'No se encontró información.',
'error': 'Lo siento, encontré un error.'
},
Language.FRENCH: {
'count': 'Il y a {count} {entity}.',
'boolean_true': 'Oui, {statement}.',
'boolean_false': 'Non, {statement}.',
'not_found': 'Aucune information trouvée.',
'error': 'Désolé, j\'ai rencontré une erreur.'
},
Language.GERMAN: {
'count': 'Es gibt {count} {entity}.',
'boolean_true': 'Ja, {statement}.',
'boolean_false': 'Nein, {statement}.',
'not_found': 'Keine Informationen gefunden.',
'error': 'Entschuldigung, ich bin auf einen Fehler gestoßen.'
}
}
def get_question_patterns(self, language: Language) -> Dict[str, List[str]]:
"""Get question patterns for a language.
Args:
language: Target language
Returns:
Dictionary of question patterns
"""
return self.question_templates.get(language, self.question_templates[Language.ENGLISH])
def get_answer_template(self, language: Language, template_type: str) -> str:
"""Get answer template for a language and type.
Args:
language: Target language
template_type: Template type
Returns:
Answer template string
"""
templates = self.answer_templates.get(language, self.answer_templates[Language.ENGLISH])
return templates.get(template_type, templates.get('error', 'Error'))
def format_answer(self,
language: Language,
template_type: str,
**kwargs) -> str:
"""Format answer using language-specific template.
Args:
language: Target language
template_type: Template type
**kwargs: Template variables
Returns:
Formatted answer
"""
template = self.get_answer_template(language, template_type)
try:
return template.format(**kwargs)
except KeyError as e:
logger.error(f"Missing template variable: {e}")
return self.get_answer_template(language, 'error')

View file

@ -0,0 +1,256 @@
"""
Ontology matcher for query system.
Identifies relevant ontology subsets for answering questions.
"""
import logging
from typing import List, Dict, Any, Set, Optional
from dataclasses import dataclass
from ...extract.kg.ontology.ontology_loader import Ontology, OntologyLoader
from ...extract.kg.ontology.ontology_embedder import OntologyEmbedder
from ...extract.kg.ontology.text_processor import TextSegment
from ...extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset
from .question_analyzer import QuestionComponents, QuestionType
logger = logging.getLogger(__name__)
@dataclass
class QueryOntologySubset(OntologySubset):
"""Extended ontology subset for query processing."""
traversal_properties: Dict[str, Any] = None # Additional properties for graph traversal
inference_rules: List[Dict[str, Any]] = None # Inference rules for reasoning
class OntologyMatcherForQueries(OntologySelector):
"""
Specialized ontology matcher for question answering.
Extends OntologySelector with query-specific logic.
"""
def __init__(self, ontology_embedder: OntologyEmbedder,
ontology_loader: OntologyLoader,
top_k: int = 15, # Higher k for queries
similarity_threshold: float = 0.6): # Lower threshold for broader coverage
"""Initialize query-specific ontology matcher.
Args:
ontology_embedder: Embedder with vector store
ontology_loader: Loader with ontology definitions
top_k: Number of top results to retrieve
similarity_threshold: Minimum similarity score
"""
super().__init__(ontology_embedder, ontology_loader, top_k, similarity_threshold)
async def match_question_to_ontology(self,
question_components: QuestionComponents,
question_segments: List[str]) -> List[QueryOntologySubset]:
"""Match question components to relevant ontology elements.
Args:
question_components: Analyzed question components
question_segments: Text segments from question
Returns:
List of query-optimized ontology subsets
"""
# Convert question segments to TextSegment objects
text_segments = [
TextSegment(text=seg, type='question', position=i)
for i, seg in enumerate(question_segments)
]
# Get base ontology subsets using parent class method
base_subsets = await self.select_ontology_subset(text_segments)
# Enhance subsets for query processing
query_subsets = []
for subset in base_subsets:
query_subset = self._enhance_for_query(subset, question_components)
query_subsets.append(query_subset)
return query_subsets
def _enhance_for_query(self, subset: OntologySubset,
question_components: QuestionComponents) -> QueryOntologySubset:
"""Enhance ontology subset with query-specific elements.
Args:
subset: Base ontology subset
question_components: Analyzed question components
Returns:
Enhanced query ontology subset
"""
# Create query subset
query_subset = QueryOntologySubset(
ontology_id=subset.ontology_id,
classes=dict(subset.classes),
object_properties=dict(subset.object_properties),
datatype_properties=dict(subset.datatype_properties),
metadata=subset.metadata,
relevance_score=subset.relevance_score,
traversal_properties={},
inference_rules=[]
)
# Add traversal properties based on question type
self._add_traversal_properties(query_subset, question_components)
# Add related properties for exploration
self._add_related_properties(query_subset)
# Add inference rules if needed
self._add_inference_rules(query_subset, question_components)
return query_subset
def _add_traversal_properties(self, subset: QueryOntologySubset,
question_components: QuestionComponents):
"""Add properties useful for graph traversal.
Args:
subset: Query ontology subset to enhance
question_components: Question analysis
"""
ontology = self.loader.get_ontology(subset.ontology_id)
if not ontology:
return
# For relationship questions, add all properties connecting mentioned classes
if question_components.question_type == QuestionType.RELATIONSHIP:
for prop_id, prop_def in ontology.object_properties.items():
domain = prop_def.domain
range_val = prop_def.range
# Check if property connects relevant classes
if domain in subset.classes or range_val in subset.classes:
if prop_id not in subset.object_properties:
subset.traversal_properties[prop_id] = prop_def.__dict__
logger.debug(f"Added traversal property: {prop_id}")
# For retrieval questions, add properties that might filter results
elif question_components.question_type == QuestionType.RETRIEVAL:
# Add all properties with domains in our classes
for prop_id, prop_def in ontology.object_properties.items():
if prop_def.domain in subset.classes:
if prop_id not in subset.object_properties:
subset.traversal_properties[prop_id] = prop_def.__dict__
for prop_id, prop_def in ontology.datatype_properties.items():
if prop_def.domain in subset.classes:
if prop_id not in subset.datatype_properties:
subset.traversal_properties[prop_id] = prop_def.__dict__
# For aggregation questions, ensure we have counting properties
elif question_components.question_type == QuestionType.AGGREGATION:
# Add properties that might be counted
for prop_id, prop_def in ontology.datatype_properties.items():
if 'count' in prop_id.lower() or 'number' in prop_id.lower():
if prop_id not in subset.datatype_properties:
subset.traversal_properties[prop_id] = prop_def.__dict__
def _add_related_properties(self, subset: QueryOntologySubset):
"""Add properties related to already selected ones.
Args:
subset: Query ontology subset to enhance
"""
ontology = self.loader.get_ontology(subset.ontology_id)
if not ontology:
return
# Add inverse properties
for prop_id in list(subset.object_properties.keys()):
prop = ontology.object_properties.get(prop_id)
if prop and prop.inverse_of:
inverse_prop = ontology.object_properties.get(prop.inverse_of)
if inverse_prop and prop.inverse_of not in subset.object_properties:
subset.object_properties[prop.inverse_of] = inverse_prop.__dict__
logger.debug(f"Added inverse property: {prop.inverse_of}")
# Add sibling properties (same domain)
domains_in_subset = set()
for prop_def in subset.object_properties.values():
if 'domain' in prop_def and prop_def['domain']:
domains_in_subset.add(prop_def['domain'])
for domain in domains_in_subset:
for prop_id, prop_def in ontology.object_properties.items():
if prop_def.domain == domain and prop_id not in subset.object_properties:
# Add up to 3 sibling properties
if len(subset.traversal_properties) < 3:
subset.traversal_properties[prop_id] = prop_def.__dict__
def _add_inference_rules(self, subset: QueryOntologySubset,
question_components: QuestionComponents):
"""Add inference rules for reasoning.
Args:
subset: Query ontology subset to enhance
question_components: Question analysis
"""
# Add transitivity rules for subclass relationships
if any(cls.get('subclass_of') for cls in subset.classes.values()):
subset.inference_rules.append({
'type': 'transitivity',
'property': 'rdfs:subClassOf',
'description': 'Subclass relationships are transitive'
})
# Add symmetry rules for equivalent classes
if any(cls.get('equivalent_classes') for cls in subset.classes.values()):
subset.inference_rules.append({
'type': 'symmetry',
'property': 'owl:equivalentClass',
'description': 'Equivalent class relationships are symmetric'
})
# Add inverse property rules
for prop_id, prop_def in subset.object_properties.items():
if 'inverse_of' in prop_def and prop_def['inverse_of']:
subset.inference_rules.append({
'type': 'inverse',
'property': prop_id,
'inverse': prop_def['inverse_of'],
'description': f'{prop_id} is inverse of {prop_def["inverse_of"]}'
})
def expand_for_hierarchical_queries(self, subset: QueryOntologySubset) -> QueryOntologySubset:
"""Expand subset to include full class hierarchies.
Args:
subset: Query ontology subset
Returns:
Expanded subset with complete hierarchies
"""
ontology = self.loader.get_ontology(subset.ontology_id)
if not ontology:
return subset
# Add all parent and child classes
classes_to_add = set()
for class_id in list(subset.classes.keys()):
# Add all parents
parents = ontology.get_parent_classes(class_id)
for parent_id in parents:
if parent_id not in subset.classes:
parent_class = ontology.get_class(parent_id)
if parent_class:
classes_to_add.add(parent_id)
# Add all children
for other_class_id, other_class in ontology.classes.items():
if other_class.subclass_of == class_id and other_class_id not in subset.classes:
classes_to_add.add(other_class_id)
# Add collected classes
for class_id in classes_to_add:
cls = ontology.get_class(class_id)
if cls:
subset.classes[class_id] = cls.__dict__
logger.debug(f"Expanded hierarchy: added {len(classes_to_add)} classes")
return subset

View file

@ -0,0 +1,640 @@
"""
Query explanation system for OntoRAG.
Provides detailed explanations of how queries are processed and results are derived.
"""
import logging
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass, field
from datetime import datetime
from .question_analyzer import QuestionComponents, QuestionType
from .ontology_matcher import QueryOntologySubset
from .sparql_generator import SPARQLQuery
from .cypher_generator import CypherQuery
from .sparql_cassandra import SPARQLResult
from .cypher_executor import CypherResult
logger = logging.getLogger(__name__)
@dataclass
class ExplanationStep:
"""Individual step in query explanation."""
step_number: int
component: str
operation: str
input_data: Dict[str, Any]
output_data: Dict[str, Any]
explanation: str
duration_ms: float
success: bool
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class QueryExplanation:
"""Complete explanation of query processing."""
query_id: str
original_question: str
processing_steps: List[ExplanationStep]
final_answer: str
confidence_score: float
total_duration_ms: float
ontologies_used: List[str]
backend_used: str
reasoning_chain: List[str]
technical_details: Dict[str, Any]
user_friendly_explanation: str
class QueryExplainer:
"""Generates explanations for query processing."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize query explainer.
Args:
config: Explainer configuration
"""
self.config = config or {}
self.explanation_level = self.config.get('explanation_level', 'detailed') # basic, detailed, technical
self.include_technical_details = self.config.get('include_technical_details', True)
self.max_reasoning_steps = self.config.get('max_reasoning_steps', 10)
# Templates for different explanation types
self.step_templates = {
'question_analysis': {
'basic': "I analyzed your question to understand what you're asking.",
'detailed': "I analyzed your question '{question}' and identified it as a {question_type} query about {entities}.",
'technical': "Question analysis: Type={question_type}, Entities={entities}, Keywords={keywords}, Expected answer={answer_type}"
},
'ontology_matching': {
'basic': "I found relevant knowledge about {entities} in the available ontologies.",
'detailed': "I searched through {ontology_count} ontologies and found {selected_elements} relevant concepts related to your question.",
'technical': "Ontology matching: Selected {classes} classes, {properties} properties from {ontologies}"
},
'query_generation': {
'basic': "I generated a query to search for the information.",
'detailed': "I created a {query_type} query using {query_language} to search the {backend} database.",
'technical': "Query generation: {query_language} query with {variables} variables, complexity score {complexity}"
},
'query_execution': {
'basic': "I searched the database and found {result_count} results.",
'detailed': "I executed the query against the {backend} database and retrieved {result_count} results in {duration}ms.",
'technical': "Query execution: {backend} backend, {result_count} results, execution time {duration}ms"
},
'answer_generation': {
'basic': "I generated a natural language answer from the results.",
'detailed': "I processed {result_count} results and generated an answer with {confidence}% confidence.",
'technical': "Answer generation: {result_count} input results, {generation_method} method, confidence {confidence}"
}
}
self.reasoning_templates = {
'entity_identification': "I identified '{entity}' as a key concept in your question.",
'ontology_selection': "I selected the '{ontology}' ontology because it contains relevant information about {concepts}.",
'query_strategy': "I chose a {strategy} query approach because {reason}.",
'result_filtering': "I filtered the results to show only the most relevant {count} items.",
'confidence_assessment': "I'm {confidence}% confident in this answer because {reasoning}."
}
def explain_query_processing(self,
question: str,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
generated_query: Union[SPARQLQuery, CypherQuery],
query_results: Union[SPARQLResult, CypherResult],
final_answer: str,
processing_metadata: Dict[str, Any]) -> QueryExplanation:
"""Generate comprehensive explanation of query processing.
Args:
question: Original question
question_components: Analyzed question components
ontology_subsets: Selected ontology subsets
generated_query: Generated query
query_results: Query execution results
final_answer: Final generated answer
processing_metadata: Processing metadata
Returns:
Complete query explanation
"""
query_id = processing_metadata.get('query_id', f"query_{datetime.now().timestamp()}")
start_time = processing_metadata.get('start_time', datetime.now())
# Build explanation steps
steps = []
step_number = 1
# Step 1: Question Analysis
steps.append(self._explain_question_analysis(
step_number, question, question_components
))
step_number += 1
# Step 2: Ontology Matching
steps.append(self._explain_ontology_matching(
step_number, question_components, ontology_subsets
))
step_number += 1
# Step 3: Query Generation
steps.append(self._explain_query_generation(
step_number, generated_query, processing_metadata
))
step_number += 1
# Step 4: Query Execution
steps.append(self._explain_query_execution(
step_number, generated_query, query_results, processing_metadata
))
step_number += 1
# Step 5: Answer Generation
steps.append(self._explain_answer_generation(
step_number, query_results, final_answer, processing_metadata
))
# Build reasoning chain
reasoning_chain = self._build_reasoning_chain(
question_components, ontology_subsets, generated_query, processing_metadata
)
# Calculate overall confidence
confidence_score = self._calculate_explanation_confidence(
question_components, query_results, processing_metadata
)
# Generate user-friendly explanation
user_friendly_explanation = self._generate_user_friendly_explanation(
question, question_components, ontology_subsets, final_answer
)
# Calculate total duration
total_duration = processing_metadata.get('total_duration_ms', 0)
return QueryExplanation(
query_id=query_id,
original_question=question,
processing_steps=steps,
final_answer=final_answer,
confidence_score=confidence_score,
total_duration_ms=total_duration,
ontologies_used=[subset.metadata.get('ontology_id', 'unknown') for subset in ontology_subsets],
backend_used=processing_metadata.get('backend_used', 'unknown'),
reasoning_chain=reasoning_chain,
technical_details=self._extract_technical_details(processing_metadata),
user_friendly_explanation=user_friendly_explanation
)
def _explain_question_analysis(self,
step_number: int,
question: str,
question_components: QuestionComponents) -> ExplanationStep:
"""Explain question analysis step."""
template = self.step_templates['question_analysis'][self.explanation_level]
if self.explanation_level == 'basic':
explanation = template
elif self.explanation_level == 'detailed':
explanation = template.format(
question=question,
question_type=question_components.question_type.value.replace('_', ' '),
entities=', '.join(question_components.entities[:3])
)
else: # technical
explanation = template.format(
question_type=question_components.question_type.value,
entities=question_components.entities,
keywords=question_components.keywords,
answer_type=question_components.expected_answer_type
)
return ExplanationStep(
step_number=step_number,
component="question_analyzer",
operation="analyze_question",
input_data={"question": question},
output_data={
"question_type": question_components.question_type.value,
"entities": question_components.entities,
"keywords": question_components.keywords
},
explanation=explanation,
duration_ms=0.0, # Would be tracked in actual implementation
success=True
)
def _explain_ontology_matching(self,
step_number: int,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset]) -> ExplanationStep:
"""Explain ontology matching step."""
template = self.step_templates['ontology_matching'][self.explanation_level]
total_elements = sum(
len(subset.classes) + len(subset.object_properties) + len(subset.datatype_properties)
for subset in ontology_subsets
)
if self.explanation_level == 'basic':
explanation = template.format(
entities=', '.join(question_components.entities[:3])
)
elif self.explanation_level == 'detailed':
explanation = template.format(
ontology_count=len(ontology_subsets),
selected_elements=total_elements
)
else: # technical
total_classes = sum(len(subset.classes) for subset in ontology_subsets)
total_properties = sum(
len(subset.object_properties) + len(subset.datatype_properties)
for subset in ontology_subsets
)
ontology_names = [subset.metadata.get('ontology_id', 'unknown') for subset in ontology_subsets]
explanation = template.format(
classes=total_classes,
properties=total_properties,
ontologies=', '.join(ontology_names)
)
return ExplanationStep(
step_number=step_number,
component="ontology_matcher",
operation="select_relevant_subset",
input_data={"entities": question_components.entities},
output_data={
"ontology_count": len(ontology_subsets),
"total_elements": total_elements
},
explanation=explanation,
duration_ms=0.0,
success=True
)
def _explain_query_generation(self,
step_number: int,
generated_query: Union[SPARQLQuery, CypherQuery],
metadata: Dict[str, Any]) -> ExplanationStep:
"""Explain query generation step."""
template = self.step_templates['query_generation'][self.explanation_level]
query_language = "SPARQL" if isinstance(generated_query, SPARQLQuery) else "Cypher"
backend = metadata.get('backend_used', 'unknown')
if self.explanation_level == 'basic':
explanation = template
elif self.explanation_level == 'detailed':
explanation = template.format(
query_type=generated_query.query_type,
query_language=query_language,
backend=backend
)
else: # technical
explanation = template.format(
query_language=query_language,
variables=len(generated_query.variables),
complexity=f"{generated_query.complexity_score:.2f}"
)
return ExplanationStep(
step_number=step_number,
component="query_generator",
operation="generate_query",
input_data={"query_type": generated_query.query_type},
output_data={
"query_language": query_language,
"variables": generated_query.variables,
"complexity": generated_query.complexity_score
},
explanation=explanation,
duration_ms=0.0,
success=True,
metadata={"generated_query": generated_query.query}
)
def _explain_query_execution(self,
step_number: int,
generated_query: Union[SPARQLQuery, CypherQuery],
query_results: Union[SPARQLResult, CypherResult],
metadata: Dict[str, Any]) -> ExplanationStep:
"""Explain query execution step."""
template = self.step_templates['query_execution'][self.explanation_level]
backend = metadata.get('backend_used', 'unknown')
duration = getattr(query_results, 'execution_time', 0) * 1000 # Convert to ms
if isinstance(query_results, SPARQLResult):
result_count = len(query_results.bindings)
else: # CypherResult
result_count = len(query_results.records)
if self.explanation_level == 'basic':
explanation = template.format(result_count=result_count)
elif self.explanation_level == 'detailed':
explanation = template.format(
backend=backend,
result_count=result_count,
duration=f"{duration:.1f}"
)
else: # technical
explanation = template.format(
backend=backend,
result_count=result_count,
duration=f"{duration:.1f}"
)
return ExplanationStep(
step_number=step_number,
component="query_executor",
operation="execute_query",
input_data={"query": generated_query.query},
output_data={
"result_count": result_count,
"execution_time_ms": duration
},
explanation=explanation,
duration_ms=duration,
success=result_count >= 0
)
def _explain_answer_generation(self,
step_number: int,
query_results: Union[SPARQLResult, CypherResult],
final_answer: str,
metadata: Dict[str, Any]) -> ExplanationStep:
"""Explain answer generation step."""
template = self.step_templates['answer_generation'][self.explanation_level]
if isinstance(query_results, SPARQLResult):
result_count = len(query_results.bindings)
else: # CypherResult
result_count = len(query_results.records)
confidence = metadata.get('answer_confidence', 0.8) * 100 # Convert to percentage
if self.explanation_level == 'basic':
explanation = template
elif self.explanation_level == 'detailed':
explanation = template.format(
result_count=result_count,
confidence=f"{confidence:.0f}"
)
else: # technical
generation_method = metadata.get('generation_method', 'template_based')
explanation = template.format(
result_count=result_count,
generation_method=generation_method,
confidence=f"{confidence:.1f}"
)
return ExplanationStep(
step_number=step_number,
component="answer_generator",
operation="generate_answer",
input_data={"result_count": result_count},
output_data={
"answer": final_answer,
"confidence": confidence / 100
},
explanation=explanation,
duration_ms=0.0,
success=bool(final_answer)
)
def _build_reasoning_chain(self,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
generated_query: Union[SPARQLQuery, CypherQuery],
metadata: Dict[str, Any]) -> List[str]:
"""Build reasoning chain explaining the decision process."""
reasoning = []
# Entity identification reasoning
if question_components.entities:
for entity in question_components.entities[:3]:
reasoning.append(
self.reasoning_templates['entity_identification'].format(entity=entity)
)
# Ontology selection reasoning
if ontology_subsets:
primary_ontology = ontology_subsets[0]
ontology_id = primary_ontology.metadata.get('ontology_id', 'primary')
concepts = list(primary_ontology.classes.keys())[:3]
reasoning.append(
self.reasoning_templates['ontology_selection'].format(
ontology=ontology_id,
concepts=', '.join(concepts)
)
)
# Query strategy reasoning
query_language = "SPARQL" if isinstance(generated_query, SPARQLQuery) else "Cypher"
if question_components.question_type == QuestionType.AGGREGATION:
strategy = "aggregation"
reason = "you asked for a count or sum"
elif question_components.question_type == QuestionType.BOOLEAN:
strategy = "boolean"
reason = "you asked a yes/no question"
else:
strategy = "retrieval"
reason = "you asked for specific information"
reasoning.append(
self.reasoning_templates['query_strategy'].format(
strategy=strategy,
reason=reason
)
)
# Confidence assessment
confidence = metadata.get('answer_confidence', 0.8) * 100
if confidence > 90:
confidence_reason = "the query matched well with available data"
elif confidence > 70:
confidence_reason = "the query found relevant information with some uncertainty"
else:
confidence_reason = "the available data partially matches your question"
reasoning.append(
self.reasoning_templates['confidence_assessment'].format(
confidence=f"{confidence:.0f}",
reasoning=confidence_reason
)
)
return reasoning[:self.max_reasoning_steps]
def _calculate_explanation_confidence(self,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
metadata: Dict[str, Any]) -> float:
"""Calculate confidence score for the explanation."""
confidence = 0.8 # Base confidence
# Adjust based on result count
if isinstance(query_results, SPARQLResult):
result_count = len(query_results.bindings)
else:
result_count = len(query_results.records)
if result_count > 0:
confidence += 0.1
if result_count > 5:
confidence += 0.05
# Adjust based on question complexity
if len(question_components.entities) > 0:
confidence += 0.05
# Adjust based on processing success
if metadata.get('all_steps_successful', True):
confidence += 0.05
return min(confidence, 1.0)
def _generate_user_friendly_explanation(self,
question: str,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
final_answer: str) -> str:
"""Generate user-friendly explanation of the process."""
explanation_parts = []
# Introduction
explanation_parts.append(f"To answer your question '{question}', I followed these steps:")
# Process summary
if question_components.question_type == QuestionType.AGGREGATION:
explanation_parts.append("1. I recognized this as a counting or aggregation question")
elif question_components.question_type == QuestionType.BOOLEAN:
explanation_parts.append("1. I recognized this as a yes/no question")
else:
explanation_parts.append("1. I analyzed your question to understand what information you need")
# Ontology usage
if ontology_subsets:
ontology_count = len(ontology_subsets)
if ontology_count == 1:
explanation_parts.append("2. I searched through the relevant knowledge base")
else:
explanation_parts.append(f"2. I searched through {ontology_count} knowledge bases")
# Result processing
explanation_parts.append("3. I found the relevant information and generated your answer")
# Conclusion
explanation_parts.append(f"The answer is: {final_answer}")
return " ".join(explanation_parts)
def _extract_technical_details(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Extract technical details for debugging and optimization."""
return {
'query_optimization': metadata.get('query_optimization', {}),
'backend_performance': metadata.get('backend_performance', {}),
'cache_usage': metadata.get('cache_usage', {}),
'error_handling': metadata.get('error_handling', {}),
'routing_decision': metadata.get('routing_decision', {})
}
def format_explanation_for_display(self,
explanation: QueryExplanation,
format_type: str = 'html') -> str:
"""Format explanation for display.
Args:
explanation: Query explanation
format_type: Output format ('html', 'markdown', 'text')
Returns:
Formatted explanation
"""
if format_type == 'html':
return self._format_html_explanation(explanation)
elif format_type == 'markdown':
return self._format_markdown_explanation(explanation)
else:
return self._format_text_explanation(explanation)
def _format_html_explanation(self, explanation: QueryExplanation) -> str:
"""Format explanation as HTML."""
html_parts = [
f"<h2>Query Explanation: {explanation.query_id}</h2>",
f"<p><strong>Question:</strong> {explanation.original_question}</p>",
f"<p><strong>Answer:</strong> {explanation.final_answer}</p>",
f"<p><strong>Confidence:</strong> {explanation.confidence_score:.1%}</p>",
"<h3>Processing Steps:</h3>",
"<ol>"
]
for step in explanation.processing_steps:
html_parts.append(f"<li><strong>{step.component}</strong>: {step.explanation}</li>")
html_parts.extend([
"</ol>",
"<h3>Reasoning:</h3>",
"<ul>"
])
for reasoning in explanation.reasoning_chain:
html_parts.append(f"<li>{reasoning}</li>")
html_parts.append("</ul>")
return "".join(html_parts)
def _format_markdown_explanation(self, explanation: QueryExplanation) -> str:
"""Format explanation as Markdown."""
md_parts = [
f"## Query Explanation: {explanation.query_id}",
f"**Question:** {explanation.original_question}",
f"**Answer:** {explanation.final_answer}",
f"**Confidence:** {explanation.confidence_score:.1%}",
"",
"### Processing Steps:",
""
]
for i, step in enumerate(explanation.processing_steps, 1):
md_parts.append(f"{i}. **{step.component}**: {step.explanation}")
md_parts.extend([
"",
"### Reasoning:",
""
])
for reasoning in explanation.reasoning_chain:
md_parts.append(f"- {reasoning}")
return "\n".join(md_parts)
def _format_text_explanation(self, explanation: QueryExplanation) -> str:
"""Format explanation as plain text."""
text_parts = [
f"Query Explanation: {explanation.query_id}",
f"Question: {explanation.original_question}",
f"Answer: {explanation.final_answer}",
f"Confidence: {explanation.confidence_score:.1%}",
"",
"Processing Steps:",
]
for i, step in enumerate(explanation.processing_steps, 1):
text_parts.append(f" {i}. {step.component}: {step.explanation}")
text_parts.extend([
"",
"Reasoning:",
])
for reasoning in explanation.reasoning_chain:
text_parts.append(f" - {reasoning}")
return "\n".join(text_parts)

View file

@ -0,0 +1,519 @@
"""
Query optimization module for OntoRAG.
Optimizes SPARQL and Cypher queries for better performance and accuracy.
"""
import logging
from typing import Dict, Any, List, Optional, Union, Tuple
from dataclasses import dataclass
from enum import Enum
import re
from .question_analyzer import QuestionComponents, QuestionType
from .ontology_matcher import QueryOntologySubset
from .sparql_generator import SPARQLQuery
from .cypher_generator import CypherQuery
logger = logging.getLogger(__name__)
class OptimizationStrategy(Enum):
"""Query optimization strategies."""
PERFORMANCE = "performance"
ACCURACY = "accuracy"
BALANCED = "balanced"
@dataclass
class OptimizationHint:
"""Optimization hint for query processing."""
strategy: OptimizationStrategy
max_results: Optional[int] = None
timeout_seconds: Optional[int] = None
use_indices: bool = True
enable_parallel: bool = False
cache_results: bool = True
@dataclass
class QueryPlan:
"""Query execution plan with optimization metadata."""
original_query: str
optimized_query: str
estimated_cost: float
optimization_notes: List[str]
index_hints: List[str]
execution_order: List[str]
class QueryOptimizer:
"""Optimizes SPARQL and Cypher queries for performance and accuracy."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize query optimizer.
Args:
config: Optimizer configuration
"""
self.config = config or {}
self.default_strategy = OptimizationStrategy(
self.config.get('default_strategy', 'balanced')
)
self.max_query_complexity = self.config.get('max_query_complexity', 10)
self.enable_query_rewriting = self.config.get('enable_query_rewriting', True)
# Performance thresholds
self.large_result_threshold = self.config.get('large_result_threshold', 1000)
self.complex_join_threshold = self.config.get('complex_join_threshold', 3)
def optimize_sparql(self,
sparql_query: SPARQLQuery,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
optimization_hint: Optional[OptimizationHint] = None) -> Tuple[SPARQLQuery, QueryPlan]:
"""Optimize SPARQL query.
Args:
sparql_query: Original SPARQL query
question_components: Question analysis
ontology_subset: Ontology subset
optimization_hint: Optimization hints
Returns:
Optimized SPARQL query and execution plan
"""
hint = optimization_hint or OptimizationHint(strategy=self.default_strategy)
optimized_query = sparql_query.query
optimization_notes = []
index_hints = []
execution_order = []
# Apply optimizations based on strategy
if hint.strategy in [OptimizationStrategy.PERFORMANCE, OptimizationStrategy.BALANCED]:
optimized_query, perf_notes, perf_hints = self._optimize_sparql_performance(
optimized_query, question_components, ontology_subset, hint
)
optimization_notes.extend(perf_notes)
index_hints.extend(perf_hints)
if hint.strategy in [OptimizationStrategy.ACCURACY, OptimizationStrategy.BALANCED]:
optimized_query, acc_notes = self._optimize_sparql_accuracy(
optimized_query, question_components, ontology_subset
)
optimization_notes.extend(acc_notes)
# Estimate query cost
estimated_cost = self._estimate_sparql_cost(optimized_query, ontology_subset)
# Build execution plan
query_plan = QueryPlan(
original_query=sparql_query.query,
optimized_query=optimized_query,
estimated_cost=estimated_cost,
optimization_notes=optimization_notes,
index_hints=index_hints,
execution_order=execution_order
)
# Create optimized query object
optimized_sparql = SPARQLQuery(
query=optimized_query,
variables=sparql_query.variables,
query_type=sparql_query.query_type,
explanation=f"Optimized: {sparql_query.explanation}",
complexity_score=min(sparql_query.complexity_score * 0.8, 1.0) # Assume optimization reduces complexity
)
return optimized_sparql, query_plan
def optimize_cypher(self,
cypher_query: CypherQuery,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
optimization_hint: Optional[OptimizationHint] = None) -> Tuple[CypherQuery, QueryPlan]:
"""Optimize Cypher query.
Args:
cypher_query: Original Cypher query
question_components: Question analysis
ontology_subset: Ontology subset
optimization_hint: Optimization hints
Returns:
Optimized Cypher query and execution plan
"""
hint = optimization_hint or OptimizationHint(strategy=self.default_strategy)
optimized_query = cypher_query.query
optimization_notes = []
index_hints = []
execution_order = []
# Apply optimizations based on strategy
if hint.strategy in [OptimizationStrategy.PERFORMANCE, OptimizationStrategy.BALANCED]:
optimized_query, perf_notes, perf_hints = self._optimize_cypher_performance(
optimized_query, question_components, ontology_subset, hint
)
optimization_notes.extend(perf_notes)
index_hints.extend(perf_hints)
if hint.strategy in [OptimizationStrategy.ACCURACY, OptimizationStrategy.BALANCED]:
optimized_query, acc_notes = self._optimize_cypher_accuracy(
optimized_query, question_components, ontology_subset
)
optimization_notes.extend(acc_notes)
# Estimate query cost
estimated_cost = self._estimate_cypher_cost(optimized_query, ontology_subset)
# Build execution plan
query_plan = QueryPlan(
original_query=cypher_query.query,
optimized_query=optimized_query,
estimated_cost=estimated_cost,
optimization_notes=optimization_notes,
index_hints=index_hints,
execution_order=execution_order
)
# Create optimized query object
optimized_cypher = CypherQuery(
query=optimized_query,
variables=cypher_query.variables,
query_type=cypher_query.query_type,
explanation=f"Optimized: {cypher_query.explanation}",
complexity_score=min(cypher_query.complexity_score * 0.8, 1.0)
)
return optimized_cypher, query_plan
def _optimize_sparql_performance(self,
query: str,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
hint: OptimizationHint) -> Tuple[str, List[str], List[str]]:
"""Apply performance optimizations to SPARQL query.
Args:
query: SPARQL query string
question_components: Question analysis
ontology_subset: Ontology subset
hint: Optimization hints
Returns:
Optimized query, optimization notes, and index hints
"""
optimized = query
notes = []
index_hints = []
# Add LIMIT if not present and large results expected
if hint.max_results and 'LIMIT' not in optimized.upper():
optimized = f"{optimized.rstrip()}\nLIMIT {hint.max_results}"
notes.append(f"Added LIMIT {hint.max_results} to prevent large result sets")
# Optimize OPTIONAL clauses (move to end)
optional_pattern = re.compile(r'OPTIONAL\s*\{[^}]+\}', re.IGNORECASE | re.DOTALL)
optionals = optional_pattern.findall(optimized)
if optionals:
# Remove optionals from current position
for optional in optionals:
optimized = optimized.replace(optional, '')
# Add them at the end (before ORDER BY/LIMIT)
insert_point = optimized.rfind('ORDER BY')
if insert_point == -1:
insert_point = optimized.rfind('LIMIT')
if insert_point == -1:
insert_point = len(optimized.rstrip())
for optional in optionals:
optimized = optimized[:insert_point] + f"\n {optional}" + optimized[insert_point:]
notes.append("Moved OPTIONAL clauses to end for better performance")
# Add index hints for Cassandra
if 'WHERE' in optimized.upper():
# Suggest indices for common patterns
if '?subject rdf:type' in optimized:
index_hints.append("type_index")
if 'rdfs:subClassOf' in optimized:
index_hints.append("hierarchy_index")
# Optimize FILTER clauses (move closer to variable bindings)
filter_pattern = re.compile(r'FILTER\s*\([^)]+\)', re.IGNORECASE)
filters = filter_pattern.findall(optimized)
if filters:
notes.append("FILTER clauses present - ensure they're positioned optimally")
return optimized, notes, index_hints
def _optimize_sparql_accuracy(self,
query: str,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Tuple[str, List[str]]:
"""Apply accuracy optimizations to SPARQL query.
Args:
query: SPARQL query string
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Optimized query and optimization notes
"""
optimized = query
notes = []
# Add missing namespace checks
if question_components.question_type == QuestionType.RETRIEVAL:
# Ensure we're not mixing namespaces inappropriately
if 'http://' in optimized and '?' in optimized:
notes.append("Verified namespace consistency for accuracy")
# Add type constraints for better precision
if '?entity' in optimized and 'rdf:type' not in optimized:
# Find a good insertion point
where_clause = re.search(r'WHERE\s*\{(.+)\}', optimized, re.DOTALL | re.IGNORECASE)
if where_clause and ontology_subset.classes:
# Add type constraint for the most relevant class
main_class = list(ontology_subset.classes.keys())[0]
type_constraint = f"\n ?entity rdf:type :{main_class} ."
# Insert after the WHERE {
where_start = where_clause.start(1)
optimized = optimized[:where_start] + type_constraint + optimized[where_start:]
notes.append(f"Added type constraint for {main_class} to improve accuracy")
# Add DISTINCT if not present for retrieval queries
if (question_components.question_type == QuestionType.RETRIEVAL and
'DISTINCT' not in optimized.upper() and
'SELECT' in optimized.upper()):
optimized = optimized.replace('SELECT ', 'SELECT DISTINCT ', 1)
notes.append("Added DISTINCT to eliminate duplicate results")
return optimized, notes
def _optimize_cypher_performance(self,
query: str,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
hint: OptimizationHint) -> Tuple[str, List[str], List[str]]:
"""Apply performance optimizations to Cypher query.
Args:
query: Cypher query string
question_components: Question analysis
ontology_subset: Ontology subset
hint: Optimization hints
Returns:
Optimized query, optimization notes, and index hints
"""
optimized = query
notes = []
index_hints = []
# Add LIMIT if not present
if hint.max_results and 'LIMIT' not in optimized.upper():
optimized = f"{optimized.rstrip()}\nLIMIT {hint.max_results}"
notes.append(f"Added LIMIT {hint.max_results} to prevent large result sets")
# Use parameters for literals to enable query plan caching
if "'" in optimized or '"' in optimized:
notes.append("Consider using parameters for literal values to enable query plan caching")
# Suggest indices based on query patterns
if 'MATCH (n:' in optimized:
label_match = re.search(r'MATCH \(n:(\w+)\)', optimized)
if label_match:
label = label_match.group(1)
index_hints.append(f"node_label_index:{label}")
if 'WHERE' in optimized.upper() and '.' in optimized:
# Property access patterns
property_pattern = re.compile(r'\.(\w+)', re.IGNORECASE)
properties = property_pattern.findall(optimized)
for prop in set(properties):
index_hints.append(f"property_index:{prop}")
# Optimize relationship traversals
if '-[' in optimized and '*' in optimized:
notes.append("Variable length path detected - consider adding relationship type filters")
# Early filtering optimization
if 'WHERE' in optimized.upper():
# Move WHERE clauses closer to MATCH clauses
notes.append("WHERE clauses present - ensure early filtering for performance")
return optimized, notes, index_hints
def _optimize_cypher_accuracy(self,
query: str,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Tuple[str, List[str]]:
"""Apply accuracy optimizations to Cypher query.
Args:
query: Cypher query string
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Optimized query and optimization notes
"""
optimized = query
notes = []
# Add DISTINCT if not present for retrieval queries
if (question_components.question_type == QuestionType.RETRIEVAL and
'DISTINCT' not in optimized.upper() and
'RETURN' in optimized.upper()):
optimized = re.sub(r'RETURN\s+', 'RETURN DISTINCT ', optimized, count=1, flags=re.IGNORECASE)
notes.append("Added DISTINCT to eliminate duplicate results")
# Ensure proper relationship direction
if '-[' in optimized and question_components.relationships:
notes.append("Verified relationship directions for semantic accuracy")
# Add null checks for optional properties
if '?' in optimized or 'OPTIONAL' in optimized.upper():
notes.append("Consider adding null checks for optional properties")
return optimized, notes
def _estimate_sparql_cost(self, query: str, ontology_subset: QueryOntologySubset) -> float:
"""Estimate execution cost for SPARQL query.
Args:
query: SPARQL query string
ontology_subset: Ontology subset
Returns:
Estimated cost (0.0 to 1.0)
"""
cost = 0.0
# Basic query complexity
cost += len(query.split('\n')) * 0.01
# Join complexity
triple_patterns = len(re.findall(r'\?\w+\s+\?\w+\s+\?\w+', query))
cost += triple_patterns * 0.1
# OPTIONAL clauses
optional_count = len(re.findall(r'OPTIONAL', query, re.IGNORECASE))
cost += optional_count * 0.15
# FILTER clauses
filter_count = len(re.findall(r'FILTER', query, re.IGNORECASE))
cost += filter_count * 0.1
# Property paths
path_count = len(re.findall(r'\*|\+', query))
cost += path_count * 0.2
# Ontology subset size impact
total_elements = (len(ontology_subset.classes) +
len(ontology_subset.object_properties) +
len(ontology_subset.datatype_properties))
cost += (total_elements / 100.0) * 0.1
return min(cost, 1.0)
def _estimate_cypher_cost(self, query: str, ontology_subset: QueryOntologySubset) -> float:
"""Estimate execution cost for Cypher query.
Args:
query: Cypher query string
ontology_subset: Ontology subset
Returns:
Estimated cost (0.0 to 1.0)
"""
cost = 0.0
# Basic query complexity
cost += len(query.split('\n')) * 0.01
# Pattern complexity
match_count = len(re.findall(r'MATCH', query, re.IGNORECASE))
cost += match_count * 0.1
# Relationship traversals
rel_count = len(re.findall(r'-\[.*?\]-', query))
cost += rel_count * 0.1
# Variable length paths
var_path_count = len(re.findall(r'\*\d*\.\.', query))
cost += var_path_count * 0.3
# WHERE clauses
where_count = len(re.findall(r'WHERE', query, re.IGNORECASE))
cost += where_count * 0.05
# Aggregation functions
agg_count = len(re.findall(r'COUNT|SUM|AVG|MIN|MAX', query, re.IGNORECASE))
cost += agg_count * 0.1
# Ontology subset size impact
total_elements = (len(ontology_subset.classes) +
len(ontology_subset.object_properties) +
len(ontology_subset.datatype_properties))
cost += (total_elements / 100.0) * 0.1
return min(cost, 1.0)
def should_use_cache(self,
query: str,
question_components: QuestionComponents,
optimization_hint: OptimizationHint) -> bool:
"""Determine if query results should be cached.
Args:
query: Query string
question_components: Question analysis
optimization_hint: Optimization hints
Returns:
True if results should be cached
"""
if not optimization_hint.cache_results:
return False
# Cache simple retrieval and factual queries
if question_components.question_type in [QuestionType.RETRIEVAL, QuestionType.FACTUAL]:
return True
# Cache expensive aggregation queries
if (question_components.question_type == QuestionType.AGGREGATION and
('COUNT' in query.upper() or 'SUM' in query.upper())):
return True
# Don't cache real-time or time-sensitive queries
if any(keyword in question_components.original_question.lower()
for keyword in ['now', 'current', 'latest', 'recent']):
return False
return False
def get_cache_key(self,
query: str,
ontology_subset: QueryOntologySubset) -> str:
"""Generate cache key for query.
Args:
query: Query string
ontology_subset: Ontology subset
Returns:
Cache key string
"""
import hashlib
# Create stable representation
ontology_repr = f"{sorted(ontology_subset.classes.keys())}-{sorted(ontology_subset.object_properties.keys())}"
combined = f"{query.strip()}-{ontology_repr}"
return hashlib.md5(combined.encode()).hexdigest()

View file

@ -0,0 +1,438 @@
"""
Main OntoRAG query service.
Orchestrates question analysis, ontology matching, query generation, execution, and answer generation.
"""
import logging
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass
from datetime import datetime
from ....flow.flow_processor import FlowProcessor
from ....tables.config import ConfigTableStore
from ...extract.kg.ontology.ontology_loader import OntologyLoader
from ...extract.kg.ontology.vector_store import InMemoryVectorStore
from .question_analyzer import QuestionAnalyzer, QuestionComponents
from .ontology_matcher import OntologyMatcher, QueryOntologySubset
from .backend_router import BackendRouter, QueryRoute, BackendType
from .sparql_generator import SPARQLGenerator, SPARQLQuery
from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult
from .cypher_generator import CypherGenerator, CypherQuery
from .cypher_executor import CypherExecutor, CypherResult
from .answer_generator import AnswerGenerator, GeneratedAnswer
logger = logging.getLogger(__name__)
@dataclass
class QueryRequest:
"""Query request from user."""
question: str
context: Optional[str] = None
ontology_hint: Optional[str] = None
max_results: int = 10
confidence_threshold: float = 0.7
@dataclass
class QueryResponse:
"""Complete query response."""
answer: str
confidence: float
execution_time: float
question_analysis: QuestionComponents
ontology_subsets: List[QueryOntologySubset]
query_route: QueryRoute
generated_query: Union[SPARQLQuery, CypherQuery]
raw_results: Union[SPARQLResult, CypherResult]
supporting_facts: List[str]
metadata: Dict[str, Any]
class OntoRAGQueryService(FlowProcessor):
"""Main OntoRAG query service orchestrating all components."""
def __init__(self, config: Dict[str, Any]):
"""Initialize OntoRAG query service.
Args:
config: Service configuration
"""
super().__init__(config)
self.config = config
# Initialize components
self.config_store = None
self.ontology_loader = None
self.vector_store = None
self.question_analyzer = None
self.ontology_matcher = None
self.backend_router = None
self.sparql_generator = None
self.sparql_engine = None
self.cypher_generator = None
self.cypher_executor = None
self.answer_generator = None
# Cache for loaded ontologies
self.ontology_cache = {}
async def init(self):
"""Initialize all components."""
await super().init()
# Initialize configuration store
self.config_store = ConfigTableStore(self.config.get('config_store', {}))
# Initialize ontology components
self.ontology_loader = OntologyLoader(self.config_store)
# Initialize vector store
vector_config = self.config.get('vector_store', {})
self.vector_store = InMemoryVectorStore.create(
store_type=vector_config.get('type', 'numpy'),
dimension=vector_config.get('dimension', 384),
similarity_threshold=vector_config.get('similarity_threshold', 0.7)
)
# Initialize question analyzer
analyzer_config = self.config.get('question_analyzer', {})
self.question_analyzer = QuestionAnalyzer(
prompt_service=self.prompt_service,
config=analyzer_config
)
# Initialize ontology matcher
matcher_config = self.config.get('ontology_matcher', {})
self.ontology_matcher = OntologyMatcher(
vector_store=self.vector_store,
embedding_service=self.embedding_service,
config=matcher_config
)
# Initialize backend router
router_config = self.config.get('backend_router', {})
self.backend_router = BackendRouter(router_config)
# Initialize query generators
self.sparql_generator = SPARQLGenerator(prompt_service=self.prompt_service)
self.cypher_generator = CypherGenerator(prompt_service=self.prompt_service)
# Initialize executors
sparql_config = self.config.get('sparql_executor', {})
if self.backend_router.is_backend_enabled(BackendType.CASSANDRA):
cassandra_config = self.backend_router.get_backend_config(BackendType.CASSANDRA)
if cassandra_config:
self.sparql_engine = SPARQLCassandraEngine(cassandra_config)
await self.sparql_engine.initialize()
cypher_config = self.config.get('cypher_executor', {})
enabled_graph_backends = [
bt for bt in [BackendType.NEO4J, BackendType.MEMGRAPH, BackendType.FALKORDB]
if self.backend_router.is_backend_enabled(bt)
]
if enabled_graph_backends:
self.cypher_executor = CypherExecutor(cypher_config)
await self.cypher_executor.initialize()
# Initialize answer generator
self.answer_generator = AnswerGenerator(prompt_service=self.prompt_service)
logger.info("OntoRAG query service initialized")
async def process(self, request: QueryRequest) -> QueryResponse:
"""Process a natural language query.
Args:
request: Query request
Returns:
Complete query response
"""
start_time = datetime.now()
try:
logger.info(f"Processing query: {request.question}")
# Step 1: Analyze question
question_components = await self.question_analyzer.analyze_question(
request.question, context=request.context
)
logger.debug(f"Question analysis: {question_components.question_type}")
# Step 2: Load and match ontologies
ontology_subsets = await self._load_and_match_ontologies(
question_components, request.ontology_hint
)
logger.debug(f"Found {len(ontology_subsets)} relevant ontology subsets")
# Step 3: Route to appropriate backend
query_route = self.backend_router.route_query(
question_components, ontology_subsets
)
logger.debug(f"Routed to {query_route.backend_type.value} backend")
# Step 4: Generate and execute query
if query_route.query_language == 'sparql':
query_results = await self._execute_sparql_path(
question_components, ontology_subsets, query_route
)
else: # cypher
query_results = await self._execute_cypher_path(
question_components, ontology_subsets, query_route
)
# Step 5: Generate natural language answer
generated_answer = await self.answer_generator.generate_answer(
question_components,
query_results['raw_results'],
ontology_subsets[0] if ontology_subsets else None,
query_route.backend_type.value
)
# Build response
execution_time = (datetime.now() - start_time).total_seconds()
response = QueryResponse(
answer=generated_answer.answer,
confidence=min(query_route.confidence, generated_answer.metadata.confidence),
execution_time=execution_time,
question_analysis=question_components,
ontology_subsets=ontology_subsets,
query_route=query_route,
generated_query=query_results['generated_query'],
raw_results=query_results['raw_results'],
supporting_facts=generated_answer.supporting_facts,
metadata={
'backend_used': query_route.backend_type.value,
'query_language': query_route.query_language,
'ontology_count': len(ontology_subsets),
'result_count': generated_answer.metadata.result_count,
'routing_reasoning': query_route.reasoning,
'generation_time': generated_answer.generation_time
}
)
logger.info(f"Query processed successfully in {execution_time:.2f}s")
return response
except Exception as e:
logger.error(f"Query processing failed: {e}")
execution_time = (datetime.now() - start_time).total_seconds()
# Return error response
return QueryResponse(
answer=f"I encountered an error processing your query: {str(e)}",
confidence=0.0,
execution_time=execution_time,
question_analysis=QuestionComponents(
original_question=request.question,
normalized_question=request.question,
question_type=None,
entities=[], keywords=[], relationships=[], constraints=[],
aggregations=[], expected_answer_type="unknown"
),
ontology_subsets=[],
query_route=None,
generated_query=None,
raw_results=None,
supporting_facts=[],
metadata={'error': str(e), 'execution_time': execution_time}
)
async def _load_and_match_ontologies(self,
question_components: QuestionComponents,
ontology_hint: Optional[str] = None) -> List[QueryOntologySubset]:
"""Load ontologies and find relevant subsets.
Args:
question_components: Analyzed question
ontology_hint: Optional ontology hint
Returns:
List of relevant ontology subsets
"""
try:
# Load available ontologies
if ontology_hint:
# Load specific ontology
ontologies = [await self.ontology_loader.load_ontology(ontology_hint)]
else:
# Load all available ontologies
available_ontologies = await self.ontology_loader.list_available_ontologies()
ontologies = []
for ontology_id in available_ontologies[:5]: # Limit to 5 for performance
try:
ontology = await self.ontology_loader.load_ontology(ontology_id)
ontologies.append(ontology)
except Exception as e:
logger.warning(f"Failed to load ontology {ontology_id}: {e}")
if not ontologies:
logger.warning("No ontologies loaded")
return []
# Extract relevant subsets
ontology_subsets = []
for ontology in ontologies:
subset = await self.ontology_matcher.select_relevant_subset(
question_components, ontology
)
if subset and (subset.classes or subset.object_properties or subset.datatype_properties):
ontology_subsets.append(subset)
return ontology_subsets
except Exception as e:
logger.error(f"Failed to load and match ontologies: {e}")
return []
async def _execute_sparql_path(self,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
query_route: QueryRoute) -> Dict[str, Any]:
"""Execute SPARQL query path.
Args:
question_components: Question analysis
ontology_subsets: Ontology subsets
query_route: Query route
Returns:
Query execution results
"""
if not self.sparql_engine:
raise RuntimeError("SPARQL engine not initialized")
# Generate SPARQL query
primary_subset = ontology_subsets[0] if ontology_subsets else None
sparql_query = await self.sparql_generator.generate_sparql(
question_components, primary_subset
)
logger.debug(f"Generated SPARQL: {sparql_query.query}")
# Execute query
sparql_results = self.sparql_engine.execute_sparql(sparql_query.query)
return {
'generated_query': sparql_query,
'raw_results': sparql_results
}
async def _execute_cypher_path(self,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
query_route: QueryRoute) -> Dict[str, Any]:
"""Execute Cypher query path.
Args:
question_components: Question analysis
ontology_subsets: Ontology subsets
query_route: Query route
Returns:
Query execution results
"""
if not self.cypher_executor:
raise RuntimeError("Cypher executor not initialized")
# Generate Cypher query
primary_subset = ontology_subsets[0] if ontology_subsets else None
cypher_query = await self.cypher_generator.generate_cypher(
question_components, primary_subset
)
logger.debug(f"Generated Cypher: {cypher_query.query}")
# Execute query
database_type = query_route.backend_type.value
cypher_results = await self.cypher_executor.execute_query(
cypher_query.query, database_type=database_type
)
return {
'generated_query': cypher_query,
'raw_results': cypher_results
}
async def get_supported_backends(self) -> List[str]:
"""Get list of supported and enabled backends.
Returns:
List of backend names
"""
return [bt.value for bt in self.backend_router.get_available_backends()]
async def get_available_ontologies(self) -> List[str]:
"""Get list of available ontologies.
Returns:
List of ontology identifiers
"""
if self.ontology_loader:
return await self.ontology_loader.list_available_ontologies()
return []
async def health_check(self) -> Dict[str, Any]:
"""Perform health check on all components.
Returns:
Health status of all components
"""
health = {
'service': 'healthy',
'components': {},
'backends': {},
'ontologies': {}
}
try:
# Check ontology loader
if self.ontology_loader:
ontologies = await self.ontology_loader.list_available_ontologies()
health['components']['ontology_loader'] = 'healthy'
health['ontologies']['count'] = len(ontologies)
else:
health['components']['ontology_loader'] = 'not_initialized'
# Check vector store
if self.vector_store:
health['components']['vector_store'] = 'healthy'
health['components']['vector_store_type'] = type(self.vector_store).__name__
else:
health['components']['vector_store'] = 'not_initialized'
# Check backends
for backend_type in self.backend_router.get_available_backends():
if backend_type == BackendType.CASSANDRA and self.sparql_engine:
health['backends']['cassandra'] = 'healthy'
elif backend_type in [BackendType.NEO4J, BackendType.MEMGRAPH, BackendType.FALKORDB] and self.cypher_executor:
health['backends'][backend_type.value] = 'healthy'
else:
health['backends'][backend_type.value] = 'configured_but_not_initialized'
except Exception as e:
health['service'] = 'degraded'
health['error'] = str(e)
return health
async def close(self):
"""Close all connections and cleanup resources."""
try:
if self.sparql_engine:
self.sparql_engine.close()
if self.cypher_executor:
await self.cypher_executor.close()
if self.config_store:
# ConfigTableStore cleanup if needed
pass
logger.info("OntoRAG query service closed")
except Exception as e:
logger.error(f"Error closing OntoRAG query service: {e}")

View file

@ -0,0 +1,364 @@
"""
Question analyzer for ontology-sensitive query system.
Decomposes user questions into semantic components.
"""
import logging
import re
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
logger = logging.getLogger(__name__)
class QuestionType(Enum):
"""Types of questions that can be asked."""
FACTUAL = "factual" # What is X?
RETRIEVAL = "retrieval" # Find all X
AGGREGATION = "aggregation" # How many X?
COMPARISON = "comparison" # Is X better than Y?
RELATIONSHIP = "relationship" # How is X related to Y?
BOOLEAN = "boolean" # Yes/no questions
PROCESS = "process" # How to do X?
TEMPORAL = "temporal" # When did X happen?
SPATIAL = "spatial" # Where is X?
@dataclass
class QuestionComponents:
"""Components extracted from a question."""
original_question: str
question_type: QuestionType
entities: List[str]
relationships: List[str]
constraints: List[str]
aggregations: List[str]
expected_answer_type: str
keywords: List[str]
class QuestionAnalyzer:
"""Analyzes natural language questions to extract semantic components."""
def __init__(self):
"""Initialize question analyzer."""
# Question word patterns
self.question_patterns = {
QuestionType.FACTUAL: [
r'^what\s+(?:is|are)',
r'^who\s+(?:is|are)',
r'^which\s+',
],
QuestionType.RETRIEVAL: [
r'^find\s+',
r'^list\s+',
r'^show\s+',
r'^get\s+',
r'^retrieve\s+',
],
QuestionType.AGGREGATION: [
r'^how\s+many',
r'^count\s+',
r'^what\s+(?:is|are)\s+the\s+(?:number|total|sum)',
],
QuestionType.COMPARISON: [
r'(?:better|worse|more|less|greater|smaller)\s+than',
r'compare\s+',
r'difference\s+between',
],
QuestionType.RELATIONSHIP: [
r'^how\s+(?:is|are).*related',
r'relationship\s+between',
r'connection\s+between',
],
QuestionType.BOOLEAN: [
r'^(?:is|are|was|were|do|does|did|can|could|will|would|should)',
r'^has\s+',
r'^have\s+',
],
QuestionType.PROCESS: [
r'^how\s+(?:to|do)',
r'^explain\s+how',
],
QuestionType.TEMPORAL: [
r'^when\s+',
r'what\s+time',
r'what\s+date',
],
QuestionType.SPATIAL: [
r'^where\s+',
r'location\s+of',
],
}
# Aggregation keywords
self.aggregation_keywords = [
'count', 'sum', 'total', 'average', 'mean', 'median',
'maximum', 'minimum', 'max', 'min', 'number of'
]
# Constraint patterns
self.constraint_patterns = [
r'(?:with|having|where)\s+(.+?)(?:\s+and|\s+or|$)',
r'(?:greater|less|more|fewer)\s+than\s+(\d+)',
r'(?:between|from)\s+(.+?)\s+(?:and|to)\s+(.+)',
r'(?:before|after|since|until)\s+(.+)',
]
def analyze(self, question: str) -> QuestionComponents:
"""Analyze a question to extract components.
Args:
question: Natural language question
Returns:
QuestionComponents with extracted information
"""
# Normalize question
question_lower = question.lower().strip()
# Determine question type
question_type = self._identify_question_type(question_lower)
# Extract entities
entities = self._extract_entities(question)
# Extract relationships
relationships = self._extract_relationships(question_lower)
# Extract constraints
constraints = self._extract_constraints(question_lower)
# Extract aggregations
aggregations = self._extract_aggregations(question_lower)
# Determine expected answer type
answer_type = self._determine_answer_type(question_type, aggregations)
# Extract keywords
keywords = self._extract_keywords(question_lower)
return QuestionComponents(
original_question=question,
question_type=question_type,
entities=entities,
relationships=relationships,
constraints=constraints,
aggregations=aggregations,
expected_answer_type=answer_type,
keywords=keywords
)
def _identify_question_type(self, question: str) -> QuestionType:
"""Identify the type of question.
Args:
question: Lowercase question text
Returns:
QuestionType enum value
"""
for q_type, patterns in self.question_patterns.items():
for pattern in patterns:
if re.search(pattern, question):
return q_type
# Default to factual
return QuestionType.FACTUAL
def _extract_entities(self, question: str) -> List[str]:
"""Extract potential entities from question.
Args:
question: Original question text
Returns:
List of entity strings
"""
entities = []
# Extract capitalized words/phrases (potential proper nouns)
# Pattern for consecutive capitalized words
pattern = r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b'
matches = re.findall(pattern, question)
entities.extend(matches)
# Extract quoted strings
quoted = re.findall(r'"([^"]+)"', question)
entities.extend(quoted)
quoted = re.findall(r"'([^']+)'", question)
entities.extend(quoted)
# Remove duplicates while preserving order
seen = set()
unique_entities = []
for entity in entities:
if entity not in seen:
seen.add(entity)
unique_entities.append(entity)
return unique_entities
def _extract_relationships(self, question: str) -> List[str]:
"""Extract relationship indicators from question.
Args:
question: Lowercase question text
Returns:
List of relationship strings
"""
relationships = []
# Common relationship patterns
rel_patterns = [
r'(\w+)\s+(?:of|by|from|to|with|for)\s+',
r'has\s+(\w+)',
r'belongs?\s+to',
r'(?:created|written|authored|owned)\s+by',
r'related\s+to',
r'connected\s+to',
r'associated\s+with',
]
for pattern in rel_patterns:
matches = re.findall(pattern, question)
relationships.extend(matches)
# Clean up
relationships = [r for r in relationships if len(r) > 2]
return list(set(relationships))
def _extract_constraints(self, question: str) -> List[str]:
"""Extract constraints from question.
Args:
question: Lowercase question text
Returns:
List of constraint strings
"""
constraints = []
for pattern in self.constraint_patterns:
matches = re.findall(pattern, question)
if matches:
if isinstance(matches[0], tuple):
constraints.extend(list(matches[0]))
else:
constraints.extend(matches)
# Clean up
constraints = [c.strip() for c in constraints if c and len(c.strip()) > 0]
return constraints
def _extract_aggregations(self, question: str) -> List[str]:
"""Extract aggregation operations from question.
Args:
question: Lowercase question text
Returns:
List of aggregation operations
"""
aggregations = []
for keyword in self.aggregation_keywords:
if keyword in question:
aggregations.append(keyword)
return aggregations
def _determine_answer_type(self, question_type: QuestionType,
aggregations: List[str]) -> str:
"""Determine expected answer type.
Args:
question_type: Type of question
aggregations: Aggregation operations found
Returns:
Expected answer type string
"""
if aggregations:
if any(a in ['count', 'number of', 'total'] for a in aggregations):
return 'number'
elif any(a in ['average', 'mean', 'median'] for a in aggregations):
return 'number'
elif any(a in ['sum'] for a in aggregations):
return 'number'
if question_type == QuestionType.BOOLEAN:
return 'boolean'
elif question_type == QuestionType.TEMPORAL:
return 'datetime'
elif question_type == QuestionType.SPATIAL:
return 'location'
elif question_type == QuestionType.RETRIEVAL:
return 'list'
elif question_type == QuestionType.COMPARISON:
return 'comparison'
else:
return 'text'
def _extract_keywords(self, question: str) -> List[str]:
"""Extract important keywords from question.
Args:
question: Lowercase question text
Returns:
List of keywords
"""
# Remove common stop words
stop_words = {
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to',
'for', 'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are',
'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do',
'does', 'did', 'will', 'would', 'could', 'should', 'may',
'might', 'must', 'can', 'shall', 'what', 'which', 'who',
'when', 'where', 'why', 'how'
}
# Extract words
words = re.findall(r'\b\w+\b', question)
# Filter stop words and short words
keywords = [w for w in words if w not in stop_words and len(w) > 2]
# Remove duplicates while preserving order
seen = set()
unique_keywords = []
for kw in keywords:
if kw not in seen:
seen.add(kw)
unique_keywords.append(kw)
return unique_keywords
def get_question_segments(self, question: str) -> List[str]:
"""Split question into segments for embedding.
Args:
question: Question text
Returns:
List of question segments
"""
segments = []
# Add full question
segments.append(question)
# Split by clauses
clauses = re.split(r'[,;]', question)
segments.extend([c.strip() for c in clauses if len(c.strip()) > 3])
# Extract key phrases
components = self.analyze(question)
segments.extend(components.entities)
segments.extend(components.keywords)
# Remove duplicates
return list(dict.fromkeys(segments))

View file

@ -0,0 +1,481 @@
"""
SPARQL-Cassandra engine using Python rdflib.
Executes SPARQL queries against Cassandra using a custom Store implementation.
"""
import logging
from typing import Dict, Any, List, Optional, Iterator, Tuple
from dataclasses import dataclass
import json
# Try to import rdflib
try:
from rdflib import Graph, Namespace, URIRef, Literal, BNode
from rdflib.store import Store
from rdflib.plugins.sparql.processor import SPARQLResult
from rdflib.plugins.sparql import prepareQuery
from rdflib.term import Node
RDFLIB_AVAILABLE = True
except ImportError:
RDFLIB_AVAILABLE = False
# Try to import Cassandra driver
try:
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.policies import DCAwareRoundRobinPolicy
CASSANDRA_AVAILABLE = True
except ImportError:
CASSANDRA_AVAILABLE = False
from ....tables.config import ConfigTableStore
logger = logging.getLogger(__name__)
@dataclass
class SPARQLResult:
"""Result from SPARQL query execution."""
bindings: List[Dict[str, Any]]
variables: List[str]
ask_result: Optional[bool] = None # For ASK queries
execution_time: float = 0.0
query_plan: Optional[str] = None
class CassandraTripleStore(Store if RDFLIB_AVAILABLE else object):
"""Custom rdflib Store implementation for Cassandra."""
def __init__(self, cassandra_config: Dict[str, Any]):
"""Initialize Cassandra triple store.
Args:
cassandra_config: Cassandra connection configuration
"""
if not CASSANDRA_AVAILABLE:
raise RuntimeError("Cassandra driver not available")
if not RDFLIB_AVAILABLE:
raise RuntimeError("rdflib not available")
super().__init__()
self.cassandra_config = cassandra_config
self.cluster = None
self.session = None
self.keyspace = cassandra_config.get('keyspace', 'trustgraph')
# Triple storage table structure
self.triple_table = f"{self.keyspace}.triples"
self.metadata_table = f"{self.keyspace}.triple_metadata"
def open(self, configuration=None, create=False):
"""Open connection to Cassandra."""
try:
# Create authentication if provided
auth_provider = None
if 'username' in self.cassandra_config and 'password' in self.cassandra_config:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_config['username'],
password=self.cassandra_config['password']
)
# Create cluster
self.cluster = Cluster(
[self.cassandra_config.get('host', 'localhost')],
port=self.cassandra_config.get('port', 9042),
auth_provider=auth_provider,
load_balancing_policy=DCAwareRoundRobinPolicy()
)
# Connect
self.session = self.cluster.connect()
# Ensure keyspace exists
if create:
self._create_schema()
# Set keyspace
self.session.set_keyspace(self.keyspace)
logger.info(f"Connected to Cassandra cluster: {self.cassandra_config.get('host')}")
return True
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}")
return False
def close(self, commit_pending_transaction=True):
"""Close Cassandra connection."""
if self.session:
self.session.shutdown()
if self.cluster:
self.cluster.shutdown()
def _create_schema(self):
"""Create Cassandra schema for triple storage."""
# Create keyspace
self.session.execute(f"""
CREATE KEYSPACE IF NOT EXISTS {self.keyspace}
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}
""")
# Create triples table optimized for SPARQL queries
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.triple_table} (
subject text,
predicate text,
object text,
object_datatype text,
object_language text,
is_literal boolean,
graph_id text,
PRIMARY KEY ((subject), predicate, object)
)
""")
# Create indexes for efficient querying
self.session.execute(f"""
CREATE INDEX IF NOT EXISTS ON {self.triple_table} (predicate)
""")
self.session.execute(f"""
CREATE INDEX IF NOT EXISTS ON {self.triple_table} (object)
""")
# Metadata table for graph information
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
graph_id text PRIMARY KEY,
created timestamp,
modified timestamp,
triple_count counter
)
""")
def triples(self, triple_pattern, context=None):
"""Retrieve triples matching the given pattern.
Args:
triple_pattern: (subject, predicate, object) pattern with None for variables
context: Graph context (optional)
Yields:
Matching triples as (subject, predicate, object) tuples
"""
if not self.session:
return
subject, predicate, object_val = triple_pattern
# Build CQL query based on pattern
cql_queries = self._pattern_to_cql(subject, predicate, object_val)
for cql, params in cql_queries:
try:
rows = self.session.execute(cql, params)
for row in rows:
yield self._row_to_triple(row)
except Exception as e:
logger.error(f"Error executing CQL query: {e}")
def _pattern_to_cql(self, subject, predicate, object_val) -> List[Tuple[str, List]]:
"""Convert triple pattern to CQL queries.
Args:
subject: Subject node or None
predicate: Predicate node or None
object_val: Object node or None
Returns:
List of (CQL query, parameters) tuples
"""
queries = []
# Convert None to wildcard, nodes to strings
s_str = str(subject) if subject else None
p_str = str(predicate) if predicate else None
o_str = str(object_val) if object_val else None
if s_str and p_str and o_str:
# Specific triple lookup
cql = f"SELECT * FROM {self.triple_table} WHERE subject = ? AND predicate = ? AND object = ?"
queries.append((cql, [s_str, p_str, o_str]))
elif s_str and p_str:
# Subject and predicate known
cql = f"SELECT * FROM {self.triple_table} WHERE subject = ? AND predicate = ?"
queries.append((cql, [s_str, p_str]))
elif s_str:
# Subject known
cql = f"SELECT * FROM {self.triple_table} WHERE subject = ?"
queries.append((cql, [s_str]))
elif p_str:
# Predicate known (requires index scan)
cql = f"SELECT * FROM {self.triple_table} WHERE predicate = ? ALLOW FILTERING"
queries.append((cql, [p_str]))
elif o_str:
# Object known (requires index scan)
cql = f"SELECT * FROM {self.triple_table} WHERE object = ? ALLOW FILTERING"
queries.append((cql, [o_str]))
else:
# Full scan (should be avoided in production)
cql = f"SELECT * FROM {self.triple_table}"
queries.append((cql, []))
return queries
def _row_to_triple(self, row):
"""Convert Cassandra row to RDF triple.
Args:
row: Cassandra row object
Returns:
(subject, predicate, object) tuple with rdflib nodes
"""
# Convert to rdflib nodes
subject = URIRef(row.subject) if row.subject.startswith('http') else BNode(row.subject)
predicate = URIRef(row.predicate)
if row.is_literal:
# Create literal with datatype/language
if row.object_datatype:
object_node = Literal(row.object, datatype=URIRef(row.object_datatype))
elif row.object_language:
object_node = Literal(row.object, lang=row.object_language)
else:
object_node = Literal(row.object)
else:
object_node = URIRef(row.object) if row.object.startswith('http') else BNode(row.object)
return (subject, predicate, object_node)
def add(self, triple, context=None, quoted=False):
"""Add a triple to the store.
Args:
triple: (subject, predicate, object) tuple
context: Graph context
quoted: Whether triple is quoted
"""
if not self.session:
return
subject, predicate, object_val = triple
# Convert to storage format
s_str = str(subject)
p_str = str(predicate)
is_literal = isinstance(object_val, Literal)
o_str = str(object_val)
o_datatype = str(object_val.datatype) if is_literal and object_val.datatype else None
o_language = object_val.language if is_literal and object_val.language else None
# Insert into Cassandra
cql = f"""
INSERT INTO {self.triple_table}
(subject, predicate, object, object_datatype, object_language, is_literal, graph_id)
VALUES (?, ?, ?, ?, ?, ?, ?)
"""
try:
self.session.execute(cql, [
s_str, p_str, o_str, o_datatype, o_language, is_literal,
str(context) if context else 'default'
])
except Exception as e:
logger.error(f"Error adding triple: {e}")
def remove(self, triple, context=None):
"""Remove a triple from the store.
Args:
triple: (subject, predicate, object) tuple
context: Graph context
"""
if not self.session:
return
subject, predicate, object_val = triple
cql = f"""
DELETE FROM {self.triple_table}
WHERE subject = ? AND predicate = ? AND object = ?
"""
try:
self.session.execute(cql, [str(subject), str(predicate), str(object_val)])
except Exception as e:
logger.error(f"Error removing triple: {e}")
def __len__(self, context=None):
"""Get number of triples in store.
Args:
context: Graph context
Returns:
Number of triples
"""
if not self.session:
return 0
try:
cql = f"SELECT COUNT(*) FROM {self.triple_table}"
result = self.session.execute(cql)
return result.one().count
except Exception as e:
logger.error(f"Error counting triples: {e}")
return 0
class SPARQLCassandraEngine:
"""SPARQL processor using Cassandra backend."""
def __init__(self, cassandra_config: Dict[str, Any]):
"""Initialize SPARQL-Cassandra engine.
Args:
cassandra_config: Cassandra configuration
"""
if not RDFLIB_AVAILABLE:
raise RuntimeError("rdflib is required for SPARQL processing")
if not CASSANDRA_AVAILABLE:
raise RuntimeError("Cassandra driver is required")
self.cassandra_config = cassandra_config
self.store = CassandraTripleStore(cassandra_config)
self.graph = Graph(store=self.store)
# Common namespaces
self.namespaces = {
'rdf': Namespace('http://www.w3.org/1999/02/22-rdf-syntax-ns#'),
'rdfs': Namespace('http://www.w3.org/2000/01/rdf-schema#'),
'owl': Namespace('http://www.w3.org/2002/07/owl#'),
'xsd': Namespace('http://www.w3.org/2001/XMLSchema#'),
}
# Bind namespaces to graph
for prefix, namespace in self.namespaces.items():
self.graph.bind(prefix, namespace)
async def initialize(self, create_schema=False):
"""Initialize the engine.
Args:
create_schema: Whether to create Cassandra schema
"""
success = self.store.open(create=create_schema)
if not success:
raise RuntimeError("Failed to connect to Cassandra")
logger.info("SPARQL-Cassandra engine initialized")
def execute_sparql(self, sparql_query: str) -> SPARQLResult:
"""Execute SPARQL query against Cassandra.
Args:
sparql_query: SPARQL query string
Returns:
Query results
"""
import time
start_time = time.time()
try:
# Prepare and execute query
prepared_query = prepareQuery(sparql_query)
result = self.graph.query(prepared_query)
execution_time = time.time() - start_time
# Format results based on query type
if sparql_query.strip().upper().startswith('ASK'):
return SPARQLResult(
bindings=[],
variables=[],
ask_result=bool(result),
execution_time=execution_time
)
else:
# SELECT query
bindings = []
variables = result.vars if hasattr(result, 'vars') else []
for row in result:
binding = {}
for i, var in enumerate(variables):
if i < len(row):
value = row[i]
binding[str(var)] = self._format_result_value(value)
bindings.append(binding)
return SPARQLResult(
bindings=bindings,
variables=[str(v) for v in variables],
execution_time=execution_time
)
except Exception as e:
logger.error(f"SPARQL execution error: {e}")
return SPARQLResult(
bindings=[],
variables=[],
execution_time=time.time() - start_time
)
def _format_result_value(self, value):
"""Format result value for output.
Args:
value: RDF value (URIRef, Literal, BNode)
Returns:
Formatted value
"""
if isinstance(value, URIRef):
return {'type': 'uri', 'value': str(value)}
elif isinstance(value, Literal):
result = {'type': 'literal', 'value': str(value)}
if value.datatype:
result['datatype'] = str(value.datatype)
if value.language:
result['language'] = value.language
return result
elif isinstance(value, BNode):
return {'type': 'bnode', 'value': str(value)}
else:
return {'type': 'unknown', 'value': str(value)}
def load_triples_from_store(self, config_store: ConfigTableStore):
"""Load triples from TrustGraph's storage into the RDF graph.
Args:
config_store: Configuration store with triples
"""
# This would need to be implemented based on how triples are stored
# in TrustGraph's Cassandra tables
logger.info("Loading triples from TrustGraph store...")
# Example implementation - would need to be adapted
# to actual TrustGraph storage format
try:
# Get all triple data
# This is a placeholder - actual implementation would need
# to query the appropriate TrustGraph tables
pass
except Exception as e:
logger.error(f"Error loading triples: {e}")
def close(self):
"""Close the engine and connections."""
if self.store:
self.store.close()
logger.info("SPARQL-Cassandra engine closed")

View file

@ -0,0 +1,487 @@
"""
SPARQL query generator for ontology-sensitive queries.
Converts natural language questions to SPARQL queries for Cassandra execution.
"""
import logging
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from .question_analyzer import QuestionComponents, QuestionType
from .ontology_matcher import QueryOntologySubset
logger = logging.getLogger(__name__)
@dataclass
class SPARQLQuery:
"""Generated SPARQL query with metadata."""
query: str
variables: List[str]
query_type: str # SELECT, ASK, CONSTRUCT, DESCRIBE
explanation: str
complexity_score: float
class SPARQLGenerator:
"""Generates SPARQL queries from natural language questions using LLM assistance."""
def __init__(self, prompt_service=None):
"""Initialize SPARQL generator.
Args:
prompt_service: Service for LLM-based query generation
"""
self.prompt_service = prompt_service
# SPARQL query templates for common patterns
self.templates = {
'simple_class_query': """
PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?entity ?label WHERE {{
?entity rdf:type :{class_name} .
OPTIONAL {{ ?entity rdfs:label ?label }}
}}""",
'property_query': """
PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?subject ?object WHERE {{
?subject :{property} ?object .
?subject rdf:type :{subject_class} .
}}""",
'hierarchy_query': """
PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?subclass ?superclass WHERE {{
?subclass rdfs:subClassOf* ?superclass .
?superclass rdf:type :{root_class} .
}}""",
'count_query': """
PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
SELECT (COUNT(?entity) AS ?count) WHERE {{
?entity rdf:type :{class_name} .
{additional_constraints}
}}""",
'boolean_query': """
PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
ASK {{
{triple_pattern}
}}"""
}
async def generate_sparql(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> SPARQLQuery:
"""Generate SPARQL query for a question.
Args:
question_components: Analyzed question components
ontology_subset: Relevant ontology subset
Returns:
Generated SPARQL query
"""
# Try template-based generation first
template_query = self._try_template_generation(question_components, ontology_subset)
if template_query:
logger.debug("Generated SPARQL using template")
return template_query
# Fall back to LLM-based generation
if self.prompt_service:
llm_query = await self._generate_with_llm(question_components, ontology_subset)
if llm_query:
logger.debug("Generated SPARQL using LLM")
return llm_query
# Final fallback to simple pattern
logger.warning("Falling back to simple SPARQL pattern")
return self._generate_fallback_query(question_components, ontology_subset)
def _try_template_generation(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Optional[SPARQLQuery]:
"""Try to generate query using templates.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Generated query or None if no template matches
"""
namespace = ontology_subset.metadata.get('namespace', 'http://example.org/')
# Simple class query (What are the animals?)
if (question_components.question_type == QuestionType.RETRIEVAL and
len(question_components.entities) == 1 and
question_components.entities[0].lower() in [c.lower() for c in ontology_subset.classes]):
class_name = self._find_matching_class(question_components.entities[0], ontology_subset)
if class_name:
query = self.templates['simple_class_query'].format(
namespace=namespace,
class_name=class_name
)
return SPARQLQuery(
query=query,
variables=['entity', 'label'],
query_type='SELECT',
explanation=f"Retrieve all instances of {class_name}",
complexity_score=0.3
)
# Count query (How many animals are there?)
if (question_components.question_type == QuestionType.AGGREGATION and
'count' in question_components.aggregations and
len(question_components.entities) >= 1):
class_name = self._find_matching_class(question_components.entities[0], ontology_subset)
if class_name:
query = self.templates['count_query'].format(
namespace=namespace,
class_name=class_name,
additional_constraints=self._build_constraints(question_components, ontology_subset)
)
return SPARQLQuery(
query=query,
variables=['count'],
query_type='SELECT',
explanation=f"Count instances of {class_name}",
complexity_score=0.4
)
# Boolean query (Is X a Y?)
if question_components.question_type == QuestionType.BOOLEAN:
triple_pattern = self._build_boolean_pattern(question_components, ontology_subset)
if triple_pattern:
query = self.templates['boolean_query'].format(
namespace=namespace,
triple_pattern=triple_pattern
)
return SPARQLQuery(
query=query,
variables=[],
query_type='ASK',
explanation="Boolean query for fact checking",
complexity_score=0.2
)
return None
async def _generate_with_llm(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Optional[SPARQLQuery]:
"""Generate SPARQL using LLM.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Generated query or None if failed
"""
try:
prompt = self._build_sparql_prompt(question_components, ontology_subset)
response = await self.prompt_service.generate_sparql(prompt=prompt)
if response and isinstance(response, dict):
query = response.get('query', '').strip()
if query.upper().startswith(('SELECT', 'ASK', 'CONSTRUCT', 'DESCRIBE')):
return SPARQLQuery(
query=query,
variables=self._extract_variables(query),
query_type=query.split()[0].upper(),
explanation=response.get('explanation', 'Generated by LLM'),
complexity_score=self._calculate_complexity(query)
)
except Exception as e:
logger.error(f"LLM SPARQL generation failed: {e}")
return None
def _build_sparql_prompt(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> str:
"""Build prompt for LLM SPARQL generation.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Formatted prompt string
"""
namespace = ontology_subset.metadata.get('namespace', 'http://example.org/')
# Format ontology elements
classes_str = self._format_classes_for_prompt(ontology_subset.classes, namespace)
props_str = self._format_properties_for_prompt(
ontology_subset.object_properties,
ontology_subset.datatype_properties,
namespace
)
prompt = f"""Generate a SPARQL query for the following question using the provided ontology.
QUESTION: {question_components.original_question}
ONTOLOGY NAMESPACE: {namespace}
AVAILABLE CLASSES:
{classes_str}
AVAILABLE PROPERTIES:
{props_str}
RULES:
- Use proper SPARQL syntax
- Include appropriate prefixes
- Use property paths for hierarchical queries (rdfs:subClassOf*)
- Add FILTER clauses for constraints
- Optimize for Cassandra backend
- Return both query and explanation
QUERY TYPE HINTS:
- Question type: {question_components.question_type.value}
- Expected answer: {question_components.expected_answer_type}
- Entities mentioned: {', '.join(question_components.entities)}
- Aggregations: {', '.join(question_components.aggregations)}
Generate a complete SPARQL query:"""
return prompt
def _generate_fallback_query(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> SPARQLQuery:
"""Generate simple fallback query.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Basic SPARQL query
"""
namespace = ontology_subset.metadata.get('namespace', 'http://example.org/')
# Very basic SELECT query
query = f"""PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?subject ?predicate ?object WHERE {{
?subject ?predicate ?object .
FILTER(CONTAINS(STR(?subject), "{question_components.keywords[0] if question_components.keywords else 'entity'}"))
}}
LIMIT 10"""
return SPARQLQuery(
query=query,
variables=['subject', 'predicate', 'object'],
query_type='SELECT',
explanation="Fallback query for basic pattern matching",
complexity_score=0.1
)
def _find_matching_class(self, entity: str, ontology_subset: QueryOntologySubset) -> Optional[str]:
"""Find matching class in ontology subset.
Args:
entity: Entity string to match
ontology_subset: Ontology subset
Returns:
Matching class name or None
"""
entity_lower = entity.lower()
# Direct match
for class_id in ontology_subset.classes:
if class_id.lower() == entity_lower:
return class_id
# Label match
for class_id, class_def in ontology_subset.classes.items():
labels = class_def.get('labels', [])
for label in labels:
if isinstance(label, dict):
label_value = label.get('value', '').lower()
if label_value == entity_lower:
return class_id
# Partial match
for class_id in ontology_subset.classes:
if entity_lower in class_id.lower() or class_id.lower() in entity_lower:
return class_id
return None
def _build_constraints(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> str:
"""Build constraint clauses for SPARQL.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
SPARQL constraint string
"""
constraints = []
for constraint in question_components.constraints:
# Simple constraint patterns
if 'greater than' in constraint.lower():
# Extract number
import re
numbers = re.findall(r'\d+', constraint)
if numbers:
constraints.append(f"FILTER(?value > {numbers[0]})")
elif 'less than' in constraint.lower():
numbers = re.findall(r'\d+', constraint)
if numbers:
constraints.append(f"FILTER(?value < {numbers[0]})")
return '\n '.join(constraints)
def _build_boolean_pattern(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Optional[str]:
"""Build triple pattern for boolean queries.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
SPARQL triple pattern or None
"""
if len(question_components.entities) >= 2:
subject = question_components.entities[0]
object_val = question_components.entities[1]
# Try to find connecting property
for prop_id in ontology_subset.object_properties:
return f":{subject} :{prop_id} :{object_val} ."
# Fallback to type check
return f":{subject} rdf:type :{object_val} ."
return None
def _format_classes_for_prompt(self, classes: Dict[str, Any], namespace: str) -> str:
"""Format classes for prompt.
Args:
classes: Classes dictionary
namespace: Ontology namespace
Returns:
Formatted classes string
"""
if not classes:
return "None"
lines = []
for class_id, definition in classes.items():
comment = definition.get('comment', '')
parent = definition.get('subclass_of', 'Thing')
lines.append(f"- :{class_id} (subclass of :{parent}) - {comment}")
return '\n'.join(lines)
def _format_properties_for_prompt(self,
object_props: Dict[str, Any],
datatype_props: Dict[str, Any],
namespace: str) -> str:
"""Format properties for prompt.
Args:
object_props: Object properties
datatype_props: Datatype properties
namespace: Ontology namespace
Returns:
Formatted properties string
"""
lines = []
for prop_id, definition in object_props.items():
domain = definition.get('domain', 'Any')
range_val = definition.get('range', 'Any')
comment = definition.get('comment', '')
lines.append(f"- :{prop_id} (:{domain} -> :{range_val}) - {comment}")
for prop_id, definition in datatype_props.items():
domain = definition.get('domain', 'Any')
range_val = definition.get('range', 'xsd:string')
comment = definition.get('comment', '')
lines.append(f"- :{prop_id} (:{domain} -> {range_val}) - {comment}")
return '\n'.join(lines) if lines else "None"
def _extract_variables(self, query: str) -> List[str]:
"""Extract variables from SPARQL query.
Args:
query: SPARQL query string
Returns:
List of variable names
"""
import re
variables = re.findall(r'\?(\w+)', query)
return list(set(variables))
def _calculate_complexity(self, query: str) -> float:
"""Calculate complexity score for SPARQL query.
Args:
query: SPARQL query string
Returns:
Complexity score (0.0 to 1.0)
"""
complexity = 0.0
# Count different SPARQL features
query_upper = query.upper()
if 'JOIN' in query_upper or 'UNION' in query_upper:
complexity += 0.3
if 'FILTER' in query_upper:
complexity += 0.2
if 'OPTIONAL' in query_upper:
complexity += 0.1
if 'GROUP BY' in query_upper:
complexity += 0.2
if 'ORDER BY' in query_upper:
complexity += 0.1
if '*' in query: # Property paths
complexity += 0.1
# Count variables
variables = self._extract_variables(query)
complexity += len(variables) * 0.05
return min(complexity, 1.0)