mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
OntoRAG: Ontology-Based Knowledge Extraction and Query Technical Specification (#523)
* Onto-rag tech spec * New processor kg-extract-ontology, use 'ontology' objects from config to guide triple extraction * Also entity contexts * Integrate with ontology extractor from workbench This is first phase, the extraction is tested and working, also GraphRAG with the extracted knowledge works
This commit is contained in:
parent
4c3db4dbbe
commit
c69f5207a4
28 changed files with 11824 additions and 0 deletions
|
|
@ -16,6 +16,7 @@ dependencies = [
|
|||
"scylla-driver",
|
||||
"cohere",
|
||||
"cryptography",
|
||||
"faiss-cpu",
|
||||
"falkordb",
|
||||
"fastembed",
|
||||
"google-genai",
|
||||
|
|
@ -29,6 +30,7 @@ dependencies = [
|
|||
"minio",
|
||||
"mistralai",
|
||||
"neo4j",
|
||||
"nltk",
|
||||
"ollama",
|
||||
"openai",
|
||||
"pinecone[grpc]",
|
||||
|
|
@ -82,6 +84,7 @@ kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
|
|||
kg-extract-objects = "trustgraph.extract.kg.objects:run"
|
||||
kg-extract-relationships = "trustgraph.extract.kg.relationships:run"
|
||||
kg-extract-topics = "trustgraph.extract.kg.topics:run"
|
||||
kg-extract-ontology = "trustgraph.extract.kg.ontology:run"
|
||||
kg-manager = "trustgraph.cores:run"
|
||||
kg-store = "trustgraph.storage.knowledge:run"
|
||||
librarian = "trustgraph.librarian:run"
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
from . extract import *
|
||||
848
trustgraph-flow/trustgraph/extract/kg/ontology/extract.py
Normal file
848
trustgraph-flow/trustgraph/extract/kg/ontology/extract.py
Normal file
|
|
@ -0,0 +1,848 @@
|
|||
"""
|
||||
OntoRAG: Ontology-based knowledge extraction service.
|
||||
Extracts ontology-conformant triples from text chunks.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from .... schema import Chunk, Triple, Triples, Metadata, Value
|
||||
from .... schema import EntityContext, EntityContexts
|
||||
from .... schema import PromptRequest, PromptResponse
|
||||
from .... rdf import TRUSTGRAPH_ENTITIES, RDF_TYPE, RDF_LABEL, DEFINITION
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
|
||||
from .ontology_loader import OntologyLoader
|
||||
from .ontology_embedder import OntologyEmbedder
|
||||
from .vector_store import InMemoryVectorStore
|
||||
from .text_processor import TextProcessor
|
||||
from .ontology_selector import OntologySelector, OntologySubset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "kg-extract-ontology"
|
||||
default_concurrency = 1
|
||||
|
||||
# URI prefix mappings for common namespaces
|
||||
URI_PREFIXES = {
|
||||
"rdf:": "http://www.w3.org/1999/02/22-rdf-syntax-ns#",
|
||||
"rdfs:": "http://www.w3.org/2000/01/rdf-schema#",
|
||||
"owl:": "http://www.w3.org/2002/07/owl#",
|
||||
"skos:": "http://www.w3.org/2004/02/skos/core#",
|
||||
"schema:": "https://schema.org/",
|
||||
"xsd:": "http://www.w3.org/2001/XMLSchema#",
|
||||
}
|
||||
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
"""Main OntoRAG extraction processor."""
|
||||
|
||||
def __init__(self, **params):
|
||||
id = params.get("id", default_ident)
|
||||
concurrency = params.get("concurrency", default_concurrency)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"concurrency": concurrency,
|
||||
}
|
||||
)
|
||||
|
||||
# Register specifications
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name="input",
|
||||
schema=Chunk,
|
||||
handler=self.on_message,
|
||||
concurrency=concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
PromptClientSpec(
|
||||
request_name="prompt-request",
|
||||
response_name="prompt-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
EmbeddingsClientSpec(
|
||||
request_name="embeddings-request",
|
||||
response_name="embeddings-response"
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name="triples",
|
||||
schema=Triples
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name="entity-contexts",
|
||||
schema=EntityContexts
|
||||
)
|
||||
)
|
||||
|
||||
# Register config handler for ontology updates
|
||||
self.register_config_handler(self.on_ontology_config)
|
||||
|
||||
# Shared components (not flow-specific)
|
||||
self.ontology_loader = OntologyLoader()
|
||||
self.text_processor = TextProcessor()
|
||||
|
||||
# Per-flow components (each flow gets its own embedder/vector store/selector)
|
||||
self.flow_components = {} # flow_id -> {embedder, vector_store, selector}
|
||||
|
||||
# Configuration
|
||||
self.top_k = params.get("top_k", 10)
|
||||
self.similarity_threshold = params.get("similarity_threshold", 0.3)
|
||||
|
||||
# Track loaded ontology version
|
||||
self.current_ontology_version = None
|
||||
self.loaded_ontology_ids = set()
|
||||
|
||||
async def initialize_flow_components(self, flow):
|
||||
"""Initialize per-flow OntoRAG components.
|
||||
|
||||
Each flow gets its own vector store and embedder to support
|
||||
different embedding models across flows. The vector store dimension
|
||||
is auto-detected from the embeddings service.
|
||||
|
||||
Args:
|
||||
flow: Flow object for this processing context
|
||||
|
||||
Returns:
|
||||
flow_id: Identifier for this flow's components
|
||||
"""
|
||||
# Use flow object as identifier
|
||||
flow_id = id(flow)
|
||||
|
||||
if flow_id in self.flow_components:
|
||||
return flow_id # Already initialized for this flow
|
||||
|
||||
try:
|
||||
logger.info(f"Initializing components for flow {flow_id}")
|
||||
|
||||
# Use embeddings client directly (no wrapper needed)
|
||||
embeddings_client = flow("embeddings-request")
|
||||
|
||||
# Detect embedding dimension by embedding a test string
|
||||
logger.info("Detecting embedding dimension from embeddings service...")
|
||||
test_embedding_response = await embeddings_client.embed("test")
|
||||
test_embedding = test_embedding_response[0] # Extract from [[vector]]
|
||||
dimension = len(test_embedding)
|
||||
logger.info(f"Detected embedding dimension: {dimension}")
|
||||
|
||||
# Initialize vector store with detected dimension
|
||||
vector_store = InMemoryVectorStore(
|
||||
dimension=dimension,
|
||||
index_type='flat'
|
||||
)
|
||||
|
||||
ontology_embedder = OntologyEmbedder(
|
||||
embedding_service=embeddings_client,
|
||||
vector_store=vector_store
|
||||
)
|
||||
|
||||
# Embed all loaded ontologies for this flow
|
||||
if self.ontology_loader.get_all_ontologies():
|
||||
logger.info(f"Embedding ontologies for flow {flow_id}")
|
||||
for ont_id, ontology in self.ontology_loader.get_all_ontologies().items():
|
||||
await ontology_embedder.embed_ontology(ontology)
|
||||
logger.info(f"Embedded {ontology_embedder.get_embedded_count()} ontology elements for flow {flow_id}")
|
||||
|
||||
# Initialize ontology selector
|
||||
ontology_selector = OntologySelector(
|
||||
ontology_embedder=ontology_embedder,
|
||||
ontology_loader=self.ontology_loader,
|
||||
top_k=self.top_k,
|
||||
similarity_threshold=self.similarity_threshold
|
||||
)
|
||||
|
||||
# Store flow-specific components
|
||||
self.flow_components[flow_id] = {
|
||||
'embedder': ontology_embedder,
|
||||
'vector_store': vector_store,
|
||||
'selector': ontology_selector,
|
||||
'dimension': dimension
|
||||
}
|
||||
|
||||
logger.info(f"Flow {flow_id} components initialized successfully (dimension={dimension})")
|
||||
return flow_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize flow {flow_id} components: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def on_ontology_config(self, config, version):
|
||||
"""
|
||||
Handle ontology configuration updates from ConfigPush queue.
|
||||
|
||||
Parses and stores ontologies. Embedding happens per-flow on first message.
|
||||
|
||||
Called automatically when:
|
||||
- Processor starts (gets full config history via start_of_messages=True)
|
||||
- Config service pushes updates (immediate event-driven notification)
|
||||
|
||||
Args:
|
||||
config: Full configuration map - config[type][key] = value
|
||||
version: Config version number (monotonically increasing)
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Received ontology config update, version={version}")
|
||||
|
||||
# Skip if we've already processed this version
|
||||
if version == self.current_ontology_version:
|
||||
logger.debug(f"Already at version {version}, skipping")
|
||||
return
|
||||
|
||||
# Extract ontology configurations
|
||||
if "ontology" not in config:
|
||||
logger.warning("No 'ontology' section in config")
|
||||
return
|
||||
|
||||
ontology_configs = config["ontology"]
|
||||
|
||||
# Parse ontology definitions
|
||||
ontologies = {}
|
||||
for ont_id, ont_json in ontology_configs.items():
|
||||
try:
|
||||
ontologies[ont_id] = json.loads(ont_json)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse ontology '{ont_id}': {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Loaded {len(ontologies)} ontology definitions")
|
||||
|
||||
# Determine what changed (for incremental updates)
|
||||
new_ids = set(ontologies.keys())
|
||||
added_ids = new_ids - self.loaded_ontology_ids
|
||||
removed_ids = self.loaded_ontology_ids - new_ids
|
||||
updated_ids = new_ids & self.loaded_ontology_ids # May have changed content
|
||||
|
||||
if added_ids:
|
||||
logger.info(f"New ontologies: {added_ids}")
|
||||
if removed_ids:
|
||||
logger.info(f"Removed ontologies: {removed_ids}")
|
||||
if updated_ids:
|
||||
logger.info(f"Updated ontologies: {updated_ids}")
|
||||
|
||||
# Update ontology loader's internal state
|
||||
self.ontology_loader.update_ontologies(ontologies)
|
||||
|
||||
# Clear all flow components to force re-embedding with new ontologies
|
||||
if added_ids or removed_ids or updated_ids:
|
||||
logger.info("Clearing flow components to trigger re-embedding")
|
||||
self.flow_components.clear()
|
||||
|
||||
# Update tracking
|
||||
self.current_ontology_version = version
|
||||
self.loaded_ontology_ids = new_ids
|
||||
|
||||
logger.info(f"Ontology config update complete, version={version}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process ontology config: {e}", exc_info=True)
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
"""Process incoming chunk message."""
|
||||
v = msg.value()
|
||||
logger.info(f"Extracting ontology-based triples from {v.metadata.id}...")
|
||||
|
||||
# Initialize flow-specific components if needed
|
||||
flow_id = await self.initialize_flow_components(flow)
|
||||
components = self.flow_components[flow_id]
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
logger.debug(f"Processing chunk: {chunk[:200]}...")
|
||||
|
||||
try:
|
||||
# Process text into segments
|
||||
segments = self.text_processor.process_chunk(chunk, extract_phrases=True)
|
||||
logger.debug(f"Split chunk into {len(segments)} segments")
|
||||
|
||||
# Select relevant ontology subset (using flow-specific selector)
|
||||
ontology_subsets = await components['selector'].select_ontology_subset(segments)
|
||||
|
||||
if not ontology_subsets:
|
||||
logger.warning("No relevant ontology elements found for chunk")
|
||||
# Emit empty outputs
|
||||
await self.emit_triples(
|
||||
flow("triples"),
|
||||
v.metadata,
|
||||
[]
|
||||
)
|
||||
await self.emit_entity_contexts(
|
||||
flow("entity-contexts"),
|
||||
v.metadata,
|
||||
[]
|
||||
)
|
||||
return
|
||||
|
||||
# Merge subsets if multiple ontologies matched
|
||||
if len(ontology_subsets) > 1:
|
||||
ontology_subset = components['selector'].merge_subsets(ontology_subsets)
|
||||
else:
|
||||
ontology_subset = ontology_subsets[0]
|
||||
|
||||
logger.debug(f"Selected ontology subset with {len(ontology_subset.classes)} classes, "
|
||||
f"{len(ontology_subset.object_properties)} object properties, "
|
||||
f"{len(ontology_subset.datatype_properties)} datatype properties")
|
||||
|
||||
# Build extraction prompt variables
|
||||
prompt_variables = self.build_extraction_variables(chunk, ontology_subset)
|
||||
|
||||
# Call prompt service for extraction
|
||||
try:
|
||||
# Use prompt() method with extract-with-ontologies prompt ID
|
||||
triples_response = await flow("prompt-request").prompt(
|
||||
id="extract-with-ontologies",
|
||||
variables=prompt_variables
|
||||
)
|
||||
logger.debug(f"Extraction response: {triples_response}")
|
||||
|
||||
if not isinstance(triples_response, list):
|
||||
logger.error("Expected list of triples from prompt service")
|
||||
triples_response = []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Prompt service error: {e}", exc_info=True)
|
||||
triples_response = []
|
||||
|
||||
# Parse and validate triples
|
||||
triples = self.parse_and_validate_triples(triples_response, ontology_subset)
|
||||
|
||||
# Add metadata triples
|
||||
for t in v.metadata.metadata:
|
||||
triples.append(t)
|
||||
|
||||
# Generate ontology definition triples
|
||||
ontology_triples = self.build_ontology_triples(ontology_subset)
|
||||
|
||||
# Combine extracted triples with ontology triples
|
||||
all_triples = triples + ontology_triples
|
||||
|
||||
# Build entity contexts from all triples (including ontology elements)
|
||||
entity_contexts = self.build_entity_contexts(all_triples)
|
||||
|
||||
# Emit all triples (extracted + ontology definitions)
|
||||
await self.emit_triples(
|
||||
flow("triples"),
|
||||
v.metadata,
|
||||
all_triples
|
||||
)
|
||||
|
||||
# Emit entity contexts
|
||||
await self.emit_entity_contexts(
|
||||
flow("entity-contexts"),
|
||||
v.metadata,
|
||||
entity_contexts
|
||||
)
|
||||
|
||||
logger.info(f"Extracted {len(triples)} content triples + {len(ontology_triples)} ontology triples "
|
||||
f"= {len(all_triples)} total triples and {len(entity_contexts)} entity contexts")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OntoRAG extraction exception: {e}", exc_info=True)
|
||||
# Emit empty outputs on error
|
||||
await self.emit_triples(
|
||||
flow("triples"),
|
||||
v.metadata,
|
||||
[]
|
||||
)
|
||||
await self.emit_entity_contexts(
|
||||
flow("entity-contexts"),
|
||||
v.metadata,
|
||||
[]
|
||||
)
|
||||
|
||||
def build_extraction_variables(self, chunk: str, ontology_subset: OntologySubset) -> Dict[str, Any]:
|
||||
"""Build variables for ontology-based extraction prompt template.
|
||||
|
||||
Args:
|
||||
chunk: Text chunk to extract from
|
||||
ontology_subset: Relevant ontology elements
|
||||
|
||||
Returns:
|
||||
Dict with template variables: text, classes, object_properties, datatype_properties
|
||||
"""
|
||||
return {
|
||||
"text": chunk,
|
||||
"classes": ontology_subset.classes,
|
||||
"object_properties": ontology_subset.object_properties,
|
||||
"datatype_properties": ontology_subset.datatype_properties
|
||||
}
|
||||
|
||||
def parse_and_validate_triples(self, triples_response: List[Any],
|
||||
ontology_subset: OntologySubset) -> List[Triple]:
|
||||
"""Parse and validate extracted triples against ontology."""
|
||||
validated_triples = []
|
||||
ontology_id = ontology_subset.ontology_id
|
||||
|
||||
for triple_data in triples_response:
|
||||
try:
|
||||
if isinstance(triple_data, dict):
|
||||
subject = triple_data.get('subject', '')
|
||||
predicate = triple_data.get('predicate', '')
|
||||
object_val = triple_data.get('object', '')
|
||||
|
||||
if not subject or not predicate or not object_val:
|
||||
continue
|
||||
|
||||
# Validate against ontology
|
||||
if self.is_valid_triple(subject, predicate, object_val, ontology_subset):
|
||||
# Expand URIs before creating Value objects
|
||||
subject_uri = self.expand_uri(subject, ontology_subset, ontology_id)
|
||||
predicate_uri = self.expand_uri(predicate, ontology_subset, ontology_id)
|
||||
|
||||
# Object might be URI or literal - check before expanding
|
||||
if self.is_uri(object_val) or self.should_expand_as_uri(object_val, ontology_subset):
|
||||
object_uri = self.expand_uri(object_val, ontology_subset, ontology_id)
|
||||
is_object_uri = True
|
||||
else:
|
||||
object_uri = object_val
|
||||
is_object_uri = False
|
||||
|
||||
# Create Triple object with expanded URIs
|
||||
s_value = Value(value=subject_uri, is_uri=True)
|
||||
p_value = Value(value=predicate_uri, is_uri=True)
|
||||
o_value = Value(value=object_uri, is_uri=is_object_uri)
|
||||
|
||||
validated_triples.append(Triple(
|
||||
s=s_value,
|
||||
p=p_value,
|
||||
o=o_value
|
||||
))
|
||||
else:
|
||||
logger.debug(f"Invalid triple: ({subject}, {predicate}, {object_val})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing triple: {e}")
|
||||
|
||||
return validated_triples
|
||||
|
||||
def should_expand_as_uri(self, value: str, ontology_subset: OntologySubset) -> bool:
|
||||
"""Check if a value should be treated as URI (not literal).
|
||||
|
||||
Returns True if value is a class name, property name, or entity reference.
|
||||
"""
|
||||
# Check if it's a class or property from ontology
|
||||
if value in ontology_subset.classes:
|
||||
return True
|
||||
if value in ontology_subset.object_properties:
|
||||
return True
|
||||
if value in ontology_subset.datatype_properties:
|
||||
return True
|
||||
# Check if it starts with a known prefix
|
||||
for prefix in URI_PREFIXES.keys():
|
||||
if value.startswith(prefix):
|
||||
return True
|
||||
# Check if it looks like an entity reference (e.g., "recipe:cornish-pasty")
|
||||
if ":" in value and not value.startswith("http"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_valid_triple(self, subject: str, predicate: str, object_val: str,
|
||||
ontology_subset: OntologySubset) -> bool:
|
||||
"""Validate triple against ontology constraints."""
|
||||
# Special case for rdf:type
|
||||
if predicate == "rdf:type" or predicate == str(RDF_TYPE):
|
||||
# Check if object is a valid class
|
||||
return object_val in ontology_subset.classes
|
||||
|
||||
# Special case for rdfs:label
|
||||
if predicate == "rdfs:label" or predicate == str(RDF_LABEL):
|
||||
return True # Labels are always valid
|
||||
|
||||
# Check if predicate is a valid property
|
||||
is_obj_prop = predicate in ontology_subset.object_properties
|
||||
is_dt_prop = predicate in ontology_subset.datatype_properties
|
||||
|
||||
if not is_obj_prop and not is_dt_prop:
|
||||
return False # Unknown property
|
||||
|
||||
# TODO: Add more sophisticated validation (domain/range checking)
|
||||
return True
|
||||
|
||||
def expand_uri(self, value: str, ontology_subset: OntologySubset, ontology_id: str = "unknown") -> str:
|
||||
"""Expand prefix notation or short names to full URIs.
|
||||
|
||||
Args:
|
||||
value: Value to expand (e.g., "rdf:type", "Recipe", "has_ingredient")
|
||||
ontology_subset: Ontology subset for class/property lookup
|
||||
ontology_id: ID of the ontology for constructing instance URIs
|
||||
|
||||
Returns:
|
||||
Full URI string
|
||||
"""
|
||||
# Already a full URI
|
||||
if value.startswith("http://") or value.startswith("https://"):
|
||||
return value
|
||||
|
||||
# Check standard prefixes (rdf:, rdfs:, etc.)
|
||||
for prefix, namespace in URI_PREFIXES.items():
|
||||
if value.startswith(prefix):
|
||||
return namespace + value[len(prefix):]
|
||||
|
||||
# Check if it's an ontology class
|
||||
if value in ontology_subset.classes:
|
||||
class_def = ontology_subset.classes[value]
|
||||
# class_def is a dict (from cls.__dict__ in ontology_selector)
|
||||
if isinstance(class_def, dict) and 'uri' in class_def and class_def['uri']:
|
||||
return class_def['uri']
|
||||
# Fallback: construct URI
|
||||
return f"https://trustgraph.ai/ontology/{ontology_id}#{value}"
|
||||
|
||||
# Check if it's an ontology property
|
||||
if value in ontology_subset.object_properties:
|
||||
prop_def = ontology_subset.object_properties[value]
|
||||
# prop_def is a dict (from prop.__dict__ in ontology_selector)
|
||||
if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']:
|
||||
return prop_def['uri']
|
||||
return f"https://trustgraph.ai/ontology/{ontology_id}#{value}"
|
||||
|
||||
if value in ontology_subset.datatype_properties:
|
||||
prop_def = ontology_subset.datatype_properties[value]
|
||||
# prop_def is a dict (from prop.__dict__ in ontology_selector)
|
||||
if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']:
|
||||
return prop_def['uri']
|
||||
return f"https://trustgraph.ai/ontology/{ontology_id}#{value}"
|
||||
|
||||
# Otherwise, treat as entity instance - construct unique URI
|
||||
# Normalize the value for URI (lowercase, replace spaces with hyphens)
|
||||
normalized = value.replace(" ", "-").lower()
|
||||
return f"https://trustgraph.ai/{ontology_id}/{normalized}"
|
||||
|
||||
def is_uri(self, value: str) -> bool:
|
||||
"""Check if value is already a full URI."""
|
||||
return value.startswith("http://") or value.startswith("https://")
|
||||
|
||||
async def emit_triples(self, pub, metadata: Metadata, triples: List[Triple]):
|
||||
"""Emit triples to output."""
|
||||
t = Triples(
|
||||
metadata=Metadata(
|
||||
id=metadata.id,
|
||||
metadata=[],
|
||||
user=metadata.user,
|
||||
collection=metadata.collection,
|
||||
),
|
||||
triples=triples,
|
||||
)
|
||||
await pub.send(t)
|
||||
|
||||
async def emit_entity_contexts(self, pub, metadata: Metadata, entities: List[EntityContext]):
|
||||
"""Emit entity contexts to output."""
|
||||
ec = EntityContexts(
|
||||
metadata=Metadata(
|
||||
id=metadata.id,
|
||||
metadata=[],
|
||||
user=metadata.user,
|
||||
collection=metadata.collection,
|
||||
),
|
||||
entities=entities,
|
||||
)
|
||||
await pub.send(ec)
|
||||
|
||||
def build_ontology_triples(self, ontology_subset: OntologySubset) -> List[Triple]:
|
||||
"""Build triples describing the ontology elements themselves.
|
||||
|
||||
Generates triples for classes and properties so they exist in the knowledge graph.
|
||||
|
||||
Args:
|
||||
ontology_subset: The ontology subset used for extraction
|
||||
|
||||
Returns:
|
||||
List of Triple objects describing ontology elements
|
||||
"""
|
||||
ontology_triples = []
|
||||
|
||||
# Generate triples for classes
|
||||
for class_id, class_def in ontology_subset.classes.items():
|
||||
# Get URI for class
|
||||
if isinstance(class_def, dict) and 'uri' in class_def and class_def['uri']:
|
||||
class_uri = class_def['uri']
|
||||
else:
|
||||
# Fallback to constructed URI
|
||||
class_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{class_id}"
|
||||
|
||||
# rdf:type owl:Class
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=class_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://www.w3.org/2002/07/owl#Class", is_uri=True)
|
||||
))
|
||||
|
||||
# rdfs:label (stored as 'labels' in OntologyClass.__dict__)
|
||||
if isinstance(class_def, dict) and 'labels' in class_def:
|
||||
labels = class_def['labels']
|
||||
if isinstance(labels, list) and labels:
|
||||
label_val = labels[0].get('value', class_id) if isinstance(labels[0], dict) else str(labels[0])
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=class_uri, is_uri=True),
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=label_val, is_uri=False)
|
||||
))
|
||||
|
||||
# rdfs:comment (stored as 'comment' in OntologyClass.__dict__)
|
||||
if isinstance(class_def, dict) and 'comment' in class_def and class_def['comment']:
|
||||
comment = class_def['comment']
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=class_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
|
||||
o=Value(value=comment, is_uri=False)
|
||||
))
|
||||
|
||||
# rdfs:subClassOf (stored as 'subclass_of' in OntologyClass.__dict__)
|
||||
if isinstance(class_def, dict) and 'subclass_of' in class_def and class_def['subclass_of']:
|
||||
parent = class_def['subclass_of']
|
||||
# Get parent URI
|
||||
if parent in ontology_subset.classes:
|
||||
parent_class_def = ontology_subset.classes[parent]
|
||||
if isinstance(parent_class_def, dict) and 'uri' in parent_class_def and parent_class_def['uri']:
|
||||
parent_uri = parent_class_def['uri']
|
||||
else:
|
||||
parent_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{parent}"
|
||||
else:
|
||||
parent_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{parent}"
|
||||
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=class_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#subClassOf", is_uri=True),
|
||||
o=Value(value=parent_uri, is_uri=True)
|
||||
))
|
||||
|
||||
# Generate triples for object properties
|
||||
for prop_id, prop_def in ontology_subset.object_properties.items():
|
||||
# Get URI for property
|
||||
if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']:
|
||||
prop_uri = prop_def['uri']
|
||||
else:
|
||||
prop_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{prop_id}"
|
||||
|
||||
# rdf:type owl:ObjectProperty
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://www.w3.org/2002/07/owl#ObjectProperty", is_uri=True)
|
||||
))
|
||||
|
||||
# rdfs:label (stored as 'labels' in OntologyProperty.__dict__)
|
||||
if isinstance(prop_def, dict) and 'labels' in prop_def:
|
||||
labels = prop_def['labels']
|
||||
if isinstance(labels, list) and labels:
|
||||
label_val = labels[0].get('value', prop_id) if isinstance(labels[0], dict) else str(labels[0])
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=label_val, is_uri=False)
|
||||
))
|
||||
|
||||
# rdfs:comment (stored as 'comment' in OntologyProperty.__dict__)
|
||||
if isinstance(prop_def, dict) and 'comment' in prop_def and prop_def['comment']:
|
||||
comment = prop_def['comment']
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
|
||||
o=Value(value=comment, is_uri=False)
|
||||
))
|
||||
|
||||
# rdfs:domain (stored as 'domain' in OntologyProperty.__dict__)
|
||||
if isinstance(prop_def, dict) and 'domain' in prop_def and prop_def['domain']:
|
||||
domain = prop_def['domain']
|
||||
# Get domain class URI
|
||||
if domain in ontology_subset.classes:
|
||||
domain_class_def = ontology_subset.classes[domain]
|
||||
if isinstance(domain_class_def, dict) and 'uri' in domain_class_def and domain_class_def['uri']:
|
||||
domain_uri = domain_class_def['uri']
|
||||
else:
|
||||
domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}"
|
||||
else:
|
||||
domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}"
|
||||
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True),
|
||||
o=Value(value=domain_uri, is_uri=True)
|
||||
))
|
||||
|
||||
# rdfs:range (stored as 'range' in OntologyProperty.__dict__)
|
||||
if isinstance(prop_def, dict) and 'range' in prop_def and prop_def['range']:
|
||||
range_val = prop_def['range']
|
||||
# Get range class URI
|
||||
if range_val in ontology_subset.classes:
|
||||
range_class_def = ontology_subset.classes[range_val]
|
||||
if isinstance(range_class_def, dict) and 'uri' in range_class_def and range_class_def['uri']:
|
||||
range_uri = range_class_def['uri']
|
||||
else:
|
||||
range_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{range_val}"
|
||||
else:
|
||||
range_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{range_val}"
|
||||
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#range", is_uri=True),
|
||||
o=Value(value=range_uri, is_uri=True)
|
||||
))
|
||||
|
||||
# Generate triples for datatype properties
|
||||
for prop_id, prop_def in ontology_subset.datatype_properties.items():
|
||||
# Get URI for property
|
||||
if isinstance(prop_def, dict) and 'uri' in prop_def and prop_def['uri']:
|
||||
prop_uri = prop_def['uri']
|
||||
else:
|
||||
prop_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{prop_id}"
|
||||
|
||||
# rdf:type owl:DatatypeProperty
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://www.w3.org/2002/07/owl#DatatypeProperty", is_uri=True)
|
||||
))
|
||||
|
||||
# rdfs:label (stored as 'labels' in OntologyProperty.__dict__)
|
||||
if isinstance(prop_def, dict) and 'labels' in prop_def:
|
||||
labels = prop_def['labels']
|
||||
if isinstance(labels, list) and labels:
|
||||
label_val = labels[0].get('value', prop_id) if isinstance(labels[0], dict) else str(labels[0])
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=label_val, is_uri=False)
|
||||
))
|
||||
|
||||
# rdfs:comment (stored as 'comment' in OntologyProperty.__dict__)
|
||||
if isinstance(prop_def, dict) and 'comment' in prop_def and prop_def['comment']:
|
||||
comment = prop_def['comment']
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#comment", is_uri=True),
|
||||
o=Value(value=comment, is_uri=False)
|
||||
))
|
||||
|
||||
# rdfs:domain (stored as 'domain' in OntologyProperty.__dict__)
|
||||
if isinstance(prop_def, dict) and 'domain' in prop_def and prop_def['domain']:
|
||||
domain = prop_def['domain']
|
||||
# Get domain class URI
|
||||
if domain in ontology_subset.classes:
|
||||
domain_class_def = ontology_subset.classes[domain]
|
||||
if isinstance(domain_class_def, dict) and 'uri' in domain_class_def and domain_class_def['uri']:
|
||||
domain_uri = domain_class_def['uri']
|
||||
else:
|
||||
domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}"
|
||||
else:
|
||||
domain_uri = f"https://trustgraph.ai/ontology/{ontology_subset.ontology_id}#{domain}"
|
||||
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#domain", is_uri=True),
|
||||
o=Value(value=domain_uri, is_uri=True)
|
||||
))
|
||||
|
||||
# rdfs:range (datatype)
|
||||
if isinstance(prop_def, dict) and 'rdfs:range' in prop_def and prop_def['rdfs:range']:
|
||||
range_val = prop_def['rdfs:range']
|
||||
# Range for datatype properties is usually xsd:string, xsd:int, etc.
|
||||
if range_val.startswith('xsd:'):
|
||||
range_uri = f"http://www.w3.org/2001/XMLSchema#{range_val[4:]}"
|
||||
else:
|
||||
range_uri = range_val
|
||||
|
||||
ontology_triples.append(Triple(
|
||||
s=Value(value=prop_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#range", is_uri=True),
|
||||
o=Value(value=range_uri, is_uri=True)
|
||||
))
|
||||
|
||||
logger.info(f"Generated {len(ontology_triples)} triples describing ontology elements")
|
||||
return ontology_triples
|
||||
|
||||
def build_entity_contexts(self, triples: List[Triple]) -> List[EntityContext]:
|
||||
"""Build entity contexts from extracted triples.
|
||||
|
||||
Collects rdfs:label and definition properties for each entity to create
|
||||
contextual descriptions for embedding.
|
||||
|
||||
Args:
|
||||
triples: List of extracted triples
|
||||
|
||||
Returns:
|
||||
List of EntityContext objects
|
||||
"""
|
||||
# Group triples by subject to collect entity information
|
||||
entity_data = {} # subject_uri -> {labels: [], definitions: []}
|
||||
|
||||
for triple in triples:
|
||||
subject_uri = triple.s.value
|
||||
predicate_uri = triple.p.value
|
||||
object_val = triple.o.value
|
||||
|
||||
# Initialize entity data if not exists
|
||||
if subject_uri not in entity_data:
|
||||
entity_data[subject_uri] = {'labels': [], 'definitions': []}
|
||||
|
||||
# Collect labels (rdfs:label)
|
||||
if predicate_uri == RDF_LABEL:
|
||||
if not triple.o.is_uri: # Labels are literals
|
||||
entity_data[subject_uri]['labels'].append(object_val)
|
||||
|
||||
# Collect definitions (skos:definition, schema:description)
|
||||
elif predicate_uri == DEFINITION or predicate_uri == "https://schema.org/description":
|
||||
if not triple.o.is_uri:
|
||||
entity_data[subject_uri]['definitions'].append(object_val)
|
||||
|
||||
# Build EntityContext objects
|
||||
entity_contexts = []
|
||||
for subject_uri, data in entity_data.items():
|
||||
# Build context text from labels and definitions
|
||||
context_parts = []
|
||||
|
||||
if data['labels']:
|
||||
context_parts.append(f"Label: {data['labels'][0]}")
|
||||
|
||||
if data['definitions']:
|
||||
context_parts.extend(data['definitions'])
|
||||
|
||||
# Only create EntityContext if we have meaningful context
|
||||
if context_parts:
|
||||
context_text = ". ".join(context_parts)
|
||||
entity_contexts.append(EntityContext(
|
||||
entity=Value(value=subject_uri, is_uri=True),
|
||||
context=context_text
|
||||
))
|
||||
|
||||
logger.debug(f"Built {len(entity_contexts)} entity contexts from {len(triples)} triples")
|
||||
return entity_contexts
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments."""
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--top-k',
|
||||
type=int,
|
||||
default=10,
|
||||
help='Number of top ontology elements to retrieve (default: 10)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--similarity-threshold',
|
||||
type=float,
|
||||
default=0.3,
|
||||
help='Similarity threshold for ontology matching (default: 0.3, range: 0.0-1.0)'
|
||||
)
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
|
||||
def run():
|
||||
"""Launch the OntoRAG extraction service."""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -0,0 +1,310 @@
|
|||
"""
|
||||
Ontology embedder component for OntoRAG system.
|
||||
Generates and stores embeddings for ontology elements.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import Dict, List, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .ontology_loader import Ontology, OntologyClass, OntologyProperty
|
||||
from .vector_store import InMemoryVectorStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OntologyElementMetadata:
|
||||
"""Metadata for an embedded ontology element."""
|
||||
type: str # 'class', 'objectProperty', 'datatypeProperty'
|
||||
ontology: str # Ontology ID
|
||||
element: str # Element ID
|
||||
definition: Dict[str, Any] # Full element definition
|
||||
text: str # Text used for embedding
|
||||
|
||||
|
||||
class OntologyEmbedder:
|
||||
"""Generates embeddings for ontology elements and stores them in vector store."""
|
||||
|
||||
def __init__(self, embedding_service=None, vector_store: Optional[InMemoryVectorStore] = None):
|
||||
"""Initialize the ontology embedder.
|
||||
|
||||
Args:
|
||||
embedding_service: Service for generating embeddings
|
||||
vector_store: Vector store instance (InMemoryVectorStore)
|
||||
"""
|
||||
self.embedding_service = embedding_service
|
||||
self.vector_store = vector_store or InMemoryVectorStore()
|
||||
self.embedded_ontologies = set()
|
||||
|
||||
def _create_text_representation(self, element_id: str, element: Any,
|
||||
element_type: str) -> str:
|
||||
"""Create text representation of an ontology element for embedding.
|
||||
|
||||
Args:
|
||||
element_id: ID of the element
|
||||
element: The element object (OntologyClass or OntologyProperty)
|
||||
element_type: Type of element
|
||||
|
||||
Returns:
|
||||
Text representation for embedding
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Add the element ID (often meaningful)
|
||||
parts.append(element_id.replace('-', ' ').replace('_', ' '))
|
||||
|
||||
# Add labels
|
||||
if hasattr(element, 'labels') and element.labels:
|
||||
for label in element.labels:
|
||||
if isinstance(label, dict):
|
||||
parts.append(label.get('value', ''))
|
||||
else:
|
||||
parts.append(str(label))
|
||||
|
||||
# Add comment/description
|
||||
if hasattr(element, 'comment') and element.comment:
|
||||
parts.append(element.comment)
|
||||
|
||||
# Add type-specific information
|
||||
if element_type == 'class':
|
||||
if hasattr(element, 'subclass_of') and element.subclass_of:
|
||||
parts.append(f"subclass of {element.subclass_of}")
|
||||
elif element_type in ['objectProperty', 'datatypeProperty']:
|
||||
if hasattr(element, 'domain') and element.domain:
|
||||
parts.append(f"domain: {element.domain}")
|
||||
if hasattr(element, 'range') and element.range:
|
||||
parts.append(f"range: {element.range}")
|
||||
|
||||
# Join all parts with spaces
|
||||
text = ' '.join(filter(None, parts))
|
||||
return text
|
||||
|
||||
async def embed_ontology(self, ontology: Ontology) -> int:
|
||||
"""Generate and store embeddings for all elements in an ontology.
|
||||
|
||||
Args:
|
||||
ontology: The ontology to embed
|
||||
|
||||
Returns:
|
||||
Number of elements embedded
|
||||
"""
|
||||
if not self.embedding_service:
|
||||
logger.warning("No embedding service available, skipping embedding")
|
||||
return 0
|
||||
|
||||
embedded_count = 0
|
||||
batch_size = 50 # Process embeddings in batches
|
||||
|
||||
# Collect all elements to embed
|
||||
elements_to_embed = []
|
||||
|
||||
# Process classes
|
||||
for class_id, class_def in ontology.classes.items():
|
||||
text = self._create_text_representation(class_id, class_def, 'class')
|
||||
elements_to_embed.append({
|
||||
'id': f"{ontology.id}:class:{class_id}",
|
||||
'text': text,
|
||||
'metadata': OntologyElementMetadata(
|
||||
type='class',
|
||||
ontology=ontology.id,
|
||||
element=class_id,
|
||||
definition=class_def.__dict__,
|
||||
text=text
|
||||
).__dict__
|
||||
})
|
||||
|
||||
# Process object properties
|
||||
for prop_id, prop_def in ontology.object_properties.items():
|
||||
text = self._create_text_representation(prop_id, prop_def, 'objectProperty')
|
||||
elements_to_embed.append({
|
||||
'id': f"{ontology.id}:objectProperty:{prop_id}",
|
||||
'text': text,
|
||||
'metadata': OntologyElementMetadata(
|
||||
type='objectProperty',
|
||||
ontology=ontology.id,
|
||||
element=prop_id,
|
||||
definition=prop_def.__dict__,
|
||||
text=text
|
||||
).__dict__
|
||||
})
|
||||
|
||||
# Process datatype properties
|
||||
for prop_id, prop_def in ontology.datatype_properties.items():
|
||||
text = self._create_text_representation(prop_id, prop_def, 'datatypeProperty')
|
||||
elements_to_embed.append({
|
||||
'id': f"{ontology.id}:datatypeProperty:{prop_id}",
|
||||
'text': text,
|
||||
'metadata': OntologyElementMetadata(
|
||||
type='datatypeProperty',
|
||||
ontology=ontology.id,
|
||||
element=prop_id,
|
||||
definition=prop_def.__dict__,
|
||||
text=text
|
||||
).__dict__
|
||||
})
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(elements_to_embed), batch_size):
|
||||
batch = elements_to_embed[i:i + batch_size]
|
||||
|
||||
# Get embeddings for batch
|
||||
texts = [elem['text'] for elem in batch]
|
||||
try:
|
||||
# Call embedding service for each text
|
||||
# Note: embed() returns 2D array [[vector]], so extract first element
|
||||
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
|
||||
embeddings_responses = await asyncio.gather(*embedding_tasks)
|
||||
|
||||
# Extract vectors from responses (each is [[vector]])
|
||||
embeddings_list = [resp[0] for resp in embeddings_responses]
|
||||
|
||||
# Convert to numpy array
|
||||
embeddings = np.array(embeddings_list)
|
||||
|
||||
# Log embedding shape for debugging
|
||||
logger.debug(f"Embeddings shape: {embeddings.shape}, expected: ({len(batch)}, {self.vector_store.dimension})")
|
||||
|
||||
# Store in vector store
|
||||
ids = [elem['id'] for elem in batch]
|
||||
metadata_list = [elem['metadata'] for elem in batch]
|
||||
|
||||
self.vector_store.add_batch(ids, embeddings, metadata_list)
|
||||
embedded_count += len(batch)
|
||||
|
||||
logger.debug(f"Embedded batch of {len(batch)} elements from ontology {ontology.id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to embed batch for ontology {ontology.id}: {e}", exc_info=True)
|
||||
|
||||
self.embedded_ontologies.add(ontology.id)
|
||||
logger.info(f"Embedded {embedded_count} elements from ontology {ontology.id}")
|
||||
return embedded_count
|
||||
|
||||
async def embed_ontologies(self, ontologies: Dict[str, Ontology]) -> int:
|
||||
"""Generate and store embeddings for multiple ontologies.
|
||||
|
||||
Args:
|
||||
ontologies: Dictionary of ontology ID to Ontology objects
|
||||
|
||||
Returns:
|
||||
Total number of elements embedded
|
||||
"""
|
||||
total_embedded = 0
|
||||
|
||||
for ont_id, ontology in ontologies.items():
|
||||
if ont_id not in self.embedded_ontologies:
|
||||
count = await self.embed_ontology(ontology)
|
||||
total_embedded += count
|
||||
else:
|
||||
logger.debug(f"Ontology {ont_id} already embedded, skipping")
|
||||
|
||||
logger.info(f"Total embedded elements: {total_embedded} from {len(ontologies)} ontologies")
|
||||
return total_embedded
|
||||
|
||||
async def embed_text(self, text: str) -> Optional[np.ndarray]:
|
||||
"""Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector or None if failed
|
||||
"""
|
||||
if not self.embedding_service:
|
||||
logger.warning("No embedding service available")
|
||||
return None
|
||||
|
||||
try:
|
||||
# embed() returns 2D array [[vector]], extract first element
|
||||
embedding_response = await self.embedding_service.embed(text)
|
||||
return np.array(embedding_response[0])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to embed text: {e}")
|
||||
return None
|
||||
|
||||
async def embed_texts(self, texts: List[str]) -> Optional[np.ndarray]:
|
||||
"""Generate embeddings for multiple texts.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
|
||||
Returns:
|
||||
Array of embeddings or None if failed
|
||||
"""
|
||||
if not self.embedding_service:
|
||||
logger.warning("No embedding service available")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Call embed() for each text (returns [[vector]] per call)
|
||||
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
|
||||
embeddings_responses = await asyncio.gather(*embedding_tasks)
|
||||
# Extract first vector from each response
|
||||
embeddings_list = [resp[0] for resp in embeddings_responses]
|
||||
return np.array(embeddings_list)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to embed texts: {e}")
|
||||
return None
|
||||
|
||||
def remove_ontology(self, ontology_id: str):
|
||||
"""Remove all embeddings for a specific ontology.
|
||||
|
||||
Note: FAISS doesn't support efficient deletion, so this currently
|
||||
requires rebuilding the entire index without the removed ontology.
|
||||
|
||||
Args:
|
||||
ontology_id: ID of ontology to remove
|
||||
"""
|
||||
if ontology_id not in self.embedded_ontologies:
|
||||
logger.debug(f"Ontology '{ontology_id}' not embedded, nothing to remove")
|
||||
return
|
||||
|
||||
# FAISS doesn't support selective deletion, so we'd need to rebuild the index
|
||||
# For now, just remove from tracking set
|
||||
# TODO: Implement index rebuilding if selective removal is needed
|
||||
self.embedded_ontologies.discard(ontology_id)
|
||||
logger.info(f"Removed ontology '{ontology_id}' from embedded set (note: vectors still in store)")
|
||||
|
||||
def clear_embeddings(self, ontology_id: Optional[str] = None):
|
||||
"""Clear embeddings from vector store.
|
||||
|
||||
Args:
|
||||
ontology_id: If provided, only clear embeddings for this ontology
|
||||
Otherwise, clear all embeddings
|
||||
"""
|
||||
if ontology_id:
|
||||
self.remove_ontology(ontology_id)
|
||||
else:
|
||||
self.vector_store.clear()
|
||||
self.embedded_ontologies.clear()
|
||||
logger.info("Cleared all embeddings from vector store")
|
||||
|
||||
def get_vector_store(self) -> InMemoryVectorStore:
|
||||
"""Get the vector store instance.
|
||||
|
||||
Returns:
|
||||
The vector store being used
|
||||
"""
|
||||
return self.vector_store
|
||||
|
||||
def get_embedded_count(self) -> int:
|
||||
"""Get the number of embedded elements.
|
||||
|
||||
Returns:
|
||||
Number of elements in the vector store
|
||||
"""
|
||||
return self.vector_store.size()
|
||||
|
||||
def is_ontology_embedded(self, ontology_id: str) -> bool:
|
||||
"""Check if an ontology has been embedded.
|
||||
|
||||
Args:
|
||||
ontology_id: ID of the ontology
|
||||
|
||||
Returns:
|
||||
True if the ontology has been embedded
|
||||
"""
|
||||
return ontology_id in self.embedded_ontologies
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
"""
|
||||
Ontology loader component for OntoRAG system.
|
||||
Loads and manages ontologies from configuration service.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OntologyClass:
|
||||
"""Represents an OWL-like class in the ontology."""
|
||||
uri: str
|
||||
type: str = "owl:Class"
|
||||
labels: List[Dict[str, str]] = field(default_factory=list)
|
||||
comment: Optional[str] = None
|
||||
subclass_of: Optional[str] = None
|
||||
equivalent_classes: List[str] = field(default_factory=list)
|
||||
disjoint_with: List[str] = field(default_factory=list)
|
||||
identifier: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(class_id: str, data: Dict[str, Any]) -> 'OntologyClass':
|
||||
"""Create OntologyClass from dictionary representation."""
|
||||
labels = data.get('rdfs:label', [])
|
||||
if isinstance(labels, list):
|
||||
labels = labels
|
||||
else:
|
||||
labels = [labels] if labels else []
|
||||
|
||||
return OntologyClass(
|
||||
uri=data.get('uri', ''),
|
||||
type=data.get('type', 'owl:Class'),
|
||||
labels=labels,
|
||||
comment=data.get('rdfs:comment'),
|
||||
subclass_of=data.get('rdfs:subClassOf'),
|
||||
equivalent_classes=data.get('owl:equivalentClass', []),
|
||||
disjoint_with=data.get('owl:disjointWith', []),
|
||||
identifier=data.get('dcterms:identifier')
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OntologyProperty:
|
||||
"""Represents a property (object or datatype) in the ontology."""
|
||||
uri: str
|
||||
type: str
|
||||
labels: List[Dict[str, str]] = field(default_factory=list)
|
||||
comment: Optional[str] = None
|
||||
domain: Optional[str] = None
|
||||
range: Optional[str] = None
|
||||
inverse_of: Optional[str] = None
|
||||
functional: bool = False
|
||||
inverse_functional: bool = False
|
||||
min_cardinality: Optional[int] = None
|
||||
max_cardinality: Optional[int] = None
|
||||
cardinality: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(prop_id: str, data: Dict[str, Any]) -> 'OntologyProperty':
|
||||
"""Create OntologyProperty from dictionary representation."""
|
||||
labels = data.get('rdfs:label', [])
|
||||
if isinstance(labels, list):
|
||||
labels = labels
|
||||
else:
|
||||
labels = [labels] if labels else []
|
||||
|
||||
return OntologyProperty(
|
||||
uri=data.get('uri', ''),
|
||||
type=data.get('type', ''),
|
||||
labels=labels,
|
||||
comment=data.get('rdfs:comment'),
|
||||
domain=data.get('rdfs:domain'),
|
||||
range=data.get('rdfs:range'),
|
||||
inverse_of=data.get('owl:inverseOf'),
|
||||
functional=data.get('owl:functionalProperty', False),
|
||||
inverse_functional=data.get('owl:inverseFunctionalProperty', False),
|
||||
min_cardinality=data.get('owl:minCardinality'),
|
||||
max_cardinality=data.get('owl:maxCardinality'),
|
||||
cardinality=data.get('owl:cardinality')
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ontology:
|
||||
"""Represents a complete ontology with metadata, classes, and properties."""
|
||||
id: str
|
||||
metadata: Dict[str, Any]
|
||||
classes: Dict[str, OntologyClass]
|
||||
object_properties: Dict[str, OntologyProperty]
|
||||
datatype_properties: Dict[str, OntologyProperty]
|
||||
|
||||
def get_class(self, class_id: str) -> Optional[OntologyClass]:
|
||||
"""Get a class by ID."""
|
||||
return self.classes.get(class_id)
|
||||
|
||||
def get_property(self, prop_id: str) -> Optional[OntologyProperty]:
|
||||
"""Get a property (object or datatype) by ID."""
|
||||
prop = self.object_properties.get(prop_id)
|
||||
if prop is None:
|
||||
prop = self.datatype_properties.get(prop_id)
|
||||
return prop
|
||||
|
||||
def get_parent_classes(self, class_id: str) -> List[str]:
|
||||
"""Get all parent classes (following subClassOf hierarchy)."""
|
||||
parents = []
|
||||
current = class_id
|
||||
visited = set()
|
||||
|
||||
while current and current not in visited:
|
||||
visited.add(current)
|
||||
cls = self.get_class(current)
|
||||
if cls and cls.subclass_of:
|
||||
parents.append(cls.subclass_of)
|
||||
current = cls.subclass_of
|
||||
else:
|
||||
break
|
||||
|
||||
return parents
|
||||
|
||||
def validate_structure(self) -> List[str]:
|
||||
"""Validate ontology structure and return list of issues."""
|
||||
issues = []
|
||||
|
||||
# Check for circular inheritance
|
||||
for class_id in self.classes:
|
||||
visited = set()
|
||||
current = class_id
|
||||
while current:
|
||||
if current in visited:
|
||||
issues.append(f"Circular inheritance detected for class {class_id}")
|
||||
break
|
||||
visited.add(current)
|
||||
cls = self.get_class(current)
|
||||
if cls:
|
||||
current = cls.subclass_of
|
||||
else:
|
||||
break
|
||||
|
||||
# Check property domains and ranges exist
|
||||
for prop_id, prop in {**self.object_properties, **self.datatype_properties}.items():
|
||||
if prop.domain and prop.domain not in self.classes:
|
||||
issues.append(f"Property {prop_id} has unknown domain {prop.domain}")
|
||||
if prop.type == "owl:ObjectProperty" and prop.range and prop.range not in self.classes:
|
||||
issues.append(f"Object property {prop_id} has unknown range class {prop.range}")
|
||||
|
||||
# Check disjoint classes
|
||||
for class_id, cls in self.classes.items():
|
||||
for disjoint_id in cls.disjoint_with:
|
||||
if disjoint_id not in self.classes:
|
||||
issues.append(f"Class {class_id} disjoint with unknown class {disjoint_id}")
|
||||
|
||||
return issues
|
||||
|
||||
|
||||
class OntologyLoader:
|
||||
"""Manages ontologies received via event-driven config updates.
|
||||
|
||||
No direct database access - receives ontologies via config handler.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize empty ontology store."""
|
||||
self.ontologies: Dict[str, Ontology] = {}
|
||||
|
||||
def update_ontologies(self, ontology_configs: Dict[str, Any]):
|
||||
"""Update ontology definitions from config.
|
||||
|
||||
Args:
|
||||
ontology_configs: Dict mapping ontology_id -> ontology_definition (parsed dicts)
|
||||
"""
|
||||
self.ontologies.clear()
|
||||
|
||||
for ont_id, ont_data in ontology_configs.items():
|
||||
try:
|
||||
# Parse classes
|
||||
classes = {}
|
||||
for class_id, class_data in ont_data.get('classes', {}).items():
|
||||
classes[class_id] = OntologyClass.from_dict(class_id, class_data)
|
||||
|
||||
# Parse object properties
|
||||
object_props = {}
|
||||
for prop_id, prop_data in ont_data.get('objectProperties', {}).items():
|
||||
object_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data)
|
||||
|
||||
# Parse datatype properties
|
||||
datatype_props = {}
|
||||
for prop_id, prop_data in ont_data.get('datatypeProperties', {}).items():
|
||||
datatype_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data)
|
||||
|
||||
# Create ontology
|
||||
ontology = Ontology(
|
||||
id=ont_id,
|
||||
metadata=ont_data.get('metadata', {}),
|
||||
classes=classes,
|
||||
object_properties=object_props,
|
||||
datatype_properties=datatype_props
|
||||
)
|
||||
|
||||
# Validate structure
|
||||
issues = ontology.validate_structure()
|
||||
if issues:
|
||||
logger.warning(f"Ontology {ont_id} has validation issues: {issues}")
|
||||
|
||||
self.ontologies[ont_id] = ontology
|
||||
logger.info(f"Loaded ontology {ont_id} with {len(classes)} classes, "
|
||||
f"{len(object_props)} object properties, "
|
||||
f"{len(datatype_props)} datatype properties")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load ontology {ont_id}: {e}", exc_info=True)
|
||||
|
||||
def get_ontology(self, ont_id: str) -> Optional[Ontology]:
|
||||
"""Get a specific ontology by ID.
|
||||
|
||||
Args:
|
||||
ont_id: Ontology identifier
|
||||
|
||||
Returns:
|
||||
Ontology object or None if not found
|
||||
"""
|
||||
return self.ontologies.get(ont_id)
|
||||
|
||||
def get_all_ontologies(self) -> Dict[str, Ontology]:
|
||||
"""Get all loaded ontologies.
|
||||
|
||||
Returns:
|
||||
Dictionary of ontology ID to Ontology objects
|
||||
"""
|
||||
return self.ontologies
|
||||
|
||||
def list_ontology_ids(self) -> List[str]:
|
||||
"""Get list of loaded ontology IDs.
|
||||
|
||||
Returns:
|
||||
List of ontology IDs
|
||||
"""
|
||||
return list(self.ontologies.keys())
|
||||
|
||||
def clear(self):
|
||||
"""Clear all loaded ontologies."""
|
||||
self.ontologies.clear()
|
||||
logger.info("Cleared all loaded ontologies")
|
||||
|
|
@ -0,0 +1,356 @@
|
|||
"""
|
||||
Ontology selection algorithm for OntoRAG system.
|
||||
Selects relevant ontology subsets based on text similarity.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Set, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
|
||||
from .ontology_loader import Ontology, OntologyLoader
|
||||
from .ontology_embedder import OntologyEmbedder
|
||||
from .text_processor import TextSegment
|
||||
from .vector_store import SearchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OntologySubset:
|
||||
"""Represents a subset of an ontology relevant to a text chunk."""
|
||||
ontology_id: str
|
||||
classes: Dict[str, Any]
|
||||
object_properties: Dict[str, Any]
|
||||
datatype_properties: Dict[str, Any]
|
||||
metadata: Dict[str, Any]
|
||||
relevance_score: float = 0.0
|
||||
|
||||
|
||||
class OntologySelector:
|
||||
"""Selects relevant ontology elements for text segments using vector similarity."""
|
||||
|
||||
def __init__(self, ontology_embedder: OntologyEmbedder,
|
||||
ontology_loader: OntologyLoader,
|
||||
top_k: int = 10,
|
||||
similarity_threshold: float = 0.7):
|
||||
"""Initialize the ontology selector.
|
||||
|
||||
Args:
|
||||
ontology_embedder: Embedder with vector store
|
||||
ontology_loader: Loader with ontology definitions
|
||||
top_k: Number of top results to retrieve per segment
|
||||
similarity_threshold: Minimum similarity score
|
||||
"""
|
||||
self.embedder = ontology_embedder
|
||||
self.loader = ontology_loader
|
||||
self.top_k = top_k
|
||||
self.similarity_threshold = similarity_threshold
|
||||
|
||||
async def select_ontology_subset(self, segments: List[TextSegment]) -> List[OntologySubset]:
|
||||
"""Select relevant ontology subsets for text segments.
|
||||
|
||||
Args:
|
||||
segments: List of text segments to match
|
||||
|
||||
Returns:
|
||||
List of ontology subsets with relevant elements
|
||||
"""
|
||||
# Collect all relevant elements
|
||||
relevant_elements = await self._find_relevant_elements(segments)
|
||||
|
||||
# Group by ontology and build subsets
|
||||
ontology_subsets = self._build_ontology_subsets(relevant_elements)
|
||||
|
||||
# Resolve dependencies
|
||||
for subset in ontology_subsets:
|
||||
self._resolve_dependencies(subset)
|
||||
|
||||
logger.info(f"Selected {len(ontology_subsets)} ontology subsets")
|
||||
return ontology_subsets
|
||||
|
||||
async def _find_relevant_elements(self, segments: List[TextSegment]) -> Set[Tuple[str, str, str, Dict]]:
|
||||
"""Find relevant ontology elements for text segments.
|
||||
|
||||
Args:
|
||||
segments: Text segments to match
|
||||
|
||||
Returns:
|
||||
Set of (ontology_id, element_type, element_id, definition) tuples
|
||||
"""
|
||||
relevant_elements = set()
|
||||
element_scores = defaultdict(float)
|
||||
|
||||
# Check if vector store has any elements
|
||||
vector_store = self.embedder.get_vector_store()
|
||||
store_size = vector_store.size()
|
||||
logger.info(f"Vector store size: {store_size} elements")
|
||||
|
||||
if store_size == 0:
|
||||
logger.warning("Vector store is empty - no ontology elements embedded")
|
||||
return relevant_elements
|
||||
|
||||
# Process each segment (log first few for debugging)
|
||||
for i, segment in enumerate(segments):
|
||||
# Get embedding for segment
|
||||
embedding = await self.embedder.embed_text(segment.text)
|
||||
if embedding is None:
|
||||
logger.warning(f"Failed to embed segment: {segment.text[:50]}...")
|
||||
continue
|
||||
|
||||
# Search vector store with no threshold to see all scores
|
||||
all_results = vector_store.search(
|
||||
embedding=embedding,
|
||||
top_k=self.top_k,
|
||||
threshold=0.0 # Get all results to see scores
|
||||
)
|
||||
|
||||
# Log top scores for first 3 segments to debug
|
||||
if i < 3 and all_results:
|
||||
top_scores = [r.score for r in all_results[:3]]
|
||||
top_elements = [r.metadata['element'] for r in all_results[:3]]
|
||||
logger.info(f"Segment {i}: '{segment.text[:60]}...'")
|
||||
logger.info(f" Top 3 scores: {top_scores} (threshold={self.similarity_threshold})")
|
||||
logger.info(f" Top 3 elements: {top_elements}")
|
||||
|
||||
# Filter by threshold
|
||||
results = [r for r in all_results if r.score >= self.similarity_threshold]
|
||||
|
||||
# Process results
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
element_key = (
|
||||
metadata['ontology'],
|
||||
metadata['type'],
|
||||
metadata['element'],
|
||||
str(metadata['definition']) # Convert dict to string for hashability
|
||||
)
|
||||
relevant_elements.add(element_key)
|
||||
# Track scores for ranking
|
||||
element_scores[element_key] = max(element_scores[element_key], result.score)
|
||||
|
||||
logger.info(f"Found {len(relevant_elements)} relevant elements from {len(segments)} segments")
|
||||
return relevant_elements
|
||||
|
||||
def _build_ontology_subsets(self, relevant_elements: Set[Tuple[str, str, str, Dict]]) -> List[OntologySubset]:
|
||||
"""Build ontology subsets from relevant elements.
|
||||
|
||||
Args:
|
||||
relevant_elements: Set of relevant element tuples
|
||||
|
||||
Returns:
|
||||
List of ontology subsets
|
||||
"""
|
||||
# Group elements by ontology
|
||||
ontology_groups = defaultdict(lambda: {
|
||||
'classes': {},
|
||||
'object_properties': {},
|
||||
'datatype_properties': {},
|
||||
'scores': []
|
||||
})
|
||||
|
||||
for ont_id, elem_type, elem_id, definition in relevant_elements:
|
||||
# Parse definition back from string if needed
|
||||
if isinstance(definition, str):
|
||||
import json
|
||||
try:
|
||||
definition = json.loads(definition.replace("'", '"'))
|
||||
except:
|
||||
definition = eval(definition) # Fallback for dict-like strings
|
||||
|
||||
# Get the actual ontology and element
|
||||
ontology = self.loader.get_ontology(ont_id)
|
||||
if not ontology:
|
||||
logger.warning(f"Ontology {ont_id} not found in loader")
|
||||
continue
|
||||
|
||||
# Add element to appropriate category
|
||||
if elem_type == 'class':
|
||||
cls = ontology.get_class(elem_id)
|
||||
if cls:
|
||||
ontology_groups[ont_id]['classes'][elem_id] = cls.__dict__
|
||||
elif elem_type == 'objectProperty':
|
||||
prop = ontology.object_properties.get(elem_id)
|
||||
if prop:
|
||||
ontology_groups[ont_id]['object_properties'][elem_id] = prop.__dict__
|
||||
elif elem_type == 'datatypeProperty':
|
||||
prop = ontology.datatype_properties.get(elem_id)
|
||||
if prop:
|
||||
ontology_groups[ont_id]['datatype_properties'][elem_id] = prop.__dict__
|
||||
|
||||
# Create OntologySubset objects
|
||||
subsets = []
|
||||
for ont_id, elements in ontology_groups.items():
|
||||
ontology = self.loader.get_ontology(ont_id)
|
||||
if ontology:
|
||||
subset = OntologySubset(
|
||||
ontology_id=ont_id,
|
||||
classes=elements['classes'],
|
||||
object_properties=elements['object_properties'],
|
||||
datatype_properties=elements['datatype_properties'],
|
||||
metadata=ontology.metadata,
|
||||
relevance_score=sum(elements['scores']) / len(elements['scores']) if elements['scores'] else 0.0
|
||||
)
|
||||
subsets.append(subset)
|
||||
|
||||
return subsets
|
||||
|
||||
def _resolve_dependencies(self, subset: OntologySubset):
|
||||
"""Resolve dependencies for ontology subset elements.
|
||||
|
||||
Args:
|
||||
subset: Ontology subset to resolve dependencies for
|
||||
"""
|
||||
ontology = self.loader.get_ontology(subset.ontology_id)
|
||||
if not ontology:
|
||||
return
|
||||
|
||||
# Track classes to add
|
||||
classes_to_add = set()
|
||||
|
||||
# Resolve class hierarchies
|
||||
for class_id in list(subset.classes.keys()):
|
||||
# Add parent classes
|
||||
parents = ontology.get_parent_classes(class_id)
|
||||
for parent_id in parents:
|
||||
parent_class = ontology.get_class(parent_id)
|
||||
if parent_class and parent_id not in subset.classes:
|
||||
classes_to_add.add(parent_id)
|
||||
|
||||
# Resolve property domains and ranges
|
||||
for prop_id, prop_def in subset.object_properties.items():
|
||||
# Add domain class
|
||||
if 'domain' in prop_def and prop_def['domain']:
|
||||
domain_id = prop_def['domain']
|
||||
if domain_id not in subset.classes:
|
||||
domain_class = ontology.get_class(domain_id)
|
||||
if domain_class:
|
||||
classes_to_add.add(domain_id)
|
||||
|
||||
# Add range class
|
||||
if 'range' in prop_def and prop_def['range']:
|
||||
range_id = prop_def['range']
|
||||
if range_id not in subset.classes:
|
||||
range_class = ontology.get_class(range_id)
|
||||
if range_class:
|
||||
classes_to_add.add(range_id)
|
||||
|
||||
# Resolve datatype property domains
|
||||
for prop_id, prop_def in subset.datatype_properties.items():
|
||||
if 'domain' in prop_def and prop_def['domain']:
|
||||
domain_id = prop_def['domain']
|
||||
if domain_id not in subset.classes:
|
||||
domain_class = ontology.get_class(domain_id)
|
||||
if domain_class:
|
||||
classes_to_add.add(domain_id)
|
||||
|
||||
# Add inverse properties
|
||||
for prop_id, prop_def in list(subset.object_properties.items()):
|
||||
if 'inverse_of' in prop_def and prop_def['inverse_of']:
|
||||
inverse_id = prop_def['inverse_of']
|
||||
if inverse_id not in subset.object_properties:
|
||||
inverse_prop = ontology.object_properties.get(inverse_id)
|
||||
if inverse_prop:
|
||||
subset.object_properties[inverse_id] = inverse_prop.__dict__
|
||||
|
||||
# NEW: Auto-include properties related to selected classes
|
||||
# For each selected class, find all properties that reference it in domain or range
|
||||
properties_added = 0
|
||||
datatype_properties_added = 0
|
||||
|
||||
for class_id in list(subset.classes.keys()):
|
||||
# Check all object properties in the ontology
|
||||
for prop_id, prop_def in ontology.object_properties.items():
|
||||
if prop_id not in subset.object_properties:
|
||||
# Check if this class is in the property's domain or range
|
||||
prop_domain = getattr(prop_def, 'domain', None)
|
||||
prop_range = getattr(prop_def, 'range', None)
|
||||
|
||||
if prop_domain == class_id or prop_range == class_id:
|
||||
subset.object_properties[prop_id] = prop_def.__dict__
|
||||
properties_added += 1
|
||||
|
||||
# Also add the other class (domain or range) if not already present
|
||||
if prop_domain and prop_domain != class_id and prop_domain not in subset.classes:
|
||||
other_class = ontology.get_class(prop_domain)
|
||||
if other_class:
|
||||
classes_to_add.add(prop_domain)
|
||||
if prop_range and prop_range != class_id and prop_range not in subset.classes:
|
||||
other_class = ontology.get_class(prop_range)
|
||||
if other_class:
|
||||
classes_to_add.add(prop_range)
|
||||
|
||||
# Check all datatype properties in the ontology
|
||||
for prop_id, prop_def in ontology.datatype_properties.items():
|
||||
if prop_id not in subset.datatype_properties:
|
||||
# Check if this class is in the property's domain
|
||||
prop_domain = getattr(prop_def, 'domain', None)
|
||||
|
||||
if prop_domain == class_id:
|
||||
subset.datatype_properties[prop_id] = prop_def.__dict__
|
||||
datatype_properties_added += 1
|
||||
|
||||
# Add collected classes
|
||||
for class_id in classes_to_add:
|
||||
cls = ontology.get_class(class_id)
|
||||
if cls:
|
||||
subset.classes[class_id] = cls.__dict__
|
||||
|
||||
logger.debug(f"Resolved dependencies for subset {subset.ontology_id}: "
|
||||
f"added {len(classes_to_add)} classes, "
|
||||
f"{properties_added} object properties, "
|
||||
f"{datatype_properties_added} datatype properties")
|
||||
|
||||
def merge_subsets(self, subsets: List[OntologySubset]) -> OntologySubset:
|
||||
"""Merge multiple ontology subsets into one.
|
||||
|
||||
Args:
|
||||
subsets: List of subsets to merge
|
||||
|
||||
Returns:
|
||||
Merged ontology subset
|
||||
"""
|
||||
if not subsets:
|
||||
return None
|
||||
if len(subsets) == 1:
|
||||
return subsets[0]
|
||||
|
||||
# Use first subset as base
|
||||
merged = OntologySubset(
|
||||
ontology_id="merged",
|
||||
classes={},
|
||||
object_properties={},
|
||||
datatype_properties={},
|
||||
metadata={},
|
||||
relevance_score=0.0
|
||||
)
|
||||
|
||||
# Merge all subsets
|
||||
total_score = 0.0
|
||||
for subset in subsets:
|
||||
# Merge classes
|
||||
for class_id, class_def in subset.classes.items():
|
||||
key = f"{subset.ontology_id}:{class_id}"
|
||||
merged.classes[key] = class_def
|
||||
|
||||
# Merge object properties
|
||||
for prop_id, prop_def in subset.object_properties.items():
|
||||
key = f"{subset.ontology_id}:{prop_id}"
|
||||
merged.object_properties[key] = prop_def
|
||||
|
||||
# Merge datatype properties
|
||||
for prop_id, prop_def in subset.datatype_properties.items():
|
||||
key = f"{subset.ontology_id}:{prop_id}"
|
||||
merged.datatype_properties[key] = prop_def
|
||||
|
||||
total_score += subset.relevance_score
|
||||
|
||||
# Average relevance score
|
||||
merged.relevance_score = total_score / len(subsets)
|
||||
|
||||
logger.info(f"Merged {len(subsets)} subsets into one with "
|
||||
f"{len(merged.classes)} classes, "
|
||||
f"{len(merged.object_properties)} object properties, "
|
||||
f"{len(merged.datatype_properties)} datatype properties")
|
||||
|
||||
return merged
|
||||
10
trustgraph-flow/trustgraph/extract/kg/ontology/run.py
Normal file
10
trustgraph-flow/trustgraph/extract/kg/ontology/run.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
OntoRAG extraction service launcher.
|
||||
"""
|
||||
|
||||
from . extract import run
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
240
trustgraph-flow/trustgraph/extract/kg/ontology/text_processor.py
Normal file
240
trustgraph-flow/trustgraph/extract/kg/ontology/text_processor.py
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
"""
|
||||
Text processing components for OntoRAG system.
|
||||
Splits text into sentences and extracts phrases for granular matching.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Ensure required NLTK data is downloaded
|
||||
try:
|
||||
nltk.data.find('tokenizers/punkt_tab')
|
||||
except LookupError:
|
||||
try:
|
||||
nltk.download('punkt_tab', quiet=True)
|
||||
except:
|
||||
# Fallback to older punkt if punkt_tab not available
|
||||
try:
|
||||
nltk.download('punkt', quiet=True)
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
nltk.data.find('taggers/averaged_perceptron_tagger_eng')
|
||||
except LookupError:
|
||||
try:
|
||||
nltk.download('averaged_perceptron_tagger_eng', quiet=True)
|
||||
except:
|
||||
# Fallback to older name
|
||||
try:
|
||||
nltk.download('averaged_perceptron_tagger', quiet=True)
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
nltk.data.find('corpora/stopwords')
|
||||
except LookupError:
|
||||
nltk.download('stopwords', quiet=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextSegment:
|
||||
"""Represents a segment of text (sentence or phrase)."""
|
||||
text: str
|
||||
type: str # 'sentence', 'phrase', 'noun_phrase', 'verb_phrase'
|
||||
position: int
|
||||
parent_sentence: Optional[str] = None
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
|
||||
class SentenceSplitter:
|
||||
"""Splits text into sentences using NLTK."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize sentence splitter."""
|
||||
try:
|
||||
# Try newer punkt_tab first
|
||||
self.sent_detector = nltk.data.load('tokenizers/punkt_tab/english/')
|
||||
logger.info("Using NLTK sentence tokenizer (punkt_tab)")
|
||||
except:
|
||||
# Fallback to older punkt
|
||||
self.sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
|
||||
logger.info("Using NLTK sentence tokenizer (punkt)")
|
||||
|
||||
def split(self, text: str) -> List[str]:
|
||||
"""Split text into sentences.
|
||||
|
||||
Args:
|
||||
text: Text to split
|
||||
|
||||
Returns:
|
||||
List of sentences
|
||||
"""
|
||||
sentences = self.sent_detector.tokenize(text)
|
||||
return sentences
|
||||
|
||||
|
||||
class PhraseExtractor:
|
||||
"""Extracts meaningful phrases from sentences using NLTK."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize phrase extractor."""
|
||||
logger.info("Using NLTK phrase extraction")
|
||||
|
||||
def extract(self, sentence: str) -> List[Dict[str, str]]:
|
||||
"""Extract phrases from a sentence.
|
||||
|
||||
Args:
|
||||
sentence: Sentence to extract phrases from
|
||||
|
||||
Returns:
|
||||
List of phrases with their types
|
||||
"""
|
||||
phrases = []
|
||||
|
||||
# Tokenize and POS tag
|
||||
tokens = nltk.word_tokenize(sentence)
|
||||
pos_tags = nltk.pos_tag(tokens)
|
||||
|
||||
# Extract noun phrases (simple pattern)
|
||||
noun_phrase = []
|
||||
for word, pos in pos_tags:
|
||||
if pos.startswith('NN') or pos.startswith('JJ'):
|
||||
noun_phrase.append(word)
|
||||
elif noun_phrase:
|
||||
if len(noun_phrase) > 1:
|
||||
phrases.append({
|
||||
'text': ' '.join(noun_phrase),
|
||||
'type': 'noun_phrase'
|
||||
})
|
||||
noun_phrase = []
|
||||
|
||||
# Add last noun phrase if exists
|
||||
if noun_phrase and len(noun_phrase) > 1:
|
||||
phrases.append({
|
||||
'text': ' '.join(noun_phrase),
|
||||
'type': 'noun_phrase'
|
||||
})
|
||||
|
||||
# Extract verb phrases (simple pattern)
|
||||
verb_phrase = []
|
||||
for word, pos in pos_tags:
|
||||
if pos.startswith('VB') or pos.startswith('RB'):
|
||||
verb_phrase.append(word)
|
||||
elif verb_phrase:
|
||||
if len(verb_phrase) > 1:
|
||||
phrases.append({
|
||||
'text': ' '.join(verb_phrase),
|
||||
'type': 'verb_phrase'
|
||||
})
|
||||
verb_phrase = []
|
||||
|
||||
# Add last verb phrase if exists
|
||||
if verb_phrase and len(verb_phrase) > 1:
|
||||
phrases.append({
|
||||
'text': ' '.join(verb_phrase),
|
||||
'type': 'verb_phrase'
|
||||
})
|
||||
|
||||
return phrases
|
||||
|
||||
|
||||
class TextProcessor:
|
||||
"""Main text processing class that coordinates sentence splitting and phrase extraction."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize text processor."""
|
||||
self.sentence_splitter = SentenceSplitter()
|
||||
self.phrase_extractor = PhraseExtractor()
|
||||
|
||||
def process_chunk(self, chunk_text: str, extract_phrases: bool = True) -> List[TextSegment]:
|
||||
"""Process a text chunk into segments.
|
||||
|
||||
Args:
|
||||
chunk_text: Text chunk to process
|
||||
extract_phrases: Whether to extract phrases from sentences
|
||||
|
||||
Returns:
|
||||
List of TextSegment objects
|
||||
"""
|
||||
segments = []
|
||||
position = 0
|
||||
|
||||
# Split into sentences
|
||||
sentences = self.sentence_splitter.split(chunk_text)
|
||||
|
||||
for sentence in sentences:
|
||||
# Add sentence segment
|
||||
segments.append(TextSegment(
|
||||
text=sentence,
|
||||
type='sentence',
|
||||
position=position
|
||||
))
|
||||
position += 1
|
||||
|
||||
# Extract phrases if requested
|
||||
if extract_phrases:
|
||||
phrases = self.phrase_extractor.extract(sentence)
|
||||
for phrase_data in phrases:
|
||||
segments.append(TextSegment(
|
||||
text=phrase_data['text'],
|
||||
type=phrase_data['type'],
|
||||
position=position,
|
||||
parent_sentence=sentence
|
||||
))
|
||||
position += 1
|
||||
|
||||
logger.debug(f"Processed chunk into {len(segments)} segments")
|
||||
return segments
|
||||
|
||||
def extract_key_terms(self, text: str) -> List[str]:
|
||||
"""Extract key terms from text for matching.
|
||||
|
||||
Args:
|
||||
text: Text to extract terms from
|
||||
|
||||
Returns:
|
||||
List of key terms
|
||||
"""
|
||||
terms = []
|
||||
|
||||
# Split on word boundaries
|
||||
words = re.findall(r'\b\w+\b', text.lower())
|
||||
|
||||
# Use NLTK stopwords
|
||||
stop_words = set(stopwords.words('english'))
|
||||
|
||||
# Filter stopwords and short words
|
||||
terms = [w for w in words if w not in stop_words and len(w) > 2]
|
||||
|
||||
# Also extract multi-word terms (bigrams)
|
||||
for i in range(len(words) - 1):
|
||||
if words[i] not in stop_words and words[i+1] not in stop_words:
|
||||
bigram = f"{words[i]} {words[i+1]}"
|
||||
terms.append(bigram)
|
||||
|
||||
return terms
|
||||
|
||||
def normalize_text(self, text: str) -> str:
|
||||
"""Normalize text for consistent processing.
|
||||
|
||||
Args:
|
||||
text: Text to normalize
|
||||
|
||||
Returns:
|
||||
Normalized text
|
||||
"""
|
||||
# Remove extra whitespace
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
# Remove leading/trailing whitespace
|
||||
text = text.strip()
|
||||
# Normalize quotes
|
||||
text = text.replace('"', '"').replace('"', '"')
|
||||
text = text.replace(''', "'").replace(''', "'")
|
||||
return text
|
||||
130
trustgraph-flow/trustgraph/extract/kg/ontology/vector_store.py
Normal file
130
trustgraph-flow/trustgraph/extract/kg/ontology/vector_store.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
"""
|
||||
Vector store implementation for OntoRAG system.
|
||||
Provides FAISS-based vector storage for ontology embeddings.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
import faiss
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Result from vector similarity search."""
|
||||
id: str
|
||||
score: float
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class InMemoryVectorStore:
|
||||
"""FAISS-based vector store implementation for ontology embeddings."""
|
||||
|
||||
def __init__(self, dimension: int = 1536, index_type: str = 'flat'):
|
||||
"""Initialize FAISS vector store.
|
||||
|
||||
Args:
|
||||
dimension: Embedding dimension (1536 for text-embedding-3-small)
|
||||
index_type: 'flat' for exact search, 'ivf' for larger datasets
|
||||
"""
|
||||
self.dimension = dimension
|
||||
self.metadata = []
|
||||
self.ids = []
|
||||
|
||||
if index_type == 'flat':
|
||||
# Exact search - best for ontologies with <10k elements
|
||||
self.index = faiss.IndexFlatIP(dimension)
|
||||
logger.info(f"Created FAISS flat index with dimension {dimension}")
|
||||
else:
|
||||
# Approximate search - for larger ontologies
|
||||
quantizer = faiss.IndexFlatIP(dimension)
|
||||
self.index = faiss.IndexIVFFlat(quantizer, dimension, 100)
|
||||
# Train with random vectors for initialization
|
||||
training_data = np.random.randn(1000, dimension).astype('float32')
|
||||
training_data = training_data / np.linalg.norm(
|
||||
training_data, axis=1, keepdims=True
|
||||
)
|
||||
self.index.train(training_data)
|
||||
logger.info(f"Created FAISS IVF index with dimension {dimension}")
|
||||
|
||||
def add(self, id: str, embedding: np.ndarray, metadata: Dict[str, Any]):
|
||||
"""Add single embedding with metadata."""
|
||||
# Normalize for cosine similarity
|
||||
embedding = embedding / np.linalg.norm(embedding)
|
||||
self.index.add(np.array([embedding], dtype=np.float32))
|
||||
self.metadata.append(metadata)
|
||||
self.ids.append(id)
|
||||
|
||||
def add_batch(self, ids: List[str], embeddings: np.ndarray,
|
||||
metadata_list: List[Dict[str, Any]]):
|
||||
"""Batch add for initial ontology loading."""
|
||||
# Normalize all embeddings
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
normalized = embeddings / norms
|
||||
self.index.add(normalized.astype(np.float32))
|
||||
self.metadata.extend(metadata_list)
|
||||
self.ids.extend(ids)
|
||||
logger.debug(f"Added batch of {len(ids)} embeddings to FAISS index")
|
||||
|
||||
def search(self, embedding: np.ndarray, top_k: int = 10,
|
||||
threshold: float = 0.0) -> List[SearchResult]:
|
||||
"""Search for similar vectors."""
|
||||
# Normalize query
|
||||
embedding = embedding / np.linalg.norm(embedding)
|
||||
|
||||
# Search
|
||||
scores, indices = self.index.search(
|
||||
np.array([embedding], dtype=np.float32),
|
||||
min(top_k, self.index.ntotal)
|
||||
)
|
||||
|
||||
# Filter by threshold and format results
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx >= 0 and score >= threshold: # FAISS returns -1 for empty slots
|
||||
results.append(SearchResult(
|
||||
id=self.ids[idx],
|
||||
score=float(score),
|
||||
metadata=self.metadata[idx]
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
def clear(self):
|
||||
"""Reset the store."""
|
||||
self.index.reset()
|
||||
self.metadata = []
|
||||
self.ids = []
|
||||
logger.info("Cleared FAISS vector store")
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return number of stored vectors."""
|
||||
return self.index.ntotal
|
||||
|
||||
|
||||
# Utility functions for vector operations
|
||||
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""Compute cosine similarity between two vectors."""
|
||||
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
|
||||
|
||||
def batch_cosine_similarity(queries: np.ndarray, targets: np.ndarray) -> np.ndarray:
|
||||
"""Compute cosine similarity between query vectors and target vectors.
|
||||
|
||||
Args:
|
||||
queries: Array of shape (n_queries, dimension)
|
||||
targets: Array of shape (n_targets, dimension)
|
||||
|
||||
Returns:
|
||||
Array of shape (n_queries, n_targets) with similarity scores
|
||||
"""
|
||||
# Normalize queries and targets
|
||||
queries_norm = queries / np.linalg.norm(queries, axis=1, keepdims=True)
|
||||
targets_norm = targets / np.linalg.norm(targets, axis=1, keepdims=True)
|
||||
|
||||
# Compute dot product
|
||||
similarities = np.dot(queries_norm, targets_norm.T)
|
||||
return similarities
|
||||
54
trustgraph-flow/trustgraph/query/ontology/__init__.py
Normal file
54
trustgraph-flow/trustgraph/query/ontology/__init__.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
"""
|
||||
OntoRAG Query System.
|
||||
|
||||
Ontology-driven natural language query processing with multi-backend support.
|
||||
Provides semantic query understanding, ontology matching, and answer generation.
|
||||
"""
|
||||
|
||||
from .query_service import OntoRAGQueryService, QueryRequest, QueryResponse
|
||||
from .question_analyzer import QuestionAnalyzer, QuestionComponents, QuestionType
|
||||
from .ontology_matcher import OntologyMatcher, QueryOntologySubset
|
||||
from .backend_router import BackendRouter, BackendType, QueryRoute
|
||||
from .sparql_generator import SPARQLGenerator, SPARQLQuery
|
||||
from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult
|
||||
from .cypher_generator import CypherGenerator, CypherQuery
|
||||
from .cypher_executor import CypherExecutor, CypherResult
|
||||
from .answer_generator import AnswerGenerator, GeneratedAnswer, AnswerMetadata
|
||||
|
||||
__all__ = [
|
||||
# Main service
|
||||
'OntoRAGQueryService',
|
||||
'QueryRequest',
|
||||
'QueryResponse',
|
||||
|
||||
# Question analysis
|
||||
'QuestionAnalyzer',
|
||||
'QuestionComponents',
|
||||
'QuestionType',
|
||||
|
||||
# Ontology matching
|
||||
'OntologyMatcher',
|
||||
'QueryOntologySubset',
|
||||
|
||||
# Backend routing
|
||||
'BackendRouter',
|
||||
'BackendType',
|
||||
'QueryRoute',
|
||||
|
||||
# SPARQL components
|
||||
'SPARQLGenerator',
|
||||
'SPARQLQuery',
|
||||
'SPARQLCassandraEngine',
|
||||
'SPARQLResult',
|
||||
|
||||
# Cypher components
|
||||
'CypherGenerator',
|
||||
'CypherQuery',
|
||||
'CypherExecutor',
|
||||
'CypherResult',
|
||||
|
||||
# Answer generation
|
||||
'AnswerGenerator',
|
||||
'GeneratedAnswer',
|
||||
'AnswerMetadata',
|
||||
]
|
||||
521
trustgraph-flow/trustgraph/query/ontology/answer_generator.py
Normal file
521
trustgraph-flow/trustgraph/query/ontology/answer_generator.py
Normal file
|
|
@ -0,0 +1,521 @@
|
|||
"""
|
||||
Answer generator for natural language responses.
|
||||
Converts query results into natural language answers using LLM assistance.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from .question_analyzer import QuestionComponents, QuestionType
|
||||
from .ontology_matcher import QueryOntologySubset
|
||||
from .sparql_cassandra import SPARQLResult
|
||||
from .cypher_executor import CypherResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnswerMetadata:
|
||||
"""Metadata about answer generation."""
|
||||
query_type: str
|
||||
backend_used: str
|
||||
execution_time: float
|
||||
result_count: int
|
||||
confidence: float
|
||||
explanation: str
|
||||
sources: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratedAnswer:
|
||||
"""Generated natural language answer."""
|
||||
answer: str
|
||||
metadata: AnswerMetadata
|
||||
supporting_facts: List[str]
|
||||
raw_results: Union[SPARQLResult, CypherResult]
|
||||
generation_time: float
|
||||
|
||||
|
||||
class AnswerGenerator:
|
||||
"""Generates natural language answers from query results."""
|
||||
|
||||
def __init__(self, prompt_service=None):
|
||||
"""Initialize answer generator.
|
||||
|
||||
Args:
|
||||
prompt_service: Service for LLM-based answer generation
|
||||
"""
|
||||
self.prompt_service = prompt_service
|
||||
|
||||
# Answer templates for different question types
|
||||
self.templates = {
|
||||
'count': "There are {count} {entity_type}.",
|
||||
'boolean_true': "Yes, {statement} is true.",
|
||||
'boolean_false': "No, {statement} is not true.",
|
||||
'list': "The {entity_type} are: {items}.",
|
||||
'single': "The {property} of {entity} is {value}.",
|
||||
'none': "No results were found for your query.",
|
||||
'error': "I encountered an error processing your query: {error}"
|
||||
}
|
||||
|
||||
async def generate_answer(self,
|
||||
question_components: QuestionComponents,
|
||||
query_results: Union[SPARQLResult, CypherResult],
|
||||
ontology_subset: QueryOntologySubset,
|
||||
backend_used: str) -> GeneratedAnswer:
|
||||
"""Generate natural language answer from query results.
|
||||
|
||||
Args:
|
||||
question_components: Original question analysis
|
||||
query_results: Results from query execution
|
||||
ontology_subset: Ontology subset used
|
||||
backend_used: Backend that executed the query
|
||||
|
||||
Returns:
|
||||
Generated answer with metadata
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
# Try LLM-based generation first
|
||||
if self.prompt_service:
|
||||
llm_answer = await self._generate_with_llm(
|
||||
question_components, query_results, ontology_subset
|
||||
)
|
||||
if llm_answer:
|
||||
execution_time = (datetime.now() - start_time).total_seconds()
|
||||
return self._build_answer_response(
|
||||
llm_answer, question_components, query_results,
|
||||
backend_used, execution_time
|
||||
)
|
||||
|
||||
# Fall back to template-based generation
|
||||
template_answer = self._generate_with_template(
|
||||
question_components, query_results, ontology_subset
|
||||
)
|
||||
|
||||
execution_time = (datetime.now() - start_time).total_seconds()
|
||||
return self._build_answer_response(
|
||||
template_answer, question_components, query_results,
|
||||
backend_used, execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Answer generation failed: {e}")
|
||||
execution_time = (datetime.now() - start_time).total_seconds()
|
||||
error_answer = self.templates['error'].format(error=str(e))
|
||||
return self._build_answer_response(
|
||||
error_answer, question_components, query_results,
|
||||
backend_used, execution_time, confidence=0.0
|
||||
)
|
||||
|
||||
async def _generate_with_llm(self,
|
||||
question_components: QuestionComponents,
|
||||
query_results: Union[SPARQLResult, CypherResult],
|
||||
ontology_subset: QueryOntologySubset) -> Optional[str]:
|
||||
"""Generate answer using LLM.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
query_results: Query results
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Generated answer or None if failed
|
||||
"""
|
||||
try:
|
||||
prompt = self._build_answer_prompt(
|
||||
question_components, query_results, ontology_subset
|
||||
)
|
||||
response = await self.prompt_service.generate_answer(prompt=prompt)
|
||||
|
||||
if response and isinstance(response, dict):
|
||||
return response.get('answer', '').strip()
|
||||
elif isinstance(response, str):
|
||||
return response.strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM answer generation failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _generate_with_template(self,
|
||||
question_components: QuestionComponents,
|
||||
query_results: Union[SPARQLResult, CypherResult],
|
||||
ontology_subset: QueryOntologySubset) -> str:
|
||||
"""Generate answer using templates.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
query_results: Query results
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Template-based answer
|
||||
"""
|
||||
# Handle empty results
|
||||
if not self._has_results(query_results):
|
||||
return self.templates['none']
|
||||
|
||||
# Handle boolean queries
|
||||
if question_components.question_type == QuestionType.BOOLEAN:
|
||||
if hasattr(query_results, 'ask_result'):
|
||||
# SPARQL ASK result
|
||||
statement = self._extract_boolean_statement(question_components)
|
||||
if query_results.ask_result:
|
||||
return self.templates['boolean_true'].format(statement=statement)
|
||||
else:
|
||||
return self.templates['boolean_false'].format(statement=statement)
|
||||
else:
|
||||
# Cypher boolean (check if any results)
|
||||
has_results = len(query_results.records) > 0
|
||||
statement = self._extract_boolean_statement(question_components)
|
||||
if has_results:
|
||||
return self.templates['boolean_true'].format(statement=statement)
|
||||
else:
|
||||
return self.templates['boolean_false'].format(statement=statement)
|
||||
|
||||
# Handle count queries
|
||||
if question_components.question_type == QuestionType.AGGREGATION:
|
||||
count = self._extract_count(query_results)
|
||||
entity_type = self._infer_entity_type(question_components, ontology_subset)
|
||||
return self.templates['count'].format(count=count, entity_type=entity_type)
|
||||
|
||||
# Handle retrieval queries
|
||||
if question_components.question_type == QuestionType.RETRIEVAL:
|
||||
items = self._extract_items(query_results)
|
||||
if len(items) == 1:
|
||||
# Single result
|
||||
entity = question_components.entities[0] if question_components.entities else "entity"
|
||||
property_name = "value"
|
||||
return self.templates['single'].format(
|
||||
property=property_name, entity=entity, value=items[0]
|
||||
)
|
||||
else:
|
||||
# Multiple results
|
||||
entity_type = self._infer_entity_type(question_components, ontology_subset)
|
||||
items_str = ", ".join(items)
|
||||
return self.templates['list'].format(entity_type=entity_type, items=items_str)
|
||||
|
||||
# Handle factual queries
|
||||
if question_components.question_type == QuestionType.FACTUAL:
|
||||
facts = self._extract_facts(query_results)
|
||||
return ". ".join(facts) if facts else self.templates['none']
|
||||
|
||||
# Default fallback
|
||||
items = self._extract_items(query_results)
|
||||
if items:
|
||||
return f"Found: {', '.join(items[:5])}" + ("..." if len(items) > 5 else "")
|
||||
else:
|
||||
return self.templates['none']
|
||||
|
||||
def _build_answer_prompt(self,
|
||||
question_components: QuestionComponents,
|
||||
query_results: Union[SPARQLResult, CypherResult],
|
||||
ontology_subset: QueryOntologySubset) -> str:
|
||||
"""Build prompt for LLM answer generation.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
query_results: Query results
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Formatted prompt string
|
||||
"""
|
||||
# Format results for prompt
|
||||
results_str = self._format_results_for_prompt(query_results)
|
||||
|
||||
# Extract ontology context
|
||||
context_classes = list(ontology_subset.classes.keys())[:5]
|
||||
context_properties = list(ontology_subset.object_properties.keys())[:5]
|
||||
|
||||
prompt = f"""Generate a natural language answer for the following question based on the query results.
|
||||
|
||||
ORIGINAL QUESTION: {question_components.original_question}
|
||||
|
||||
QUESTION TYPE: {question_components.question_type.value}
|
||||
EXPECTED ANSWER: {question_components.expected_answer_type}
|
||||
|
||||
ONTOLOGY CONTEXT:
|
||||
- Classes: {', '.join(context_classes)}
|
||||
- Properties: {', '.join(context_properties)}
|
||||
|
||||
QUERY RESULTS:
|
||||
{results_str}
|
||||
|
||||
INSTRUCTIONS:
|
||||
- Provide a clear, concise answer in natural language
|
||||
- Use the original question's tone and style
|
||||
- Include specific facts from the results
|
||||
- If no results, explain that no information was found
|
||||
- Be accurate and don't make assumptions beyond the data
|
||||
- Limit response to 2-3 sentences unless the question requires more detail
|
||||
|
||||
ANSWER:"""
|
||||
|
||||
return prompt
|
||||
|
||||
def _format_results_for_prompt(self, query_results: Union[SPARQLResult, CypherResult]) -> str:
|
||||
"""Format query results for prompt inclusion.
|
||||
|
||||
Args:
|
||||
query_results: Query results to format
|
||||
|
||||
Returns:
|
||||
Formatted results string
|
||||
"""
|
||||
if isinstance(query_results, SPARQLResult):
|
||||
if hasattr(query_results, 'ask_result') and query_results.ask_result is not None:
|
||||
return f"Boolean result: {query_results.ask_result}"
|
||||
|
||||
if not query_results.bindings:
|
||||
return "No results found"
|
||||
|
||||
# Format SPARQL bindings
|
||||
lines = []
|
||||
for binding in query_results.bindings[:10]: # Limit to first 10
|
||||
formatted = []
|
||||
for var, value in binding.items():
|
||||
if isinstance(value, dict):
|
||||
formatted.append(f"{var}: {value.get('value', value)}")
|
||||
else:
|
||||
formatted.append(f"{var}: {value}")
|
||||
lines.append("- " + ", ".join(formatted))
|
||||
|
||||
if len(query_results.bindings) > 10:
|
||||
lines.append(f"... and {len(query_results.bindings) - 10} more results")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
else: # CypherResult
|
||||
if not query_results.records:
|
||||
return "No results found"
|
||||
|
||||
# Format Cypher records
|
||||
lines = []
|
||||
for record in query_results.records[:10]: # Limit to first 10
|
||||
if isinstance(record, dict):
|
||||
formatted = [f"{k}: {v}" for k, v in record.items()]
|
||||
lines.append("- " + ", ".join(formatted))
|
||||
else:
|
||||
lines.append(f"- {record}")
|
||||
|
||||
if len(query_results.records) > 10:
|
||||
lines.append(f"... and {len(query_results.records) - 10} more results")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _has_results(self, query_results: Union[SPARQLResult, CypherResult]) -> bool:
|
||||
"""Check if query results contain data.
|
||||
|
||||
Args:
|
||||
query_results: Query results to check
|
||||
|
||||
Returns:
|
||||
True if results contain data
|
||||
"""
|
||||
if isinstance(query_results, SPARQLResult):
|
||||
return bool(query_results.bindings) or query_results.ask_result is not None
|
||||
else: # CypherResult
|
||||
return bool(query_results.records)
|
||||
|
||||
def _extract_count(self, query_results: Union[SPARQLResult, CypherResult]) -> int:
|
||||
"""Extract count from aggregation query results.
|
||||
|
||||
Args:
|
||||
query_results: Query results
|
||||
|
||||
Returns:
|
||||
Count value
|
||||
"""
|
||||
if isinstance(query_results, SPARQLResult):
|
||||
if query_results.bindings:
|
||||
binding = query_results.bindings[0]
|
||||
# Look for count variable
|
||||
for var, value in binding.items():
|
||||
if 'count' in var.lower():
|
||||
if isinstance(value, dict):
|
||||
return int(value.get('value', 0))
|
||||
return int(value)
|
||||
return len(query_results.bindings)
|
||||
else: # CypherResult
|
||||
if query_results.records:
|
||||
record = query_results.records[0]
|
||||
if isinstance(record, dict):
|
||||
# Look for count key
|
||||
for key, value in record.items():
|
||||
if 'count' in key.lower():
|
||||
return int(value)
|
||||
elif isinstance(record, (int, float)):
|
||||
return int(record)
|
||||
return len(query_results.records)
|
||||
|
||||
def _extract_items(self, query_results: Union[SPARQLResult, CypherResult]) -> List[str]:
|
||||
"""Extract items from query results.
|
||||
|
||||
Args:
|
||||
query_results: Query results
|
||||
|
||||
Returns:
|
||||
List of extracted items
|
||||
"""
|
||||
items = []
|
||||
|
||||
if isinstance(query_results, SPARQLResult):
|
||||
for binding in query_results.bindings:
|
||||
for var, value in binding.items():
|
||||
if isinstance(value, dict):
|
||||
item_value = value.get('value', str(value))
|
||||
else:
|
||||
item_value = str(value)
|
||||
|
||||
# Clean up URIs
|
||||
if item_value.startswith('http'):
|
||||
item_value = item_value.split('/')[-1].split('#')[-1]
|
||||
|
||||
items.append(item_value)
|
||||
break # Take first value per binding
|
||||
|
||||
else: # CypherResult
|
||||
for record in query_results.records:
|
||||
if isinstance(record, dict):
|
||||
# Take first value from record
|
||||
for key, value in record.items():
|
||||
items.append(str(value))
|
||||
break
|
||||
else:
|
||||
items.append(str(record))
|
||||
|
||||
return items
|
||||
|
||||
def _extract_facts(self, query_results: Union[SPARQLResult, CypherResult]) -> List[str]:
|
||||
"""Extract facts from query results.
|
||||
|
||||
Args:
|
||||
query_results: Query results
|
||||
|
||||
Returns:
|
||||
List of facts
|
||||
"""
|
||||
facts = []
|
||||
|
||||
if isinstance(query_results, SPARQLResult):
|
||||
for binding in query_results.bindings:
|
||||
fact_parts = []
|
||||
for var, value in binding.items():
|
||||
if isinstance(value, dict):
|
||||
val_str = value.get('value', str(value))
|
||||
else:
|
||||
val_str = str(value)
|
||||
|
||||
# Clean up URIs
|
||||
if val_str.startswith('http'):
|
||||
val_str = val_str.split('/')[-1].split('#')[-1]
|
||||
|
||||
fact_parts.append(f"{var}: {val_str}")
|
||||
|
||||
facts.append(", ".join(fact_parts))
|
||||
|
||||
else: # CypherResult
|
||||
for record in query_results.records:
|
||||
if isinstance(record, dict):
|
||||
fact_parts = [f"{k}: {v}" for k, v in record.items()]
|
||||
facts.append(", ".join(fact_parts))
|
||||
else:
|
||||
facts.append(str(record))
|
||||
|
||||
return facts
|
||||
|
||||
def _extract_boolean_statement(self, question_components: QuestionComponents) -> str:
|
||||
"""Extract statement for boolean answer.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
|
||||
Returns:
|
||||
Statement string
|
||||
"""
|
||||
# Extract the key assertion from the question
|
||||
question = question_components.original_question.lower()
|
||||
|
||||
# Remove question words
|
||||
statement = question.replace('is ', '').replace('are ', '').replace('does ', '')
|
||||
statement = statement.replace('?', '').strip()
|
||||
|
||||
return statement
|
||||
|
||||
def _infer_entity_type(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset) -> str:
|
||||
"""Infer entity type from question and ontology.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Entity type string
|
||||
"""
|
||||
# Try to match entities to ontology classes
|
||||
for entity in question_components.entities:
|
||||
entity_lower = entity.lower()
|
||||
for class_id in ontology_subset.classes:
|
||||
if class_id.lower() == entity_lower or entity_lower in class_id.lower():
|
||||
return class_id
|
||||
|
||||
# Fallback to first entity or generic term
|
||||
if question_components.entities:
|
||||
return question_components.entities[0]
|
||||
else:
|
||||
return "entities"
|
||||
|
||||
def _build_answer_response(self,
|
||||
answer: str,
|
||||
question_components: QuestionComponents,
|
||||
query_results: Union[SPARQLResult, CypherResult],
|
||||
backend_used: str,
|
||||
execution_time: float,
|
||||
confidence: float = 0.8) -> GeneratedAnswer:
|
||||
"""Build final answer response.
|
||||
|
||||
Args:
|
||||
answer: Generated answer text
|
||||
question_components: Question analysis
|
||||
query_results: Query results
|
||||
backend_used: Backend used for query
|
||||
execution_time: Answer generation time
|
||||
confidence: Confidence score
|
||||
|
||||
Returns:
|
||||
Complete answer response
|
||||
"""
|
||||
# Extract supporting facts
|
||||
supporting_facts = self._extract_facts(query_results)
|
||||
|
||||
# Build metadata
|
||||
result_count = 0
|
||||
if isinstance(query_results, SPARQLResult):
|
||||
result_count = len(query_results.bindings)
|
||||
else: # CypherResult
|
||||
result_count = len(query_results.records)
|
||||
|
||||
metadata = AnswerMetadata(
|
||||
query_type=question_components.question_type.value,
|
||||
backend_used=backend_used,
|
||||
execution_time=execution_time,
|
||||
result_count=result_count,
|
||||
confidence=confidence,
|
||||
explanation=f"Generated answer using {backend_used} backend",
|
||||
sources=[] # Could be populated with data source information
|
||||
)
|
||||
|
||||
return GeneratedAnswer(
|
||||
answer=answer,
|
||||
metadata=metadata,
|
||||
supporting_facts=supporting_facts[:5], # Limit to top 5
|
||||
raw_results=query_results,
|
||||
generation_time=execution_time
|
||||
)
|
||||
350
trustgraph-flow/trustgraph/query/ontology/backend_router.py
Normal file
350
trustgraph-flow/trustgraph/query/ontology/backend_router.py
Normal file
|
|
@ -0,0 +1,350 @@
|
|||
"""
|
||||
Backend router for ontology query system.
|
||||
Routes queries to appropriate backend based on configuration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from .question_analyzer import QuestionComponents
|
||||
from .ontology_matcher import QueryOntologySubset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BackendType(Enum):
|
||||
"""Supported backend types."""
|
||||
CASSANDRA = "cassandra"
|
||||
NEO4J = "neo4j"
|
||||
MEMGRAPH = "memgraph"
|
||||
FALKORDB = "falkordb"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
"""Configuration for a backend."""
|
||||
type: BackendType
|
||||
priority: int = 0
|
||||
enabled: bool = True
|
||||
config: Dict[str, Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryRoute:
|
||||
"""Routing decision for a query."""
|
||||
backend_type: BackendType
|
||||
query_language: str # 'sparql' or 'cypher'
|
||||
confidence: float
|
||||
reasoning: str
|
||||
|
||||
|
||||
class BackendRouter:
|
||||
"""Routes queries to appropriate backends based on configuration and heuristics."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""Initialize backend router.
|
||||
|
||||
Args:
|
||||
config: Router configuration
|
||||
"""
|
||||
self.config = config
|
||||
self.backends = self._parse_backend_config(config)
|
||||
self.routing_strategy = config.get('routing_strategy', 'priority')
|
||||
self.enable_fallback = config.get('enable_fallback', True)
|
||||
|
||||
def _parse_backend_config(self, config: Dict[str, Any]) -> Dict[BackendType, BackendConfig]:
|
||||
"""Parse backend configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
Dictionary of backend type to configuration
|
||||
"""
|
||||
backends = {}
|
||||
|
||||
# Parse primary backend
|
||||
primary = config.get('primary', 'cassandra')
|
||||
if primary:
|
||||
try:
|
||||
backend_type = BackendType(primary)
|
||||
backends[backend_type] = BackendConfig(
|
||||
type=backend_type,
|
||||
priority=100,
|
||||
enabled=True,
|
||||
config=config.get(primary, {})
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(f"Unknown primary backend type: {primary}")
|
||||
|
||||
# Parse fallback backends
|
||||
fallbacks = config.get('fallback', [])
|
||||
for i, fallback in enumerate(fallbacks):
|
||||
try:
|
||||
backend_type = BackendType(fallback)
|
||||
backends[backend_type] = BackendConfig(
|
||||
type=backend_type,
|
||||
priority=50 - i * 10, # Decreasing priority
|
||||
enabled=True,
|
||||
config=config.get(fallback, {})
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(f"Unknown fallback backend type: {fallback}")
|
||||
|
||||
return backends
|
||||
|
||||
def route_query(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subsets: List[QueryOntologySubset]) -> QueryRoute:
|
||||
"""Route a query to the best backend.
|
||||
|
||||
Args:
|
||||
question_components: Analyzed question
|
||||
ontology_subsets: Relevant ontology subsets
|
||||
|
||||
Returns:
|
||||
QueryRoute with routing decision
|
||||
"""
|
||||
if self.routing_strategy == 'priority':
|
||||
return self._route_by_priority()
|
||||
elif self.routing_strategy == 'adaptive':
|
||||
return self._route_adaptive(question_components, ontology_subsets)
|
||||
elif self.routing_strategy == 'round_robin':
|
||||
return self._route_round_robin()
|
||||
else:
|
||||
return self._route_by_priority()
|
||||
|
||||
def _route_by_priority(self) -> QueryRoute:
|
||||
"""Route based on backend priority.
|
||||
|
||||
Returns:
|
||||
QueryRoute to highest priority backend
|
||||
"""
|
||||
# Find highest priority enabled backend
|
||||
best_backend = None
|
||||
best_priority = -1
|
||||
|
||||
for backend_type, backend_config in self.backends.items():
|
||||
if backend_config.enabled and backend_config.priority > best_priority:
|
||||
best_backend = backend_type
|
||||
best_priority = backend_config.priority
|
||||
|
||||
if best_backend is None:
|
||||
raise RuntimeError("No enabled backends available")
|
||||
|
||||
# Determine query language
|
||||
query_language = 'sparql' if best_backend == BackendType.CASSANDRA else 'cypher'
|
||||
|
||||
return QueryRoute(
|
||||
backend_type=best_backend,
|
||||
query_language=query_language,
|
||||
confidence=1.0,
|
||||
reasoning=f"Priority routing to {best_backend.value}"
|
||||
)
|
||||
|
||||
def _route_adaptive(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subsets: List[QueryOntologySubset]) -> QueryRoute:
|
||||
"""Route based on question characteristics and ontology complexity.
|
||||
|
||||
Args:
|
||||
question_components: Analyzed question
|
||||
ontology_subsets: Relevant ontology subsets
|
||||
|
||||
Returns:
|
||||
QueryRoute with adaptive decision
|
||||
"""
|
||||
scores = {}
|
||||
|
||||
for backend_type, backend_config in self.backends.items():
|
||||
if not backend_config.enabled:
|
||||
continue
|
||||
|
||||
score = self._calculate_backend_score(
|
||||
backend_type, question_components, ontology_subsets
|
||||
)
|
||||
scores[backend_type] = score
|
||||
|
||||
if not scores:
|
||||
raise RuntimeError("No enabled backends available")
|
||||
|
||||
# Select backend with highest score
|
||||
best_backend = max(scores.keys(), key=lambda k: scores[k])
|
||||
best_score = scores[best_backend]
|
||||
|
||||
# Determine query language
|
||||
query_language = 'sparql' if best_backend == BackendType.CASSANDRA else 'cypher'
|
||||
|
||||
return QueryRoute(
|
||||
backend_type=best_backend,
|
||||
query_language=query_language,
|
||||
confidence=best_score,
|
||||
reasoning=f"Adaptive routing: {best_backend.value} scored {best_score:.2f}"
|
||||
)
|
||||
|
||||
def _calculate_backend_score(self,
|
||||
backend_type: BackendType,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subsets: List[QueryOntologySubset]) -> float:
|
||||
"""Calculate score for a backend based on query characteristics.
|
||||
|
||||
Args:
|
||||
backend_type: Backend to score
|
||||
question_components: Question analysis
|
||||
ontology_subsets: Ontology subsets
|
||||
|
||||
Returns:
|
||||
Score (0.0 to 1.0)
|
||||
"""
|
||||
score = 0.0
|
||||
|
||||
# Base priority score
|
||||
backend_config = self.backends[backend_type]
|
||||
score += backend_config.priority / 100.0
|
||||
|
||||
# Question type preferences
|
||||
if backend_type == BackendType.CASSANDRA:
|
||||
# SPARQL is good for hierarchical and complex reasoning
|
||||
if question_components.question_type.value in ['factual', 'aggregation']:
|
||||
score += 0.3
|
||||
# Good for ontology-heavy queries
|
||||
if len(ontology_subsets) > 1:
|
||||
score += 0.2
|
||||
else:
|
||||
# Cypher is good for graph traversal and relationships
|
||||
if question_components.question_type.value in ['relationship', 'retrieval']:
|
||||
score += 0.3
|
||||
# Good for simple graph patterns
|
||||
if len(question_components.relationships) > 0:
|
||||
score += 0.2
|
||||
|
||||
# Complexity considerations
|
||||
total_elements = sum(
|
||||
len(subset.classes) + len(subset.object_properties) + len(subset.datatype_properties)
|
||||
for subset in ontology_subsets
|
||||
)
|
||||
|
||||
if backend_type == BackendType.CASSANDRA:
|
||||
# SPARQL handles complex ontologies well
|
||||
if total_elements > 20:
|
||||
score += 0.2
|
||||
else:
|
||||
# Cypher is efficient for simpler queries
|
||||
if total_elements <= 10:
|
||||
score += 0.2
|
||||
|
||||
# Aggregation considerations
|
||||
if question_components.aggregations:
|
||||
if backend_type == BackendType.CASSANDRA:
|
||||
score += 0.1 # SPARQL has built-in aggregation
|
||||
else:
|
||||
score += 0.2 # Cypher has excellent aggregation
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
def _route_round_robin(self) -> QueryRoute:
|
||||
"""Route using round-robin strategy.
|
||||
|
||||
Returns:
|
||||
QueryRoute using round-robin selection
|
||||
"""
|
||||
# Simple round-robin implementation
|
||||
enabled_backends = [
|
||||
bt for bt, bc in self.backends.items() if bc.enabled
|
||||
]
|
||||
|
||||
if not enabled_backends:
|
||||
raise RuntimeError("No enabled backends available")
|
||||
|
||||
# For simplicity, just return the first enabled backend
|
||||
# In a real implementation, you'd track state
|
||||
backend_type = enabled_backends[0]
|
||||
query_language = 'sparql' if backend_type == BackendType.CASSANDRA else 'cypher'
|
||||
|
||||
return QueryRoute(
|
||||
backend_type=backend_type,
|
||||
query_language=query_language,
|
||||
confidence=0.8,
|
||||
reasoning=f"Round-robin routing to {backend_type.value}"
|
||||
)
|
||||
|
||||
def get_fallback_route(self, failed_backend: BackendType) -> Optional[QueryRoute]:
|
||||
"""Get fallback route when a backend fails.
|
||||
|
||||
Args:
|
||||
failed_backend: Backend that failed
|
||||
|
||||
Returns:
|
||||
Fallback route or None if no fallback available
|
||||
"""
|
||||
if not self.enable_fallback:
|
||||
return None
|
||||
|
||||
# Find next best backend
|
||||
fallback_backends = [
|
||||
(bt, bc) for bt, bc in self.backends.items()
|
||||
if bc.enabled and bt != failed_backend
|
||||
]
|
||||
|
||||
if not fallback_backends:
|
||||
return None
|
||||
|
||||
# Sort by priority
|
||||
fallback_backends.sort(key=lambda x: x[1].priority, reverse=True)
|
||||
fallback_type = fallback_backends[0][0]
|
||||
|
||||
query_language = 'sparql' if fallback_type == BackendType.CASSANDRA else 'cypher'
|
||||
|
||||
return QueryRoute(
|
||||
backend_type=fallback_type,
|
||||
query_language=query_language,
|
||||
confidence=0.7,
|
||||
reasoning=f"Fallback from {failed_backend.value} to {fallback_type.value}"
|
||||
)
|
||||
|
||||
def get_available_backends(self) -> List[BackendType]:
|
||||
"""Get list of available backends.
|
||||
|
||||
Returns:
|
||||
List of enabled backend types
|
||||
"""
|
||||
return [bt for bt, bc in self.backends.items() if bc.enabled]
|
||||
|
||||
def is_backend_enabled(self, backend_type: BackendType) -> bool:
|
||||
"""Check if a backend is enabled.
|
||||
|
||||
Args:
|
||||
backend_type: Backend to check
|
||||
|
||||
Returns:
|
||||
True if backend is enabled
|
||||
"""
|
||||
backend_config = self.backends.get(backend_type)
|
||||
return backend_config is not None and backend_config.enabled
|
||||
|
||||
def update_backend_status(self, backend_type: BackendType, enabled: bool):
|
||||
"""Update backend enabled status.
|
||||
|
||||
Args:
|
||||
backend_type: Backend to update
|
||||
enabled: New enabled status
|
||||
"""
|
||||
if backend_type in self.backends:
|
||||
self.backends[backend_type].enabled = enabled
|
||||
logger.info(f"Backend {backend_type.value} {'enabled' if enabled else 'disabled'}")
|
||||
else:
|
||||
logger.warning(f"Unknown backend type: {backend_type}")
|
||||
|
||||
def get_backend_config(self, backend_type: BackendType) -> Optional[Dict[str, Any]]:
|
||||
"""Get configuration for a backend.
|
||||
|
||||
Args:
|
||||
backend_type: Backend type
|
||||
|
||||
Returns:
|
||||
Configuration dictionary or None
|
||||
"""
|
||||
backend_config = self.backends.get(backend_type)
|
||||
return backend_config.config if backend_config else None
|
||||
651
trustgraph-flow/trustgraph/query/ontology/cache.py
Normal file
651
trustgraph-flow/trustgraph/query/ontology/cache.py
Normal file
|
|
@ -0,0 +1,651 @@
|
|||
"""
|
||||
Caching system for OntoRAG query results and computations.
|
||||
Provides multiple cache backends and intelligent cache management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
import pickle
|
||||
from typing import Dict, Any, List, Optional, Union, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime, timedelta
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""Cache entry with metadata."""
|
||||
key: str
|
||||
value: Any
|
||||
created_at: datetime
|
||||
accessed_at: datetime
|
||||
access_count: int
|
||||
ttl_seconds: Optional[int] = None
|
||||
tags: List[str] = None
|
||||
size_bytes: int = 0
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if cache entry is expired."""
|
||||
if self.ttl_seconds is None:
|
||||
return False
|
||||
return (datetime.now() - self.created_at).total_seconds() > self.ttl_seconds
|
||||
|
||||
def touch(self):
|
||||
"""Update access time and count."""
|
||||
self.accessed_at = datetime.now()
|
||||
self.access_count += 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""Cache performance statistics."""
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
evictions: int = 0
|
||||
total_entries: int = 0
|
||||
total_size_bytes: int = 0
|
||||
hit_rate: float = 0.0
|
||||
|
||||
def update_hit_rate(self):
|
||||
"""Update hit rate calculation."""
|
||||
total_requests = self.hits + self.misses
|
||||
self.hit_rate = self.hits / total_requests if total_requests > 0 else 0.0
|
||||
|
||||
|
||||
class CacheBackend(ABC):
|
||||
"""Abstract base class for cache backends."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Get cache entry by key."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None):
|
||||
"""Set cache entry."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: str) -> bool:
|
||||
"""Delete cache entry."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, tags: Optional[List[str]] = None):
|
||||
"""Clear cache entries."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_stats(self) -> CacheStats:
|
||||
"""Get cache statistics."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cleanup_expired(self):
|
||||
"""Clean up expired entries."""
|
||||
pass
|
||||
|
||||
|
||||
class InMemoryCache(CacheBackend):
|
||||
"""In-memory cache backend."""
|
||||
|
||||
def __init__(self, max_size: int = 1000, max_size_bytes: int = 100 * 1024 * 1024):
|
||||
"""Initialize in-memory cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries
|
||||
max_size_bytes: Maximum total size in bytes
|
||||
"""
|
||||
self.max_size = max_size
|
||||
self.max_size_bytes = max_size_bytes
|
||||
self.entries: Dict[str, CacheEntry] = {}
|
||||
self.stats = CacheStats()
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Get cache entry by key."""
|
||||
with self._lock:
|
||||
entry = self.entries.get(key)
|
||||
if entry is None:
|
||||
self.stats.misses += 1
|
||||
self.stats.update_hit_rate()
|
||||
return None
|
||||
|
||||
if entry.is_expired():
|
||||
del self.entries[key]
|
||||
self.stats.misses += 1
|
||||
self.stats.evictions += 1
|
||||
self.stats.total_entries -= 1
|
||||
self.stats.total_size_bytes -= entry.size_bytes
|
||||
self.stats.update_hit_rate()
|
||||
return None
|
||||
|
||||
entry.touch()
|
||||
self.stats.hits += 1
|
||||
self.stats.update_hit_rate()
|
||||
return entry
|
||||
|
||||
def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None):
|
||||
"""Set cache entry."""
|
||||
with self._lock:
|
||||
# Calculate size
|
||||
try:
|
||||
size_bytes = len(pickle.dumps(value))
|
||||
except Exception:
|
||||
size_bytes = len(str(value).encode('utf-8'))
|
||||
|
||||
# Create entry
|
||||
now = datetime.now()
|
||||
entry = CacheEntry(
|
||||
key=key,
|
||||
value=value,
|
||||
created_at=now,
|
||||
accessed_at=now,
|
||||
access_count=1,
|
||||
ttl_seconds=ttl_seconds,
|
||||
tags=tags or [],
|
||||
size_bytes=size_bytes
|
||||
)
|
||||
|
||||
# Check if we need to evict
|
||||
self._ensure_capacity(size_bytes)
|
||||
|
||||
# Store entry
|
||||
old_entry = self.entries.get(key)
|
||||
if old_entry:
|
||||
self.stats.total_size_bytes -= old_entry.size_bytes
|
||||
else:
|
||||
self.stats.total_entries += 1
|
||||
|
||||
self.entries[key] = entry
|
||||
self.stats.total_size_bytes += size_bytes
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""Delete cache entry."""
|
||||
with self._lock:
|
||||
entry = self.entries.pop(key, None)
|
||||
if entry:
|
||||
self.stats.total_entries -= 1
|
||||
self.stats.total_size_bytes -= entry.size_bytes
|
||||
self.stats.evictions += 1
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self, tags: Optional[List[str]] = None):
|
||||
"""Clear cache entries."""
|
||||
with self._lock:
|
||||
if tags is None:
|
||||
# Clear all
|
||||
self.stats.evictions += len(self.entries)
|
||||
self.entries.clear()
|
||||
self.stats.total_entries = 0
|
||||
self.stats.total_size_bytes = 0
|
||||
else:
|
||||
# Clear by tags
|
||||
to_delete = []
|
||||
for key, entry in self.entries.items():
|
||||
if any(tag in entry.tags for tag in tags):
|
||||
to_delete.append(key)
|
||||
|
||||
for key in to_delete:
|
||||
self.delete(key)
|
||||
|
||||
def get_stats(self) -> CacheStats:
|
||||
"""Get cache statistics."""
|
||||
with self._lock:
|
||||
return CacheStats(
|
||||
hits=self.stats.hits,
|
||||
misses=self.stats.misses,
|
||||
evictions=self.stats.evictions,
|
||||
total_entries=self.stats.total_entries,
|
||||
total_size_bytes=self.stats.total_size_bytes,
|
||||
hit_rate=self.stats.hit_rate
|
||||
)
|
||||
|
||||
def cleanup_expired(self):
|
||||
"""Clean up expired entries."""
|
||||
with self._lock:
|
||||
to_delete = []
|
||||
for key, entry in self.entries.items():
|
||||
if entry.is_expired():
|
||||
to_delete.append(key)
|
||||
|
||||
for key in to_delete:
|
||||
self.delete(key)
|
||||
|
||||
def _ensure_capacity(self, new_size_bytes: int):
|
||||
"""Ensure cache has capacity for new entry."""
|
||||
# Check size limit
|
||||
if self.stats.total_size_bytes + new_size_bytes > self.max_size_bytes:
|
||||
self._evict_by_size(new_size_bytes)
|
||||
|
||||
# Check count limit
|
||||
if len(self.entries) >= self.max_size:
|
||||
self._evict_by_count()
|
||||
|
||||
def _evict_by_size(self, needed_bytes: int):
|
||||
"""Evict entries to free up space."""
|
||||
# Sort by access time (LRU)
|
||||
sorted_entries = sorted(
|
||||
self.entries.items(),
|
||||
key=lambda x: (x[1].accessed_at, x[1].access_count)
|
||||
)
|
||||
|
||||
freed_bytes = 0
|
||||
for key, entry in sorted_entries:
|
||||
if freed_bytes >= needed_bytes:
|
||||
break
|
||||
freed_bytes += entry.size_bytes
|
||||
del self.entries[key]
|
||||
self.stats.total_entries -= 1
|
||||
self.stats.total_size_bytes -= entry.size_bytes
|
||||
self.stats.evictions += 1
|
||||
|
||||
def _evict_by_count(self):
|
||||
"""Evict least recently used entry."""
|
||||
if not self.entries:
|
||||
return
|
||||
|
||||
# Find LRU entry
|
||||
lru_key = min(
|
||||
self.entries.keys(),
|
||||
key=lambda k: (self.entries[k].accessed_at, self.entries[k].access_count)
|
||||
)
|
||||
self.delete(lru_key)
|
||||
|
||||
|
||||
class FileCache(CacheBackend):
|
||||
"""File-based cache backend."""
|
||||
|
||||
def __init__(self, cache_dir: str, max_files: int = 10000):
|
||||
"""Initialize file cache.
|
||||
|
||||
Args:
|
||||
cache_dir: Directory to store cache files
|
||||
max_files: Maximum number of cache files
|
||||
"""
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.max_files = max_files
|
||||
self.stats = CacheStats()
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Load existing stats
|
||||
self._load_stats()
|
||||
|
||||
def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Get cache entry by key."""
|
||||
with self._lock:
|
||||
cache_file = self.cache_dir / f"{self._safe_key(key)}.cache"
|
||||
if not cache_file.exists():
|
||||
self.stats.misses += 1
|
||||
self.stats.update_hit_rate()
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(cache_file, 'rb') as f:
|
||||
entry = pickle.load(f)
|
||||
|
||||
if entry.is_expired():
|
||||
cache_file.unlink()
|
||||
self.stats.misses += 1
|
||||
self.stats.evictions += 1
|
||||
self.stats.total_entries -= 1
|
||||
self.stats.update_hit_rate()
|
||||
return None
|
||||
|
||||
entry.touch()
|
||||
# Update file modification time
|
||||
cache_file.touch()
|
||||
|
||||
self.stats.hits += 1
|
||||
self.stats.update_hit_rate()
|
||||
return entry
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading cache file {cache_file}: {e}")
|
||||
cache_file.unlink(missing_ok=True)
|
||||
self.stats.misses += 1
|
||||
self.stats.update_hit_rate()
|
||||
return None
|
||||
|
||||
def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None):
|
||||
"""Set cache entry."""
|
||||
with self._lock:
|
||||
cache_file = self.cache_dir / f"{self._safe_key(key)}.cache"
|
||||
|
||||
# Create entry
|
||||
now = datetime.now()
|
||||
entry = CacheEntry(
|
||||
key=key,
|
||||
value=value,
|
||||
created_at=now,
|
||||
accessed_at=now,
|
||||
access_count=1,
|
||||
ttl_seconds=ttl_seconds,
|
||||
tags=tags or []
|
||||
)
|
||||
|
||||
try:
|
||||
# Ensure capacity
|
||||
self._ensure_capacity()
|
||||
|
||||
# Write to file
|
||||
with open(cache_file, 'wb') as f:
|
||||
pickle.dump(entry, f)
|
||||
|
||||
entry.size_bytes = cache_file.stat().st_size
|
||||
|
||||
if not cache_file.exists():
|
||||
self.stats.total_entries += 1
|
||||
|
||||
self.stats.total_size_bytes += entry.size_bytes
|
||||
self._save_stats()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing cache file {cache_file}: {e}")
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""Delete cache entry."""
|
||||
with self._lock:
|
||||
cache_file = self.cache_dir / f"{self._safe_key(key)}.cache"
|
||||
if cache_file.exists():
|
||||
size = cache_file.stat().st_size
|
||||
cache_file.unlink()
|
||||
self.stats.total_entries -= 1
|
||||
self.stats.total_size_bytes -= size
|
||||
self.stats.evictions += 1
|
||||
self._save_stats()
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self, tags: Optional[List[str]] = None):
|
||||
"""Clear cache entries."""
|
||||
with self._lock:
|
||||
if tags is None:
|
||||
# Clear all
|
||||
for cache_file in self.cache_dir.glob("*.cache"):
|
||||
cache_file.unlink()
|
||||
self.stats.evictions += self.stats.total_entries
|
||||
self.stats.total_entries = 0
|
||||
self.stats.total_size_bytes = 0
|
||||
else:
|
||||
# Clear by tags
|
||||
for cache_file in self.cache_dir.glob("*.cache"):
|
||||
try:
|
||||
with open(cache_file, 'rb') as f:
|
||||
entry = pickle.load(f)
|
||||
if any(tag in entry.tags for tag in tags):
|
||||
size = cache_file.stat().st_size
|
||||
cache_file.unlink()
|
||||
self.stats.total_entries -= 1
|
||||
self.stats.total_size_bytes -= size
|
||||
self.stats.evictions += 1
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
self._save_stats()
|
||||
|
||||
def get_stats(self) -> CacheStats:
|
||||
"""Get cache statistics."""
|
||||
with self._lock:
|
||||
return CacheStats(
|
||||
hits=self.stats.hits,
|
||||
misses=self.stats.misses,
|
||||
evictions=self.stats.evictions,
|
||||
total_entries=self.stats.total_entries,
|
||||
total_size_bytes=self.stats.total_size_bytes,
|
||||
hit_rate=self.stats.hit_rate
|
||||
)
|
||||
|
||||
def cleanup_expired(self):
|
||||
"""Clean up expired entries."""
|
||||
with self._lock:
|
||||
for cache_file in self.cache_dir.glob("*.cache"):
|
||||
try:
|
||||
with open(cache_file, 'rb') as f:
|
||||
entry = pickle.load(f)
|
||||
if entry.is_expired():
|
||||
size = cache_file.stat().st_size
|
||||
cache_file.unlink()
|
||||
self.stats.total_entries -= 1
|
||||
self.stats.total_size_bytes -= size
|
||||
self.stats.evictions += 1
|
||||
except Exception:
|
||||
# Remove corrupted files
|
||||
cache_file.unlink()
|
||||
|
||||
self._save_stats()
|
||||
|
||||
def _safe_key(self, key: str) -> str:
|
||||
"""Convert key to safe filename."""
|
||||
import hashlib
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
def _ensure_capacity(self):
|
||||
"""Ensure cache has capacity for new entry."""
|
||||
cache_files = list(self.cache_dir.glob("*.cache"))
|
||||
if len(cache_files) >= self.max_files:
|
||||
# Remove oldest file
|
||||
oldest_file = min(cache_files, key=lambda f: f.stat().st_mtime)
|
||||
size = oldest_file.stat().st_size
|
||||
oldest_file.unlink()
|
||||
self.stats.total_entries -= 1
|
||||
self.stats.total_size_bytes -= size
|
||||
self.stats.evictions += 1
|
||||
|
||||
def _load_stats(self):
|
||||
"""Load statistics from file."""
|
||||
stats_file = self.cache_dir / "stats.json"
|
||||
if stats_file.exists():
|
||||
try:
|
||||
with open(stats_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
self.stats = CacheStats(**data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _save_stats(self):
|
||||
"""Save statistics to file."""
|
||||
stats_file = self.cache_dir / "stats.json"
|
||||
try:
|
||||
with open(stats_file, 'w') as f:
|
||||
json.dump(asdict(self.stats), f, default=str)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""Cache manager with multiple backends and intelligent caching strategies."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""Initialize cache manager.
|
||||
|
||||
Args:
|
||||
config: Cache configuration
|
||||
"""
|
||||
self.config = config
|
||||
self.backends: Dict[str, CacheBackend] = {}
|
||||
self.default_backend = config.get('default_backend', 'memory')
|
||||
self.default_ttl = config.get('default_ttl_seconds', 3600) # 1 hour
|
||||
|
||||
# Initialize backends
|
||||
self._init_backends()
|
||||
|
||||
# Start cleanup task
|
||||
self.cleanup_interval = config.get('cleanup_interval_seconds', 300) # 5 minutes
|
||||
self._start_cleanup_task()
|
||||
|
||||
def _init_backends(self):
|
||||
"""Initialize cache backends."""
|
||||
backends_config = self.config.get('backends', {})
|
||||
|
||||
# Memory backend
|
||||
if 'memory' in backends_config or self.default_backend == 'memory':
|
||||
memory_config = backends_config.get('memory', {})
|
||||
self.backends['memory'] = InMemoryCache(
|
||||
max_size=memory_config.get('max_size', 1000),
|
||||
max_size_bytes=memory_config.get('max_size_bytes', 100 * 1024 * 1024)
|
||||
)
|
||||
|
||||
# File backend
|
||||
if 'file' in backends_config or self.default_backend == 'file':
|
||||
file_config = backends_config.get('file', {})
|
||||
self.backends['file'] = FileCache(
|
||||
cache_dir=file_config.get('cache_dir', './cache'),
|
||||
max_files=file_config.get('max_files', 10000)
|
||||
)
|
||||
|
||||
def get(self, key: str, backend: Optional[str] = None) -> Optional[Any]:
|
||||
"""Get value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
backend: Backend name (optional)
|
||||
|
||||
Returns:
|
||||
Cached value or None
|
||||
"""
|
||||
backend_name = backend or self.default_backend
|
||||
cache_backend = self.backends.get(backend_name)
|
||||
|
||||
if cache_backend is None:
|
||||
logger.warning(f"Cache backend '{backend_name}' not found")
|
||||
return None
|
||||
|
||||
entry = cache_backend.get(key)
|
||||
return entry.value if entry else None
|
||||
|
||||
def set(self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
backend: Optional[str] = None):
|
||||
"""Set value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl_seconds: Time to live in seconds
|
||||
tags: Cache tags
|
||||
backend: Backend name (optional)
|
||||
"""
|
||||
backend_name = backend or self.default_backend
|
||||
cache_backend = self.backends.get(backend_name)
|
||||
|
||||
if cache_backend is None:
|
||||
logger.warning(f"Cache backend '{backend_name}' not found")
|
||||
return
|
||||
|
||||
ttl = ttl_seconds if ttl_seconds is not None else self.default_ttl
|
||||
cache_backend.set(key, value, ttl, tags)
|
||||
|
||||
def delete(self, key: str, backend: Optional[str] = None) -> bool:
|
||||
"""Delete value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
backend: Backend name (optional)
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
backend_name = backend or self.default_backend
|
||||
cache_backend = self.backends.get(backend_name)
|
||||
|
||||
if cache_backend is None:
|
||||
return False
|
||||
|
||||
return cache_backend.delete(key)
|
||||
|
||||
def clear(self, tags: Optional[List[str]] = None, backend: Optional[str] = None):
|
||||
"""Clear cache entries.
|
||||
|
||||
Args:
|
||||
tags: Tags to clear (optional)
|
||||
backend: Backend name (optional)
|
||||
"""
|
||||
if backend:
|
||||
cache_backend = self.backends.get(backend)
|
||||
if cache_backend:
|
||||
cache_backend.clear(tags)
|
||||
else:
|
||||
# Clear all backends
|
||||
for cache_backend in self.backends.values():
|
||||
cache_backend.clear(tags)
|
||||
|
||||
def get_stats(self) -> Dict[str, CacheStats]:
|
||||
"""Get statistics for all backends.
|
||||
|
||||
Returns:
|
||||
Dictionary of backend name to statistics
|
||||
"""
|
||||
return {name: backend.get_stats() for name, backend in self.backends.items()}
|
||||
|
||||
def cleanup_expired(self):
|
||||
"""Clean up expired entries in all backends."""
|
||||
for backend in self.backends.values():
|
||||
try:
|
||||
backend.cleanup_expired()
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up cache backend: {e}")
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""Start periodic cleanup task."""
|
||||
def cleanup_worker():
|
||||
while True:
|
||||
try:
|
||||
time.sleep(self.cleanup_interval)
|
||||
self.cleanup_expired()
|
||||
except Exception as e:
|
||||
logger.error(f"Cache cleanup error: {e}")
|
||||
|
||||
cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
|
||||
cleanup_thread.start()
|
||||
|
||||
|
||||
# Cache decorators and utilities
|
||||
|
||||
def cache_result(cache_manager: CacheManager,
|
||||
key_func: Optional[callable] = None,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
backend: Optional[str] = None):
|
||||
"""Decorator to cache function results.
|
||||
|
||||
Args:
|
||||
cache_manager: Cache manager instance
|
||||
key_func: Function to generate cache key
|
||||
ttl_seconds: Time to live
|
||||
tags: Cache tags
|
||||
backend: Backend name
|
||||
"""
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
# Generate cache key
|
||||
if key_func:
|
||||
cache_key = key_func(*args, **kwargs)
|
||||
else:
|
||||
cache_key = f"{func.__name__}:{hash((args, tuple(sorted(kwargs.items()))))}"
|
||||
|
||||
# Try to get from cache
|
||||
cached_result = cache_manager.get(cache_key, backend)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
# Execute function
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Cache result
|
||||
cache_manager.set(cache_key, result, ttl_seconds, tags, backend)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
610
trustgraph-flow/trustgraph/query/ontology/cypher_executor.py
Normal file
610
trustgraph-flow/trustgraph/query/ontology/cypher_executor.py
Normal file
|
|
@ -0,0 +1,610 @@
|
|||
"""
|
||||
Cypher executor for multiple graph databases.
|
||||
Executes Cypher queries against Neo4j, Memgraph, and FalkorDB.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from .cypher_generator import CypherQuery
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import various database drivers
|
||||
try:
|
||||
from neo4j import GraphDatabase, Driver as Neo4jDriver
|
||||
NEO4J_AVAILABLE = True
|
||||
except ImportError:
|
||||
NEO4J_AVAILABLE = False
|
||||
Neo4jDriver = None
|
||||
|
||||
try:
|
||||
import redis
|
||||
REDIS_AVAILABLE = True
|
||||
except ImportError:
|
||||
REDIS_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CypherResult:
|
||||
"""Result from Cypher query execution."""
|
||||
records: List[Dict[str, Any]]
|
||||
summary: Dict[str, Any]
|
||||
execution_time: float
|
||||
database_type: str
|
||||
query_plan: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class CypherExecutorBase(ABC):
|
||||
"""Abstract base class for Cypher executors."""
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, cypher_query: CypherQuery) -> CypherResult:
|
||||
"""Execute Cypher query."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
"""Close database connection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected to database."""
|
||||
pass
|
||||
|
||||
|
||||
class Neo4jExecutor(CypherExecutorBase):
|
||||
"""Cypher executor for Neo4j database."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""Initialize Neo4j executor.
|
||||
|
||||
Args:
|
||||
config: Neo4j configuration
|
||||
"""
|
||||
if not NEO4J_AVAILABLE:
|
||||
raise RuntimeError("Neo4j driver not available")
|
||||
|
||||
self.config = config
|
||||
self.driver: Optional[Neo4jDriver] = None
|
||||
self._connection_pool_size = config.get('connection_pool_size', 10)
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to Neo4j database."""
|
||||
try:
|
||||
uri = self.config.get('uri', 'bolt://localhost:7687')
|
||||
username = self.config.get('username')
|
||||
password = self.config.get('password')
|
||||
|
||||
auth = (username, password) if username and password else None
|
||||
|
||||
# Create driver with connection pool
|
||||
self.driver = GraphDatabase.driver(
|
||||
uri,
|
||||
auth=auth,
|
||||
max_connection_pool_size=self._connection_pool_size,
|
||||
connection_timeout=self.config.get('connection_timeout', 30),
|
||||
max_retry_time=self.config.get('max_retry_time', 15)
|
||||
)
|
||||
|
||||
# Verify connectivity
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.driver.verify_connectivity
|
||||
)
|
||||
|
||||
logger.info(f"Connected to Neo4j at {uri}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Neo4j: {e}")
|
||||
raise
|
||||
|
||||
async def execute(self, cypher_query: CypherQuery) -> CypherResult:
|
||||
"""Execute Cypher query against Neo4j.
|
||||
|
||||
Args:
|
||||
cypher_query: Cypher query to execute
|
||||
|
||||
Returns:
|
||||
Query results
|
||||
"""
|
||||
if not self.driver:
|
||||
await self.connect()
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Execute query in a session
|
||||
records = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self._execute_sync, cypher_query
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
return CypherResult(
|
||||
records=records,
|
||||
summary={'record_count': len(records)},
|
||||
execution_time=execution_time,
|
||||
database_type='neo4j'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Neo4j query execution error: {e}")
|
||||
execution_time = time.time() - start_time
|
||||
return CypherResult(
|
||||
records=[],
|
||||
summary={'error': str(e)},
|
||||
execution_time=execution_time,
|
||||
database_type='neo4j'
|
||||
)
|
||||
|
||||
def _execute_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]:
|
||||
"""Execute query synchronously in thread executor.
|
||||
|
||||
Args:
|
||||
cypher_query: Cypher query to execute
|
||||
|
||||
Returns:
|
||||
List of record dictionaries
|
||||
"""
|
||||
with self.driver.session() as session:
|
||||
result = session.run(cypher_query.query, cypher_query.parameters)
|
||||
records = []
|
||||
for record in result:
|
||||
record_dict = {}
|
||||
for key in record.keys():
|
||||
value = record[key]
|
||||
record_dict[key] = self._format_neo4j_value(value)
|
||||
records.append(record_dict)
|
||||
return records
|
||||
|
||||
def _format_neo4j_value(self, value):
|
||||
"""Format Neo4j value for JSON serialization.
|
||||
|
||||
Args:
|
||||
value: Neo4j value
|
||||
|
||||
Returns:
|
||||
JSON-serializable value
|
||||
"""
|
||||
# Handle Neo4j node objects
|
||||
if hasattr(value, 'labels') and hasattr(value, 'items'):
|
||||
return {
|
||||
'labels': list(value.labels),
|
||||
'properties': dict(value.items())
|
||||
}
|
||||
# Handle Neo4j relationship objects
|
||||
elif hasattr(value, 'type') and hasattr(value, 'items'):
|
||||
return {
|
||||
'type': value.type,
|
||||
'properties': dict(value.items())
|
||||
}
|
||||
# Handle Neo4j path objects
|
||||
elif hasattr(value, 'nodes') and hasattr(value, 'relationships'):
|
||||
return {
|
||||
'nodes': [self._format_neo4j_value(n) for n in value.nodes],
|
||||
'relationships': [self._format_neo4j_value(r) for r in value.relationships]
|
||||
}
|
||||
else:
|
||||
return value
|
||||
|
||||
async def close(self):
|
||||
"""Close Neo4j connection."""
|
||||
if self.driver:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.driver.close
|
||||
)
|
||||
self.driver = None
|
||||
logger.info("Neo4j connection closed")
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected to Neo4j."""
|
||||
return self.driver is not None
|
||||
|
||||
|
||||
class MemgraphExecutor(CypherExecutorBase):
|
||||
"""Cypher executor for Memgraph database."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""Initialize Memgraph executor.
|
||||
|
||||
Args:
|
||||
config: Memgraph configuration
|
||||
"""
|
||||
if not NEO4J_AVAILABLE: # Memgraph uses Neo4j driver
|
||||
raise RuntimeError("Neo4j driver required for Memgraph")
|
||||
|
||||
self.config = config
|
||||
self.driver: Optional[Neo4jDriver] = None
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to Memgraph database."""
|
||||
try:
|
||||
uri = self.config.get('uri', 'bolt://localhost:7688')
|
||||
username = self.config.get('username')
|
||||
password = self.config.get('password')
|
||||
|
||||
auth = (username, password) if username and password else None
|
||||
|
||||
# Memgraph uses Neo4j driver but with different defaults
|
||||
self.driver = GraphDatabase.driver(
|
||||
uri,
|
||||
auth=auth,
|
||||
max_connection_pool_size=self.config.get('connection_pool_size', 5),
|
||||
connection_timeout=self.config.get('connection_timeout', 10)
|
||||
)
|
||||
|
||||
# Verify connectivity
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.driver.verify_connectivity
|
||||
)
|
||||
|
||||
logger.info(f"Connected to Memgraph at {uri}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Memgraph: {e}")
|
||||
raise
|
||||
|
||||
async def execute(self, cypher_query: CypherQuery) -> CypherResult:
|
||||
"""Execute Cypher query against Memgraph.
|
||||
|
||||
Args:
|
||||
cypher_query: Cypher query to execute
|
||||
|
||||
Returns:
|
||||
Query results
|
||||
"""
|
||||
if not self.driver:
|
||||
await self.connect()
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Execute query with Memgraph-specific optimizations
|
||||
records = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self._execute_memgraph_sync, cypher_query
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
return CypherResult(
|
||||
records=records,
|
||||
summary={
|
||||
'record_count': len(records),
|
||||
'engine': 'memgraph'
|
||||
},
|
||||
execution_time=execution_time,
|
||||
database_type='memgraph'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Memgraph query execution error: {e}")
|
||||
execution_time = time.time() - start_time
|
||||
return CypherResult(
|
||||
records=[],
|
||||
summary={'error': str(e)},
|
||||
execution_time=execution_time,
|
||||
database_type='memgraph'
|
||||
)
|
||||
|
||||
def _execute_memgraph_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]:
|
||||
"""Execute query synchronously for Memgraph.
|
||||
|
||||
Args:
|
||||
cypher_query: Cypher query to execute
|
||||
|
||||
Returns:
|
||||
List of record dictionaries
|
||||
"""
|
||||
with self.driver.session() as session:
|
||||
# Add Memgraph-specific query hints if available
|
||||
query = cypher_query.query
|
||||
if cypher_query.database_hints and cypher_query.database_hints.get('memory_limit'):
|
||||
# Memgraph supports memory limits
|
||||
query = f"// Memory limit: {cypher_query.database_hints['memory_limit']}\n{query}"
|
||||
|
||||
result = session.run(query, cypher_query.parameters)
|
||||
records = []
|
||||
for record in result:
|
||||
record_dict = {}
|
||||
for key in record.keys():
|
||||
record_dict[key] = record[key]
|
||||
records.append(record_dict)
|
||||
return records
|
||||
|
||||
async def close(self):
|
||||
"""Close Memgraph connection."""
|
||||
if self.driver:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.driver.close
|
||||
)
|
||||
self.driver = None
|
||||
logger.info("Memgraph connection closed")
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected to Memgraph."""
|
||||
return self.driver is not None
|
||||
|
||||
|
||||
class FalkorDBExecutor(CypherExecutorBase):
|
||||
"""Cypher executor for FalkorDB (Redis-based graph database)."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""Initialize FalkorDB executor.
|
||||
|
||||
Args:
|
||||
config: FalkorDB configuration
|
||||
"""
|
||||
if not REDIS_AVAILABLE:
|
||||
raise RuntimeError("Redis driver required for FalkorDB")
|
||||
|
||||
self.config = config
|
||||
self.redis_client: Optional[redis.Redis] = None
|
||||
self.graph_name = config.get('graph_name', 'knowledge_graph')
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to FalkorDB (Redis)."""
|
||||
try:
|
||||
self.redis_client = redis.Redis(
|
||||
host=self.config.get('host', 'localhost'),
|
||||
port=self.config.get('port', 6379),
|
||||
password=self.config.get('password'),
|
||||
db=self.config.get('db', 0),
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=self.config.get('connection_timeout', 10),
|
||||
socket_timeout=self.config.get('socket_timeout', 10)
|
||||
)
|
||||
|
||||
# Test connection
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.redis_client.ping
|
||||
)
|
||||
|
||||
logger.info(f"Connected to FalkorDB at {self.config.get('host', 'localhost')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to FalkorDB: {e}")
|
||||
raise
|
||||
|
||||
async def execute(self, cypher_query: CypherQuery) -> CypherResult:
|
||||
"""Execute Cypher query against FalkorDB.
|
||||
|
||||
Args:
|
||||
cypher_query: Cypher query to execute
|
||||
|
||||
Returns:
|
||||
Query results
|
||||
"""
|
||||
if not self.redis_client:
|
||||
await self.connect()
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Execute query using FalkorDB's GRAPH.QUERY command
|
||||
records = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self._execute_falkordb_sync, cypher_query
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
return CypherResult(
|
||||
records=records,
|
||||
summary={
|
||||
'record_count': len(records),
|
||||
'engine': 'falkordb'
|
||||
},
|
||||
execution_time=execution_time,
|
||||
database_type='falkordb'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"FalkorDB query execution error: {e}")
|
||||
execution_time = time.time() - start_time
|
||||
return CypherResult(
|
||||
records=[],
|
||||
summary={'error': str(e)},
|
||||
execution_time=execution_time,
|
||||
database_type='falkordb'
|
||||
)
|
||||
|
||||
def _execute_falkordb_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]:
|
||||
"""Execute query synchronously for FalkorDB.
|
||||
|
||||
Args:
|
||||
cypher_query: Cypher query to execute
|
||||
|
||||
Returns:
|
||||
List of record dictionaries
|
||||
"""
|
||||
# Substitute parameters in query (FalkorDB parameter handling)
|
||||
query = cypher_query.query
|
||||
for param, value in cypher_query.parameters.items():
|
||||
if isinstance(value, str):
|
||||
query = query.replace(f'${param}', f'"{value}"')
|
||||
else:
|
||||
query = query.replace(f'${param}', str(value))
|
||||
|
||||
# Execute using FalkorDB GRAPH.QUERY command
|
||||
result = self.redis_client.execute_command(
|
||||
'GRAPH.QUERY', self.graph_name, query
|
||||
)
|
||||
|
||||
# Parse FalkorDB result format
|
||||
records = []
|
||||
if result and len(result) > 1:
|
||||
# FalkorDB returns [header, data rows, statistics]
|
||||
headers = result[0] if result[0] else []
|
||||
data_rows = result[1] if len(result) > 1 else []
|
||||
|
||||
for row in data_rows:
|
||||
record = {}
|
||||
for i, header in enumerate(headers):
|
||||
if i < len(row):
|
||||
record[header] = self._format_falkordb_value(row[i])
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
def _format_falkordb_value(self, value):
|
||||
"""Format FalkorDB value for JSON serialization.
|
||||
|
||||
Args:
|
||||
value: FalkorDB value
|
||||
|
||||
Returns:
|
||||
JSON-serializable value
|
||||
"""
|
||||
# FalkorDB returns values in specific formats
|
||||
if isinstance(value, list) and len(value) == 3:
|
||||
# Check if it's a node/relationship representation
|
||||
if value[0] == 1: # Node
|
||||
return {
|
||||
'type': 'node',
|
||||
'labels': value[1],
|
||||
'properties': value[2]
|
||||
}
|
||||
elif value[0] == 2: # Relationship
|
||||
return {
|
||||
'type': 'relationship',
|
||||
'rel_type': value[1],
|
||||
'properties': value[2]
|
||||
}
|
||||
|
||||
return value
|
||||
|
||||
async def close(self):
|
||||
"""Close FalkorDB connection."""
|
||||
if self.redis_client:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.redis_client.close
|
||||
)
|
||||
self.redis_client = None
|
||||
logger.info("FalkorDB connection closed")
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected to FalkorDB."""
|
||||
return self.redis_client is not None
|
||||
|
||||
|
||||
class CypherExecutor:
|
||||
"""Multi-database Cypher executor with automatic routing."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""Initialize multi-database executor.
|
||||
|
||||
Args:
|
||||
config: Configuration for all database types
|
||||
"""
|
||||
self.config = config
|
||||
self.executors: Dict[str, CypherExecutorBase] = {}
|
||||
|
||||
# Initialize available executors
|
||||
self._initialize_executors()
|
||||
|
||||
def _initialize_executors(self):
|
||||
"""Initialize database executors based on configuration."""
|
||||
# Neo4j executor
|
||||
if 'neo4j' in self.config and NEO4J_AVAILABLE:
|
||||
try:
|
||||
self.executors['neo4j'] = Neo4jExecutor(self.config['neo4j'])
|
||||
logger.info("Neo4j executor initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Neo4j executor: {e}")
|
||||
|
||||
# Memgraph executor
|
||||
if 'memgraph' in self.config and NEO4J_AVAILABLE:
|
||||
try:
|
||||
self.executors['memgraph'] = MemgraphExecutor(self.config['memgraph'])
|
||||
logger.info("Memgraph executor initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Memgraph executor: {e}")
|
||||
|
||||
# FalkorDB executor
|
||||
if 'falkordb' in self.config and REDIS_AVAILABLE:
|
||||
try:
|
||||
self.executors['falkordb'] = FalkorDBExecutor(self.config['falkordb'])
|
||||
logger.info("FalkorDB executor initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize FalkorDB executor: {e}")
|
||||
|
||||
if not self.executors:
|
||||
raise RuntimeError("No database executors could be initialized")
|
||||
|
||||
async def execute_cypher(self, cypher_query: CypherQuery,
|
||||
database_type: str) -> CypherResult:
|
||||
"""Execute Cypher query on specified database.
|
||||
|
||||
Args:
|
||||
cypher_query: Cypher query to execute
|
||||
database_type: Target database type
|
||||
|
||||
Returns:
|
||||
Query results
|
||||
"""
|
||||
if database_type not in self.executors:
|
||||
raise ValueError(f"Database type {database_type} not available. "
|
||||
f"Available: {list(self.executors.keys())}")
|
||||
|
||||
executor = self.executors[database_type]
|
||||
|
||||
# Ensure connection
|
||||
if not executor.is_connected():
|
||||
await executor.connect()
|
||||
|
||||
# Execute query
|
||||
return await executor.execute(cypher_query)
|
||||
|
||||
async def execute_on_all(self, cypher_query: CypherQuery) -> Dict[str, CypherResult]:
|
||||
"""Execute query on all available databases.
|
||||
|
||||
Args:
|
||||
cypher_query: Cypher query to execute
|
||||
|
||||
Returns:
|
||||
Results from all databases
|
||||
"""
|
||||
results = {}
|
||||
tasks = []
|
||||
|
||||
for db_type, executor in self.executors.items():
|
||||
task = asyncio.create_task(
|
||||
self.execute_cypher(cypher_query, db_type),
|
||||
name=f"cypher_query_{db_type}"
|
||||
)
|
||||
tasks.append((db_type, task))
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for db_type, task in tasks:
|
||||
try:
|
||||
results[db_type] = await task
|
||||
except Exception as e:
|
||||
logger.error(f"Query failed on {db_type}: {e}")
|
||||
results[db_type] = CypherResult(
|
||||
records=[],
|
||||
summary={'error': str(e)},
|
||||
execution_time=0.0,
|
||||
database_type=db_type
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def get_available_databases(self) -> List[str]:
|
||||
"""Get list of available database types.
|
||||
|
||||
Returns:
|
||||
List of available database type names
|
||||
"""
|
||||
return list(self.executors.keys())
|
||||
|
||||
async def close_all(self):
|
||||
"""Close all database connections."""
|
||||
for executor in self.executors.values():
|
||||
await executor.close()
|
||||
logger.info("All Cypher executor connections closed")
|
||||
628
trustgraph-flow/trustgraph/query/ontology/cypher_generator.py
Normal file
628
trustgraph-flow/trustgraph/query/ontology/cypher_generator.py
Normal file
|
|
@ -0,0 +1,628 @@
|
|||
"""
|
||||
Cypher query generator for ontology-sensitive queries.
|
||||
Converts natural language questions to Cypher queries for graph databases.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .question_analyzer import QuestionComponents, QuestionType
|
||||
from .ontology_matcher import QueryOntologySubset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CypherQuery:
|
||||
"""Generated Cypher query with metadata."""
|
||||
query: str
|
||||
parameters: Dict[str, Any]
|
||||
variables: List[str]
|
||||
explanation: str
|
||||
complexity_score: float
|
||||
database_hints: Dict[str, Any] = None # Database-specific optimization hints
|
||||
|
||||
|
||||
class CypherGenerator:
|
||||
"""Generates Cypher queries from natural language questions using LLM assistance."""
|
||||
|
||||
def __init__(self, prompt_service=None):
|
||||
"""Initialize Cypher generator.
|
||||
|
||||
Args:
|
||||
prompt_service: Service for LLM-based query generation
|
||||
"""
|
||||
self.prompt_service = prompt_service
|
||||
|
||||
# Cypher query templates for common patterns
|
||||
self.templates = {
|
||||
'simple_node_query': """
|
||||
MATCH (n:{node_label})
|
||||
RETURN n.name AS name, n.{property} AS {property}
|
||||
LIMIT {limit}""",
|
||||
|
||||
'relationship_query': """
|
||||
MATCH (a:{source_label})-[r:{relationship}]->(b:{target_label})
|
||||
WHERE a.name = $source_name
|
||||
RETURN b.name AS name, r.{rel_property} AS property""",
|
||||
|
||||
'path_query': """
|
||||
MATCH path = (start:{start_label})-[*1..{max_depth}]->(end:{end_label})
|
||||
WHERE start.name = $start_name
|
||||
RETURN path, length(path) AS path_length
|
||||
ORDER BY path_length""",
|
||||
|
||||
'count_query': """
|
||||
MATCH (n:{node_label})
|
||||
{where_clause}
|
||||
RETURN count(n) AS count""",
|
||||
|
||||
'aggregation_query': """
|
||||
MATCH (n:{node_label})
|
||||
{where_clause}
|
||||
RETURN
|
||||
count(n) AS count,
|
||||
avg(n.{numeric_property}) AS average,
|
||||
sum(n.{numeric_property}) AS total""",
|
||||
|
||||
'boolean_query': """
|
||||
MATCH (a:{source_label})-[:{relationship}]->(b:{target_label})
|
||||
WHERE a.name = $source_name AND b.name = $target_name
|
||||
RETURN count(*) > 0 AS exists""",
|
||||
|
||||
'hierarchy_query': """
|
||||
MATCH (child:{child_label})-[:SUBCLASS_OF*]->(parent:{parent_label})
|
||||
WHERE parent.name = $parent_name
|
||||
RETURN child.name AS child_name, parent.name AS parent_name""",
|
||||
|
||||
'property_filter_query': """
|
||||
MATCH (n:{node_label})
|
||||
WHERE n.{property} {operator} ${property}_value
|
||||
RETURN n.name AS name, n.{property} AS {property}
|
||||
ORDER BY n.{property}"""
|
||||
}
|
||||
|
||||
async def generate_cypher(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset,
|
||||
database_type: str = "neo4j") -> CypherQuery:
|
||||
"""Generate Cypher query for a question.
|
||||
|
||||
Args:
|
||||
question_components: Analyzed question components
|
||||
ontology_subset: Relevant ontology subset
|
||||
database_type: Target database (neo4j, memgraph, falkordb)
|
||||
|
||||
Returns:
|
||||
Generated Cypher query
|
||||
"""
|
||||
# Try template-based generation first
|
||||
template_query = self._try_template_generation(
|
||||
question_components, ontology_subset, database_type
|
||||
)
|
||||
if template_query:
|
||||
logger.debug("Generated Cypher using template")
|
||||
return template_query
|
||||
|
||||
# Fall back to LLM-based generation
|
||||
if self.prompt_service:
|
||||
llm_query = await self._generate_with_llm(
|
||||
question_components, ontology_subset, database_type
|
||||
)
|
||||
if llm_query:
|
||||
logger.debug("Generated Cypher using LLM")
|
||||
return llm_query
|
||||
|
||||
# Final fallback to simple pattern
|
||||
logger.warning("Falling back to simple Cypher pattern")
|
||||
return self._generate_fallback_query(question_components, ontology_subset)
|
||||
|
||||
def _try_template_generation(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset,
|
||||
database_type: str) -> Optional[CypherQuery]:
|
||||
"""Try to generate query using templates.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
database_type: Target database type
|
||||
|
||||
Returns:
|
||||
Generated query or None if no template matches
|
||||
"""
|
||||
# Simple node query (What are the animals?)
|
||||
if (question_components.question_type == QuestionType.RETRIEVAL and
|
||||
len(question_components.entities) == 1):
|
||||
|
||||
node_label = self._find_matching_node_label(
|
||||
question_components.entities[0], ontology_subset
|
||||
)
|
||||
if node_label:
|
||||
query = self.templates['simple_node_query'].format(
|
||||
node_label=node_label,
|
||||
property='name',
|
||||
limit=100
|
||||
)
|
||||
return CypherQuery(
|
||||
query=query,
|
||||
parameters={},
|
||||
variables=['name'],
|
||||
explanation=f"Retrieve all nodes of type {node_label}",
|
||||
complexity_score=0.2,
|
||||
database_hints=self._get_database_hints(database_type, 'simple')
|
||||
)
|
||||
|
||||
# Count query (How many animals are there?)
|
||||
if (question_components.question_type == QuestionType.AGGREGATION and
|
||||
'count' in question_components.aggregations):
|
||||
|
||||
node_label = self._find_matching_node_label(
|
||||
question_components.entities[0] if question_components.entities else 'Entity',
|
||||
ontology_subset
|
||||
)
|
||||
if node_label:
|
||||
where_clause = self._build_where_clause(question_components)
|
||||
query = self.templates['count_query'].format(
|
||||
node_label=node_label,
|
||||
where_clause=where_clause
|
||||
)
|
||||
return CypherQuery(
|
||||
query=query,
|
||||
parameters=self._extract_parameters(question_components),
|
||||
variables=['count'],
|
||||
explanation=f"Count nodes of type {node_label}",
|
||||
complexity_score=0.3,
|
||||
database_hints=self._get_database_hints(database_type, 'aggregation')
|
||||
)
|
||||
|
||||
# Relationship query (Which documents were authored by John Smith?)
|
||||
if (question_components.question_type == QuestionType.RETRIEVAL and
|
||||
len(question_components.entities) >= 2):
|
||||
|
||||
source_label = self._find_matching_node_label(
|
||||
question_components.entities[1], ontology_subset
|
||||
)
|
||||
target_label = self._find_matching_node_label(
|
||||
question_components.entities[0], ontology_subset
|
||||
)
|
||||
relationship = self._find_matching_relationship(
|
||||
question_components, ontology_subset
|
||||
)
|
||||
|
||||
if source_label and target_label and relationship:
|
||||
query = self.templates['relationship_query'].format(
|
||||
source_label=source_label,
|
||||
target_label=target_label,
|
||||
relationship=relationship,
|
||||
rel_property='name'
|
||||
)
|
||||
return CypherQuery(
|
||||
query=query,
|
||||
parameters={'source_name': question_components.entities[1]},
|
||||
variables=['name'],
|
||||
explanation=f"Find {target_label} related to {source_label} via {relationship}",
|
||||
complexity_score=0.4,
|
||||
database_hints=self._get_database_hints(database_type, 'relationship')
|
||||
)
|
||||
|
||||
# Boolean query (Is X related to Y?)
|
||||
if question_components.question_type == QuestionType.BOOLEAN:
|
||||
if len(question_components.entities) >= 2:
|
||||
source_label = self._find_matching_node_label(
|
||||
question_components.entities[0], ontology_subset
|
||||
)
|
||||
target_label = self._find_matching_node_label(
|
||||
question_components.entities[1], ontology_subset
|
||||
)
|
||||
relationship = self._find_matching_relationship(
|
||||
question_components, ontology_subset
|
||||
)
|
||||
|
||||
if source_label and target_label and relationship:
|
||||
query = self.templates['boolean_query'].format(
|
||||
source_label=source_label,
|
||||
target_label=target_label,
|
||||
relationship=relationship
|
||||
)
|
||||
return CypherQuery(
|
||||
query=query,
|
||||
parameters={
|
||||
'source_name': question_components.entities[0],
|
||||
'target_name': question_components.entities[1]
|
||||
},
|
||||
variables=['exists'],
|
||||
explanation="Boolean check for relationship existence",
|
||||
complexity_score=0.3,
|
||||
database_hints=self._get_database_hints(database_type, 'boolean')
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def _generate_with_llm(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset,
|
||||
database_type: str) -> Optional[CypherQuery]:
|
||||
"""Generate Cypher using LLM.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
database_type: Target database type
|
||||
|
||||
Returns:
|
||||
Generated query or None if failed
|
||||
"""
|
||||
try:
|
||||
prompt = self._build_cypher_prompt(
|
||||
question_components, ontology_subset, database_type
|
||||
)
|
||||
response = await self.prompt_service.generate_cypher(prompt=prompt)
|
||||
|
||||
if response and isinstance(response, dict):
|
||||
query = response.get('query', '').strip()
|
||||
if query.upper().startswith(('MATCH', 'CREATE', 'MERGE', 'DELETE', 'RETURN')):
|
||||
return CypherQuery(
|
||||
query=query,
|
||||
parameters=response.get('parameters', {}),
|
||||
variables=self._extract_variables(query),
|
||||
explanation=response.get('explanation', 'Generated by LLM'),
|
||||
complexity_score=self._calculate_complexity(query),
|
||||
database_hints=self._get_database_hints(database_type, 'complex')
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM Cypher generation failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _build_cypher_prompt(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset,
|
||||
database_type: str) -> str:
|
||||
"""Build prompt for LLM Cypher generation.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
database_type: Target database type
|
||||
|
||||
Returns:
|
||||
Formatted prompt string
|
||||
"""
|
||||
# Format ontology elements as node labels and relationships
|
||||
node_labels = self._format_node_labels(ontology_subset.classes)
|
||||
relationships = self._format_relationships(
|
||||
ontology_subset.object_properties,
|
||||
ontology_subset.datatype_properties
|
||||
)
|
||||
|
||||
prompt = f"""Generate a Cypher query for the following question using the provided ontology.
|
||||
|
||||
QUESTION: {question_components.original_question}
|
||||
|
||||
TARGET DATABASE: {database_type}
|
||||
|
||||
AVAILABLE NODE LABELS (from classes):
|
||||
{node_labels}
|
||||
|
||||
AVAILABLE RELATIONSHIP TYPES (from properties):
|
||||
{relationships}
|
||||
|
||||
RULES:
|
||||
- Use MATCH patterns for graph traversal
|
||||
- Include WHERE clauses for filters
|
||||
- Use aggregation functions when needed (COUNT, SUM, AVG)
|
||||
- Optimize for {database_type} performance
|
||||
- Consider index hints for large datasets
|
||||
- Use parameters for values (e.g., $name)
|
||||
|
||||
QUERY TYPE HINTS:
|
||||
- Question type: {question_components.question_type.value}
|
||||
- Expected answer: {question_components.expected_answer_type}
|
||||
- Entities mentioned: {', '.join(question_components.entities)}
|
||||
- Aggregations: {', '.join(question_components.aggregations)}
|
||||
|
||||
DATABASE-SPECIFIC OPTIMIZATIONS:
|
||||
{self._get_database_specific_hints(database_type)}
|
||||
|
||||
Generate a complete Cypher query with parameters:"""
|
||||
|
||||
return prompt
|
||||
|
||||
def _generate_fallback_query(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset) -> CypherQuery:
|
||||
"""Generate simple fallback query.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Basic Cypher query
|
||||
"""
|
||||
# Very basic MATCH query
|
||||
first_class = list(ontology_subset.classes.keys())[0] if ontology_subset.classes else 'Entity'
|
||||
|
||||
query = f"""MATCH (n:{first_class})
|
||||
WHERE n.name CONTAINS $keyword
|
||||
RETURN n.name AS name, labels(n) AS types
|
||||
LIMIT 10"""
|
||||
|
||||
return CypherQuery(
|
||||
query=query,
|
||||
parameters={'keyword': question_components.keywords[0] if question_components.keywords else 'entity'},
|
||||
variables=['name', 'types'],
|
||||
explanation="Fallback query for basic pattern matching",
|
||||
complexity_score=0.1
|
||||
)
|
||||
|
||||
def _find_matching_node_label(self, entity: str, ontology_subset: QueryOntologySubset) -> Optional[str]:
|
||||
"""Find matching node label in ontology subset.
|
||||
|
||||
Args:
|
||||
entity: Entity string to match
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Matching node label or None
|
||||
"""
|
||||
entity_lower = entity.lower()
|
||||
|
||||
# Direct match
|
||||
for class_id in ontology_subset.classes:
|
||||
if class_id.lower() == entity_lower:
|
||||
return class_id
|
||||
|
||||
# Label match
|
||||
for class_id, class_def in ontology_subset.classes.items():
|
||||
labels = class_def.get('labels', [])
|
||||
for label in labels:
|
||||
if isinstance(label, dict):
|
||||
label_value = label.get('value', '').lower()
|
||||
if label_value == entity_lower:
|
||||
return class_id
|
||||
|
||||
# Partial match
|
||||
for class_id in ontology_subset.classes:
|
||||
if entity_lower in class_id.lower() or class_id.lower() in entity_lower:
|
||||
return class_id
|
||||
|
||||
return None
|
||||
|
||||
def _find_matching_relationship(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset) -> Optional[str]:
|
||||
"""Find matching relationship type.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Matching relationship type or None
|
||||
"""
|
||||
# Look for relationship keywords
|
||||
for keyword in question_components.keywords:
|
||||
keyword_lower = keyword.lower()
|
||||
|
||||
# Check object properties
|
||||
for prop_id in ontology_subset.object_properties:
|
||||
if keyword_lower in prop_id.lower() or prop_id.lower() in keyword_lower:
|
||||
return prop_id.upper().replace('-', '_')
|
||||
|
||||
# Common relationship mappings
|
||||
relationship_mappings = {
|
||||
'author': 'AUTHORED_BY',
|
||||
'created': 'CREATED_BY',
|
||||
'owns': 'OWNS',
|
||||
'has': 'HAS',
|
||||
'contains': 'CONTAINS',
|
||||
'parent': 'PARENT_OF',
|
||||
'child': 'CHILD_OF',
|
||||
'related': 'RELATED_TO'
|
||||
}
|
||||
|
||||
for keyword in question_components.keywords:
|
||||
if keyword.lower() in relationship_mappings:
|
||||
return relationship_mappings[keyword.lower()]
|
||||
|
||||
# Default relationship
|
||||
return 'RELATED_TO'
|
||||
|
||||
def _build_where_clause(self, question_components: QuestionComponents) -> str:
|
||||
"""Build WHERE clause for Cypher query.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
|
||||
Returns:
|
||||
WHERE clause string
|
||||
"""
|
||||
conditions = []
|
||||
|
||||
for constraint in question_components.constraints:
|
||||
if 'greater than' in constraint.lower():
|
||||
import re
|
||||
numbers = re.findall(r'\d+', constraint)
|
||||
if numbers:
|
||||
conditions.append(f"n.value > {numbers[0]}")
|
||||
elif 'less than' in constraint.lower():
|
||||
numbers = re.findall(r'\d+', constraint)
|
||||
if numbers:
|
||||
conditions.append(f"n.value < {numbers[0]}")
|
||||
|
||||
if conditions:
|
||||
return f"WHERE {' AND '.join(conditions)}"
|
||||
return ""
|
||||
|
||||
def _extract_parameters(self, question_components: QuestionComponents) -> Dict[str, Any]:
|
||||
"""Extract parameters from question components.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
|
||||
Returns:
|
||||
Parameters dictionary
|
||||
"""
|
||||
parameters = {}
|
||||
|
||||
# Extract numeric values
|
||||
import re
|
||||
for constraint in question_components.constraints:
|
||||
numbers = re.findall(r'\d+', constraint)
|
||||
for i, number in enumerate(numbers):
|
||||
parameters[f'value_{i}'] = int(number)
|
||||
|
||||
return parameters
|
||||
|
||||
def _format_node_labels(self, classes: Dict[str, Any]) -> str:
|
||||
"""Format classes as node labels for prompt.
|
||||
|
||||
Args:
|
||||
classes: Classes dictionary
|
||||
|
||||
Returns:
|
||||
Formatted node labels string
|
||||
"""
|
||||
if not classes:
|
||||
return "None"
|
||||
|
||||
lines = []
|
||||
for class_id, definition in classes.items():
|
||||
comment = definition.get('comment', '')
|
||||
lines.append(f"- :{class_id} - {comment}")
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _format_relationships(self,
|
||||
object_props: Dict[str, Any],
|
||||
datatype_props: Dict[str, Any]) -> str:
|
||||
"""Format properties as relationships for prompt.
|
||||
|
||||
Args:
|
||||
object_props: Object properties
|
||||
datatype_props: Datatype properties
|
||||
|
||||
Returns:
|
||||
Formatted relationships string
|
||||
"""
|
||||
lines = []
|
||||
|
||||
for prop_id, definition in object_props.items():
|
||||
domain = definition.get('domain', 'Any')
|
||||
range_val = definition.get('range', 'Any')
|
||||
comment = definition.get('comment', '')
|
||||
rel_type = prop_id.upper().replace('-', '_')
|
||||
lines.append(f"- :{rel_type} ({domain} -> {range_val}) - {comment}")
|
||||
|
||||
return '\n'.join(lines) if lines else "None"
|
||||
|
||||
def _extract_variables(self, query: str) -> List[str]:
|
||||
"""Extract variables from Cypher query.
|
||||
|
||||
Args:
|
||||
query: Cypher query string
|
||||
|
||||
Returns:
|
||||
List of variable names
|
||||
"""
|
||||
import re
|
||||
# Extract RETURN clause variables
|
||||
return_match = re.search(r'RETURN\s+(.+?)(?:ORDER|LIMIT|$)', query, re.IGNORECASE | re.DOTALL)
|
||||
if return_match:
|
||||
return_clause = return_match.group(1)
|
||||
variables = re.findall(r'(\w+)(?:\s+AS\s+(\w+))?', return_clause)
|
||||
return [var[1] if var[1] else var[0] for var in variables]
|
||||
return []
|
||||
|
||||
def _calculate_complexity(self, query: str) -> float:
|
||||
"""Calculate complexity score for Cypher query.
|
||||
|
||||
Args:
|
||||
query: Cypher query string
|
||||
|
||||
Returns:
|
||||
Complexity score (0.0 to 1.0)
|
||||
"""
|
||||
complexity = 0.0
|
||||
query_upper = query.upper()
|
||||
|
||||
# Count different Cypher features
|
||||
if 'JOIN' in query_upper or 'UNION' in query_upper:
|
||||
complexity += 0.3
|
||||
if 'WHERE' in query_upper:
|
||||
complexity += 0.2
|
||||
if 'OPTIONAL' in query_upper:
|
||||
complexity += 0.1
|
||||
if 'ORDER BY' in query_upper:
|
||||
complexity += 0.1
|
||||
if '*' in query: # Variable length paths
|
||||
complexity += 0.2
|
||||
if any(agg in query_upper for agg in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']):
|
||||
complexity += 0.2
|
||||
|
||||
# Count path length
|
||||
path_matches = re.findall(r'\[.*?\*(\d+)\.\.(\d+).*?\]', query)
|
||||
for start, end in path_matches:
|
||||
complexity += (int(end) - int(start)) * 0.05
|
||||
|
||||
return min(complexity, 1.0)
|
||||
|
||||
def _get_database_hints(self, database_type: str, query_category: str) -> Dict[str, Any]:
|
||||
"""Get database-specific optimization hints.
|
||||
|
||||
Args:
|
||||
database_type: Target database
|
||||
query_category: Category of query
|
||||
|
||||
Returns:
|
||||
Optimization hints
|
||||
"""
|
||||
hints = {}
|
||||
|
||||
if database_type == "neo4j":
|
||||
hints.update({
|
||||
'use_index': True,
|
||||
'explain_plan': 'EXPLAIN',
|
||||
'profile_query': 'PROFILE'
|
||||
})
|
||||
elif database_type == "memgraph":
|
||||
hints.update({
|
||||
'use_index': True,
|
||||
'explain_plan': 'EXPLAIN',
|
||||
'memory_limit': '1GB'
|
||||
})
|
||||
elif database_type == "falkordb":
|
||||
hints.update({
|
||||
'use_index': False, # Redis-based, different indexing
|
||||
'cache_result': True
|
||||
})
|
||||
|
||||
return hints
|
||||
|
||||
def _get_database_specific_hints(self, database_type: str) -> str:
|
||||
"""Get database-specific optimization hints as text.
|
||||
|
||||
Args:
|
||||
database_type: Target database
|
||||
|
||||
Returns:
|
||||
Hints as formatted string
|
||||
"""
|
||||
if database_type == "neo4j":
|
||||
return """- Use USING INDEX hints for large datasets
|
||||
- Consider PROFILE for query optimization
|
||||
- Prefer MERGE over CREATE when appropriate"""
|
||||
elif database_type == "memgraph":
|
||||
return """- Leverage in-memory processing advantages
|
||||
- Use streaming for large result sets
|
||||
- Consider query parallelization"""
|
||||
elif database_type == "falkordb":
|
||||
return """- Optimize for Redis memory constraints
|
||||
- Use simple patterns for best performance
|
||||
- Leverage Redis data structures when possible"""
|
||||
else:
|
||||
return "- Use standard Cypher optimization patterns"
|
||||
557
trustgraph-flow/trustgraph/query/ontology/error_handling.py
Normal file
557
trustgraph-flow/trustgraph/query/ontology/error_handling.py
Normal file
|
|
@ -0,0 +1,557 @@
|
|||
"""
|
||||
Error handling and recovery mechanisms for OntoRAG.
|
||||
Provides comprehensive error handling, retry logic, and graceful degradation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional, Callable, Union, Type
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
import traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ErrorSeverity(Enum):
|
||||
"""Error severity levels."""
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class ErrorCategory(Enum):
|
||||
"""Error categories for better handling."""
|
||||
ONTOLOGY_LOADING = "ontology_loading"
|
||||
QUESTION_ANALYSIS = "question_analysis"
|
||||
QUERY_GENERATION = "query_generation"
|
||||
QUERY_EXECUTION = "query_execution"
|
||||
ANSWER_GENERATION = "answer_generation"
|
||||
BACKEND_CONNECTION = "backend_connection"
|
||||
CACHE_ERROR = "cache_error"
|
||||
VALIDATION_ERROR = "validation_error"
|
||||
TIMEOUT_ERROR = "timeout_error"
|
||||
AUTHENTICATION_ERROR = "authentication_error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorContext:
|
||||
"""Context information for an error."""
|
||||
category: ErrorCategory
|
||||
severity: ErrorSeverity
|
||||
component: str
|
||||
operation: str
|
||||
user_message: Optional[str] = None
|
||||
technical_details: Optional[str] = None
|
||||
suggestion: Optional[str] = None
|
||||
retry_count: int = 0
|
||||
max_retries: int = 3
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
|
||||
class OntoRAGError(Exception):
|
||||
"""Base exception for OntoRAG system."""
|
||||
|
||||
def __init__(self,
|
||||
message: str,
|
||||
context: Optional[ErrorContext] = None,
|
||||
cause: Optional[Exception] = None):
|
||||
"""Initialize OntoRAG error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
context: Error context
|
||||
cause: Original exception that caused this error
|
||||
"""
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.context = context or ErrorContext(
|
||||
category=ErrorCategory.VALIDATION_ERROR,
|
||||
severity=ErrorSeverity.MEDIUM,
|
||||
component="unknown",
|
||||
operation="unknown"
|
||||
)
|
||||
self.cause = cause
|
||||
self.timestamp = time.time()
|
||||
|
||||
|
||||
class OntologyLoadingError(OntoRAGError):
|
||||
"""Error loading ontology."""
|
||||
pass
|
||||
|
||||
|
||||
class QuestionAnalysisError(OntoRAGError):
|
||||
"""Error analyzing question."""
|
||||
pass
|
||||
|
||||
|
||||
class QueryGenerationError(OntoRAGError):
|
||||
"""Error generating query."""
|
||||
pass
|
||||
|
||||
|
||||
class QueryExecutionError(OntoRAGError):
|
||||
"""Error executing query."""
|
||||
pass
|
||||
|
||||
|
||||
class AnswerGenerationError(OntoRAGError):
|
||||
"""Error generating answer."""
|
||||
pass
|
||||
|
||||
|
||||
class BackendConnectionError(OntoRAGError):
|
||||
"""Error connecting to backend."""
|
||||
pass
|
||||
|
||||
|
||||
class TimeoutError(OntoRAGError):
|
||||
"""Operation timeout error."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""Configuration for retry logic."""
|
||||
max_retries: int = 3
|
||||
base_delay: float = 1.0
|
||||
max_delay: float = 60.0
|
||||
exponential_backoff: bool = True
|
||||
jitter: bool = True
|
||||
retry_on_exceptions: List[Type[Exception]] = None
|
||||
|
||||
|
||||
class ErrorRecoveryStrategy:
|
||||
"""Strategy for handling and recovering from errors."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None):
|
||||
"""Initialize error recovery strategy.
|
||||
|
||||
Args:
|
||||
config: Recovery configuration
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.retry_configs = self._build_retry_configs()
|
||||
self.fallback_strategies = self._build_fallback_strategies()
|
||||
self.error_counters: Dict[str, int] = {}
|
||||
self.circuit_breakers: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def _build_retry_configs(self) -> Dict[ErrorCategory, RetryConfig]:
|
||||
"""Build retry configurations for different error categories."""
|
||||
return {
|
||||
ErrorCategory.BACKEND_CONNECTION: RetryConfig(
|
||||
max_retries=5,
|
||||
base_delay=2.0,
|
||||
retry_on_exceptions=[BackendConnectionError, ConnectionError, TimeoutError]
|
||||
),
|
||||
ErrorCategory.QUERY_EXECUTION: RetryConfig(
|
||||
max_retries=3,
|
||||
base_delay=1.0,
|
||||
retry_on_exceptions=[QueryExecutionError, TimeoutError]
|
||||
),
|
||||
ErrorCategory.ONTOLOGY_LOADING: RetryConfig(
|
||||
max_retries=2,
|
||||
base_delay=0.5,
|
||||
retry_on_exceptions=[OntologyLoadingError, IOError]
|
||||
),
|
||||
ErrorCategory.QUESTION_ANALYSIS: RetryConfig(
|
||||
max_retries=2,
|
||||
base_delay=1.0,
|
||||
retry_on_exceptions=[QuestionAnalysisError, TimeoutError]
|
||||
),
|
||||
ErrorCategory.ANSWER_GENERATION: RetryConfig(
|
||||
max_retries=2,
|
||||
base_delay=1.0,
|
||||
retry_on_exceptions=[AnswerGenerationError, TimeoutError]
|
||||
)
|
||||
}
|
||||
|
||||
def _build_fallback_strategies(self) -> Dict[ErrorCategory, Callable]:
|
||||
"""Build fallback strategies for different error categories."""
|
||||
return {
|
||||
ErrorCategory.QUESTION_ANALYSIS: self._fallback_question_analysis,
|
||||
ErrorCategory.QUERY_GENERATION: self._fallback_query_generation,
|
||||
ErrorCategory.QUERY_EXECUTION: self._fallback_query_execution,
|
||||
ErrorCategory.ANSWER_GENERATION: self._fallback_answer_generation,
|
||||
ErrorCategory.BACKEND_CONNECTION: self._fallback_backend_connection
|
||||
}
|
||||
|
||||
async def handle_error(self,
|
||||
error: Exception,
|
||||
context: ErrorContext,
|
||||
operation: Callable,
|
||||
*args,
|
||||
**kwargs) -> Any:
|
||||
"""Handle error with recovery strategies.
|
||||
|
||||
Args:
|
||||
error: The exception that occurred
|
||||
context: Error context
|
||||
operation: Function to retry
|
||||
*args: Operation arguments
|
||||
**kwargs: Operation keyword arguments
|
||||
|
||||
Returns:
|
||||
Result of successful operation or fallback
|
||||
"""
|
||||
logger.error(f"Handling error in {context.component}.{context.operation}: {error}")
|
||||
|
||||
# Update error counters
|
||||
error_key = f"{context.category.value}:{context.component}"
|
||||
self.error_counters[error_key] = self.error_counters.get(error_key, 0) + 1
|
||||
|
||||
# Check circuit breaker
|
||||
if self._is_circuit_open(error_key):
|
||||
return await self._execute_fallback(context, *args, **kwargs)
|
||||
|
||||
# Try retry if configured
|
||||
retry_config = self.retry_configs.get(context.category)
|
||||
if retry_config and context.retry_count < retry_config.max_retries:
|
||||
if any(isinstance(error, exc_type) for exc_type in retry_config.retry_on_exceptions or []):
|
||||
return await self._retry_operation(
|
||||
operation, context, retry_config, *args, **kwargs
|
||||
)
|
||||
|
||||
# Execute fallback strategy
|
||||
return await self._execute_fallback(context, *args, **kwargs)
|
||||
|
||||
async def _retry_operation(self,
|
||||
operation: Callable,
|
||||
context: ErrorContext,
|
||||
retry_config: RetryConfig,
|
||||
*args,
|
||||
**kwargs) -> Any:
|
||||
"""Retry operation with backoff."""
|
||||
context.retry_count += 1
|
||||
|
||||
# Calculate delay
|
||||
delay = retry_config.base_delay
|
||||
if retry_config.exponential_backoff:
|
||||
delay *= (2 ** (context.retry_count - 1))
|
||||
delay = min(delay, retry_config.max_delay)
|
||||
|
||||
# Add jitter
|
||||
if retry_config.jitter:
|
||||
import random
|
||||
delay *= (0.5 + random.random())
|
||||
|
||||
logger.info(f"Retrying {context.component}.{context.operation} "
|
||||
f"(attempt {context.retry_count}) after {delay:.2f}s")
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(operation):
|
||||
return await operation(*args, **kwargs)
|
||||
else:
|
||||
return operation(*args, **kwargs)
|
||||
except Exception as e:
|
||||
return await self.handle_error(e, context, operation, *args, **kwargs)
|
||||
|
||||
async def _execute_fallback(self,
|
||||
context: ErrorContext,
|
||||
*args,
|
||||
**kwargs) -> Any:
|
||||
"""Execute fallback strategy."""
|
||||
fallback_func = self.fallback_strategies.get(context.category)
|
||||
if fallback_func:
|
||||
logger.info(f"Executing fallback for {context.category.value}")
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(fallback_func):
|
||||
return await fallback_func(context, *args, **kwargs)
|
||||
else:
|
||||
return fallback_func(context, *args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Fallback strategy failed: {e}")
|
||||
|
||||
# Default fallback
|
||||
return self._default_fallback(context)
|
||||
|
||||
def _is_circuit_open(self, error_key: str) -> bool:
|
||||
"""Check if circuit breaker is open."""
|
||||
circuit = self.circuit_breakers.get(error_key, {})
|
||||
error_count = self.error_counters.get(error_key, 0)
|
||||
error_threshold = self.config.get('circuit_breaker_threshold', 10)
|
||||
window_seconds = self.config.get('circuit_breaker_window', 300) # 5 minutes
|
||||
|
||||
current_time = time.time()
|
||||
window_start = circuit.get('window_start', current_time)
|
||||
|
||||
# Reset window if expired
|
||||
if current_time - window_start > window_seconds:
|
||||
self.circuit_breakers[error_key] = {'window_start': current_time}
|
||||
self.error_counters[error_key] = 0
|
||||
return False
|
||||
|
||||
return error_count >= error_threshold
|
||||
|
||||
def _default_fallback(self, context: ErrorContext) -> Any:
|
||||
"""Default fallback response."""
|
||||
if context.category == ErrorCategory.ANSWER_GENERATION:
|
||||
return "I'm sorry, I encountered an error while processing your question. Please try again."
|
||||
elif context.category == ErrorCategory.QUERY_EXECUTION:
|
||||
return {"error": "Query execution failed", "results": []}
|
||||
else:
|
||||
return None
|
||||
|
||||
# Fallback strategy implementations
|
||||
|
||||
async def _fallback_question_analysis(self, context: ErrorContext, question: str, **kwargs):
|
||||
"""Fallback for question analysis."""
|
||||
from .question_analyzer import QuestionComponents, QuestionType
|
||||
|
||||
# Simple keyword-based analysis
|
||||
question_lower = question.lower()
|
||||
|
||||
# Determine question type
|
||||
if any(word in question_lower for word in ['how many', 'count', 'number']):
|
||||
question_type = QuestionType.AGGREGATION
|
||||
elif question_lower.startswith(('is', 'are', 'does', 'can')):
|
||||
question_type = QuestionType.BOOLEAN
|
||||
elif any(word in question_lower for word in ['what', 'which', 'who', 'where']):
|
||||
question_type = QuestionType.RETRIEVAL
|
||||
else:
|
||||
question_type = QuestionType.FACTUAL
|
||||
|
||||
# Extract simple entities (nouns)
|
||||
import re
|
||||
words = re.findall(r'\b[a-zA-Z]+\b', question)
|
||||
entities = [word for word in words if len(word) > 3 and word.lower() not in
|
||||
{'what', 'which', 'where', 'when', 'who', 'how', 'does', 'are', 'the'}]
|
||||
|
||||
return QuestionComponents(
|
||||
original_question=question,
|
||||
normalized_question=question.lower(),
|
||||
question_type=question_type,
|
||||
entities=entities[:3], # Limit to 3 entities
|
||||
keywords=words[:5], # Limit to 5 keywords
|
||||
relationships=[],
|
||||
constraints=[],
|
||||
aggregations=['count'] if question_type == QuestionType.AGGREGATION else [],
|
||||
expected_answer_type='text'
|
||||
)
|
||||
|
||||
async def _fallback_query_generation(self, context: ErrorContext, **kwargs):
|
||||
"""Fallback for query generation."""
|
||||
# Generate simple query based on available information
|
||||
if 'sparql' in context.metadata.get('query_language', '').lower():
|
||||
query = """
|
||||
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
|
||||
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
||||
|
||||
SELECT ?subject ?predicate ?object WHERE {
|
||||
?subject ?predicate ?object .
|
||||
}
|
||||
LIMIT 10
|
||||
"""
|
||||
from .sparql_generator import SPARQLQuery
|
||||
return SPARQLQuery(
|
||||
query=query,
|
||||
variables=['subject', 'predicate', 'object'],
|
||||
query_type='SELECT',
|
||||
explanation='Fallback SPARQL query',
|
||||
complexity_score=0.1
|
||||
)
|
||||
else:
|
||||
query = "MATCH (n) RETURN n LIMIT 10"
|
||||
from .cypher_generator import CypherQuery
|
||||
return CypherQuery(
|
||||
query=query,
|
||||
variables=['n'],
|
||||
query_type='MATCH',
|
||||
explanation='Fallback Cypher query',
|
||||
complexity_score=0.1
|
||||
)
|
||||
|
||||
async def _fallback_query_execution(self, context: ErrorContext, **kwargs):
|
||||
"""Fallback for query execution."""
|
||||
# Return empty results
|
||||
if 'sparql' in context.metadata.get('query_language', '').lower():
|
||||
from .sparql_cassandra import SPARQLResult
|
||||
return SPARQLResult(
|
||||
bindings=[],
|
||||
variables=[],
|
||||
execution_time=0.0
|
||||
)
|
||||
else:
|
||||
from .cypher_executor import CypherResult
|
||||
return CypherResult(
|
||||
records=[],
|
||||
summary={'type': 'fallback'},
|
||||
metadata={'query': 'fallback'},
|
||||
execution_time=0.0
|
||||
)
|
||||
|
||||
async def _fallback_answer_generation(self, context: ErrorContext, question: str = None, **kwargs):
|
||||
"""Fallback for answer generation."""
|
||||
fallback_messages = [
|
||||
"I'm experiencing some technical difficulties. Please try rephrasing your question.",
|
||||
"I couldn't process your question at the moment. Could you try asking it differently?",
|
||||
"There seems to be an issue with my analysis. Please try again in a moment.",
|
||||
"I'm having trouble understanding your question right now. Please try again."
|
||||
]
|
||||
|
||||
import random
|
||||
return random.choice(fallback_messages)
|
||||
|
||||
async def _fallback_backend_connection(self, context: ErrorContext, **kwargs):
|
||||
"""Fallback for backend connection."""
|
||||
logger.warning(f"Backend connection failed for {context.component}")
|
||||
# Could switch to alternative backend here
|
||||
return None
|
||||
|
||||
|
||||
def with_error_handling(category: ErrorCategory,
|
||||
component: str,
|
||||
operation: str,
|
||||
severity: ErrorSeverity = ErrorSeverity.MEDIUM):
|
||||
"""Decorator for automatic error handling.
|
||||
|
||||
Args:
|
||||
category: Error category
|
||||
component: Component name
|
||||
operation: Operation name
|
||||
severity: Error severity
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
context = ErrorContext(
|
||||
category=category,
|
||||
severity=severity,
|
||||
component=component,
|
||||
operation=operation,
|
||||
technical_details=str(e),
|
||||
metadata={'args': str(args), 'kwargs': str(kwargs)}
|
||||
)
|
||||
|
||||
# Get error recovery strategy from first argument if it's available
|
||||
error_strategy = None
|
||||
if args and hasattr(args[0], '_error_strategy'):
|
||||
error_strategy = args[0]._error_strategy
|
||||
|
||||
if error_strategy:
|
||||
return await error_strategy.handle_error(e, context, func, *args, **kwargs)
|
||||
else:
|
||||
# Re-raise as OntoRAG error
|
||||
raise OntoRAGError(
|
||||
f"Error in {component}.{operation}: {str(e)}",
|
||||
context=context,
|
||||
cause=e
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
context = ErrorContext(
|
||||
category=category,
|
||||
severity=severity,
|
||||
component=component,
|
||||
operation=operation,
|
||||
technical_details=str(e),
|
||||
metadata={'args': str(args), 'kwargs': str(kwargs)}
|
||||
)
|
||||
|
||||
raise OntoRAGError(
|
||||
f"Error in {component}.{operation}: {str(e)}",
|
||||
context=context,
|
||||
cause=e
|
||||
)
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ErrorReporter:
|
||||
"""Reports and tracks errors for monitoring and debugging."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None):
|
||||
"""Initialize error reporter.
|
||||
|
||||
Args:
|
||||
config: Reporter configuration
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.error_log: List[Dict[str, Any]] = []
|
||||
self.max_log_size = self.config.get('max_log_size', 1000)
|
||||
|
||||
def report_error(self, error: OntoRAGError):
|
||||
"""Report an error for tracking.
|
||||
|
||||
Args:
|
||||
error: The error to report
|
||||
"""
|
||||
error_entry = {
|
||||
'timestamp': error.timestamp,
|
||||
'message': error.message,
|
||||
'category': error.context.category.value,
|
||||
'severity': error.context.severity.value,
|
||||
'component': error.context.component,
|
||||
'operation': error.context.operation,
|
||||
'retry_count': error.context.retry_count,
|
||||
'technical_details': error.context.technical_details,
|
||||
'stack_trace': traceback.format_exc() if error.cause else None
|
||||
}
|
||||
|
||||
self.error_log.append(error_entry)
|
||||
|
||||
# Trim log if too large
|
||||
if len(self.error_log) > self.max_log_size:
|
||||
self.error_log = self.error_log[-self.max_log_size:]
|
||||
|
||||
# Log based on severity
|
||||
if error.context.severity == ErrorSeverity.CRITICAL:
|
||||
logger.critical(f"CRITICAL ERROR: {error.message}")
|
||||
elif error.context.severity == ErrorSeverity.HIGH:
|
||||
logger.error(f"HIGH SEVERITY: {error.message}")
|
||||
elif error.context.severity == ErrorSeverity.MEDIUM:
|
||||
logger.warning(f"MEDIUM SEVERITY: {error.message}")
|
||||
else:
|
||||
logger.info(f"LOW SEVERITY: {error.message}")
|
||||
|
||||
def get_error_summary(self) -> Dict[str, Any]:
|
||||
"""Get summary of recent errors.
|
||||
|
||||
Returns:
|
||||
Error summary statistics
|
||||
"""
|
||||
if not self.error_log:
|
||||
return {'total_errors': 0}
|
||||
|
||||
recent_errors = [
|
||||
e for e in self.error_log
|
||||
if time.time() - e['timestamp'] < 3600 # Last hour
|
||||
]
|
||||
|
||||
category_counts = {}
|
||||
severity_counts = {}
|
||||
component_counts = {}
|
||||
|
||||
for error in recent_errors:
|
||||
category_counts[error['category']] = category_counts.get(error['category'], 0) + 1
|
||||
severity_counts[error['severity']] = severity_counts.get(error['severity'], 0) + 1
|
||||
component_counts[error['component']] = component_counts.get(error['component'], 0) + 1
|
||||
|
||||
return {
|
||||
'total_errors': len(self.error_log),
|
||||
'recent_errors': len(recent_errors),
|
||||
'category_breakdown': category_counts,
|
||||
'severity_breakdown': severity_counts,
|
||||
'component_breakdown': component_counts,
|
||||
'most_recent_error': self.error_log[-1] if self.error_log else None
|
||||
}
|
||||
737
trustgraph-flow/trustgraph/query/ontology/monitoring.py
Normal file
737
trustgraph-flow/trustgraph/query/ontology/monitoring.py
Normal file
|
|
@ -0,0 +1,737 @@
|
|||
"""
|
||||
Performance monitoring and metrics collection for OntoRAG.
|
||||
Provides comprehensive monitoring of system performance, query patterns, and resource usage.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict, deque
|
||||
import statistics
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetricType(Enum):
|
||||
"""Types of metrics to collect."""
|
||||
COUNTER = "counter"
|
||||
GAUGE = "gauge"
|
||||
HISTOGRAM = "histogram"
|
||||
TIMER = "timer"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metric:
|
||||
"""Individual metric data point."""
|
||||
name: str
|
||||
value: float
|
||||
timestamp: datetime
|
||||
labels: Dict[str, str] = field(default_factory=dict)
|
||||
metric_type: MetricType = MetricType.GAUGE
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimerMetric:
|
||||
"""Timer metric for measuring duration."""
|
||||
name: str
|
||||
start_time: float
|
||||
labels: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def stop(self) -> float:
|
||||
"""Stop timer and return duration."""
|
||||
return time.time() - self.start_time
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerformanceStats:
|
||||
"""Performance statistics for a component."""
|
||||
total_requests: int = 0
|
||||
successful_requests: int = 0
|
||||
failed_requests: int = 0
|
||||
avg_response_time: float = 0.0
|
||||
min_response_time: float = float('inf')
|
||||
max_response_time: float = 0.0
|
||||
p95_response_time: float = 0.0
|
||||
p99_response_time: float = 0.0
|
||||
throughput_per_second: float = 0.0
|
||||
error_rate: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemHealth:
|
||||
"""Overall system health metrics."""
|
||||
status: str = "healthy" # healthy, degraded, unhealthy
|
||||
uptime_seconds: float = 0.0
|
||||
cpu_usage_percent: float = 0.0
|
||||
memory_usage_percent: float = 0.0
|
||||
active_connections: int = 0
|
||||
queue_size: int = 0
|
||||
cache_hit_rate: float = 0.0
|
||||
error_rate: float = 0.0
|
||||
|
||||
|
||||
class MetricsCollector:
|
||||
"""Collects and stores metrics data."""
|
||||
|
||||
def __init__(self, max_metrics: int = 10000, retention_hours: int = 24):
|
||||
"""Initialize metrics collector.
|
||||
|
||||
Args:
|
||||
max_metrics: Maximum number of metrics to retain
|
||||
retention_hours: Hours to retain metrics
|
||||
"""
|
||||
self.max_metrics = max_metrics
|
||||
self.retention_hours = retention_hours
|
||||
self.metrics: Dict[str, deque] = defaultdict(lambda: deque(maxlen=max_metrics))
|
||||
self.counters: Dict[str, float] = defaultdict(float)
|
||||
self.gauges: Dict[str, float] = defaultdict(float)
|
||||
self.timers: Dict[str, List[float]] = defaultdict(list)
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def increment(self, name: str, value: float = 1.0, labels: Dict[str, str] = None):
|
||||
"""Increment a counter metric.
|
||||
|
||||
Args:
|
||||
name: Metric name
|
||||
value: Value to increment by
|
||||
labels: Metric labels
|
||||
"""
|
||||
with self._lock:
|
||||
metric_key = self._build_key(name, labels)
|
||||
self.counters[metric_key] += value
|
||||
self._add_metric(name, value, MetricType.COUNTER, labels)
|
||||
|
||||
def set_gauge(self, name: str, value: float, labels: Dict[str, str] = None):
|
||||
"""Set a gauge metric value.
|
||||
|
||||
Args:
|
||||
name: Metric name
|
||||
value: Gauge value
|
||||
labels: Metric labels
|
||||
"""
|
||||
with self._lock:
|
||||
metric_key = self._build_key(name, labels)
|
||||
self.gauges[metric_key] = value
|
||||
self._add_metric(name, value, MetricType.GAUGE, labels)
|
||||
|
||||
def record_timer(self, name: str, duration: float, labels: Dict[str, str] = None):
|
||||
"""Record a timer measurement.
|
||||
|
||||
Args:
|
||||
name: Metric name
|
||||
duration: Duration in seconds
|
||||
labels: Metric labels
|
||||
"""
|
||||
with self._lock:
|
||||
metric_key = self._build_key(name, labels)
|
||||
self.timers[metric_key].append(duration)
|
||||
|
||||
# Keep only recent measurements
|
||||
max_timer_values = 1000
|
||||
if len(self.timers[metric_key]) > max_timer_values:
|
||||
self.timers[metric_key] = self.timers[metric_key][-max_timer_values:]
|
||||
|
||||
self._add_metric(name, duration, MetricType.TIMER, labels)
|
||||
|
||||
def start_timer(self, name: str, labels: Dict[str, str] = None) -> TimerMetric:
|
||||
"""Start a timer.
|
||||
|
||||
Args:
|
||||
name: Metric name
|
||||
labels: Metric labels
|
||||
|
||||
Returns:
|
||||
Timer metric object
|
||||
"""
|
||||
return TimerMetric(name=name, start_time=time.time(), labels=labels or {})
|
||||
|
||||
def stop_timer(self, timer: TimerMetric):
|
||||
"""Stop a timer and record the measurement.
|
||||
|
||||
Args:
|
||||
timer: Timer metric to stop
|
||||
"""
|
||||
duration = timer.stop()
|
||||
self.record_timer(timer.name, duration, timer.labels)
|
||||
return duration
|
||||
|
||||
def get_counter(self, name: str, labels: Dict[str, str] = None) -> float:
|
||||
"""Get counter value.
|
||||
|
||||
Args:
|
||||
name: Metric name
|
||||
labels: Metric labels
|
||||
|
||||
Returns:
|
||||
Counter value
|
||||
"""
|
||||
metric_key = self._build_key(name, labels)
|
||||
return self.counters.get(metric_key, 0.0)
|
||||
|
||||
def get_gauge(self, name: str, labels: Dict[str, str] = None) -> float:
|
||||
"""Get gauge value.
|
||||
|
||||
Args:
|
||||
name: Metric name
|
||||
labels: Metric labels
|
||||
|
||||
Returns:
|
||||
Gauge value
|
||||
"""
|
||||
metric_key = self._build_key(name, labels)
|
||||
return self.gauges.get(metric_key, 0.0)
|
||||
|
||||
def get_timer_stats(self, name: str, labels: Dict[str, str] = None) -> Dict[str, float]:
|
||||
"""Get timer statistics.
|
||||
|
||||
Args:
|
||||
name: Metric name
|
||||
labels: Metric labels
|
||||
|
||||
Returns:
|
||||
Timer statistics
|
||||
"""
|
||||
metric_key = self._build_key(name, labels)
|
||||
values = self.timers.get(metric_key, [])
|
||||
|
||||
if not values:
|
||||
return {}
|
||||
|
||||
sorted_values = sorted(values)
|
||||
return {
|
||||
'count': len(values),
|
||||
'sum': sum(values),
|
||||
'avg': statistics.mean(values),
|
||||
'min': min(values),
|
||||
'max': max(values),
|
||||
'p50': sorted_values[int(len(sorted_values) * 0.5)],
|
||||
'p95': sorted_values[int(len(sorted_values) * 0.95)],
|
||||
'p99': sorted_values[int(len(sorted_values) * 0.99)]
|
||||
}
|
||||
|
||||
def get_metrics(self,
|
||||
name_pattern: Optional[str] = None,
|
||||
since: Optional[datetime] = None) -> List[Metric]:
|
||||
"""Get metrics matching pattern and time range.
|
||||
|
||||
Args:
|
||||
name_pattern: Pattern to match metric names
|
||||
since: Only return metrics since this time
|
||||
|
||||
Returns:
|
||||
List of matching metrics
|
||||
"""
|
||||
with self._lock:
|
||||
results = []
|
||||
cutoff_time = since or datetime.now() - timedelta(hours=self.retention_hours)
|
||||
|
||||
for metric_name, metric_queue in self.metrics.items():
|
||||
if name_pattern and name_pattern not in metric_name:
|
||||
continue
|
||||
|
||||
for metric in metric_queue:
|
||||
if metric.timestamp >= cutoff_time:
|
||||
results.append(metric)
|
||||
|
||||
return sorted(results, key=lambda m: m.timestamp)
|
||||
|
||||
def cleanup_old_metrics(self):
|
||||
"""Remove old metrics beyond retention period."""
|
||||
with self._lock:
|
||||
cutoff_time = datetime.now() - timedelta(hours=self.retention_hours)
|
||||
|
||||
for metric_name in list(self.metrics.keys()):
|
||||
metric_queue = self.metrics[metric_name]
|
||||
# Remove old metrics
|
||||
while metric_queue and metric_queue[0].timestamp < cutoff_time:
|
||||
metric_queue.popleft()
|
||||
|
||||
# Remove empty queues
|
||||
if not metric_queue:
|
||||
del self.metrics[metric_name]
|
||||
|
||||
def _add_metric(self, name: str, value: float, metric_type: MetricType, labels: Dict[str, str]):
|
||||
"""Add metric to storage."""
|
||||
metric = Metric(
|
||||
name=name,
|
||||
value=value,
|
||||
timestamp=datetime.now(),
|
||||
labels=labels or {},
|
||||
metric_type=metric_type
|
||||
)
|
||||
self.metrics[name].append(metric)
|
||||
|
||||
def _build_key(self, name: str, labels: Dict[str, str]) -> str:
|
||||
"""Build metric key from name and labels."""
|
||||
if not labels:
|
||||
return name
|
||||
|
||||
label_str = ','.join(f"{k}={v}" for k, v in sorted(labels.items()))
|
||||
return f"{name}{{{label_str}}}"
|
||||
|
||||
|
||||
class PerformanceMonitor:
|
||||
"""Monitors system performance and component health."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None):
|
||||
"""Initialize performance monitor.
|
||||
|
||||
Args:
|
||||
config: Monitor configuration
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.metrics_collector = MetricsCollector(
|
||||
max_metrics=self.config.get('max_metrics', 10000),
|
||||
retention_hours=self.config.get('retention_hours', 24)
|
||||
)
|
||||
|
||||
self.component_stats: Dict[str, PerformanceStats] = {}
|
||||
self.start_time = time.time()
|
||||
self.monitoring_enabled = self.config.get('enabled', True)
|
||||
|
||||
# Start background monitoring tasks
|
||||
if self.monitoring_enabled:
|
||||
self._start_background_tasks()
|
||||
|
||||
def record_request(self,
|
||||
component: str,
|
||||
operation: str,
|
||||
duration: float,
|
||||
success: bool = True,
|
||||
labels: Dict[str, str] = None):
|
||||
"""Record a request completion.
|
||||
|
||||
Args:
|
||||
component: Component name
|
||||
operation: Operation name
|
||||
duration: Request duration in seconds
|
||||
success: Whether request was successful
|
||||
labels: Additional labels
|
||||
"""
|
||||
if not self.monitoring_enabled:
|
||||
return
|
||||
|
||||
base_labels = {'component': component, 'operation': operation}
|
||||
if labels:
|
||||
base_labels.update(labels)
|
||||
|
||||
# Record metrics
|
||||
self.metrics_collector.increment('requests_total', labels=base_labels)
|
||||
self.metrics_collector.record_timer('request_duration', duration, base_labels)
|
||||
|
||||
if success:
|
||||
self.metrics_collector.increment('requests_successful', labels=base_labels)
|
||||
else:
|
||||
self.metrics_collector.increment('requests_failed', labels=base_labels)
|
||||
|
||||
# Update component stats
|
||||
self._update_component_stats(component, duration, success)
|
||||
|
||||
def record_query_complexity(self,
|
||||
complexity_score: float,
|
||||
query_type: str,
|
||||
backend: str):
|
||||
"""Record query complexity metrics.
|
||||
|
||||
Args:
|
||||
complexity_score: Query complexity score (0.0 to 1.0)
|
||||
query_type: Type of query (SPARQL, Cypher)
|
||||
backend: Backend used
|
||||
"""
|
||||
if not self.monitoring_enabled:
|
||||
return
|
||||
|
||||
labels = {'query_type': query_type, 'backend': backend}
|
||||
self.metrics_collector.set_gauge('query_complexity', complexity_score, labels)
|
||||
|
||||
def record_cache_access(self, hit: bool, cache_type: str = 'default'):
|
||||
"""Record cache access.
|
||||
|
||||
Args:
|
||||
hit: Whether it was a cache hit
|
||||
cache_type: Type of cache
|
||||
"""
|
||||
if not self.monitoring_enabled:
|
||||
return
|
||||
|
||||
labels = {'cache_type': cache_type}
|
||||
self.metrics_collector.increment('cache_requests_total', labels=labels)
|
||||
|
||||
if hit:
|
||||
self.metrics_collector.increment('cache_hits_total', labels=labels)
|
||||
else:
|
||||
self.metrics_collector.increment('cache_misses_total', labels=labels)
|
||||
|
||||
def record_ontology_selection(self,
|
||||
selected_elements: int,
|
||||
total_elements: int,
|
||||
ontology_id: str):
|
||||
"""Record ontology selection metrics.
|
||||
|
||||
Args:
|
||||
selected_elements: Number of selected ontology elements
|
||||
total_elements: Total ontology elements
|
||||
ontology_id: Ontology identifier
|
||||
"""
|
||||
if not self.monitoring_enabled:
|
||||
return
|
||||
|
||||
labels = {'ontology_id': ontology_id}
|
||||
self.metrics_collector.set_gauge('ontology_elements_selected', selected_elements, labels)
|
||||
self.metrics_collector.set_gauge('ontology_elements_total', total_elements, labels)
|
||||
|
||||
selection_ratio = selected_elements / total_elements if total_elements > 0 else 0
|
||||
self.metrics_collector.set_gauge('ontology_selection_ratio', selection_ratio, labels)
|
||||
|
||||
def get_component_stats(self, component: str) -> Optional[PerformanceStats]:
|
||||
"""Get performance statistics for a component.
|
||||
|
||||
Args:
|
||||
component: Component name
|
||||
|
||||
Returns:
|
||||
Performance statistics or None
|
||||
"""
|
||||
return self.component_stats.get(component)
|
||||
|
||||
def get_system_health(self) -> SystemHealth:
|
||||
"""Get overall system health status.
|
||||
|
||||
Returns:
|
||||
System health metrics
|
||||
"""
|
||||
# Calculate uptime
|
||||
uptime = time.time() - self.start_time
|
||||
|
||||
# Get error rate
|
||||
total_requests = self.metrics_collector.get_counter('requests_total')
|
||||
failed_requests = self.metrics_collector.get_counter('requests_failed')
|
||||
error_rate = failed_requests / total_requests if total_requests > 0 else 0.0
|
||||
|
||||
# Get cache hit rate
|
||||
cache_hits = self.metrics_collector.get_counter('cache_hits_total')
|
||||
cache_requests = self.metrics_collector.get_counter('cache_requests_total')
|
||||
cache_hit_rate = cache_hits / cache_requests if cache_requests > 0 else 0.0
|
||||
|
||||
# Determine status
|
||||
status = "healthy"
|
||||
if error_rate > 0.1: # More than 10% error rate
|
||||
status = "degraded"
|
||||
if error_rate > 0.3: # More than 30% error rate
|
||||
status = "unhealthy"
|
||||
|
||||
return SystemHealth(
|
||||
status=status,
|
||||
uptime_seconds=uptime,
|
||||
error_rate=error_rate,
|
||||
cache_hit_rate=cache_hit_rate
|
||||
)
|
||||
|
||||
def get_performance_report(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive performance report.
|
||||
|
||||
Returns:
|
||||
Performance report
|
||||
"""
|
||||
report = {
|
||||
'system_health': self.get_system_health(),
|
||||
'component_stats': {},
|
||||
'top_slow_operations': [],
|
||||
'error_patterns': {},
|
||||
'cache_performance': {},
|
||||
'ontology_usage': {}
|
||||
}
|
||||
|
||||
# Component statistics
|
||||
for component, stats in self.component_stats.items():
|
||||
report['component_stats'][component] = stats
|
||||
|
||||
# Top slow operations
|
||||
timer_stats = {}
|
||||
for metric_name in self.metrics_collector.timers.keys():
|
||||
if 'request_duration' in metric_name:
|
||||
stats = self.metrics_collector.get_timer_stats(metric_name)
|
||||
if stats:
|
||||
timer_stats[metric_name] = stats
|
||||
|
||||
# Sort by p95 latency
|
||||
slow_ops = sorted(
|
||||
timer_stats.items(),
|
||||
key=lambda x: x[1].get('p95', 0),
|
||||
reverse=True
|
||||
)[:10]
|
||||
|
||||
report['top_slow_operations'] = [
|
||||
{'operation': op, 'stats': stats} for op, stats in slow_ops
|
||||
]
|
||||
|
||||
# Cache performance
|
||||
cache_types = set()
|
||||
for metric_name in self.metrics_collector.counters.keys():
|
||||
if 'cache_type=' in metric_name:
|
||||
cache_type = metric_name.split('cache_type=')[1].split(',')[0].split('}')[0]
|
||||
cache_types.add(cache_type)
|
||||
|
||||
for cache_type in cache_types:
|
||||
labels = {'cache_type': cache_type}
|
||||
hits = self.metrics_collector.get_counter('cache_hits_total', labels)
|
||||
requests = self.metrics_collector.get_counter('cache_requests_total', labels)
|
||||
hit_rate = hits / requests if requests > 0 else 0.0
|
||||
|
||||
report['cache_performance'][cache_type] = {
|
||||
'hit_rate': hit_rate,
|
||||
'total_requests': requests,
|
||||
'total_hits': hits
|
||||
}
|
||||
|
||||
return report
|
||||
|
||||
def _update_component_stats(self, component: str, duration: float, success: bool):
|
||||
"""Update component performance statistics."""
|
||||
if component not in self.component_stats:
|
||||
self.component_stats[component] = PerformanceStats()
|
||||
|
||||
stats = self.component_stats[component]
|
||||
stats.total_requests += 1
|
||||
|
||||
if success:
|
||||
stats.successful_requests += 1
|
||||
else:
|
||||
stats.failed_requests += 1
|
||||
|
||||
# Update response time stats
|
||||
stats.min_response_time = min(stats.min_response_time, duration)
|
||||
stats.max_response_time = max(stats.max_response_time, duration)
|
||||
|
||||
# Get timer stats for percentiles
|
||||
timer_stats = self.metrics_collector.get_timer_stats(
|
||||
'request_duration', {'component': component}
|
||||
)
|
||||
|
||||
if timer_stats:
|
||||
stats.avg_response_time = timer_stats.get('avg', 0.0)
|
||||
stats.p95_response_time = timer_stats.get('p95', 0.0)
|
||||
stats.p99_response_time = timer_stats.get('p99', 0.0)
|
||||
|
||||
# Calculate rates
|
||||
stats.error_rate = stats.failed_requests / stats.total_requests
|
||||
|
||||
# Calculate throughput (requests per second over last minute)
|
||||
recent_requests = len([
|
||||
m for m in self.metrics_collector.get_metrics('requests_total')
|
||||
if m.labels.get('component') == component and
|
||||
m.timestamp > datetime.now() - timedelta(minutes=1)
|
||||
])
|
||||
stats.throughput_per_second = recent_requests / 60.0
|
||||
|
||||
def _start_background_tasks(self):
|
||||
"""Start background monitoring tasks."""
|
||||
def cleanup_worker():
|
||||
"""Worker to clean up old metrics."""
|
||||
while True:
|
||||
try:
|
||||
time.sleep(300) # 5 minutes
|
||||
self.metrics_collector.cleanup_old_metrics()
|
||||
except Exception as e:
|
||||
logger.error(f"Metrics cleanup error: {e}")
|
||||
|
||||
cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
|
||||
cleanup_thread.start()
|
||||
|
||||
|
||||
# Monitoring decorators
|
||||
|
||||
def monitor_performance(component: str,
|
||||
operation: str,
|
||||
monitor: Optional[PerformanceMonitor] = None):
|
||||
"""Decorator to monitor function performance.
|
||||
|
||||
Args:
|
||||
component: Component name
|
||||
operation: Operation name
|
||||
monitor: Performance monitor instance
|
||||
"""
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
if not monitor or not monitor.monitoring_enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
timer = monitor.metrics_collector.start_timer(
|
||||
'request_duration',
|
||||
{'component': component, 'operation': operation}
|
||||
)
|
||||
|
||||
success = True
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as e:
|
||||
success = False
|
||||
raise
|
||||
finally:
|
||||
duration = monitor.metrics_collector.stop_timer(timer)
|
||||
monitor.record_request(component, operation, duration, success)
|
||||
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
if not monitor or not monitor.monitoring_enabled:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
timer = monitor.metrics_collector.start_timer(
|
||||
'request_duration',
|
||||
{'component': component, 'operation': operation}
|
||||
)
|
||||
|
||||
success = True
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
result = await func(*args, **kwargs)
|
||||
else:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as e:
|
||||
success = False
|
||||
raise
|
||||
finally:
|
||||
duration = monitor.metrics_collector.stop_timer(timer)
|
||||
monitor.record_request(component, operation, duration, success)
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class QueryPatternAnalyzer:
|
||||
"""Analyzes query patterns for optimization insights."""
|
||||
|
||||
def __init__(self, monitor: PerformanceMonitor):
|
||||
"""Initialize query pattern analyzer.
|
||||
|
||||
Args:
|
||||
monitor: Performance monitor instance
|
||||
"""
|
||||
self.monitor = monitor
|
||||
self.query_patterns: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
|
||||
|
||||
def record_query_pattern(self,
|
||||
question_type: str,
|
||||
entities: List[str],
|
||||
complexity: float,
|
||||
backend: str,
|
||||
duration: float,
|
||||
success: bool):
|
||||
"""Record a query pattern for analysis.
|
||||
|
||||
Args:
|
||||
question_type: Type of question
|
||||
entities: Entities in question
|
||||
complexity: Query complexity score
|
||||
backend: Backend used
|
||||
duration: Query duration
|
||||
success: Whether query succeeded
|
||||
"""
|
||||
pattern = {
|
||||
'timestamp': datetime.now(),
|
||||
'question_type': question_type,
|
||||
'entity_count': len(entities),
|
||||
'entities': entities,
|
||||
'complexity': complexity,
|
||||
'backend': backend,
|
||||
'duration': duration,
|
||||
'success': success
|
||||
}
|
||||
|
||||
pattern_key = f"{question_type}:{len(entities)}"
|
||||
self.query_patterns[pattern_key].append(pattern)
|
||||
|
||||
# Keep only recent patterns
|
||||
cutoff_time = datetime.now() - timedelta(hours=24)
|
||||
self.query_patterns[pattern_key] = [
|
||||
p for p in self.query_patterns[pattern_key]
|
||||
if p['timestamp'] > cutoff_time
|
||||
]
|
||||
|
||||
def get_optimization_insights(self) -> Dict[str, Any]:
|
||||
"""Get insights for query optimization.
|
||||
|
||||
Returns:
|
||||
Optimization insights and recommendations
|
||||
"""
|
||||
insights = {
|
||||
'slow_patterns': [],
|
||||
'common_failures': [],
|
||||
'backend_performance': {},
|
||||
'complexity_analysis': {},
|
||||
'recommendations': []
|
||||
}
|
||||
|
||||
# Analyze slow patterns
|
||||
for pattern_key, patterns in self.query_patterns.items():
|
||||
if not patterns:
|
||||
continue
|
||||
|
||||
avg_duration = statistics.mean([p['duration'] for p in patterns])
|
||||
success_rate = sum(1 for p in patterns if p['success']) / len(patterns)
|
||||
|
||||
if avg_duration > 5.0: # Slow queries > 5 seconds
|
||||
insights['slow_patterns'].append({
|
||||
'pattern': pattern_key,
|
||||
'avg_duration': avg_duration,
|
||||
'count': len(patterns),
|
||||
'success_rate': success_rate
|
||||
})
|
||||
|
||||
if success_rate < 0.8: # Low success rate
|
||||
insights['common_failures'].append({
|
||||
'pattern': pattern_key,
|
||||
'success_rate': success_rate,
|
||||
'count': len(patterns)
|
||||
})
|
||||
|
||||
# Analyze backend performance
|
||||
backend_stats = defaultdict(list)
|
||||
for patterns in self.query_patterns.values():
|
||||
for pattern in patterns:
|
||||
backend_stats[pattern['backend']].append(pattern['duration'])
|
||||
|
||||
for backend, durations in backend_stats.items():
|
||||
insights['backend_performance'][backend] = {
|
||||
'avg_duration': statistics.mean(durations),
|
||||
'p95_duration': sorted(durations)[int(len(durations) * 0.95)],
|
||||
'query_count': len(durations)
|
||||
}
|
||||
|
||||
# Generate recommendations
|
||||
recommendations = []
|
||||
|
||||
# Slow pattern recommendations
|
||||
for slow_pattern in insights['slow_patterns']:
|
||||
recommendations.append(
|
||||
f"Consider optimizing {slow_pattern['pattern']} queries - "
|
||||
f"average duration {slow_pattern['avg_duration']:.2f}s"
|
||||
)
|
||||
|
||||
# Backend recommendations
|
||||
if len(insights['backend_performance']) > 1:
|
||||
fastest_backend = min(
|
||||
insights['backend_performance'].items(),
|
||||
key=lambda x: x[1]['avg_duration']
|
||||
)[0]
|
||||
recommendations.append(
|
||||
f"Consider routing more queries to {fastest_backend} "
|
||||
f"for better performance"
|
||||
)
|
||||
|
||||
insights['recommendations'] = recommendations
|
||||
|
||||
return insights
|
||||
656
trustgraph-flow/trustgraph/query/ontology/multi_language.py
Normal file
656
trustgraph-flow/trustgraph/query/ontology/multi_language.py
Normal file
|
|
@ -0,0 +1,656 @@
|
|||
"""
|
||||
Multi-language support for OntoRAG.
|
||||
Provides language detection, translation, and multilingual query processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Language(Enum):
|
||||
"""Supported languages."""
|
||||
ENGLISH = "en"
|
||||
SPANISH = "es"
|
||||
FRENCH = "fr"
|
||||
GERMAN = "de"
|
||||
ITALIAN = "it"
|
||||
PORTUGUESE = "pt"
|
||||
CHINESE = "zh"
|
||||
JAPANESE = "ja"
|
||||
KOREAN = "ko"
|
||||
ARABIC = "ar"
|
||||
RUSSIAN = "ru"
|
||||
DUTCH = "nl"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageDetectionResult:
|
||||
"""Language detection result."""
|
||||
language: Language
|
||||
confidence: float
|
||||
detected_text: str
|
||||
alternative_languages: List[Tuple[Language, float]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranslationResult:
|
||||
"""Translation result."""
|
||||
original_text: str
|
||||
translated_text: str
|
||||
source_language: Language
|
||||
target_language: Language
|
||||
confidence: float
|
||||
|
||||
|
||||
class LanguageDetector:
|
||||
"""Detects language of input text."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None):
|
||||
"""Initialize language detector.
|
||||
|
||||
Args:
|
||||
config: Detector configuration
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.default_language = Language(self.config.get('default_language', 'en'))
|
||||
self.confidence_threshold = self.config.get('confidence_threshold', 0.7)
|
||||
|
||||
# Try to import language detection libraries
|
||||
self.detector = None
|
||||
self._init_detector()
|
||||
|
||||
def _init_detector(self):
|
||||
"""Initialize language detection backend."""
|
||||
try:
|
||||
# Try langdetect first
|
||||
import langdetect
|
||||
self.detector = 'langdetect'
|
||||
logger.info("Using langdetect for language detection")
|
||||
except ImportError:
|
||||
try:
|
||||
# Try textblob as fallback
|
||||
from textblob import TextBlob
|
||||
self.detector = 'textblob'
|
||||
logger.info("Using TextBlob for language detection")
|
||||
except ImportError:
|
||||
logger.warning("No language detection library available, using rule-based detection")
|
||||
self.detector = 'rule_based'
|
||||
|
||||
def detect_language(self, text: str) -> LanguageDetectionResult:
|
||||
"""Detect language of input text.
|
||||
|
||||
Args:
|
||||
text: Text to analyze
|
||||
|
||||
Returns:
|
||||
Language detection result
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return LanguageDetectionResult(
|
||||
language=self.default_language,
|
||||
confidence=0.0,
|
||||
detected_text=text
|
||||
)
|
||||
|
||||
try:
|
||||
if self.detector == 'langdetect':
|
||||
return self._detect_with_langdetect(text)
|
||||
elif self.detector == 'textblob':
|
||||
return self._detect_with_textblob(text)
|
||||
else:
|
||||
return self._detect_with_rules(text)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Language detection failed: {e}")
|
||||
return LanguageDetectionResult(
|
||||
language=self.default_language,
|
||||
confidence=0.0,
|
||||
detected_text=text
|
||||
)
|
||||
|
||||
def _detect_with_langdetect(self, text: str) -> LanguageDetectionResult:
|
||||
"""Detect language using langdetect library."""
|
||||
import langdetect
|
||||
from langdetect.lang_detect_exception import LangDetectException
|
||||
|
||||
try:
|
||||
# Get detailed detection results
|
||||
probabilities = langdetect.detect_langs(text)
|
||||
|
||||
if not probabilities:
|
||||
return LanguageDetectionResult(
|
||||
language=self.default_language,
|
||||
confidence=0.0,
|
||||
detected_text=text
|
||||
)
|
||||
|
||||
best_match = probabilities[0]
|
||||
detected_lang_code = best_match.lang
|
||||
confidence = best_match.prob
|
||||
|
||||
# Map to our Language enum
|
||||
try:
|
||||
detected_language = Language(detected_lang_code)
|
||||
except ValueError:
|
||||
# Map common variations
|
||||
lang_mapping = {
|
||||
'ca': Language.SPANISH, # Catalan -> Spanish
|
||||
'eu': Language.SPANISH, # Basque -> Spanish
|
||||
'gl': Language.SPANISH, # Galician -> Spanish
|
||||
'zh-cn': Language.CHINESE,
|
||||
'zh-tw': Language.CHINESE,
|
||||
}
|
||||
detected_language = lang_mapping.get(detected_lang_code, self.default_language)
|
||||
|
||||
# Get alternatives
|
||||
alternatives = []
|
||||
for lang_prob in probabilities[1:3]: # Top 3 alternatives
|
||||
try:
|
||||
alt_lang = Language(lang_prob.lang)
|
||||
alternatives.append((alt_lang, lang_prob.prob))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return LanguageDetectionResult(
|
||||
language=detected_language,
|
||||
confidence=confidence,
|
||||
detected_text=text,
|
||||
alternative_languages=alternatives
|
||||
)
|
||||
|
||||
except LangDetectException:
|
||||
return LanguageDetectionResult(
|
||||
language=self.default_language,
|
||||
confidence=0.0,
|
||||
detected_text=text
|
||||
)
|
||||
|
||||
def _detect_with_textblob(self, text: str) -> LanguageDetectionResult:
|
||||
"""Detect language using TextBlob."""
|
||||
from textblob import TextBlob
|
||||
|
||||
try:
|
||||
blob = TextBlob(text)
|
||||
detected_lang_code = blob.detect_language()
|
||||
|
||||
try:
|
||||
detected_language = Language(detected_lang_code)
|
||||
except ValueError:
|
||||
detected_language = self.default_language
|
||||
|
||||
# TextBlob doesn't provide confidence, so estimate based on text length
|
||||
confidence = min(0.8, len(text) / 100.0) if len(text) > 10 else 0.5
|
||||
|
||||
return LanguageDetectionResult(
|
||||
language=detected_language,
|
||||
confidence=confidence,
|
||||
detected_text=text
|
||||
)
|
||||
|
||||
except Exception:
|
||||
return LanguageDetectionResult(
|
||||
language=self.default_language,
|
||||
confidence=0.0,
|
||||
detected_text=text
|
||||
)
|
||||
|
||||
def _detect_with_rules(self, text: str) -> LanguageDetectionResult:
|
||||
"""Rule-based language detection fallback."""
|
||||
text_lower = text.lower()
|
||||
|
||||
# Simple keyword-based detection
|
||||
language_keywords = {
|
||||
Language.SPANISH: ['qué', 'cuál', 'cuándo', 'dónde', 'cómo', 'por qué', 'cuántos'],
|
||||
Language.FRENCH: ['que', 'quel', 'quand', 'où', 'comment', 'pourquoi', 'combien'],
|
||||
Language.GERMAN: ['was', 'welche', 'wann', 'wo', 'wie', 'warum', 'wieviele'],
|
||||
Language.ITALIAN: ['che', 'quale', 'quando', 'dove', 'come', 'perché', 'quanti'],
|
||||
Language.PORTUGUESE: ['que', 'qual', 'quando', 'onde', 'como', 'por que', 'quantos'],
|
||||
Language.DUTCH: ['wat', 'welke', 'wanneer', 'waar', 'hoe', 'waarom', 'hoeveel']
|
||||
}
|
||||
|
||||
best_match = self.default_language
|
||||
best_score = 0
|
||||
|
||||
for language, keywords in language_keywords.items():
|
||||
score = sum(1 for keyword in keywords if keyword in text_lower)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_match = language
|
||||
|
||||
confidence = min(0.8, best_score / 3.0) if best_score > 0 else 0.1
|
||||
|
||||
return LanguageDetectionResult(
|
||||
language=best_match,
|
||||
confidence=confidence,
|
||||
detected_text=text
|
||||
)
|
||||
|
||||
|
||||
class TextTranslator:
|
||||
"""Translates text between languages."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None):
|
||||
"""Initialize text translator.
|
||||
|
||||
Args:
|
||||
config: Translator configuration
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.translator = None
|
||||
self._init_translator()
|
||||
|
||||
def _init_translator(self):
|
||||
"""Initialize translation backend."""
|
||||
try:
|
||||
# Try Google Translate first
|
||||
from googletrans import Translator
|
||||
self.translator = Translator()
|
||||
self.backend = 'googletrans'
|
||||
logger.info("Using Google Translate for translation")
|
||||
except ImportError:
|
||||
try:
|
||||
# Try TextBlob as fallback
|
||||
from textblob import TextBlob
|
||||
self.backend = 'textblob'
|
||||
logger.info("Using TextBlob for translation")
|
||||
except ImportError:
|
||||
logger.warning("No translation library available")
|
||||
self.backend = None
|
||||
|
||||
def translate(self,
|
||||
text: str,
|
||||
target_language: Language,
|
||||
source_language: Optional[Language] = None) -> TranslationResult:
|
||||
"""Translate text to target language.
|
||||
|
||||
Args:
|
||||
text: Text to translate
|
||||
target_language: Target language
|
||||
source_language: Source language (auto-detect if None)
|
||||
|
||||
Returns:
|
||||
Translation result
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return TranslationResult(
|
||||
original_text=text,
|
||||
translated_text=text,
|
||||
source_language=source_language or Language.ENGLISH,
|
||||
target_language=target_language,
|
||||
confidence=0.0
|
||||
)
|
||||
|
||||
try:
|
||||
if self.backend == 'googletrans':
|
||||
return self._translate_with_googletrans(text, target_language, source_language)
|
||||
elif self.backend == 'textblob':
|
||||
return self._translate_with_textblob(text, target_language, source_language)
|
||||
else:
|
||||
# No translation available
|
||||
return TranslationResult(
|
||||
original_text=text,
|
||||
translated_text=text,
|
||||
source_language=source_language or Language.ENGLISH,
|
||||
target_language=target_language,
|
||||
confidence=0.0
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Translation failed: {e}")
|
||||
return TranslationResult(
|
||||
original_text=text,
|
||||
translated_text=text,
|
||||
source_language=source_language or Language.ENGLISH,
|
||||
target_language=target_language,
|
||||
confidence=0.0
|
||||
)
|
||||
|
||||
def _translate_with_googletrans(self,
|
||||
text: str,
|
||||
target_language: Language,
|
||||
source_language: Optional[Language]) -> TranslationResult:
|
||||
"""Translate using Google Translate."""
|
||||
try:
|
||||
src_code = source_language.value if source_language else 'auto'
|
||||
dest_code = target_language.value
|
||||
|
||||
result = self.translator.translate(text, src=src_code, dest=dest_code)
|
||||
|
||||
detected_source = Language(result.src) if result.src != 'auto' else Language.ENGLISH
|
||||
confidence = 0.9 # Google Translate is generally reliable
|
||||
|
||||
return TranslationResult(
|
||||
original_text=text,
|
||||
translated_text=result.text,
|
||||
source_language=detected_source,
|
||||
target_language=target_language,
|
||||
confidence=confidence
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Google Translate error: {e}")
|
||||
raise
|
||||
|
||||
def _translate_with_textblob(self,
|
||||
text: str,
|
||||
target_language: Language,
|
||||
source_language: Optional[Language]) -> TranslationResult:
|
||||
"""Translate using TextBlob."""
|
||||
from textblob import TextBlob
|
||||
|
||||
try:
|
||||
blob = TextBlob(text)
|
||||
|
||||
if not source_language:
|
||||
# Auto-detect source language
|
||||
detected_lang = blob.detect_language()
|
||||
try:
|
||||
source_language = Language(detected_lang)
|
||||
except ValueError:
|
||||
source_language = Language.ENGLISH
|
||||
|
||||
translated_blob = blob.translate(to=target_language.value)
|
||||
translated_text = str(translated_blob)
|
||||
|
||||
# TextBlob confidence estimation
|
||||
confidence = 0.7 if len(text) > 10 else 0.5
|
||||
|
||||
return TranslationResult(
|
||||
original_text=text,
|
||||
translated_text=translated_text,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
confidence=confidence
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TextBlob translation error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class MultiLanguageQueryProcessor:
|
||||
"""Processes queries in multiple languages."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None):
|
||||
"""Initialize multi-language query processor.
|
||||
|
||||
Args:
|
||||
config: Processor configuration
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.language_detector = LanguageDetector(config.get('language_detection', {}))
|
||||
self.translator = TextTranslator(config.get('translation', {}))
|
||||
self.supported_languages = [Language(lang) for lang in config.get('supported_languages', ['en'])]
|
||||
self.primary_language = Language(config.get('primary_language', 'en'))
|
||||
|
||||
async def process_multilingual_query(self, question: str) -> Dict[str, Any]:
|
||||
"""Process a query in any supported language.
|
||||
|
||||
Args:
|
||||
question: Question in any language
|
||||
|
||||
Returns:
|
||||
Processing result with language information
|
||||
"""
|
||||
# Step 1: Detect language
|
||||
detection_result = self.language_detector.detect_language(question)
|
||||
detected_language = detection_result.language
|
||||
|
||||
logger.info(f"Detected language: {detected_language.value} "
|
||||
f"(confidence: {detection_result.confidence:.2f})")
|
||||
|
||||
# Step 2: Translate to primary language if needed
|
||||
translated_question = question
|
||||
translation_result = None
|
||||
|
||||
if detected_language != self.primary_language:
|
||||
if detection_result.confidence >= self.language_detector.confidence_threshold:
|
||||
translation_result = self.translator.translate(
|
||||
question, self.primary_language, detected_language
|
||||
)
|
||||
translated_question = translation_result.translated_text
|
||||
logger.info(f"Translated question: {translated_question}")
|
||||
else:
|
||||
logger.warning(f"Low confidence language detection, processing in {self.primary_language.value}")
|
||||
|
||||
# Step 3: Return processing information
|
||||
return {
|
||||
'original_question': question,
|
||||
'translated_question': translated_question,
|
||||
'detected_language': detected_language,
|
||||
'detection_confidence': detection_result.confidence,
|
||||
'translation_result': translation_result,
|
||||
'processing_language': self.primary_language,
|
||||
'alternative_languages': detection_result.alternative_languages
|
||||
}
|
||||
|
||||
async def translate_answer(self,
|
||||
answer: str,
|
||||
target_language: Language) -> TranslationResult:
|
||||
"""Translate answer back to target language.
|
||||
|
||||
Args:
|
||||
answer: Answer in primary language
|
||||
target_language: Target language for answer
|
||||
|
||||
Returns:
|
||||
Translation result
|
||||
"""
|
||||
if target_language == self.primary_language:
|
||||
# No translation needed
|
||||
return TranslationResult(
|
||||
original_text=answer,
|
||||
translated_text=answer,
|
||||
source_language=self.primary_language,
|
||||
target_language=target_language,
|
||||
confidence=1.0
|
||||
)
|
||||
|
||||
return self.translator.translate(answer, target_language, self.primary_language)
|
||||
|
||||
def get_language_specific_ontology_terms(self,
|
||||
ontology_subset: Dict[str, Any],
|
||||
language: Language) -> Dict[str, Any]:
|
||||
"""Get language-specific terms from ontology.
|
||||
|
||||
Args:
|
||||
ontology_subset: Ontology subset
|
||||
language: Target language
|
||||
|
||||
Returns:
|
||||
Language-specific ontology terms
|
||||
"""
|
||||
# Extract language-specific labels and descriptions
|
||||
lang_code = language.value
|
||||
result = {}
|
||||
|
||||
# Process classes
|
||||
if 'classes' in ontology_subset:
|
||||
result['classes'] = {}
|
||||
for class_id, class_def in ontology_subset['classes'].items():
|
||||
lang_labels = []
|
||||
if 'labels' in class_def:
|
||||
for label in class_def['labels']:
|
||||
if isinstance(label, dict) and label.get('language') == lang_code:
|
||||
lang_labels.append(label['value'])
|
||||
elif isinstance(label, str):
|
||||
lang_labels.append(label)
|
||||
|
||||
result['classes'][class_id] = {
|
||||
**class_def,
|
||||
'language_labels': lang_labels
|
||||
}
|
||||
|
||||
# Process properties
|
||||
for prop_type in ['object_properties', 'datatype_properties']:
|
||||
if prop_type in ontology_subset:
|
||||
result[prop_type] = {}
|
||||
for prop_id, prop_def in ontology_subset[prop_type].items():
|
||||
lang_labels = []
|
||||
if 'labels' in prop_def:
|
||||
for label in prop_def['labels']:
|
||||
if isinstance(label, dict) and label.get('language') == lang_code:
|
||||
lang_labels.append(label['value'])
|
||||
elif isinstance(label, str):
|
||||
lang_labels.append(label)
|
||||
|
||||
result[prop_type][prop_id] = {
|
||||
**prop_def,
|
||||
'language_labels': lang_labels
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def is_language_supported(self, language: Language) -> bool:
|
||||
"""Check if language is supported.
|
||||
|
||||
Args:
|
||||
language: Language to check
|
||||
|
||||
Returns:
|
||||
True if language is supported
|
||||
"""
|
||||
return language in self.supported_languages
|
||||
|
||||
def get_supported_languages(self) -> List[Language]:
|
||||
"""Get list of supported languages.
|
||||
|
||||
Returns:
|
||||
List of supported languages
|
||||
"""
|
||||
return self.supported_languages.copy()
|
||||
|
||||
def add_language_support(self, language: Language):
|
||||
"""Add support for a new language.
|
||||
|
||||
Args:
|
||||
language: Language to add support for
|
||||
"""
|
||||
if language not in self.supported_languages:
|
||||
self.supported_languages.append(language)
|
||||
logger.info(f"Added support for language: {language.value}")
|
||||
|
||||
def remove_language_support(self, language: Language):
|
||||
"""Remove support for a language.
|
||||
|
||||
Args:
|
||||
language: Language to remove support for
|
||||
"""
|
||||
if language in self.supported_languages and language != self.primary_language:
|
||||
self.supported_languages.remove(language)
|
||||
logger.info(f"Removed support for language: {language.value}")
|
||||
else:
|
||||
logger.warning(f"Cannot remove primary language or unsupported language: {language.value}")
|
||||
|
||||
|
||||
class LanguageSpecificTemplates:
|
||||
"""Manages language-specific query and answer templates."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize language-specific templates."""
|
||||
self.question_templates = {
|
||||
Language.ENGLISH: {
|
||||
'count': ['how many', 'count of', 'number of'],
|
||||
'boolean': ['is', 'are', 'does', 'can', 'will'],
|
||||
'retrieval': ['what', 'which', 'who', 'where'],
|
||||
'factual': ['tell me about', 'describe', 'explain']
|
||||
},
|
||||
Language.SPANISH: {
|
||||
'count': ['cuántos', 'cuántas', 'número de', 'cantidad de'],
|
||||
'boolean': ['es', 'son', 'está', 'están', 'puede', 'pueden'],
|
||||
'retrieval': ['qué', 'cuál', 'cuáles', 'quién', 'dónde'],
|
||||
'factual': ['dime sobre', 'describe', 'explica']
|
||||
},
|
||||
Language.FRENCH: {
|
||||
'count': ['combien', 'nombre de', 'quantité de'],
|
||||
'boolean': ['est', 'sont', 'peut', 'peuvent'],
|
||||
'retrieval': ['que', 'quel', 'quelle', 'qui', 'où'],
|
||||
'factual': ['dis-moi sur', 'décris', 'explique']
|
||||
},
|
||||
Language.GERMAN: {
|
||||
'count': ['wie viele', 'anzahl der', 'zahl der'],
|
||||
'boolean': ['ist', 'sind', 'kann', 'können'],
|
||||
'retrieval': ['was', 'welche', 'wer', 'wo'],
|
||||
'factual': ['erzähl mir über', 'beschreibe', 'erkläre']
|
||||
}
|
||||
}
|
||||
|
||||
self.answer_templates = {
|
||||
Language.ENGLISH: {
|
||||
'count': 'There are {count} {entity}.',
|
||||
'boolean_true': 'Yes, {statement}.',
|
||||
'boolean_false': 'No, {statement}.',
|
||||
'not_found': 'No information found.',
|
||||
'error': 'Sorry, I encountered an error.'
|
||||
},
|
||||
Language.SPANISH: {
|
||||
'count': 'Hay {count} {entity}.',
|
||||
'boolean_true': 'Sí, {statement}.',
|
||||
'boolean_false': 'No, {statement}.',
|
||||
'not_found': 'No se encontró información.',
|
||||
'error': 'Lo siento, encontré un error.'
|
||||
},
|
||||
Language.FRENCH: {
|
||||
'count': 'Il y a {count} {entity}.',
|
||||
'boolean_true': 'Oui, {statement}.',
|
||||
'boolean_false': 'Non, {statement}.',
|
||||
'not_found': 'Aucune information trouvée.',
|
||||
'error': 'Désolé, j\'ai rencontré une erreur.'
|
||||
},
|
||||
Language.GERMAN: {
|
||||
'count': 'Es gibt {count} {entity}.',
|
||||
'boolean_true': 'Ja, {statement}.',
|
||||
'boolean_false': 'Nein, {statement}.',
|
||||
'not_found': 'Keine Informationen gefunden.',
|
||||
'error': 'Entschuldigung, ich bin auf einen Fehler gestoßen.'
|
||||
}
|
||||
}
|
||||
|
||||
def get_question_patterns(self, language: Language) -> Dict[str, List[str]]:
|
||||
"""Get question patterns for a language.
|
||||
|
||||
Args:
|
||||
language: Target language
|
||||
|
||||
Returns:
|
||||
Dictionary of question patterns
|
||||
"""
|
||||
return self.question_templates.get(language, self.question_templates[Language.ENGLISH])
|
||||
|
||||
def get_answer_template(self, language: Language, template_type: str) -> str:
|
||||
"""Get answer template for a language and type.
|
||||
|
||||
Args:
|
||||
language: Target language
|
||||
template_type: Template type
|
||||
|
||||
Returns:
|
||||
Answer template string
|
||||
"""
|
||||
templates = self.answer_templates.get(language, self.answer_templates[Language.ENGLISH])
|
||||
return templates.get(template_type, templates.get('error', 'Error'))
|
||||
|
||||
def format_answer(self,
|
||||
language: Language,
|
||||
template_type: str,
|
||||
**kwargs) -> str:
|
||||
"""Format answer using language-specific template.
|
||||
|
||||
Args:
|
||||
language: Target language
|
||||
template_type: Template type
|
||||
**kwargs: Template variables
|
||||
|
||||
Returns:
|
||||
Formatted answer
|
||||
"""
|
||||
template = self.get_answer_template(language, template_type)
|
||||
try:
|
||||
return template.format(**kwargs)
|
||||
except KeyError as e:
|
||||
logger.error(f"Missing template variable: {e}")
|
||||
return self.get_answer_template(language, 'error')
|
||||
256
trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py
Normal file
256
trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
"""
|
||||
Ontology matcher for query system.
|
||||
Identifies relevant ontology subsets for answering questions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Set, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ...extract.kg.ontology.ontology_loader import Ontology, OntologyLoader
|
||||
from ...extract.kg.ontology.ontology_embedder import OntologyEmbedder
|
||||
from ...extract.kg.ontology.text_processor import TextSegment
|
||||
from ...extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset
|
||||
from .question_analyzer import QuestionComponents, QuestionType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryOntologySubset(OntologySubset):
|
||||
"""Extended ontology subset for query processing."""
|
||||
traversal_properties: Dict[str, Any] = None # Additional properties for graph traversal
|
||||
inference_rules: List[Dict[str, Any]] = None # Inference rules for reasoning
|
||||
|
||||
|
||||
class OntologyMatcherForQueries(OntologySelector):
|
||||
"""
|
||||
Specialized ontology matcher for question answering.
|
||||
Extends OntologySelector with query-specific logic.
|
||||
"""
|
||||
|
||||
def __init__(self, ontology_embedder: OntologyEmbedder,
|
||||
ontology_loader: OntologyLoader,
|
||||
top_k: int = 15, # Higher k for queries
|
||||
similarity_threshold: float = 0.6): # Lower threshold for broader coverage
|
||||
"""Initialize query-specific ontology matcher.
|
||||
|
||||
Args:
|
||||
ontology_embedder: Embedder with vector store
|
||||
ontology_loader: Loader with ontology definitions
|
||||
top_k: Number of top results to retrieve
|
||||
similarity_threshold: Minimum similarity score
|
||||
"""
|
||||
super().__init__(ontology_embedder, ontology_loader, top_k, similarity_threshold)
|
||||
|
||||
async def match_question_to_ontology(self,
|
||||
question_components: QuestionComponents,
|
||||
question_segments: List[str]) -> List[QueryOntologySubset]:
|
||||
"""Match question components to relevant ontology elements.
|
||||
|
||||
Args:
|
||||
question_components: Analyzed question components
|
||||
question_segments: Text segments from question
|
||||
|
||||
Returns:
|
||||
List of query-optimized ontology subsets
|
||||
"""
|
||||
# Convert question segments to TextSegment objects
|
||||
text_segments = [
|
||||
TextSegment(text=seg, type='question', position=i)
|
||||
for i, seg in enumerate(question_segments)
|
||||
]
|
||||
|
||||
# Get base ontology subsets using parent class method
|
||||
base_subsets = await self.select_ontology_subset(text_segments)
|
||||
|
||||
# Enhance subsets for query processing
|
||||
query_subsets = []
|
||||
for subset in base_subsets:
|
||||
query_subset = self._enhance_for_query(subset, question_components)
|
||||
query_subsets.append(query_subset)
|
||||
|
||||
return query_subsets
|
||||
|
||||
def _enhance_for_query(self, subset: OntologySubset,
|
||||
question_components: QuestionComponents) -> QueryOntologySubset:
|
||||
"""Enhance ontology subset with query-specific elements.
|
||||
|
||||
Args:
|
||||
subset: Base ontology subset
|
||||
question_components: Analyzed question components
|
||||
|
||||
Returns:
|
||||
Enhanced query ontology subset
|
||||
"""
|
||||
# Create query subset
|
||||
query_subset = QueryOntologySubset(
|
||||
ontology_id=subset.ontology_id,
|
||||
classes=dict(subset.classes),
|
||||
object_properties=dict(subset.object_properties),
|
||||
datatype_properties=dict(subset.datatype_properties),
|
||||
metadata=subset.metadata,
|
||||
relevance_score=subset.relevance_score,
|
||||
traversal_properties={},
|
||||
inference_rules=[]
|
||||
)
|
||||
|
||||
# Add traversal properties based on question type
|
||||
self._add_traversal_properties(query_subset, question_components)
|
||||
|
||||
# Add related properties for exploration
|
||||
self._add_related_properties(query_subset)
|
||||
|
||||
# Add inference rules if needed
|
||||
self._add_inference_rules(query_subset, question_components)
|
||||
|
||||
return query_subset
|
||||
|
||||
def _add_traversal_properties(self, subset: QueryOntologySubset,
|
||||
question_components: QuestionComponents):
|
||||
"""Add properties useful for graph traversal.
|
||||
|
||||
Args:
|
||||
subset: Query ontology subset to enhance
|
||||
question_components: Question analysis
|
||||
"""
|
||||
ontology = self.loader.get_ontology(subset.ontology_id)
|
||||
if not ontology:
|
||||
return
|
||||
|
||||
# For relationship questions, add all properties connecting mentioned classes
|
||||
if question_components.question_type == QuestionType.RELATIONSHIP:
|
||||
for prop_id, prop_def in ontology.object_properties.items():
|
||||
domain = prop_def.domain
|
||||
range_val = prop_def.range
|
||||
|
||||
# Check if property connects relevant classes
|
||||
if domain in subset.classes or range_val in subset.classes:
|
||||
if prop_id not in subset.object_properties:
|
||||
subset.traversal_properties[prop_id] = prop_def.__dict__
|
||||
logger.debug(f"Added traversal property: {prop_id}")
|
||||
|
||||
# For retrieval questions, add properties that might filter results
|
||||
elif question_components.question_type == QuestionType.RETRIEVAL:
|
||||
# Add all properties with domains in our classes
|
||||
for prop_id, prop_def in ontology.object_properties.items():
|
||||
if prop_def.domain in subset.classes:
|
||||
if prop_id not in subset.object_properties:
|
||||
subset.traversal_properties[prop_id] = prop_def.__dict__
|
||||
|
||||
for prop_id, prop_def in ontology.datatype_properties.items():
|
||||
if prop_def.domain in subset.classes:
|
||||
if prop_id not in subset.datatype_properties:
|
||||
subset.traversal_properties[prop_id] = prop_def.__dict__
|
||||
|
||||
# For aggregation questions, ensure we have counting properties
|
||||
elif question_components.question_type == QuestionType.AGGREGATION:
|
||||
# Add properties that might be counted
|
||||
for prop_id, prop_def in ontology.datatype_properties.items():
|
||||
if 'count' in prop_id.lower() or 'number' in prop_id.lower():
|
||||
if prop_id not in subset.datatype_properties:
|
||||
subset.traversal_properties[prop_id] = prop_def.__dict__
|
||||
|
||||
def _add_related_properties(self, subset: QueryOntologySubset):
|
||||
"""Add properties related to already selected ones.
|
||||
|
||||
Args:
|
||||
subset: Query ontology subset to enhance
|
||||
"""
|
||||
ontology = self.loader.get_ontology(subset.ontology_id)
|
||||
if not ontology:
|
||||
return
|
||||
|
||||
# Add inverse properties
|
||||
for prop_id in list(subset.object_properties.keys()):
|
||||
prop = ontology.object_properties.get(prop_id)
|
||||
if prop and prop.inverse_of:
|
||||
inverse_prop = ontology.object_properties.get(prop.inverse_of)
|
||||
if inverse_prop and prop.inverse_of not in subset.object_properties:
|
||||
subset.object_properties[prop.inverse_of] = inverse_prop.__dict__
|
||||
logger.debug(f"Added inverse property: {prop.inverse_of}")
|
||||
|
||||
# Add sibling properties (same domain)
|
||||
domains_in_subset = set()
|
||||
for prop_def in subset.object_properties.values():
|
||||
if 'domain' in prop_def and prop_def['domain']:
|
||||
domains_in_subset.add(prop_def['domain'])
|
||||
|
||||
for domain in domains_in_subset:
|
||||
for prop_id, prop_def in ontology.object_properties.items():
|
||||
if prop_def.domain == domain and prop_id not in subset.object_properties:
|
||||
# Add up to 3 sibling properties
|
||||
if len(subset.traversal_properties) < 3:
|
||||
subset.traversal_properties[prop_id] = prop_def.__dict__
|
||||
|
||||
def _add_inference_rules(self, subset: QueryOntologySubset,
|
||||
question_components: QuestionComponents):
|
||||
"""Add inference rules for reasoning.
|
||||
|
||||
Args:
|
||||
subset: Query ontology subset to enhance
|
||||
question_components: Question analysis
|
||||
"""
|
||||
# Add transitivity rules for subclass relationships
|
||||
if any(cls.get('subclass_of') for cls in subset.classes.values()):
|
||||
subset.inference_rules.append({
|
||||
'type': 'transitivity',
|
||||
'property': 'rdfs:subClassOf',
|
||||
'description': 'Subclass relationships are transitive'
|
||||
})
|
||||
|
||||
# Add symmetry rules for equivalent classes
|
||||
if any(cls.get('equivalent_classes') for cls in subset.classes.values()):
|
||||
subset.inference_rules.append({
|
||||
'type': 'symmetry',
|
||||
'property': 'owl:equivalentClass',
|
||||
'description': 'Equivalent class relationships are symmetric'
|
||||
})
|
||||
|
||||
# Add inverse property rules
|
||||
for prop_id, prop_def in subset.object_properties.items():
|
||||
if 'inverse_of' in prop_def and prop_def['inverse_of']:
|
||||
subset.inference_rules.append({
|
||||
'type': 'inverse',
|
||||
'property': prop_id,
|
||||
'inverse': prop_def['inverse_of'],
|
||||
'description': f'{prop_id} is inverse of {prop_def["inverse_of"]}'
|
||||
})
|
||||
|
||||
def expand_for_hierarchical_queries(self, subset: QueryOntologySubset) -> QueryOntologySubset:
|
||||
"""Expand subset to include full class hierarchies.
|
||||
|
||||
Args:
|
||||
subset: Query ontology subset
|
||||
|
||||
Returns:
|
||||
Expanded subset with complete hierarchies
|
||||
"""
|
||||
ontology = self.loader.get_ontology(subset.ontology_id)
|
||||
if not ontology:
|
||||
return subset
|
||||
|
||||
# Add all parent and child classes
|
||||
classes_to_add = set()
|
||||
for class_id in list(subset.classes.keys()):
|
||||
# Add all parents
|
||||
parents = ontology.get_parent_classes(class_id)
|
||||
for parent_id in parents:
|
||||
if parent_id not in subset.classes:
|
||||
parent_class = ontology.get_class(parent_id)
|
||||
if parent_class:
|
||||
classes_to_add.add(parent_id)
|
||||
|
||||
# Add all children
|
||||
for other_class_id, other_class in ontology.classes.items():
|
||||
if other_class.subclass_of == class_id and other_class_id not in subset.classes:
|
||||
classes_to_add.add(other_class_id)
|
||||
|
||||
# Add collected classes
|
||||
for class_id in classes_to_add:
|
||||
cls = ontology.get_class(class_id)
|
||||
if cls:
|
||||
subset.classes[class_id] = cls.__dict__
|
||||
|
||||
logger.debug(f"Expanded hierarchy: added {len(classes_to_add)} classes")
|
||||
return subset
|
||||
640
trustgraph-flow/trustgraph/query/ontology/query_explanation.py
Normal file
640
trustgraph-flow/trustgraph/query/ontology/query_explanation.py
Normal file
|
|
@ -0,0 +1,640 @@
|
|||
"""
|
||||
Query explanation system for OntoRAG.
|
||||
Provides detailed explanations of how queries are processed and results are derived.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from .question_analyzer import QuestionComponents, QuestionType
|
||||
from .ontology_matcher import QueryOntologySubset
|
||||
from .sparql_generator import SPARQLQuery
|
||||
from .cypher_generator import CypherQuery
|
||||
from .sparql_cassandra import SPARQLResult
|
||||
from .cypher_executor import CypherResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExplanationStep:
|
||||
"""Individual step in query explanation."""
|
||||
step_number: int
|
||||
component: str
|
||||
operation: str
|
||||
input_data: Dict[str, Any]
|
||||
output_data: Dict[str, Any]
|
||||
explanation: str
|
||||
duration_ms: float
|
||||
success: bool
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryExplanation:
|
||||
"""Complete explanation of query processing."""
|
||||
query_id: str
|
||||
original_question: str
|
||||
processing_steps: List[ExplanationStep]
|
||||
final_answer: str
|
||||
confidence_score: float
|
||||
total_duration_ms: float
|
||||
ontologies_used: List[str]
|
||||
backend_used: str
|
||||
reasoning_chain: List[str]
|
||||
technical_details: Dict[str, Any]
|
||||
user_friendly_explanation: str
|
||||
|
||||
|
||||
class QueryExplainer:
|
||||
"""Generates explanations for query processing."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None):
|
||||
"""Initialize query explainer.
|
||||
|
||||
Args:
|
||||
config: Explainer configuration
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.explanation_level = self.config.get('explanation_level', 'detailed') # basic, detailed, technical
|
||||
self.include_technical_details = self.config.get('include_technical_details', True)
|
||||
self.max_reasoning_steps = self.config.get('max_reasoning_steps', 10)
|
||||
|
||||
# Templates for different explanation types
|
||||
self.step_templates = {
|
||||
'question_analysis': {
|
||||
'basic': "I analyzed your question to understand what you're asking.",
|
||||
'detailed': "I analyzed your question '{question}' and identified it as a {question_type} query about {entities}.",
|
||||
'technical': "Question analysis: Type={question_type}, Entities={entities}, Keywords={keywords}, Expected answer={answer_type}"
|
||||
},
|
||||
'ontology_matching': {
|
||||
'basic': "I found relevant knowledge about {entities} in the available ontologies.",
|
||||
'detailed': "I searched through {ontology_count} ontologies and found {selected_elements} relevant concepts related to your question.",
|
||||
'technical': "Ontology matching: Selected {classes} classes, {properties} properties from {ontologies}"
|
||||
},
|
||||
'query_generation': {
|
||||
'basic': "I generated a query to search for the information.",
|
||||
'detailed': "I created a {query_type} query using {query_language} to search the {backend} database.",
|
||||
'technical': "Query generation: {query_language} query with {variables} variables, complexity score {complexity}"
|
||||
},
|
||||
'query_execution': {
|
||||
'basic': "I searched the database and found {result_count} results.",
|
||||
'detailed': "I executed the query against the {backend} database and retrieved {result_count} results in {duration}ms.",
|
||||
'technical': "Query execution: {backend} backend, {result_count} results, execution time {duration}ms"
|
||||
},
|
||||
'answer_generation': {
|
||||
'basic': "I generated a natural language answer from the results.",
|
||||
'detailed': "I processed {result_count} results and generated an answer with {confidence}% confidence.",
|
||||
'technical': "Answer generation: {result_count} input results, {generation_method} method, confidence {confidence}"
|
||||
}
|
||||
}
|
||||
|
||||
self.reasoning_templates = {
|
||||
'entity_identification': "I identified '{entity}' as a key concept in your question.",
|
||||
'ontology_selection': "I selected the '{ontology}' ontology because it contains relevant information about {concepts}.",
|
||||
'query_strategy': "I chose a {strategy} query approach because {reason}.",
|
||||
'result_filtering': "I filtered the results to show only the most relevant {count} items.",
|
||||
'confidence_assessment': "I'm {confidence}% confident in this answer because {reasoning}."
|
||||
}
|
||||
|
||||
def explain_query_processing(self,
|
||||
question: str,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subsets: List[QueryOntologySubset],
|
||||
generated_query: Union[SPARQLQuery, CypherQuery],
|
||||
query_results: Union[SPARQLResult, CypherResult],
|
||||
final_answer: str,
|
||||
processing_metadata: Dict[str, Any]) -> QueryExplanation:
|
||||
"""Generate comprehensive explanation of query processing.
|
||||
|
||||
Args:
|
||||
question: Original question
|
||||
question_components: Analyzed question components
|
||||
ontology_subsets: Selected ontology subsets
|
||||
generated_query: Generated query
|
||||
query_results: Query execution results
|
||||
final_answer: Final generated answer
|
||||
processing_metadata: Processing metadata
|
||||
|
||||
Returns:
|
||||
Complete query explanation
|
||||
"""
|
||||
query_id = processing_metadata.get('query_id', f"query_{datetime.now().timestamp()}")
|
||||
start_time = processing_metadata.get('start_time', datetime.now())
|
||||
|
||||
# Build explanation steps
|
||||
steps = []
|
||||
step_number = 1
|
||||
|
||||
# Step 1: Question Analysis
|
||||
steps.append(self._explain_question_analysis(
|
||||
step_number, question, question_components
|
||||
))
|
||||
step_number += 1
|
||||
|
||||
# Step 2: Ontology Matching
|
||||
steps.append(self._explain_ontology_matching(
|
||||
step_number, question_components, ontology_subsets
|
||||
))
|
||||
step_number += 1
|
||||
|
||||
# Step 3: Query Generation
|
||||
steps.append(self._explain_query_generation(
|
||||
step_number, generated_query, processing_metadata
|
||||
))
|
||||
step_number += 1
|
||||
|
||||
# Step 4: Query Execution
|
||||
steps.append(self._explain_query_execution(
|
||||
step_number, generated_query, query_results, processing_metadata
|
||||
))
|
||||
step_number += 1
|
||||
|
||||
# Step 5: Answer Generation
|
||||
steps.append(self._explain_answer_generation(
|
||||
step_number, query_results, final_answer, processing_metadata
|
||||
))
|
||||
|
||||
# Build reasoning chain
|
||||
reasoning_chain = self._build_reasoning_chain(
|
||||
question_components, ontology_subsets, generated_query, processing_metadata
|
||||
)
|
||||
|
||||
# Calculate overall confidence
|
||||
confidence_score = self._calculate_explanation_confidence(
|
||||
question_components, query_results, processing_metadata
|
||||
)
|
||||
|
||||
# Generate user-friendly explanation
|
||||
user_friendly_explanation = self._generate_user_friendly_explanation(
|
||||
question, question_components, ontology_subsets, final_answer
|
||||
)
|
||||
|
||||
# Calculate total duration
|
||||
total_duration = processing_metadata.get('total_duration_ms', 0)
|
||||
|
||||
return QueryExplanation(
|
||||
query_id=query_id,
|
||||
original_question=question,
|
||||
processing_steps=steps,
|
||||
final_answer=final_answer,
|
||||
confidence_score=confidence_score,
|
||||
total_duration_ms=total_duration,
|
||||
ontologies_used=[subset.metadata.get('ontology_id', 'unknown') for subset in ontology_subsets],
|
||||
backend_used=processing_metadata.get('backend_used', 'unknown'),
|
||||
reasoning_chain=reasoning_chain,
|
||||
technical_details=self._extract_technical_details(processing_metadata),
|
||||
user_friendly_explanation=user_friendly_explanation
|
||||
)
|
||||
|
||||
def _explain_question_analysis(self,
|
||||
step_number: int,
|
||||
question: str,
|
||||
question_components: QuestionComponents) -> ExplanationStep:
|
||||
"""Explain question analysis step."""
|
||||
template = self.step_templates['question_analysis'][self.explanation_level]
|
||||
|
||||
if self.explanation_level == 'basic':
|
||||
explanation = template
|
||||
elif self.explanation_level == 'detailed':
|
||||
explanation = template.format(
|
||||
question=question,
|
||||
question_type=question_components.question_type.value.replace('_', ' '),
|
||||
entities=', '.join(question_components.entities[:3])
|
||||
)
|
||||
else: # technical
|
||||
explanation = template.format(
|
||||
question_type=question_components.question_type.value,
|
||||
entities=question_components.entities,
|
||||
keywords=question_components.keywords,
|
||||
answer_type=question_components.expected_answer_type
|
||||
)
|
||||
|
||||
return ExplanationStep(
|
||||
step_number=step_number,
|
||||
component="question_analyzer",
|
||||
operation="analyze_question",
|
||||
input_data={"question": question},
|
||||
output_data={
|
||||
"question_type": question_components.question_type.value,
|
||||
"entities": question_components.entities,
|
||||
"keywords": question_components.keywords
|
||||
},
|
||||
explanation=explanation,
|
||||
duration_ms=0.0, # Would be tracked in actual implementation
|
||||
success=True
|
||||
)
|
||||
|
||||
def _explain_ontology_matching(self,
|
||||
step_number: int,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subsets: List[QueryOntologySubset]) -> ExplanationStep:
|
||||
"""Explain ontology matching step."""
|
||||
template = self.step_templates['ontology_matching'][self.explanation_level]
|
||||
|
||||
total_elements = sum(
|
||||
len(subset.classes) + len(subset.object_properties) + len(subset.datatype_properties)
|
||||
for subset in ontology_subsets
|
||||
)
|
||||
|
||||
if self.explanation_level == 'basic':
|
||||
explanation = template.format(
|
||||
entities=', '.join(question_components.entities[:3])
|
||||
)
|
||||
elif self.explanation_level == 'detailed':
|
||||
explanation = template.format(
|
||||
ontology_count=len(ontology_subsets),
|
||||
selected_elements=total_elements
|
||||
)
|
||||
else: # technical
|
||||
total_classes = sum(len(subset.classes) for subset in ontology_subsets)
|
||||
total_properties = sum(
|
||||
len(subset.object_properties) + len(subset.datatype_properties)
|
||||
for subset in ontology_subsets
|
||||
)
|
||||
ontology_names = [subset.metadata.get('ontology_id', 'unknown') for subset in ontology_subsets]
|
||||
|
||||
explanation = template.format(
|
||||
classes=total_classes,
|
||||
properties=total_properties,
|
||||
ontologies=', '.join(ontology_names)
|
||||
)
|
||||
|
||||
return ExplanationStep(
|
||||
step_number=step_number,
|
||||
component="ontology_matcher",
|
||||
operation="select_relevant_subset",
|
||||
input_data={"entities": question_components.entities},
|
||||
output_data={
|
||||
"ontology_count": len(ontology_subsets),
|
||||
"total_elements": total_elements
|
||||
},
|
||||
explanation=explanation,
|
||||
duration_ms=0.0,
|
||||
success=True
|
||||
)
|
||||
|
||||
def _explain_query_generation(self,
|
||||
step_number: int,
|
||||
generated_query: Union[SPARQLQuery, CypherQuery],
|
||||
metadata: Dict[str, Any]) -> ExplanationStep:
|
||||
"""Explain query generation step."""
|
||||
template = self.step_templates['query_generation'][self.explanation_level]
|
||||
|
||||
query_language = "SPARQL" if isinstance(generated_query, SPARQLQuery) else "Cypher"
|
||||
backend = metadata.get('backend_used', 'unknown')
|
||||
|
||||
if self.explanation_level == 'basic':
|
||||
explanation = template
|
||||
elif self.explanation_level == 'detailed':
|
||||
explanation = template.format(
|
||||
query_type=generated_query.query_type,
|
||||
query_language=query_language,
|
||||
backend=backend
|
||||
)
|
||||
else: # technical
|
||||
explanation = template.format(
|
||||
query_language=query_language,
|
||||
variables=len(generated_query.variables),
|
||||
complexity=f"{generated_query.complexity_score:.2f}"
|
||||
)
|
||||
|
||||
return ExplanationStep(
|
||||
step_number=step_number,
|
||||
component="query_generator",
|
||||
operation="generate_query",
|
||||
input_data={"query_type": generated_query.query_type},
|
||||
output_data={
|
||||
"query_language": query_language,
|
||||
"variables": generated_query.variables,
|
||||
"complexity": generated_query.complexity_score
|
||||
},
|
||||
explanation=explanation,
|
||||
duration_ms=0.0,
|
||||
success=True,
|
||||
metadata={"generated_query": generated_query.query}
|
||||
)
|
||||
|
||||
def _explain_query_execution(self,
|
||||
step_number: int,
|
||||
generated_query: Union[SPARQLQuery, CypherQuery],
|
||||
query_results: Union[SPARQLResult, CypherResult],
|
||||
metadata: Dict[str, Any]) -> ExplanationStep:
|
||||
"""Explain query execution step."""
|
||||
template = self.step_templates['query_execution'][self.explanation_level]
|
||||
|
||||
backend = metadata.get('backend_used', 'unknown')
|
||||
duration = getattr(query_results, 'execution_time', 0) * 1000 # Convert to ms
|
||||
|
||||
if isinstance(query_results, SPARQLResult):
|
||||
result_count = len(query_results.bindings)
|
||||
else: # CypherResult
|
||||
result_count = len(query_results.records)
|
||||
|
||||
if self.explanation_level == 'basic':
|
||||
explanation = template.format(result_count=result_count)
|
||||
elif self.explanation_level == 'detailed':
|
||||
explanation = template.format(
|
||||
backend=backend,
|
||||
result_count=result_count,
|
||||
duration=f"{duration:.1f}"
|
||||
)
|
||||
else: # technical
|
||||
explanation = template.format(
|
||||
backend=backend,
|
||||
result_count=result_count,
|
||||
duration=f"{duration:.1f}"
|
||||
)
|
||||
|
||||
return ExplanationStep(
|
||||
step_number=step_number,
|
||||
component="query_executor",
|
||||
operation="execute_query",
|
||||
input_data={"query": generated_query.query},
|
||||
output_data={
|
||||
"result_count": result_count,
|
||||
"execution_time_ms": duration
|
||||
},
|
||||
explanation=explanation,
|
||||
duration_ms=duration,
|
||||
success=result_count >= 0
|
||||
)
|
||||
|
||||
def _explain_answer_generation(self,
|
||||
step_number: int,
|
||||
query_results: Union[SPARQLResult, CypherResult],
|
||||
final_answer: str,
|
||||
metadata: Dict[str, Any]) -> ExplanationStep:
|
||||
"""Explain answer generation step."""
|
||||
template = self.step_templates['answer_generation'][self.explanation_level]
|
||||
|
||||
if isinstance(query_results, SPARQLResult):
|
||||
result_count = len(query_results.bindings)
|
||||
else: # CypherResult
|
||||
result_count = len(query_results.records)
|
||||
|
||||
confidence = metadata.get('answer_confidence', 0.8) * 100 # Convert to percentage
|
||||
|
||||
if self.explanation_level == 'basic':
|
||||
explanation = template
|
||||
elif self.explanation_level == 'detailed':
|
||||
explanation = template.format(
|
||||
result_count=result_count,
|
||||
confidence=f"{confidence:.0f}"
|
||||
)
|
||||
else: # technical
|
||||
generation_method = metadata.get('generation_method', 'template_based')
|
||||
explanation = template.format(
|
||||
result_count=result_count,
|
||||
generation_method=generation_method,
|
||||
confidence=f"{confidence:.1f}"
|
||||
)
|
||||
|
||||
return ExplanationStep(
|
||||
step_number=step_number,
|
||||
component="answer_generator",
|
||||
operation="generate_answer",
|
||||
input_data={"result_count": result_count},
|
||||
output_data={
|
||||
"answer": final_answer,
|
||||
"confidence": confidence / 100
|
||||
},
|
||||
explanation=explanation,
|
||||
duration_ms=0.0,
|
||||
success=bool(final_answer)
|
||||
)
|
||||
|
||||
def _build_reasoning_chain(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subsets: List[QueryOntologySubset],
|
||||
generated_query: Union[SPARQLQuery, CypherQuery],
|
||||
metadata: Dict[str, Any]) -> List[str]:
|
||||
"""Build reasoning chain explaining the decision process."""
|
||||
reasoning = []
|
||||
|
||||
# Entity identification reasoning
|
||||
if question_components.entities:
|
||||
for entity in question_components.entities[:3]:
|
||||
reasoning.append(
|
||||
self.reasoning_templates['entity_identification'].format(entity=entity)
|
||||
)
|
||||
|
||||
# Ontology selection reasoning
|
||||
if ontology_subsets:
|
||||
primary_ontology = ontology_subsets[0]
|
||||
ontology_id = primary_ontology.metadata.get('ontology_id', 'primary')
|
||||
concepts = list(primary_ontology.classes.keys())[:3]
|
||||
reasoning.append(
|
||||
self.reasoning_templates['ontology_selection'].format(
|
||||
ontology=ontology_id,
|
||||
concepts=', '.join(concepts)
|
||||
)
|
||||
)
|
||||
|
||||
# Query strategy reasoning
|
||||
query_language = "SPARQL" if isinstance(generated_query, SPARQLQuery) else "Cypher"
|
||||
if question_components.question_type == QuestionType.AGGREGATION:
|
||||
strategy = "aggregation"
|
||||
reason = "you asked for a count or sum"
|
||||
elif question_components.question_type == QuestionType.BOOLEAN:
|
||||
strategy = "boolean"
|
||||
reason = "you asked a yes/no question"
|
||||
else:
|
||||
strategy = "retrieval"
|
||||
reason = "you asked for specific information"
|
||||
|
||||
reasoning.append(
|
||||
self.reasoning_templates['query_strategy'].format(
|
||||
strategy=strategy,
|
||||
reason=reason
|
||||
)
|
||||
)
|
||||
|
||||
# Confidence assessment
|
||||
confidence = metadata.get('answer_confidence', 0.8) * 100
|
||||
if confidence > 90:
|
||||
confidence_reason = "the query matched well with available data"
|
||||
elif confidence > 70:
|
||||
confidence_reason = "the query found relevant information with some uncertainty"
|
||||
else:
|
||||
confidence_reason = "the available data partially matches your question"
|
||||
|
||||
reasoning.append(
|
||||
self.reasoning_templates['confidence_assessment'].format(
|
||||
confidence=f"{confidence:.0f}",
|
||||
reasoning=confidence_reason
|
||||
)
|
||||
)
|
||||
|
||||
return reasoning[:self.max_reasoning_steps]
|
||||
|
||||
def _calculate_explanation_confidence(self,
|
||||
question_components: QuestionComponents,
|
||||
query_results: Union[SPARQLResult, CypherResult],
|
||||
metadata: Dict[str, Any]) -> float:
|
||||
"""Calculate confidence score for the explanation."""
|
||||
confidence = 0.8 # Base confidence
|
||||
|
||||
# Adjust based on result count
|
||||
if isinstance(query_results, SPARQLResult):
|
||||
result_count = len(query_results.bindings)
|
||||
else:
|
||||
result_count = len(query_results.records)
|
||||
|
||||
if result_count > 0:
|
||||
confidence += 0.1
|
||||
if result_count > 5:
|
||||
confidence += 0.05
|
||||
|
||||
# Adjust based on question complexity
|
||||
if len(question_components.entities) > 0:
|
||||
confidence += 0.05
|
||||
|
||||
# Adjust based on processing success
|
||||
if metadata.get('all_steps_successful', True):
|
||||
confidence += 0.05
|
||||
|
||||
return min(confidence, 1.0)
|
||||
|
||||
def _generate_user_friendly_explanation(self,
|
||||
question: str,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subsets: List[QueryOntologySubset],
|
||||
final_answer: str) -> str:
|
||||
"""Generate user-friendly explanation of the process."""
|
||||
explanation_parts = []
|
||||
|
||||
# Introduction
|
||||
explanation_parts.append(f"To answer your question '{question}', I followed these steps:")
|
||||
|
||||
# Process summary
|
||||
if question_components.question_type == QuestionType.AGGREGATION:
|
||||
explanation_parts.append("1. I recognized this as a counting or aggregation question")
|
||||
elif question_components.question_type == QuestionType.BOOLEAN:
|
||||
explanation_parts.append("1. I recognized this as a yes/no question")
|
||||
else:
|
||||
explanation_parts.append("1. I analyzed your question to understand what information you need")
|
||||
|
||||
# Ontology usage
|
||||
if ontology_subsets:
|
||||
ontology_count = len(ontology_subsets)
|
||||
if ontology_count == 1:
|
||||
explanation_parts.append("2. I searched through the relevant knowledge base")
|
||||
else:
|
||||
explanation_parts.append(f"2. I searched through {ontology_count} knowledge bases")
|
||||
|
||||
# Result processing
|
||||
explanation_parts.append("3. I found the relevant information and generated your answer")
|
||||
|
||||
# Conclusion
|
||||
explanation_parts.append(f"The answer is: {final_answer}")
|
||||
|
||||
return " ".join(explanation_parts)
|
||||
|
||||
def _extract_technical_details(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Extract technical details for debugging and optimization."""
|
||||
return {
|
||||
'query_optimization': metadata.get('query_optimization', {}),
|
||||
'backend_performance': metadata.get('backend_performance', {}),
|
||||
'cache_usage': metadata.get('cache_usage', {}),
|
||||
'error_handling': metadata.get('error_handling', {}),
|
||||
'routing_decision': metadata.get('routing_decision', {})
|
||||
}
|
||||
|
||||
def format_explanation_for_display(self,
|
||||
explanation: QueryExplanation,
|
||||
format_type: str = 'html') -> str:
|
||||
"""Format explanation for display.
|
||||
|
||||
Args:
|
||||
explanation: Query explanation
|
||||
format_type: Output format ('html', 'markdown', 'text')
|
||||
|
||||
Returns:
|
||||
Formatted explanation
|
||||
"""
|
||||
if format_type == 'html':
|
||||
return self._format_html_explanation(explanation)
|
||||
elif format_type == 'markdown':
|
||||
return self._format_markdown_explanation(explanation)
|
||||
else:
|
||||
return self._format_text_explanation(explanation)
|
||||
|
||||
def _format_html_explanation(self, explanation: QueryExplanation) -> str:
|
||||
"""Format explanation as HTML."""
|
||||
html_parts = [
|
||||
f"<h2>Query Explanation: {explanation.query_id}</h2>",
|
||||
f"<p><strong>Question:</strong> {explanation.original_question}</p>",
|
||||
f"<p><strong>Answer:</strong> {explanation.final_answer}</p>",
|
||||
f"<p><strong>Confidence:</strong> {explanation.confidence_score:.1%}</p>",
|
||||
"<h3>Processing Steps:</h3>",
|
||||
"<ol>"
|
||||
]
|
||||
|
||||
for step in explanation.processing_steps:
|
||||
html_parts.append(f"<li><strong>{step.component}</strong>: {step.explanation}</li>")
|
||||
|
||||
html_parts.extend([
|
||||
"</ol>",
|
||||
"<h3>Reasoning:</h3>",
|
||||
"<ul>"
|
||||
])
|
||||
|
||||
for reasoning in explanation.reasoning_chain:
|
||||
html_parts.append(f"<li>{reasoning}</li>")
|
||||
|
||||
html_parts.append("</ul>")
|
||||
|
||||
return "".join(html_parts)
|
||||
|
||||
def _format_markdown_explanation(self, explanation: QueryExplanation) -> str:
|
||||
"""Format explanation as Markdown."""
|
||||
md_parts = [
|
||||
f"## Query Explanation: {explanation.query_id}",
|
||||
f"**Question:** {explanation.original_question}",
|
||||
f"**Answer:** {explanation.final_answer}",
|
||||
f"**Confidence:** {explanation.confidence_score:.1%}",
|
||||
"",
|
||||
"### Processing Steps:",
|
||||
""
|
||||
]
|
||||
|
||||
for i, step in enumerate(explanation.processing_steps, 1):
|
||||
md_parts.append(f"{i}. **{step.component}**: {step.explanation}")
|
||||
|
||||
md_parts.extend([
|
||||
"",
|
||||
"### Reasoning:",
|
||||
""
|
||||
])
|
||||
|
||||
for reasoning in explanation.reasoning_chain:
|
||||
md_parts.append(f"- {reasoning}")
|
||||
|
||||
return "\n".join(md_parts)
|
||||
|
||||
def _format_text_explanation(self, explanation: QueryExplanation) -> str:
|
||||
"""Format explanation as plain text."""
|
||||
text_parts = [
|
||||
f"Query Explanation: {explanation.query_id}",
|
||||
f"Question: {explanation.original_question}",
|
||||
f"Answer: {explanation.final_answer}",
|
||||
f"Confidence: {explanation.confidence_score:.1%}",
|
||||
"",
|
||||
"Processing Steps:",
|
||||
]
|
||||
|
||||
for i, step in enumerate(explanation.processing_steps, 1):
|
||||
text_parts.append(f" {i}. {step.component}: {step.explanation}")
|
||||
|
||||
text_parts.extend([
|
||||
"",
|
||||
"Reasoning:",
|
||||
])
|
||||
|
||||
for reasoning in explanation.reasoning_chain:
|
||||
text_parts.append(f" - {reasoning}")
|
||||
|
||||
return "\n".join(text_parts)
|
||||
519
trustgraph-flow/trustgraph/query/ontology/query_optimizer.py
Normal file
519
trustgraph-flow/trustgraph/query/ontology/query_optimizer.py
Normal file
|
|
@ -0,0 +1,519 @@
|
|||
"""
|
||||
Query optimization module for OntoRAG.
|
||||
Optimizes SPARQL and Cypher queries for better performance and accuracy.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Union, Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import re
|
||||
|
||||
from .question_analyzer import QuestionComponents, QuestionType
|
||||
from .ontology_matcher import QueryOntologySubset
|
||||
from .sparql_generator import SPARQLQuery
|
||||
from .cypher_generator import CypherQuery
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OptimizationStrategy(Enum):
|
||||
"""Query optimization strategies."""
|
||||
PERFORMANCE = "performance"
|
||||
ACCURACY = "accuracy"
|
||||
BALANCED = "balanced"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizationHint:
|
||||
"""Optimization hint for query processing."""
|
||||
strategy: OptimizationStrategy
|
||||
max_results: Optional[int] = None
|
||||
timeout_seconds: Optional[int] = None
|
||||
use_indices: bool = True
|
||||
enable_parallel: bool = False
|
||||
cache_results: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryPlan:
|
||||
"""Query execution plan with optimization metadata."""
|
||||
original_query: str
|
||||
optimized_query: str
|
||||
estimated_cost: float
|
||||
optimization_notes: List[str]
|
||||
index_hints: List[str]
|
||||
execution_order: List[str]
|
||||
|
||||
|
||||
class QueryOptimizer:
|
||||
"""Optimizes SPARQL and Cypher queries for performance and accuracy."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None):
|
||||
"""Initialize query optimizer.
|
||||
|
||||
Args:
|
||||
config: Optimizer configuration
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.default_strategy = OptimizationStrategy(
|
||||
self.config.get('default_strategy', 'balanced')
|
||||
)
|
||||
self.max_query_complexity = self.config.get('max_query_complexity', 10)
|
||||
self.enable_query_rewriting = self.config.get('enable_query_rewriting', True)
|
||||
|
||||
# Performance thresholds
|
||||
self.large_result_threshold = self.config.get('large_result_threshold', 1000)
|
||||
self.complex_join_threshold = self.config.get('complex_join_threshold', 3)
|
||||
|
||||
def optimize_sparql(self,
|
||||
sparql_query: SPARQLQuery,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset,
|
||||
optimization_hint: Optional[OptimizationHint] = None) -> Tuple[SPARQLQuery, QueryPlan]:
|
||||
"""Optimize SPARQL query.
|
||||
|
||||
Args:
|
||||
sparql_query: Original SPARQL query
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
optimization_hint: Optimization hints
|
||||
|
||||
Returns:
|
||||
Optimized SPARQL query and execution plan
|
||||
"""
|
||||
hint = optimization_hint or OptimizationHint(strategy=self.default_strategy)
|
||||
|
||||
optimized_query = sparql_query.query
|
||||
optimization_notes = []
|
||||
index_hints = []
|
||||
execution_order = []
|
||||
|
||||
# Apply optimizations based on strategy
|
||||
if hint.strategy in [OptimizationStrategy.PERFORMANCE, OptimizationStrategy.BALANCED]:
|
||||
optimized_query, perf_notes, perf_hints = self._optimize_sparql_performance(
|
||||
optimized_query, question_components, ontology_subset, hint
|
||||
)
|
||||
optimization_notes.extend(perf_notes)
|
||||
index_hints.extend(perf_hints)
|
||||
|
||||
if hint.strategy in [OptimizationStrategy.ACCURACY, OptimizationStrategy.BALANCED]:
|
||||
optimized_query, acc_notes = self._optimize_sparql_accuracy(
|
||||
optimized_query, question_components, ontology_subset
|
||||
)
|
||||
optimization_notes.extend(acc_notes)
|
||||
|
||||
# Estimate query cost
|
||||
estimated_cost = self._estimate_sparql_cost(optimized_query, ontology_subset)
|
||||
|
||||
# Build execution plan
|
||||
query_plan = QueryPlan(
|
||||
original_query=sparql_query.query,
|
||||
optimized_query=optimized_query,
|
||||
estimated_cost=estimated_cost,
|
||||
optimization_notes=optimization_notes,
|
||||
index_hints=index_hints,
|
||||
execution_order=execution_order
|
||||
)
|
||||
|
||||
# Create optimized query object
|
||||
optimized_sparql = SPARQLQuery(
|
||||
query=optimized_query,
|
||||
variables=sparql_query.variables,
|
||||
query_type=sparql_query.query_type,
|
||||
explanation=f"Optimized: {sparql_query.explanation}",
|
||||
complexity_score=min(sparql_query.complexity_score * 0.8, 1.0) # Assume optimization reduces complexity
|
||||
)
|
||||
|
||||
return optimized_sparql, query_plan
|
||||
|
||||
def optimize_cypher(self,
|
||||
cypher_query: CypherQuery,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset,
|
||||
optimization_hint: Optional[OptimizationHint] = None) -> Tuple[CypherQuery, QueryPlan]:
|
||||
"""Optimize Cypher query.
|
||||
|
||||
Args:
|
||||
cypher_query: Original Cypher query
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
optimization_hint: Optimization hints
|
||||
|
||||
Returns:
|
||||
Optimized Cypher query and execution plan
|
||||
"""
|
||||
hint = optimization_hint or OptimizationHint(strategy=self.default_strategy)
|
||||
|
||||
optimized_query = cypher_query.query
|
||||
optimization_notes = []
|
||||
index_hints = []
|
||||
execution_order = []
|
||||
|
||||
# Apply optimizations based on strategy
|
||||
if hint.strategy in [OptimizationStrategy.PERFORMANCE, OptimizationStrategy.BALANCED]:
|
||||
optimized_query, perf_notes, perf_hints = self._optimize_cypher_performance(
|
||||
optimized_query, question_components, ontology_subset, hint
|
||||
)
|
||||
optimization_notes.extend(perf_notes)
|
||||
index_hints.extend(perf_hints)
|
||||
|
||||
if hint.strategy in [OptimizationStrategy.ACCURACY, OptimizationStrategy.BALANCED]:
|
||||
optimized_query, acc_notes = self._optimize_cypher_accuracy(
|
||||
optimized_query, question_components, ontology_subset
|
||||
)
|
||||
optimization_notes.extend(acc_notes)
|
||||
|
||||
# Estimate query cost
|
||||
estimated_cost = self._estimate_cypher_cost(optimized_query, ontology_subset)
|
||||
|
||||
# Build execution plan
|
||||
query_plan = QueryPlan(
|
||||
original_query=cypher_query.query,
|
||||
optimized_query=optimized_query,
|
||||
estimated_cost=estimated_cost,
|
||||
optimization_notes=optimization_notes,
|
||||
index_hints=index_hints,
|
||||
execution_order=execution_order
|
||||
)
|
||||
|
||||
# Create optimized query object
|
||||
optimized_cypher = CypherQuery(
|
||||
query=optimized_query,
|
||||
variables=cypher_query.variables,
|
||||
query_type=cypher_query.query_type,
|
||||
explanation=f"Optimized: {cypher_query.explanation}",
|
||||
complexity_score=min(cypher_query.complexity_score * 0.8, 1.0)
|
||||
)
|
||||
|
||||
return optimized_cypher, query_plan
|
||||
|
||||
def _optimize_sparql_performance(self,
|
||||
query: str,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset,
|
||||
hint: OptimizationHint) -> Tuple[str, List[str], List[str]]:
|
||||
"""Apply performance optimizations to SPARQL query.
|
||||
|
||||
Args:
|
||||
query: SPARQL query string
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
hint: Optimization hints
|
||||
|
||||
Returns:
|
||||
Optimized query, optimization notes, and index hints
|
||||
"""
|
||||
optimized = query
|
||||
notes = []
|
||||
index_hints = []
|
||||
|
||||
# Add LIMIT if not present and large results expected
|
||||
if hint.max_results and 'LIMIT' not in optimized.upper():
|
||||
optimized = f"{optimized.rstrip()}\nLIMIT {hint.max_results}"
|
||||
notes.append(f"Added LIMIT {hint.max_results} to prevent large result sets")
|
||||
|
||||
# Optimize OPTIONAL clauses (move to end)
|
||||
optional_pattern = re.compile(r'OPTIONAL\s*\{[^}]+\}', re.IGNORECASE | re.DOTALL)
|
||||
optionals = optional_pattern.findall(optimized)
|
||||
if optionals:
|
||||
# Remove optionals from current position
|
||||
for optional in optionals:
|
||||
optimized = optimized.replace(optional, '')
|
||||
|
||||
# Add them at the end (before ORDER BY/LIMIT)
|
||||
insert_point = optimized.rfind('ORDER BY')
|
||||
if insert_point == -1:
|
||||
insert_point = optimized.rfind('LIMIT')
|
||||
if insert_point == -1:
|
||||
insert_point = len(optimized.rstrip())
|
||||
|
||||
for optional in optionals:
|
||||
optimized = optimized[:insert_point] + f"\n {optional}" + optimized[insert_point:]
|
||||
|
||||
notes.append("Moved OPTIONAL clauses to end for better performance")
|
||||
|
||||
# Add index hints for Cassandra
|
||||
if 'WHERE' in optimized.upper():
|
||||
# Suggest indices for common patterns
|
||||
if '?subject rdf:type' in optimized:
|
||||
index_hints.append("type_index")
|
||||
if 'rdfs:subClassOf' in optimized:
|
||||
index_hints.append("hierarchy_index")
|
||||
|
||||
# Optimize FILTER clauses (move closer to variable bindings)
|
||||
filter_pattern = re.compile(r'FILTER\s*\([^)]+\)', re.IGNORECASE)
|
||||
filters = filter_pattern.findall(optimized)
|
||||
if filters:
|
||||
notes.append("FILTER clauses present - ensure they're positioned optimally")
|
||||
|
||||
return optimized, notes, index_hints
|
||||
|
||||
def _optimize_sparql_accuracy(self,
|
||||
query: str,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset) -> Tuple[str, List[str]]:
|
||||
"""Apply accuracy optimizations to SPARQL query.
|
||||
|
||||
Args:
|
||||
query: SPARQL query string
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Optimized query and optimization notes
|
||||
"""
|
||||
optimized = query
|
||||
notes = []
|
||||
|
||||
# Add missing namespace checks
|
||||
if question_components.question_type == QuestionType.RETRIEVAL:
|
||||
# Ensure we're not mixing namespaces inappropriately
|
||||
if 'http://' in optimized and '?' in optimized:
|
||||
notes.append("Verified namespace consistency for accuracy")
|
||||
|
||||
# Add type constraints for better precision
|
||||
if '?entity' in optimized and 'rdf:type' not in optimized:
|
||||
# Find a good insertion point
|
||||
where_clause = re.search(r'WHERE\s*\{(.+)\}', optimized, re.DOTALL | re.IGNORECASE)
|
||||
if where_clause and ontology_subset.classes:
|
||||
# Add type constraint for the most relevant class
|
||||
main_class = list(ontology_subset.classes.keys())[0]
|
||||
type_constraint = f"\n ?entity rdf:type :{main_class} ."
|
||||
|
||||
# Insert after the WHERE {
|
||||
where_start = where_clause.start(1)
|
||||
optimized = optimized[:where_start] + type_constraint + optimized[where_start:]
|
||||
notes.append(f"Added type constraint for {main_class} to improve accuracy")
|
||||
|
||||
# Add DISTINCT if not present for retrieval queries
|
||||
if (question_components.question_type == QuestionType.RETRIEVAL and
|
||||
'DISTINCT' not in optimized.upper() and
|
||||
'SELECT' in optimized.upper()):
|
||||
optimized = optimized.replace('SELECT ', 'SELECT DISTINCT ', 1)
|
||||
notes.append("Added DISTINCT to eliminate duplicate results")
|
||||
|
||||
return optimized, notes
|
||||
|
||||
def _optimize_cypher_performance(self,
|
||||
query: str,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset,
|
||||
hint: OptimizationHint) -> Tuple[str, List[str], List[str]]:
|
||||
"""Apply performance optimizations to Cypher query.
|
||||
|
||||
Args:
|
||||
query: Cypher query string
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
hint: Optimization hints
|
||||
|
||||
Returns:
|
||||
Optimized query, optimization notes, and index hints
|
||||
"""
|
||||
optimized = query
|
||||
notes = []
|
||||
index_hints = []
|
||||
|
||||
# Add LIMIT if not present
|
||||
if hint.max_results and 'LIMIT' not in optimized.upper():
|
||||
optimized = f"{optimized.rstrip()}\nLIMIT {hint.max_results}"
|
||||
notes.append(f"Added LIMIT {hint.max_results} to prevent large result sets")
|
||||
|
||||
# Use parameters for literals to enable query plan caching
|
||||
if "'" in optimized or '"' in optimized:
|
||||
notes.append("Consider using parameters for literal values to enable query plan caching")
|
||||
|
||||
# Suggest indices based on query patterns
|
||||
if 'MATCH (n:' in optimized:
|
||||
label_match = re.search(r'MATCH \(n:(\w+)\)', optimized)
|
||||
if label_match:
|
||||
label = label_match.group(1)
|
||||
index_hints.append(f"node_label_index:{label}")
|
||||
|
||||
if 'WHERE' in optimized.upper() and '.' in optimized:
|
||||
# Property access patterns
|
||||
property_pattern = re.compile(r'\.(\w+)', re.IGNORECASE)
|
||||
properties = property_pattern.findall(optimized)
|
||||
for prop in set(properties):
|
||||
index_hints.append(f"property_index:{prop}")
|
||||
|
||||
# Optimize relationship traversals
|
||||
if '-[' in optimized and '*' in optimized:
|
||||
notes.append("Variable length path detected - consider adding relationship type filters")
|
||||
|
||||
# Early filtering optimization
|
||||
if 'WHERE' in optimized.upper():
|
||||
# Move WHERE clauses closer to MATCH clauses
|
||||
notes.append("WHERE clauses present - ensure early filtering for performance")
|
||||
|
||||
return optimized, notes, index_hints
|
||||
|
||||
def _optimize_cypher_accuracy(self,
|
||||
query: str,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset) -> Tuple[str, List[str]]:
|
||||
"""Apply accuracy optimizations to Cypher query.
|
||||
|
||||
Args:
|
||||
query: Cypher query string
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Optimized query and optimization notes
|
||||
"""
|
||||
optimized = query
|
||||
notes = []
|
||||
|
||||
# Add DISTINCT if not present for retrieval queries
|
||||
if (question_components.question_type == QuestionType.RETRIEVAL and
|
||||
'DISTINCT' not in optimized.upper() and
|
||||
'RETURN' in optimized.upper()):
|
||||
optimized = re.sub(r'RETURN\s+', 'RETURN DISTINCT ', optimized, count=1, flags=re.IGNORECASE)
|
||||
notes.append("Added DISTINCT to eliminate duplicate results")
|
||||
|
||||
# Ensure proper relationship direction
|
||||
if '-[' in optimized and question_components.relationships:
|
||||
notes.append("Verified relationship directions for semantic accuracy")
|
||||
|
||||
# Add null checks for optional properties
|
||||
if '?' in optimized or 'OPTIONAL' in optimized.upper():
|
||||
notes.append("Consider adding null checks for optional properties")
|
||||
|
||||
return optimized, notes
|
||||
|
||||
def _estimate_sparql_cost(self, query: str, ontology_subset: QueryOntologySubset) -> float:
|
||||
"""Estimate execution cost for SPARQL query.
|
||||
|
||||
Args:
|
||||
query: SPARQL query string
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Estimated cost (0.0 to 1.0)
|
||||
"""
|
||||
cost = 0.0
|
||||
|
||||
# Basic query complexity
|
||||
cost += len(query.split('\n')) * 0.01
|
||||
|
||||
# Join complexity
|
||||
triple_patterns = len(re.findall(r'\?\w+\s+\?\w+\s+\?\w+', query))
|
||||
cost += triple_patterns * 0.1
|
||||
|
||||
# OPTIONAL clauses
|
||||
optional_count = len(re.findall(r'OPTIONAL', query, re.IGNORECASE))
|
||||
cost += optional_count * 0.15
|
||||
|
||||
# FILTER clauses
|
||||
filter_count = len(re.findall(r'FILTER', query, re.IGNORECASE))
|
||||
cost += filter_count * 0.1
|
||||
|
||||
# Property paths
|
||||
path_count = len(re.findall(r'\*|\+', query))
|
||||
cost += path_count * 0.2
|
||||
|
||||
# Ontology subset size impact
|
||||
total_elements = (len(ontology_subset.classes) +
|
||||
len(ontology_subset.object_properties) +
|
||||
len(ontology_subset.datatype_properties))
|
||||
cost += (total_elements / 100.0) * 0.1
|
||||
|
||||
return min(cost, 1.0)
|
||||
|
||||
def _estimate_cypher_cost(self, query: str, ontology_subset: QueryOntologySubset) -> float:
|
||||
"""Estimate execution cost for Cypher query.
|
||||
|
||||
Args:
|
||||
query: Cypher query string
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Estimated cost (0.0 to 1.0)
|
||||
"""
|
||||
cost = 0.0
|
||||
|
||||
# Basic query complexity
|
||||
cost += len(query.split('\n')) * 0.01
|
||||
|
||||
# Pattern complexity
|
||||
match_count = len(re.findall(r'MATCH', query, re.IGNORECASE))
|
||||
cost += match_count * 0.1
|
||||
|
||||
# Relationship traversals
|
||||
rel_count = len(re.findall(r'-\[.*?\]-', query))
|
||||
cost += rel_count * 0.1
|
||||
|
||||
# Variable length paths
|
||||
var_path_count = len(re.findall(r'\*\d*\.\.', query))
|
||||
cost += var_path_count * 0.3
|
||||
|
||||
# WHERE clauses
|
||||
where_count = len(re.findall(r'WHERE', query, re.IGNORECASE))
|
||||
cost += where_count * 0.05
|
||||
|
||||
# Aggregation functions
|
||||
agg_count = len(re.findall(r'COUNT|SUM|AVG|MIN|MAX', query, re.IGNORECASE))
|
||||
cost += agg_count * 0.1
|
||||
|
||||
# Ontology subset size impact
|
||||
total_elements = (len(ontology_subset.classes) +
|
||||
len(ontology_subset.object_properties) +
|
||||
len(ontology_subset.datatype_properties))
|
||||
cost += (total_elements / 100.0) * 0.1
|
||||
|
||||
return min(cost, 1.0)
|
||||
|
||||
def should_use_cache(self,
|
||||
query: str,
|
||||
question_components: QuestionComponents,
|
||||
optimization_hint: OptimizationHint) -> bool:
|
||||
"""Determine if query results should be cached.
|
||||
|
||||
Args:
|
||||
query: Query string
|
||||
question_components: Question analysis
|
||||
optimization_hint: Optimization hints
|
||||
|
||||
Returns:
|
||||
True if results should be cached
|
||||
"""
|
||||
if not optimization_hint.cache_results:
|
||||
return False
|
||||
|
||||
# Cache simple retrieval and factual queries
|
||||
if question_components.question_type in [QuestionType.RETRIEVAL, QuestionType.FACTUAL]:
|
||||
return True
|
||||
|
||||
# Cache expensive aggregation queries
|
||||
if (question_components.question_type == QuestionType.AGGREGATION and
|
||||
('COUNT' in query.upper() or 'SUM' in query.upper())):
|
||||
return True
|
||||
|
||||
# Don't cache real-time or time-sensitive queries
|
||||
if any(keyword in question_components.original_question.lower()
|
||||
for keyword in ['now', 'current', 'latest', 'recent']):
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def get_cache_key(self,
|
||||
query: str,
|
||||
ontology_subset: QueryOntologySubset) -> str:
|
||||
"""Generate cache key for query.
|
||||
|
||||
Args:
|
||||
query: Query string
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Cache key string
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
# Create stable representation
|
||||
ontology_repr = f"{sorted(ontology_subset.classes.keys())}-{sorted(ontology_subset.object_properties.keys())}"
|
||||
combined = f"{query.strip()}-{ontology_repr}"
|
||||
|
||||
return hashlib.md5(combined.encode()).hexdigest()
|
||||
438
trustgraph-flow/trustgraph/query/ontology/query_service.py
Normal file
438
trustgraph-flow/trustgraph/query/ontology/query_service.py
Normal file
|
|
@ -0,0 +1,438 @@
|
|||
"""
|
||||
Main OntoRAG query service.
|
||||
Orchestrates question analysis, ontology matching, query generation, execution, and answer generation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from ....flow.flow_processor import FlowProcessor
|
||||
from ....tables.config import ConfigTableStore
|
||||
from ...extract.kg.ontology.ontology_loader import OntologyLoader
|
||||
from ...extract.kg.ontology.vector_store import InMemoryVectorStore
|
||||
|
||||
from .question_analyzer import QuestionAnalyzer, QuestionComponents
|
||||
from .ontology_matcher import OntologyMatcher, QueryOntologySubset
|
||||
from .backend_router import BackendRouter, QueryRoute, BackendType
|
||||
from .sparql_generator import SPARQLGenerator, SPARQLQuery
|
||||
from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult
|
||||
from .cypher_generator import CypherGenerator, CypherQuery
|
||||
from .cypher_executor import CypherExecutor, CypherResult
|
||||
from .answer_generator import AnswerGenerator, GeneratedAnswer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryRequest:
|
||||
"""Query request from user."""
|
||||
question: str
|
||||
context: Optional[str] = None
|
||||
ontology_hint: Optional[str] = None
|
||||
max_results: int = 10
|
||||
confidence_threshold: float = 0.7
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryResponse:
|
||||
"""Complete query response."""
|
||||
answer: str
|
||||
confidence: float
|
||||
execution_time: float
|
||||
question_analysis: QuestionComponents
|
||||
ontology_subsets: List[QueryOntologySubset]
|
||||
query_route: QueryRoute
|
||||
generated_query: Union[SPARQLQuery, CypherQuery]
|
||||
raw_results: Union[SPARQLResult, CypherResult]
|
||||
supporting_facts: List[str]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class OntoRAGQueryService(FlowProcessor):
|
||||
"""Main OntoRAG query service orchestrating all components."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""Initialize OntoRAG query service.
|
||||
|
||||
Args:
|
||||
config: Service configuration
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Initialize components
|
||||
self.config_store = None
|
||||
self.ontology_loader = None
|
||||
self.vector_store = None
|
||||
self.question_analyzer = None
|
||||
self.ontology_matcher = None
|
||||
self.backend_router = None
|
||||
self.sparql_generator = None
|
||||
self.sparql_engine = None
|
||||
self.cypher_generator = None
|
||||
self.cypher_executor = None
|
||||
self.answer_generator = None
|
||||
|
||||
# Cache for loaded ontologies
|
||||
self.ontology_cache = {}
|
||||
|
||||
async def init(self):
|
||||
"""Initialize all components."""
|
||||
await super().init()
|
||||
|
||||
# Initialize configuration store
|
||||
self.config_store = ConfigTableStore(self.config.get('config_store', {}))
|
||||
|
||||
# Initialize ontology components
|
||||
self.ontology_loader = OntologyLoader(self.config_store)
|
||||
|
||||
# Initialize vector store
|
||||
vector_config = self.config.get('vector_store', {})
|
||||
self.vector_store = InMemoryVectorStore.create(
|
||||
store_type=vector_config.get('type', 'numpy'),
|
||||
dimension=vector_config.get('dimension', 384),
|
||||
similarity_threshold=vector_config.get('similarity_threshold', 0.7)
|
||||
)
|
||||
|
||||
# Initialize question analyzer
|
||||
analyzer_config = self.config.get('question_analyzer', {})
|
||||
self.question_analyzer = QuestionAnalyzer(
|
||||
prompt_service=self.prompt_service,
|
||||
config=analyzer_config
|
||||
)
|
||||
|
||||
# Initialize ontology matcher
|
||||
matcher_config = self.config.get('ontology_matcher', {})
|
||||
self.ontology_matcher = OntologyMatcher(
|
||||
vector_store=self.vector_store,
|
||||
embedding_service=self.embedding_service,
|
||||
config=matcher_config
|
||||
)
|
||||
|
||||
# Initialize backend router
|
||||
router_config = self.config.get('backend_router', {})
|
||||
self.backend_router = BackendRouter(router_config)
|
||||
|
||||
# Initialize query generators
|
||||
self.sparql_generator = SPARQLGenerator(prompt_service=self.prompt_service)
|
||||
self.cypher_generator = CypherGenerator(prompt_service=self.prompt_service)
|
||||
|
||||
# Initialize executors
|
||||
sparql_config = self.config.get('sparql_executor', {})
|
||||
if self.backend_router.is_backend_enabled(BackendType.CASSANDRA):
|
||||
cassandra_config = self.backend_router.get_backend_config(BackendType.CASSANDRA)
|
||||
if cassandra_config:
|
||||
self.sparql_engine = SPARQLCassandraEngine(cassandra_config)
|
||||
await self.sparql_engine.initialize()
|
||||
|
||||
cypher_config = self.config.get('cypher_executor', {})
|
||||
enabled_graph_backends = [
|
||||
bt for bt in [BackendType.NEO4J, BackendType.MEMGRAPH, BackendType.FALKORDB]
|
||||
if self.backend_router.is_backend_enabled(bt)
|
||||
]
|
||||
if enabled_graph_backends:
|
||||
self.cypher_executor = CypherExecutor(cypher_config)
|
||||
await self.cypher_executor.initialize()
|
||||
|
||||
# Initialize answer generator
|
||||
self.answer_generator = AnswerGenerator(prompt_service=self.prompt_service)
|
||||
|
||||
logger.info("OntoRAG query service initialized")
|
||||
|
||||
async def process(self, request: QueryRequest) -> QueryResponse:
|
||||
"""Process a natural language query.
|
||||
|
||||
Args:
|
||||
request: Query request
|
||||
|
||||
Returns:
|
||||
Complete query response
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
logger.info(f"Processing query: {request.question}")
|
||||
|
||||
# Step 1: Analyze question
|
||||
question_components = await self.question_analyzer.analyze_question(
|
||||
request.question, context=request.context
|
||||
)
|
||||
logger.debug(f"Question analysis: {question_components.question_type}")
|
||||
|
||||
# Step 2: Load and match ontologies
|
||||
ontology_subsets = await self._load_and_match_ontologies(
|
||||
question_components, request.ontology_hint
|
||||
)
|
||||
logger.debug(f"Found {len(ontology_subsets)} relevant ontology subsets")
|
||||
|
||||
# Step 3: Route to appropriate backend
|
||||
query_route = self.backend_router.route_query(
|
||||
question_components, ontology_subsets
|
||||
)
|
||||
logger.debug(f"Routed to {query_route.backend_type.value} backend")
|
||||
|
||||
# Step 4: Generate and execute query
|
||||
if query_route.query_language == 'sparql':
|
||||
query_results = await self._execute_sparql_path(
|
||||
question_components, ontology_subsets, query_route
|
||||
)
|
||||
else: # cypher
|
||||
query_results = await self._execute_cypher_path(
|
||||
question_components, ontology_subsets, query_route
|
||||
)
|
||||
|
||||
# Step 5: Generate natural language answer
|
||||
generated_answer = await self.answer_generator.generate_answer(
|
||||
question_components,
|
||||
query_results['raw_results'],
|
||||
ontology_subsets[0] if ontology_subsets else None,
|
||||
query_route.backend_type.value
|
||||
)
|
||||
|
||||
# Build response
|
||||
execution_time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
response = QueryResponse(
|
||||
answer=generated_answer.answer,
|
||||
confidence=min(query_route.confidence, generated_answer.metadata.confidence),
|
||||
execution_time=execution_time,
|
||||
question_analysis=question_components,
|
||||
ontology_subsets=ontology_subsets,
|
||||
query_route=query_route,
|
||||
generated_query=query_results['generated_query'],
|
||||
raw_results=query_results['raw_results'],
|
||||
supporting_facts=generated_answer.supporting_facts,
|
||||
metadata={
|
||||
'backend_used': query_route.backend_type.value,
|
||||
'query_language': query_route.query_language,
|
||||
'ontology_count': len(ontology_subsets),
|
||||
'result_count': generated_answer.metadata.result_count,
|
||||
'routing_reasoning': query_route.reasoning,
|
||||
'generation_time': generated_answer.generation_time
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Query processed successfully in {execution_time:.2f}s")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Query processing failed: {e}")
|
||||
execution_time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
# Return error response
|
||||
return QueryResponse(
|
||||
answer=f"I encountered an error processing your query: {str(e)}",
|
||||
confidence=0.0,
|
||||
execution_time=execution_time,
|
||||
question_analysis=QuestionComponents(
|
||||
original_question=request.question,
|
||||
normalized_question=request.question,
|
||||
question_type=None,
|
||||
entities=[], keywords=[], relationships=[], constraints=[],
|
||||
aggregations=[], expected_answer_type="unknown"
|
||||
),
|
||||
ontology_subsets=[],
|
||||
query_route=None,
|
||||
generated_query=None,
|
||||
raw_results=None,
|
||||
supporting_facts=[],
|
||||
metadata={'error': str(e), 'execution_time': execution_time}
|
||||
)
|
||||
|
||||
async def _load_and_match_ontologies(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_hint: Optional[str] = None) -> List[QueryOntologySubset]:
|
||||
"""Load ontologies and find relevant subsets.
|
||||
|
||||
Args:
|
||||
question_components: Analyzed question
|
||||
ontology_hint: Optional ontology hint
|
||||
|
||||
Returns:
|
||||
List of relevant ontology subsets
|
||||
"""
|
||||
try:
|
||||
# Load available ontologies
|
||||
if ontology_hint:
|
||||
# Load specific ontology
|
||||
ontologies = [await self.ontology_loader.load_ontology(ontology_hint)]
|
||||
else:
|
||||
# Load all available ontologies
|
||||
available_ontologies = await self.ontology_loader.list_available_ontologies()
|
||||
ontologies = []
|
||||
for ontology_id in available_ontologies[:5]: # Limit to 5 for performance
|
||||
try:
|
||||
ontology = await self.ontology_loader.load_ontology(ontology_id)
|
||||
ontologies.append(ontology)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ontology {ontology_id}: {e}")
|
||||
|
||||
if not ontologies:
|
||||
logger.warning("No ontologies loaded")
|
||||
return []
|
||||
|
||||
# Extract relevant subsets
|
||||
ontology_subsets = []
|
||||
for ontology in ontologies:
|
||||
subset = await self.ontology_matcher.select_relevant_subset(
|
||||
question_components, ontology
|
||||
)
|
||||
if subset and (subset.classes or subset.object_properties or subset.datatype_properties):
|
||||
ontology_subsets.append(subset)
|
||||
|
||||
return ontology_subsets
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load and match ontologies: {e}")
|
||||
return []
|
||||
|
||||
async def _execute_sparql_path(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subsets: List[QueryOntologySubset],
|
||||
query_route: QueryRoute) -> Dict[str, Any]:
|
||||
"""Execute SPARQL query path.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
ontology_subsets: Ontology subsets
|
||||
query_route: Query route
|
||||
|
||||
Returns:
|
||||
Query execution results
|
||||
"""
|
||||
if not self.sparql_engine:
|
||||
raise RuntimeError("SPARQL engine not initialized")
|
||||
|
||||
# Generate SPARQL query
|
||||
primary_subset = ontology_subsets[0] if ontology_subsets else None
|
||||
sparql_query = await self.sparql_generator.generate_sparql(
|
||||
question_components, primary_subset
|
||||
)
|
||||
|
||||
logger.debug(f"Generated SPARQL: {sparql_query.query}")
|
||||
|
||||
# Execute query
|
||||
sparql_results = self.sparql_engine.execute_sparql(sparql_query.query)
|
||||
|
||||
return {
|
||||
'generated_query': sparql_query,
|
||||
'raw_results': sparql_results
|
||||
}
|
||||
|
||||
async def _execute_cypher_path(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subsets: List[QueryOntologySubset],
|
||||
query_route: QueryRoute) -> Dict[str, Any]:
|
||||
"""Execute Cypher query path.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
ontology_subsets: Ontology subsets
|
||||
query_route: Query route
|
||||
|
||||
Returns:
|
||||
Query execution results
|
||||
"""
|
||||
if not self.cypher_executor:
|
||||
raise RuntimeError("Cypher executor not initialized")
|
||||
|
||||
# Generate Cypher query
|
||||
primary_subset = ontology_subsets[0] if ontology_subsets else None
|
||||
cypher_query = await self.cypher_generator.generate_cypher(
|
||||
question_components, primary_subset
|
||||
)
|
||||
|
||||
logger.debug(f"Generated Cypher: {cypher_query.query}")
|
||||
|
||||
# Execute query
|
||||
database_type = query_route.backend_type.value
|
||||
cypher_results = await self.cypher_executor.execute_query(
|
||||
cypher_query.query, database_type=database_type
|
||||
)
|
||||
|
||||
return {
|
||||
'generated_query': cypher_query,
|
||||
'raw_results': cypher_results
|
||||
}
|
||||
|
||||
async def get_supported_backends(self) -> List[str]:
|
||||
"""Get list of supported and enabled backends.
|
||||
|
||||
Returns:
|
||||
List of backend names
|
||||
"""
|
||||
return [bt.value for bt in self.backend_router.get_available_backends()]
|
||||
|
||||
async def get_available_ontologies(self) -> List[str]:
|
||||
"""Get list of available ontologies.
|
||||
|
||||
Returns:
|
||||
List of ontology identifiers
|
||||
"""
|
||||
if self.ontology_loader:
|
||||
return await self.ontology_loader.list_available_ontologies()
|
||||
return []
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Perform health check on all components.
|
||||
|
||||
Returns:
|
||||
Health status of all components
|
||||
"""
|
||||
health = {
|
||||
'service': 'healthy',
|
||||
'components': {},
|
||||
'backends': {},
|
||||
'ontologies': {}
|
||||
}
|
||||
|
||||
try:
|
||||
# Check ontology loader
|
||||
if self.ontology_loader:
|
||||
ontologies = await self.ontology_loader.list_available_ontologies()
|
||||
health['components']['ontology_loader'] = 'healthy'
|
||||
health['ontologies']['count'] = len(ontologies)
|
||||
else:
|
||||
health['components']['ontology_loader'] = 'not_initialized'
|
||||
|
||||
# Check vector store
|
||||
if self.vector_store:
|
||||
health['components']['vector_store'] = 'healthy'
|
||||
health['components']['vector_store_type'] = type(self.vector_store).__name__
|
||||
else:
|
||||
health['components']['vector_store'] = 'not_initialized'
|
||||
|
||||
# Check backends
|
||||
for backend_type in self.backend_router.get_available_backends():
|
||||
if backend_type == BackendType.CASSANDRA and self.sparql_engine:
|
||||
health['backends']['cassandra'] = 'healthy'
|
||||
elif backend_type in [BackendType.NEO4J, BackendType.MEMGRAPH, BackendType.FALKORDB] and self.cypher_executor:
|
||||
health['backends'][backend_type.value] = 'healthy'
|
||||
else:
|
||||
health['backends'][backend_type.value] = 'configured_but_not_initialized'
|
||||
|
||||
except Exception as e:
|
||||
health['service'] = 'degraded'
|
||||
health['error'] = str(e)
|
||||
|
||||
return health
|
||||
|
||||
async def close(self):
|
||||
"""Close all connections and cleanup resources."""
|
||||
try:
|
||||
if self.sparql_engine:
|
||||
self.sparql_engine.close()
|
||||
|
||||
if self.cypher_executor:
|
||||
await self.cypher_executor.close()
|
||||
|
||||
if self.config_store:
|
||||
# ConfigTableStore cleanup if needed
|
||||
pass
|
||||
|
||||
logger.info("OntoRAG query service closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing OntoRAG query service: {e}")
|
||||
364
trustgraph-flow/trustgraph/query/ontology/question_analyzer.py
Normal file
364
trustgraph-flow/trustgraph/query/ontology/question_analyzer.py
Normal file
|
|
@ -0,0 +1,364 @@
|
|||
"""
|
||||
Question analyzer for ontology-sensitive query system.
|
||||
Decomposes user questions into semantic components.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QuestionType(Enum):
|
||||
"""Types of questions that can be asked."""
|
||||
FACTUAL = "factual" # What is X?
|
||||
RETRIEVAL = "retrieval" # Find all X
|
||||
AGGREGATION = "aggregation" # How many X?
|
||||
COMPARISON = "comparison" # Is X better than Y?
|
||||
RELATIONSHIP = "relationship" # How is X related to Y?
|
||||
BOOLEAN = "boolean" # Yes/no questions
|
||||
PROCESS = "process" # How to do X?
|
||||
TEMPORAL = "temporal" # When did X happen?
|
||||
SPATIAL = "spatial" # Where is X?
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuestionComponents:
|
||||
"""Components extracted from a question."""
|
||||
original_question: str
|
||||
question_type: QuestionType
|
||||
entities: List[str]
|
||||
relationships: List[str]
|
||||
constraints: List[str]
|
||||
aggregations: List[str]
|
||||
expected_answer_type: str
|
||||
keywords: List[str]
|
||||
|
||||
|
||||
class QuestionAnalyzer:
|
||||
"""Analyzes natural language questions to extract semantic components."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize question analyzer."""
|
||||
# Question word patterns
|
||||
self.question_patterns = {
|
||||
QuestionType.FACTUAL: [
|
||||
r'^what\s+(?:is|are)',
|
||||
r'^who\s+(?:is|are)',
|
||||
r'^which\s+',
|
||||
],
|
||||
QuestionType.RETRIEVAL: [
|
||||
r'^find\s+',
|
||||
r'^list\s+',
|
||||
r'^show\s+',
|
||||
r'^get\s+',
|
||||
r'^retrieve\s+',
|
||||
],
|
||||
QuestionType.AGGREGATION: [
|
||||
r'^how\s+many',
|
||||
r'^count\s+',
|
||||
r'^what\s+(?:is|are)\s+the\s+(?:number|total|sum)',
|
||||
],
|
||||
QuestionType.COMPARISON: [
|
||||
r'(?:better|worse|more|less|greater|smaller)\s+than',
|
||||
r'compare\s+',
|
||||
r'difference\s+between',
|
||||
],
|
||||
QuestionType.RELATIONSHIP: [
|
||||
r'^how\s+(?:is|are).*related',
|
||||
r'relationship\s+between',
|
||||
r'connection\s+between',
|
||||
],
|
||||
QuestionType.BOOLEAN: [
|
||||
r'^(?:is|are|was|were|do|does|did|can|could|will|would|should)',
|
||||
r'^has\s+',
|
||||
r'^have\s+',
|
||||
],
|
||||
QuestionType.PROCESS: [
|
||||
r'^how\s+(?:to|do)',
|
||||
r'^explain\s+how',
|
||||
],
|
||||
QuestionType.TEMPORAL: [
|
||||
r'^when\s+',
|
||||
r'what\s+time',
|
||||
r'what\s+date',
|
||||
],
|
||||
QuestionType.SPATIAL: [
|
||||
r'^where\s+',
|
||||
r'location\s+of',
|
||||
],
|
||||
}
|
||||
|
||||
# Aggregation keywords
|
||||
self.aggregation_keywords = [
|
||||
'count', 'sum', 'total', 'average', 'mean', 'median',
|
||||
'maximum', 'minimum', 'max', 'min', 'number of'
|
||||
]
|
||||
|
||||
# Constraint patterns
|
||||
self.constraint_patterns = [
|
||||
r'(?:with|having|where)\s+(.+?)(?:\s+and|\s+or|$)',
|
||||
r'(?:greater|less|more|fewer)\s+than\s+(\d+)',
|
||||
r'(?:between|from)\s+(.+?)\s+(?:and|to)\s+(.+)',
|
||||
r'(?:before|after|since|until)\s+(.+)',
|
||||
]
|
||||
|
||||
def analyze(self, question: str) -> QuestionComponents:
|
||||
"""Analyze a question to extract components.
|
||||
|
||||
Args:
|
||||
question: Natural language question
|
||||
|
||||
Returns:
|
||||
QuestionComponents with extracted information
|
||||
"""
|
||||
# Normalize question
|
||||
question_lower = question.lower().strip()
|
||||
|
||||
# Determine question type
|
||||
question_type = self._identify_question_type(question_lower)
|
||||
|
||||
# Extract entities
|
||||
entities = self._extract_entities(question)
|
||||
|
||||
# Extract relationships
|
||||
relationships = self._extract_relationships(question_lower)
|
||||
|
||||
# Extract constraints
|
||||
constraints = self._extract_constraints(question_lower)
|
||||
|
||||
# Extract aggregations
|
||||
aggregations = self._extract_aggregations(question_lower)
|
||||
|
||||
# Determine expected answer type
|
||||
answer_type = self._determine_answer_type(question_type, aggregations)
|
||||
|
||||
# Extract keywords
|
||||
keywords = self._extract_keywords(question_lower)
|
||||
|
||||
return QuestionComponents(
|
||||
original_question=question,
|
||||
question_type=question_type,
|
||||
entities=entities,
|
||||
relationships=relationships,
|
||||
constraints=constraints,
|
||||
aggregations=aggregations,
|
||||
expected_answer_type=answer_type,
|
||||
keywords=keywords
|
||||
)
|
||||
|
||||
def _identify_question_type(self, question: str) -> QuestionType:
|
||||
"""Identify the type of question.
|
||||
|
||||
Args:
|
||||
question: Lowercase question text
|
||||
|
||||
Returns:
|
||||
QuestionType enum value
|
||||
"""
|
||||
for q_type, patterns in self.question_patterns.items():
|
||||
for pattern in patterns:
|
||||
if re.search(pattern, question):
|
||||
return q_type
|
||||
|
||||
# Default to factual
|
||||
return QuestionType.FACTUAL
|
||||
|
||||
def _extract_entities(self, question: str) -> List[str]:
|
||||
"""Extract potential entities from question.
|
||||
|
||||
Args:
|
||||
question: Original question text
|
||||
|
||||
Returns:
|
||||
List of entity strings
|
||||
"""
|
||||
entities = []
|
||||
|
||||
# Extract capitalized words/phrases (potential proper nouns)
|
||||
# Pattern for consecutive capitalized words
|
||||
pattern = r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b'
|
||||
matches = re.findall(pattern, question)
|
||||
entities.extend(matches)
|
||||
|
||||
# Extract quoted strings
|
||||
quoted = re.findall(r'"([^"]+)"', question)
|
||||
entities.extend(quoted)
|
||||
quoted = re.findall(r"'([^']+)'", question)
|
||||
entities.extend(quoted)
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_entities = []
|
||||
for entity in entities:
|
||||
if entity not in seen:
|
||||
seen.add(entity)
|
||||
unique_entities.append(entity)
|
||||
|
||||
return unique_entities
|
||||
|
||||
def _extract_relationships(self, question: str) -> List[str]:
|
||||
"""Extract relationship indicators from question.
|
||||
|
||||
Args:
|
||||
question: Lowercase question text
|
||||
|
||||
Returns:
|
||||
List of relationship strings
|
||||
"""
|
||||
relationships = []
|
||||
|
||||
# Common relationship patterns
|
||||
rel_patterns = [
|
||||
r'(\w+)\s+(?:of|by|from|to|with|for)\s+',
|
||||
r'has\s+(\w+)',
|
||||
r'belongs?\s+to',
|
||||
r'(?:created|written|authored|owned)\s+by',
|
||||
r'related\s+to',
|
||||
r'connected\s+to',
|
||||
r'associated\s+with',
|
||||
]
|
||||
|
||||
for pattern in rel_patterns:
|
||||
matches = re.findall(pattern, question)
|
||||
relationships.extend(matches)
|
||||
|
||||
# Clean up
|
||||
relationships = [r for r in relationships if len(r) > 2]
|
||||
return list(set(relationships))
|
||||
|
||||
def _extract_constraints(self, question: str) -> List[str]:
|
||||
"""Extract constraints from question.
|
||||
|
||||
Args:
|
||||
question: Lowercase question text
|
||||
|
||||
Returns:
|
||||
List of constraint strings
|
||||
"""
|
||||
constraints = []
|
||||
|
||||
for pattern in self.constraint_patterns:
|
||||
matches = re.findall(pattern, question)
|
||||
if matches:
|
||||
if isinstance(matches[0], tuple):
|
||||
constraints.extend(list(matches[0]))
|
||||
else:
|
||||
constraints.extend(matches)
|
||||
|
||||
# Clean up
|
||||
constraints = [c.strip() for c in constraints if c and len(c.strip()) > 0]
|
||||
return constraints
|
||||
|
||||
def _extract_aggregations(self, question: str) -> List[str]:
|
||||
"""Extract aggregation operations from question.
|
||||
|
||||
Args:
|
||||
question: Lowercase question text
|
||||
|
||||
Returns:
|
||||
List of aggregation operations
|
||||
"""
|
||||
aggregations = []
|
||||
|
||||
for keyword in self.aggregation_keywords:
|
||||
if keyword in question:
|
||||
aggregations.append(keyword)
|
||||
|
||||
return aggregations
|
||||
|
||||
def _determine_answer_type(self, question_type: QuestionType,
|
||||
aggregations: List[str]) -> str:
|
||||
"""Determine expected answer type.
|
||||
|
||||
Args:
|
||||
question_type: Type of question
|
||||
aggregations: Aggregation operations found
|
||||
|
||||
Returns:
|
||||
Expected answer type string
|
||||
"""
|
||||
if aggregations:
|
||||
if any(a in ['count', 'number of', 'total'] for a in aggregations):
|
||||
return 'number'
|
||||
elif any(a in ['average', 'mean', 'median'] for a in aggregations):
|
||||
return 'number'
|
||||
elif any(a in ['sum'] for a in aggregations):
|
||||
return 'number'
|
||||
|
||||
if question_type == QuestionType.BOOLEAN:
|
||||
return 'boolean'
|
||||
elif question_type == QuestionType.TEMPORAL:
|
||||
return 'datetime'
|
||||
elif question_type == QuestionType.SPATIAL:
|
||||
return 'location'
|
||||
elif question_type == QuestionType.RETRIEVAL:
|
||||
return 'list'
|
||||
elif question_type == QuestionType.COMPARISON:
|
||||
return 'comparison'
|
||||
else:
|
||||
return 'text'
|
||||
|
||||
def _extract_keywords(self, question: str) -> List[str]:
|
||||
"""Extract important keywords from question.
|
||||
|
||||
Args:
|
||||
question: Lowercase question text
|
||||
|
||||
Returns:
|
||||
List of keywords
|
||||
"""
|
||||
# Remove common stop words
|
||||
stop_words = {
|
||||
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to',
|
||||
'for', 'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are',
|
||||
'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do',
|
||||
'does', 'did', 'will', 'would', 'could', 'should', 'may',
|
||||
'might', 'must', 'can', 'shall', 'what', 'which', 'who',
|
||||
'when', 'where', 'why', 'how'
|
||||
}
|
||||
|
||||
# Extract words
|
||||
words = re.findall(r'\b\w+\b', question)
|
||||
|
||||
# Filter stop words and short words
|
||||
keywords = [w for w in words if w not in stop_words and len(w) > 2]
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_keywords = []
|
||||
for kw in keywords:
|
||||
if kw not in seen:
|
||||
seen.add(kw)
|
||||
unique_keywords.append(kw)
|
||||
|
||||
return unique_keywords
|
||||
|
||||
def get_question_segments(self, question: str) -> List[str]:
|
||||
"""Split question into segments for embedding.
|
||||
|
||||
Args:
|
||||
question: Question text
|
||||
|
||||
Returns:
|
||||
List of question segments
|
||||
"""
|
||||
segments = []
|
||||
|
||||
# Add full question
|
||||
segments.append(question)
|
||||
|
||||
# Split by clauses
|
||||
clauses = re.split(r'[,;]', question)
|
||||
segments.extend([c.strip() for c in clauses if len(c.strip()) > 3])
|
||||
|
||||
# Extract key phrases
|
||||
components = self.analyze(question)
|
||||
segments.extend(components.entities)
|
||||
segments.extend(components.keywords)
|
||||
|
||||
# Remove duplicates
|
||||
return list(dict.fromkeys(segments))
|
||||
481
trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py
Normal file
481
trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py
Normal file
|
|
@ -0,0 +1,481 @@
|
|||
"""
|
||||
SPARQL-Cassandra engine using Python rdflib.
|
||||
Executes SPARQL queries against Cassandra using a custom Store implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Iterator, Tuple
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
|
||||
# Try to import rdflib
|
||||
try:
|
||||
from rdflib import Graph, Namespace, URIRef, Literal, BNode
|
||||
from rdflib.store import Store
|
||||
from rdflib.plugins.sparql.processor import SPARQLResult
|
||||
from rdflib.plugins.sparql import prepareQuery
|
||||
from rdflib.term import Node
|
||||
RDFLIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
RDFLIB_AVAILABLE = False
|
||||
|
||||
# Try to import Cassandra driver
|
||||
try:
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
from cassandra.policies import DCAwareRoundRobinPolicy
|
||||
CASSANDRA_AVAILABLE = True
|
||||
except ImportError:
|
||||
CASSANDRA_AVAILABLE = False
|
||||
|
||||
from ....tables.config import ConfigTableStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SPARQLResult:
|
||||
"""Result from SPARQL query execution."""
|
||||
bindings: List[Dict[str, Any]]
|
||||
variables: List[str]
|
||||
ask_result: Optional[bool] = None # For ASK queries
|
||||
execution_time: float = 0.0
|
||||
query_plan: Optional[str] = None
|
||||
|
||||
|
||||
class CassandraTripleStore(Store if RDFLIB_AVAILABLE else object):
|
||||
"""Custom rdflib Store implementation for Cassandra."""
|
||||
|
||||
def __init__(self, cassandra_config: Dict[str, Any]):
|
||||
"""Initialize Cassandra triple store.
|
||||
|
||||
Args:
|
||||
cassandra_config: Cassandra connection configuration
|
||||
"""
|
||||
if not CASSANDRA_AVAILABLE:
|
||||
raise RuntimeError("Cassandra driver not available")
|
||||
if not RDFLIB_AVAILABLE:
|
||||
raise RuntimeError("rdflib not available")
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.cassandra_config = cassandra_config
|
||||
self.cluster = None
|
||||
self.session = None
|
||||
self.keyspace = cassandra_config.get('keyspace', 'trustgraph')
|
||||
|
||||
# Triple storage table structure
|
||||
self.triple_table = f"{self.keyspace}.triples"
|
||||
self.metadata_table = f"{self.keyspace}.triple_metadata"
|
||||
|
||||
def open(self, configuration=None, create=False):
|
||||
"""Open connection to Cassandra."""
|
||||
try:
|
||||
# Create authentication if provided
|
||||
auth_provider = None
|
||||
if 'username' in self.cassandra_config and 'password' in self.cassandra_config:
|
||||
auth_provider = PlainTextAuthProvider(
|
||||
username=self.cassandra_config['username'],
|
||||
password=self.cassandra_config['password']
|
||||
)
|
||||
|
||||
# Create cluster
|
||||
self.cluster = Cluster(
|
||||
[self.cassandra_config.get('host', 'localhost')],
|
||||
port=self.cassandra_config.get('port', 9042),
|
||||
auth_provider=auth_provider,
|
||||
load_balancing_policy=DCAwareRoundRobinPolicy()
|
||||
)
|
||||
|
||||
# Connect
|
||||
self.session = self.cluster.connect()
|
||||
|
||||
# Ensure keyspace exists
|
||||
if create:
|
||||
self._create_schema()
|
||||
|
||||
# Set keyspace
|
||||
self.session.set_keyspace(self.keyspace)
|
||||
|
||||
logger.info(f"Connected to Cassandra cluster: {self.cassandra_config.get('host')}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra: {e}")
|
||||
return False
|
||||
|
||||
def close(self, commit_pending_transaction=True):
|
||||
"""Close Cassandra connection."""
|
||||
if self.session:
|
||||
self.session.shutdown()
|
||||
if self.cluster:
|
||||
self.cluster.shutdown()
|
||||
|
||||
def _create_schema(self):
|
||||
"""Create Cassandra schema for triple storage."""
|
||||
# Create keyspace
|
||||
self.session.execute(f"""
|
||||
CREATE KEYSPACE IF NOT EXISTS {self.keyspace}
|
||||
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}
|
||||
""")
|
||||
|
||||
# Create triples table optimized for SPARQL queries
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.triple_table} (
|
||||
subject text,
|
||||
predicate text,
|
||||
object text,
|
||||
object_datatype text,
|
||||
object_language text,
|
||||
is_literal boolean,
|
||||
graph_id text,
|
||||
PRIMARY KEY ((subject), predicate, object)
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes for efficient querying
|
||||
self.session.execute(f"""
|
||||
CREATE INDEX IF NOT EXISTS ON {self.triple_table} (predicate)
|
||||
""")
|
||||
self.session.execute(f"""
|
||||
CREATE INDEX IF NOT EXISTS ON {self.triple_table} (object)
|
||||
""")
|
||||
|
||||
# Metadata table for graph information
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
|
||||
graph_id text PRIMARY KEY,
|
||||
created timestamp,
|
||||
modified timestamp,
|
||||
triple_count counter
|
||||
)
|
||||
""")
|
||||
|
||||
def triples(self, triple_pattern, context=None):
|
||||
"""Retrieve triples matching the given pattern.
|
||||
|
||||
Args:
|
||||
triple_pattern: (subject, predicate, object) pattern with None for variables
|
||||
context: Graph context (optional)
|
||||
|
||||
Yields:
|
||||
Matching triples as (subject, predicate, object) tuples
|
||||
"""
|
||||
if not self.session:
|
||||
return
|
||||
|
||||
subject, predicate, object_val = triple_pattern
|
||||
|
||||
# Build CQL query based on pattern
|
||||
cql_queries = self._pattern_to_cql(subject, predicate, object_val)
|
||||
|
||||
for cql, params in cql_queries:
|
||||
try:
|
||||
rows = self.session.execute(cql, params)
|
||||
for row in rows:
|
||||
yield self._row_to_triple(row)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing CQL query: {e}")
|
||||
|
||||
def _pattern_to_cql(self, subject, predicate, object_val) -> List[Tuple[str, List]]:
|
||||
"""Convert triple pattern to CQL queries.
|
||||
|
||||
Args:
|
||||
subject: Subject node or None
|
||||
predicate: Predicate node or None
|
||||
object_val: Object node or None
|
||||
|
||||
Returns:
|
||||
List of (CQL query, parameters) tuples
|
||||
"""
|
||||
queries = []
|
||||
|
||||
# Convert None to wildcard, nodes to strings
|
||||
s_str = str(subject) if subject else None
|
||||
p_str = str(predicate) if predicate else None
|
||||
o_str = str(object_val) if object_val else None
|
||||
|
||||
if s_str and p_str and o_str:
|
||||
# Specific triple lookup
|
||||
cql = f"SELECT * FROM {self.triple_table} WHERE subject = ? AND predicate = ? AND object = ?"
|
||||
queries.append((cql, [s_str, p_str, o_str]))
|
||||
|
||||
elif s_str and p_str:
|
||||
# Subject and predicate known
|
||||
cql = f"SELECT * FROM {self.triple_table} WHERE subject = ? AND predicate = ?"
|
||||
queries.append((cql, [s_str, p_str]))
|
||||
|
||||
elif s_str:
|
||||
# Subject known
|
||||
cql = f"SELECT * FROM {self.triple_table} WHERE subject = ?"
|
||||
queries.append((cql, [s_str]))
|
||||
|
||||
elif p_str:
|
||||
# Predicate known (requires index scan)
|
||||
cql = f"SELECT * FROM {self.triple_table} WHERE predicate = ? ALLOW FILTERING"
|
||||
queries.append((cql, [p_str]))
|
||||
|
||||
elif o_str:
|
||||
# Object known (requires index scan)
|
||||
cql = f"SELECT * FROM {self.triple_table} WHERE object = ? ALLOW FILTERING"
|
||||
queries.append((cql, [o_str]))
|
||||
|
||||
else:
|
||||
# Full scan (should be avoided in production)
|
||||
cql = f"SELECT * FROM {self.triple_table}"
|
||||
queries.append((cql, []))
|
||||
|
||||
return queries
|
||||
|
||||
def _row_to_triple(self, row):
|
||||
"""Convert Cassandra row to RDF triple.
|
||||
|
||||
Args:
|
||||
row: Cassandra row object
|
||||
|
||||
Returns:
|
||||
(subject, predicate, object) tuple with rdflib nodes
|
||||
"""
|
||||
# Convert to rdflib nodes
|
||||
subject = URIRef(row.subject) if row.subject.startswith('http') else BNode(row.subject)
|
||||
|
||||
predicate = URIRef(row.predicate)
|
||||
|
||||
if row.is_literal:
|
||||
# Create literal with datatype/language
|
||||
if row.object_datatype:
|
||||
object_node = Literal(row.object, datatype=URIRef(row.object_datatype))
|
||||
elif row.object_language:
|
||||
object_node = Literal(row.object, lang=row.object_language)
|
||||
else:
|
||||
object_node = Literal(row.object)
|
||||
else:
|
||||
object_node = URIRef(row.object) if row.object.startswith('http') else BNode(row.object)
|
||||
|
||||
return (subject, predicate, object_node)
|
||||
|
||||
def add(self, triple, context=None, quoted=False):
|
||||
"""Add a triple to the store.
|
||||
|
||||
Args:
|
||||
triple: (subject, predicate, object) tuple
|
||||
context: Graph context
|
||||
quoted: Whether triple is quoted
|
||||
"""
|
||||
if not self.session:
|
||||
return
|
||||
|
||||
subject, predicate, object_val = triple
|
||||
|
||||
# Convert to storage format
|
||||
s_str = str(subject)
|
||||
p_str = str(predicate)
|
||||
|
||||
is_literal = isinstance(object_val, Literal)
|
||||
o_str = str(object_val)
|
||||
o_datatype = str(object_val.datatype) if is_literal and object_val.datatype else None
|
||||
o_language = object_val.language if is_literal and object_val.language else None
|
||||
|
||||
# Insert into Cassandra
|
||||
cql = f"""
|
||||
INSERT INTO {self.triple_table}
|
||||
(subject, predicate, object, object_datatype, object_language, is_literal, graph_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
try:
|
||||
self.session.execute(cql, [
|
||||
s_str, p_str, o_str, o_datatype, o_language, is_literal,
|
||||
str(context) if context else 'default'
|
||||
])
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding triple: {e}")
|
||||
|
||||
def remove(self, triple, context=None):
|
||||
"""Remove a triple from the store.
|
||||
|
||||
Args:
|
||||
triple: (subject, predicate, object) tuple
|
||||
context: Graph context
|
||||
"""
|
||||
if not self.session:
|
||||
return
|
||||
|
||||
subject, predicate, object_val = triple
|
||||
|
||||
cql = f"""
|
||||
DELETE FROM {self.triple_table}
|
||||
WHERE subject = ? AND predicate = ? AND object = ?
|
||||
"""
|
||||
|
||||
try:
|
||||
self.session.execute(cql, [str(subject), str(predicate), str(object_val)])
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing triple: {e}")
|
||||
|
||||
def __len__(self, context=None):
|
||||
"""Get number of triples in store.
|
||||
|
||||
Args:
|
||||
context: Graph context
|
||||
|
||||
Returns:
|
||||
Number of triples
|
||||
"""
|
||||
if not self.session:
|
||||
return 0
|
||||
|
||||
try:
|
||||
cql = f"SELECT COUNT(*) FROM {self.triple_table}"
|
||||
result = self.session.execute(cql)
|
||||
return result.one().count
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting triples: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
class SPARQLCassandraEngine:
|
||||
"""SPARQL processor using Cassandra backend."""
|
||||
|
||||
def __init__(self, cassandra_config: Dict[str, Any]):
|
||||
"""Initialize SPARQL-Cassandra engine.
|
||||
|
||||
Args:
|
||||
cassandra_config: Cassandra configuration
|
||||
"""
|
||||
if not RDFLIB_AVAILABLE:
|
||||
raise RuntimeError("rdflib is required for SPARQL processing")
|
||||
if not CASSANDRA_AVAILABLE:
|
||||
raise RuntimeError("Cassandra driver is required")
|
||||
|
||||
self.cassandra_config = cassandra_config
|
||||
self.store = CassandraTripleStore(cassandra_config)
|
||||
self.graph = Graph(store=self.store)
|
||||
|
||||
# Common namespaces
|
||||
self.namespaces = {
|
||||
'rdf': Namespace('http://www.w3.org/1999/02/22-rdf-syntax-ns#'),
|
||||
'rdfs': Namespace('http://www.w3.org/2000/01/rdf-schema#'),
|
||||
'owl': Namespace('http://www.w3.org/2002/07/owl#'),
|
||||
'xsd': Namespace('http://www.w3.org/2001/XMLSchema#'),
|
||||
}
|
||||
|
||||
# Bind namespaces to graph
|
||||
for prefix, namespace in self.namespaces.items():
|
||||
self.graph.bind(prefix, namespace)
|
||||
|
||||
async def initialize(self, create_schema=False):
|
||||
"""Initialize the engine.
|
||||
|
||||
Args:
|
||||
create_schema: Whether to create Cassandra schema
|
||||
"""
|
||||
success = self.store.open(create=create_schema)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to connect to Cassandra")
|
||||
|
||||
logger.info("SPARQL-Cassandra engine initialized")
|
||||
|
||||
def execute_sparql(self, sparql_query: str) -> SPARQLResult:
|
||||
"""Execute SPARQL query against Cassandra.
|
||||
|
||||
Args:
|
||||
sparql_query: SPARQL query string
|
||||
|
||||
Returns:
|
||||
Query results
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Prepare and execute query
|
||||
prepared_query = prepareQuery(sparql_query)
|
||||
result = self.graph.query(prepared_query)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Format results based on query type
|
||||
if sparql_query.strip().upper().startswith('ASK'):
|
||||
return SPARQLResult(
|
||||
bindings=[],
|
||||
variables=[],
|
||||
ask_result=bool(result),
|
||||
execution_time=execution_time
|
||||
)
|
||||
else:
|
||||
# SELECT query
|
||||
bindings = []
|
||||
variables = result.vars if hasattr(result, 'vars') else []
|
||||
|
||||
for row in result:
|
||||
binding = {}
|
||||
for i, var in enumerate(variables):
|
||||
if i < len(row):
|
||||
value = row[i]
|
||||
binding[str(var)] = self._format_result_value(value)
|
||||
bindings.append(binding)
|
||||
|
||||
return SPARQLResult(
|
||||
bindings=bindings,
|
||||
variables=[str(v) for v in variables],
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SPARQL execution error: {e}")
|
||||
return SPARQLResult(
|
||||
bindings=[],
|
||||
variables=[],
|
||||
execution_time=time.time() - start_time
|
||||
)
|
||||
|
||||
def _format_result_value(self, value):
|
||||
"""Format result value for output.
|
||||
|
||||
Args:
|
||||
value: RDF value (URIRef, Literal, BNode)
|
||||
|
||||
Returns:
|
||||
Formatted value
|
||||
"""
|
||||
if isinstance(value, URIRef):
|
||||
return {'type': 'uri', 'value': str(value)}
|
||||
elif isinstance(value, Literal):
|
||||
result = {'type': 'literal', 'value': str(value)}
|
||||
if value.datatype:
|
||||
result['datatype'] = str(value.datatype)
|
||||
if value.language:
|
||||
result['language'] = value.language
|
||||
return result
|
||||
elif isinstance(value, BNode):
|
||||
return {'type': 'bnode', 'value': str(value)}
|
||||
else:
|
||||
return {'type': 'unknown', 'value': str(value)}
|
||||
|
||||
def load_triples_from_store(self, config_store: ConfigTableStore):
|
||||
"""Load triples from TrustGraph's storage into the RDF graph.
|
||||
|
||||
Args:
|
||||
config_store: Configuration store with triples
|
||||
"""
|
||||
# This would need to be implemented based on how triples are stored
|
||||
# in TrustGraph's Cassandra tables
|
||||
logger.info("Loading triples from TrustGraph store...")
|
||||
|
||||
# Example implementation - would need to be adapted
|
||||
# to actual TrustGraph storage format
|
||||
try:
|
||||
# Get all triple data
|
||||
# This is a placeholder - actual implementation would need
|
||||
# to query the appropriate TrustGraph tables
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading triples: {e}")
|
||||
|
||||
def close(self):
|
||||
"""Close the engine and connections."""
|
||||
if self.store:
|
||||
self.store.close()
|
||||
logger.info("SPARQL-Cassandra engine closed")
|
||||
487
trustgraph-flow/trustgraph/query/ontology/sparql_generator.py
Normal file
487
trustgraph-flow/trustgraph/query/ontology/sparql_generator.py
Normal file
|
|
@ -0,0 +1,487 @@
|
|||
"""
|
||||
SPARQL query generator for ontology-sensitive queries.
|
||||
Converts natural language questions to SPARQL queries for Cassandra execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .question_analyzer import QuestionComponents, QuestionType
|
||||
from .ontology_matcher import QueryOntologySubset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SPARQLQuery:
|
||||
"""Generated SPARQL query with metadata."""
|
||||
query: str
|
||||
variables: List[str]
|
||||
query_type: str # SELECT, ASK, CONSTRUCT, DESCRIBE
|
||||
explanation: str
|
||||
complexity_score: float
|
||||
|
||||
|
||||
class SPARQLGenerator:
|
||||
"""Generates SPARQL queries from natural language questions using LLM assistance."""
|
||||
|
||||
def __init__(self, prompt_service=None):
|
||||
"""Initialize SPARQL generator.
|
||||
|
||||
Args:
|
||||
prompt_service: Service for LLM-based query generation
|
||||
"""
|
||||
self.prompt_service = prompt_service
|
||||
|
||||
# SPARQL query templates for common patterns
|
||||
self.templates = {
|
||||
'simple_class_query': """
|
||||
PREFIX : <{namespace}>
|
||||
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
|
||||
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
||||
|
||||
SELECT ?entity ?label WHERE {{
|
||||
?entity rdf:type :{class_name} .
|
||||
OPTIONAL {{ ?entity rdfs:label ?label }}
|
||||
}}""",
|
||||
|
||||
'property_query': """
|
||||
PREFIX : <{namespace}>
|
||||
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
|
||||
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
||||
|
||||
SELECT ?subject ?object WHERE {{
|
||||
?subject :{property} ?object .
|
||||
?subject rdf:type :{subject_class} .
|
||||
}}""",
|
||||
|
||||
'hierarchy_query': """
|
||||
PREFIX : <{namespace}>
|
||||
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
|
||||
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
||||
|
||||
SELECT ?subclass ?superclass WHERE {{
|
||||
?subclass rdfs:subClassOf* ?superclass .
|
||||
?superclass rdf:type :{root_class} .
|
||||
}}""",
|
||||
|
||||
'count_query': """
|
||||
PREFIX : <{namespace}>
|
||||
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
|
||||
|
||||
SELECT (COUNT(?entity) AS ?count) WHERE {{
|
||||
?entity rdf:type :{class_name} .
|
||||
{additional_constraints}
|
||||
}}""",
|
||||
|
||||
'boolean_query': """
|
||||
PREFIX : <{namespace}>
|
||||
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
|
||||
|
||||
ASK {{
|
||||
{triple_pattern}
|
||||
}}"""
|
||||
}
|
||||
|
||||
async def generate_sparql(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset) -> SPARQLQuery:
|
||||
"""Generate SPARQL query for a question.
|
||||
|
||||
Args:
|
||||
question_components: Analyzed question components
|
||||
ontology_subset: Relevant ontology subset
|
||||
|
||||
Returns:
|
||||
Generated SPARQL query
|
||||
"""
|
||||
# Try template-based generation first
|
||||
template_query = self._try_template_generation(question_components, ontology_subset)
|
||||
if template_query:
|
||||
logger.debug("Generated SPARQL using template")
|
||||
return template_query
|
||||
|
||||
# Fall back to LLM-based generation
|
||||
if self.prompt_service:
|
||||
llm_query = await self._generate_with_llm(question_components, ontology_subset)
|
||||
if llm_query:
|
||||
logger.debug("Generated SPARQL using LLM")
|
||||
return llm_query
|
||||
|
||||
# Final fallback to simple pattern
|
||||
logger.warning("Falling back to simple SPARQL pattern")
|
||||
return self._generate_fallback_query(question_components, ontology_subset)
|
||||
|
||||
def _try_template_generation(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset) -> Optional[SPARQLQuery]:
|
||||
"""Try to generate query using templates.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Generated query or None if no template matches
|
||||
"""
|
||||
namespace = ontology_subset.metadata.get('namespace', 'http://example.org/')
|
||||
|
||||
# Simple class query (What are the animals?)
|
||||
if (question_components.question_type == QuestionType.RETRIEVAL and
|
||||
len(question_components.entities) == 1 and
|
||||
question_components.entities[0].lower() in [c.lower() for c in ontology_subset.classes]):
|
||||
|
||||
class_name = self._find_matching_class(question_components.entities[0], ontology_subset)
|
||||
if class_name:
|
||||
query = self.templates['simple_class_query'].format(
|
||||
namespace=namespace,
|
||||
class_name=class_name
|
||||
)
|
||||
return SPARQLQuery(
|
||||
query=query,
|
||||
variables=['entity', 'label'],
|
||||
query_type='SELECT',
|
||||
explanation=f"Retrieve all instances of {class_name}",
|
||||
complexity_score=0.3
|
||||
)
|
||||
|
||||
# Count query (How many animals are there?)
|
||||
if (question_components.question_type == QuestionType.AGGREGATION and
|
||||
'count' in question_components.aggregations and
|
||||
len(question_components.entities) >= 1):
|
||||
|
||||
class_name = self._find_matching_class(question_components.entities[0], ontology_subset)
|
||||
if class_name:
|
||||
query = self.templates['count_query'].format(
|
||||
namespace=namespace,
|
||||
class_name=class_name,
|
||||
additional_constraints=self._build_constraints(question_components, ontology_subset)
|
||||
)
|
||||
return SPARQLQuery(
|
||||
query=query,
|
||||
variables=['count'],
|
||||
query_type='SELECT',
|
||||
explanation=f"Count instances of {class_name}",
|
||||
complexity_score=0.4
|
||||
)
|
||||
|
||||
# Boolean query (Is X a Y?)
|
||||
if question_components.question_type == QuestionType.BOOLEAN:
|
||||
triple_pattern = self._build_boolean_pattern(question_components, ontology_subset)
|
||||
if triple_pattern:
|
||||
query = self.templates['boolean_query'].format(
|
||||
namespace=namespace,
|
||||
triple_pattern=triple_pattern
|
||||
)
|
||||
return SPARQLQuery(
|
||||
query=query,
|
||||
variables=[],
|
||||
query_type='ASK',
|
||||
explanation="Boolean query for fact checking",
|
||||
complexity_score=0.2
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def _generate_with_llm(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset) -> Optional[SPARQLQuery]:
|
||||
"""Generate SPARQL using LLM.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Generated query or None if failed
|
||||
"""
|
||||
try:
|
||||
prompt = self._build_sparql_prompt(question_components, ontology_subset)
|
||||
response = await self.prompt_service.generate_sparql(prompt=prompt)
|
||||
|
||||
if response and isinstance(response, dict):
|
||||
query = response.get('query', '').strip()
|
||||
if query.upper().startswith(('SELECT', 'ASK', 'CONSTRUCT', 'DESCRIBE')):
|
||||
return SPARQLQuery(
|
||||
query=query,
|
||||
variables=self._extract_variables(query),
|
||||
query_type=query.split()[0].upper(),
|
||||
explanation=response.get('explanation', 'Generated by LLM'),
|
||||
complexity_score=self._calculate_complexity(query)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM SPARQL generation failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _build_sparql_prompt(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset) -> str:
|
||||
"""Build prompt for LLM SPARQL generation.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Formatted prompt string
|
||||
"""
|
||||
namespace = ontology_subset.metadata.get('namespace', 'http://example.org/')
|
||||
|
||||
# Format ontology elements
|
||||
classes_str = self._format_classes_for_prompt(ontology_subset.classes, namespace)
|
||||
props_str = self._format_properties_for_prompt(
|
||||
ontology_subset.object_properties,
|
||||
ontology_subset.datatype_properties,
|
||||
namespace
|
||||
)
|
||||
|
||||
prompt = f"""Generate a SPARQL query for the following question using the provided ontology.
|
||||
|
||||
QUESTION: {question_components.original_question}
|
||||
|
||||
ONTOLOGY NAMESPACE: {namespace}
|
||||
|
||||
AVAILABLE CLASSES:
|
||||
{classes_str}
|
||||
|
||||
AVAILABLE PROPERTIES:
|
||||
{props_str}
|
||||
|
||||
RULES:
|
||||
- Use proper SPARQL syntax
|
||||
- Include appropriate prefixes
|
||||
- Use property paths for hierarchical queries (rdfs:subClassOf*)
|
||||
- Add FILTER clauses for constraints
|
||||
- Optimize for Cassandra backend
|
||||
- Return both query and explanation
|
||||
|
||||
QUERY TYPE HINTS:
|
||||
- Question type: {question_components.question_type.value}
|
||||
- Expected answer: {question_components.expected_answer_type}
|
||||
- Entities mentioned: {', '.join(question_components.entities)}
|
||||
- Aggregations: {', '.join(question_components.aggregations)}
|
||||
|
||||
Generate a complete SPARQL query:"""
|
||||
|
||||
return prompt
|
||||
|
||||
def _generate_fallback_query(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset) -> SPARQLQuery:
|
||||
"""Generate simple fallback query.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Basic SPARQL query
|
||||
"""
|
||||
namespace = ontology_subset.metadata.get('namespace', 'http://example.org/')
|
||||
|
||||
# Very basic SELECT query
|
||||
query = f"""PREFIX : <{namespace}>
|
||||
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
|
||||
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
||||
|
||||
SELECT ?subject ?predicate ?object WHERE {{
|
||||
?subject ?predicate ?object .
|
||||
FILTER(CONTAINS(STR(?subject), "{question_components.keywords[0] if question_components.keywords else 'entity'}"))
|
||||
}}
|
||||
LIMIT 10"""
|
||||
|
||||
return SPARQLQuery(
|
||||
query=query,
|
||||
variables=['subject', 'predicate', 'object'],
|
||||
query_type='SELECT',
|
||||
explanation="Fallback query for basic pattern matching",
|
||||
complexity_score=0.1
|
||||
)
|
||||
|
||||
def _find_matching_class(self, entity: str, ontology_subset: QueryOntologySubset) -> Optional[str]:
|
||||
"""Find matching class in ontology subset.
|
||||
|
||||
Args:
|
||||
entity: Entity string to match
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
Matching class name or None
|
||||
"""
|
||||
entity_lower = entity.lower()
|
||||
|
||||
# Direct match
|
||||
for class_id in ontology_subset.classes:
|
||||
if class_id.lower() == entity_lower:
|
||||
return class_id
|
||||
|
||||
# Label match
|
||||
for class_id, class_def in ontology_subset.classes.items():
|
||||
labels = class_def.get('labels', [])
|
||||
for label in labels:
|
||||
if isinstance(label, dict):
|
||||
label_value = label.get('value', '').lower()
|
||||
if label_value == entity_lower:
|
||||
return class_id
|
||||
|
||||
# Partial match
|
||||
for class_id in ontology_subset.classes:
|
||||
if entity_lower in class_id.lower() or class_id.lower() in entity_lower:
|
||||
return class_id
|
||||
|
||||
return None
|
||||
|
||||
def _build_constraints(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset) -> str:
|
||||
"""Build constraint clauses for SPARQL.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
SPARQL constraint string
|
||||
"""
|
||||
constraints = []
|
||||
|
||||
for constraint in question_components.constraints:
|
||||
# Simple constraint patterns
|
||||
if 'greater than' in constraint.lower():
|
||||
# Extract number
|
||||
import re
|
||||
numbers = re.findall(r'\d+', constraint)
|
||||
if numbers:
|
||||
constraints.append(f"FILTER(?value > {numbers[0]})")
|
||||
|
||||
elif 'less than' in constraint.lower():
|
||||
numbers = re.findall(r'\d+', constraint)
|
||||
if numbers:
|
||||
constraints.append(f"FILTER(?value < {numbers[0]})")
|
||||
|
||||
return '\n '.join(constraints)
|
||||
|
||||
def _build_boolean_pattern(self,
|
||||
question_components: QuestionComponents,
|
||||
ontology_subset: QueryOntologySubset) -> Optional[str]:
|
||||
"""Build triple pattern for boolean queries.
|
||||
|
||||
Args:
|
||||
question_components: Question analysis
|
||||
ontology_subset: Ontology subset
|
||||
|
||||
Returns:
|
||||
SPARQL triple pattern or None
|
||||
"""
|
||||
if len(question_components.entities) >= 2:
|
||||
subject = question_components.entities[0]
|
||||
object_val = question_components.entities[1]
|
||||
|
||||
# Try to find connecting property
|
||||
for prop_id in ontology_subset.object_properties:
|
||||
return f":{subject} :{prop_id} :{object_val} ."
|
||||
|
||||
# Fallback to type check
|
||||
return f":{subject} rdf:type :{object_val} ."
|
||||
|
||||
return None
|
||||
|
||||
def _format_classes_for_prompt(self, classes: Dict[str, Any], namespace: str) -> str:
|
||||
"""Format classes for prompt.
|
||||
|
||||
Args:
|
||||
classes: Classes dictionary
|
||||
namespace: Ontology namespace
|
||||
|
||||
Returns:
|
||||
Formatted classes string
|
||||
"""
|
||||
if not classes:
|
||||
return "None"
|
||||
|
||||
lines = []
|
||||
for class_id, definition in classes.items():
|
||||
comment = definition.get('comment', '')
|
||||
parent = definition.get('subclass_of', 'Thing')
|
||||
lines.append(f"- :{class_id} (subclass of :{parent}) - {comment}")
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _format_properties_for_prompt(self,
|
||||
object_props: Dict[str, Any],
|
||||
datatype_props: Dict[str, Any],
|
||||
namespace: str) -> str:
|
||||
"""Format properties for prompt.
|
||||
|
||||
Args:
|
||||
object_props: Object properties
|
||||
datatype_props: Datatype properties
|
||||
namespace: Ontology namespace
|
||||
|
||||
Returns:
|
||||
Formatted properties string
|
||||
"""
|
||||
lines = []
|
||||
|
||||
for prop_id, definition in object_props.items():
|
||||
domain = definition.get('domain', 'Any')
|
||||
range_val = definition.get('range', 'Any')
|
||||
comment = definition.get('comment', '')
|
||||
lines.append(f"- :{prop_id} (:{domain} -> :{range_val}) - {comment}")
|
||||
|
||||
for prop_id, definition in datatype_props.items():
|
||||
domain = definition.get('domain', 'Any')
|
||||
range_val = definition.get('range', 'xsd:string')
|
||||
comment = definition.get('comment', '')
|
||||
lines.append(f"- :{prop_id} (:{domain} -> {range_val}) - {comment}")
|
||||
|
||||
return '\n'.join(lines) if lines else "None"
|
||||
|
||||
def _extract_variables(self, query: str) -> List[str]:
|
||||
"""Extract variables from SPARQL query.
|
||||
|
||||
Args:
|
||||
query: SPARQL query string
|
||||
|
||||
Returns:
|
||||
List of variable names
|
||||
"""
|
||||
import re
|
||||
variables = re.findall(r'\?(\w+)', query)
|
||||
return list(set(variables))
|
||||
|
||||
def _calculate_complexity(self, query: str) -> float:
|
||||
"""Calculate complexity score for SPARQL query.
|
||||
|
||||
Args:
|
||||
query: SPARQL query string
|
||||
|
||||
Returns:
|
||||
Complexity score (0.0 to 1.0)
|
||||
"""
|
||||
complexity = 0.0
|
||||
|
||||
# Count different SPARQL features
|
||||
query_upper = query.upper()
|
||||
|
||||
if 'JOIN' in query_upper or 'UNION' in query_upper:
|
||||
complexity += 0.3
|
||||
if 'FILTER' in query_upper:
|
||||
complexity += 0.2
|
||||
if 'OPTIONAL' in query_upper:
|
||||
complexity += 0.1
|
||||
if 'GROUP BY' in query_upper:
|
||||
complexity += 0.2
|
||||
if 'ORDER BY' in query_upper:
|
||||
complexity += 0.1
|
||||
if '*' in query: # Property paths
|
||||
complexity += 0.1
|
||||
|
||||
# Count variables
|
||||
variables = self._extract_variables(query)
|
||||
complexity += len(variables) * 0.05
|
||||
|
||||
return min(complexity, 1.0)
|
||||
Loading…
Add table
Add a link
Reference in a new issue